diff options
author | JakobDev <jakobdev@gmx.de> | 2023-09-16 16:39:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-16 16:39:12 +0200 |
commit | f91dbbba98c841f11d99be998ed5dd98122a457c (patch) | |
tree | 9c6c935ccf745c5a1716f1330922354809cd39e0 | |
parent | Ui correction in mobile view nav bar left aligned items. (#27046) (diff) | |
download | forgejo-f91dbbba98c841f11d99be998ed5dd98122a457c.tar.xz forgejo-f91dbbba98c841f11d99be998ed5dd98122a457c.zip |
Next round of `db.DefaultContext` refactor (#27089)
Part of #27065
90 files changed, 434 insertions, 464 deletions
diff --git a/models/actions/schedule.go b/models/actions/schedule.go index b0bc40dadc..34d23f1c01 100644 --- a/models/actions/schedule.go +++ b/models/actions/schedule.go @@ -41,15 +41,15 @@ func init() { } // GetSchedulesMapByIDs returns the schedules by given id slice. -func GetSchedulesMapByIDs(ids []int64) (map[int64]*ActionSchedule, error) { +func GetSchedulesMapByIDs(ctx context.Context, ids []int64) (map[int64]*ActionSchedule, error) { schedules := make(map[int64]*ActionSchedule, len(ids)) - return schedules, db.GetEngine(db.DefaultContext).In("id", ids).Find(&schedules) + return schedules, db.GetEngine(ctx).In("id", ids).Find(&schedules) } // GetReposMapByIDs returns the repos by given id slice. -func GetReposMapByIDs(ids []int64) (map[int64]*repo_model.Repository, error) { +func GetReposMapByIDs(ctx context.Context, ids []int64) (map[int64]*repo_model.Repository, error) { repos := make(map[int64]*repo_model.Repository, len(ids)) - return repos, db.GetEngine(db.DefaultContext).In("id", ids).Find(&repos) + return repos, db.GetEngine(ctx).In("id", ids).Find(&repos) } var cronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor) diff --git a/models/actions/schedule_spec_list.go b/models/actions/schedule_spec_list.go index d379490b4e..2c017fdabc 100644 --- a/models/actions/schedule_spec_list.go +++ b/models/actions/schedule_spec_list.go @@ -23,9 +23,9 @@ func (specs SpecList) GetScheduleIDs() []int64 { return ids.Values() } -func (specs SpecList) LoadSchedules() error { +func (specs SpecList) LoadSchedules(ctx context.Context) error { scheduleIDs := specs.GetScheduleIDs() - schedules, err := GetSchedulesMapByIDs(scheduleIDs) + schedules, err := GetSchedulesMapByIDs(ctx, scheduleIDs) if err != nil { return err } @@ -34,7 +34,7 @@ func (specs SpecList) LoadSchedules() error { } repoIDs := specs.GetRepoIDs() - repos, err := GetReposMapByIDs(repoIDs) + repos, err := GetReposMapByIDs(ctx, repoIDs) if err != nil { return err } @@ -95,7 +95,7 @@ func FindSpecs(ctx context.Context, opts FindSpecOptions) (SpecList, int64, erro return nil, 0, err } - if err := specs.LoadSchedules(); err != nil { + if err := specs.LoadSchedules(ctx); err != nil { return nil, 0, err } return specs, total, nil diff --git a/models/admin/task.go b/models/admin/task.go index 8aa397ad35..c8bc95f981 100644 --- a/models/admin/task.go +++ b/models/admin/task.go @@ -48,11 +48,7 @@ type TranslatableMessage struct { } // LoadRepo loads repository of the task -func (task *Task) LoadRepo() error { - return task.loadRepo(db.DefaultContext) -} - -func (task *Task) loadRepo(ctx context.Context) error { +func (task *Task) LoadRepo(ctx context.Context) error { if task.Repo != nil { return nil } @@ -70,13 +66,13 @@ func (task *Task) loadRepo(ctx context.Context) error { } // LoadDoer loads do user -func (task *Task) LoadDoer() error { +func (task *Task) LoadDoer(ctx context.Context) error { if task.Doer != nil { return nil } var doer user_model.User - has, err := db.GetEngine(db.DefaultContext).ID(task.DoerID).Get(&doer) + has, err := db.GetEngine(ctx).ID(task.DoerID).Get(&doer) if err != nil { return err } else if !has { @@ -90,13 +86,13 @@ func (task *Task) LoadDoer() error { } // LoadOwner loads owner user -func (task *Task) LoadOwner() error { +func (task *Task) LoadOwner(ctx context.Context) error { if task.Owner != nil { return nil } var owner user_model.User - has, err := db.GetEngine(db.DefaultContext).ID(task.OwnerID).Get(&owner) + has, err := db.GetEngine(ctx).ID(task.OwnerID).Get(&owner) if err != nil { return err } else if !has { @@ -110,8 +106,8 @@ func (task *Task) LoadOwner() error { } // UpdateCols updates some columns -func (task *Task) UpdateCols(cols ...string) error { - _, err := db.GetEngine(db.DefaultContext).ID(task.ID).Cols(cols...).Update(task) +func (task *Task) UpdateCols(ctx context.Context, cols ...string) error { + _, err := db.GetEngine(ctx).ID(task.ID).Cols(cols...).Update(task) return err } @@ -169,12 +165,12 @@ func (err ErrTaskDoesNotExist) Unwrap() error { } // GetMigratingTask returns the migrating task by repo's id -func GetMigratingTask(repoID int64) (*Task, error) { +func GetMigratingTask(ctx context.Context, repoID int64) (*Task, error) { task := Task{ RepoID: repoID, Type: structs.TaskTypeMigrateRepo, } - has, err := db.GetEngine(db.DefaultContext).Get(&task) + has, err := db.GetEngine(ctx).Get(&task) if err != nil { return nil, err } else if !has { @@ -184,13 +180,13 @@ func GetMigratingTask(repoID int64) (*Task, error) { } // GetMigratingTaskByID returns the migrating task by repo's id -func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, error) { +func GetMigratingTaskByID(ctx context.Context, id, doerID int64) (*Task, *migration.MigrateOptions, error) { task := Task{ ID: id, DoerID: doerID, Type: structs.TaskTypeMigrateRepo, } - has, err := db.GetEngine(db.DefaultContext).Get(&task) + has, err := db.GetEngine(ctx).Get(&task) if err != nil { return nil, nil, err } else if !has { @@ -205,12 +201,12 @@ func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, e } // CreateTask creates a task on database -func CreateTask(task *Task) error { - return db.Insert(db.DefaultContext, task) +func CreateTask(ctx context.Context, task *Task) error { + return db.Insert(ctx, task) } // FinishMigrateTask updates database when migrate task finished -func FinishMigrateTask(task *Task) error { +func FinishMigrateTask(ctx context.Context, task *Task) error { task.Status = structs.TaskStatusFinished task.EndTime = timeutil.TimeStampNow() @@ -231,6 +227,6 @@ func FinishMigrateTask(task *Task) error { } task.PayloadContent = string(confBytes) - _, err = db.GetEngine(db.DefaultContext).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task) + _, err = db.GetEngine(ctx).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task) return err } diff --git a/models/auth/session.go b/models/auth/session.go index b60e6a903b..28f25170ee 100644 --- a/models/auth/session.go +++ b/models/auth/session.go @@ -4,6 +4,7 @@ package auth import ( + "context" "fmt" "code.gitea.io/gitea/models/db" @@ -22,8 +23,8 @@ func init() { } // UpdateSession updates the session with provided id -func UpdateSession(key string, data []byte) error { - _, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{ +func UpdateSession(ctx context.Context, key string, data []byte) error { + _, err := db.GetEngine(ctx).ID(key).Update(&Session{ Data: data, Expiry: timeutil.TimeStampNow(), }) @@ -31,12 +32,12 @@ func UpdateSession(key string, data []byte) error { } // ReadSession reads the data for the provided session -func ReadSession(key string) (*Session, error) { +func ReadSession(ctx context.Context, key string) (*Session, error) { session := Session{ Key: key, } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return nil, err } @@ -55,24 +56,24 @@ func ReadSession(key string) (*Session, error) { } // ExistSession checks if a session exists -func ExistSession(key string) (bool, error) { +func ExistSession(ctx context.Context, key string) (bool, error) { session := Session{ Key: key, } - return db.GetEngine(db.DefaultContext).Get(&session) + return db.GetEngine(ctx).Get(&session) } // DestroySession destroys a session -func DestroySession(key string) error { - _, err := db.GetEngine(db.DefaultContext).Delete(&Session{ +func DestroySession(ctx context.Context, key string) error { + _, err := db.GetEngine(ctx).Delete(&Session{ Key: key, }) return err } // RegenerateSession regenerates a session from the old id -func RegenerateSession(oldKey, newKey string) (*Session, error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return nil, err } @@ -114,12 +115,12 @@ func RegenerateSession(oldKey, newKey string) (*Session, error) { } // CountSessions returns the number of sessions -func CountSessions() (int64, error) { - return db.GetEngine(db.DefaultContext).Count(&Session{}) +func CountSessions(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Count(&Session{}) } // CleanupSessions cleans up expired sessions -func CleanupSessions(maxLifetime int64) error { - _, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{}) +func CleanupSessions(ctx context.Context, maxLifetime int64) error { + _, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{}) return err } diff --git a/models/auth/webauthn.go b/models/auth/webauthn.go index db5dd7eea5..d12713bd37 100644 --- a/models/auth/webauthn.go +++ b/models/auth/webauthn.go @@ -67,11 +67,7 @@ func (cred WebAuthnCredential) TableName() string { } // UpdateSignCount will update the database value of SignCount -func (cred *WebAuthnCredential) UpdateSignCount() error { - return cred.updateSignCount(db.DefaultContext) -} - -func (cred *WebAuthnCredential) updateSignCount(ctx context.Context) error { +func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error { _, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred) return err } @@ -113,30 +109,18 @@ func (list WebAuthnCredentialList) ToCredentials() []webauthn.Credential { } // GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user -func GetWebAuthnCredentialsByUID(uid int64) (WebAuthnCredentialList, error) { - return getWebAuthnCredentialsByUID(db.DefaultContext, uid) -} - -func getWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) { +func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) { creds := make(WebAuthnCredentialList, 0) return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds) } // ExistsWebAuthnCredentialsForUID returns if the given user has credentials -func ExistsWebAuthnCredentialsForUID(uid int64) (bool, error) { - return existsWebAuthnCredentialsByUID(db.DefaultContext, uid) -} - -func existsWebAuthnCredentialsByUID(ctx context.Context, uid int64) (bool, error) { +func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) { return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) } // GetWebAuthnCredentialByName returns WebAuthn credential by id -func GetWebAuthnCredentialByName(uid int64, name string) (*WebAuthnCredential, error) { - return getWebAuthnCredentialByName(db.DefaultContext, uid, name) -} - -func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) { +func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) { cred := new(WebAuthnCredential) if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil { return nil, err @@ -147,11 +131,7 @@ func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (* } // GetWebAuthnCredentialByID returns WebAuthn credential by id -func GetWebAuthnCredentialByID(id int64) (*WebAuthnCredential, error) { - return getWebAuthnCredentialByID(db.DefaultContext, id) -} - -func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) { +func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) { cred := new(WebAuthnCredential) if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil { return nil, err @@ -162,16 +142,12 @@ func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredenti } // HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations -func HasWebAuthnRegistrationsByUID(uid int64) (bool, error) { - return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) +func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) { + return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) } // GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID -func GetWebAuthnCredentialByCredID(userID int64, credID []byte) (*WebAuthnCredential, error) { - return getWebAuthnCredentialByCredID(db.DefaultContext, userID, credID) -} - -func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) { +func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) { cred := new(WebAuthnCredential) if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil { return nil, err @@ -182,11 +158,7 @@ func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []b } // CreateCredential will create a new WebAuthnCredential from the given Credential -func CreateCredential(userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) { - return createCredential(db.DefaultContext, userID, name, cred) -} - -func createCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) { +func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) { c := &WebAuthnCredential{ UserID: userID, Name: name, @@ -205,18 +177,14 @@ func createCredential(ctx context.Context, userID int64, name string, cred *weba } // DeleteCredential will delete WebAuthnCredential -func DeleteCredential(id, userID int64) (bool, error) { - return deleteCredential(db.DefaultContext, id, userID) -} - -func deleteCredential(ctx context.Context, id, userID int64) (bool, error) { +func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) { had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{}) return had > 0, err } // WebAuthnCredentials implementns the webauthn.User interface -func WebAuthnCredentials(userID int64) ([]webauthn.Credential, error) { - dbCreds, err := GetWebAuthnCredentialsByUID(userID) +func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) { + dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID) if err != nil { return nil, err } diff --git a/models/auth/webauthn_test.go b/models/auth/webauthn_test.go index 6f2ec087c7..f1cf398adf 100644 --- a/models/auth/webauthn_test.go +++ b/models/auth/webauthn_test.go @@ -7,6 +7,7 @@ import ( "testing" auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" "github.com/go-webauthn/webauthn/webauthn" @@ -16,11 +17,11 @@ import ( func TestGetWebAuthnCredentialByID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - res, err := auth_model.GetWebAuthnCredentialByID(1) + res, err := auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 1) assert.NoError(t, err) assert.Equal(t, "WebAuthn credential", res.Name) - _, err = auth_model.GetWebAuthnCredentialByID(342432) + _, err = auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 342432) assert.Error(t, err) assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err)) } @@ -28,7 +29,7 @@ func TestGetWebAuthnCredentialByID(t *testing.T) { func TestGetWebAuthnCredentialsByUID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - res, err := auth_model.GetWebAuthnCredentialsByUID(32) + res, err := auth_model.GetWebAuthnCredentialsByUID(db.DefaultContext, 32) assert.NoError(t, err) assert.Len(t, res, 1) assert.Equal(t, "WebAuthn credential", res[0].Name) @@ -42,7 +43,7 @@ func TestWebAuthnCredential_UpdateSignCount(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1}) cred.SignCount = 1 - assert.NoError(t, cred.UpdateSignCount()) + assert.NoError(t, cred.UpdateSignCount(db.DefaultContext)) unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1}) } @@ -50,14 +51,14 @@ func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1}) cred.SignCount = 0xffffffff - assert.NoError(t, cred.UpdateSignCount()) + assert.NoError(t, cred.UpdateSignCount(db.DefaultContext)) unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff}) } func TestCreateCredential(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - res, err := auth_model.CreateCredential(1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")}) + res, err := auth_model.CreateCredential(db.DefaultContext, 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")}) assert.NoError(t, err) assert.Equal(t, "WebAuthn Created Credential", res.Name) assert.Equal(t, []byte("Test"), res.CredentialID) diff --git a/models/issues/issue_test.go b/models/issues/issue_test.go index 747fbbc78c..b7fa7eff1c 100644 --- a/models/issues/issue_test.go +++ b/models/issues/issue_test.go @@ -385,7 +385,7 @@ func TestMilestoneList_LoadTotalTrackedTimes(t *testing.T) { unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}), } - assert.NoError(t, miles.LoadTotalTrackedTimes()) + assert.NoError(t, miles.LoadTotalTrackedTimes(db.DefaultContext)) assert.Equal(t, int64(3682), miles[0].TotalTrackedTime) } @@ -394,7 +394,7 @@ func TestLoadTotalTrackedTime(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) - assert.NoError(t, milestone.LoadTotalTrackedTime()) + assert.NoError(t, milestone.LoadTotalTrackedTime(db.DefaultContext)) assert.Equal(t, int64(3682), milestone.TotalTrackedTime) } diff --git a/models/issues/issue_watch.go b/models/issues/issue_watch.go index 1efc0ea687..b7e9504c67 100644 --- a/models/issues/issue_watch.go +++ b/models/issues/issue_watch.go @@ -30,8 +30,8 @@ func init() { type IssueWatchList []*IssueWatch // CreateOrUpdateIssueWatch set watching for a user and issue -func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { - iw, exists, err := GetIssueWatch(db.DefaultContext, userID, issueID) +func CreateOrUpdateIssueWatch(ctx context.Context, userID, issueID int64, isWatching bool) error { + iw, exists, err := GetIssueWatch(ctx, userID, issueID) if err != nil { return err } @@ -43,13 +43,13 @@ func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { IsWatching: isWatching, } - if _, err := db.GetEngine(db.DefaultContext).Insert(iw); err != nil { + if _, err := db.GetEngine(ctx).Insert(iw); err != nil { return err } } else { iw.IsWatching = isWatching - if _, err := db.GetEngine(db.DefaultContext).ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil { + if _, err := db.GetEngine(ctx).ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil { return err } } @@ -69,15 +69,15 @@ func GetIssueWatch(ctx context.Context, userID, issueID int64) (iw *IssueWatch, // CheckIssueWatch check if an user is watching an issue // it takes participants and repo watch into account -func CheckIssueWatch(user *user_model.User, issue *Issue) (bool, error) { - iw, exist, err := GetIssueWatch(db.DefaultContext, user.ID, issue.ID) +func CheckIssueWatch(ctx context.Context, user *user_model.User, issue *Issue) (bool, error) { + iw, exist, err := GetIssueWatch(ctx, user.ID, issue.ID) if err != nil { return false, err } if exist { return iw.IsWatching, nil } - w, err := repo_model.GetWatch(db.DefaultContext, user.ID, issue.RepoID) + w, err := repo_model.GetWatch(ctx, user.ID, issue.RepoID) if err != nil { return false, err } diff --git a/models/issues/issue_watch_test.go b/models/issues/issue_watch_test.go index 4f44487f56..d4ce8d8d3d 100644 --- a/models/issues/issue_watch_test.go +++ b/models/issues/issue_watch_test.go @@ -16,11 +16,11 @@ import ( func TestCreateOrUpdateIssueWatch(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(3, 1, true)) + assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(db.DefaultContext, 3, 1, true)) iw := unittest.AssertExistsAndLoadBean(t, &issues_model.IssueWatch{UserID: 3, IssueID: 1}) assert.True(t, iw.IsWatching) - assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(1, 1, false)) + assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(db.DefaultContext, 1, 1, false)) iw = unittest.AssertExistsAndLoadBean(t, &issues_model.IssueWatch{UserID: 1, IssueID: 1}) assert.False(t, iw.IsWatching) } diff --git a/models/issues/label.go b/models/issues/label.go index 0087c933a6..f8dbb9e39c 100644 --- a/models/issues/label.go +++ b/models/issues/label.go @@ -199,8 +199,8 @@ func NewLabel(ctx context.Context, l *Label) error { } // NewLabels creates new labels -func NewLabels(labels ...*Label) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func NewLabels(ctx context.Context, labels ...*Label) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -221,19 +221,19 @@ func NewLabels(labels ...*Label) error { } // UpdateLabel updates label information. -func UpdateLabel(l *Label) error { +func UpdateLabel(ctx context.Context, l *Label) error { color, err := label.NormalizeColor(l.Color) if err != nil { return err } l.Color = color - return updateLabelCols(db.DefaultContext, l, "name", "description", "color", "exclusive", "archived_unix") + return updateLabelCols(ctx, l, "name", "description", "color", "exclusive", "archived_unix") } // DeleteLabel delete a label -func DeleteLabel(id, labelID int64) error { - l, err := GetLabelByID(db.DefaultContext, labelID) +func DeleteLabel(ctx context.Context, id, labelID int64) error { + l, err := GetLabelByID(ctx, labelID) if err != nil { if IsErrLabelNotExist(err) { return nil @@ -241,7 +241,7 @@ func DeleteLabel(id, labelID int64) error { return err } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -289,9 +289,9 @@ func GetLabelByID(ctx context.Context, labelID int64) (*Label, error) { } // GetLabelsByIDs returns a list of labels by IDs -func GetLabelsByIDs(labelIDs []int64, cols ...string) ([]*Label, error) { +func GetLabelsByIDs(ctx context.Context, labelIDs []int64, cols ...string) ([]*Label, error) { labels := make([]*Label, 0, len(labelIDs)) - return labels, db.GetEngine(db.DefaultContext).Table("label"). + return labels, db.GetEngine(ctx).Table("label"). In("id", labelIDs). Asc("name"). Cols(cols...). @@ -339,9 +339,9 @@ func GetLabelInRepoByID(ctx context.Context, repoID, labelID int64) (*Label, err // GetLabelIDsInRepoByNames returns a list of labelIDs by names in a given // repository. // it silently ignores label names that do not belong to the repository. -func GetLabelIDsInRepoByNames(repoID int64, labelNames []string) ([]int64, error) { +func GetLabelIDsInRepoByNames(ctx context.Context, repoID int64, labelNames []string) ([]int64, error) { labelIDs := make([]int64, 0, len(labelNames)) - return labelIDs, db.GetEngine(db.DefaultContext).Table("label"). + return labelIDs, db.GetEngine(ctx).Table("label"). Where("repo_id = ?", repoID). In("name", labelNames). Asc("name"). @@ -398,8 +398,8 @@ func GetLabelsByRepoID(ctx context.Context, repoID int64, sortType string, listO } // CountLabelsByRepoID count number of all labels that belong to given repository by ID. -func CountLabelsByRepoID(repoID int64) (int64, error) { - return db.GetEngine(db.DefaultContext).Where("repo_id = ?", repoID).Count(&Label{}) +func CountLabelsByRepoID(ctx context.Context, repoID int64) (int64, error) { + return db.GetEngine(ctx).Where("repo_id = ?", repoID).Count(&Label{}) } // GetLabelInOrgByName returns a label by name in given organization. @@ -442,13 +442,13 @@ func GetLabelInOrgByID(ctx context.Context, orgID, labelID int64) (*Label, error // GetLabelIDsInOrgByNames returns a list of labelIDs by names in a given // organization. -func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error) { +func GetLabelIDsInOrgByNames(ctx context.Context, orgID int64, labelNames []string) ([]int64, error) { if orgID <= 0 { return nil, ErrOrgLabelNotExist{0, orgID} } labelIDs := make([]int64, 0, len(labelNames)) - return labelIDs, db.GetEngine(db.DefaultContext).Table("label"). + return labelIDs, db.GetEngine(ctx).Table("label"). Where("org_id = ?", orgID). In("name", labelNames). Asc("name"). @@ -506,8 +506,8 @@ func GetLabelIDsByNames(ctx context.Context, labelNames []string) ([]int64, erro } // CountLabelsByOrgID count all labels that belong to given organization by ID. -func CountLabelsByOrgID(orgID int64) (int64, error) { - return db.GetEngine(db.DefaultContext).Where("org_id = ?", orgID).Count(&Label{}) +func CountLabelsByOrgID(ctx context.Context, orgID int64) (int64, error) { + return db.GetEngine(ctx).Where("org_id = ?", orgID).Count(&Label{}) } func updateLabelCols(ctx context.Context, l *Label, cols ...string) error { diff --git a/models/issues/label_test.go b/models/issues/label_test.go index 3f0e980b31..9f44cd3e03 100644 --- a/models/issues/label_test.go +++ b/models/issues/label_test.go @@ -48,7 +48,7 @@ func TestNewLabels(t *testing.T) { for _, label := range labels { unittest.AssertNotExistsBean(t, label) } - assert.NoError(t, issues_model.NewLabels(labels...)) + assert.NoError(t, issues_model.NewLabels(db.DefaultContext, labels...)) for _, label := range labels { unittest.AssertExistsAndLoadBean(t, label, unittest.Cond("id = ?", label.ID)) } @@ -81,7 +81,7 @@ func TestGetLabelInRepoByName(t *testing.T) { func TestGetLabelInRepoByNames(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - labelIDs, err := issues_model.GetLabelIDsInRepoByNames(1, []string{"label1", "label2"}) + labelIDs, err := issues_model.GetLabelIDsInRepoByNames(db.DefaultContext, 1, []string{"label1", "label2"}) assert.NoError(t, err) assert.Len(t, labelIDs, 2) @@ -93,7 +93,7 @@ func TestGetLabelInRepoByNames(t *testing.T) { func TestGetLabelInRepoByNamesDiscardsNonExistentLabels(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) // label3 doesn't exists.. See labels.yml - labelIDs, err := issues_model.GetLabelIDsInRepoByNames(1, []string{"label1", "label2", "label3"}) + labelIDs, err := issues_model.GetLabelIDsInRepoByNames(db.DefaultContext, 1, []string{"label1", "label2", "label3"}) assert.NoError(t, err) assert.Len(t, labelIDs, 2) @@ -166,7 +166,7 @@ func TestGetLabelInOrgByName(t *testing.T) { func TestGetLabelInOrgByNames(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - labelIDs, err := issues_model.GetLabelIDsInOrgByNames(3, []string{"orglabel3", "orglabel4"}) + labelIDs, err := issues_model.GetLabelIDsInOrgByNames(db.DefaultContext, 3, []string{"orglabel3", "orglabel4"}) assert.NoError(t, err) assert.Len(t, labelIDs, 2) @@ -178,7 +178,7 @@ func TestGetLabelInOrgByNames(t *testing.T) { func TestGetLabelInOrgByNamesDiscardsNonExistentLabels(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) // orglabel99 doesn't exists.. See labels.yml - labelIDs, err := issues_model.GetLabelIDsInOrgByNames(3, []string{"orglabel3", "orglabel4", "orglabel99"}) + labelIDs, err := issues_model.GetLabelIDsInOrgByNames(db.DefaultContext, 3, []string{"orglabel3", "orglabel4", "orglabel99"}) assert.NoError(t, err) assert.Len(t, labelIDs, 2) @@ -269,7 +269,7 @@ func TestUpdateLabel(t *testing.T) { } label.Color = update.Color label.Name = update.Name - assert.NoError(t, issues_model.UpdateLabel(update)) + assert.NoError(t, issues_model.UpdateLabel(db.DefaultContext, update)) newLabel := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 1}) assert.EqualValues(t, label.ID, newLabel.ID) assert.EqualValues(t, label.Color, newLabel.Color) @@ -282,13 +282,13 @@ func TestUpdateLabel(t *testing.T) { func TestDeleteLabel(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) label := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 1}) - assert.NoError(t, issues_model.DeleteLabel(label.RepoID, label.ID)) + assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, label.RepoID, label.ID)) unittest.AssertNotExistsBean(t, &issues_model.Label{ID: label.ID, RepoID: label.RepoID}) - assert.NoError(t, issues_model.DeleteLabel(label.RepoID, label.ID)) + assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, label.RepoID, label.ID)) unittest.AssertNotExistsBean(t, &issues_model.Label{ID: label.ID}) - assert.NoError(t, issues_model.DeleteLabel(unittest.NonexistentID, unittest.NonexistentID)) + assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) unittest.CheckConsistencyFor(t, &issues_model.Label{}, &repo_model.Repository{}) } diff --git a/models/issues/milestone.go b/models/issues/milestone.go index c15b2a41fe..ad1d5d0453 100644 --- a/models/issues/milestone.go +++ b/models/issues/milestone.go @@ -103,8 +103,8 @@ func (m *Milestone) State() api.StateType { } // NewMilestone creates new milestone of repository. -func NewMilestone(m *Milestone) (err error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func NewMilestone(ctx context.Context, m *Milestone) (err error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -140,9 +140,9 @@ func GetMilestoneByRepoID(ctx context.Context, repoID, id int64) (*Milestone, er } // GetMilestoneByRepoIDANDName return a milestone if one exist by name and repo -func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error) { +func GetMilestoneByRepoIDANDName(ctx context.Context, repoID int64, name string) (*Milestone, error) { var mile Milestone - has, err := db.GetEngine(db.DefaultContext).Where("repo_id=? AND name=?", repoID, name).Get(&mile) + has, err := db.GetEngine(ctx).Where("repo_id=? AND name=?", repoID, name).Get(&mile) if err != nil { return nil, err } @@ -153,8 +153,8 @@ func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error) } // UpdateMilestone updates information of given milestone. -func UpdateMilestone(m *Milestone, oldIsClosed bool) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func UpdateMilestone(ctx context.Context, m *Milestone, oldIsClosed bool) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -211,8 +211,8 @@ func UpdateMilestoneCounters(ctx context.Context, id int64) error { } // ChangeMilestoneStatusByRepoIDAndID changes a milestone open/closed status if the milestone ID is in the repo. -func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func ChangeMilestoneStatusByRepoIDAndID(ctx context.Context, repoID, milestoneID int64, isClosed bool) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -238,8 +238,8 @@ func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool } // ChangeMilestoneStatus changes the milestone open/closed status. -func ChangeMilestoneStatus(m *Milestone, isClosed bool) (err error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func ChangeMilestoneStatus(ctx context.Context, m *Milestone, isClosed bool) (err error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -269,8 +269,8 @@ func changeMilestoneStatus(ctx context.Context, m *Milestone, isClosed bool) err } // DeleteMilestoneByRepoID deletes a milestone from a repository. -func DeleteMilestoneByRepoID(repoID, id int64) error { - m, err := GetMilestoneByRepoID(db.DefaultContext, repoID, id) +func DeleteMilestoneByRepoID(ctx context.Context, repoID, id int64) error { + m, err := GetMilestoneByRepoID(ctx, repoID, id) if err != nil { if IsErrMilestoneNotExist(err) { return nil @@ -278,12 +278,12 @@ func DeleteMilestoneByRepoID(repoID, id int64) error { return err } - repo, err := repo_model.GetRepositoryByID(db.DefaultContext, m.RepoID) + repo, err := repo_model.GetRepositoryByID(ctx, m.RepoID) if err != nil { return err } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -332,7 +332,8 @@ func updateRepoMilestoneNum(ctx context.Context, repoID int64) error { return err } -func (m *Milestone) loadTotalTrackedTime(ctx context.Context) error { +// LoadTotalTrackedTime loads the tracked time for the milestone +func (m *Milestone) LoadTotalTrackedTime(ctx context.Context) error { type totalTimesByMilestone struct { MilestoneID int64 Time int64 @@ -355,18 +356,13 @@ func (m *Milestone) loadTotalTrackedTime(ctx context.Context) error { return nil } -// LoadTotalTrackedTime loads the tracked time for the milestone -func (m *Milestone) LoadTotalTrackedTime() error { - return m.loadTotalTrackedTime(db.DefaultContext) -} - // InsertMilestones creates milestones of repository. -func InsertMilestones(ms ...*Milestone) (err error) { +func InsertMilestones(ctx context.Context, ms ...*Milestone) (err error) { if len(ms) == 0 { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/issues/milestone_list.go b/models/issues/milestone_list.go index b0c29106a0..d5c9b1358c 100644 --- a/models/issues/milestone_list.go +++ b/models/issues/milestone_list.go @@ -100,9 +100,9 @@ func GetMilestoneIDsByNames(ctx context.Context, names []string) ([]int64, error } // SearchMilestones search milestones -func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType, keyword string) (MilestoneList, error) { +func SearchMilestones(ctx context.Context, repoCond builder.Cond, page int, isClosed bool, sortType, keyword string) (MilestoneList, error) { miles := make([]*Milestone, 0, setting.UI.IssuePagingNum) - sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) + sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } @@ -131,8 +131,9 @@ func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType, } // GetMilestonesByRepoIDs returns a list of milestones of given repositories and status. -func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) { +func GetMilestonesByRepoIDs(ctx context.Context, repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) { return SearchMilestones( + ctx, builder.In("repo_id", repoIDs), page, isClosed, @@ -141,7 +142,8 @@ func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType s ) } -func (milestones MilestoneList) loadTotalTrackedTimes(ctx context.Context) error { +// LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request +func (milestones MilestoneList) LoadTotalTrackedTimes(ctx context.Context) error { type totalTimesByMilestone struct { MilestoneID int64 Time int64 @@ -181,11 +183,6 @@ func (milestones MilestoneList) loadTotalTrackedTimes(ctx context.Context) error return nil } -// LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request -func (milestones MilestoneList) LoadTotalTrackedTimes() error { - return milestones.loadTotalTrackedTimes(db.DefaultContext) -} - // CountMilestones returns number of milestones in given repository with other options func CountMilestones(ctx context.Context, opts GetMilestonesOption) (int64, error) { return db.GetEngine(ctx). @@ -194,8 +191,8 @@ func CountMilestones(ctx context.Context, opts GetMilestonesOption) (int64, erro } // CountMilestonesByRepoCond map from repo conditions to number of milestones matching the options` -func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64]int64, error) { - sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) +func CountMilestonesByRepoCond(ctx context.Context, repoCond builder.Cond, isClosed bool) (map[int64]int64, error) { + sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed) if repoCond.IsValid() { sess.In("repo_id", builder.Select("id").From("repository").Where(repoCond)) } @@ -219,8 +216,8 @@ func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64] } // CountMilestonesByRepoCondAndKw map from repo conditions and the keyword of milestones' name to number of milestones matching the options` -func CountMilestonesByRepoCondAndKw(repoCond builder.Cond, keyword string, isClosed bool) (map[int64]int64, error) { - sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed) +func CountMilestonesByRepoCondAndKw(ctx context.Context, repoCond builder.Cond, keyword string, isClosed bool) (map[int64]int64, error) { + sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } @@ -257,11 +254,11 @@ func (m MilestonesStats) Total() int64 { } // GetMilestonesStatsByRepoCond returns milestone statistic information for dashboard by given conditions. -func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, error) { +func GetMilestonesStatsByRepoCond(ctx context.Context, repoCond builder.Cond) (*MilestonesStats, error) { var err error stats := &MilestonesStats{} - sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", false) + sess := db.GetEngine(ctx).Where("is_closed = ?", false) if repoCond.IsValid() { sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond))) } @@ -270,7 +267,7 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro return nil, err } - sess = db.GetEngine(db.DefaultContext).Where("is_closed = ?", true) + sess = db.GetEngine(ctx).Where("is_closed = ?", true) if repoCond.IsValid() { sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond))) } @@ -283,11 +280,11 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro } // GetMilestonesStatsByRepoCondAndKw returns milestone statistic information for dashboard by given repo conditions and name keyword. -func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (*MilestonesStats, error) { +func GetMilestonesStatsByRepoCondAndKw(ctx context.Context, repoCond builder.Cond, keyword string) (*MilestonesStats, error) { var err error stats := &MilestonesStats{} - sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", false) + sess := db.GetEngine(ctx).Where("is_closed = ?", false) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } @@ -299,7 +296,7 @@ func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (* return nil, err } - sess = db.GetEngine(db.DefaultContext).Where("is_closed = ?", true) + sess = db.GetEngine(ctx).Where("is_closed = ?", true) if len(keyword) > 0 { sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)}) } diff --git a/models/issues/milestone_test.go b/models/issues/milestone_test.go index e85d77ebc8..403eeaadb3 100644 --- a/models/issues/milestone_test.go +++ b/models/issues/milestone_test.go @@ -201,12 +201,12 @@ func TestCountMilestonesByRepoIDs(t *testing.T) { repo1OpenCount, repo1ClosedCount := milestonesCount(1) repo2OpenCount, repo2ClosedCount := milestonesCount(2) - openCounts, err := issues_model.CountMilestonesByRepoCond(builder.In("repo_id", []int64{1, 2}), false) + openCounts, err := issues_model.CountMilestonesByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{1, 2}), false) assert.NoError(t, err) assert.EqualValues(t, repo1OpenCount, openCounts[1]) assert.EqualValues(t, repo2OpenCount, openCounts[2]) - closedCounts, err := issues_model.CountMilestonesByRepoCond(builder.In("repo_id", []int64{1, 2}), true) + closedCounts, err := issues_model.CountMilestonesByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{1, 2}), true) assert.NoError(t, err) assert.EqualValues(t, repo1ClosedCount, closedCounts[1]) assert.EqualValues(t, repo2ClosedCount, closedCounts[2]) @@ -218,7 +218,7 @@ func TestGetMilestonesByRepoIDs(t *testing.T) { repo2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 2}) test := func(sortType string, sortCond func(*issues_model.Milestone) int) { for _, page := range []int{0, 1} { - openMilestones, err := issues_model.GetMilestonesByRepoIDs([]int64{repo1.ID, repo2.ID}, page, false, sortType) + openMilestones, err := issues_model.GetMilestonesByRepoIDs(db.DefaultContext, []int64{repo1.ID, repo2.ID}, page, false, sortType) assert.NoError(t, err) assert.Len(t, openMilestones, repo1.NumOpenMilestones+repo2.NumOpenMilestones) values := make([]int, len(openMilestones)) @@ -227,7 +227,7 @@ func TestGetMilestonesByRepoIDs(t *testing.T) { } assert.True(t, sort.IntsAreSorted(values)) - closedMilestones, err := issues_model.GetMilestonesByRepoIDs([]int64{repo1.ID, repo2.ID}, page, true, sortType) + closedMilestones, err := issues_model.GetMilestonesByRepoIDs(db.DefaultContext, []int64{repo1.ID, repo2.ID}, page, true, sortType) assert.NoError(t, err) assert.Len(t, closedMilestones, repo1.NumClosedMilestones+repo2.NumClosedMilestones) values = make([]int, len(closedMilestones)) @@ -262,7 +262,7 @@ func TestGetMilestonesStats(t *testing.T) { test := func(repoID int64) { repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID}) - stats, err := issues_model.GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": repoID})) + stats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.And(builder.Eq{"repo_id": repoID})) assert.NoError(t, err) assert.EqualValues(t, repo.NumMilestones-repo.NumClosedMilestones, stats.OpenCount) assert.EqualValues(t, repo.NumClosedMilestones, stats.ClosedCount) @@ -271,7 +271,7 @@ func TestGetMilestonesStats(t *testing.T) { test(2) test(3) - stats, err := issues_model.GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": unittest.NonexistentID})) + stats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.And(builder.Eq{"repo_id": unittest.NonexistentID})) assert.NoError(t, err) assert.EqualValues(t, 0, stats.OpenCount) assert.EqualValues(t, 0, stats.ClosedCount) @@ -279,7 +279,7 @@ func TestGetMilestonesStats(t *testing.T) { repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1}) repo2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 2}) - milestoneStats, err := issues_model.GetMilestonesStatsByRepoCond(builder.In("repo_id", []int64{repo1.ID, repo2.ID})) + milestoneStats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{repo1.ID, repo2.ID})) assert.NoError(t, err) assert.EqualValues(t, repo1.NumOpenMilestones+repo2.NumOpenMilestones, milestoneStats.OpenCount) assert.EqualValues(t, repo1.NumClosedMilestones+repo2.NumClosedMilestones, milestoneStats.ClosedCount) @@ -293,7 +293,7 @@ func TestNewMilestone(t *testing.T) { Content: "milestoneContent", } - assert.NoError(t, issues_model.NewMilestone(milestone)) + assert.NoError(t, issues_model.NewMilestone(db.DefaultContext, milestone)) unittest.AssertExistsAndLoadBean(t, milestone) unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{}) } @@ -302,22 +302,22 @@ func TestChangeMilestoneStatus(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) - assert.NoError(t, issues_model.ChangeMilestoneStatus(milestone, true)) + assert.NoError(t, issues_model.ChangeMilestoneStatus(db.DefaultContext, milestone, true)) unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}, "is_closed=1") unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{}) - assert.NoError(t, issues_model.ChangeMilestoneStatus(milestone, false)) + assert.NoError(t, issues_model.ChangeMilestoneStatus(db.DefaultContext, milestone, false)) unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}, "is_closed=0") unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{}) } func TestDeleteMilestoneByRepoID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.NoError(t, issues_model.DeleteMilestoneByRepoID(1, 1)) + assert.NoError(t, issues_model.DeleteMilestoneByRepoID(db.DefaultContext, 1, 1)) unittest.AssertNotExistsBean(t, &issues_model.Milestone{ID: 1}) unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: 1}) - assert.NoError(t, issues_model.DeleteMilestoneByRepoID(unittest.NonexistentID, unittest.NonexistentID)) + assert.NoError(t, issues_model.DeleteMilestoneByRepoID(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) } func TestUpdateMilestone(t *testing.T) { @@ -326,7 +326,7 @@ func TestUpdateMilestone(t *testing.T) { milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) milestone.Name = " newMilestoneName " milestone.Content = "newMilestoneContent" - assert.NoError(t, issues_model.UpdateMilestone(milestone, milestone.IsClosed)) + assert.NoError(t, issues_model.UpdateMilestone(db.DefaultContext, milestone, milestone.IsClosed)) milestone = unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}) assert.EqualValues(t, "newMilestoneName", milestone.Name) unittest.CheckConsistencyFor(t, &issues_model.Milestone{}) @@ -361,7 +361,7 @@ func TestMigrate_InsertMilestones(t *testing.T) { RepoID: repo.ID, Name: name, } - err := issues_model.InsertMilestones(ms) + err := issues_model.InsertMilestones(db.DefaultContext, ms) assert.NoError(t, err) unittest.AssertExistsAndLoadBean(t, ms) repoModified := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repo.ID}) diff --git a/models/issues/stopwatch.go b/models/issues/stopwatch.go index c8cd5ad33f..2c662bdb06 100644 --- a/models/issues/stopwatch.go +++ b/models/issues/stopwatch.go @@ -81,9 +81,9 @@ type UserStopwatch struct { } // GetUIDsAndNotificationCounts between the two provided times -func GetUIDsAndStopwatch() ([]*UserStopwatch, error) { +func GetUIDsAndStopwatch(ctx context.Context) ([]*UserStopwatch, error) { sws := []*Stopwatch{} - if err := db.GetEngine(db.DefaultContext).Where("issue_id != 0").Find(&sws); err != nil { + if err := db.GetEngine(ctx).Where("issue_id != 0").Find(&sws); err != nil { return nil, err } if len(sws) == 0 { @@ -107,9 +107,9 @@ func GetUIDsAndStopwatch() ([]*UserStopwatch, error) { } // GetUserStopwatches return list of all stopwatches of a user -func GetUserStopwatches(userID int64, listOptions db.ListOptions) ([]*Stopwatch, error) { +func GetUserStopwatches(ctx context.Context, userID int64, listOptions db.ListOptions) ([]*Stopwatch, error) { sws := make([]*Stopwatch, 0, 8) - sess := db.GetEngine(db.DefaultContext).Where("stopwatch.user_id = ?", userID) + sess := db.GetEngine(ctx).Where("stopwatch.user_id = ?", userID) if listOptions.Page != 0 { sess = db.SetSessionPagination(sess, &listOptions) } @@ -122,13 +122,13 @@ func GetUserStopwatches(userID int64, listOptions db.ListOptions) ([]*Stopwatch, } // CountUserStopwatches return count of all stopwatches of a user -func CountUserStopwatches(userID int64) (int64, error) { - return db.GetEngine(db.DefaultContext).Where("user_id = ?", userID).Count(&Stopwatch{}) +func CountUserStopwatches(ctx context.Context, userID int64) (int64, error) { + return db.GetEngine(ctx).Where("user_id = ?", userID).Count(&Stopwatch{}) } // StopwatchExists returns true if the stopwatch exists -func StopwatchExists(userID, issueID int64) bool { - _, exists, _ := getStopwatch(db.DefaultContext, userID, issueID) +func StopwatchExists(ctx context.Context, userID, issueID int64) bool { + _, exists, _ := getStopwatch(ctx, userID, issueID) return exists } @@ -168,15 +168,15 @@ func FinishIssueStopwatchIfPossible(ctx context.Context, user *user_model.User, } // CreateOrStopIssueStopwatch create an issue stopwatch if it's not exist, otherwise finish it -func CreateOrStopIssueStopwatch(user *user_model.User, issue *Issue) error { - _, exists, err := getStopwatch(db.DefaultContext, user.ID, issue.ID) +func CreateOrStopIssueStopwatch(ctx context.Context, user *user_model.User, issue *Issue) error { + _, exists, err := getStopwatch(ctx, user.ID, issue.ID) if err != nil { return err } if exists { - return FinishIssueStopwatch(db.DefaultContext, user, issue) + return FinishIssueStopwatch(ctx, user, issue) } - return CreateIssueStopwatch(db.DefaultContext, user, issue) + return CreateIssueStopwatch(ctx, user, issue) } // FinishIssueStopwatch if stopwatch exist then finish it otherwise return an error @@ -269,8 +269,8 @@ func CreateIssueStopwatch(ctx context.Context, user *user_model.User, issue *Iss } // CancelStopwatch removes the given stopwatch and logs it into issue's timeline. -func CancelStopwatch(user *user_model.User, issue *Issue) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func CancelStopwatch(ctx context.Context, user *user_model.User, issue *Issue) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/issues/stopwatch_test.go b/models/issues/stopwatch_test.go index fa937ecbed..39958a7f36 100644 --- a/models/issues/stopwatch_test.go +++ b/models/issues/stopwatch_test.go @@ -26,20 +26,20 @@ func TestCancelStopwatch(t *testing.T) { issue2, err := issues_model.GetIssueByID(db.DefaultContext, 2) assert.NoError(t, err) - err = issues_model.CancelStopwatch(user1, issue1) + err = issues_model.CancelStopwatch(db.DefaultContext, user1, issue1) assert.NoError(t, err) unittest.AssertNotExistsBean(t, &issues_model.Stopwatch{UserID: user1.ID, IssueID: issue1.ID}) _ = unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{Type: issues_model.CommentTypeCancelTracking, PosterID: user1.ID, IssueID: issue1.ID}) - assert.Nil(t, issues_model.CancelStopwatch(user1, issue2)) + assert.Nil(t, issues_model.CancelStopwatch(db.DefaultContext, user1, issue2)) } func TestStopwatchExists(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.True(t, issues_model.StopwatchExists(1, 1)) - assert.False(t, issues_model.StopwatchExists(1, 2)) + assert.True(t, issues_model.StopwatchExists(db.DefaultContext, 1, 1)) + assert.False(t, issues_model.StopwatchExists(db.DefaultContext, 1, 2)) } func TestHasUserStopwatch(t *testing.T) { @@ -68,11 +68,11 @@ func TestCreateOrStopIssueStopwatch(t *testing.T) { issue2, err := issues_model.GetIssueByID(db.DefaultContext, 2) assert.NoError(t, err) - assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(org3, issue1)) + assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(db.DefaultContext, org3, issue1)) sw := unittest.AssertExistsAndLoadBean(t, &issues_model.Stopwatch{UserID: 3, IssueID: 1}) assert.LessOrEqual(t, sw.CreatedUnix, timeutil.TimeStampNow()) - assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(user2, issue2)) + assert.NoError(t, issues_model.CreateOrStopIssueStopwatch(db.DefaultContext, user2, issue2)) unittest.AssertNotExistsBean(t, &issues_model.Stopwatch{UserID: 2, IssueID: 2}) unittest.AssertExistsAndLoadBean(t, &issues_model.TrackedTime{UserID: 2, IssueID: 2}) } diff --git a/models/organization/mini_org.go b/models/organization/mini_org.go index b1627b5e6c..b1b24624c5 100644 --- a/models/organization/mini_org.go +++ b/models/organization/mini_org.go @@ -4,6 +4,7 @@ package organization import ( + "context" "fmt" "strings" @@ -19,7 +20,7 @@ import ( type MinimalOrg = Organization // GetUserOrgsList returns all organizations the given user has access to -func GetUserOrgsList(user *user_model.User) ([]*MinimalOrg, error) { +func GetUserOrgsList(ctx context.Context, user *user_model.User) ([]*MinimalOrg, error) { schema, err := db.TableInfo(new(user_model.User)) if err != nil { return nil, err @@ -42,7 +43,7 @@ func GetUserOrgsList(user *user_model.User) ([]*MinimalOrg, error) { groupByStr := groupByCols.String() groupByStr = groupByStr[0 : len(groupByStr)-1] - sess := db.GetEngine(db.DefaultContext) + sess := db.GetEngine(ctx) sess = sess.Select(groupByStr+", count(distinct repo_id) as org_count"). Table("user"). Join("INNER", "team", "`team`.org_id = `user`.id"). diff --git a/models/repo/archiver.go b/models/repo/archiver.go index 70f53cfe15..6d0ed42877 100644 --- a/models/repo/archiver.go +++ b/models/repo/archiver.go @@ -72,7 +72,7 @@ var delRepoArchiver = new(RepoArchiver) // DeleteRepoArchiver delete archiver func DeleteRepoArchiver(ctx context.Context, archiver *RepoArchiver) error { - _, err := db.GetEngine(db.DefaultContext).ID(archiver.ID).Delete(delRepoArchiver) + _, err := db.GetEngine(ctx).ID(archiver.ID).Delete(delRepoArchiver) return err } @@ -113,8 +113,8 @@ func UpdateRepoArchiverStatus(ctx context.Context, archiver *RepoArchiver) error } // DeleteAllRepoArchives deletes all repo archives records -func DeleteAllRepoArchives() error { - _, err := db.GetEngine(db.DefaultContext).Where("1=1").Delete(new(RepoArchiver)) +func DeleteAllRepoArchives(ctx context.Context) error { + _, err := db.GetEngine(ctx).Where("1=1").Delete(new(RepoArchiver)) return err } @@ -133,10 +133,10 @@ func (opts FindRepoArchiversOption) toConds() builder.Cond { } // FindRepoArchives find repo archivers -func FindRepoArchives(opts FindRepoArchiversOption) ([]*RepoArchiver, error) { +func FindRepoArchives(ctx context.Context, opts FindRepoArchiversOption) ([]*RepoArchiver, error) { archivers := make([]*RepoArchiver, 0, opts.PageSize) start, limit := opts.GetSkipTake() - err := db.GetEngine(db.DefaultContext).Where(opts.toConds()). + err := db.GetEngine(ctx).Where(opts.toConds()). Asc("created_unix"). Limit(limit, start). Find(&archivers) @@ -144,7 +144,7 @@ func FindRepoArchives(opts FindRepoArchiversOption) ([]*RepoArchiver, error) { } // SetArchiveRepoState sets if a repo is archived -func SetArchiveRepoState(repo *Repository, isArchived bool) (err error) { +func SetArchiveRepoState(ctx context.Context, repo *Repository, isArchived bool) (err error) { repo.IsArchived = isArchived if isArchived { @@ -153,6 +153,6 @@ func SetArchiveRepoState(repo *Repository, isArchived bool) (err error) { repo.ArchivedUnix = timeutil.TimeStamp(0) } - _, err = db.GetEngine(db.DefaultContext).ID(repo.ID).Cols("is_archived", "archived_unix").NoAutoTime().Update(repo) + _, err = db.GetEngine(ctx).ID(repo.ID).Cols("is_archived", "archived_unix").NoAutoTime().Update(repo) return err } diff --git a/models/repo/topic.go b/models/repo/topic.go index 71302388b9..ca533fc1e0 100644 --- a/models/repo/topic.go +++ b/models/repo/topic.go @@ -92,9 +92,9 @@ func SanitizeAndValidateTopics(topics []string) (validTopics, invalidTopics []st } // GetTopicByName retrieves topic by name -func GetTopicByName(name string) (*Topic, error) { +func GetTopicByName(ctx context.Context, name string) (*Topic, error) { var topic Topic - if has, err := db.GetEngine(db.DefaultContext).Where("name = ?", name).Get(&topic); err != nil { + if has, err := db.GetEngine(ctx).Where("name = ?", name).Get(&topic); err != nil { return nil, err } else if !has { return nil, ErrTopicNotExist{name} @@ -192,8 +192,8 @@ func (opts *FindTopicOptions) toConds() builder.Cond { } // FindTopics retrieves the topics via FindTopicOptions -func FindTopics(opts *FindTopicOptions) ([]*Topic, int64, error) { - sess := db.GetEngine(db.DefaultContext).Select("topic.*").Where(opts.toConds()) +func FindTopics(ctx context.Context, opts *FindTopicOptions) ([]*Topic, int64, error) { + sess := db.GetEngine(ctx).Select("topic.*").Where(opts.toConds()) orderBy := "topic.repo_count DESC" if opts.RepoID > 0 { sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") @@ -208,8 +208,8 @@ func FindTopics(opts *FindTopicOptions) ([]*Topic, int64, error) { } // CountTopics counts the number of topics matching the FindTopicOptions -func CountTopics(opts *FindTopicOptions) (int64, error) { - sess := db.GetEngine(db.DefaultContext).Where(opts.toConds()) +func CountTopics(ctx context.Context, opts *FindTopicOptions) (int64, error) { + sess := db.GetEngine(ctx).Where(opts.toConds()) if opts.RepoID > 0 { sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") } @@ -231,8 +231,8 @@ func GetRepoTopicByName(ctx context.Context, repoID int64, topicName string) (*T } // AddTopic adds a topic name to a repository (if it does not already have it) -func AddTopic(repoID int64, topicName string) (*Topic, error) { - ctx, committer, err := db.TxContext(db.DefaultContext) +func AddTopic(ctx context.Context, repoID int64, topicName string) (*Topic, error) { + ctx, committer, err := db.TxContext(ctx) if err != nil { return nil, err } @@ -261,8 +261,8 @@ func AddTopic(repoID int64, topicName string) (*Topic, error) { } // DeleteTopic removes a topic name from a repository (if it has it) -func DeleteTopic(repoID int64, topicName string) (*Topic, error) { - topic, err := GetRepoTopicByName(db.DefaultContext, repoID, topicName) +func DeleteTopic(ctx context.Context, repoID int64, topicName string) (*Topic, error) { + topic, err := GetRepoTopicByName(ctx, repoID, topicName) if err != nil { return nil, err } @@ -271,26 +271,26 @@ func DeleteTopic(repoID int64, topicName string) (*Topic, error) { return nil, nil } - err = removeTopicFromRepo(db.DefaultContext, repoID, topic) + err = removeTopicFromRepo(ctx, repoID, topic) if err != nil { return nil, err } - err = syncTopicsInRepository(db.GetEngine(db.DefaultContext), repoID) + err = syncTopicsInRepository(db.GetEngine(ctx), repoID) return topic, err } // SaveTopics save topics to a repository -func SaveTopics(repoID int64, topicNames ...string) error { - topics, _, err := FindTopics(&FindTopicOptions{ +func SaveTopics(ctx context.Context, repoID int64, topicNames ...string) error { + topics, _, err := FindTopics(ctx, &FindTopicOptions{ RepoID: repoID, }) if err != nil { return err } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/repo/topic_test.go b/models/repo/topic_test.go index aaed91bdd3..2b609e6d66 100644 --- a/models/repo/topic_test.go +++ b/models/repo/topic_test.go @@ -19,47 +19,47 @@ func TestAddTopic(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - topics, _, err := repo_model.FindTopics(&repo_model.FindTopicOptions{}) + topics, _, err := repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) assert.NoError(t, err) assert.Len(t, topics, totalNrOfTopics) - topics, total, err := repo_model.FindTopics(&repo_model.FindTopicOptions{ + topics, total, err := repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ ListOptions: db.ListOptions{Page: 1, PageSize: 2}, }) assert.NoError(t, err) assert.Len(t, topics, 2) assert.EqualValues(t, 6, total) - topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{ + topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ RepoID: 1, }) assert.NoError(t, err) assert.Len(t, topics, repo1NrOfTopics) - assert.NoError(t, repo_model.SaveTopics(2, "golang")) + assert.NoError(t, repo_model.SaveTopics(db.DefaultContext, 2, "golang")) repo2NrOfTopics := 1 - topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{}) + topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) assert.NoError(t, err) assert.Len(t, topics, totalNrOfTopics) - topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{ + topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ RepoID: 2, }) assert.NoError(t, err) assert.Len(t, topics, repo2NrOfTopics) - assert.NoError(t, repo_model.SaveTopics(2, "golang", "gitea")) + assert.NoError(t, repo_model.SaveTopics(db.DefaultContext, 2, "golang", "gitea")) repo2NrOfTopics = 2 totalNrOfTopics++ - topic, err := repo_model.GetTopicByName("gitea") + topic, err := repo_model.GetTopicByName(db.DefaultContext, "gitea") assert.NoError(t, err) assert.EqualValues(t, 1, topic.RepoCount) - topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{}) + topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{}) assert.NoError(t, err) assert.Len(t, topics, totalNrOfTopics) - topics, _, err = repo_model.FindTopics(&repo_model.FindTopicOptions{ + topics, _, err = repo_model.FindTopics(db.DefaultContext, &repo_model.FindTopicOptions{ RepoID: 2, }) assert.NoError(t, err) diff --git a/models/repo/update.go b/models/repo/update.go index c4fba32ad2..6ddf1a8905 100644 --- a/models/repo/update.go +++ b/models/repo/update.go @@ -16,11 +16,11 @@ import ( ) // UpdateRepositoryOwnerNames updates repository owner_names (this should only be used when the ownerName has changed case) -func UpdateRepositoryOwnerNames(ownerID int64, ownerName string) error { +func UpdateRepositoryOwnerNames(ctx context.Context, ownerID int64, ownerName string) error { if ownerID == 0 { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -36,8 +36,8 @@ func UpdateRepositoryOwnerNames(ownerID int64, ownerName string) error { } // UpdateRepositoryUpdatedTime updates a repository's updated time -func UpdateRepositoryUpdatedTime(repoID int64, updateTime time.Time) error { - _, err := db.GetEngine(db.DefaultContext).Exec("UPDATE repository SET updated_unix = ? WHERE id = ?", updateTime.Unix(), repoID) +func UpdateRepositoryUpdatedTime(ctx context.Context, repoID int64, updateTime time.Time) error { + _, err := db.GetEngine(ctx).Exec("UPDATE repository SET updated_unix = ? WHERE id = ?", updateTime.Unix(), repoID) return err } @@ -107,7 +107,7 @@ func (err ErrRepoFilesAlreadyExist) Unwrap() error { } // CheckCreateRepository check if could created a repository -func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdopt bool) error { +func CheckCreateRepository(ctx context.Context, doer, u *user_model.User, name string, overwriteOrAdopt bool) error { if !doer.CanCreateRepo() { return ErrReachLimitOfRepo{u.MaxRepoCreation} } @@ -116,7 +116,7 @@ func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdo return err } - has, err := IsRepositoryModelOrDirExist(db.DefaultContext, u, name) + has, err := IsRepositoryModelOrDirExist(ctx, u, name) if err != nil { return fmt.Errorf("IsRepositoryExist: %w", err) } else if has { @@ -136,18 +136,18 @@ func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdo } // ChangeRepositoryName changes all corresponding setting from old repository name to new one. -func ChangeRepositoryName(doer *user_model.User, repo *Repository, newRepoName string) (err error) { +func ChangeRepositoryName(ctx context.Context, doer *user_model.User, repo *Repository, newRepoName string) (err error) { oldRepoName := repo.Name newRepoName = strings.ToLower(newRepoName) if err = IsUsableRepoName(newRepoName); err != nil { return err } - if err := repo.LoadOwner(db.DefaultContext); err != nil { + if err := repo.LoadOwner(ctx); err != nil { return err } - has, err := IsRepositoryModelOrDirExist(db.DefaultContext, repo.Owner, newRepoName) + has, err := IsRepositoryModelOrDirExist(ctx, repo.Owner, newRepoName) if err != nil { return fmt.Errorf("IsRepositoryExist: %w", err) } else if has { @@ -171,7 +171,7 @@ func ChangeRepositoryName(doer *user_model.User, repo *Repository, newRepoName s } } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/repo_transfer.go b/models/repo_transfer.go index 1c873cec57..630c243c8e 100644 --- a/models/repo_transfer.go +++ b/models/repo_transfer.go @@ -79,8 +79,8 @@ func (r *RepoTransfer) LoadAttributes(ctx context.Context) error { // CanUserAcceptTransfer checks if the user has the rights to accept/decline a repo transfer. // For user, it checks if it's himself // For organizations, it checks if the user is able to create repos -func (r *RepoTransfer) CanUserAcceptTransfer(u *user_model.User) bool { - if err := r.LoadAttributes(db.DefaultContext); err != nil { +func (r *RepoTransfer) CanUserAcceptTransfer(ctx context.Context, u *user_model.User) bool { + if err := r.LoadAttributes(ctx); err != nil { log.Error("LoadAttributes: %v", err) return false } @@ -89,7 +89,7 @@ func (r *RepoTransfer) CanUserAcceptTransfer(u *user_model.User) bool { return r.RecipientID == u.ID } - allowed, err := organization.CanCreateOrgRepo(db.DefaultContext, r.RecipientID, u.ID) + allowed, err := organization.CanCreateOrgRepo(ctx, r.RecipientID, u.ID) if err != nil { log.Error("CanCreateOrgRepo: %v", err) return false @@ -122,8 +122,8 @@ func deleteRepositoryTransfer(ctx context.Context, repoID int64) error { // CancelRepositoryTransfer marks the repository as ready and remove pending transfer entry, // thus cancel the transfer process. -func CancelRepositoryTransfer(repo *repo_model.Repository) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func CancelRepositoryTransfer(ctx context.Context, repo *repo_model.Repository) error { + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -199,7 +199,7 @@ func CreatePendingRepositoryTransfer(ctx context.Context, doer, newOwner *user_m } // TransferOwnership transfers all corresponding repository items from old user to new one. -func TransferOwnership(doer *user_model.User, newOwnerName string, repo *repo_model.Repository) (err error) { +func TransferOwnership(ctx context.Context, doer *user_model.User, newOwnerName string, repo *repo_model.Repository) (err error) { repoRenamed := false wikiRenamed := false oldOwnerName := doer.Name @@ -234,7 +234,7 @@ func TransferOwnership(doer *user_model.User, newOwnerName string, repo *repo_mo } }() - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/repo_transfer_test.go b/models/repo_transfer_test.go index 7364d4d02c..b55cef9473 100644 --- a/models/repo_transfer_test.go +++ b/models/repo_transfer_test.go @@ -25,7 +25,7 @@ func TestRepositoryTransfer(t *testing.T) { assert.NotNil(t, transfer) // Cancel transfer - assert.NoError(t, CancelRepositoryTransfer(repo)) + assert.NoError(t, CancelRepositoryTransfer(db.DefaultContext, repo)) transfer, err = GetPendingRepositoryTransfer(db.DefaultContext, repo) assert.Error(t, err) @@ -53,5 +53,5 @@ func TestRepositoryTransfer(t *testing.T) { assert.Error(t, err) // Cancel transfer - assert.NoError(t, CancelRepositoryTransfer(repo)) + assert.NoError(t, CancelRepositoryTransfer(db.DefaultContext, repo)) } diff --git a/models/user/follow.go b/models/user/follow.go index 7efecc26a7..f4dd2891ff 100644 --- a/models/user/follow.go +++ b/models/user/follow.go @@ -4,6 +4,8 @@ package user import ( + "context" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/timeutil" ) @@ -21,18 +23,18 @@ func init() { } // IsFollowing returns true if user is following followID. -func IsFollowing(userID, followID int64) bool { - has, _ := db.GetEngine(db.DefaultContext).Get(&Follow{UserID: userID, FollowID: followID}) +func IsFollowing(ctx context.Context, userID, followID int64) bool { + has, _ := db.GetEngine(ctx).Get(&Follow{UserID: userID, FollowID: followID}) return has } // FollowUser marks someone be another's follower. -func FollowUser(userID, followID int64) (err error) { - if userID == followID || IsFollowing(userID, followID) { +func FollowUser(ctx context.Context, userID, followID int64) (err error) { + if userID == followID || IsFollowing(ctx, userID, followID) { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } @@ -53,12 +55,12 @@ func FollowUser(userID, followID int64) (err error) { } // UnfollowUser unmarks someone as another's follower. -func UnfollowUser(userID, followID int64) (err error) { - if userID == followID || !IsFollowing(userID, followID) { +func UnfollowUser(ctx context.Context, userID, followID int64) (err error) { + if userID == followID || !IsFollowing(ctx, userID, followID) { return nil } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/models/user/follow_test.go b/models/user/follow_test.go index fc408d5257..c327d935ae 100644 --- a/models/user/follow_test.go +++ b/models/user/follow_test.go @@ -6,6 +6,7 @@ package user_test import ( "testing" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -14,9 +15,9 @@ import ( func TestIsFollowing(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.True(t, user_model.IsFollowing(4, 2)) - assert.False(t, user_model.IsFollowing(2, 4)) - assert.False(t, user_model.IsFollowing(5, unittest.NonexistentID)) - assert.False(t, user_model.IsFollowing(unittest.NonexistentID, 5)) - assert.False(t, user_model.IsFollowing(unittest.NonexistentID, unittest.NonexistentID)) + assert.True(t, user_model.IsFollowing(db.DefaultContext, 4, 2)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, 2, 4)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, 5, unittest.NonexistentID)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, unittest.NonexistentID, 5)) + assert.False(t, user_model.IsFollowing(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) } diff --git a/models/user/user.go b/models/user/user.go index b3956da1cb..63b95816ce 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -1246,7 +1246,7 @@ func IsUserVisibleToViewer(ctx context.Context, u, viewer *User) bool { } // If they follow - they see each over - follower := IsFollowing(u.ID, viewer.ID) + follower := IsFollowing(ctx, u.ID, viewer.ID) if follower { return true } diff --git a/models/user/user_test.go b/models/user/user_test.go index b15f0cbc59..971117482c 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -449,13 +449,13 @@ func TestFollowUser(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(followerID, followedID int64) { - assert.NoError(t, user_model.FollowUser(followerID, followedID)) + assert.NoError(t, user_model.FollowUser(db.DefaultContext, followerID, followedID)) unittest.AssertExistsAndLoadBean(t, &user_model.Follow{UserID: followerID, FollowID: followedID}) } testSuccess(4, 2) testSuccess(5, 2) - assert.NoError(t, user_model.FollowUser(2, 2)) + assert.NoError(t, user_model.FollowUser(db.DefaultContext, 2, 2)) unittest.CheckConsistencyFor(t, &user_model.User{}) } @@ -464,7 +464,7 @@ func TestUnfollowUser(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(followerID, followedID int64) { - assert.NoError(t, user_model.UnfollowUser(followerID, followedID)) + assert.NoError(t, user_model.UnfollowUser(db.DefaultContext, followerID, followedID)) unittest.AssertNotExistsBean(t, &user_model.Follow{UserID: followerID, FollowID: followedID}) } testSuccess(4, 2) diff --git a/modules/auth/webauthn/webauthn.go b/modules/auth/webauthn/webauthn.go index e732878f85..189d197333 100644 --- a/modules/auth/webauthn/webauthn.go +++ b/modules/auth/webauthn/webauthn.go @@ -68,7 +68,7 @@ func (u *User) WebAuthnIcon() string { // WebAuthnCredentials implementns the webauthn.User interface func (u *User) WebAuthnCredentials() []webauthn.Credential { - dbCreds, err := auth.GetWebAuthnCredentialsByUID(u.ID) + dbCreds, err := auth.GetWebAuthnCredentialsByUID(db.DefaultContext, u.ID) if err != nil { return nil } diff --git a/modules/context/repo.go b/modules/context/repo.go index f9c966d5be..7355dc9af2 100644 --- a/modules/context/repo.go +++ b/modules/context/repo.go @@ -740,7 +740,7 @@ func RepoAssignment(ctx *Context) context.CancelFunc { ctx.Data["RepoTransfer"] = repoTransfer if ctx.Doer != nil { - ctx.Data["CanUserAcceptTransfer"] = repoTransfer.CanUserAcceptTransfer(ctx.Doer) + ctx.Data["CanUserAcceptTransfer"] = repoTransfer.CanUserAcceptTransfer(ctx, ctx.Doer) } } diff --git a/modules/eventsource/manager_run.go b/modules/eventsource/manager_run.go index 35dfc62f1e..2785836b89 100644 --- a/modules/eventsource/manager_run.go +++ b/modules/eventsource/manager_run.go @@ -84,7 +84,7 @@ loop: then = now if setting.Service.EnableTimetracking { - usersStopwatches, err := issues_model.GetUIDsAndStopwatch() + usersStopwatches, err := issues_model.GetUIDsAndStopwatch(ctx) if err != nil { log.Error("Unable to get GetUIDsAndStopwatch: %v", err) return diff --git a/modules/indexer/issues/db/options.go b/modules/indexer/issues/db/options.go index 0d6a8406d3..4d3fa44ca6 100644 --- a/modules/indexer/issues/db/options.go +++ b/modules/indexer/issues/db/options.go @@ -97,7 +97,7 @@ func ToDBOptions(ctx context.Context, options *internal.SearchOptions) (*issue_m if len(options.IncludedLabelIDs) == 0 && len(options.IncludedAnyLabelIDs) > 0 { _ = ctx // issue_model.GetLabelsByIDs should be called with ctx, this line can be removed when it's done. - labels, err := issue_model.GetLabelsByIDs(options.IncludedAnyLabelIDs, "name") + labels, err := issue_model.GetLabelsByIDs(ctx, options.IncludedAnyLabelIDs, "name") if err != nil { return nil, fmt.Errorf("GetLabelsByIDs: %v", err) } diff --git a/modules/session/db.go b/modules/session/db.go index f86f7d1e9c..9909f2dc1e 100644 --- a/modules/session/db.go +++ b/modules/session/db.go @@ -8,6 +8,7 @@ import ( "sync" "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/timeutil" "gitea.com/go-chi/session" @@ -71,7 +72,7 @@ func (s *DBStore) Release() error { return err } - return auth.UpdateSession(s.sid, data) + return auth.UpdateSession(db.DefaultContext, s.sid, data) } // Flush deletes all session data. @@ -97,7 +98,7 @@ func (p *DBProvider) Init(maxLifetime int64, connStr string) error { // Read returns raw session store by session ID. func (p *DBProvider) Read(sid string) (session.RawStore, error) { - s, err := auth.ReadSession(sid) + s, err := auth.ReadSession(db.DefaultContext, sid) if err != nil { return nil, err } @@ -117,7 +118,7 @@ func (p *DBProvider) Read(sid string) (session.RawStore, error) { // Exist returns true if session with given ID exists. func (p *DBProvider) Exist(sid string) bool { - has, err := auth.ExistSession(sid) + has, err := auth.ExistSession(db.DefaultContext, sid) if err != nil { panic("session/DB: error checking existence: " + err.Error()) } @@ -126,12 +127,12 @@ func (p *DBProvider) Exist(sid string) bool { // Destroy deletes a session by session ID. func (p *DBProvider) Destroy(sid string) error { - return auth.DestroySession(sid) + return auth.DestroySession(db.DefaultContext, sid) } // Regenerate regenerates a session store from old session ID to new one. func (p *DBProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { - s, err := auth.RegenerateSession(oldsid, sid) + s, err := auth.RegenerateSession(db.DefaultContext, oldsid, sid) if err != nil { return nil, err } @@ -151,7 +152,7 @@ func (p *DBProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err err // Count counts and returns number of sessions. func (p *DBProvider) Count() int { - total, err := auth.CountSessions() + total, err := auth.CountSessions(db.DefaultContext) if err != nil { panic("session/DB: error counting records: " + err.Error()) } @@ -160,7 +161,7 @@ func (p *DBProvider) Count() int { // GC calls GC to clean expired sessions. func (p *DBProvider) GC() { - if err := auth.CleanupSessions(p.maxLifetime); err != nil { + if err := auth.CleanupSessions(db.DefaultContext, p.maxLifetime); err != nil { log.Printf("session/DB: error garbage collecting: %v", err) } } diff --git a/routers/api/v1/org/label.go b/routers/api/v1/org/label.go index 2dd4505a91..5a03059ded 100644 --- a/routers/api/v1/org/label.go +++ b/routers/api/v1/org/label.go @@ -50,7 +50,7 @@ func ListLabels(ctx *context.APIContext) { return } - count, err := issues_model.CountLabelsByOrgID(ctx.Org.Organization.ID) + count, err := issues_model.CountLabelsByOrgID(ctx, ctx.Org.Organization.ID) if err != nil { ctx.InternalServerError(err) return @@ -218,7 +218,7 @@ func EditLabel(ctx *context.APIContext) { l.Description = *form.Description } l.SetArchived(form.IsArchived != nil && *form.IsArchived) - if err := issues_model.UpdateLabel(l); err != nil { + if err := issues_model.UpdateLabel(ctx, l); err != nil { ctx.Error(http.StatusInternalServerError, "UpdateLabel", err) return } @@ -249,7 +249,7 @@ func DeleteLabel(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" - if err := issues_model.DeleteLabel(ctx.Org.Organization.ID, ctx.ParamsInt64(":id")); err != nil { + if err := issues_model.DeleteLabel(ctx, ctx.Org.Organization.ID, ctx.ParamsInt64(":id")); err != nil { ctx.Error(http.StatusInternalServerError, "DeleteLabel", err) return } diff --git a/routers/api/v1/repo/collaborators.go b/routers/api/v1/repo/collaborators.go index 4be43d46ad..206e3fb29b 100644 --- a/routers/api/v1/repo/collaborators.go +++ b/routers/api/v1/repo/collaborators.go @@ -234,7 +234,7 @@ func DeleteCollaborator(ctx *context.APIContext) { return } - if err := repo_service.DeleteCollaboration(ctx.Repo.Repository, collaborator.ID); err != nil { + if err := repo_service.DeleteCollaboration(ctx, ctx.Repo.Repository, collaborator.ID); err != nil { ctx.Error(http.StatusInternalServerError, "DeleteCollaboration", err) return } diff --git a/routers/api/v1/repo/issue.go b/routers/api/v1/repo/issue.go index c6248bacec..05dfa45e3d 100644 --- a/routers/api/v1/repo/issue.go +++ b/routers/api/v1/repo/issue.go @@ -413,7 +413,7 @@ func ListIssues(ctx *context.APIContext) { var labelIDs []int64 if splitted := strings.Split(ctx.FormString("labels"), ","); len(splitted) > 0 { - labelIDs, err = issues_model.GetLabelIDsInRepoByNames(ctx.Repo.Repository.ID, splitted) + labelIDs, err = issues_model.GetLabelIDsInRepoByNames(ctx, ctx.Repo.Repository.ID, splitted) if err != nil { ctx.Error(http.StatusInternalServerError, "GetLabelIDsInRepoByNames", err) return @@ -425,7 +425,7 @@ func ListIssues(ctx *context.APIContext) { for i := range part { // uses names and fall back to ids // non existent milestones are discarded - mile, err := issues_model.GetMilestoneByRepoIDANDName(ctx.Repo.Repository.ID, part[i]) + mile, err := issues_model.GetMilestoneByRepoIDANDName(ctx, ctx.Repo.Repository.ID, part[i]) if err == nil { mileIDs = append(mileIDs, mile.ID) continue diff --git a/routers/api/v1/repo/issue_label.go b/routers/api/v1/repo/issue_label.go index b050a397f2..2f9ad7060c 100644 --- a/routers/api/v1/repo/issue_label.go +++ b/routers/api/v1/repo/issue_label.go @@ -107,7 +107,7 @@ func AddIssueLabels(ctx *context.APIContext) { return } - if err = issue_service.AddLabels(issue, ctx.Doer, labels); err != nil { + if err = issue_service.AddLabels(ctx, issue, ctx.Doer, labels); err != nil { ctx.Error(http.StatusInternalServerError, "AddLabels", err) return } @@ -186,7 +186,7 @@ func DeleteIssueLabel(ctx *context.APIContext) { return } - if err := issue_service.RemoveLabel(issue, ctx.Doer, label); err != nil { + if err := issue_service.RemoveLabel(ctx, issue, ctx.Doer, label); err != nil { ctx.Error(http.StatusInternalServerError, "DeleteIssueLabel", err) return } @@ -237,7 +237,7 @@ func ReplaceIssueLabels(ctx *context.APIContext) { return } - if err := issue_service.ReplaceLabels(issue, ctx.Doer, labels); err != nil { + if err := issue_service.ReplaceLabels(ctx, issue, ctx.Doer, labels); err != nil { ctx.Error(http.StatusInternalServerError, "ReplaceLabels", err) return } @@ -298,7 +298,7 @@ func ClearIssueLabels(ctx *context.APIContext) { return } - if err := issue_service.ClearLabels(issue, ctx.Doer); err != nil { + if err := issue_service.ClearLabels(ctx, issue, ctx.Doer); err != nil { ctx.Error(http.StatusInternalServerError, "ClearLabels", err) return } @@ -317,7 +317,7 @@ func prepareForReplaceOrAdd(ctx *context.APIContext, form api.IssueLabelsOption) return nil, nil, err } - labels, err := issues_model.GetLabelsByIDs(form.Labels, "id", "repo_id", "org_id") + labels, err := issues_model.GetLabelsByIDs(ctx, form.Labels, "id", "repo_id", "org_id") if err != nil { ctx.Error(http.StatusInternalServerError, "GetLabelsByIDs", err) return nil, nil, err diff --git a/routers/api/v1/repo/issue_stopwatch.go b/routers/api/v1/repo/issue_stopwatch.go index 75fa863138..384532ab87 100644 --- a/routers/api/v1/repo/issue_stopwatch.go +++ b/routers/api/v1/repo/issue_stopwatch.go @@ -152,7 +152,7 @@ func DeleteIssueStopwatch(ctx *context.APIContext) { return } - if err := issues_model.CancelStopwatch(ctx.Doer, issue); err != nil { + if err := issues_model.CancelStopwatch(ctx, ctx.Doer, issue); err != nil { ctx.Error(http.StatusInternalServerError, "CancelStopwatch", err) return } @@ -182,7 +182,7 @@ func prepareIssueStopwatch(ctx *context.APIContext, shouldExist bool) (*issues_m return nil, errors.New("Cannot use time tracker") } - if issues_model.StopwatchExists(ctx.Doer.ID, issue.ID) != shouldExist { + if issues_model.StopwatchExists(ctx, ctx.Doer.ID, issue.ID) != shouldExist { if shouldExist { ctx.Error(http.StatusConflict, "StopwatchExists", "cannot stop/cancel a non existent stopwatch") err = errors.New("cannot stop/cancel a non existent stopwatch") @@ -218,13 +218,13 @@ func GetStopwatches(ctx *context.APIContext) { // "200": // "$ref": "#/responses/StopWatchList" - sws, err := issues_model.GetUserStopwatches(ctx.Doer.ID, utils.GetListOptions(ctx)) + sws, err := issues_model.GetUserStopwatches(ctx, ctx.Doer.ID, utils.GetListOptions(ctx)) if err != nil { ctx.Error(http.StatusInternalServerError, "GetUserStopwatches", err) return } - count, err := issues_model.CountUserStopwatches(ctx.Doer.ID) + count, err := issues_model.CountUserStopwatches(ctx, ctx.Doer.ID) if err != nil { ctx.InternalServerError(err) return diff --git a/routers/api/v1/repo/issue_subscription.go b/routers/api/v1/repo/issue_subscription.go index 1fec029465..ab9a037040 100644 --- a/routers/api/v1/repo/issue_subscription.go +++ b/routers/api/v1/repo/issue_subscription.go @@ -132,7 +132,7 @@ func setIssueSubscription(ctx *context.APIContext, watch bool) { return } - current, err := issues_model.CheckIssueWatch(user, issue) + current, err := issues_model.CheckIssueWatch(ctx, user, issue) if err != nil { ctx.Error(http.StatusInternalServerError, "CheckIssueWatch", err) return @@ -145,7 +145,7 @@ func setIssueSubscription(ctx *context.APIContext, watch bool) { } // Update watch state - if err := issues_model.CreateOrUpdateIssueWatch(user.ID, issue.ID, watch); err != nil { + if err := issues_model.CreateOrUpdateIssueWatch(ctx, user.ID, issue.ID, watch); err != nil { ctx.Error(http.StatusInternalServerError, "CreateOrUpdateIssueWatch", err) return } @@ -196,7 +196,7 @@ func CheckIssueSubscription(ctx *context.APIContext) { return } - watching, err := issues_model.CheckIssueWatch(ctx.Doer, issue) + watching, err := issues_model.CheckIssueWatch(ctx, ctx.Doer, issue) if err != nil { ctx.InternalServerError(err) return diff --git a/routers/api/v1/repo/label.go b/routers/api/v1/repo/label.go index e93c72a9f5..420d3ab5b4 100644 --- a/routers/api/v1/repo/label.go +++ b/routers/api/v1/repo/label.go @@ -55,7 +55,7 @@ func ListLabels(ctx *context.APIContext) { return } - count, err := issues_model.CountLabelsByRepoID(ctx.Repo.Repository.ID) + count, err := issues_model.CountLabelsByRepoID(ctx, ctx.Repo.Repository.ID) if err != nil { ctx.InternalServerError(err) return @@ -240,7 +240,7 @@ func EditLabel(ctx *context.APIContext) { l.Description = *form.Description } l.SetArchived(form.IsArchived != nil && *form.IsArchived) - if err := issues_model.UpdateLabel(l); err != nil { + if err := issues_model.UpdateLabel(ctx, l); err != nil { ctx.Error(http.StatusInternalServerError, "UpdateLabel", err) return } @@ -276,7 +276,7 @@ func DeleteLabel(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" - if err := issues_model.DeleteLabel(ctx.Repo.Repository.ID, ctx.ParamsInt64(":id")); err != nil { + if err := issues_model.DeleteLabel(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":id")); err != nil { ctx.Error(http.StatusInternalServerError, "DeleteLabel", err) return } diff --git a/routers/api/v1/repo/milestone.go b/routers/api/v1/repo/milestone.go index fff9493a23..1a86444660 100644 --- a/routers/api/v1/repo/milestone.go +++ b/routers/api/v1/repo/milestone.go @@ -163,7 +163,7 @@ func CreateMilestone(ctx *context.APIContext) { milestone.ClosedDateUnix = timeutil.TimeStampNow() } - if err := issues_model.NewMilestone(milestone); err != nil { + if err := issues_model.NewMilestone(ctx, milestone); err != nil { ctx.Error(http.StatusInternalServerError, "NewMilestone", err) return } @@ -225,7 +225,7 @@ func EditMilestone(ctx *context.APIContext) { milestone.IsClosed = *form.State == string(api.StateClosed) } - if err := issues_model.UpdateMilestone(milestone, oldIsClosed); err != nil { + if err := issues_model.UpdateMilestone(ctx, milestone, oldIsClosed); err != nil { ctx.Error(http.StatusInternalServerError, "UpdateMilestone", err) return } @@ -264,7 +264,7 @@ func DeleteMilestone(ctx *context.APIContext) { return } - if err := issues_model.DeleteMilestoneByRepoID(ctx.Repo.Repository.ID, m.ID); err != nil { + if err := issues_model.DeleteMilestoneByRepoID(ctx, ctx.Repo.Repository.ID, m.ID); err != nil { ctx.Error(http.StatusInternalServerError, "DeleteMilestoneByRepoID", err) return } @@ -286,7 +286,7 @@ func getMilestoneByIDOrName(ctx *context.APIContext) *issues_model.Milestone { } } - milestone, err := issues_model.GetMilestoneByRepoIDANDName(ctx.Repo.Repository.ID, mile) + milestone, err := issues_model.GetMilestoneByRepoIDANDName(ctx, ctx.Repo.Repository.ID, mile) if err != nil { if issues_model.IsErrMilestoneNotExist(err) { ctx.NotFound() diff --git a/routers/api/v1/repo/repo.go b/routers/api/v1/repo/repo.go index e06fc2df66..5f25fdce14 100644 --- a/routers/api/v1/repo/repo.go +++ b/routers/api/v1/repo/repo.go @@ -1003,14 +1003,14 @@ func updateRepoArchivedState(ctx *context.APIContext, opts api.EditRepoOption) e return err } if *opts.Archived { - if err := repo_model.SetArchiveRepoState(repo, *opts.Archived); err != nil { + if err := repo_model.SetArchiveRepoState(ctx, repo, *opts.Archived); err != nil { log.Error("Tried to archive a repo: %s", err) ctx.Error(http.StatusInternalServerError, "ArchiveRepoState", err) return err } log.Trace("Repository was archived: %s/%s", ctx.Repo.Owner.Name, repo.Name) } else { - if err := repo_model.SetArchiveRepoState(repo, *opts.Archived); err != nil { + if err := repo_model.SetArchiveRepoState(ctx, repo, *opts.Archived); err != nil { log.Error("Tried to un-archive a repo: %s", err) ctx.Error(http.StatusInternalServerError, "ArchiveRepoState", err) return err diff --git a/routers/api/v1/repo/topic.go b/routers/api/v1/repo/topic.go index c0c05154c4..d662b9b583 100644 --- a/routers/api/v1/repo/topic.go +++ b/routers/api/v1/repo/topic.go @@ -53,7 +53,7 @@ func ListTopics(ctx *context.APIContext) { RepoID: ctx.Repo.Repository.ID, } - topics, total, err := repo_model.FindTopics(opts) + topics, total, err := repo_model.FindTopics(ctx, opts) if err != nil { ctx.InternalServerError(err) return @@ -120,7 +120,7 @@ func UpdateTopics(ctx *context.APIContext) { return } - err := repo_model.SaveTopics(ctx.Repo.Repository.ID, validTopics...) + err := repo_model.SaveTopics(ctx, ctx.Repo.Repository.ID, validTopics...) if err != nil { log.Error("SaveTopics failed: %v", err) ctx.InternalServerError(err) @@ -172,7 +172,7 @@ func AddTopic(ctx *context.APIContext) { } // Prevent adding more topics than allowed to repo - count, err := repo_model.CountTopics(&repo_model.FindTopicOptions{ + count, err := repo_model.CountTopics(ctx, &repo_model.FindTopicOptions{ RepoID: ctx.Repo.Repository.ID, }) if err != nil { @@ -187,7 +187,7 @@ func AddTopic(ctx *context.APIContext) { return } - _, err = repo_model.AddTopic(ctx.Repo.Repository.ID, topicName) + _, err = repo_model.AddTopic(ctx, ctx.Repo.Repository.ID, topicName) if err != nil { log.Error("AddTopic failed: %v", err) ctx.InternalServerError(err) @@ -238,7 +238,7 @@ func DeleteTopic(ctx *context.APIContext) { return } - topic, err := repo_model.DeleteTopic(ctx.Repo.Repository.ID, topicName) + topic, err := repo_model.DeleteTopic(ctx, ctx.Repo.Repository.ID, topicName) if err != nil { log.Error("DeleteTopic failed: %v", err) ctx.InternalServerError(err) @@ -287,7 +287,7 @@ func TopicSearch(ctx *context.APIContext) { ListOptions: utils.GetListOptions(ctx), } - topics, total, err := repo_model.FindTopics(opts) + topics, total, err := repo_model.FindTopics(ctx, opts) if err != nil { ctx.InternalServerError(err) return diff --git a/routers/api/v1/repo/transfer.go b/routers/api/v1/repo/transfer.go index 8ff22a1193..326895918e 100644 --- a/routers/api/v1/repo/transfer.go +++ b/routers/api/v1/repo/transfer.go @@ -221,7 +221,7 @@ func acceptOrRejectRepoTransfer(ctx *context.APIContext, accept bool) error { return err } - if !repoTransfer.CanUserAcceptTransfer(ctx.Doer) { + if !repoTransfer.CanUserAcceptTransfer(ctx, ctx.Doer) { ctx.Error(http.StatusForbidden, "CanUserAcceptTransfer", nil) return fmt.Errorf("user does not have permissions to do this") } @@ -230,5 +230,5 @@ func acceptOrRejectRepoTransfer(ctx *context.APIContext, accept bool) error { return repo_service.TransferOwnership(ctx, repoTransfer.Doer, repoTransfer.Recipient, ctx.Repo.Repository, repoTransfer.Teams) } - return models.CancelRepositoryTransfer(ctx.Repo.Repository) + return models.CancelRepositoryTransfer(ctx, ctx.Repo.Repository) } diff --git a/routers/api/v1/user/follower.go b/routers/api/v1/user/follower.go index 1aa906ccb1..5815ed4f0b 100644 --- a/routers/api/v1/user/follower.go +++ b/routers/api/v1/user/follower.go @@ -151,7 +151,7 @@ func ListFollowing(ctx *context.APIContext) { } func checkUserFollowing(ctx *context.APIContext, u *user_model.User, followID int64) { - if user_model.IsFollowing(u.ID, followID) { + if user_model.IsFollowing(ctx, u.ID, followID) { ctx.Status(http.StatusNoContent) } else { ctx.NotFound() @@ -224,7 +224,7 @@ func Follow(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" - if err := user_model.FollowUser(ctx.Doer.ID, ctx.ContextUser.ID); err != nil { + if err := user_model.FollowUser(ctx, ctx.Doer.ID, ctx.ContextUser.ID); err != nil { ctx.Error(http.StatusInternalServerError, "FollowUser", err) return } @@ -248,7 +248,7 @@ func Unfollow(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" - if err := user_model.UnfollowUser(ctx.Doer.ID, ctx.ContextUser.ID); err != nil { + if err := user_model.UnfollowUser(ctx, ctx.Doer.ID, ctx.ContextUser.ID); err != nil { ctx.Error(http.StatusInternalServerError, "UnfollowUser", err) return } diff --git a/routers/init.go b/routers/init.go index 6369a39754..150a5c56f2 100644 --- a/routers/init.go +++ b/routers/init.go @@ -140,7 +140,7 @@ func InitWebInstalled(ctx context.Context) { mustInitCtx(ctx, models.Init) mustInitCtx(ctx, authmodel.Init) - mustInit(repo_service.Init) + mustInitCtx(ctx, repo_service.Init) // Booting long running goroutines. mustInit(indexer_service.Init) diff --git a/routers/web/admin/users.go b/routers/web/admin/users.go index 5562cc390c..af49b00ad6 100644 --- a/routers/web/admin/users.go +++ b/routers/web/admin/users.go @@ -243,7 +243,7 @@ func prepareUserInfo(ctx *context.Context) *user_model.User { ctx.ServerError("auth.HasTwoFactorByUID", err) return nil } - hasWebAuthn, err := auth.HasWebAuthnRegistrationsByUID(u.ID) + hasWebAuthn, err := auth.HasWebAuthnRegistrationsByUID(ctx, u.ID) if err != nil { ctx.ServerError("auth.HasWebAuthnRegistrationsByUID", err) return nil @@ -421,13 +421,13 @@ func EditUserPost(ctx *context.Context) { } } - wn, err := auth.GetWebAuthnCredentialsByUID(u.ID) + wn, err := auth.GetWebAuthnCredentialsByUID(ctx, u.ID) if err != nil { ctx.ServerError("auth.GetTwoFactorByUID", err) return } for _, cred := range wn { - if _, err := auth.DeleteCredential(cred.ID, u.ID); err != nil { + if _, err := auth.DeleteCredential(ctx, cred.ID, u.ID); err != nil { ctx.ServerError("auth.DeleteCredential", err) return } diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index b7a73e4379..8017602d99 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -243,7 +243,7 @@ func SignInPost(ctx *context.Context) { } // Check if the user has webauthn registration - hasWebAuthnTwofa, err := auth.HasWebAuthnRegistrationsByUID(u.ID) + hasWebAuthnTwofa, err := auth.HasWebAuthnRegistrationsByUID(ctx, u.ID) if err != nil { ctx.ServerError("UserSignIn", err) return diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go index 745b4e818c..c6e3d1231b 100644 --- a/routers/web/auth/linkaccount.go +++ b/routers/web/auth/linkaccount.go @@ -185,7 +185,7 @@ func linkAccount(ctx *context.Context, u *user_model.User, gothUser goth.User, r } // If WebAuthn is enrolled -> Redirect to WebAuthn instead - regs, err := auth.GetWebAuthnCredentialsByUID(u.ID) + regs, err := auth.GetWebAuthnCredentialsByUID(ctx, u.ID) if err == nil && len(regs) > 0 { ctx.Redirect(setting.AppSubURL + "/user/webauthn") return diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index 640c01e203..40c91b3f85 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -237,7 +237,7 @@ func newAccessTokenResponse(ctx go_context.Context, grant *auth.OAuth2Grant, ser idToken.EmailVerified = user.IsActive } if grant.ScopeContains("groups") { - groups, err := getOAuthGroupsForUser(user) + groups, err := getOAuthGroupsForUser(ctx, user) if err != nil { log.Error("Error getting groups: %v", err) return nil, &AccessTokenError{ @@ -291,7 +291,7 @@ func InfoOAuth(ctx *context.Context) { Picture: ctx.Doer.AvatarLink(ctx), } - groups, err := getOAuthGroupsForUser(ctx.Doer) + groups, err := getOAuthGroupsForUser(ctx, ctx.Doer) if err != nil { ctx.ServerError("Oauth groups for user", err) return @@ -303,8 +303,8 @@ func InfoOAuth(ctx *context.Context) { // returns a list of "org" and "org:team" strings, // that the given user is a part of. -func getOAuthGroupsForUser(user *user_model.User) ([]string, error) { - orgs, err := org_model.GetUserOrgsList(user) +func getOAuthGroupsForUser(ctx go_context.Context, user *user_model.User) ([]string, error) { + orgs, err := org_model.GetUserOrgsList(ctx, user) if err != nil { return nil, fmt.Errorf("GetUserOrgList: %w", err) } @@ -1197,7 +1197,7 @@ func handleOAuth2SignIn(ctx *context.Context, source *auth.Source, u *user_model } // If WebAuthn is enrolled -> Redirect to WebAuthn instead - regs, err := auth.GetWebAuthnCredentialsByUID(u.ID) + regs, err := auth.GetWebAuthnCredentialsByUID(ctx, u.ID) if err == nil && len(regs) > 0 { ctx.Redirect(setting.AppSubURL + "/user/webauthn") return diff --git a/routers/web/auth/webauthn.go b/routers/web/auth/webauthn.go index 013e11eacc..b19e18aa8e 100644 --- a/routers/web/auth/webauthn.go +++ b/routers/web/auth/webauthn.go @@ -55,7 +55,7 @@ func WebAuthnLoginAssertion(ctx *context.Context) { return } - exists, err := auth.ExistsWebAuthnCredentialsForUID(user.ID) + exists, err := auth.ExistsWebAuthnCredentialsForUID(ctx, user.ID) if err != nil { ctx.ServerError("UserSignIn", err) return @@ -127,14 +127,14 @@ func WebAuthnLoginAssertionPost(ctx *context.Context) { } // Success! Get the credential and update the sign count with the new value we received. - dbCred, err := auth.GetWebAuthnCredentialByCredID(user.ID, cred.ID) + dbCred, err := auth.GetWebAuthnCredentialByCredID(ctx, user.ID, cred.ID) if err != nil { ctx.ServerError("GetWebAuthnCredentialByCredID", err) return } dbCred.SignCount = cred.Authenticator.SignCount - if err := dbCred.UpdateSignCount(); err != nil { + if err := dbCred.UpdateSignCount(ctx); err != nil { ctx.ServerError("UpdateSignCount", err) return } diff --git a/routers/web/explore/topic.go b/routers/web/explore/topic.go index 132ef23fa7..bb1be310de 100644 --- a/routers/web/explore/topic.go +++ b/routers/web/explore/topic.go @@ -23,7 +23,7 @@ func TopicSearch(ctx *context.Context) { }, } - topics, total, err := repo_model.FindTopics(opts) + topics, total, err := repo_model.FindTopics(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError) return diff --git a/routers/web/org/home.go b/routers/web/org/home.go index 15386393e9..ec866eb6b3 100644 --- a/routers/web/org/home.go +++ b/routers/web/org/home.go @@ -131,7 +131,7 @@ func Home(ctx *context.Context) { var isFollowing bool if ctx.Doer != nil { - isFollowing = user_model.IsFollowing(ctx.Doer.ID, ctx.ContextUser.ID) + isFollowing = user_model.IsFollowing(ctx, ctx.Doer.ID, ctx.ContextUser.ID) } ctx.Data["Repos"] = repos diff --git a/routers/web/org/org_labels.go b/routers/web/org/org_labels.go index 2c7725e38d..f78bd00274 100644 --- a/routers/web/org/org_labels.go +++ b/routers/web/org/org_labels.go @@ -76,7 +76,7 @@ func UpdateLabel(ctx *context.Context) { l.Description = form.Description l.Color = form.Color l.SetArchived(form.IsArchived) - if err := issues_model.UpdateLabel(l); err != nil { + if err := issues_model.UpdateLabel(ctx, l); err != nil { ctx.ServerError("UpdateLabel", err) return } @@ -85,7 +85,7 @@ func UpdateLabel(ctx *context.Context) { // DeleteLabel delete a label func DeleteLabel(ctx *context.Context) { - if err := issues_model.DeleteLabel(ctx.Org.Organization.ID, ctx.FormInt64("id")); err != nil { + if err := issues_model.DeleteLabel(ctx, ctx.Org.Organization.ID, ctx.FormInt64("id")); err != nil { ctx.Flash.Error("DeleteLabel: " + err.Error()) } else { ctx.Flash.Success(ctx.Tr("repo.issues.label_deletion_success")) diff --git a/routers/web/repo/issue.go b/routers/web/repo/issue.go index 94c9382f23..f4aa357fac 100644 --- a/routers/web/repo/issue.go +++ b/routers/web/repo/issue.go @@ -1412,7 +1412,7 @@ func ViewIssue(ctx *context.Context) { if ctx.Doer != nil { iw.UserID = ctx.Doer.ID iw.IssueID = issue.ID - iw.IsWatching, err = issues_model.CheckIssueWatch(ctx.Doer, issue) + iw.IsWatching, err = issues_model.CheckIssueWatch(ctx, ctx.Doer, issue) if err != nil { ctx.ServerError("CheckIssueWatch", err) return @@ -1530,7 +1530,7 @@ func ViewIssue(ctx *context.Context) { if ctx.Repo.Repository.IsTimetrackerEnabled(ctx) { if ctx.IsSigned { // Deal with the stopwatch - ctx.Data["IsStopwatchRunning"] = issues_model.StopwatchExists(ctx.Doer.ID, issue.ID) + ctx.Data["IsStopwatchRunning"] = issues_model.StopwatchExists(ctx, ctx.Doer.ID, issue.ID) if !ctx.Data["IsStopwatchRunning"].(bool) { var exists bool var swIssue *issues_model.Issue @@ -2708,7 +2708,7 @@ func ListIssues(ctx *context.Context) { var labelIDs []int64 if splitted := strings.Split(ctx.FormString("labels"), ","); len(splitted) > 0 { - labelIDs, err = issues_model.GetLabelIDsInRepoByNames(ctx.Repo.Repository.ID, splitted) + labelIDs, err = issues_model.GetLabelIDsInRepoByNames(ctx, ctx.Repo.Repository.ID, splitted) if err != nil { ctx.Error(http.StatusInternalServerError, err.Error()) return @@ -2720,7 +2720,7 @@ func ListIssues(ctx *context.Context) { for i := range part { // uses names and fall back to ids // non existent milestones are discarded - mile, err := issues_model.GetMilestoneByRepoIDANDName(ctx.Repo.Repository.ID, part[i]) + mile, err := issues_model.GetMilestoneByRepoIDANDName(ctx, ctx.Repo.Repository.ID, part[i]) if err == nil { mileIDs = append(mileIDs, mile.ID) continue @@ -3037,7 +3037,7 @@ func NewComment(ctx *context.Context) { return } } else { - if err := stopTimerIfAvailable(ctx.Doer, issue); err != nil { + if err := stopTimerIfAvailable(ctx, ctx.Doer, issue); err != nil { ctx.ServerError("CreateOrStopIssueStopwatch", err) return } diff --git a/routers/web/repo/issue_label.go b/routers/web/repo/issue_label.go index 257610d3af..2d129490f5 100644 --- a/routers/web/repo/issue_label.go +++ b/routers/web/repo/issue_label.go @@ -145,7 +145,7 @@ func UpdateLabel(ctx *context.Context) { l.Color = form.Color l.SetArchived(form.IsArchived) - if err := issues_model.UpdateLabel(l); err != nil { + if err := issues_model.UpdateLabel(ctx, l); err != nil { ctx.ServerError("UpdateLabel", err) return } @@ -154,7 +154,7 @@ func UpdateLabel(ctx *context.Context) { // DeleteLabel delete a label func DeleteLabel(ctx *context.Context) { - if err := issues_model.DeleteLabel(ctx.Repo.Repository.ID, ctx.FormInt64("id")); err != nil { + if err := issues_model.DeleteLabel(ctx, ctx.Repo.Repository.ID, ctx.FormInt64("id")); err != nil { ctx.Flash.Error("DeleteLabel: " + err.Error()) } else { ctx.Flash.Success(ctx.Tr("repo.issues.label_deletion_success")) @@ -173,7 +173,7 @@ func UpdateIssueLabel(ctx *context.Context) { switch action := ctx.FormString("action"); action { case "clear": for _, issue := range issues { - if err := issue_service.ClearLabels(issue, ctx.Doer); err != nil { + if err := issue_service.ClearLabels(ctx, issue, ctx.Doer); err != nil { ctx.ServerError("ClearLabels", err) return } @@ -208,14 +208,14 @@ func UpdateIssueLabel(ctx *context.Context) { if action == "attach" { for _, issue := range issues { - if err = issue_service.AddLabel(issue, ctx.Doer, label); err != nil { + if err = issue_service.AddLabel(ctx, issue, ctx.Doer, label); err != nil { ctx.ServerError("AddLabel", err) return } } } else { for _, issue := range issues { - if err = issue_service.RemoveLabel(issue, ctx.Doer, label); err != nil { + if err = issue_service.RemoveLabel(ctx, issue, ctx.Doer, label); err != nil { ctx.ServerError("RemoveLabel", err) return } diff --git a/routers/web/repo/issue_stopwatch.go b/routers/web/repo/issue_stopwatch.go index 3e715437e6..d42af57329 100644 --- a/routers/web/repo/issue_stopwatch.go +++ b/routers/web/repo/issue_stopwatch.go @@ -22,7 +22,7 @@ func IssueStopwatch(c *context.Context) { var showSuccessMessage bool - if !issues_model.StopwatchExists(c.Doer.ID, issue.ID) { + if !issues_model.StopwatchExists(c, c.Doer.ID, issue.ID) { showSuccessMessage = true } @@ -31,7 +31,7 @@ func IssueStopwatch(c *context.Context) { return } - if err := issues_model.CreateOrStopIssueStopwatch(c.Doer, issue); err != nil { + if err := issues_model.CreateOrStopIssueStopwatch(c, c.Doer, issue); err != nil { c.ServerError("CreateOrStopIssueStopwatch", err) return } @@ -55,12 +55,12 @@ func CancelStopwatch(c *context.Context) { return } - if err := issues_model.CancelStopwatch(c.Doer, issue); err != nil { + if err := issues_model.CancelStopwatch(c, c.Doer, issue); err != nil { c.ServerError("CancelStopwatch", err) return } - stopwatches, err := issues_model.GetUserStopwatches(c.Doer.ID, db.ListOptions{}) + stopwatches, err := issues_model.GetUserStopwatches(c, c.Doer.ID, db.ListOptions{}) if err != nil { c.ServerError("GetUserStopwatches", err) return diff --git a/routers/web/repo/issue_watch.go b/routers/web/repo/issue_watch.go index d3d3a2af21..1cb5cc7162 100644 --- a/routers/web/repo/issue_watch.go +++ b/routers/web/repo/issue_watch.go @@ -47,7 +47,7 @@ func IssueWatch(ctx *context.Context) { return } - if err := issues_model.CreateOrUpdateIssueWatch(ctx.Doer.ID, issue.ID, watch); err != nil { + if err := issues_model.CreateOrUpdateIssueWatch(ctx, ctx.Doer.ID, issue.ID, watch); err != nil { ctx.ServerError("CreateOrUpdateIssueWatch", err) return } diff --git a/routers/web/repo/migrate.go b/routers/web/repo/migrate.go index a6125a1a58..b70901d5f2 100644 --- a/routers/web/repo/migrate.go +++ b/routers/web/repo/migrate.go @@ -232,13 +232,13 @@ func MigratePost(ctx *context.Context) { opts.Releases = false } - err = repo_model.CheckCreateRepository(ctx.Doer, ctxUser, opts.RepoName, false) + err = repo_model.CheckCreateRepository(ctx, ctx.Doer, ctxUser, opts.RepoName, false) if err != nil { handleMigrateError(ctx, ctxUser, err, "MigratePost", tpl, form) return } - err = task.MigrateRepository(ctx.Doer, ctxUser, opts) + err = task.MigrateRepository(ctx, ctx.Doer, ctxUser, opts) if err == nil { ctx.Redirect(ctxUser.HomeLink() + "/" + url.PathEscape(opts.RepoName)) return @@ -260,7 +260,7 @@ func setMigrationContextData(ctx *context.Context, serviceType structs.GitServic } func MigrateRetryPost(ctx *context.Context) { - if err := task.RetryMigrateTask(ctx.Repo.Repository.ID); err != nil { + if err := task.RetryMigrateTask(ctx, ctx.Repo.Repository.ID); err != nil { log.Error("Retry task failed: %v", err) ctx.ServerError("task.RetryMigrateTask", err) return @@ -269,7 +269,7 @@ func MigrateRetryPost(ctx *context.Context) { } func MigrateCancelPost(ctx *context.Context) { - migratingTask, err := admin_model.GetMigratingTask(ctx.Repo.Repository.ID) + migratingTask, err := admin_model.GetMigratingTask(ctx, ctx.Repo.Repository.ID) if err != nil { log.Error("GetMigratingTask: %v", err) ctx.Redirect(ctx.Repo.Repository.Link()) @@ -277,7 +277,7 @@ func MigrateCancelPost(ctx *context.Context) { } if migratingTask.Status == structs.TaskStatusRunning { taskUpdate := &admin_model.Task{ID: migratingTask.ID, Status: structs.TaskStatusFailed, Message: "canceled"} - if err = taskUpdate.UpdateCols("status", "message"); err != nil { + if err = taskUpdate.UpdateCols(ctx, "status", "message"); err != nil { ctx.ServerError("task.UpdateCols", err) return } diff --git a/routers/web/repo/milestone.go b/routers/web/repo/milestone.go index ad355ce5d7..df52ca3528 100644 --- a/routers/web/repo/milestone.go +++ b/routers/web/repo/milestone.go @@ -65,7 +65,7 @@ func Milestones(ctx *context.Context) { return } - stats, err := issues_model.GetMilestonesStatsByRepoCondAndKw(builder.And(builder.Eq{"id": ctx.Repo.Repository.ID}), keyword) + stats, err := issues_model.GetMilestonesStatsByRepoCondAndKw(ctx, builder.And(builder.Eq{"id": ctx.Repo.Repository.ID}), keyword) if err != nil { ctx.ServerError("GetMilestoneStats", err) return @@ -74,7 +74,7 @@ func Milestones(ctx *context.Context) { ctx.Data["ClosedCount"] = stats.ClosedCount if ctx.Repo.Repository.IsTimetrackerEnabled(ctx) { - if err := miles.LoadTotalTrackedTimes(); err != nil { + if err := miles.LoadTotalTrackedTimes(ctx); err != nil { ctx.ServerError("LoadTotalTrackedTimes", err) return } @@ -142,7 +142,7 @@ func NewMilestonePost(ctx *context.Context) { } deadline = time.Date(deadline.Year(), deadline.Month(), deadline.Day(), 23, 59, 59, 0, deadline.Location()) - if err = issues_model.NewMilestone(&issues_model.Milestone{ + if err = issues_model.NewMilestone(ctx, &issues_model.Milestone{ RepoID: ctx.Repo.Repository.ID, Name: form.Title, Content: form.Content, @@ -214,7 +214,7 @@ func EditMilestonePost(ctx *context.Context) { m.Name = form.Title m.Content = form.Content m.DeadlineUnix = timeutil.TimeStamp(deadline.Unix()) - if err = issues_model.UpdateMilestone(m, m.IsClosed); err != nil { + if err = issues_model.UpdateMilestone(ctx, m, m.IsClosed); err != nil { ctx.ServerError("UpdateMilestone", err) return } @@ -236,7 +236,7 @@ func ChangeMilestoneStatus(ctx *context.Context) { } id := ctx.ParamsInt64(":id") - if err := issues_model.ChangeMilestoneStatusByRepoIDAndID(ctx.Repo.Repository.ID, id, toClose); err != nil { + if err := issues_model.ChangeMilestoneStatusByRepoIDAndID(ctx, ctx.Repo.Repository.ID, id, toClose); err != nil { if issues_model.IsErrMilestoneNotExist(err) { ctx.NotFound("", err) } else { @@ -249,7 +249,7 @@ func ChangeMilestoneStatus(ctx *context.Context) { // DeleteMilestone delete a milestone func DeleteMilestone(ctx *context.Context) { - if err := issues_model.DeleteMilestoneByRepoID(ctx.Repo.Repository.ID, ctx.FormInt64("id")); err != nil { + if err := issues_model.DeleteMilestoneByRepoID(ctx, ctx.Repo.Repository.ID, ctx.FormInt64("id")); err != nil { ctx.Flash.Error("DeleteMilestoneByRepoID: " + err.Error()) } else { ctx.Flash.Success(ctx.Tr("repo.milestones.deletion_success")) diff --git a/routers/web/repo/pull.go b/routers/web/repo/pull.go index 0ef4a29f0c..63dfd0f7b5 100644 --- a/routers/web/repo/pull.go +++ b/routers/web/repo/pull.go @@ -1270,7 +1270,7 @@ func MergePullRequest(ctx *context.Context) { } log.Trace("Pull request merged: %d", pr.ID) - if err := stopTimerIfAvailable(ctx.Doer, issue); err != nil { + if err := stopTimerIfAvailable(ctx, ctx.Doer, issue); err != nil { ctx.ServerError("CreateOrStopIssueStopwatch", err) return } @@ -1326,9 +1326,9 @@ func CancelAutoMergePullRequest(ctx *context.Context) { ctx.Redirect(fmt.Sprintf("%s/pulls/%d", ctx.Repo.RepoLink, issue.Index)) } -func stopTimerIfAvailable(user *user_model.User, issue *issues_model.Issue) error { - if issues_model.StopwatchExists(user.ID, issue.ID) { - if err := issues_model.CreateOrStopIssueStopwatch(user, issue); err != nil { +func stopTimerIfAvailable(ctx *context.Context, user *user_model.User, issue *issues_model.Issue) error { + if issues_model.StopwatchExists(ctx, user.ID, issue.ID) { + if err := issues_model.CreateOrStopIssueStopwatch(ctx, user, issue); err != nil { return err } } diff --git a/routers/web/repo/repo.go b/routers/web/repo/repo.go index 799c2268de..b31ebb1971 100644 --- a/routers/web/repo/repo.go +++ b/routers/web/repo/repo.go @@ -344,7 +344,7 @@ func acceptOrRejectRepoTransfer(ctx *context.Context, accept bool) error { return err } - if !repoTransfer.CanUserAcceptTransfer(ctx.Doer) { + if !repoTransfer.CanUserAcceptTransfer(ctx, ctx.Doer) { return errors.New("user does not have enough permissions") } @@ -359,7 +359,7 @@ func acceptOrRejectRepoTransfer(ctx *context.Context, accept bool) error { } ctx.Flash.Success(ctx.Tr("repo.settings.transfer.success")) } else { - if err := models.CancelRepositoryTransfer(ctx.Repo.Repository); err != nil { + if err := models.CancelRepositoryTransfer(ctx, ctx.Repo.Repository); err != nil { return err } ctx.Flash.Success(ctx.Tr("repo.settings.transfer.rejected")) diff --git a/routers/web/repo/setting/collaboration.go b/routers/web/repo/setting/collaboration.go index 1e71d33c08..e217697cc0 100644 --- a/routers/web/repo/setting/collaboration.go +++ b/routers/web/repo/setting/collaboration.go @@ -127,7 +127,7 @@ func ChangeCollaborationAccessMode(ctx *context.Context) { // DeleteCollaboration delete a collaboration for a repository func DeleteCollaboration(ctx *context.Context) { - if err := repo_service.DeleteCollaboration(ctx.Repo.Repository, ctx.FormInt64("id")); err != nil { + if err := repo_service.DeleteCollaboration(ctx, ctx.Repo.Repository, ctx.FormInt64("id")); err != nil { ctx.Flash.Error("DeleteCollaboration: " + err.Error()) } else { ctx.Flash.Success(ctx.Tr("repo.settings.remove_collaborator_success")) diff --git a/routers/web/repo/setting/setting.go b/routers/web/repo/setting/setting.go index af09e240d5..7c85ff2078 100644 --- a/routers/web/repo/setting/setting.go +++ b/routers/web/repo/setting/setting.go @@ -799,7 +799,7 @@ func SettingsPost(ctx *context.Context) { return } - if err := models.CancelRepositoryTransfer(ctx.Repo.Repository); err != nil { + if err := models.CancelRepositoryTransfer(ctx, ctx.Repo.Repository); err != nil { ctx.ServerError("CancelRepositoryTransfer", err) return } @@ -863,7 +863,7 @@ func SettingsPost(ctx *context.Context) { return } - if err := repo_model.SetArchiveRepoState(repo, true); err != nil { + if err := repo_model.SetArchiveRepoState(ctx, repo, true); err != nil { log.Error("Tried to archive a repo: %s", err) ctx.Flash.Error(ctx.Tr("repo.settings.archive.error")) ctx.Redirect(ctx.Repo.RepoLink + "/settings") @@ -881,7 +881,7 @@ func SettingsPost(ctx *context.Context) { return } - if err := repo_model.SetArchiveRepoState(repo, false); err != nil { + if err := repo_model.SetArchiveRepoState(ctx, repo, false); err != nil { log.Error("Tried to unarchive a repo: %s", err) ctx.Flash.Error(ctx.Tr("repo.settings.unarchive.error")) ctx.Redirect(ctx.Repo.RepoLink + "/settings") diff --git a/routers/web/repo/topic.go b/routers/web/repo/topic.go index d22c3c6aa3..d0e706c5bd 100644 --- a/routers/web/repo/topic.go +++ b/routers/web/repo/topic.go @@ -45,7 +45,7 @@ func TopicsPost(ctx *context.Context) { return } - err := repo_model.SaveTopics(ctx.Repo.Repository.ID, validTopics...) + err := repo_model.SaveTopics(ctx, ctx.Repo.Repository.ID, validTopics...) if err != nil { log.Error("SaveTopics failed: %v", err) ctx.JSON(http.StatusInternalServerError, map[string]any{ diff --git a/routers/web/repo/view.go b/routers/web/repo/view.go index 26e9cedd3a..37da76e3e5 100644 --- a/routers/web/repo/view.go +++ b/routers/web/repo/view.go @@ -640,7 +640,7 @@ func safeURL(address string) string { func checkHomeCodeViewable(ctx *context.Context) { if len(ctx.Repo.Units) > 0 { if ctx.Repo.Repository.IsBeingCreated() { - task, err := admin_model.GetMigratingTask(ctx.Repo.Repository.ID) + task, err := admin_model.GetMigratingTask(ctx, ctx.Repo.Repository.ID) if err != nil { if admin_model.IsErrTaskDoesNotExist(err) { ctx.Data["Repo"] = ctx.Repo @@ -893,7 +893,7 @@ func renderLanguageStats(ctx *context.Context) { } func renderRepoTopics(ctx *context.Context) { - topics, _, err := repo_model.FindTopics(&repo_model.FindTopicOptions{ + topics, _, err := repo_model.FindTopics(ctx, &repo_model.FindTopicOptions{ RepoID: ctx.Repo.Repository.ID, }) if err != nil { diff --git a/routers/web/shared/user/header.go b/routers/web/shared/user/header.go index 649537ec63..16d9321e80 100644 --- a/routers/web/shared/user/header.go +++ b/routers/web/shared/user/header.go @@ -30,7 +30,7 @@ func prepareContextForCommonProfile(ctx *context.Context) { func PrepareContextForProfileBigAvatar(ctx *context.Context) { prepareContextForCommonProfile(ctx) - ctx.Data["IsFollowing"] = ctx.Doer != nil && user_model.IsFollowing(ctx.Doer.ID, ctx.ContextUser.ID) + ctx.Data["IsFollowing"] = ctx.Doer != nil && user_model.IsFollowing(ctx, ctx.Doer.ID, ctx.ContextUser.ID) ctx.Data["ShowUserEmail"] = setting.UI.ShowUserEmail && ctx.ContextUser.Email != "" && ctx.IsSigned && !ctx.ContextUser.KeepEmailPrivate // Show OpenID URIs diff --git a/routers/web/user/home.go b/routers/web/user/home.go index a88479e129..9efb536a7f 100644 --- a/routers/web/user/home.go +++ b/routers/web/user/home.go @@ -59,7 +59,7 @@ func getDashboardContextUser(ctx *context.Context) *user_model.User { } ctx.Data["ContextUser"] = ctxUser - orgs, err := organization.GetUserOrgsList(ctx.Doer) + orgs, err := organization.GetUserOrgsList(ctx, ctx.Doer) if err != nil { ctx.ServerError("GetUserOrgsList", err) return nil @@ -213,13 +213,13 @@ func Milestones(ctx *context.Context) { } } - counts, err := issues_model.CountMilestonesByRepoCondAndKw(userRepoCond, keyword, isShowClosed) + counts, err := issues_model.CountMilestonesByRepoCondAndKw(ctx, userRepoCond, keyword, isShowClosed) if err != nil { ctx.ServerError("CountMilestonesByRepoIDs", err) return } - milestones, err := issues_model.SearchMilestones(repoCond, page, isShowClosed, sortType, keyword) + milestones, err := issues_model.SearchMilestones(ctx, repoCond, page, isShowClosed, sortType, keyword) if err != nil { ctx.ServerError("SearchMilestones", err) return @@ -256,7 +256,7 @@ func Milestones(ctx *context.Context) { } if milestones[i].Repo.IsTimetrackerEnabled(ctx) { - err := milestones[i].LoadTotalTrackedTime() + err := milestones[i].LoadTotalTrackedTime(ctx) if err != nil { ctx.ServerError("LoadTotalTrackedTime", err) return @@ -265,7 +265,7 @@ func Milestones(ctx *context.Context) { i++ } - milestoneStats, err := issues_model.GetMilestonesStatsByRepoCondAndKw(repoCond, keyword) + milestoneStats, err := issues_model.GetMilestonesStatsByRepoCondAndKw(ctx, repoCond, keyword) if err != nil { ctx.ServerError("GetMilestoneStats", err) return @@ -275,7 +275,7 @@ func Milestones(ctx *context.Context) { if len(repoIDs) == 0 { totalMilestoneStats = milestoneStats } else { - totalMilestoneStats, err = issues_model.GetMilestonesStatsByRepoCondAndKw(userRepoCond, keyword) + totalMilestoneStats, err = issues_model.GetMilestonesStatsByRepoCondAndKw(ctx, userRepoCond, keyword) if err != nil { ctx.ServerError("GetMilestoneStats", err) return diff --git a/routers/web/user/profile.go b/routers/web/user/profile.go index 87505b94b1..71d10ab4c1 100644 --- a/routers/web/user/profile.go +++ b/routers/web/user/profile.go @@ -292,9 +292,9 @@ func Action(ctx *context.Context) { var err error switch ctx.FormString("action") { case "follow": - err = user_model.FollowUser(ctx.Doer.ID, ctx.ContextUser.ID) + err = user_model.FollowUser(ctx, ctx.Doer.ID, ctx.ContextUser.ID) case "unfollow": - err = user_model.UnfollowUser(ctx.Doer.ID, ctx.ContextUser.ID) + err = user_model.UnfollowUser(ctx, ctx.Doer.ID, ctx.ContextUser.ID) } if err != nil { diff --git a/routers/web/user/setting/security/security.go b/routers/web/user/setting/security/security.go index 1ce59fef09..5a17c161fe 100644 --- a/routers/web/user/setting/security/security.go +++ b/routers/web/user/setting/security/security.go @@ -59,7 +59,7 @@ func loadSecurityData(ctx *context.Context) { } ctx.Data["TOTPEnrolled"] = enrolled - credentials, err := auth_model.GetWebAuthnCredentialsByUID(ctx.Doer.ID) + credentials, err := auth_model.GetWebAuthnCredentialsByUID(ctx, ctx.Doer.ID) if err != nil { ctx.ServerError("GetWebAuthnCredentialsByUID", err) return diff --git a/routers/web/user/setting/security/webauthn.go b/routers/web/user/setting/security/webauthn.go index 990e506d6f..ce103528c5 100644 --- a/routers/web/user/setting/security/webauthn.go +++ b/routers/web/user/setting/security/webauthn.go @@ -29,7 +29,7 @@ func WebAuthnRegister(ctx *context.Context) { form.Name = strconv.FormatInt(time.Now().UnixNano(), 16) } - cred, err := auth.GetWebAuthnCredentialByName(ctx.Doer.ID, form.Name) + cred, err := auth.GetWebAuthnCredentialByName(ctx, ctx.Doer.ID, form.Name) if err != nil && !auth.IsErrWebAuthnCredentialNotExist(err) { ctx.ServerError("GetWebAuthnCredentialsByUID", err) return @@ -88,7 +88,7 @@ func WebauthnRegisterPost(ctx *context.Context) { return } - dbCred, err := auth.GetWebAuthnCredentialByName(ctx.Doer.ID, name) + dbCred, err := auth.GetWebAuthnCredentialByName(ctx, ctx.Doer.ID, name) if err != nil && !auth.IsErrWebAuthnCredentialNotExist(err) { ctx.ServerError("GetWebAuthnCredentialsByUID", err) return @@ -99,7 +99,7 @@ func WebauthnRegisterPost(ctx *context.Context) { } // Create the credential - _, err = auth.CreateCredential(ctx.Doer.ID, name, cred) + _, err = auth.CreateCredential(ctx, ctx.Doer.ID, name, cred) if err != nil { ctx.ServerError("CreateCredential", err) return @@ -112,7 +112,7 @@ func WebauthnRegisterPost(ctx *context.Context) { // WebauthnDelete deletes an security key by id func WebauthnDelete(ctx *context.Context) { form := web.GetForm(ctx).(*forms.WebauthnDeleteForm) - if _, err := auth.DeleteCredential(form.ID, ctx.Doer.ID); err != nil { + if _, err := auth.DeleteCredential(ctx, form.ID, ctx.Doer.ID); err != nil { ctx.ServerError("GetWebAuthnCredentialByID", err) return } diff --git a/routers/web/user/stop_watch.go b/routers/web/user/stop_watch.go index d262c777c3..cac446d84a 100644 --- a/routers/web/user/stop_watch.go +++ b/routers/web/user/stop_watch.go @@ -14,7 +14,7 @@ import ( // GetStopwatches get all stopwatches func GetStopwatches(ctx *context.Context) { - sws, err := issues_model.GetUserStopwatches(ctx.Doer.ID, db.ListOptions{ + sws, err := issues_model.GetUserStopwatches(ctx, ctx.Doer.ID, db.ListOptions{ Page: ctx.FormInt("page"), PageSize: convert.ToCorrectPageSize(ctx.FormInt("limit")), }) @@ -23,7 +23,7 @@ func GetStopwatches(ctx *context.Context) { return } - count, err := issues_model.CountUserStopwatches(ctx.Doer.ID) + count, err := issues_model.CountUserStopwatches(ctx, ctx.Doer.ID) if err != nil { ctx.Error(http.StatusInternalServerError, err.Error()) return diff --git a/routers/web/user/task.go b/routers/web/user/task.go index d92bf64af0..f35f40e6a0 100644 --- a/routers/web/user/task.go +++ b/routers/web/user/task.go @@ -14,7 +14,7 @@ import ( // TaskStatus returns task's status func TaskStatus(ctx *context.Context) { - task, opts, err := admin_model.GetMigratingTaskByID(ctx.ParamsInt64("task"), ctx.Doer.ID) + task, opts, err := admin_model.GetMigratingTaskByID(ctx, ctx.ParamsInt64("task"), ctx.Doer.ID) if err != nil { if admin_model.IsErrTaskDoesNotExist(err) { ctx.JSON(http.StatusNotFound, map[string]any{ diff --git a/services/issue/label.go b/services/issue/label.go index f830aab0e7..91f0308d9f 100644 --- a/services/issue/label.go +++ b/services/issue/label.go @@ -4,6 +4,8 @@ package issue import ( + "context" + "code.gitea.io/gitea/models/db" issues_model "code.gitea.io/gitea/models/issues" access_model "code.gitea.io/gitea/models/perm/access" @@ -12,49 +14,49 @@ import ( ) // ClearLabels clears all of an issue's labels -func ClearLabels(issue *issues_model.Issue, doer *user_model.User) error { +func ClearLabels(ctx context.Context, issue *issues_model.Issue, doer *user_model.User) error { if err := issues_model.ClearIssueLabels(issue, doer); err != nil { return err } - notify_service.IssueClearLabels(db.DefaultContext, doer, issue) + notify_service.IssueClearLabels(ctx, doer, issue) return nil } // AddLabel adds a new label to the issue. -func AddLabel(issue *issues_model.Issue, doer *user_model.User, label *issues_model.Label) error { +func AddLabel(ctx context.Context, issue *issues_model.Issue, doer *user_model.User, label *issues_model.Label) error { if err := issues_model.NewIssueLabel(issue, label, doer); err != nil { return err } - notify_service.IssueChangeLabels(db.DefaultContext, doer, issue, []*issues_model.Label{label}, nil) + notify_service.IssueChangeLabels(ctx, doer, issue, []*issues_model.Label{label}, nil) return nil } // AddLabels adds a list of new labels to the issue. -func AddLabels(issue *issues_model.Issue, doer *user_model.User, labels []*issues_model.Label) error { +func AddLabels(ctx context.Context, issue *issues_model.Issue, doer *user_model.User, labels []*issues_model.Label) error { if err := issues_model.NewIssueLabels(issue, labels, doer); err != nil { return err } - notify_service.IssueChangeLabels(db.DefaultContext, doer, issue, labels, nil) + notify_service.IssueChangeLabels(ctx, doer, issue, labels, nil) return nil } // RemoveLabel removes a label from issue by given ID. -func RemoveLabel(issue *issues_model.Issue, doer *user_model.User, label *issues_model.Label) error { - ctx, committer, err := db.TxContext(db.DefaultContext) +func RemoveLabel(ctx context.Context, issue *issues_model.Issue, doer *user_model.User, label *issues_model.Label) error { + dbCtx, committer, err := db.TxContext(ctx) if err != nil { return err } defer committer.Close() - if err := issue.LoadRepo(ctx); err != nil { + if err := issue.LoadRepo(dbCtx); err != nil { return err } - perm, err := access_model.GetUserRepoPermission(ctx, issue.Repo, doer) + perm, err := access_model.GetUserRepoPermission(dbCtx, issue.Repo, doer) if err != nil { return err } @@ -65,7 +67,7 @@ func RemoveLabel(issue *issues_model.Issue, doer *user_model.User, label *issues return issues_model.ErrRepoLabelNotExist{} } - if err := issues_model.DeleteIssueLabel(ctx, issue, label, doer); err != nil { + if err := issues_model.DeleteIssueLabel(dbCtx, issue, label, doer); err != nil { return err } @@ -73,13 +75,13 @@ func RemoveLabel(issue *issues_model.Issue, doer *user_model.User, label *issues return err } - notify_service.IssueChangeLabels(db.DefaultContext, doer, issue, nil, []*issues_model.Label{label}) + notify_service.IssueChangeLabels(ctx, doer, issue, nil, []*issues_model.Label{label}) return nil } // ReplaceLabels removes all current labels and add new labels to the issue. -func ReplaceLabels(issue *issues_model.Issue, doer *user_model.User, labels []*issues_model.Label) error { - old, err := issues_model.GetLabelsByIssueID(db.DefaultContext, issue.ID) +func ReplaceLabels(ctx context.Context, issue *issues_model.Issue, doer *user_model.User, labels []*issues_model.Label) error { + old, err := issues_model.GetLabelsByIssueID(ctx, issue.ID) if err != nil { return err } @@ -88,6 +90,6 @@ func ReplaceLabels(issue *issues_model.Issue, doer *user_model.User, labels []*i return err } - notify_service.IssueChangeLabels(db.DefaultContext, doer, issue, labels, old) + notify_service.IssueChangeLabels(ctx, doer, issue, labels, old) return nil } diff --git a/services/issue/label_test.go b/services/issue/label_test.go index af220601f1..90608c9e26 100644 --- a/services/issue/label_test.go +++ b/services/issue/label_test.go @@ -6,6 +6,7 @@ package issue import ( "testing" + "code.gitea.io/gitea/models/db" issues_model "code.gitea.io/gitea/models/issues" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -32,7 +33,7 @@ func TestIssue_AddLabels(t *testing.T) { labels[i] = unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: labelID}) } doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: test.doerID}) - assert.NoError(t, AddLabels(issue, doer, labels)) + assert.NoError(t, AddLabels(db.DefaultContext, issue, doer, labels)) for _, labelID := range test.labelIDs { unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: test.issueID, LabelID: labelID}) } @@ -55,7 +56,7 @@ func TestIssue_AddLabel(t *testing.T) { issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: test.issueID}) label := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: test.labelID}) doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: test.doerID}) - assert.NoError(t, AddLabel(issue, doer, label)) + assert.NoError(t, AddLabel(db.DefaultContext, issue, doer, label)) unittest.AssertExistsAndLoadBean(t, &issues_model.IssueLabel{IssueID: test.issueID, LabelID: test.labelID}) } } diff --git a/services/mailer/incoming/incoming_handler.go b/services/mailer/incoming/incoming_handler.go index b594e35189..78f9f89fc9 100644 --- a/services/mailer/incoming/incoming_handler.go +++ b/services/mailer/incoming/incoming_handler.go @@ -170,7 +170,7 @@ func (h *UnsubscribeHandler) Handle(ctx context.Context, _ *MailContent, doer *u return nil } - return issues_model.CreateOrUpdateIssueWatch(doer.ID, issue.ID, false) + return issues_model.CreateOrUpdateIssueWatch(ctx, doer.ID, issue.ID, false) } return fmt.Errorf("unsupported unsubscribe reference: %v", ref) diff --git a/services/migrations/dump.go b/services/migrations/dump.go index 603954810c..07812002af 100644 --- a/services/migrations/dump.go +++ b/services/migrations/dump.go @@ -655,7 +655,7 @@ func DumpRepository(ctx context.Context, baseDir, ownerName string, opts base.Mi return err } - if err := migrateRepository(doer, downloader, uploader, opts, nil); err != nil { + if err := migrateRepository(ctx, doer, downloader, uploader, opts, nil); err != nil { if err1 := uploader.Rollback(); err1 != nil { log.Error("rollback failed: %v", err1) } @@ -727,7 +727,7 @@ func RestoreRepository(ctx context.Context, baseDir, ownerName, repoName string, return err } - if err = migrateRepository(doer, downloader, uploader, migrateOpts, nil); err != nil { + if err = migrateRepository(ctx, doer, downloader, uploader, migrateOpts, nil); err != nil { if err1 := uploader.Rollback(); err1 != nil { log.Error("rollback failed: %v", err1) } diff --git a/services/migrations/gitea_uploader.go b/services/migrations/gitea_uploader.go index 4c21efae44..9f1e613bb2 100644 --- a/services/migrations/gitea_uploader.go +++ b/services/migrations/gitea_uploader.go @@ -162,7 +162,7 @@ func (g *GiteaLocalUploader) CreateTopics(topics ...string) error { c++ } topics = topics[:c] - return repo_model.SaveTopics(g.repo.ID, topics...) + return repo_model.SaveTopics(g.ctx, g.repo.ID, topics...) } // CreateMilestones creates milestones @@ -205,7 +205,7 @@ func (g *GiteaLocalUploader) CreateMilestones(milestones ...*base.Milestone) err mss = append(mss, &ms) } - err := issues_model.InsertMilestones(mss...) + err := issues_model.InsertMilestones(g.ctx, mss...) if err != nil { return err } @@ -236,7 +236,7 @@ func (g *GiteaLocalUploader) CreateLabels(labels ...*base.Label) error { }) } - err := issues_model.NewLabels(lbs...) + err := issues_model.NewLabels(g.ctx, lbs...) if err != nil { return err } @@ -516,7 +516,6 @@ func (g *GiteaLocalUploader) CreateComments(comments ...*base.Comment) error { // CreatePullRequests creates pull requests func (g *GiteaLocalUploader) CreatePullRequests(prs ...*base.PullRequest) error { gprs := make([]*issues_model.PullRequest, 0, len(prs)) - ctx := db.DefaultContext for _, pr := range prs { gpr, err := g.newPullRequest(pr) if err != nil { @@ -529,12 +528,12 @@ func (g *GiteaLocalUploader) CreatePullRequests(prs ...*base.PullRequest) error gprs = append(gprs, gpr) } - if err := issues_model.InsertPullRequests(ctx, gprs...); err != nil { + if err := issues_model.InsertPullRequests(g.ctx, gprs...); err != nil { return err } for _, pr := range gprs { g.issues[pr.Issue.Index] = pr.Issue - pull.AddToTaskQueue(ctx, pr) + pull.AddToTaskQueue(g.ctx, pr) } return nil } diff --git a/services/migrations/gitea_uploader_test.go b/services/migrations/gitea_uploader_test.go index e42d9e9286..4c6dfddc08 100644 --- a/services/migrations/gitea_uploader_test.go +++ b/services/migrations/gitea_uploader_test.go @@ -44,7 +44,7 @@ func TestGiteaUploadRepo(t *testing.T) { uploader = NewGiteaLocalUploader(graceful.GetManager().HammerContext(), user, user.Name, repoName) ) - err := migrateRepository(user, downloader, uploader, base.MigrateOptions{ + err := migrateRepository(db.DefaultContext, user, downloader, uploader, base.MigrateOptions{ CloneAddr: "https://github.com/go-xorm/builder", RepoName: repoName, AuthUsername: "", diff --git a/services/migrations/migrate.go b/services/migrations/migrate.go index 0ebb3411fd..0b83f3b4a3 100644 --- a/services/migrations/migrate.go +++ b/services/migrations/migrate.go @@ -127,7 +127,7 @@ func MigrateRepository(ctx context.Context, doer *user_model.User, ownerName str uploader := NewGiteaLocalUploader(ctx, doer, ownerName, opts.RepoName) uploader.gitServiceType = opts.GitServiceType - if err := migrateRepository(doer, downloader, uploader, opts, messenger); err != nil { + if err := migrateRepository(ctx, doer, downloader, uploader, opts, messenger); err != nil { if err1 := uploader.Rollback(); err1 != nil { log.Error("rollback failed: %v", err1) } @@ -176,7 +176,7 @@ func newDownloader(ctx context.Context, ownerName string, opts base.MigrateOptio // migrateRepository will download information and then upload it to Uploader, this is a simple // process for small repository. For a big repository, save all the data to disk // before upload is better -func migrateRepository(doer *user_model.User, downloader base.Downloader, uploader base.Uploader, opts base.MigrateOptions, messenger base.Messenger) error { +func migrateRepository(ctx context.Context, doer *user_model.User, downloader base.Downloader, uploader base.Uploader, opts base.MigrateOptions, messenger base.Messenger) error { if messenger == nil { messenger = base.NilMessenger } diff --git a/services/mirror/mirror_pull.go b/services/mirror/mirror_pull.go index 321bd38fc3..d2b7d37eaa 100644 --- a/services/mirror/mirror_pull.go +++ b/services/mirror/mirror_pull.go @@ -540,7 +540,7 @@ func SyncPullMirror(ctx context.Context, repoID int64) bool { return false } - if err = repo_model.UpdateRepositoryUpdatedTime(m.RepoID, commitDate); err != nil { + if err = repo_model.UpdateRepositoryUpdatedTime(ctx, m.RepoID, commitDate); err != nil { log.Error("SyncMirrors [repo: %-v]: unable to update repository 'updated_unix': %v", m.Repo, err) return false } diff --git a/services/repository/archiver/archiver.go b/services/repository/archiver/archiver.go index 2e3defee8d..f6f03e75ae 100644 --- a/services/repository/archiver/archiver.go +++ b/services/repository/archiver/archiver.go @@ -346,7 +346,7 @@ func DeleteOldRepositoryArchives(ctx context.Context, olderThan time.Duration) e log.Trace("Doing: ArchiveCleanup") for { - archivers, err := repo_model.FindRepoArchives(repo_model.FindRepoArchiversOption{ + archivers, err := repo_model.FindRepoArchives(ctx, repo_model.FindRepoArchiversOption{ ListOptions: db.ListOptions{ PageSize: 100, Page: 1, @@ -374,7 +374,7 @@ func DeleteOldRepositoryArchives(ctx context.Context, olderThan time.Duration) e // DeleteRepositoryArchives deletes all repositories' archives. func DeleteRepositoryArchives(ctx context.Context) error { - if err := repo_model.DeleteAllRepoArchives(); err != nil { + if err := repo_model.DeleteAllRepoArchives(ctx); err != nil { return err } return storage.Clean(storage.RepoArchives) diff --git a/services/repository/collaboration.go b/services/repository/collaboration.go index 28824d83f5..eff33c71f3 100644 --- a/services/repository/collaboration.go +++ b/services/repository/collaboration.go @@ -5,6 +5,8 @@ package repository import ( + "context" + "code.gitea.io/gitea/models" "code.gitea.io/gitea/models/db" access_model "code.gitea.io/gitea/models/perm/access" @@ -12,13 +14,13 @@ import ( ) // DeleteCollaboration removes collaboration relation between the user and repository. -func DeleteCollaboration(repo *repo_model.Repository, uid int64) (err error) { +func DeleteCollaboration(ctx context.Context, repo *repo_model.Repository, uid int64) (err error) { collaboration := &repo_model.Collaboration{ RepoID: repo.ID, UserID: uid, } - ctx, committer, err := db.TxContext(db.DefaultContext) + ctx, committer, err := db.TxContext(ctx) if err != nil { return err } diff --git a/services/repository/collaboration_test.go b/services/repository/collaboration_test.go index 08159af7bc..c3d006bfd8 100644 --- a/services/repository/collaboration_test.go +++ b/services/repository/collaboration_test.go @@ -18,10 +18,10 @@ func TestRepository_DeleteCollaboration(t *testing.T) { repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 4}) assert.NoError(t, repo.LoadOwner(db.DefaultContext)) - assert.NoError(t, DeleteCollaboration(repo, 4)) + assert.NoError(t, DeleteCollaboration(db.DefaultContext, repo, 4)) unittest.AssertNotExistsBean(t, &repo_model.Collaboration{RepoID: repo.ID, UserID: 4}) - assert.NoError(t, DeleteCollaboration(repo, 4)) + assert.NoError(t, DeleteCollaboration(db.DefaultContext, repo, 4)) unittest.AssertNotExistsBean(t, &repo_model.Collaboration{RepoID: repo.ID, UserID: 4}) unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: repo.ID}) diff --git a/services/repository/push.go b/services/repository/push.go index 9b00b57e71..97da45f52b 100644 --- a/services/repository/push.go +++ b/services/repository/push.go @@ -292,7 +292,7 @@ func pushUpdates(optsList []*repo_module.PushUpdateOptions) error { } // Change repository last updated time. - if err := repo_model.UpdateRepositoryUpdatedTime(repo.ID, time.Now()); err != nil { + if err := repo_model.UpdateRepositoryUpdatedTime(ctx, repo.ID, time.Now()); err != nil { return fmt.Errorf("UpdateRepositoryUpdatedTime: %w", err) } diff --git a/services/repository/repository.go b/services/repository/repository.go index 60f9568b54..fb52980bbd 100644 --- a/services/repository/repository.go +++ b/services/repository/repository.go @@ -95,12 +95,12 @@ func PushCreateRepo(ctx context.Context, authUser, owner *user_model.User, repoN } // Init start repository service -func Init() error { +func Init(ctx context.Context) error { if err := repo_module.LoadRepoConfig(); err != nil { return err } - system_model.RemoveAllWithNotice(db.DefaultContext, "Clean up temporary repository uploads", setting.Repository.Upload.TempPath) - system_model.RemoveAllWithNotice(db.DefaultContext, "Clean up temporary repositories", repo_module.LocalCopyPath()) + system_model.RemoveAllWithNotice(ctx, "Clean up temporary repository uploads", setting.Repository.Upload.TempPath) + system_model.RemoveAllWithNotice(ctx, "Clean up temporary repositories", repo_module.LocalCopyPath()) if err := initPushQueue(); err != nil { return err } diff --git a/services/repository/transfer.go b/services/repository/transfer.go index 2edb61816f..574b6c6a56 100644 --- a/services/repository/transfer.go +++ b/services/repository/transfer.go @@ -37,7 +37,7 @@ func TransferOwnership(ctx context.Context, doer, newOwner *user_model.User, rep oldOwner := repo.Owner repoWorkingPool.CheckIn(fmt.Sprint(repo.ID)) - if err := models.TransferOwnership(doer, newOwner.Name, repo); err != nil { + if err := models.TransferOwnership(ctx, doer, newOwner.Name, repo); err != nil { repoWorkingPool.CheckOut(fmt.Sprint(repo.ID)) return err } @@ -70,7 +70,7 @@ func ChangeRepositoryName(ctx context.Context, doer *user_model.User, repo *repo // local copy's origin accordingly. repoWorkingPool.CheckIn(fmt.Sprint(repo.ID)) - if err := repo_model.ChangeRepositoryName(doer, repo, newRepoName); err != nil { + if err := repo_model.ChangeRepositoryName(ctx, doer, repo, newRepoName); err != nil { repoWorkingPool.CheckOut(fmt.Sprint(repo.ID)) return err } diff --git a/services/task/migrate.go b/services/task/migrate.go index ebf179045e..70e5abdee6 100644 --- a/services/task/migrate.go +++ b/services/task/migrate.go @@ -4,6 +4,7 @@ package task import ( + "context" "errors" "fmt" "strings" @@ -40,7 +41,7 @@ func handleCreateError(owner *user_model.User, err error) error { } } -func runMigrateTask(t *admin_model.Task) (err error) { +func runMigrateTask(ctx context.Context, t *admin_model.Task) (err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("PANIC whilst trying to do migrate task: %v", e) @@ -48,9 +49,9 @@ func runMigrateTask(t *admin_model.Task) (err error) { } if err == nil { - err = admin_model.FinishMigrateTask(t) + err = admin_model.FinishMigrateTask(ctx, t) if err == nil { - notify_service.MigrateRepository(db.DefaultContext, t.Doer, t.Owner, t.Repo) + notify_service.MigrateRepository(ctx, t.Doer, t.Owner, t.Repo) return } @@ -63,14 +64,14 @@ func runMigrateTask(t *admin_model.Task) (err error) { t.Status = structs.TaskStatusFailed t.Message = err.Error() - if err := t.UpdateCols("status", "message", "end_time"); err != nil { + if err := t.UpdateCols(ctx, "status", "message", "end_time"); err != nil { log.Error("Task UpdateCols failed: %v", err) } // then, do not delete the repository, otherwise the users won't be able to see the last error }() - if err = t.LoadRepo(); err != nil { + if err = t.LoadRepo(ctx); err != nil { return err } @@ -79,10 +80,10 @@ func runMigrateTask(t *admin_model.Task) (err error) { return nil } - if err = t.LoadDoer(); err != nil { + if err = t.LoadDoer(ctx); err != nil { return err } - if err = t.LoadOwner(); err != nil { + if err = t.LoadOwner(ctx); err != nil { return err } @@ -100,7 +101,7 @@ func runMigrateTask(t *admin_model.Task) (err error) { t.StartTime = timeutil.TimeStampNow() t.Status = structs.TaskStatusRunning - if err = t.UpdateCols("start_time", "status"); err != nil { + if err = t.UpdateCols(ctx, "start_time", "status"); err != nil { return err } @@ -112,7 +113,7 @@ func runMigrateTask(t *admin_model.Task) (err error) { case <-ctx.Done(): return } - task, _ := admin_model.GetMigratingTask(t.RepoID) + task, _ := admin_model.GetMigratingTask(ctx, t.RepoID) if task != nil && task.Status != structs.TaskStatusRunning { log.Debug("MigrateTask[%d] by DoerID[%d] to RepoID[%d] for OwnerID[%d] is canceled due to status is not 'running'", t.ID, t.DoerID, t.RepoID, t.OwnerID) cancel() @@ -128,7 +129,7 @@ func runMigrateTask(t *admin_model.Task) (err error) { } bs, _ := json.Marshal(message) t.Message = string(bs) - _ = t.UpdateCols("message") + _ = t.UpdateCols(ctx, "message") }) if err == nil { diff --git a/services/task/task.go b/services/task/task.go index 3a40faef90..e15cab7b3c 100644 --- a/services/task/task.go +++ b/services/task/task.go @@ -4,6 +4,7 @@ package task import ( + "context" "fmt" admin_model "code.gitea.io/gitea/models/admin" @@ -27,10 +28,10 @@ import ( var taskQueue *queue.WorkerPoolQueue[*admin_model.Task] // Run a task -func Run(t *admin_model.Task) error { +func Run(ctx context.Context, t *admin_model.Task) error { switch t.Type { case structs.TaskTypeMigrateRepo: - return runMigrateTask(t) + return runMigrateTask(ctx, t) default: return fmt.Errorf("Unknown task type: %d", t.Type) } @@ -48,7 +49,7 @@ func Init() error { func handler(items ...*admin_model.Task) []*admin_model.Task { for _, task := range items { - if err := Run(task); err != nil { + if err := Run(db.DefaultContext, task); err != nil { log.Error("Run task failed: %v", err) } } @@ -56,8 +57,8 @@ func handler(items ...*admin_model.Task) []*admin_model.Task { } // MigrateRepository add migration repository to task -func MigrateRepository(doer, u *user_model.User, opts base.MigrateOptions) error { - task, err := CreateMigrateTask(doer, u, opts) +func MigrateRepository(ctx context.Context, doer, u *user_model.User, opts base.MigrateOptions) error { + task, err := CreateMigrateTask(ctx, doer, u, opts) if err != nil { return err } @@ -66,7 +67,7 @@ func MigrateRepository(doer, u *user_model.User, opts base.MigrateOptions) error } // CreateMigrateTask creates a migrate task -func CreateMigrateTask(doer, u *user_model.User, opts base.MigrateOptions) (*admin_model.Task, error) { +func CreateMigrateTask(ctx context.Context, doer, u *user_model.User, opts base.MigrateOptions) (*admin_model.Task, error) { // encrypt credentials for persistence var err error opts.CloneAddrEncrypted, err = secret.EncryptSecret(setting.SecretKey, opts.CloneAddr) @@ -97,11 +98,11 @@ func CreateMigrateTask(doer, u *user_model.User, opts base.MigrateOptions) (*adm PayloadContent: string(bs), } - if err := admin_model.CreateTask(task); err != nil { + if err := admin_model.CreateTask(ctx, task); err != nil { return nil, err } - repo, err := repo_service.CreateRepositoryDirectly(db.DefaultContext, doer, u, repo_service.CreateRepoOptions{ + repo, err := repo_service.CreateRepositoryDirectly(ctx, doer, u, repo_service.CreateRepoOptions{ Name: opts.RepoName, Description: opts.Description, OriginalURL: opts.OriginalURL, @@ -113,7 +114,7 @@ func CreateMigrateTask(doer, u *user_model.User, opts base.MigrateOptions) (*adm if err != nil { task.EndTime = timeutil.TimeStampNow() task.Status = structs.TaskStatusFailed - err2 := task.UpdateCols("end_time", "status") + err2 := task.UpdateCols(ctx, "end_time", "status") if err2 != nil { log.Error("UpdateCols Failed: %v", err2.Error()) } @@ -121,7 +122,7 @@ func CreateMigrateTask(doer, u *user_model.User, opts base.MigrateOptions) (*adm } task.RepoID = repo.ID - if err = task.UpdateCols("repo_id"); err != nil { + if err = task.UpdateCols(ctx, "repo_id"); err != nil { return nil, err } @@ -129,8 +130,8 @@ func CreateMigrateTask(doer, u *user_model.User, opts base.MigrateOptions) (*adm } // RetryMigrateTask retry a migrate task -func RetryMigrateTask(repoID int64) error { - migratingTask, err := admin_model.GetMigratingTask(repoID) +func RetryMigrateTask(ctx context.Context, repoID int64) error { + migratingTask, err := admin_model.GetMigratingTask(ctx, repoID) if err != nil { log.Error("GetMigratingTask: %v", err) return err @@ -144,7 +145,7 @@ func RetryMigrateTask(repoID int64) error { // Reset task status and messages migratingTask.Status = structs.TaskStatusQueued migratingTask.Message = "" - if err = migratingTask.UpdateCols("status", "message"); err != nil { + if err = migratingTask.UpdateCols(ctx, "status", "message"); err != nil { log.Error("task.UpdateCols failed: %v", err) return err } diff --git a/services/user/user.go b/services/user/user.go index 72bea0b468..5b2e74eb82 100644 --- a/services/user/user.go +++ b/services/user/user.go @@ -59,7 +59,7 @@ func RenameUser(ctx context.Context, u *user_model.User, newUserName string) err u.Name = oldUserName return err } - return repo_model.UpdateRepositoryOwnerNames(u.ID, newUserName) + return repo_model.UpdateRepositoryOwnerNames(ctx, u.ID, newUserName) } ctx, committer, err := db.TxContext(ctx) diff --git a/tests/integration/incoming_email_test.go b/tests/integration/incoming_email_test.go index b4478f5780..1284833864 100644 --- a/tests/integration/incoming_email_test.go +++ b/tests/integration/incoming_email_test.go @@ -154,7 +154,7 @@ func TestIncomingEmail(t *testing.T) { t.Run("Unsubscribe", func(t *testing.T) { defer tests.PrintCurrentTest(t)() - watching, err := issues_model.CheckIssueWatch(user, issue) + watching, err := issues_model.CheckIssueWatch(db.DefaultContext, user, issue) assert.NoError(t, err) assert.True(t, watching) @@ -169,7 +169,7 @@ func TestIncomingEmail(t *testing.T) { assert.NoError(t, handler.Handle(db.DefaultContext, content, user, payload)) - watching, err = issues_model.CheckIssueWatch(user, issue) + watching, err = issues_model.CheckIssueWatch(db.DefaultContext, user, issue) assert.NoError(t, err) assert.False(t, watching) }) |