diff options
Diffstat (limited to '')
-rw-r--r-- | models/unit/unit.go | 437 | ||||
-rw-r--r-- | models/unit/unit_test.go | 96 | ||||
-rw-r--r-- | models/unittest/consistency.go | 192 | ||||
-rw-r--r-- | models/unittest/fixtures.go | 144 | ||||
-rw-r--r-- | models/unittest/fscopy.go | 102 | ||||
-rw-r--r-- | models/unittest/mock_http.go | 115 | ||||
-rw-r--r-- | models/unittest/reflection.go | 40 | ||||
-rw-r--r-- | models/unittest/testdb.go | 267 | ||||
-rw-r--r-- | models/unittest/unit_tests.go | 164 |
9 files changed, 1557 insertions, 0 deletions
diff --git a/models/unit/unit.go b/models/unit/unit.go new file mode 100644 index 0000000..5a8b911 --- /dev/null +++ b/models/unit/unit.go @@ -0,0 +1,437 @@ +// Copyright 2017 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unit + +import ( + "errors" + "fmt" + "strings" + "sync/atomic" + + "code.gitea.io/gitea/models/perm" + "code.gitea.io/gitea/modules/container" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" +) + +// Type is Unit's Type +type Type int + +// Enumerate all the unit types +const ( + TypeInvalid Type = iota // 0 invalid + TypeCode // 1 code + TypeIssues // 2 issues + TypePullRequests // 3 PRs + TypeReleases // 4 Releases + TypeWiki // 5 Wiki + TypeExternalWiki // 6 ExternalWiki + TypeExternalTracker // 7 ExternalTracker + TypeProjects // 8 Projects + TypePackages // 9 Packages + TypeActions // 10 Actions +) + +// Value returns integer value for unit type +func (u Type) Value() int { + return int(u) +} + +func (u Type) String() string { + switch u { + case TypeCode: + return "TypeCode" + case TypeIssues: + return "TypeIssues" + case TypePullRequests: + return "TypePullRequests" + case TypeReleases: + return "TypeReleases" + case TypeWiki: + return "TypeWiki" + case TypeExternalWiki: + return "TypeExternalWiki" + case TypeExternalTracker: + return "TypeExternalTracker" + case TypeProjects: + return "TypeProjects" + case TypePackages: + return "TypePackages" + case TypeActions: + return "TypeActions" + } + return fmt.Sprintf("Unknown Type %d", u) +} + +func (u Type) LogString() string { + return fmt.Sprintf("<UnitType:%d:%s>", u, u.String()) +} + +var ( + // AllRepoUnitTypes contains all the unit types + AllRepoUnitTypes = []Type{ + TypeCode, + TypeIssues, + TypePullRequests, + TypeReleases, + TypeWiki, + TypeExternalWiki, + TypeExternalTracker, + TypeProjects, + TypePackages, + TypeActions, + } + + // DefaultRepoUnits contains the default unit types + DefaultRepoUnits = []Type{ + TypeCode, + TypeIssues, + TypePullRequests, + TypeReleases, + TypeWiki, + TypeProjects, + TypePackages, + TypeActions, + } + + // ForkRepoUnits contains the default unit types for forks + DefaultForkRepoUnits = []Type{ + TypeCode, + TypePullRequests, + } + + // NotAllowedDefaultRepoUnits contains units that can't be default + NotAllowedDefaultRepoUnits = []Type{ + TypeExternalWiki, + TypeExternalTracker, + } + + disabledRepoUnitsAtomic atomic.Pointer[[]Type] // the units that have been globally disabled + + // AllowedRepoUnitGroups contains the units that have been globally enabled, + // with mutually exclusive units grouped together. + AllowedRepoUnitGroups = [][]Type{} +) + +// DisabledRepoUnitsGet returns the globally disabled units, it is a quick patch to fix data-race during testing. +// Because the queue worker might read when a test is mocking the value. FIXME: refactor to a clear solution later. +func DisabledRepoUnitsGet() []Type { + v := disabledRepoUnitsAtomic.Load() + if v == nil { + return nil + } + return *v +} + +func DisabledRepoUnitsSet(v []Type) { + disabledRepoUnitsAtomic.Store(&v) +} + +// Get valid set of default repository units from settings +func validateDefaultRepoUnits(defaultUnits, settingDefaultUnits []Type) []Type { + units := defaultUnits + + // Use setting if not empty + if len(settingDefaultUnits) > 0 { + units = make([]Type, 0, len(settingDefaultUnits)) + for _, settingUnit := range settingDefaultUnits { + if !settingUnit.CanBeDefault() { + log.Warn("Not allowed as default unit: %s", settingUnit.String()) + continue + } + units = append(units, settingUnit) + } + } + + // Remove disabled units + for _, disabledUnit := range DisabledRepoUnitsGet() { + for i, unit := range units { + if unit == disabledUnit { + units = append(units[:i], units[i+1:]...) + } + } + } + + return units +} + +// LoadUnitConfig load units from settings +func LoadUnitConfig() error { + disabledRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DisabledRepoUnits...) + if len(invalidKeys) > 0 { + log.Warn("Invalid keys in disabled repo units: %s", strings.Join(invalidKeys, ", ")) + } + DisabledRepoUnitsSet(disabledRepoUnits) + + setDefaultRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DefaultRepoUnits...) + if len(invalidKeys) > 0 { + log.Warn("Invalid keys in default repo units: %s", strings.Join(invalidKeys, ", ")) + } + DefaultRepoUnits = validateDefaultRepoUnits(DefaultRepoUnits, setDefaultRepoUnits) + if len(DefaultRepoUnits) == 0 { + return errors.New("no default repository units found") + } + setDefaultForkRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DefaultForkRepoUnits...) + if len(invalidKeys) > 0 { + log.Warn("Invalid keys in default fork repo units: %s", strings.Join(invalidKeys, ", ")) + } + DefaultForkRepoUnits = validateDefaultRepoUnits(DefaultForkRepoUnits, setDefaultForkRepoUnits) + if len(DefaultForkRepoUnits) == 0 { + return errors.New("no default fork repository units found") + } + + // Collect the allowed repo unit groups. Mutually exclusive units are + // grouped together. + AllowedRepoUnitGroups = [][]Type{} + for _, unit := range []Type{ + TypeCode, + TypePullRequests, + TypeProjects, + TypePackages, + TypeActions, + } { + // If unit is globally disabled, ignore it. + if unit.UnitGlobalDisabled() { + continue + } + + // If it is allowed, add it to the group list. + AllowedRepoUnitGroups = append(AllowedRepoUnitGroups, []Type{unit}) + } + + addMutuallyExclusiveGroup := func(unit1, unit2 Type) { + var list []Type + + if !unit1.UnitGlobalDisabled() { + list = append(list, unit1) + } + + if !unit2.UnitGlobalDisabled() { + list = append(list, unit2) + } + + if len(list) > 0 { + AllowedRepoUnitGroups = append(AllowedRepoUnitGroups, list) + } + } + + addMutuallyExclusiveGroup(TypeIssues, TypeExternalTracker) + addMutuallyExclusiveGroup(TypeWiki, TypeExternalWiki) + + return nil +} + +// UnitGlobalDisabled checks if unit type is global disabled +func (u Type) UnitGlobalDisabled() bool { + for _, ud := range DisabledRepoUnitsGet() { + if u == ud { + return true + } + } + return false +} + +// CanBeDefault checks if the unit type can be a default repo unit +func (u *Type) CanBeDefault() bool { + for _, nadU := range NotAllowedDefaultRepoUnits { + if *u == nadU { + return false + } + } + return true +} + +// Unit is a section of one repository +type Unit struct { + Type Type + Name string + NameKey string + URI string + DescKey string + Idx int + MaxAccessMode perm.AccessMode // The max access mode of the unit. i.e. Read means this unit can only be read. +} + +// IsLessThan compares order of two units +func (u Unit) IsLessThan(unit Unit) bool { + if (u.Type == TypeExternalTracker || u.Type == TypeExternalWiki) && unit.Type != TypeExternalTracker && unit.Type != TypeExternalWiki { + return false + } + return u.Idx < unit.Idx +} + +// MaxPerm returns the max perms of this unit +func (u Unit) MaxPerm() perm.AccessMode { + if u.Type == TypeExternalTracker || u.Type == TypeExternalWiki { + return perm.AccessModeRead + } + return perm.AccessModeAdmin +} + +// Enumerate all the units +var ( + UnitCode = Unit{ + TypeCode, + "code", + "repo.code", + "/", + "repo.code.desc", + 0, + perm.AccessModeOwner, + } + + UnitIssues = Unit{ + TypeIssues, + "issues", + "repo.issues", + "/issues", + "repo.issues.desc", + 1, + perm.AccessModeOwner, + } + + UnitExternalTracker = Unit{ + TypeExternalTracker, + "ext_issues", + "repo.ext_issues", + "/issues", + "repo.ext_issues.desc", + 1, + perm.AccessModeRead, + } + + UnitPullRequests = Unit{ + TypePullRequests, + "pulls", + "repo.pulls", + "/pulls", + "repo.pulls.desc", + 2, + perm.AccessModeOwner, + } + + UnitReleases = Unit{ + TypeReleases, + "releases", + "repo.releases", + "/releases", + "repo.releases.desc", + 3, + perm.AccessModeOwner, + } + + UnitWiki = Unit{ + TypeWiki, + "wiki", + "repo.wiki", + "/wiki", + "repo.wiki.desc", + 4, + perm.AccessModeOwner, + } + + UnitExternalWiki = Unit{ + TypeExternalWiki, + "ext_wiki", + "repo.ext_wiki", + "/wiki", + "repo.ext_wiki.desc", + 4, + perm.AccessModeRead, + } + + UnitProjects = Unit{ + TypeProjects, + "projects", + "repo.projects", + "/projects", + "repo.projects.desc", + 5, + perm.AccessModeOwner, + } + + UnitPackages = Unit{ + TypePackages, + "packages", + "repo.packages", + "/packages", + "packages.desc", + 6, + perm.AccessModeRead, + } + + UnitActions = Unit{ + TypeActions, + "actions", + "repo.actions", + "/actions", + "actions.unit.desc", + 7, + perm.AccessModeOwner, + } + + // Units contains all the units + Units = map[Type]Unit{ + TypeCode: UnitCode, + TypeIssues: UnitIssues, + TypeExternalTracker: UnitExternalTracker, + TypePullRequests: UnitPullRequests, + TypeReleases: UnitReleases, + TypeWiki: UnitWiki, + TypeExternalWiki: UnitExternalWiki, + TypeProjects: UnitProjects, + TypePackages: UnitPackages, + TypeActions: UnitActions, + } +) + +// FindUnitTypes give the unit key names and return valid unique units and invalid keys +func FindUnitTypes(nameKeys ...string) (res []Type, invalidKeys []string) { + m := make(container.Set[Type]) + for _, key := range nameKeys { + t := TypeFromKey(key) + if t == TypeInvalid { + invalidKeys = append(invalidKeys, key) + } else if m.Add(t) { + res = append(res, t) + } + } + return res, invalidKeys +} + +// TypeFromKey give the unit key name and return unit +func TypeFromKey(nameKey string) Type { + for t, u := range Units { + if strings.EqualFold(nameKey, u.NameKey) { + return t + } + } + return TypeInvalid +} + +// AllUnitKeyNames returns all unit key names +func AllUnitKeyNames() []string { + res := make([]string, 0, len(Units)) + for _, u := range Units { + res = append(res, u.NameKey) + } + return res +} + +// MinUnitAccessMode returns the minial permission of the permission map +func MinUnitAccessMode(unitsMap map[Type]perm.AccessMode) perm.AccessMode { + res := perm.AccessModeNone + for t, mode := range unitsMap { + // Don't allow `TypeExternal{Tracker,Wiki}` to influence this as they can only be set to READ perms. + if t == TypeExternalTracker || t == TypeExternalWiki { + continue + } + + // get the minial permission great than AccessModeNone except all are AccessModeNone + if mode > perm.AccessModeNone && (res == perm.AccessModeNone || mode < res) { + res = mode + } + } + return res +} diff --git a/models/unit/unit_test.go b/models/unit/unit_test.go new file mode 100644 index 0000000..a739677 --- /dev/null +++ b/models/unit/unit_test.go @@ -0,0 +1,96 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unit + +import ( + "testing" + + "code.gitea.io/gitea/modules/setting" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLoadUnitConfig(t *testing.T) { + t.Run("regular", func(t *testing.T) { + defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) { + DisabledRepoUnitsSet(disabledRepoUnits) + DefaultRepoUnits = defaultRepoUnits + DefaultForkRepoUnits = defaultForkRepoUnits + }(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits) + defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) { + setting.Repository.DisabledRepoUnits = disabledRepoUnits + setting.Repository.DefaultRepoUnits = defaultRepoUnits + setting.Repository.DefaultForkRepoUnits = defaultForkRepoUnits + }(setting.Repository.DisabledRepoUnits, setting.Repository.DefaultRepoUnits, setting.Repository.DefaultForkRepoUnits) + + setting.Repository.DisabledRepoUnits = []string{"repo.issues"} + setting.Repository.DefaultRepoUnits = []string{"repo.code", "repo.releases", "repo.issues", "repo.pulls"} + setting.Repository.DefaultForkRepoUnits = []string{"repo.releases"} + require.NoError(t, LoadUnitConfig()) + assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet()) + assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits) + assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits) + }) + t.Run("invalid", func(t *testing.T) { + defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) { + DisabledRepoUnitsSet(disabledRepoUnits) + DefaultRepoUnits = defaultRepoUnits + DefaultForkRepoUnits = defaultForkRepoUnits + }(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits) + defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) { + setting.Repository.DisabledRepoUnits = disabledRepoUnits + setting.Repository.DefaultRepoUnits = defaultRepoUnits + setting.Repository.DefaultForkRepoUnits = defaultForkRepoUnits + }(setting.Repository.DisabledRepoUnits, setting.Repository.DefaultRepoUnits, setting.Repository.DefaultForkRepoUnits) + + setting.Repository.DisabledRepoUnits = []string{"repo.issues", "invalid.1"} + setting.Repository.DefaultRepoUnits = []string{"repo.code", "invalid.2", "repo.releases", "repo.issues", "repo.pulls"} + setting.Repository.DefaultForkRepoUnits = []string{"invalid.3", "repo.releases"} + require.NoError(t, LoadUnitConfig()) + assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet()) + assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits) + assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits) + }) + t.Run("duplicate", func(t *testing.T) { + defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) { + DisabledRepoUnitsSet(disabledRepoUnits) + DefaultRepoUnits = defaultRepoUnits + DefaultForkRepoUnits = defaultForkRepoUnits + }(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits) + defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) { + setting.Repository.DisabledRepoUnits = disabledRepoUnits + setting.Repository.DefaultRepoUnits = defaultRepoUnits + setting.Repository.DefaultForkRepoUnits = defaultForkRepoUnits + }(setting.Repository.DisabledRepoUnits, setting.Repository.DefaultRepoUnits, setting.Repository.DefaultForkRepoUnits) + + setting.Repository.DisabledRepoUnits = []string{"repo.issues", "repo.issues"} + setting.Repository.DefaultRepoUnits = []string{"repo.code", "repo.releases", "repo.issues", "repo.pulls", "repo.code"} + setting.Repository.DefaultForkRepoUnits = []string{"repo.releases", "repo.releases"} + require.NoError(t, LoadUnitConfig()) + assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet()) + assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits) + assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits) + }) + t.Run("empty_default", func(t *testing.T) { + defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) { + DisabledRepoUnitsSet(disabledRepoUnits) + DefaultRepoUnits = defaultRepoUnits + DefaultForkRepoUnits = defaultForkRepoUnits + }(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits) + defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) { + setting.Repository.DisabledRepoUnits = disabledRepoUnits + setting.Repository.DefaultRepoUnits = defaultRepoUnits + setting.Repository.DefaultForkRepoUnits = defaultForkRepoUnits + }(setting.Repository.DisabledRepoUnits, setting.Repository.DefaultRepoUnits, setting.Repository.DefaultForkRepoUnits) + + setting.Repository.DisabledRepoUnits = []string{"repo.issues", "repo.issues"} + setting.Repository.DefaultRepoUnits = []string{} + setting.Repository.DefaultForkRepoUnits = []string{"repo.releases", "repo.releases"} + require.NoError(t, LoadUnitConfig()) + assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet()) + assert.ElementsMatch(t, []Type{TypeCode, TypePullRequests, TypeReleases, TypeWiki, TypePackages, TypeProjects, TypeActions}, DefaultRepoUnits) + assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits) + }) +} diff --git a/models/unittest/consistency.go b/models/unittest/consistency.go new file mode 100644 index 0000000..4e26de7 --- /dev/null +++ b/models/unittest/consistency.go @@ -0,0 +1,192 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unittest + +import ( + "reflect" + "strconv" + "strings" + "testing" + + "code.gitea.io/gitea/models/db" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "xorm.io/builder" +) + +const ( + // these const values are copied from `models` package to prevent from cycle-import + modelsUserTypeOrganization = 1 + modelsRepoWatchModeDont = 2 + modelsCommentTypeComment = 0 +) + +var consistencyCheckMap = make(map[string]func(t *testing.T, bean any)) + +// CheckConsistencyFor test that all matching database entries are consistent +func CheckConsistencyFor(t *testing.T, beansToCheck ...any) { + for _, bean := range beansToCheck { + sliceType := reflect.SliceOf(reflect.TypeOf(bean)) + sliceValue := reflect.MakeSlice(sliceType, 0, 10) + + ptrToSliceValue := reflect.New(sliceType) + ptrToSliceValue.Elem().Set(sliceValue) + + require.NoError(t, db.GetEngine(db.DefaultContext).Table(bean).Find(ptrToSliceValue.Interface())) + sliceValue = ptrToSliceValue.Elem() + + for i := 0; i < sliceValue.Len(); i++ { + entity := sliceValue.Index(i).Interface() + checkForConsistency(t, entity) + } + } +} + +func checkForConsistency(t *testing.T, bean any) { + tb, err := db.TableInfo(bean) + require.NoError(t, err) + f := consistencyCheckMap[tb.Name] + if f == nil { + assert.FailNow(t, "unknown bean type: %#v", bean) + } + f(t, bean) +} + +func init() { + parseBool := func(v string) bool { + b, _ := strconv.ParseBool(v) + return b + } + parseInt := func(v string) int { + i, _ := strconv.Atoi(v) + return i + } + + checkForUserConsistency := func(t *testing.T, bean any) { + user := reflectionWrap(bean) + AssertCountByCond(t, "repository", builder.Eq{"owner_id": user.int("ID")}, user.int("NumRepos")) + AssertCountByCond(t, "star", builder.Eq{"uid": user.int("ID")}, user.int("NumStars")) + AssertCountByCond(t, "org_user", builder.Eq{"org_id": user.int("ID")}, user.int("NumMembers")) + AssertCountByCond(t, "team", builder.Eq{"org_id": user.int("ID")}, user.int("NumTeams")) + AssertCountByCond(t, "follow", builder.Eq{"user_id": user.int("ID")}, user.int("NumFollowing")) + AssertCountByCond(t, "follow", builder.Eq{"follow_id": user.int("ID")}, user.int("NumFollowers")) + if user.int("Type") != modelsUserTypeOrganization { + assert.EqualValues(t, 0, user.int("NumMembers"), "Unexpected number of members for user id: %d", user.int("ID")) + assert.EqualValues(t, 0, user.int("NumTeams"), "Unexpected number of teams for user id: %d", user.int("ID")) + } + } + + checkForRepoConsistency := func(t *testing.T, bean any) { + repo := reflectionWrap(bean) + assert.Equal(t, repo.str("LowerName"), strings.ToLower(repo.str("Name")), "repo: %+v", repo) + AssertCountByCond(t, "star", builder.Eq{"repo_id": repo.int("ID")}, repo.int("NumStars")) + AssertCountByCond(t, "milestone", builder.Eq{"repo_id": repo.int("ID")}, repo.int("NumMilestones")) + AssertCountByCond(t, "repository", builder.Eq{"fork_id": repo.int("ID")}, repo.int("NumForks")) + if repo.bool("IsFork") { + AssertExistsAndLoadMap(t, "repository", builder.Eq{"id": repo.int("ForkID")}) + } + + actual := GetCountByCond(t, "watch", builder.Eq{"repo_id": repo.int("ID")}. + And(builder.Neq{"mode": modelsRepoWatchModeDont})) + assert.EqualValues(t, repo.int("NumWatches"), actual, + "Unexpected number of watches for repo id: %d", repo.int("ID")) + + actual = GetCountByCond(t, "issue", builder.Eq{"is_pull": false, "repo_id": repo.int("ID")}) + assert.EqualValues(t, repo.int("NumIssues"), actual, + "Unexpected number of issues for repo id: %d", repo.int("ID")) + + actual = GetCountByCond(t, "issue", builder.Eq{"is_pull": false, "is_closed": true, "repo_id": repo.int("ID")}) + assert.EqualValues(t, repo.int("NumClosedIssues"), actual, + "Unexpected number of closed issues for repo id: %d", repo.int("ID")) + + actual = GetCountByCond(t, "issue", builder.Eq{"is_pull": true, "repo_id": repo.int("ID")}) + assert.EqualValues(t, repo.int("NumPulls"), actual, + "Unexpected number of pulls for repo id: %d", repo.int("ID")) + + actual = GetCountByCond(t, "issue", builder.Eq{"is_pull": true, "is_closed": true, "repo_id": repo.int("ID")}) + assert.EqualValues(t, repo.int("NumClosedPulls"), actual, + "Unexpected number of closed pulls for repo id: %d", repo.int("ID")) + + actual = GetCountByCond(t, "milestone", builder.Eq{"is_closed": true, "repo_id": repo.int("ID")}) + assert.EqualValues(t, repo.int("NumClosedMilestones"), actual, + "Unexpected number of closed milestones for repo id: %d", repo.int("ID")) + } + + checkForIssueConsistency := func(t *testing.T, bean any) { + issue := reflectionWrap(bean) + typeComment := modelsCommentTypeComment + actual := GetCountByCond(t, "comment", builder.Eq{"`type`": typeComment, "issue_id": issue.int("ID")}) + assert.EqualValues(t, issue.int("NumComments"), actual, "Unexpected number of comments for issue id: %d", issue.int("ID")) + if issue.bool("IsPull") { + prRow := AssertExistsAndLoadMap(t, "pull_request", builder.Eq{"issue_id": issue.int("ID")}) + assert.EqualValues(t, parseInt(prRow["index"]), issue.int("Index"), "Unexpected index for issue id: %d", issue.int("ID")) + } + } + + checkForPullRequestConsistency := func(t *testing.T, bean any) { + pr := reflectionWrap(bean) + issueRow := AssertExistsAndLoadMap(t, "issue", builder.Eq{"id": pr.int("IssueID")}) + assert.True(t, parseBool(issueRow["is_pull"])) + assert.EqualValues(t, parseInt(issueRow["index"]), pr.int("Index"), "Unexpected index for pull request id: %d", pr.int("ID")) + } + + checkForMilestoneConsistency := func(t *testing.T, bean any) { + milestone := reflectionWrap(bean) + AssertCountByCond(t, "issue", builder.Eq{"milestone_id": milestone.int("ID")}, milestone.int("NumIssues")) + + actual := GetCountByCond(t, "issue", builder.Eq{"is_closed": true, "milestone_id": milestone.int("ID")}) + assert.EqualValues(t, milestone.int("NumClosedIssues"), actual, "Unexpected number of closed issues for milestone id: %d", milestone.int("ID")) + + completeness := 0 + if milestone.int("NumIssues") > 0 { + completeness = milestone.int("NumClosedIssues") * 100 / milestone.int("NumIssues") + } + assert.Equal(t, completeness, milestone.int("Completeness")) + } + + checkForLabelConsistency := func(t *testing.T, bean any) { + label := reflectionWrap(bean) + issueLabels, err := db.GetEngine(db.DefaultContext).Table("issue_label"). + Where(builder.Eq{"label_id": label.int("ID")}). + Query() + require.NoError(t, err) + + assert.Len(t, issueLabels, label.int("NumIssues"), "Unexpected number of issue for label id: %d", label.int("ID")) + + issueIDs := make([]int, len(issueLabels)) + for i, issueLabel := range issueLabels { + issueIDs[i], _ = strconv.Atoi(string(issueLabel["issue_id"])) + } + + expected := int64(0) + if len(issueIDs) > 0 { + expected = GetCountByCond(t, "issue", builder.In("id", issueIDs).And(builder.Eq{"is_closed": true})) + } + assert.EqualValues(t, expected, label.int("NumClosedIssues"), "Unexpected number of closed issues for label id: %d", label.int("ID")) + } + + checkForTeamConsistency := func(t *testing.T, bean any) { + team := reflectionWrap(bean) + AssertCountByCond(t, "team_user", builder.Eq{"team_id": team.int("ID")}, team.int("NumMembers")) + AssertCountByCond(t, "team_repo", builder.Eq{"team_id": team.int("ID")}, team.int("NumRepos")) + } + + checkForActionConsistency := func(t *testing.T, bean any) { + action := reflectionWrap(bean) + if action.int("RepoID") != 1700 { // dangling intentional + repoRow := AssertExistsAndLoadMap(t, "repository", builder.Eq{"id": action.int("RepoID")}) + assert.Equal(t, parseBool(repoRow["is_private"]), action.bool("IsPrivate"), "Unexpected is_private field for action id: %d", action.int("ID")) + } + } + + consistencyCheckMap["user"] = checkForUserConsistency + consistencyCheckMap["repository"] = checkForRepoConsistency + consistencyCheckMap["issue"] = checkForIssueConsistency + consistencyCheckMap["pull_request"] = checkForPullRequestConsistency + consistencyCheckMap["milestone"] = checkForMilestoneConsistency + consistencyCheckMap["label"] = checkForLabelConsistency + consistencyCheckMap["team"] = checkForTeamConsistency + consistencyCheckMap["action"] = checkForActionConsistency +} diff --git a/models/unittest/fixtures.go b/models/unittest/fixtures.go new file mode 100644 index 0000000..63b26a0 --- /dev/null +++ b/models/unittest/fixtures.go @@ -0,0 +1,144 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +//nolint:forbidigo +package unittest + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/auth/password/hash" + "code.gitea.io/gitea/modules/setting" + + "github.com/go-testfixtures/testfixtures/v3" + "xorm.io/xorm" + "xorm.io/xorm/schemas" +) + +var fixturesLoader *testfixtures.Loader + +// GetXORMEngine gets the XORM engine +func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) { + if len(engine) == 1 { + return engine[0] + } + return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine) +} + +func OverrideFixtures(opts FixturesOptions, engine ...*xorm.Engine) func() { + old := fixturesLoader + if err := InitFixtures(opts, engine...); err != nil { + panic(err) + } + return func() { + fixturesLoader = old + } +} + +// InitFixtures initialize test fixtures for a test database +func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) { + e := GetXORMEngine(engine...) + var fixtureOptionFiles func(*testfixtures.Loader) error + if opts.Dir != "" { + fixtureOptionFiles = testfixtures.Directory(opts.Dir) + } else { + fixtureOptionFiles = testfixtures.Files(opts.Files...) + } + var fixtureOptionDirs []func(*testfixtures.Loader) error + if opts.Dirs != nil { + for _, dir := range opts.Dirs { + fixtureOptionDirs = append(fixtureOptionDirs, testfixtures.Directory(filepath.Join(opts.Base, dir))) + } + } + dialect := "unknown" + switch e.Dialect().URI().DBType { + case schemas.POSTGRES: + dialect = "postgres" + case schemas.MYSQL: + dialect = "mysql" + case schemas.SQLITE: + dialect = "sqlite3" + default: + fmt.Println("Unsupported RDBMS for integration tests") + os.Exit(1) + } + loaderOptions := []func(loader *testfixtures.Loader) error{ + testfixtures.Database(e.DB().DB), + testfixtures.Dialect(dialect), + testfixtures.DangerousSkipTestDatabaseCheck(), + fixtureOptionFiles, + } + loaderOptions = append(loaderOptions, fixtureOptionDirs...) + + if e.Dialect().URI().DBType == schemas.POSTGRES { + loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences()) + } + + fixturesLoader, err = testfixtures.New(loaderOptions...) + if err != nil { + return err + } + + // register the dummy hash algorithm function used in the test fixtures + _ = hash.Register("dummy", hash.NewDummyHasher) + + setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy") + + return err +} + +// LoadFixtures load fixtures for a test database +func LoadFixtures(engine ...*xorm.Engine) error { + e := GetXORMEngine(engine...) + var err error + // (doubt) database transaction conflicts could occur and result in ROLLBACK? just try for a few times. + for i := 0; i < 5; i++ { + if err = fixturesLoader.Load(); err == nil { + break + } + time.Sleep(200 * time.Millisecond) + } + if err != nil { + fmt.Printf("LoadFixtures failed after retries: %v\n", err) + } + // Now if we're running postgres we need to tell it to update the sequences + if e.Dialect().URI().DBType == schemas.POSTGRES { + results, err := e.QueryString(`SELECT 'SELECT SETVAL(' || + quote_literal(quote_ident(PGT.schemaname) || '.' || quote_ident(S.relname)) || + ', COALESCE(MAX(' ||quote_ident(C.attname)|| '), 1) ) FROM ' || + quote_ident(PGT.schemaname)|| '.'||quote_ident(T.relname)|| ';' + FROM pg_class AS S, + pg_depend AS D, + pg_class AS T, + pg_attribute AS C, + pg_tables AS PGT + WHERE S.relkind = 'S' + AND S.oid = D.objid + AND D.refobjid = T.oid + AND D.refobjid = C.attrelid + AND D.refobjsubid = C.attnum + AND T.relname = PGT.tablename + ORDER BY S.relname;`) + if err != nil { + fmt.Printf("Failed to generate sequence update: %v\n", err) + return err + } + for _, r := range results { + for _, value := range r { + _, err = e.Exec(value) + if err != nil { + fmt.Printf("Failed to update sequence: %s Error: %v\n", value, err) + return err + } + } + } + } + _ = hash.Register("dummy", hash.NewDummyHasher) + setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy") + + return err +} diff --git a/models/unittest/fscopy.go b/models/unittest/fscopy.go new file mode 100644 index 0000000..74b12d5 --- /dev/null +++ b/models/unittest/fscopy.go @@ -0,0 +1,102 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unittest + +import ( + "errors" + "io" + "os" + "path" + "strings" + + "code.gitea.io/gitea/modules/util" +) + +// Copy copies file from source to target path. +func Copy(src, dest string) error { + // Gather file information to set back later. + si, err := os.Lstat(src) + if err != nil { + return err + } + + // Handle symbolic link. + if si.Mode()&os.ModeSymlink != 0 { + target, err := os.Readlink(src) + if err != nil { + return err + } + // NOTE: os.Chmod and os.Chtimes don't recognize symbolic link, + // which will lead "no such file or directory" error. + return os.Symlink(target, dest) + } + + sr, err := os.Open(src) + if err != nil { + return err + } + defer sr.Close() + + dw, err := os.Create(dest) + if err != nil { + return err + } + defer dw.Close() + + if _, err = io.Copy(dw, sr); err != nil { + return err + } + + // Set back file information. + if err = os.Chtimes(dest, si.ModTime(), si.ModTime()); err != nil { + return err + } + return os.Chmod(dest, si.Mode()) +} + +// CopyDir copy files recursively from source to target directory. +// +// The filter accepts a function that process the path info. +// and should return true for need to filter. +// +// It returns error when error occurs in underlying functions. +func CopyDir(srcPath, destPath string, filters ...func(filePath string) bool) error { + // Check if target directory exists. + if _, err := os.Stat(destPath); !errors.Is(err, os.ErrNotExist) { + return util.NewAlreadyExistErrorf("file or directory already exists: %s", destPath) + } + + err := os.MkdirAll(destPath, os.ModePerm) + if err != nil { + return err + } + + // Gather directory info. + infos, err := util.StatDir(srcPath, true) + if err != nil { + return err + } + + var filter func(filePath string) bool + if len(filters) > 0 { + filter = filters[0] + } + + for _, info := range infos { + if filter != nil && filter(info) { + continue + } + + curPath := path.Join(destPath, info) + if strings.HasSuffix(info, "/") { + err = os.MkdirAll(curPath, os.ModePerm) + } else { + err = Copy(path.Join(srcPath, info), curPath) + } + if err != nil { + return err + } + } + return nil +} diff --git a/models/unittest/mock_http.go b/models/unittest/mock_http.go new file mode 100644 index 0000000..aea2489 --- /dev/null +++ b/models/unittest/mock_http.go @@ -0,0 +1,115 @@ +// Copyright 2017 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unittest + +import ( + "bufio" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "slices" + "strings" + "testing" + + "code.gitea.io/gitea/modules/log" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Mocks HTTP responses of a third-party service (such as GitHub, GitLab…) +// This has two modes: +// - live mode: the requests made to the mock HTTP server are transmitted to the live +// service, and responses are saved as test data files +// - test mode: the responses to requests to the mock HTTP server are read from the +// test data files +func NewMockWebServer(t *testing.T, liveServerBaseURL, testDataDir string, liveMode bool) *httptest.Server { + mockServerBaseURL := "" + ignoredHeaders := []string{"cf-ray", "server", "date", "report-to", "nel", "x-request-id"} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := NormalizedFullPath(r.URL) + log.Info("Mock HTTP Server: got request for path %s", r.URL.Path) + // TODO check request method (support POST?) + fixturePath := fmt.Sprintf("%s/%s_%s", testDataDir, r.Method, url.PathEscape(path)) + if liveMode { + liveURL := fmt.Sprintf("%s%s", liveServerBaseURL, path) + + request, err := http.NewRequest(r.Method, liveURL, nil) + require.NoError(t, err, "constructing an HTTP request to %s failed", liveURL) + for headerName, headerValues := range r.Header { + // do not pass on the encoding: let the Transport of the HTTP client handle that for us + if strings.ToLower(headerName) != "accept-encoding" { + for _, headerValue := range headerValues { + request.Header.Add(headerName, headerValue) + } + } + } + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err, "HTTP request to %s failed: %s", liveURL) + assert.Less(t, response.StatusCode, 400, "unexpected status code for %s", liveURL) + + fixture, err := os.Create(fixturePath) + require.NoError(t, err, "failed to open the fixture file %s for writing", fixturePath) + defer fixture.Close() + fixtureWriter := bufio.NewWriter(fixture) + + for headerName, headerValues := range response.Header { + for _, headerValue := range headerValues { + if !slices.Contains(ignoredHeaders, strings.ToLower(headerName)) { + _, err := fixtureWriter.WriteString(fmt.Sprintf("%s: %s\n", headerName, headerValue)) + require.NoError(t, err, "writing the header of the HTTP response to the fixture file failed") + } + } + } + _, err = fixtureWriter.WriteString("\n") + require.NoError(t, err, "writing the header of the HTTP response to the fixture file failed") + fixtureWriter.Flush() + + log.Info("Mock HTTP Server: writing response to %s", fixturePath) + _, err = io.Copy(fixture, response.Body) + require.NoError(t, err, "writing the body of the HTTP response to %s failed", liveURL) + + err = fixture.Sync() + require.NoError(t, err, "writing the body of the HTTP response to the fixture file failed") + } + + fixture, err := os.ReadFile(fixturePath) + require.NoError(t, err, "missing mock HTTP response: "+fixturePath) + + w.WriteHeader(http.StatusOK) + + // replace any mention of the live HTTP service by the mocked host + stringFixture := strings.ReplaceAll(string(fixture), liveServerBaseURL, mockServerBaseURL) + // parse back the fixture file into a series of HTTP headers followed by response body + lines := strings.Split(stringFixture, "\n") + for idx, line := range lines { + colonIndex := strings.Index(line, ": ") + if colonIndex != -1 { + w.Header().Set(line[0:colonIndex], line[colonIndex+2:]) + } else { + // we reached the end of the headers (empty line), so what follows is the body + responseBody := strings.Join(lines[idx+1:], "\n") + _, err := w.Write([]byte(responseBody)) + require.NoError(t, err, "writing the body of the HTTP response failed") + break + } + } + })) + mockServerBaseURL = server.URL + return server +} + +func NormalizedFullPath(url *url.URL) string { + // TODO normalize path (remove trailing slash?) + // TODO normalize RawQuery (order query parameters?) + if len(url.Query()) == 0 { + return url.EscapedPath() + } + return fmt.Sprintf("%s?%s", url.EscapedPath(), url.RawQuery) +} diff --git a/models/unittest/reflection.go b/models/unittest/reflection.go new file mode 100644 index 0000000..141fc66 --- /dev/null +++ b/models/unittest/reflection.go @@ -0,0 +1,40 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unittest + +import ( + "log" + "reflect" +) + +func fieldByName(v reflect.Value, field string) reflect.Value { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + f := v.FieldByName(field) + if !f.IsValid() { + log.Panicf("can not read %s for %v", field, v) + } + return f +} + +type reflectionValue struct { + v reflect.Value +} + +func reflectionWrap(v any) *reflectionValue { + return &reflectionValue{v: reflect.ValueOf(v)} +} + +func (rv *reflectionValue) int(field string) int { + return int(fieldByName(rv.v, field).Int()) +} + +func (rv *reflectionValue) str(field string) string { + return fieldByName(rv.v, field).String() +} + +func (rv *reflectionValue) bool(field string) bool { + return fieldByName(rv.v, field).Bool() +} diff --git a/models/unittest/testdb.go b/models/unittest/testdb.go new file mode 100644 index 0000000..94a3253 --- /dev/null +++ b/models/unittest/testdb.go @@ -0,0 +1,267 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unittest + +import ( + "context" + "fmt" + "log" + "os" + "path/filepath" + "strings" + "testing" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/system" + "code.gitea.io/gitea/modules/auth/password/hash" + "code.gitea.io/gitea/modules/base" + "code.gitea.io/gitea/modules/git" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/setting/config" + "code.gitea.io/gitea/modules/storage" + "code.gitea.io/gitea/modules/util" + + "github.com/stretchr/testify/require" + "xorm.io/xorm" + "xorm.io/xorm/names" +) + +// giteaRoot a path to the gitea root +var ( + giteaRoot string + fixturesDir string +) + +// FixturesDir returns the fixture directory +func FixturesDir() string { + return fixturesDir +} + +func fatalTestError(fmtStr string, args ...any) { + _, _ = fmt.Fprintf(os.Stderr, fmtStr, args...) + os.Exit(1) +} + +// InitSettings initializes config provider and load common settings for tests +func InitSettings() { + if setting.CustomConf == "" { + setting.CustomConf = filepath.Join(setting.CustomPath, "conf/app-unittest-tmp.ini") + _ = os.Remove(setting.CustomConf) + } + setting.InitCfgProvider(setting.CustomConf) + setting.LoadCommonSettings() + + if err := setting.PrepareAppDataPath(); err != nil { + log.Fatalf("Can not prepare APP_DATA_PATH: %v", err) + } + // register the dummy hash algorithm function used in the test fixtures + _ = hash.Register("dummy", hash.NewDummyHasher) + + setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy") +} + +// TestOptions represents test options +type TestOptions struct { + FixtureFiles []string + SetUp func() error // SetUp will be executed before all tests in this package + TearDown func() error // TearDown will be executed after all tests in this package +} + +// MainTest a reusable TestMain(..) function for unit tests that need to use a +// test database. Creates the test database, and sets necessary settings. +func MainTest(m *testing.M, testOpts ...*TestOptions) { + searchDir, _ := os.Getwd() + for searchDir != "" { + if _, err := os.Stat(filepath.Join(searchDir, "go.mod")); err == nil { + break // The "go.mod" should be the one for Gitea repository + } + if dir := filepath.Dir(searchDir); dir == searchDir { + searchDir = "" // reaches the root of filesystem + } else { + searchDir = dir + } + } + if searchDir == "" { + panic("The tests should run in a Gitea repository, there should be a 'go.mod' in the root") + } + + giteaRoot = searchDir + setting.CustomPath = filepath.Join(giteaRoot, "custom") + InitSettings() + + fixturesDir = filepath.Join(giteaRoot, "models", "fixtures") + var opts FixturesOptions + if len(testOpts) == 0 || len(testOpts[0].FixtureFiles) == 0 { + opts.Dir = fixturesDir + } else { + for _, f := range testOpts[0].FixtureFiles { + if len(f) != 0 { + opts.Files = append(opts.Files, filepath.Join(fixturesDir, f)) + } + } + } + + if err := CreateTestEngine(opts); err != nil { + fatalTestError("Error creating test engine: %v\n", err) + } + + setting.AppURL = "https://try.gitea.io/" + setting.RunUser = "runuser" + setting.SSH.User = "sshuser" + setting.SSH.BuiltinServerUser = "builtinuser" + setting.SSH.Port = 3000 + setting.SSH.Domain = "try.gitea.io" + setting.Database.Type = "sqlite3" + setting.Repository.DefaultBranch = "master" // many test code still assume that default branch is called "master" + repoRootPath, err := os.MkdirTemp(os.TempDir(), "repos") + if err != nil { + fatalTestError("TempDir: %v\n", err) + } + setting.RepoRootPath = repoRootPath + appDataPath, err := os.MkdirTemp(os.TempDir(), "appdata") + if err != nil { + fatalTestError("TempDir: %v\n", err) + } + setting.AppDataPath = appDataPath + setting.AppWorkPath = giteaRoot + setting.StaticRootPath = giteaRoot + setting.GravatarSource = "https://secure.gravatar.com/avatar/" + + setting.Attachment.Storage.Path = filepath.Join(setting.AppDataPath, "attachments") + + setting.LFS.Storage.Path = filepath.Join(setting.AppDataPath, "lfs") + + setting.Avatar.Storage.Path = filepath.Join(setting.AppDataPath, "avatars") + + setting.RepoAvatar.Storage.Path = filepath.Join(setting.AppDataPath, "repo-avatars") + + setting.RepoArchive.Storage.Path = filepath.Join(setting.AppDataPath, "repo-archive") + + setting.Packages.Storage.Path = filepath.Join(setting.AppDataPath, "packages") + + setting.Actions.LogStorage.Path = filepath.Join(setting.AppDataPath, "actions_log") + + setting.Git.HomePath = filepath.Join(setting.AppDataPath, "home") + + setting.IncomingEmail.ReplyToAddress = "incoming+%{token}@localhost" + + config.SetDynGetter(system.NewDatabaseDynKeyGetter()) + + if err = storage.Init(); err != nil { + fatalTestError("storage.Init: %v\n", err) + } + if err = util.RemoveAll(repoRootPath); err != nil { + fatalTestError("util.RemoveAll: %v\n", err) + } + if err = CopyDir(filepath.Join(giteaRoot, "tests", "gitea-repositories-meta"), setting.RepoRootPath); err != nil { + fatalTestError("util.CopyDir: %v\n", err) + } + + if err = git.InitFull(context.Background()); err != nil { + fatalTestError("git.Init: %v\n", err) + } + ownerDirs, err := os.ReadDir(setting.RepoRootPath) + if err != nil { + fatalTestError("unable to read the new repo root: %v\n", err) + } + for _, ownerDir := range ownerDirs { + if !ownerDir.Type().IsDir() { + continue + } + repoDirs, err := os.ReadDir(filepath.Join(setting.RepoRootPath, ownerDir.Name())) + if err != nil { + fatalTestError("unable to read the new repo root: %v\n", err) + } + for _, repoDir := range repoDirs { + _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "objects", "pack"), 0o755) + _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "objects", "info"), 0o755) + _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "refs", "heads"), 0o755) + _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "refs", "tag"), 0o755) + } + } + + if len(testOpts) > 0 && testOpts[0].SetUp != nil { + if err := testOpts[0].SetUp(); err != nil { + fatalTestError("set up failed: %v\n", err) + } + } + + exitStatus := m.Run() + + if len(testOpts) > 0 && testOpts[0].TearDown != nil { + if err := testOpts[0].TearDown(); err != nil { + fatalTestError("tear down failed: %v\n", err) + } + } + + if err = util.RemoveAll(repoRootPath); err != nil { + fatalTestError("util.RemoveAll: %v\n", err) + } + if err = util.RemoveAll(appDataPath); err != nil { + fatalTestError("util.RemoveAll: %v\n", err) + } + os.Exit(exitStatus) +} + +// FixturesOptions fixtures needs to be loaded options +type FixturesOptions struct { + Dir string + Files []string + Dirs []string + Base string +} + +// CreateTestEngine creates a memory database and loads the fixture data from fixturesDir +func CreateTestEngine(opts FixturesOptions) error { + x, err := xorm.NewEngine("sqlite3", "file::memory:?cache=shared&_txlock=immediate") + if err != nil { + if strings.Contains(err.Error(), "unknown driver") { + return fmt.Errorf(`sqlite3 requires: import _ "github.com/mattn/go-sqlite3" or -tags sqlite,sqlite_unlock_notify%s%w`, "\n", err) + } + return err + } + x.SetMapper(names.GonicMapper{}) + db.SetDefaultEngine(context.Background(), x) + + if err = db.SyncAllTables(); err != nil { + return err + } + switch os.Getenv("GITEA_UNIT_TESTS_LOG_SQL") { + case "true", "1": + x.ShowSQL(true) + } + + return InitFixtures(opts) +} + +// PrepareTestDatabase load test fixtures into test database +func PrepareTestDatabase() error { + return LoadFixtures() +} + +// PrepareTestEnv prepares the environment for unit tests. Can only be called +// by tests that use the above MainTest(..) function. +func PrepareTestEnv(t testing.TB) { + require.NoError(t, PrepareTestDatabase()) + require.NoError(t, util.RemoveAll(setting.RepoRootPath)) + metaPath := filepath.Join(giteaRoot, "tests", "gitea-repositories-meta") + require.NoError(t, CopyDir(metaPath, setting.RepoRootPath)) + ownerDirs, err := os.ReadDir(setting.RepoRootPath) + require.NoError(t, err) + for _, ownerDir := range ownerDirs { + if !ownerDir.Type().IsDir() { + continue + } + repoDirs, err := os.ReadDir(filepath.Join(setting.RepoRootPath, ownerDir.Name())) + require.NoError(t, err) + for _, repoDir := range repoDirs { + _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "objects", "pack"), 0o755) + _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "objects", "info"), 0o755) + _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "refs", "heads"), 0o755) + _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "refs", "tag"), 0o755) + } + } + + base.SetupGiteaRoot() // Makes sure GITEA_ROOT is set +} diff --git a/models/unittest/unit_tests.go b/models/unittest/unit_tests.go new file mode 100644 index 0000000..157c676 --- /dev/null +++ b/models/unittest/unit_tests.go @@ -0,0 +1,164 @@ +// Copyright 2016 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package unittest + +import ( + "math" + "testing" + + "code.gitea.io/gitea/models/db" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "xorm.io/builder" +) + +// Code in this file is mainly used by unittest.CheckConsistencyFor, which is not in the unit test for various reasons. +// In the future if we can decouple CheckConsistencyFor into separate unit test code, then this file can be moved into unittest package too. + +// NonexistentID an ID that will never exist +const NonexistentID = int64(math.MaxInt64) + +type testCond struct { + query any + args []any +} + +type testOrderBy string + +// Cond create a condition with arguments for a test +func Cond(query any, args ...any) any { + return &testCond{query: query, args: args} +} + +// OrderBy creates "ORDER BY" a test query +func OrderBy(orderBy string) any { + return testOrderBy(orderBy) +} + +func whereOrderConditions(e db.Engine, conditions []any) db.Engine { + orderBy := "id" // query must have the "ORDER BY", otherwise the result is not deterministic + for _, condition := range conditions { + switch cond := condition.(type) { + case *testCond: + e = e.Where(cond.query, cond.args...) + case testOrderBy: + orderBy = string(cond) + default: + e = e.Where(cond) + } + } + return e.OrderBy(orderBy) +} + +// LoadBeanIfExists loads beans from fixture database if exist +func LoadBeanIfExists(bean any, conditions ...any) (bool, error) { + e := db.GetEngine(db.DefaultContext) + return whereOrderConditions(e, conditions).Get(bean) +} + +// BeanExists for testing, check if a bean exists +func BeanExists(t testing.TB, bean any, conditions ...any) bool { + exists, err := LoadBeanIfExists(bean, conditions...) + require.NoError(t, err) + return exists +} + +// AssertExistsAndLoadBean assert that a bean exists and load it from the test database +func AssertExistsAndLoadBean[T any](t testing.TB, bean T, conditions ...any) T { + exists, err := LoadBeanIfExists(bean, conditions...) + require.NoError(t, err) + assert.True(t, exists, + "Expected to find %+v (of type %T, with conditions %+v), but did not", + bean, bean, conditions) + return bean +} + +// AssertExistsAndLoadMap assert that a row exists and load it from the test database +func AssertExistsAndLoadMap(t testing.TB, table string, conditions ...any) map[string]string { + e := db.GetEngine(db.DefaultContext).Table(table) + res, err := whereOrderConditions(e, conditions).Query() + require.NoError(t, err) + assert.Len(t, res, 1, + "Expected to find one row in %s (with conditions %+v), but found %d", + table, conditions, len(res), + ) + + if len(res) == 1 { + rec := map[string]string{} + for k, v := range res[0] { + rec[k] = string(v) + } + return rec + } + return nil +} + +// GetCount get the count of a bean +func GetCount(t testing.TB, bean any, conditions ...any) int { + e := db.GetEngine(db.DefaultContext) + for _, condition := range conditions { + switch cond := condition.(type) { + case *testCond: + e = e.Where(cond.query, cond.args...) + default: + e = e.Where(cond) + } + } + count, err := e.Count(bean) + require.NoError(t, err) + return int(count) +} + +// AssertNotExistsBean assert that a bean does not exist in the test database +func AssertNotExistsBean(t testing.TB, bean any, conditions ...any) { + exists, err := LoadBeanIfExists(bean, conditions...) + require.NoError(t, err) + assert.False(t, exists) +} + +// AssertExistsIf asserts that a bean exists or does not exist, depending on +// what is expected. +func AssertExistsIf(t testing.TB, expected bool, bean any, conditions ...any) { + exists, err := LoadBeanIfExists(bean, conditions...) + require.NoError(t, err) + assert.Equal(t, expected, exists) +} + +// AssertSuccessfulInsert assert that beans is successfully inserted +func AssertSuccessfulInsert(t testing.TB, beans ...any) { + err := db.Insert(db.DefaultContext, beans...) + require.NoError(t, err) +} + +// AssertSuccessfulDelete assert that beans is successfully deleted +func AssertSuccessfulDelete(t require.TestingT, beans ...any) { + err := db.DeleteBeans(db.DefaultContext, beans...) + require.NoError(t, err) +} + +// AssertCount assert the count of a bean +func AssertCount(t testing.TB, bean, expected any) bool { + return assert.EqualValues(t, expected, GetCount(t, bean)) +} + +// AssertInt64InRange assert value is in range [low, high] +func AssertInt64InRange(t testing.TB, low, high, value int64) { + assert.True(t, value >= low && value <= high, + "Expected value in range [%d, %d], found %d", low, high, value) +} + +// GetCountByCond get the count of database entries matching bean +func GetCountByCond(t testing.TB, tableName string, cond builder.Cond) int64 { + e := db.GetEngine(db.DefaultContext) + count, err := e.Table(tableName).Where(cond).Count() + require.NoError(t, err) + return count +} + +// AssertCountByCond test the count of database entries matching bean +func AssertCountByCond(t testing.TB, tableName string, cond builder.Cond, expected int) bool { + return assert.EqualValues(t, expected, GetCountByCond(t, tableName, cond), + "Failed consistency test, the counted bean (of table %s) was %+v", tableName, cond) +} |