summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/issues/comment_list.go12
-rw-r--r--models/issues/comment_list_test.go86
-rw-r--r--models/issues/issue_list.go18
-rw-r--r--models/issues/issue_list_test.go49
-rw-r--r--models/user/user.go14
-rw-r--r--models/user/user_test.go33
6 files changed, 188 insertions, 24 deletions
diff --git a/models/issues/comment_list.go b/models/issues/comment_list.go
index 61ac1c8f56..7a133d1c16 100644
--- a/models/issues/comment_list.go
+++ b/models/issues/comment_list.go
@@ -23,7 +23,7 @@ func (comments CommentList) LoadPosters(ctx context.Context) error {
}
posterIDs := container.FilterSlice(comments, func(c *Comment) (int64, bool) {
- return c.PosterID, c.Poster == nil && c.PosterID > 0
+ return c.PosterID, c.Poster == nil && user_model.IsValidUserID(c.PosterID)
})
posterMaps, err := getPostersByIDs(ctx, posterIDs)
@@ -33,7 +33,7 @@ func (comments CommentList) LoadPosters(ctx context.Context) error {
for _, comment := range comments {
if comment.Poster == nil {
- comment.Poster = getPoster(comment.PosterID, posterMaps)
+ comment.PosterID, comment.Poster = user_model.GetUserFromMap(comment.PosterID, posterMaps)
}
}
return nil
@@ -165,7 +165,7 @@ func (comments CommentList) loadOldMilestones(ctx context.Context) error {
func (comments CommentList) getAssigneeIDs() []int64 {
return container.FilterSlice(comments, func(comment *Comment) (int64, bool) {
- return comment.AssigneeID, comment.AssigneeID > 0
+ return comment.AssigneeID, user_model.IsValidUserID(comment.AssigneeID)
})
}
@@ -206,11 +206,7 @@ func (comments CommentList) loadAssignees(ctx context.Context) error {
}
for _, comment := range comments {
- comment.Assignee = assignees[comment.AssigneeID]
- if comment.Assignee == nil {
- comment.AssigneeID = user_model.GhostUserID
- comment.Assignee = user_model.NewGhostUser()
- }
+ comment.AssigneeID, comment.Assignee = user_model.GetUserFromMap(comment.AssigneeID, assignees)
}
return nil
}
diff --git a/models/issues/comment_list_test.go b/models/issues/comment_list_test.go
new file mode 100644
index 0000000000..66037d7358
--- /dev/null
+++ b/models/issues/comment_list_test.go
@@ -0,0 +1,86 @@
+// Copyright 2024 The Forgejo Authors
+// SPDX-License-Identifier: MIT
+
+package issues
+
+import (
+ "testing"
+
+ "code.gitea.io/gitea/models/db"
+ repo_model "code.gitea.io/gitea/models/repo"
+ "code.gitea.io/gitea/models/unittest"
+ user_model "code.gitea.io/gitea/models/user"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestCommentListLoadUser(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+
+ issue := unittest.AssertExistsAndLoadBean(t, &Issue{})
+ repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: issue.RepoID})
+ doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: repo.OwnerID})
+
+ for _, testCase := range []struct {
+ poster int64
+ assignee int64
+ user *user_model.User
+ }{
+ {
+ poster: user_model.ActionsUserID,
+ assignee: user_model.ActionsUserID,
+ user: user_model.NewActionsUser(),
+ },
+ {
+ poster: user_model.GhostUserID,
+ assignee: user_model.GhostUserID,
+ user: user_model.NewGhostUser(),
+ },
+ {
+ poster: doer.ID,
+ assignee: doer.ID,
+ user: doer,
+ },
+ {
+ poster: 0,
+ assignee: 0,
+ user: user_model.NewGhostUser(),
+ },
+ {
+ poster: -200,
+ assignee: -200,
+ user: user_model.NewGhostUser(),
+ },
+ {
+ poster: 200,
+ assignee: 200,
+ user: user_model.NewGhostUser(),
+ },
+ } {
+ t.Run(testCase.user.Name, func(t *testing.T) {
+ comment, err := CreateComment(db.DefaultContext, &CreateCommentOptions{
+ Type: CommentTypeComment,
+ Doer: testCase.user,
+ Repo: repo,
+ Issue: issue,
+ Content: "Hello",
+ })
+ assert.NoError(t, err)
+
+ list := CommentList{comment}
+
+ comment.PosterID = testCase.poster
+ comment.Poster = nil
+ assert.NoError(t, list.LoadPosters(db.DefaultContext))
+ require.NotNil(t, comment.Poster)
+ assert.Equal(t, testCase.user.ID, comment.Poster.ID)
+
+ comment.AssigneeID = testCase.assignee
+ comment.Assignee = nil
+ assert.NoError(t, list.loadAssignees(db.DefaultContext))
+ require.NotNil(t, comment.Assignee)
+ assert.Equal(t, testCase.user.ID, comment.Assignee.ID)
+ })
+ }
+}
diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go
index fbfa7584a0..fe6c630a31 100644
--- a/models/issues/issue_list.go
+++ b/models/issues/issue_list.go
@@ -79,7 +79,7 @@ func (issues IssueList) LoadPosters(ctx context.Context) error {
}
posterIDs := container.FilterSlice(issues, func(issue *Issue) (int64, bool) {
- return issue.PosterID, issue.Poster == nil && issue.PosterID > 0
+ return issue.PosterID, issue.Poster == nil && user_model.IsValidUserID(issue.PosterID)
})
posterMaps, err := getPostersByIDs(ctx, posterIDs)
@@ -89,7 +89,7 @@ func (issues IssueList) LoadPosters(ctx context.Context) error {
for _, issue := range issues {
if issue.Poster == nil {
- issue.Poster = getPoster(issue.PosterID, posterMaps)
+ issue.PosterID, issue.Poster = user_model.GetUserFromMap(issue.PosterID, posterMaps)
}
}
return nil
@@ -115,20 +115,6 @@ func getPostersByIDs(ctx context.Context, posterIDs []int64) (map[int64]*user_mo
return posterMaps, nil
}
-func getPoster(posterID int64, posterMaps map[int64]*user_model.User) *user_model.User {
- if posterID == user_model.ActionsUserID {
- return user_model.NewActionsUser()
- }
- if posterID <= 0 {
- return nil
- }
- poster, ok := posterMaps[posterID]
- if !ok {
- return user_model.NewGhostUser()
- }
- return poster
-}
-
func (issues IssueList) getIssueIDs() []int64 {
ids := make([]int64, 0, len(issues))
for _, issue := range issues {
diff --git a/models/issues/issue_list_test.go b/models/issues/issue_list_test.go
index 10ba38a64b..50bbd5c667 100644
--- a/models/issues/issue_list_test.go
+++ b/models/issues/issue_list_test.go
@@ -9,9 +9,11 @@ import (
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/unittest"
+ user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/setting"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestIssueList_LoadRepositories(t *testing.T) {
@@ -78,3 +80,50 @@ func TestIssueList_LoadAttributes(t *testing.T) {
assert.Equal(t, issue.ID == 1, issue.IsRead, "unexpected is_read value for issue[%d]", issue.ID)
}
}
+
+func TestIssueListLoadUser(t *testing.T) {
+ assert.NoError(t, unittest.PrepareTestDatabase())
+
+ issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{})
+ doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
+
+ for _, testCase := range []struct {
+ poster int64
+ user *user_model.User
+ }{
+ {
+ poster: user_model.ActionsUserID,
+ user: user_model.NewActionsUser(),
+ },
+ {
+ poster: user_model.GhostUserID,
+ user: user_model.NewGhostUser(),
+ },
+ {
+ poster: doer.ID,
+ user: doer,
+ },
+ {
+ poster: 0,
+ user: user_model.NewGhostUser(),
+ },
+ {
+ poster: -200,
+ user: user_model.NewGhostUser(),
+ },
+ {
+ poster: 200,
+ user: user_model.NewGhostUser(),
+ },
+ } {
+ t.Run(testCase.user.Name, func(t *testing.T) {
+ list := issues_model.IssueList{issue}
+
+ issue.PosterID = testCase.poster
+ issue.Poster = nil
+ assert.NoError(t, list.LoadPosters(db.DefaultContext))
+ require.NotNil(t, issue.Poster)
+ assert.Equal(t, testCase.user.ID, issue.Poster.ID)
+ })
+ }
+}
diff --git a/models/user/user.go b/models/user/user.go
index 4dae957485..56a2bc38ae 100644
--- a/models/user/user.go
+++ b/models/user/user.go
@@ -906,6 +906,20 @@ func GetUserByIDs(ctx context.Context, ids []int64) ([]*User, error) {
return users, err
}
+func IsValidUserID(id int64) bool {
+ return id > 0 || id == GhostUserID || id == ActionsUserID
+}
+
+func GetUserFromMap(id int64, idMap map[int64]*User) (int64, *User) {
+ if user, ok := idMap[id]; ok {
+ return id, user
+ }
+ if id == ActionsUserID {
+ return ActionsUserID, NewActionsUser()
+ }
+ return GhostUserID, NewGhostUser()
+}
+
// GetPossibleUserByID returns the user if id > 0 or return system usrs if id < 0
func GetPossibleUserByID(ctx context.Context, id int64) (*User, error) {
switch id {
diff --git a/models/user/user_test.go b/models/user/user_test.go
index 7457256017..6d688a694b 100644
--- a/models/user/user_test.go
+++ b/models/user/user_test.go
@@ -35,6 +35,39 @@ func TestOAuth2Application_LoadUser(t *testing.T) {
assert.NotNil(t, user)
}
+func TestIsValidUserID(t *testing.T) {
+ assert.False(t, user_model.IsValidUserID(-30))
+ assert.False(t, user_model.IsValidUserID(0))
+ assert.True(t, user_model.IsValidUserID(user_model.GhostUserID))
+ assert.True(t, user_model.IsValidUserID(user_model.ActionsUserID))
+ assert.True(t, user_model.IsValidUserID(200))
+}
+
+func TestGetUserFromMap(t *testing.T) {
+ id := int64(200)
+ idMap := map[int64]*user_model.User{
+ id: {ID: id},
+ }
+
+ ghostID := int64(user_model.GhostUserID)
+ actionsID := int64(user_model.ActionsUserID)
+ actualID, actualUser := user_model.GetUserFromMap(-20, idMap)
+ assert.Equal(t, ghostID, actualID)
+ assert.Equal(t, ghostID, actualUser.ID)
+
+ actualID, actualUser = user_model.GetUserFromMap(0, idMap)
+ assert.Equal(t, ghostID, actualID)
+ assert.Equal(t, ghostID, actualUser.ID)
+
+ actualID, actualUser = user_model.GetUserFromMap(ghostID, idMap)
+ assert.Equal(t, ghostID, actualID)
+ assert.Equal(t, ghostID, actualUser.ID)
+
+ actualID, actualUser = user_model.GetUserFromMap(actionsID, idMap)
+ assert.Equal(t, actionsID, actualID)
+ assert.Equal(t, actionsID, actualUser.ID)
+}
+
func TestGetUserByName(t *testing.T) {
defer tests.AddFixtures("models/user/fixtures/")()
assert.NoError(t, unittest.PrepareTestDatabase())