summaryrefslogtreecommitdiffstats
path: root/models/unit
diff options
context:
space:
mode:
authorDaniel Baumann <daniel@debian.org>2024-10-18 20:33:49 +0200
committerDaniel Baumann <daniel@debian.org>2024-12-12 23:57:56 +0100
commite68b9d00a6e05b3a941f63ffb696f91e554ac5ec (patch)
tree97775d6c13b0f416af55314eb6a89ef792474615 /models/unit
parentInitial commit. (diff)
downloadforgejo-e68b9d00a6e05b3a941f63ffb696f91e554ac5ec.tar.xz
forgejo-e68b9d00a6e05b3a941f63ffb696f91e554ac5ec.zip
Adding upstream version 9.0.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
Diffstat (limited to '')
-rw-r--r--models/unit/unit.go437
-rw-r--r--models/unit/unit_test.go96
-rw-r--r--models/unittest/consistency.go192
-rw-r--r--models/unittest/fixtures.go144
-rw-r--r--models/unittest/fscopy.go102
-rw-r--r--models/unittest/mock_http.go115
-rw-r--r--models/unittest/reflection.go40
-rw-r--r--models/unittest/testdb.go267
-rw-r--r--models/unittest/unit_tests.go164
9 files changed, 1557 insertions, 0 deletions
diff --git a/models/unit/unit.go b/models/unit/unit.go
new file mode 100644
index 0000000..5a8b911
--- /dev/null
+++ b/models/unit/unit.go
@@ -0,0 +1,437 @@
+// Copyright 2017 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package unit
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+ "sync/atomic"
+
+ "code.gitea.io/gitea/models/perm"
+ "code.gitea.io/gitea/modules/container"
+ "code.gitea.io/gitea/modules/log"
+ "code.gitea.io/gitea/modules/setting"
+)
+
+// Type is Unit's Type
+type Type int
+
+// Enumerate all the unit types
+const (
+ TypeInvalid Type = iota // 0 invalid
+ TypeCode // 1 code
+ TypeIssues // 2 issues
+ TypePullRequests // 3 PRs
+ TypeReleases // 4 Releases
+ TypeWiki // 5 Wiki
+ TypeExternalWiki // 6 ExternalWiki
+ TypeExternalTracker // 7 ExternalTracker
+ TypeProjects // 8 Projects
+ TypePackages // 9 Packages
+ TypeActions // 10 Actions
+)
+
+// Value returns integer value for unit type
+func (u Type) Value() int {
+ return int(u)
+}
+
+func (u Type) String() string {
+ switch u {
+ case TypeCode:
+ return "TypeCode"
+ case TypeIssues:
+ return "TypeIssues"
+ case TypePullRequests:
+ return "TypePullRequests"
+ case TypeReleases:
+ return "TypeReleases"
+ case TypeWiki:
+ return "TypeWiki"
+ case TypeExternalWiki:
+ return "TypeExternalWiki"
+ case TypeExternalTracker:
+ return "TypeExternalTracker"
+ case TypeProjects:
+ return "TypeProjects"
+ case TypePackages:
+ return "TypePackages"
+ case TypeActions:
+ return "TypeActions"
+ }
+ return fmt.Sprintf("Unknown Type %d", u)
+}
+
+func (u Type) LogString() string {
+ return fmt.Sprintf("<UnitType:%d:%s>", u, u.String())
+}
+
+var (
+ // AllRepoUnitTypes contains all the unit types
+ AllRepoUnitTypes = []Type{
+ TypeCode,
+ TypeIssues,
+ TypePullRequests,
+ TypeReleases,
+ TypeWiki,
+ TypeExternalWiki,
+ TypeExternalTracker,
+ TypeProjects,
+ TypePackages,
+ TypeActions,
+ }
+
+ // DefaultRepoUnits contains the default unit types
+ DefaultRepoUnits = []Type{
+ TypeCode,
+ TypeIssues,
+ TypePullRequests,
+ TypeReleases,
+ TypeWiki,
+ TypeProjects,
+ TypePackages,
+ TypeActions,
+ }
+
+ // ForkRepoUnits contains the default unit types for forks
+ DefaultForkRepoUnits = []Type{
+ TypeCode,
+ TypePullRequests,
+ }
+
+ // NotAllowedDefaultRepoUnits contains units that can't be default
+ NotAllowedDefaultRepoUnits = []Type{
+ TypeExternalWiki,
+ TypeExternalTracker,
+ }
+
+ disabledRepoUnitsAtomic atomic.Pointer[[]Type] // the units that have been globally disabled
+
+ // AllowedRepoUnitGroups contains the units that have been globally enabled,
+ // with mutually exclusive units grouped together.
+ AllowedRepoUnitGroups = [][]Type{}
+)
+
+// DisabledRepoUnitsGet returns the globally disabled units, it is a quick patch to fix data-race during testing.
+// Because the queue worker might read when a test is mocking the value. FIXME: refactor to a clear solution later.
+func DisabledRepoUnitsGet() []Type {
+ v := disabledRepoUnitsAtomic.Load()
+ if v == nil {
+ return nil
+ }
+ return *v
+}
+
+func DisabledRepoUnitsSet(v []Type) {
+ disabledRepoUnitsAtomic.Store(&v)
+}
+
+// Get valid set of default repository units from settings
+func validateDefaultRepoUnits(defaultUnits, settingDefaultUnits []Type) []Type {
+ units := defaultUnits
+
+ // Use setting if not empty
+ if len(settingDefaultUnits) > 0 {
+ units = make([]Type, 0, len(settingDefaultUnits))
+ for _, settingUnit := range settingDefaultUnits {
+ if !settingUnit.CanBeDefault() {
+ log.Warn("Not allowed as default unit: %s", settingUnit.String())
+ continue
+ }
+ units = append(units, settingUnit)
+ }
+ }
+
+ // Remove disabled units
+ for _, disabledUnit := range DisabledRepoUnitsGet() {
+ for i, unit := range units {
+ if unit == disabledUnit {
+ units = append(units[:i], units[i+1:]...)
+ }
+ }
+ }
+
+ return units
+}
+
+// LoadUnitConfig load units from settings
+func LoadUnitConfig() error {
+ disabledRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DisabledRepoUnits...)
+ if len(invalidKeys) > 0 {
+ log.Warn("Invalid keys in disabled repo units: %s", strings.Join(invalidKeys, ", "))
+ }
+ DisabledRepoUnitsSet(disabledRepoUnits)
+
+ setDefaultRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DefaultRepoUnits...)
+ if len(invalidKeys) > 0 {
+ log.Warn("Invalid keys in default repo units: %s", strings.Join(invalidKeys, ", "))
+ }
+ DefaultRepoUnits = validateDefaultRepoUnits(DefaultRepoUnits, setDefaultRepoUnits)
+ if len(DefaultRepoUnits) == 0 {
+ return errors.New("no default repository units found")
+ }
+ setDefaultForkRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DefaultForkRepoUnits...)
+ if len(invalidKeys) > 0 {
+ log.Warn("Invalid keys in default fork repo units: %s", strings.Join(invalidKeys, ", "))
+ }
+ DefaultForkRepoUnits = validateDefaultRepoUnits(DefaultForkRepoUnits, setDefaultForkRepoUnits)
+ if len(DefaultForkRepoUnits) == 0 {
+ return errors.New("no default fork repository units found")
+ }
+
+ // Collect the allowed repo unit groups. Mutually exclusive units are
+ // grouped together.
+ AllowedRepoUnitGroups = [][]Type{}
+ for _, unit := range []Type{
+ TypeCode,
+ TypePullRequests,
+ TypeProjects,
+ TypePackages,
+ TypeActions,
+ } {
+ // If unit is globally disabled, ignore it.
+ if unit.UnitGlobalDisabled() {
+ continue
+ }
+
+ // If it is allowed, add it to the group list.
+ AllowedRepoUnitGroups = append(AllowedRepoUnitGroups, []Type{unit})
+ }
+
+ addMutuallyExclusiveGroup := func(unit1, unit2 Type) {
+ var list []Type
+
+ if !unit1.UnitGlobalDisabled() {
+ list = append(list, unit1)
+ }
+
+ if !unit2.UnitGlobalDisabled() {
+ list = append(list, unit2)
+ }
+
+ if len(list) > 0 {
+ AllowedRepoUnitGroups = append(AllowedRepoUnitGroups, list)
+ }
+ }
+
+ addMutuallyExclusiveGroup(TypeIssues, TypeExternalTracker)
+ addMutuallyExclusiveGroup(TypeWiki, TypeExternalWiki)
+
+ return nil
+}
+
+// UnitGlobalDisabled checks if unit type is global disabled
+func (u Type) UnitGlobalDisabled() bool {
+ for _, ud := range DisabledRepoUnitsGet() {
+ if u == ud {
+ return true
+ }
+ }
+ return false
+}
+
+// CanBeDefault checks if the unit type can be a default repo unit
+func (u *Type) CanBeDefault() bool {
+ for _, nadU := range NotAllowedDefaultRepoUnits {
+ if *u == nadU {
+ return false
+ }
+ }
+ return true
+}
+
+// Unit is a section of one repository
+type Unit struct {
+ Type Type
+ Name string
+ NameKey string
+ URI string
+ DescKey string
+ Idx int
+ MaxAccessMode perm.AccessMode // The max access mode of the unit. i.e. Read means this unit can only be read.
+}
+
+// IsLessThan compares order of two units
+func (u Unit) IsLessThan(unit Unit) bool {
+ if (u.Type == TypeExternalTracker || u.Type == TypeExternalWiki) && unit.Type != TypeExternalTracker && unit.Type != TypeExternalWiki {
+ return false
+ }
+ return u.Idx < unit.Idx
+}
+
+// MaxPerm returns the max perms of this unit
+func (u Unit) MaxPerm() perm.AccessMode {
+ if u.Type == TypeExternalTracker || u.Type == TypeExternalWiki {
+ return perm.AccessModeRead
+ }
+ return perm.AccessModeAdmin
+}
+
+// Enumerate all the units
+var (
+ UnitCode = Unit{
+ TypeCode,
+ "code",
+ "repo.code",
+ "/",
+ "repo.code.desc",
+ 0,
+ perm.AccessModeOwner,
+ }
+
+ UnitIssues = Unit{
+ TypeIssues,
+ "issues",
+ "repo.issues",
+ "/issues",
+ "repo.issues.desc",
+ 1,
+ perm.AccessModeOwner,
+ }
+
+ UnitExternalTracker = Unit{
+ TypeExternalTracker,
+ "ext_issues",
+ "repo.ext_issues",
+ "/issues",
+ "repo.ext_issues.desc",
+ 1,
+ perm.AccessModeRead,
+ }
+
+ UnitPullRequests = Unit{
+ TypePullRequests,
+ "pulls",
+ "repo.pulls",
+ "/pulls",
+ "repo.pulls.desc",
+ 2,
+ perm.AccessModeOwner,
+ }
+
+ UnitReleases = Unit{
+ TypeReleases,
+ "releases",
+ "repo.releases",
+ "/releases",
+ "repo.releases.desc",
+ 3,
+ perm.AccessModeOwner,
+ }
+
+ UnitWiki = Unit{
+ TypeWiki,
+ "wiki",
+ "repo.wiki",
+ "/wiki",
+ "repo.wiki.desc",
+ 4,
+ perm.AccessModeOwner,
+ }
+
+ UnitExternalWiki = Unit{
+ TypeExternalWiki,
+ "ext_wiki",
+ "repo.ext_wiki",
+ "/wiki",
+ "repo.ext_wiki.desc",
+ 4,
+ perm.AccessModeRead,
+ }
+
+ UnitProjects = Unit{
+ TypeProjects,
+ "projects",
+ "repo.projects",
+ "/projects",
+ "repo.projects.desc",
+ 5,
+ perm.AccessModeOwner,
+ }
+
+ UnitPackages = Unit{
+ TypePackages,
+ "packages",
+ "repo.packages",
+ "/packages",
+ "packages.desc",
+ 6,
+ perm.AccessModeRead,
+ }
+
+ UnitActions = Unit{
+ TypeActions,
+ "actions",
+ "repo.actions",
+ "/actions",
+ "actions.unit.desc",
+ 7,
+ perm.AccessModeOwner,
+ }
+
+ // Units contains all the units
+ Units = map[Type]Unit{
+ TypeCode: UnitCode,
+ TypeIssues: UnitIssues,
+ TypeExternalTracker: UnitExternalTracker,
+ TypePullRequests: UnitPullRequests,
+ TypeReleases: UnitReleases,
+ TypeWiki: UnitWiki,
+ TypeExternalWiki: UnitExternalWiki,
+ TypeProjects: UnitProjects,
+ TypePackages: UnitPackages,
+ TypeActions: UnitActions,
+ }
+)
+
+// FindUnitTypes give the unit key names and return valid unique units and invalid keys
+func FindUnitTypes(nameKeys ...string) (res []Type, invalidKeys []string) {
+ m := make(container.Set[Type])
+ for _, key := range nameKeys {
+ t := TypeFromKey(key)
+ if t == TypeInvalid {
+ invalidKeys = append(invalidKeys, key)
+ } else if m.Add(t) {
+ res = append(res, t)
+ }
+ }
+ return res, invalidKeys
+}
+
+// TypeFromKey give the unit key name and return unit
+func TypeFromKey(nameKey string) Type {
+ for t, u := range Units {
+ if strings.EqualFold(nameKey, u.NameKey) {
+ return t
+ }
+ }
+ return TypeInvalid
+}
+
+// AllUnitKeyNames returns all unit key names
+func AllUnitKeyNames() []string {
+ res := make([]string, 0, len(Units))
+ for _, u := range Units {
+ res = append(res, u.NameKey)
+ }
+ return res
+}
+
+// MinUnitAccessMode returns the minial permission of the permission map
+func MinUnitAccessMode(unitsMap map[Type]perm.AccessMode) perm.AccessMode {
+ res := perm.AccessModeNone
+ for t, mode := range unitsMap {
+ // Don't allow `TypeExternal{Tracker,Wiki}` to influence this as they can only be set to READ perms.
+ if t == TypeExternalTracker || t == TypeExternalWiki {
+ continue
+ }
+
+ // get the minial permission great than AccessModeNone except all are AccessModeNone
+ if mode > perm.AccessModeNone && (res == perm.AccessModeNone || mode < res) {
+ res = mode
+ }
+ }
+ return res
+}
diff --git a/models/unit/unit_test.go b/models/unit/unit_test.go
new file mode 100644
index 0000000..a739677
--- /dev/null
+++ b/models/unit/unit_test.go
@@ -0,0 +1,96 @@
+// Copyright 2023 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package unit
+
+import (
+ "testing"
+
+ "code.gitea.io/gitea/modules/setting"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestLoadUnitConfig(t *testing.T) {
+ t.Run("regular", func(t *testing.T) {
+ defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) {
+ DisabledRepoUnitsSet(disabledRepoUnits)
+ DefaultRepoUnits = defaultRepoUnits
+ DefaultForkRepoUnits = defaultForkRepoUnits
+ }(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits)
+ defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) {
+ setting.Repository.DisabledRepoUnits = disabledRepoUnits
+ setting.Repository.DefaultRepoUnits = defaultRepoUnits
+ setting.Repository.DefaultForkRepoUnits = defaultForkRepoUnits
+ }(setting.Repository.DisabledRepoUnits, setting.Repository.DefaultRepoUnits, setting.Repository.DefaultForkRepoUnits)
+
+ setting.Repository.DisabledRepoUnits = []string{"repo.issues"}
+ setting.Repository.DefaultRepoUnits = []string{"repo.code", "repo.releases", "repo.issues", "repo.pulls"}
+ setting.Repository.DefaultForkRepoUnits = []string{"repo.releases"}
+ require.NoError(t, LoadUnitConfig())
+ assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet())
+ assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits)
+ assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits)
+ })
+ t.Run("invalid", func(t *testing.T) {
+ defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) {
+ DisabledRepoUnitsSet(disabledRepoUnits)
+ DefaultRepoUnits = defaultRepoUnits
+ DefaultForkRepoUnits = defaultForkRepoUnits
+ }(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits)
+ defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) {
+ setting.Repository.DisabledRepoUnits = disabledRepoUnits
+ setting.Repository.DefaultRepoUnits = defaultRepoUnits
+ setting.Repository.DefaultForkRepoUnits = defaultForkRepoUnits
+ }(setting.Repository.DisabledRepoUnits, setting.Repository.DefaultRepoUnits, setting.Repository.DefaultForkRepoUnits)
+
+ setting.Repository.DisabledRepoUnits = []string{"repo.issues", "invalid.1"}
+ setting.Repository.DefaultRepoUnits = []string{"repo.code", "invalid.2", "repo.releases", "repo.issues", "repo.pulls"}
+ setting.Repository.DefaultForkRepoUnits = []string{"invalid.3", "repo.releases"}
+ require.NoError(t, LoadUnitConfig())
+ assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet())
+ assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits)
+ assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits)
+ })
+ t.Run("duplicate", func(t *testing.T) {
+ defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) {
+ DisabledRepoUnitsSet(disabledRepoUnits)
+ DefaultRepoUnits = defaultRepoUnits
+ DefaultForkRepoUnits = defaultForkRepoUnits
+ }(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits)
+ defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) {
+ setting.Repository.DisabledRepoUnits = disabledRepoUnits
+ setting.Repository.DefaultRepoUnits = defaultRepoUnits
+ setting.Repository.DefaultForkRepoUnits = defaultForkRepoUnits
+ }(setting.Repository.DisabledRepoUnits, setting.Repository.DefaultRepoUnits, setting.Repository.DefaultForkRepoUnits)
+
+ setting.Repository.DisabledRepoUnits = []string{"repo.issues", "repo.issues"}
+ setting.Repository.DefaultRepoUnits = []string{"repo.code", "repo.releases", "repo.issues", "repo.pulls", "repo.code"}
+ setting.Repository.DefaultForkRepoUnits = []string{"repo.releases", "repo.releases"}
+ require.NoError(t, LoadUnitConfig())
+ assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet())
+ assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits)
+ assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits)
+ })
+ t.Run("empty_default", func(t *testing.T) {
+ defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) {
+ DisabledRepoUnitsSet(disabledRepoUnits)
+ DefaultRepoUnits = defaultRepoUnits
+ DefaultForkRepoUnits = defaultForkRepoUnits
+ }(DisabledRepoUnitsGet(), DefaultRepoUnits, DefaultForkRepoUnits)
+ defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) {
+ setting.Repository.DisabledRepoUnits = disabledRepoUnits
+ setting.Repository.DefaultRepoUnits = defaultRepoUnits
+ setting.Repository.DefaultForkRepoUnits = defaultForkRepoUnits
+ }(setting.Repository.DisabledRepoUnits, setting.Repository.DefaultRepoUnits, setting.Repository.DefaultForkRepoUnits)
+
+ setting.Repository.DisabledRepoUnits = []string{"repo.issues", "repo.issues"}
+ setting.Repository.DefaultRepoUnits = []string{}
+ setting.Repository.DefaultForkRepoUnits = []string{"repo.releases", "repo.releases"}
+ require.NoError(t, LoadUnitConfig())
+ assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnitsGet())
+ assert.ElementsMatch(t, []Type{TypeCode, TypePullRequests, TypeReleases, TypeWiki, TypePackages, TypeProjects, TypeActions}, DefaultRepoUnits)
+ assert.Equal(t, []Type{TypeReleases}, DefaultForkRepoUnits)
+ })
+}
diff --git a/models/unittest/consistency.go b/models/unittest/consistency.go
new file mode 100644
index 0000000..4e26de7
--- /dev/null
+++ b/models/unittest/consistency.go
@@ -0,0 +1,192 @@
+// Copyright 2021 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package unittest
+
+import (
+ "reflect"
+ "strconv"
+ "strings"
+ "testing"
+
+ "code.gitea.io/gitea/models/db"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "xorm.io/builder"
+)
+
+const (
+ // these const values are copied from `models` package to prevent from cycle-import
+ modelsUserTypeOrganization = 1
+ modelsRepoWatchModeDont = 2
+ modelsCommentTypeComment = 0
+)
+
+var consistencyCheckMap = make(map[string]func(t *testing.T, bean any))
+
+// CheckConsistencyFor test that all matching database entries are consistent
+func CheckConsistencyFor(t *testing.T, beansToCheck ...any) {
+ for _, bean := range beansToCheck {
+ sliceType := reflect.SliceOf(reflect.TypeOf(bean))
+ sliceValue := reflect.MakeSlice(sliceType, 0, 10)
+
+ ptrToSliceValue := reflect.New(sliceType)
+ ptrToSliceValue.Elem().Set(sliceValue)
+
+ require.NoError(t, db.GetEngine(db.DefaultContext).Table(bean).Find(ptrToSliceValue.Interface()))
+ sliceValue = ptrToSliceValue.Elem()
+
+ for i := 0; i < sliceValue.Len(); i++ {
+ entity := sliceValue.Index(i).Interface()
+ checkForConsistency(t, entity)
+ }
+ }
+}
+
+func checkForConsistency(t *testing.T, bean any) {
+ tb, err := db.TableInfo(bean)
+ require.NoError(t, err)
+ f := consistencyCheckMap[tb.Name]
+ if f == nil {
+ assert.FailNow(t, "unknown bean type: %#v", bean)
+ }
+ f(t, bean)
+}
+
+func init() {
+ parseBool := func(v string) bool {
+ b, _ := strconv.ParseBool(v)
+ return b
+ }
+ parseInt := func(v string) int {
+ i, _ := strconv.Atoi(v)
+ return i
+ }
+
+ checkForUserConsistency := func(t *testing.T, bean any) {
+ user := reflectionWrap(bean)
+ AssertCountByCond(t, "repository", builder.Eq{"owner_id": user.int("ID")}, user.int("NumRepos"))
+ AssertCountByCond(t, "star", builder.Eq{"uid": user.int("ID")}, user.int("NumStars"))
+ AssertCountByCond(t, "org_user", builder.Eq{"org_id": user.int("ID")}, user.int("NumMembers"))
+ AssertCountByCond(t, "team", builder.Eq{"org_id": user.int("ID")}, user.int("NumTeams"))
+ AssertCountByCond(t, "follow", builder.Eq{"user_id": user.int("ID")}, user.int("NumFollowing"))
+ AssertCountByCond(t, "follow", builder.Eq{"follow_id": user.int("ID")}, user.int("NumFollowers"))
+ if user.int("Type") != modelsUserTypeOrganization {
+ assert.EqualValues(t, 0, user.int("NumMembers"), "Unexpected number of members for user id: %d", user.int("ID"))
+ assert.EqualValues(t, 0, user.int("NumTeams"), "Unexpected number of teams for user id: %d", user.int("ID"))
+ }
+ }
+
+ checkForRepoConsistency := func(t *testing.T, bean any) {
+ repo := reflectionWrap(bean)
+ assert.Equal(t, repo.str("LowerName"), strings.ToLower(repo.str("Name")), "repo: %+v", repo)
+ AssertCountByCond(t, "star", builder.Eq{"repo_id": repo.int("ID")}, repo.int("NumStars"))
+ AssertCountByCond(t, "milestone", builder.Eq{"repo_id": repo.int("ID")}, repo.int("NumMilestones"))
+ AssertCountByCond(t, "repository", builder.Eq{"fork_id": repo.int("ID")}, repo.int("NumForks"))
+ if repo.bool("IsFork") {
+ AssertExistsAndLoadMap(t, "repository", builder.Eq{"id": repo.int("ForkID")})
+ }
+
+ actual := GetCountByCond(t, "watch", builder.Eq{"repo_id": repo.int("ID")}.
+ And(builder.Neq{"mode": modelsRepoWatchModeDont}))
+ assert.EqualValues(t, repo.int("NumWatches"), actual,
+ "Unexpected number of watches for repo id: %d", repo.int("ID"))
+
+ actual = GetCountByCond(t, "issue", builder.Eq{"is_pull": false, "repo_id": repo.int("ID")})
+ assert.EqualValues(t, repo.int("NumIssues"), actual,
+ "Unexpected number of issues for repo id: %d", repo.int("ID"))
+
+ actual = GetCountByCond(t, "issue", builder.Eq{"is_pull": false, "is_closed": true, "repo_id": repo.int("ID")})
+ assert.EqualValues(t, repo.int("NumClosedIssues"), actual,
+ "Unexpected number of closed issues for repo id: %d", repo.int("ID"))
+
+ actual = GetCountByCond(t, "issue", builder.Eq{"is_pull": true, "repo_id": repo.int("ID")})
+ assert.EqualValues(t, repo.int("NumPulls"), actual,
+ "Unexpected number of pulls for repo id: %d", repo.int("ID"))
+
+ actual = GetCountByCond(t, "issue", builder.Eq{"is_pull": true, "is_closed": true, "repo_id": repo.int("ID")})
+ assert.EqualValues(t, repo.int("NumClosedPulls"), actual,
+ "Unexpected number of closed pulls for repo id: %d", repo.int("ID"))
+
+ actual = GetCountByCond(t, "milestone", builder.Eq{"is_closed": true, "repo_id": repo.int("ID")})
+ assert.EqualValues(t, repo.int("NumClosedMilestones"), actual,
+ "Unexpected number of closed milestones for repo id: %d", repo.int("ID"))
+ }
+
+ checkForIssueConsistency := func(t *testing.T, bean any) {
+ issue := reflectionWrap(bean)
+ typeComment := modelsCommentTypeComment
+ actual := GetCountByCond(t, "comment", builder.Eq{"`type`": typeComment, "issue_id": issue.int("ID")})
+ assert.EqualValues(t, issue.int("NumComments"), actual, "Unexpected number of comments for issue id: %d", issue.int("ID"))
+ if issue.bool("IsPull") {
+ prRow := AssertExistsAndLoadMap(t, "pull_request", builder.Eq{"issue_id": issue.int("ID")})
+ assert.EqualValues(t, parseInt(prRow["index"]), issue.int("Index"), "Unexpected index for issue id: %d", issue.int("ID"))
+ }
+ }
+
+ checkForPullRequestConsistency := func(t *testing.T, bean any) {
+ pr := reflectionWrap(bean)
+ issueRow := AssertExistsAndLoadMap(t, "issue", builder.Eq{"id": pr.int("IssueID")})
+ assert.True(t, parseBool(issueRow["is_pull"]))
+ assert.EqualValues(t, parseInt(issueRow["index"]), pr.int("Index"), "Unexpected index for pull request id: %d", pr.int("ID"))
+ }
+
+ checkForMilestoneConsistency := func(t *testing.T, bean any) {
+ milestone := reflectionWrap(bean)
+ AssertCountByCond(t, "issue", builder.Eq{"milestone_id": milestone.int("ID")}, milestone.int("NumIssues"))
+
+ actual := GetCountByCond(t, "issue", builder.Eq{"is_closed": true, "milestone_id": milestone.int("ID")})
+ assert.EqualValues(t, milestone.int("NumClosedIssues"), actual, "Unexpected number of closed issues for milestone id: %d", milestone.int("ID"))
+
+ completeness := 0
+ if milestone.int("NumIssues") > 0 {
+ completeness = milestone.int("NumClosedIssues") * 100 / milestone.int("NumIssues")
+ }
+ assert.Equal(t, completeness, milestone.int("Completeness"))
+ }
+
+ checkForLabelConsistency := func(t *testing.T, bean any) {
+ label := reflectionWrap(bean)
+ issueLabels, err := db.GetEngine(db.DefaultContext).Table("issue_label").
+ Where(builder.Eq{"label_id": label.int("ID")}).
+ Query()
+ require.NoError(t, err)
+
+ assert.Len(t, issueLabels, label.int("NumIssues"), "Unexpected number of issue for label id: %d", label.int("ID"))
+
+ issueIDs := make([]int, len(issueLabels))
+ for i, issueLabel := range issueLabels {
+ issueIDs[i], _ = strconv.Atoi(string(issueLabel["issue_id"]))
+ }
+
+ expected := int64(0)
+ if len(issueIDs) > 0 {
+ expected = GetCountByCond(t, "issue", builder.In("id", issueIDs).And(builder.Eq{"is_closed": true}))
+ }
+ assert.EqualValues(t, expected, label.int("NumClosedIssues"), "Unexpected number of closed issues for label id: %d", label.int("ID"))
+ }
+
+ checkForTeamConsistency := func(t *testing.T, bean any) {
+ team := reflectionWrap(bean)
+ AssertCountByCond(t, "team_user", builder.Eq{"team_id": team.int("ID")}, team.int("NumMembers"))
+ AssertCountByCond(t, "team_repo", builder.Eq{"team_id": team.int("ID")}, team.int("NumRepos"))
+ }
+
+ checkForActionConsistency := func(t *testing.T, bean any) {
+ action := reflectionWrap(bean)
+ if action.int("RepoID") != 1700 { // dangling intentional
+ repoRow := AssertExistsAndLoadMap(t, "repository", builder.Eq{"id": action.int("RepoID")})
+ assert.Equal(t, parseBool(repoRow["is_private"]), action.bool("IsPrivate"), "Unexpected is_private field for action id: %d", action.int("ID"))
+ }
+ }
+
+ consistencyCheckMap["user"] = checkForUserConsistency
+ consistencyCheckMap["repository"] = checkForRepoConsistency
+ consistencyCheckMap["issue"] = checkForIssueConsistency
+ consistencyCheckMap["pull_request"] = checkForPullRequestConsistency
+ consistencyCheckMap["milestone"] = checkForMilestoneConsistency
+ consistencyCheckMap["label"] = checkForLabelConsistency
+ consistencyCheckMap["team"] = checkForTeamConsistency
+ consistencyCheckMap["action"] = checkForActionConsistency
+}
diff --git a/models/unittest/fixtures.go b/models/unittest/fixtures.go
new file mode 100644
index 0000000..63b26a0
--- /dev/null
+++ b/models/unittest/fixtures.go
@@ -0,0 +1,144 @@
+// Copyright 2021 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+//nolint:forbidigo
+package unittest
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "time"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/auth/password/hash"
+ "code.gitea.io/gitea/modules/setting"
+
+ "github.com/go-testfixtures/testfixtures/v3"
+ "xorm.io/xorm"
+ "xorm.io/xorm/schemas"
+)
+
+var fixturesLoader *testfixtures.Loader
+
+// GetXORMEngine gets the XORM engine
+func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) {
+ if len(engine) == 1 {
+ return engine[0]
+ }
+ return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine)
+}
+
+func OverrideFixtures(opts FixturesOptions, engine ...*xorm.Engine) func() {
+ old := fixturesLoader
+ if err := InitFixtures(opts, engine...); err != nil {
+ panic(err)
+ }
+ return func() {
+ fixturesLoader = old
+ }
+}
+
+// InitFixtures initialize test fixtures for a test database
+func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) {
+ e := GetXORMEngine(engine...)
+ var fixtureOptionFiles func(*testfixtures.Loader) error
+ if opts.Dir != "" {
+ fixtureOptionFiles = testfixtures.Directory(opts.Dir)
+ } else {
+ fixtureOptionFiles = testfixtures.Files(opts.Files...)
+ }
+ var fixtureOptionDirs []func(*testfixtures.Loader) error
+ if opts.Dirs != nil {
+ for _, dir := range opts.Dirs {
+ fixtureOptionDirs = append(fixtureOptionDirs, testfixtures.Directory(filepath.Join(opts.Base, dir)))
+ }
+ }
+ dialect := "unknown"
+ switch e.Dialect().URI().DBType {
+ case schemas.POSTGRES:
+ dialect = "postgres"
+ case schemas.MYSQL:
+ dialect = "mysql"
+ case schemas.SQLITE:
+ dialect = "sqlite3"
+ default:
+ fmt.Println("Unsupported RDBMS for integration tests")
+ os.Exit(1)
+ }
+ loaderOptions := []func(loader *testfixtures.Loader) error{
+ testfixtures.Database(e.DB().DB),
+ testfixtures.Dialect(dialect),
+ testfixtures.DangerousSkipTestDatabaseCheck(),
+ fixtureOptionFiles,
+ }
+ loaderOptions = append(loaderOptions, fixtureOptionDirs...)
+
+ if e.Dialect().URI().DBType == schemas.POSTGRES {
+ loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences())
+ }
+
+ fixturesLoader, err = testfixtures.New(loaderOptions...)
+ if err != nil {
+ return err
+ }
+
+ // register the dummy hash algorithm function used in the test fixtures
+ _ = hash.Register("dummy", hash.NewDummyHasher)
+
+ setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy")
+
+ return err
+}
+
+// LoadFixtures load fixtures for a test database
+func LoadFixtures(engine ...*xorm.Engine) error {
+ e := GetXORMEngine(engine...)
+ var err error
+ // (doubt) database transaction conflicts could occur and result in ROLLBACK? just try for a few times.
+ for i := 0; i < 5; i++ {
+ if err = fixturesLoader.Load(); err == nil {
+ break
+ }
+ time.Sleep(200 * time.Millisecond)
+ }
+ if err != nil {
+ fmt.Printf("LoadFixtures failed after retries: %v\n", err)
+ }
+ // Now if we're running postgres we need to tell it to update the sequences
+ if e.Dialect().URI().DBType == schemas.POSTGRES {
+ results, err := e.QueryString(`SELECT 'SELECT SETVAL(' ||
+ quote_literal(quote_ident(PGT.schemaname) || '.' || quote_ident(S.relname)) ||
+ ', COALESCE(MAX(' ||quote_ident(C.attname)|| '), 1) ) FROM ' ||
+ quote_ident(PGT.schemaname)|| '.'||quote_ident(T.relname)|| ';'
+ FROM pg_class AS S,
+ pg_depend AS D,
+ pg_class AS T,
+ pg_attribute AS C,
+ pg_tables AS PGT
+ WHERE S.relkind = 'S'
+ AND S.oid = D.objid
+ AND D.refobjid = T.oid
+ AND D.refobjid = C.attrelid
+ AND D.refobjsubid = C.attnum
+ AND T.relname = PGT.tablename
+ ORDER BY S.relname;`)
+ if err != nil {
+ fmt.Printf("Failed to generate sequence update: %v\n", err)
+ return err
+ }
+ for _, r := range results {
+ for _, value := range r {
+ _, err = e.Exec(value)
+ if err != nil {
+ fmt.Printf("Failed to update sequence: %s Error: %v\n", value, err)
+ return err
+ }
+ }
+ }
+ }
+ _ = hash.Register("dummy", hash.NewDummyHasher)
+ setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy")
+
+ return err
+}
diff --git a/models/unittest/fscopy.go b/models/unittest/fscopy.go
new file mode 100644
index 0000000..74b12d5
--- /dev/null
+++ b/models/unittest/fscopy.go
@@ -0,0 +1,102 @@
+// Copyright 2022 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package unittest
+
+import (
+ "errors"
+ "io"
+ "os"
+ "path"
+ "strings"
+
+ "code.gitea.io/gitea/modules/util"
+)
+
+// Copy copies file from source to target path.
+func Copy(src, dest string) error {
+ // Gather file information to set back later.
+ si, err := os.Lstat(src)
+ if err != nil {
+ return err
+ }
+
+ // Handle symbolic link.
+ if si.Mode()&os.ModeSymlink != 0 {
+ target, err := os.Readlink(src)
+ if err != nil {
+ return err
+ }
+ // NOTE: os.Chmod and os.Chtimes don't recognize symbolic link,
+ // which will lead "no such file or directory" error.
+ return os.Symlink(target, dest)
+ }
+
+ sr, err := os.Open(src)
+ if err != nil {
+ return err
+ }
+ defer sr.Close()
+
+ dw, err := os.Create(dest)
+ if err != nil {
+ return err
+ }
+ defer dw.Close()
+
+ if _, err = io.Copy(dw, sr); err != nil {
+ return err
+ }
+
+ // Set back file information.
+ if err = os.Chtimes(dest, si.ModTime(), si.ModTime()); err != nil {
+ return err
+ }
+ return os.Chmod(dest, si.Mode())
+}
+
+// CopyDir copy files recursively from source to target directory.
+//
+// The filter accepts a function that process the path info.
+// and should return true for need to filter.
+//
+// It returns error when error occurs in underlying functions.
+func CopyDir(srcPath, destPath string, filters ...func(filePath string) bool) error {
+ // Check if target directory exists.
+ if _, err := os.Stat(destPath); !errors.Is(err, os.ErrNotExist) {
+ return util.NewAlreadyExistErrorf("file or directory already exists: %s", destPath)
+ }
+
+ err := os.MkdirAll(destPath, os.ModePerm)
+ if err != nil {
+ return err
+ }
+
+ // Gather directory info.
+ infos, err := util.StatDir(srcPath, true)
+ if err != nil {
+ return err
+ }
+
+ var filter func(filePath string) bool
+ if len(filters) > 0 {
+ filter = filters[0]
+ }
+
+ for _, info := range infos {
+ if filter != nil && filter(info) {
+ continue
+ }
+
+ curPath := path.Join(destPath, info)
+ if strings.HasSuffix(info, "/") {
+ err = os.MkdirAll(curPath, os.ModePerm)
+ } else {
+ err = Copy(path.Join(srcPath, info), curPath)
+ }
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/models/unittest/mock_http.go b/models/unittest/mock_http.go
new file mode 100644
index 0000000..aea2489
--- /dev/null
+++ b/models/unittest/mock_http.go
@@ -0,0 +1,115 @@
+// Copyright 2017 The Forgejo Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package unittest
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "slices"
+ "strings"
+ "testing"
+
+ "code.gitea.io/gitea/modules/log"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Mocks HTTP responses of a third-party service (such as GitHub, GitLab…)
+// This has two modes:
+// - live mode: the requests made to the mock HTTP server are transmitted to the live
+// service, and responses are saved as test data files
+// - test mode: the responses to requests to the mock HTTP server are read from the
+// test data files
+func NewMockWebServer(t *testing.T, liveServerBaseURL, testDataDir string, liveMode bool) *httptest.Server {
+ mockServerBaseURL := ""
+ ignoredHeaders := []string{"cf-ray", "server", "date", "report-to", "nel", "x-request-id"}
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ path := NormalizedFullPath(r.URL)
+ log.Info("Mock HTTP Server: got request for path %s", r.URL.Path)
+ // TODO check request method (support POST?)
+ fixturePath := fmt.Sprintf("%s/%s_%s", testDataDir, r.Method, url.PathEscape(path))
+ if liveMode {
+ liveURL := fmt.Sprintf("%s%s", liveServerBaseURL, path)
+
+ request, err := http.NewRequest(r.Method, liveURL, nil)
+ require.NoError(t, err, "constructing an HTTP request to %s failed", liveURL)
+ for headerName, headerValues := range r.Header {
+ // do not pass on the encoding: let the Transport of the HTTP client handle that for us
+ if strings.ToLower(headerName) != "accept-encoding" {
+ for _, headerValue := range headerValues {
+ request.Header.Add(headerName, headerValue)
+ }
+ }
+ }
+
+ response, err := http.DefaultClient.Do(request)
+ require.NoError(t, err, "HTTP request to %s failed: %s", liveURL)
+ assert.Less(t, response.StatusCode, 400, "unexpected status code for %s", liveURL)
+
+ fixture, err := os.Create(fixturePath)
+ require.NoError(t, err, "failed to open the fixture file %s for writing", fixturePath)
+ defer fixture.Close()
+ fixtureWriter := bufio.NewWriter(fixture)
+
+ for headerName, headerValues := range response.Header {
+ for _, headerValue := range headerValues {
+ if !slices.Contains(ignoredHeaders, strings.ToLower(headerName)) {
+ _, err := fixtureWriter.WriteString(fmt.Sprintf("%s: %s\n", headerName, headerValue))
+ require.NoError(t, err, "writing the header of the HTTP response to the fixture file failed")
+ }
+ }
+ }
+ _, err = fixtureWriter.WriteString("\n")
+ require.NoError(t, err, "writing the header of the HTTP response to the fixture file failed")
+ fixtureWriter.Flush()
+
+ log.Info("Mock HTTP Server: writing response to %s", fixturePath)
+ _, err = io.Copy(fixture, response.Body)
+ require.NoError(t, err, "writing the body of the HTTP response to %s failed", liveURL)
+
+ err = fixture.Sync()
+ require.NoError(t, err, "writing the body of the HTTP response to the fixture file failed")
+ }
+
+ fixture, err := os.ReadFile(fixturePath)
+ require.NoError(t, err, "missing mock HTTP response: "+fixturePath)
+
+ w.WriteHeader(http.StatusOK)
+
+ // replace any mention of the live HTTP service by the mocked host
+ stringFixture := strings.ReplaceAll(string(fixture), liveServerBaseURL, mockServerBaseURL)
+ // parse back the fixture file into a series of HTTP headers followed by response body
+ lines := strings.Split(stringFixture, "\n")
+ for idx, line := range lines {
+ colonIndex := strings.Index(line, ": ")
+ if colonIndex != -1 {
+ w.Header().Set(line[0:colonIndex], line[colonIndex+2:])
+ } else {
+ // we reached the end of the headers (empty line), so what follows is the body
+ responseBody := strings.Join(lines[idx+1:], "\n")
+ _, err := w.Write([]byte(responseBody))
+ require.NoError(t, err, "writing the body of the HTTP response failed")
+ break
+ }
+ }
+ }))
+ mockServerBaseURL = server.URL
+ return server
+}
+
+func NormalizedFullPath(url *url.URL) string {
+ // TODO normalize path (remove trailing slash?)
+ // TODO normalize RawQuery (order query parameters?)
+ if len(url.Query()) == 0 {
+ return url.EscapedPath()
+ }
+ return fmt.Sprintf("%s?%s", url.EscapedPath(), url.RawQuery)
+}
diff --git a/models/unittest/reflection.go b/models/unittest/reflection.go
new file mode 100644
index 0000000..141fc66
--- /dev/null
+++ b/models/unittest/reflection.go
@@ -0,0 +1,40 @@
+// Copyright 2021 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package unittest
+
+import (
+ "log"
+ "reflect"
+)
+
+func fieldByName(v reflect.Value, field string) reflect.Value {
+ if v.Kind() == reflect.Ptr {
+ v = v.Elem()
+ }
+ f := v.FieldByName(field)
+ if !f.IsValid() {
+ log.Panicf("can not read %s for %v", field, v)
+ }
+ return f
+}
+
+type reflectionValue struct {
+ v reflect.Value
+}
+
+func reflectionWrap(v any) *reflectionValue {
+ return &reflectionValue{v: reflect.ValueOf(v)}
+}
+
+func (rv *reflectionValue) int(field string) int {
+ return int(fieldByName(rv.v, field).Int())
+}
+
+func (rv *reflectionValue) str(field string) string {
+ return fieldByName(rv.v, field).String()
+}
+
+func (rv *reflectionValue) bool(field string) bool {
+ return fieldByName(rv.v, field).Bool()
+}
diff --git a/models/unittest/testdb.go b/models/unittest/testdb.go
new file mode 100644
index 0000000..94a3253
--- /dev/null
+++ b/models/unittest/testdb.go
@@ -0,0 +1,267 @@
+// Copyright 2021 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package unittest
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/models/system"
+ "code.gitea.io/gitea/modules/auth/password/hash"
+ "code.gitea.io/gitea/modules/base"
+ "code.gitea.io/gitea/modules/git"
+ "code.gitea.io/gitea/modules/setting"
+ "code.gitea.io/gitea/modules/setting/config"
+ "code.gitea.io/gitea/modules/storage"
+ "code.gitea.io/gitea/modules/util"
+
+ "github.com/stretchr/testify/require"
+ "xorm.io/xorm"
+ "xorm.io/xorm/names"
+)
+
+// giteaRoot a path to the gitea root
+var (
+ giteaRoot string
+ fixturesDir string
+)
+
+// FixturesDir returns the fixture directory
+func FixturesDir() string {
+ return fixturesDir
+}
+
+func fatalTestError(fmtStr string, args ...any) {
+ _, _ = fmt.Fprintf(os.Stderr, fmtStr, args...)
+ os.Exit(1)
+}
+
+// InitSettings initializes config provider and load common settings for tests
+func InitSettings() {
+ if setting.CustomConf == "" {
+ setting.CustomConf = filepath.Join(setting.CustomPath, "conf/app-unittest-tmp.ini")
+ _ = os.Remove(setting.CustomConf)
+ }
+ setting.InitCfgProvider(setting.CustomConf)
+ setting.LoadCommonSettings()
+
+ if err := setting.PrepareAppDataPath(); err != nil {
+ log.Fatalf("Can not prepare APP_DATA_PATH: %v", err)
+ }
+ // register the dummy hash algorithm function used in the test fixtures
+ _ = hash.Register("dummy", hash.NewDummyHasher)
+
+ setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy")
+}
+
+// TestOptions represents test options
+type TestOptions struct {
+ FixtureFiles []string
+ SetUp func() error // SetUp will be executed before all tests in this package
+ TearDown func() error // TearDown will be executed after all tests in this package
+}
+
+// MainTest a reusable TestMain(..) function for unit tests that need to use a
+// test database. Creates the test database, and sets necessary settings.
+func MainTest(m *testing.M, testOpts ...*TestOptions) {
+ searchDir, _ := os.Getwd()
+ for searchDir != "" {
+ if _, err := os.Stat(filepath.Join(searchDir, "go.mod")); err == nil {
+ break // The "go.mod" should be the one for Gitea repository
+ }
+ if dir := filepath.Dir(searchDir); dir == searchDir {
+ searchDir = "" // reaches the root of filesystem
+ } else {
+ searchDir = dir
+ }
+ }
+ if searchDir == "" {
+ panic("The tests should run in a Gitea repository, there should be a 'go.mod' in the root")
+ }
+
+ giteaRoot = searchDir
+ setting.CustomPath = filepath.Join(giteaRoot, "custom")
+ InitSettings()
+
+ fixturesDir = filepath.Join(giteaRoot, "models", "fixtures")
+ var opts FixturesOptions
+ if len(testOpts) == 0 || len(testOpts[0].FixtureFiles) == 0 {
+ opts.Dir = fixturesDir
+ } else {
+ for _, f := range testOpts[0].FixtureFiles {
+ if len(f) != 0 {
+ opts.Files = append(opts.Files, filepath.Join(fixturesDir, f))
+ }
+ }
+ }
+
+ if err := CreateTestEngine(opts); err != nil {
+ fatalTestError("Error creating test engine: %v\n", err)
+ }
+
+ setting.AppURL = "https://try.gitea.io/"
+ setting.RunUser = "runuser"
+ setting.SSH.User = "sshuser"
+ setting.SSH.BuiltinServerUser = "builtinuser"
+ setting.SSH.Port = 3000
+ setting.SSH.Domain = "try.gitea.io"
+ setting.Database.Type = "sqlite3"
+ setting.Repository.DefaultBranch = "master" // many test code still assume that default branch is called "master"
+ repoRootPath, err := os.MkdirTemp(os.TempDir(), "repos")
+ if err != nil {
+ fatalTestError("TempDir: %v\n", err)
+ }
+ setting.RepoRootPath = repoRootPath
+ appDataPath, err := os.MkdirTemp(os.TempDir(), "appdata")
+ if err != nil {
+ fatalTestError("TempDir: %v\n", err)
+ }
+ setting.AppDataPath = appDataPath
+ setting.AppWorkPath = giteaRoot
+ setting.StaticRootPath = giteaRoot
+ setting.GravatarSource = "https://secure.gravatar.com/avatar/"
+
+ setting.Attachment.Storage.Path = filepath.Join(setting.AppDataPath, "attachments")
+
+ setting.LFS.Storage.Path = filepath.Join(setting.AppDataPath, "lfs")
+
+ setting.Avatar.Storage.Path = filepath.Join(setting.AppDataPath, "avatars")
+
+ setting.RepoAvatar.Storage.Path = filepath.Join(setting.AppDataPath, "repo-avatars")
+
+ setting.RepoArchive.Storage.Path = filepath.Join(setting.AppDataPath, "repo-archive")
+
+ setting.Packages.Storage.Path = filepath.Join(setting.AppDataPath, "packages")
+
+ setting.Actions.LogStorage.Path = filepath.Join(setting.AppDataPath, "actions_log")
+
+ setting.Git.HomePath = filepath.Join(setting.AppDataPath, "home")
+
+ setting.IncomingEmail.ReplyToAddress = "incoming+%{token}@localhost"
+
+ config.SetDynGetter(system.NewDatabaseDynKeyGetter())
+
+ if err = storage.Init(); err != nil {
+ fatalTestError("storage.Init: %v\n", err)
+ }
+ if err = util.RemoveAll(repoRootPath); err != nil {
+ fatalTestError("util.RemoveAll: %v\n", err)
+ }
+ if err = CopyDir(filepath.Join(giteaRoot, "tests", "gitea-repositories-meta"), setting.RepoRootPath); err != nil {
+ fatalTestError("util.CopyDir: %v\n", err)
+ }
+
+ if err = git.InitFull(context.Background()); err != nil {
+ fatalTestError("git.Init: %v\n", err)
+ }
+ ownerDirs, err := os.ReadDir(setting.RepoRootPath)
+ if err != nil {
+ fatalTestError("unable to read the new repo root: %v\n", err)
+ }
+ for _, ownerDir := range ownerDirs {
+ if !ownerDir.Type().IsDir() {
+ continue
+ }
+ repoDirs, err := os.ReadDir(filepath.Join(setting.RepoRootPath, ownerDir.Name()))
+ if err != nil {
+ fatalTestError("unable to read the new repo root: %v\n", err)
+ }
+ for _, repoDir := range repoDirs {
+ _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "objects", "pack"), 0o755)
+ _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "objects", "info"), 0o755)
+ _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "refs", "heads"), 0o755)
+ _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "refs", "tag"), 0o755)
+ }
+ }
+
+ if len(testOpts) > 0 && testOpts[0].SetUp != nil {
+ if err := testOpts[0].SetUp(); err != nil {
+ fatalTestError("set up failed: %v\n", err)
+ }
+ }
+
+ exitStatus := m.Run()
+
+ if len(testOpts) > 0 && testOpts[0].TearDown != nil {
+ if err := testOpts[0].TearDown(); err != nil {
+ fatalTestError("tear down failed: %v\n", err)
+ }
+ }
+
+ if err = util.RemoveAll(repoRootPath); err != nil {
+ fatalTestError("util.RemoveAll: %v\n", err)
+ }
+ if err = util.RemoveAll(appDataPath); err != nil {
+ fatalTestError("util.RemoveAll: %v\n", err)
+ }
+ os.Exit(exitStatus)
+}
+
+// FixturesOptions fixtures needs to be loaded options
+type FixturesOptions struct {
+ Dir string
+ Files []string
+ Dirs []string
+ Base string
+}
+
+// CreateTestEngine creates a memory database and loads the fixture data from fixturesDir
+func CreateTestEngine(opts FixturesOptions) error {
+ x, err := xorm.NewEngine("sqlite3", "file::memory:?cache=shared&_txlock=immediate")
+ if err != nil {
+ if strings.Contains(err.Error(), "unknown driver") {
+ return fmt.Errorf(`sqlite3 requires: import _ "github.com/mattn/go-sqlite3" or -tags sqlite,sqlite_unlock_notify%s%w`, "\n", err)
+ }
+ return err
+ }
+ x.SetMapper(names.GonicMapper{})
+ db.SetDefaultEngine(context.Background(), x)
+
+ if err = db.SyncAllTables(); err != nil {
+ return err
+ }
+ switch os.Getenv("GITEA_UNIT_TESTS_LOG_SQL") {
+ case "true", "1":
+ x.ShowSQL(true)
+ }
+
+ return InitFixtures(opts)
+}
+
+// PrepareTestDatabase load test fixtures into test database
+func PrepareTestDatabase() error {
+ return LoadFixtures()
+}
+
+// PrepareTestEnv prepares the environment for unit tests. Can only be called
+// by tests that use the above MainTest(..) function.
+func PrepareTestEnv(t testing.TB) {
+ require.NoError(t, PrepareTestDatabase())
+ require.NoError(t, util.RemoveAll(setting.RepoRootPath))
+ metaPath := filepath.Join(giteaRoot, "tests", "gitea-repositories-meta")
+ require.NoError(t, CopyDir(metaPath, setting.RepoRootPath))
+ ownerDirs, err := os.ReadDir(setting.RepoRootPath)
+ require.NoError(t, err)
+ for _, ownerDir := range ownerDirs {
+ if !ownerDir.Type().IsDir() {
+ continue
+ }
+ repoDirs, err := os.ReadDir(filepath.Join(setting.RepoRootPath, ownerDir.Name()))
+ require.NoError(t, err)
+ for _, repoDir := range repoDirs {
+ _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "objects", "pack"), 0o755)
+ _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "objects", "info"), 0o755)
+ _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "refs", "heads"), 0o755)
+ _ = os.MkdirAll(filepath.Join(setting.RepoRootPath, ownerDir.Name(), repoDir.Name(), "refs", "tag"), 0o755)
+ }
+ }
+
+ base.SetupGiteaRoot() // Makes sure GITEA_ROOT is set
+}
diff --git a/models/unittest/unit_tests.go b/models/unittest/unit_tests.go
new file mode 100644
index 0000000..157c676
--- /dev/null
+++ b/models/unittest/unit_tests.go
@@ -0,0 +1,164 @@
+// Copyright 2016 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package unittest
+
+import (
+ "math"
+ "testing"
+
+ "code.gitea.io/gitea/models/db"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "xorm.io/builder"
+)
+
+// Code in this file is mainly used by unittest.CheckConsistencyFor, which is not in the unit test for various reasons.
+// In the future if we can decouple CheckConsistencyFor into separate unit test code, then this file can be moved into unittest package too.
+
+// NonexistentID an ID that will never exist
+const NonexistentID = int64(math.MaxInt64)
+
+type testCond struct {
+ query any
+ args []any
+}
+
+type testOrderBy string
+
+// Cond create a condition with arguments for a test
+func Cond(query any, args ...any) any {
+ return &testCond{query: query, args: args}
+}
+
+// OrderBy creates "ORDER BY" a test query
+func OrderBy(orderBy string) any {
+ return testOrderBy(orderBy)
+}
+
+func whereOrderConditions(e db.Engine, conditions []any) db.Engine {
+ orderBy := "id" // query must have the "ORDER BY", otherwise the result is not deterministic
+ for _, condition := range conditions {
+ switch cond := condition.(type) {
+ case *testCond:
+ e = e.Where(cond.query, cond.args...)
+ case testOrderBy:
+ orderBy = string(cond)
+ default:
+ e = e.Where(cond)
+ }
+ }
+ return e.OrderBy(orderBy)
+}
+
+// LoadBeanIfExists loads beans from fixture database if exist
+func LoadBeanIfExists(bean any, conditions ...any) (bool, error) {
+ e := db.GetEngine(db.DefaultContext)
+ return whereOrderConditions(e, conditions).Get(bean)
+}
+
+// BeanExists for testing, check if a bean exists
+func BeanExists(t testing.TB, bean any, conditions ...any) bool {
+ exists, err := LoadBeanIfExists(bean, conditions...)
+ require.NoError(t, err)
+ return exists
+}
+
+// AssertExistsAndLoadBean assert that a bean exists and load it from the test database
+func AssertExistsAndLoadBean[T any](t testing.TB, bean T, conditions ...any) T {
+ exists, err := LoadBeanIfExists(bean, conditions...)
+ require.NoError(t, err)
+ assert.True(t, exists,
+ "Expected to find %+v (of type %T, with conditions %+v), but did not",
+ bean, bean, conditions)
+ return bean
+}
+
+// AssertExistsAndLoadMap assert that a row exists and load it from the test database
+func AssertExistsAndLoadMap(t testing.TB, table string, conditions ...any) map[string]string {
+ e := db.GetEngine(db.DefaultContext).Table(table)
+ res, err := whereOrderConditions(e, conditions).Query()
+ require.NoError(t, err)
+ assert.Len(t, res, 1,
+ "Expected to find one row in %s (with conditions %+v), but found %d",
+ table, conditions, len(res),
+ )
+
+ if len(res) == 1 {
+ rec := map[string]string{}
+ for k, v := range res[0] {
+ rec[k] = string(v)
+ }
+ return rec
+ }
+ return nil
+}
+
+// GetCount get the count of a bean
+func GetCount(t testing.TB, bean any, conditions ...any) int {
+ e := db.GetEngine(db.DefaultContext)
+ for _, condition := range conditions {
+ switch cond := condition.(type) {
+ case *testCond:
+ e = e.Where(cond.query, cond.args...)
+ default:
+ e = e.Where(cond)
+ }
+ }
+ count, err := e.Count(bean)
+ require.NoError(t, err)
+ return int(count)
+}
+
+// AssertNotExistsBean assert that a bean does not exist in the test database
+func AssertNotExistsBean(t testing.TB, bean any, conditions ...any) {
+ exists, err := LoadBeanIfExists(bean, conditions...)
+ require.NoError(t, err)
+ assert.False(t, exists)
+}
+
+// AssertExistsIf asserts that a bean exists or does not exist, depending on
+// what is expected.
+func AssertExistsIf(t testing.TB, expected bool, bean any, conditions ...any) {
+ exists, err := LoadBeanIfExists(bean, conditions...)
+ require.NoError(t, err)
+ assert.Equal(t, expected, exists)
+}
+
+// AssertSuccessfulInsert assert that beans is successfully inserted
+func AssertSuccessfulInsert(t testing.TB, beans ...any) {
+ err := db.Insert(db.DefaultContext, beans...)
+ require.NoError(t, err)
+}
+
+// AssertSuccessfulDelete assert that beans is successfully deleted
+func AssertSuccessfulDelete(t require.TestingT, beans ...any) {
+ err := db.DeleteBeans(db.DefaultContext, beans...)
+ require.NoError(t, err)
+}
+
+// AssertCount assert the count of a bean
+func AssertCount(t testing.TB, bean, expected any) bool {
+ return assert.EqualValues(t, expected, GetCount(t, bean))
+}
+
+// AssertInt64InRange assert value is in range [low, high]
+func AssertInt64InRange(t testing.TB, low, high, value int64) {
+ assert.True(t, value >= low && value <= high,
+ "Expected value in range [%d, %d], found %d", low, high, value)
+}
+
+// GetCountByCond get the count of database entries matching bean
+func GetCountByCond(t testing.TB, tableName string, cond builder.Cond) int64 {
+ e := db.GetEngine(db.DefaultContext)
+ count, err := e.Table(tableName).Where(cond).Count()
+ require.NoError(t, err)
+ return count
+}
+
+// AssertCountByCond test the count of database entries matching bean
+func AssertCountByCond(t testing.TB, tableName string, cond builder.Cond, expected int) bool {
+ return assert.EqualValues(t, expected, GetCountByCond(t, tableName, cond),
+ "Failed consistency test, the counted bean (of table %s) was %+v", tableName, cond)
+}