summaryrefslogtreecommitdiffstats
path: root/models/auth
diff options
context:
space:
mode:
authorDaniel Baumann <daniel@debian.org>2024-10-18 20:33:49 +0200
committerDaniel Baumann <daniel@debian.org>2024-10-18 20:33:49 +0200
commitdd136858f1ea40ad3c94191d647487fa4f31926c (patch)
tree58fec94a7b2a12510c9664b21793f1ed560c6518 /models/auth
parentInitial commit. (diff)
downloadforgejo-dd136858f1ea40ad3c94191d647487fa4f31926c.tar.xz
forgejo-dd136858f1ea40ad3c94191d647487fa4f31926c.zip
Adding upstream version 9.0.0.upstream/9.0.0upstreamdebian
Signed-off-by: Daniel Baumann <daniel@debian.org>
Diffstat (limited to 'models/auth')
-rw-r--r--models/auth/TestOrphanedOAuth2Applications/oauth2_application.yaml25
-rw-r--r--models/auth/access_token.go236
-rw-r--r--models/auth/access_token_scope.go350
-rw-r--r--models/auth/access_token_scope_test.go90
-rw-r--r--models/auth/access_token_test.go133
-rw-r--r--models/auth/auth_token.go96
-rw-r--r--models/auth/main_test.go20
-rw-r--r--models/auth/oauth2.go676
-rw-r--r--models/auth/oauth2_list.go32
-rw-r--r--models/auth/oauth2_test.go299
-rw-r--r--models/auth/session.go120
-rw-r--r--models/auth/session_test.go143
-rw-r--r--models/auth/source.go412
-rw-r--r--models/auth/source_test.go61
-rw-r--r--models/auth/twofactor.go166
-rw-r--r--models/auth/webauthn.go209
-rw-r--r--models/auth/webauthn_test.go78
17 files changed, 3146 insertions, 0 deletions
diff --git a/models/auth/TestOrphanedOAuth2Applications/oauth2_application.yaml b/models/auth/TestOrphanedOAuth2Applications/oauth2_application.yaml
new file mode 100644
index 0000000..b188770
--- /dev/null
+++ b/models/auth/TestOrphanedOAuth2Applications/oauth2_application.yaml
@@ -0,0 +1,25 @@
+-
+ id: 1000
+ uid: 0
+ name: "Git Credential Manager"
+ client_id: "e90ee53c-94e2-48ac-9358-a874fb9e0662"
+ redirect_uris: '["http://127.0.0.1", "https://127.0.0.1"]'
+ created_unix: 1712358091
+ updated_unix: 1712358091
+-
+ id: 1001
+ uid: 0
+ name: "git-credential-oauth"
+ client_id: "a4792ccc-144e-407e-86c9-5e7d8d9c3269"
+ redirect_uris: '["http://127.0.0.1", "https://127.0.0.1"]'
+ created_unix: 1712358091
+ updated_unix: 1712358091
+
+-
+ id: 1002
+ uid: 1234567890
+ name: "Should be removed"
+ client_id: "deadc0de-badd-dd11-fee1-deaddecafbad"
+ redirect_uris: '["http://127.0.0.1", "https://127.0.0.1"]'
+ created_unix: 1712358091
+ updated_unix: 1712358091
diff --git a/models/auth/access_token.go b/models/auth/access_token.go
new file mode 100644
index 0000000..63331b4
--- /dev/null
+++ b/models/auth/access_token.go
@@ -0,0 +1,236 @@
+// Copyright 2014 The Gogs Authors. All rights reserved.
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth
+
+import (
+ "context"
+ "crypto/subtle"
+ "encoding/hex"
+ "fmt"
+ "time"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/setting"
+ "code.gitea.io/gitea/modules/timeutil"
+ "code.gitea.io/gitea/modules/util"
+
+ lru "github.com/hashicorp/golang-lru/v2"
+ "xorm.io/builder"
+)
+
+// ErrAccessTokenNotExist represents a "AccessTokenNotExist" kind of error.
+type ErrAccessTokenNotExist struct {
+ Token string
+}
+
+// IsErrAccessTokenNotExist checks if an error is a ErrAccessTokenNotExist.
+func IsErrAccessTokenNotExist(err error) bool {
+ _, ok := err.(ErrAccessTokenNotExist)
+ return ok
+}
+
+func (err ErrAccessTokenNotExist) Error() string {
+ return fmt.Sprintf("access token does not exist [sha: %s]", err.Token)
+}
+
+func (err ErrAccessTokenNotExist) Unwrap() error {
+ return util.ErrNotExist
+}
+
+// ErrAccessTokenEmpty represents a "AccessTokenEmpty" kind of error.
+type ErrAccessTokenEmpty struct{}
+
+// IsErrAccessTokenEmpty checks if an error is a ErrAccessTokenEmpty.
+func IsErrAccessTokenEmpty(err error) bool {
+ _, ok := err.(ErrAccessTokenEmpty)
+ return ok
+}
+
+func (err ErrAccessTokenEmpty) Error() string {
+ return "access token is empty"
+}
+
+func (err ErrAccessTokenEmpty) Unwrap() error {
+ return util.ErrInvalidArgument
+}
+
+var successfulAccessTokenCache *lru.Cache[string, any]
+
+// AccessToken represents a personal access token.
+type AccessToken struct {
+ ID int64 `xorm:"pk autoincr"`
+ UID int64 `xorm:"INDEX"`
+ Name string
+ Token string `xorm:"-"`
+ TokenHash string `xorm:"UNIQUE"` // sha256 of token
+ TokenSalt string
+ TokenLastEight string `xorm:"INDEX token_last_eight"`
+ Scope AccessTokenScope
+
+ CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
+ HasRecentActivity bool `xorm:"-"`
+ HasUsed bool `xorm:"-"`
+}
+
+// AfterLoad is invoked from XORM after setting the values of all fields of this object.
+func (t *AccessToken) AfterLoad() {
+ t.HasUsed = t.UpdatedUnix > t.CreatedUnix
+ t.HasRecentActivity = t.UpdatedUnix.AddDuration(7*24*time.Hour) > timeutil.TimeStampNow()
+}
+
+func init() {
+ db.RegisterModel(new(AccessToken), func() error {
+ if setting.SuccessfulTokensCacheSize > 0 {
+ var err error
+ successfulAccessTokenCache, err = lru.New[string, any](setting.SuccessfulTokensCacheSize)
+ if err != nil {
+ return fmt.Errorf("unable to allocate AccessToken cache: %w", err)
+ }
+ } else {
+ successfulAccessTokenCache = nil
+ }
+ return nil
+ })
+}
+
+// NewAccessToken creates new access token.
+func NewAccessToken(ctx context.Context, t *AccessToken) error {
+ salt, err := util.CryptoRandomString(10)
+ if err != nil {
+ return err
+ }
+ token, err := util.CryptoRandomBytes(20)
+ if err != nil {
+ return err
+ }
+ t.TokenSalt = salt
+ t.Token = hex.EncodeToString(token)
+ t.TokenHash = HashToken(t.Token, t.TokenSalt)
+ t.TokenLastEight = t.Token[len(t.Token)-8:]
+ _, err = db.GetEngine(ctx).Insert(t)
+ return err
+}
+
+// DisplayPublicOnly whether to display this as a public-only token.
+func (t *AccessToken) DisplayPublicOnly() bool {
+ publicOnly, err := t.Scope.PublicOnly()
+ if err != nil {
+ return false
+ }
+ return publicOnly
+}
+
+func getAccessTokenIDFromCache(token string) int64 {
+ if successfulAccessTokenCache == nil {
+ return 0
+ }
+ tInterface, ok := successfulAccessTokenCache.Get(token)
+ if !ok {
+ return 0
+ }
+ t, ok := tInterface.(int64)
+ if !ok {
+ return 0
+ }
+ return t
+}
+
+// GetAccessTokenBySHA returns access token by given token value
+func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
+ if token == "" {
+ return nil, ErrAccessTokenEmpty{}
+ }
+ // A token is defined as being SHA1 sum these are 40 hexadecimal bytes long
+ if len(token) != 40 {
+ return nil, ErrAccessTokenNotExist{token}
+ }
+ for _, x := range []byte(token) {
+ if x < '0' || (x > '9' && x < 'a') || x > 'f' {
+ return nil, ErrAccessTokenNotExist{token}
+ }
+ }
+
+ lastEight := token[len(token)-8:]
+
+ if id := getAccessTokenIDFromCache(token); id > 0 {
+ accessToken := &AccessToken{
+ TokenLastEight: lastEight,
+ }
+ // Re-get the token from the db in case it has been deleted in the intervening period
+ has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
+ if err != nil {
+ return nil, err
+ }
+ if has {
+ return accessToken, nil
+ }
+ successfulAccessTokenCache.Remove(token)
+ }
+
+ var tokens []AccessToken
+ err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
+ if err != nil {
+ return nil, err
+ } else if len(tokens) == 0 {
+ return nil, ErrAccessTokenNotExist{token}
+ }
+
+ for _, t := range tokens {
+ tempHash := HashToken(token, t.TokenSalt)
+ if subtle.ConstantTimeCompare([]byte(t.TokenHash), []byte(tempHash)) == 1 {
+ if successfulAccessTokenCache != nil {
+ successfulAccessTokenCache.Add(token, t.ID)
+ }
+ return &t, nil
+ }
+ }
+ return nil, ErrAccessTokenNotExist{token}
+}
+
+// AccessTokenByNameExists checks if a token name has been used already by a user.
+func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
+ return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
+}
+
+// ListAccessTokensOptions contain filter options
+type ListAccessTokensOptions struct {
+ db.ListOptions
+ Name string
+ UserID int64
+}
+
+func (opts ListAccessTokensOptions) ToConds() builder.Cond {
+ cond := builder.NewCond()
+ // user id is required, otherwise it will return all result which maybe a possible bug
+ cond = cond.And(builder.Eq{"uid": opts.UserID})
+ if len(opts.Name) > 0 {
+ cond = cond.And(builder.Eq{"name": opts.Name})
+ }
+ return cond
+}
+
+func (opts ListAccessTokensOptions) ToOrders() string {
+ return "created_unix DESC"
+}
+
+// UpdateAccessToken updates information of access token.
+func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
+ _, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
+ return err
+}
+
+// DeleteAccessTokenByID deletes access token by given ID.
+func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
+ cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
+ UID: userID,
+ })
+ if err != nil {
+ return err
+ } else if cnt != 1 {
+ return ErrAccessTokenNotExist{}
+ }
+ return nil
+}
diff --git a/models/auth/access_token_scope.go b/models/auth/access_token_scope.go
new file mode 100644
index 0000000..003ca5c
--- /dev/null
+++ b/models/auth/access_token_scope.go
@@ -0,0 +1,350 @@
+// Copyright 2022 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth
+
+import (
+ "fmt"
+ "strings"
+
+ "code.gitea.io/gitea/models/perm"
+)
+
+// AccessTokenScopeCategory represents the scope category for an access token
+type AccessTokenScopeCategory int
+
+const (
+ AccessTokenScopeCategoryActivityPub = iota
+ AccessTokenScopeCategoryAdmin
+ AccessTokenScopeCategoryMisc // WARN: this is now just a placeholder, don't remove it which will change the following values
+ AccessTokenScopeCategoryNotification
+ AccessTokenScopeCategoryOrganization
+ AccessTokenScopeCategoryPackage
+ AccessTokenScopeCategoryIssue
+ AccessTokenScopeCategoryRepository
+ AccessTokenScopeCategoryUser
+)
+
+// AllAccessTokenScopeCategories contains all access token scope categories
+var AllAccessTokenScopeCategories = []AccessTokenScopeCategory{
+ AccessTokenScopeCategoryActivityPub,
+ AccessTokenScopeCategoryAdmin,
+ AccessTokenScopeCategoryMisc,
+ AccessTokenScopeCategoryNotification,
+ AccessTokenScopeCategoryOrganization,
+ AccessTokenScopeCategoryPackage,
+ AccessTokenScopeCategoryIssue,
+ AccessTokenScopeCategoryRepository,
+ AccessTokenScopeCategoryUser,
+}
+
+// AccessTokenScopeLevel represents the access levels without a given scope category
+type AccessTokenScopeLevel int
+
+const (
+ NoAccess AccessTokenScopeLevel = iota
+ Read
+ Write
+)
+
+// AccessTokenScope represents the scope for an access token.
+type AccessTokenScope string
+
+// for all categories, write implies read
+const (
+ AccessTokenScopeAll AccessTokenScope = "all"
+ AccessTokenScopePublicOnly AccessTokenScope = "public-only" // limited to public orgs/repos
+
+ AccessTokenScopeReadActivityPub AccessTokenScope = "read:activitypub"
+ AccessTokenScopeWriteActivityPub AccessTokenScope = "write:activitypub"
+
+ AccessTokenScopeReadAdmin AccessTokenScope = "read:admin"
+ AccessTokenScopeWriteAdmin AccessTokenScope = "write:admin"
+
+ AccessTokenScopeReadMisc AccessTokenScope = "read:misc"
+ AccessTokenScopeWriteMisc AccessTokenScope = "write:misc"
+
+ AccessTokenScopeReadNotification AccessTokenScope = "read:notification"
+ AccessTokenScopeWriteNotification AccessTokenScope = "write:notification"
+
+ AccessTokenScopeReadOrganization AccessTokenScope = "read:organization"
+ AccessTokenScopeWriteOrganization AccessTokenScope = "write:organization"
+
+ AccessTokenScopeReadPackage AccessTokenScope = "read:package"
+ AccessTokenScopeWritePackage AccessTokenScope = "write:package"
+
+ AccessTokenScopeReadIssue AccessTokenScope = "read:issue"
+ AccessTokenScopeWriteIssue AccessTokenScope = "write:issue"
+
+ AccessTokenScopeReadRepository AccessTokenScope = "read:repository"
+ AccessTokenScopeWriteRepository AccessTokenScope = "write:repository"
+
+ AccessTokenScopeReadUser AccessTokenScope = "read:user"
+ AccessTokenScopeWriteUser AccessTokenScope = "write:user"
+)
+
+// accessTokenScopeBitmap represents a bitmap of access token scopes.
+type accessTokenScopeBitmap uint64
+
+// Bitmap of each scope, including the child scopes.
+const (
+ // AccessTokenScopeAllBits is the bitmap of all access token scopes
+ accessTokenScopeAllBits accessTokenScopeBitmap = accessTokenScopeWriteActivityPubBits |
+ accessTokenScopeWriteAdminBits | accessTokenScopeWriteMiscBits | accessTokenScopeWriteNotificationBits |
+ accessTokenScopeWriteOrganizationBits | accessTokenScopeWritePackageBits | accessTokenScopeWriteIssueBits |
+ accessTokenScopeWriteRepositoryBits | accessTokenScopeWriteUserBits
+
+ accessTokenScopePublicOnlyBits accessTokenScopeBitmap = 1 << iota
+
+ accessTokenScopeReadActivityPubBits accessTokenScopeBitmap = 1 << iota
+ accessTokenScopeWriteActivityPubBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadActivityPubBits
+
+ accessTokenScopeReadAdminBits accessTokenScopeBitmap = 1 << iota
+ accessTokenScopeWriteAdminBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadAdminBits
+
+ accessTokenScopeReadMiscBits accessTokenScopeBitmap = 1 << iota
+ accessTokenScopeWriteMiscBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadMiscBits
+
+ accessTokenScopeReadNotificationBits accessTokenScopeBitmap = 1 << iota
+ accessTokenScopeWriteNotificationBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadNotificationBits
+
+ accessTokenScopeReadOrganizationBits accessTokenScopeBitmap = 1 << iota
+ accessTokenScopeWriteOrganizationBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadOrganizationBits
+
+ accessTokenScopeReadPackageBits accessTokenScopeBitmap = 1 << iota
+ accessTokenScopeWritePackageBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadPackageBits
+
+ accessTokenScopeReadIssueBits accessTokenScopeBitmap = 1 << iota
+ accessTokenScopeWriteIssueBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadIssueBits
+
+ accessTokenScopeReadRepositoryBits accessTokenScopeBitmap = 1 << iota
+ accessTokenScopeWriteRepositoryBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadRepositoryBits
+
+ accessTokenScopeReadUserBits accessTokenScopeBitmap = 1 << iota
+ accessTokenScopeWriteUserBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadUserBits
+
+ // The current implementation only supports up to 64 token scopes.
+ // If we need to support > 64 scopes,
+ // refactoring the whole implementation in this file (and only this file) is needed.
+)
+
+// allAccessTokenScopes contains all access token scopes.
+// The order is important: parent scope must precede child scopes.
+var allAccessTokenScopes = []AccessTokenScope{
+ AccessTokenScopePublicOnly,
+ AccessTokenScopeWriteActivityPub, AccessTokenScopeReadActivityPub,
+ AccessTokenScopeWriteAdmin, AccessTokenScopeReadAdmin,
+ AccessTokenScopeWriteMisc, AccessTokenScopeReadMisc,
+ AccessTokenScopeWriteNotification, AccessTokenScopeReadNotification,
+ AccessTokenScopeWriteOrganization, AccessTokenScopeReadOrganization,
+ AccessTokenScopeWritePackage, AccessTokenScopeReadPackage,
+ AccessTokenScopeWriteIssue, AccessTokenScopeReadIssue,
+ AccessTokenScopeWriteRepository, AccessTokenScopeReadRepository,
+ AccessTokenScopeWriteUser, AccessTokenScopeReadUser,
+}
+
+// allAccessTokenScopeBits contains all access token scopes.
+var allAccessTokenScopeBits = map[AccessTokenScope]accessTokenScopeBitmap{
+ AccessTokenScopeAll: accessTokenScopeAllBits,
+ AccessTokenScopePublicOnly: accessTokenScopePublicOnlyBits,
+ AccessTokenScopeReadActivityPub: accessTokenScopeReadActivityPubBits,
+ AccessTokenScopeWriteActivityPub: accessTokenScopeWriteActivityPubBits,
+ AccessTokenScopeReadAdmin: accessTokenScopeReadAdminBits,
+ AccessTokenScopeWriteAdmin: accessTokenScopeWriteAdminBits,
+ AccessTokenScopeReadMisc: accessTokenScopeReadMiscBits,
+ AccessTokenScopeWriteMisc: accessTokenScopeWriteMiscBits,
+ AccessTokenScopeReadNotification: accessTokenScopeReadNotificationBits,
+ AccessTokenScopeWriteNotification: accessTokenScopeWriteNotificationBits,
+ AccessTokenScopeReadOrganization: accessTokenScopeReadOrganizationBits,
+ AccessTokenScopeWriteOrganization: accessTokenScopeWriteOrganizationBits,
+ AccessTokenScopeReadPackage: accessTokenScopeReadPackageBits,
+ AccessTokenScopeWritePackage: accessTokenScopeWritePackageBits,
+ AccessTokenScopeReadIssue: accessTokenScopeReadIssueBits,
+ AccessTokenScopeWriteIssue: accessTokenScopeWriteIssueBits,
+ AccessTokenScopeReadRepository: accessTokenScopeReadRepositoryBits,
+ AccessTokenScopeWriteRepository: accessTokenScopeWriteRepositoryBits,
+ AccessTokenScopeReadUser: accessTokenScopeReadUserBits,
+ AccessTokenScopeWriteUser: accessTokenScopeWriteUserBits,
+}
+
+// readAccessTokenScopes maps a scope category to the read permission scope
+var accessTokenScopes = map[AccessTokenScopeLevel]map[AccessTokenScopeCategory]AccessTokenScope{
+ Read: {
+ AccessTokenScopeCategoryActivityPub: AccessTokenScopeReadActivityPub,
+ AccessTokenScopeCategoryAdmin: AccessTokenScopeReadAdmin,
+ AccessTokenScopeCategoryMisc: AccessTokenScopeReadMisc,
+ AccessTokenScopeCategoryNotification: AccessTokenScopeReadNotification,
+ AccessTokenScopeCategoryOrganization: AccessTokenScopeReadOrganization,
+ AccessTokenScopeCategoryPackage: AccessTokenScopeReadPackage,
+ AccessTokenScopeCategoryIssue: AccessTokenScopeReadIssue,
+ AccessTokenScopeCategoryRepository: AccessTokenScopeReadRepository,
+ AccessTokenScopeCategoryUser: AccessTokenScopeReadUser,
+ },
+ Write: {
+ AccessTokenScopeCategoryActivityPub: AccessTokenScopeWriteActivityPub,
+ AccessTokenScopeCategoryAdmin: AccessTokenScopeWriteAdmin,
+ AccessTokenScopeCategoryMisc: AccessTokenScopeWriteMisc,
+ AccessTokenScopeCategoryNotification: AccessTokenScopeWriteNotification,
+ AccessTokenScopeCategoryOrganization: AccessTokenScopeWriteOrganization,
+ AccessTokenScopeCategoryPackage: AccessTokenScopeWritePackage,
+ AccessTokenScopeCategoryIssue: AccessTokenScopeWriteIssue,
+ AccessTokenScopeCategoryRepository: AccessTokenScopeWriteRepository,
+ AccessTokenScopeCategoryUser: AccessTokenScopeWriteUser,
+ },
+}
+
+// GetRequiredScopes gets the specific scopes for a given level and categories
+func GetRequiredScopes(level AccessTokenScopeLevel, scopeCategories ...AccessTokenScopeCategory) []AccessTokenScope {
+ scopes := make([]AccessTokenScope, 0, len(scopeCategories))
+ for _, cat := range scopeCategories {
+ scopes = append(scopes, accessTokenScopes[level][cat])
+ }
+ return scopes
+}
+
+// ContainsCategory checks if a list of categories contains a specific category
+func ContainsCategory(categories []AccessTokenScopeCategory, category AccessTokenScopeCategory) bool {
+ for _, c := range categories {
+ if c == category {
+ return true
+ }
+ }
+ return false
+}
+
+// GetScopeLevelFromAccessMode converts permission access mode to scope level
+func GetScopeLevelFromAccessMode(mode perm.AccessMode) AccessTokenScopeLevel {
+ switch mode {
+ case perm.AccessModeNone:
+ return NoAccess
+ case perm.AccessModeRead:
+ return Read
+ case perm.AccessModeWrite:
+ return Write
+ case perm.AccessModeAdmin:
+ return Write
+ case perm.AccessModeOwner:
+ return Write
+ default:
+ return NoAccess
+ }
+}
+
+// parse the scope string into a bitmap, thus removing possible duplicates.
+func (s AccessTokenScope) parse() (accessTokenScopeBitmap, error) {
+ var bitmap accessTokenScopeBitmap
+
+ // The following is the more performant equivalent of 'for _, v := range strings.Split(remainingScope, ",")' as this is hot code
+ remainingScopes := string(s)
+ for len(remainingScopes) > 0 {
+ i := strings.IndexByte(remainingScopes, ',')
+ var v string
+ if i < 0 {
+ v = remainingScopes
+ remainingScopes = ""
+ } else if i+1 >= len(remainingScopes) {
+ v = remainingScopes[:i]
+ remainingScopes = ""
+ } else {
+ v = remainingScopes[:i]
+ remainingScopes = remainingScopes[i+1:]
+ }
+ singleScope := AccessTokenScope(v)
+ if singleScope == "" || singleScope == "sudo" {
+ continue
+ }
+ if singleScope == AccessTokenScopeAll {
+ bitmap |= accessTokenScopeAllBits
+ continue
+ }
+
+ bits, ok := allAccessTokenScopeBits[singleScope]
+ if !ok {
+ return 0, fmt.Errorf("invalid access token scope: %s", singleScope)
+ }
+ bitmap |= bits
+ }
+
+ return bitmap, nil
+}
+
+// StringSlice returns the AccessTokenScope as a []string
+func (s AccessTokenScope) StringSlice() []string {
+ return strings.Split(string(s), ",")
+}
+
+// Normalize returns a normalized scope string without any duplicates.
+func (s AccessTokenScope) Normalize() (AccessTokenScope, error) {
+ bitmap, err := s.parse()
+ if err != nil {
+ return "", err
+ }
+
+ return bitmap.toScope(), nil
+}
+
+// PublicOnly checks if this token scope is limited to public resources
+func (s AccessTokenScope) PublicOnly() (bool, error) {
+ bitmap, err := s.parse()
+ if err != nil {
+ return false, err
+ }
+
+ return bitmap.hasScope(AccessTokenScopePublicOnly)
+}
+
+// HasScope returns true if the string has the given scope
+func (s AccessTokenScope) HasScope(scopes ...AccessTokenScope) (bool, error) {
+ bitmap, err := s.parse()
+ if err != nil {
+ return false, err
+ }
+
+ for _, s := range scopes {
+ if has, err := bitmap.hasScope(s); !has || err != nil {
+ return has, err
+ }
+ }
+
+ return true, nil
+}
+
+// hasScope returns true if the string has the given scope
+func (bitmap accessTokenScopeBitmap) hasScope(scope AccessTokenScope) (bool, error) {
+ expectedBits, ok := allAccessTokenScopeBits[scope]
+ if !ok {
+ return false, fmt.Errorf("invalid access token scope: %s", scope)
+ }
+
+ return bitmap&expectedBits == expectedBits, nil
+}
+
+// toScope returns a normalized scope string without any duplicates.
+func (bitmap accessTokenScopeBitmap) toScope() AccessTokenScope {
+ var scopes []string
+
+ // iterate over all scopes, and reconstruct the bitmap
+ // if the reconstructed bitmap doesn't change, then the scope is already included
+ var reconstruct accessTokenScopeBitmap
+
+ for _, singleScope := range allAccessTokenScopes {
+ // no need for error checking here, since we know the scope is valid
+ if ok, _ := bitmap.hasScope(singleScope); ok {
+ current := reconstruct | allAccessTokenScopeBits[singleScope]
+ if current == reconstruct {
+ continue
+ }
+
+ reconstruct = current
+ scopes = append(scopes, string(singleScope))
+ }
+ }
+
+ scope := AccessTokenScope(strings.Join(scopes, ","))
+ scope = AccessTokenScope(strings.ReplaceAll(
+ string(scope),
+ "write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user",
+ "all",
+ ))
+ return scope
+}
diff --git a/models/auth/access_token_scope_test.go b/models/auth/access_token_scope_test.go
new file mode 100644
index 0000000..d11c5e6
--- /dev/null
+++ b/models/auth/access_token_scope_test.go
@@ -0,0 +1,90 @@
+// Copyright 2022 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type scopeTestNormalize struct {
+ in AccessTokenScope
+ out AccessTokenScope
+ err error
+}
+
+func TestAccessTokenScope_Normalize(t *testing.T) {
+ tests := []scopeTestNormalize{
+ {"", "", nil},
+ {"write:misc,write:notification,read:package,write:notification,public-only", "public-only,write:misc,write:notification,read:package", nil},
+ {"all,sudo", "all", nil},
+ {"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user", "all", nil},
+ {"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user,public-only", "public-only,all", nil},
+ }
+
+ for _, scope := range []string{"activitypub", "admin", "misc", "notification", "organization", "package", "issue", "repository", "user"} {
+ tests = append(tests,
+ scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%s", scope)), AccessTokenScope(fmt.Sprintf("read:%s", scope)), nil},
+ scopeTestNormalize{AccessTokenScope(fmt.Sprintf("write:%s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil},
+ scopeTestNormalize{AccessTokenScope(fmt.Sprintf("write:%[1]s,read:%[1]s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil},
+ scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil},
+ scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s,write:%[1]s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil},
+ )
+ }
+
+ for _, test := range tests {
+ t.Run(string(test.in), func(t *testing.T) {
+ scope, err := test.in.Normalize()
+ assert.Equal(t, test.out, scope)
+ assert.Equal(t, test.err, err)
+ })
+ }
+}
+
+type scopeTestHasScope struct {
+ in AccessTokenScope
+ scope AccessTokenScope
+ out bool
+ err error
+}
+
+func TestAccessTokenScope_HasScope(t *testing.T) {
+ tests := []scopeTestHasScope{
+ {"read:admin", "write:package", false, nil},
+ {"all", "write:package", true, nil},
+ {"write:package", "all", false, nil},
+ {"public-only", "read:issue", false, nil},
+ }
+
+ for _, scope := range []string{"activitypub", "admin", "misc", "notification", "organization", "package", "issue", "repository", "user"} {
+ tests = append(tests,
+ scopeTestHasScope{
+ AccessTokenScope(fmt.Sprintf("read:%s", scope)),
+ AccessTokenScope(fmt.Sprintf("read:%s", scope)), true, nil,
+ },
+ scopeTestHasScope{
+ AccessTokenScope(fmt.Sprintf("write:%s", scope)),
+ AccessTokenScope(fmt.Sprintf("write:%s", scope)), true, nil,
+ },
+ scopeTestHasScope{
+ AccessTokenScope(fmt.Sprintf("write:%s", scope)),
+ AccessTokenScope(fmt.Sprintf("read:%s", scope)), true, nil,
+ },
+ scopeTestHasScope{
+ AccessTokenScope(fmt.Sprintf("read:%s", scope)),
+ AccessTokenScope(fmt.Sprintf("write:%s", scope)), false, nil,
+ },
+ )
+ }
+
+ for _, test := range tests {
+ t.Run(string(test.in), func(t *testing.T) {
+ hasScope, err := test.in.HasScope(test.scope)
+ assert.Equal(t, test.out, hasScope)
+ assert.Equal(t, test.err, err)
+ })
+ }
+}
diff --git a/models/auth/access_token_test.go b/models/auth/access_token_test.go
new file mode 100644
index 0000000..e6ea487
--- /dev/null
+++ b/models/auth/access_token_test.go
@@ -0,0 +1,133 @@
+// Copyright 2016 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth_test
+
+import (
+ "testing"
+
+ auth_model "code.gitea.io/gitea/models/auth"
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/models/unittest"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewAccessToken(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ token := &auth_model.AccessToken{
+ UID: 3,
+ Name: "Token C",
+ }
+ require.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
+ unittest.AssertExistsAndLoadBean(t, token)
+
+ invalidToken := &auth_model.AccessToken{
+ ID: token.ID, // duplicate
+ UID: 2,
+ Name: "Token F",
+ }
+ require.Error(t, auth_model.NewAccessToken(db.DefaultContext, invalidToken))
+}
+
+func TestAccessTokenByNameExists(t *testing.T) {
+ name := "Token Gitea"
+
+ require.NoError(t, unittest.PrepareTestDatabase())
+ token := &auth_model.AccessToken{
+ UID: 3,
+ Name: name,
+ }
+
+ // Check to make sure it doesn't exists already
+ exist, err := auth_model.AccessTokenByNameExists(db.DefaultContext, token)
+ require.NoError(t, err)
+ assert.False(t, exist)
+
+ // Save it to the database
+ require.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
+ unittest.AssertExistsAndLoadBean(t, token)
+
+ // This token must be found by name in the DB now
+ exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, token)
+ require.NoError(t, err)
+ assert.True(t, exist)
+
+ user4Token := &auth_model.AccessToken{
+ UID: 4,
+ Name: name,
+ }
+
+ // Name matches but different user ID, this shouldn't exists in the
+ // database
+ exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, user4Token)
+ require.NoError(t, err)
+ assert.False(t, exist)
+}
+
+func TestGetAccessTokenBySHA(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
+ require.NoError(t, err)
+ assert.Equal(t, int64(1), token.UID)
+ assert.Equal(t, "Token A", token.Name)
+ assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
+ assert.Equal(t, "e4efbf36", token.TokenLastEight)
+
+ _, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "notahash")
+ require.Error(t, err)
+ assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
+
+ _, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "")
+ require.Error(t, err)
+ assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
+}
+
+func TestListAccessTokens(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ tokens, err := db.Find[auth_model.AccessToken](db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 1})
+ require.NoError(t, err)
+ if assert.Len(t, tokens, 2) {
+ assert.Equal(t, int64(1), tokens[0].UID)
+ assert.Equal(t, int64(1), tokens[1].UID)
+ assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token A")
+ assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
+ }
+
+ tokens, err = db.Find[auth_model.AccessToken](db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 2})
+ require.NoError(t, err)
+ if assert.Len(t, tokens, 1) {
+ assert.Equal(t, int64(2), tokens[0].UID)
+ assert.Equal(t, "Token A", tokens[0].Name)
+ }
+
+ tokens, err = db.Find[auth_model.AccessToken](db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 100})
+ require.NoError(t, err)
+ assert.Empty(t, tokens)
+}
+
+func TestUpdateAccessToken(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
+ require.NoError(t, err)
+ token.Name = "Token Z"
+
+ require.NoError(t, auth_model.UpdateAccessToken(db.DefaultContext, token))
+ unittest.AssertExistsAndLoadBean(t, token)
+}
+
+func TestDeleteAccessTokenByID(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+
+ token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
+ require.NoError(t, err)
+ assert.Equal(t, int64(1), token.UID)
+
+ require.NoError(t, auth_model.DeleteAccessTokenByID(db.DefaultContext, token.ID, 1))
+ unittest.AssertNotExistsBean(t, token)
+
+ err = auth_model.DeleteAccessTokenByID(db.DefaultContext, 100, 100)
+ require.Error(t, err)
+ assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
+}
diff --git a/models/auth/auth_token.go b/models/auth/auth_token.go
new file mode 100644
index 0000000..2c3ca90
--- /dev/null
+++ b/models/auth/auth_token.go
@@ -0,0 +1,96 @@
+// Copyright 2023 The Forgejo Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "time"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/timeutil"
+ "code.gitea.io/gitea/modules/util"
+)
+
+// AuthorizationToken represents a authorization token to a user.
+type AuthorizationToken struct {
+ ID int64 `xorm:"pk autoincr"`
+ UID int64 `xorm:"INDEX"`
+ LookupKey string `xorm:"INDEX UNIQUE"`
+ HashedValidator string
+ Expiry timeutil.TimeStamp
+}
+
+// TableName provides the real table name.
+func (AuthorizationToken) TableName() string {
+ return "forgejo_auth_token"
+}
+
+func init() {
+ db.RegisterModel(new(AuthorizationToken))
+}
+
+// IsExpired returns if the authorization token is expired.
+func (authToken *AuthorizationToken) IsExpired() bool {
+ return authToken.Expiry.AsLocalTime().Before(time.Now())
+}
+
+// GenerateAuthToken generates a new authentication token for the given user.
+// It returns the lookup key and validator values that should be passed to the
+// user via a long-term cookie.
+func GenerateAuthToken(ctx context.Context, userID int64, expiry timeutil.TimeStamp) (lookupKey, validator string, err error) {
+ // Request 64 random bytes. The first 32 bytes will be used for the lookupKey
+ // and the other 32 bytes will be used for the validator.
+ rBytes, err := util.CryptoRandomBytes(64)
+ if err != nil {
+ return "", "", err
+ }
+ hexEncoded := hex.EncodeToString(rBytes)
+ validator, lookupKey = hexEncoded[64:], hexEncoded[:64]
+
+ _, err = db.GetEngine(ctx).Insert(&AuthorizationToken{
+ UID: userID,
+ Expiry: expiry,
+ LookupKey: lookupKey,
+ HashedValidator: HashValidator(rBytes[32:]),
+ })
+ return lookupKey, validator, err
+}
+
+// FindAuthToken will find a authorization token via the lookup key.
+func FindAuthToken(ctx context.Context, lookupKey string) (*AuthorizationToken, error) {
+ var authToken AuthorizationToken
+ has, err := db.GetEngine(ctx).Where("lookup_key = ?", lookupKey).Get(&authToken)
+ if err != nil {
+ return nil, err
+ } else if !has {
+ return nil, fmt.Errorf("lookup key %q: %w", lookupKey, util.ErrNotExist)
+ }
+ return &authToken, nil
+}
+
+// DeleteAuthToken will delete the authorization token.
+func DeleteAuthToken(ctx context.Context, authToken *AuthorizationToken) error {
+ _, err := db.DeleteByBean(ctx, authToken)
+ return err
+}
+
+// DeleteAuthTokenByUser will delete all authorization tokens for the user.
+func DeleteAuthTokenByUser(ctx context.Context, userID int64) error {
+ if userID == 0 {
+ return nil
+ }
+
+ _, err := db.DeleteByBean(ctx, &AuthorizationToken{UID: userID})
+ return err
+}
+
+// HashValidator will return a hexified hashed version of the validator.
+func HashValidator(validator []byte) string {
+ h := sha256.New()
+ h.Write(validator)
+ return hex.EncodeToString(h.Sum(nil))
+}
diff --git a/models/auth/main_test.go b/models/auth/main_test.go
new file mode 100644
index 0000000..d772ea6
--- /dev/null
+++ b/models/auth/main_test.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth_test
+
+import (
+ "testing"
+
+ "code.gitea.io/gitea/models/unittest"
+
+ _ "code.gitea.io/gitea/models"
+ _ "code.gitea.io/gitea/models/actions"
+ _ "code.gitea.io/gitea/models/activities"
+ _ "code.gitea.io/gitea/models/auth"
+ _ "code.gitea.io/gitea/models/perm/access"
+)
+
+func TestMain(m *testing.M) {
+ unittest.MainTest(m)
+}
diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go
new file mode 100644
index 0000000..125d64b
--- /dev/null
+++ b/models/auth/oauth2.go
@@ -0,0 +1,676 @@
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/base32"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "net"
+ "net/url"
+ "strings"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/container"
+ "code.gitea.io/gitea/modules/setting"
+ "code.gitea.io/gitea/modules/timeutil"
+ "code.gitea.io/gitea/modules/util"
+
+ uuid "github.com/google/uuid"
+ "golang.org/x/crypto/bcrypt"
+ "xorm.io/builder"
+ "xorm.io/xorm"
+)
+
+// OAuth2Application represents an OAuth2 client (RFC 6749)
+type OAuth2Application struct {
+ ID int64 `xorm:"pk autoincr"`
+ UID int64 `xorm:"INDEX"`
+ Name string
+ ClientID string `xorm:"unique"`
+ ClientSecret string
+ // OAuth defines both Confidential and Public client types
+ // https://datatracker.ietf.org/doc/html/rfc6749#section-2.1
+ // "Authorization servers MUST record the client type in the client registration details"
+ // https://datatracker.ietf.org/doc/html/rfc8252#section-8.4
+ ConfidentialClient bool `xorm:"NOT NULL DEFAULT TRUE"`
+ RedirectURIs []string `xorm:"redirect_uris JSON TEXT"`
+ CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
+}
+
+func init() {
+ db.RegisterModel(new(OAuth2Application))
+ db.RegisterModel(new(OAuth2AuthorizationCode))
+ db.RegisterModel(new(OAuth2Grant))
+}
+
+type BuiltinOAuth2Application struct {
+ ConfigName string
+ DisplayName string
+ RedirectURIs []string
+}
+
+func BuiltinApplications() map[string]*BuiltinOAuth2Application {
+ m := make(map[string]*BuiltinOAuth2Application)
+ m["a4792ccc-144e-407e-86c9-5e7d8d9c3269"] = &BuiltinOAuth2Application{
+ ConfigName: "git-credential-oauth",
+ DisplayName: "git-credential-oauth",
+ RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
+ }
+ m["e90ee53c-94e2-48ac-9358-a874fb9e0662"] = &BuiltinOAuth2Application{
+ ConfigName: "git-credential-manager",
+ DisplayName: "Git Credential Manager",
+ RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
+ }
+ m["d57cb8c4-630c-4168-8324-ec79935e18d4"] = &BuiltinOAuth2Application{
+ ConfigName: "tea",
+ DisplayName: "tea",
+ RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
+ }
+ return m
+}
+
+func BuiltinApplicationsClientIDs() (clientIDs []string) {
+ for clientID := range BuiltinApplications() {
+ clientIDs = append(clientIDs, clientID)
+ }
+ return clientIDs
+}
+
+func Init(ctx context.Context) error {
+ builtinApps := BuiltinApplications()
+ var builtinAllClientIDs []string
+ for clientID := range builtinApps {
+ builtinAllClientIDs = append(builtinAllClientIDs, clientID)
+ }
+
+ var registeredApps []*OAuth2Application
+ if err := db.GetEngine(ctx).In("client_id", builtinAllClientIDs).Find(&registeredApps); err != nil {
+ return err
+ }
+
+ clientIDsToAdd := container.Set[string]{}
+ for _, configName := range setting.OAuth2.DefaultApplications {
+ found := false
+ for clientID, builtinApp := range builtinApps {
+ if builtinApp.ConfigName == configName {
+ clientIDsToAdd.Add(clientID) // add all user-configured apps to the "add" list
+ found = true
+ }
+ }
+ if !found {
+ return fmt.Errorf("unknown oauth2 application: %q", configName)
+ }
+ }
+ clientIDsToDelete := container.Set[string]{}
+ for _, app := range registeredApps {
+ if !clientIDsToAdd.Contains(app.ClientID) {
+ clientIDsToDelete.Add(app.ClientID) // if a registered app is not in the "add" list, it should be deleted
+ }
+ }
+ for _, app := range registeredApps {
+ clientIDsToAdd.Remove(app.ClientID) // no need to re-add existing (registered) apps, so remove them from the set
+ }
+
+ for _, app := range registeredApps {
+ if clientIDsToDelete.Contains(app.ClientID) {
+ if err := deleteOAuth2Application(ctx, app.ID, 0); err != nil {
+ return err
+ }
+ }
+ }
+ for clientID := range clientIDsToAdd {
+ builtinApp := builtinApps[clientID]
+ if err := db.Insert(ctx, &OAuth2Application{
+ Name: builtinApp.DisplayName,
+ ClientID: clientID,
+ RedirectURIs: builtinApp.RedirectURIs,
+ }); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// TableName sets the table name to `oauth2_application`
+func (app *OAuth2Application) TableName() string {
+ return "oauth2_application"
+}
+
+// ContainsRedirectURI checks if redirectURI is allowed for app
+func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool {
+ // OAuth2 requires the redirect URI to be an exact match, no dynamic parts are allowed.
+ // https://stackoverflow.com/questions/55524480/should-dynamic-query-parameters-be-present-in-the-redirection-uri-for-an-oauth2
+ // https://www.rfc-editor.org/rfc/rfc6819#section-5.2.3.3
+ // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
+ // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics-12#section-3.1
+ contains := func(s string) bool {
+ s = strings.TrimSuffix(strings.ToLower(s), "/")
+ for _, u := range app.RedirectURIs {
+ if strings.TrimSuffix(strings.ToLower(u), "/") == s {
+ return true
+ }
+ }
+ return false
+ }
+ if !app.ConfidentialClient {
+ uri, err := url.Parse(redirectURI)
+ // ignore port for http loopback uris following https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
+ if err == nil && uri.Scheme == "http" && uri.Port() != "" {
+ ip := net.ParseIP(uri.Hostname())
+ if ip != nil && ip.IsLoopback() {
+ // strip port
+ uri.Host = uri.Hostname()
+ if contains(uri.String()) {
+ return true
+ }
+ }
+ }
+ }
+ return contains(redirectURI)
+}
+
+// Base32 characters, but lowercased.
+const lowerBase32Chars = "abcdefghijklmnopqrstuvwxyz234567"
+
+// base32 encoder that uses lowered characters without padding.
+var base32Lower = base32.NewEncoding(lowerBase32Chars).WithPadding(base32.NoPadding)
+
+// GenerateClientSecret will generate the client secret and returns the plaintext and saves the hash at the database
+func (app *OAuth2Application) GenerateClientSecret(ctx context.Context) (string, error) {
+ rBytes, err := util.CryptoRandomBytes(32)
+ if err != nil {
+ return "", err
+ }
+ // Add a prefix to the base32, this is in order to make it easier
+ // for code scanners to grab sensitive tokens.
+ clientSecret := "gto_" + base32Lower.EncodeToString(rBytes)
+
+ hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
+ if err != nil {
+ return "", err
+ }
+ app.ClientSecret = string(hashedSecret)
+ if _, err := db.GetEngine(ctx).ID(app.ID).Cols("client_secret").Update(app); err != nil {
+ return "", err
+ }
+ return clientSecret, nil
+}
+
+// ValidateClientSecret validates the given secret by the hash saved in database
+func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
+ return bcrypt.CompareHashAndPassword([]byte(app.ClientSecret), secret) == nil
+}
+
+// GetGrantByUserID returns a OAuth2Grant by its user and application ID
+func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
+ grant = new(OAuth2Grant)
+ if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
+ return nil, err
+ } else if !has {
+ return nil, nil
+ }
+ return grant, nil
+}
+
+// CreateGrant generates a grant for an user
+func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
+ grant := &OAuth2Grant{
+ ApplicationID: app.ID,
+ UserID: userID,
+ Scope: scope,
+ }
+ err := db.Insert(ctx, grant)
+ if err != nil {
+ return nil, err
+ }
+ return grant, nil
+}
+
+// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
+func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
+ app = new(OAuth2Application)
+ has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
+ if !has {
+ return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
+ }
+ return app, err
+}
+
+// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
+func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
+ app = new(OAuth2Application)
+ has, err := db.GetEngine(ctx).ID(id).Get(app)
+ if err != nil {
+ return nil, err
+ }
+ if !has {
+ return nil, ErrOAuthApplicationNotFound{ID: id}
+ }
+ return app, nil
+}
+
+// CreateOAuth2ApplicationOptions holds options to create an oauth2 application
+type CreateOAuth2ApplicationOptions struct {
+ Name string
+ UserID int64
+ ConfidentialClient bool
+ RedirectURIs []string
+}
+
+// CreateOAuth2Application inserts a new oauth2 application
+func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
+ clientID := uuid.New().String()
+ app := &OAuth2Application{
+ UID: opts.UserID,
+ Name: opts.Name,
+ ClientID: clientID,
+ RedirectURIs: opts.RedirectURIs,
+ ConfidentialClient: opts.ConfidentialClient,
+ }
+ if err := db.Insert(ctx, app); err != nil {
+ return nil, err
+ }
+ return app, nil
+}
+
+// UpdateOAuth2ApplicationOptions holds options to update an oauth2 application
+type UpdateOAuth2ApplicationOptions struct {
+ ID int64
+ Name string
+ UserID int64
+ ConfidentialClient bool
+ RedirectURIs []string
+}
+
+// UpdateOAuth2Application updates an oauth2 application
+func UpdateOAuth2Application(ctx context.Context, opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
+ ctx, committer, err := db.TxContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer committer.Close()
+
+ app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
+ if err != nil {
+ return nil, err
+ }
+ if app.UID != opts.UserID {
+ return nil, errors.New("UID mismatch")
+ }
+ builtinApps := BuiltinApplications()
+ if _, builtin := builtinApps[app.ClientID]; builtin {
+ return nil, fmt.Errorf("failed to edit OAuth2 application: application is locked: %s", app.ClientID)
+ }
+
+ app.Name = opts.Name
+ app.RedirectURIs = opts.RedirectURIs
+ app.ConfidentialClient = opts.ConfidentialClient
+
+ if err = updateOAuth2Application(ctx, app); err != nil {
+ return nil, err
+ }
+ app.ClientSecret = ""
+
+ return app, committer.Commit()
+}
+
+func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
+ if _, err := db.GetEngine(ctx).ID(app.ID).UseBool("confidential_client").Update(app); err != nil {
+ return err
+ }
+ return nil
+}
+
+func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
+ sess := db.GetEngine(ctx)
+ // the userid could be 0 if the app is instance-wide
+ if deleted, err := sess.Where(builder.Eq{"id": id, "uid": userid}).Delete(&OAuth2Application{}); err != nil {
+ return err
+ } else if deleted == 0 {
+ return ErrOAuthApplicationNotFound{ID: id}
+ }
+ codes := make([]*OAuth2AuthorizationCode, 0)
+ // delete correlating auth codes
+ if err := sess.Join("INNER", "oauth2_grant",
+ "oauth2_authorization_code.grant_id = oauth2_grant.id AND oauth2_grant.application_id = ?", id).Find(&codes); err != nil {
+ return err
+ }
+ codeIDs := make([]int64, 0, len(codes))
+ for _, grant := range codes {
+ codeIDs = append(codeIDs, grant.ID)
+ }
+
+ if _, err := sess.In("id", codeIDs).Delete(new(OAuth2AuthorizationCode)); err != nil {
+ return err
+ }
+
+ if _, err := sess.Where("application_id = ?", id).Delete(new(OAuth2Grant)); err != nil {
+ return err
+ }
+ return nil
+}
+
+// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
+func DeleteOAuth2Application(ctx context.Context, id, userid int64) error {
+ ctx, committer, err := db.TxContext(ctx)
+ if err != nil {
+ return err
+ }
+ defer committer.Close()
+ app, err := GetOAuth2ApplicationByID(ctx, id)
+ if err != nil {
+ return err
+ }
+ builtinApps := BuiltinApplications()
+ if _, builtin := builtinApps[app.ClientID]; builtin {
+ return fmt.Errorf("failed to delete OAuth2 application: application is locked: %s", app.ClientID)
+ }
+ if err := deleteOAuth2Application(ctx, id, userid); err != nil {
+ return err
+ }
+ return committer.Commit()
+}
+
+//////////////////////////////////////////////////////
+
+// OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime.
+type OAuth2AuthorizationCode struct {
+ ID int64 `xorm:"pk autoincr"`
+ Grant *OAuth2Grant `xorm:"-"`
+ GrantID int64
+ Code string `xorm:"INDEX unique"`
+ CodeChallenge string
+ CodeChallengeMethod string
+ RedirectURI string
+ ValidUntil timeutil.TimeStamp `xorm:"index"`
+}
+
+// TableName sets the table name to `oauth2_authorization_code`
+func (code *OAuth2AuthorizationCode) TableName() string {
+ return "oauth2_authorization_code"
+}
+
+// GenerateRedirectURI generates a redirect URI for a successful authorization request. State will be used if not empty.
+func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (*url.URL, error) {
+ redirect, err := url.Parse(code.RedirectURI)
+ if err != nil {
+ return nil, err
+ }
+ q := redirect.Query()
+ if state != "" {
+ q.Set("state", state)
+ }
+ q.Set("code", code.Code)
+ redirect.RawQuery = q.Encode()
+ return redirect, err
+}
+
+// Invalidate deletes the auth code from the database to invalidate this code
+func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
+ _, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
+ return err
+}
+
+// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
+func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
+ switch code.CodeChallengeMethod {
+ case "S256":
+ // base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
+ h := sha256.Sum256([]byte(verifier))
+ hashedVerifier := base64.RawURLEncoding.EncodeToString(h[:])
+ return hashedVerifier == code.CodeChallenge
+ case "plain":
+ return verifier == code.CodeChallenge
+ case "":
+ return true
+ default:
+ // unsupported method -> return false
+ return false
+ }
+}
+
+// GetOAuth2AuthorizationByCode returns an authorization by its code
+func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
+ auth = new(OAuth2AuthorizationCode)
+ if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
+ return nil, err
+ } else if !has {
+ return nil, nil
+ }
+ auth.Grant = new(OAuth2Grant)
+ if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
+ return nil, err
+ } else if !has {
+ return nil, nil
+ }
+ return auth, nil
+}
+
+//////////////////////////////////////////////////////
+
+// OAuth2Grant represents the permission of an user for a specific application to access resources
+type OAuth2Grant struct {
+ ID int64 `xorm:"pk autoincr"`
+ UserID int64 `xorm:"INDEX unique(user_application)"`
+ Application *OAuth2Application `xorm:"-"`
+ ApplicationID int64 `xorm:"INDEX unique(user_application)"`
+ Counter int64 `xorm:"NOT NULL DEFAULT 1"`
+ Scope string `xorm:"TEXT"`
+ Nonce string `xorm:"TEXT"`
+ CreatedUnix timeutil.TimeStamp `xorm:"created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
+}
+
+// TableName sets the table name to `oauth2_grant`
+func (grant *OAuth2Grant) TableName() string {
+ return "oauth2_grant"
+}
+
+// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
+func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
+ rBytes, err := util.CryptoRandomBytes(32)
+ if err != nil {
+ return &OAuth2AuthorizationCode{}, err
+ }
+ // Add a prefix to the base32, this is in order to make it easier
+ // for code scanners to grab sensitive tokens.
+ codeSecret := "gta_" + base32Lower.EncodeToString(rBytes)
+
+ code = &OAuth2AuthorizationCode{
+ Grant: grant,
+ GrantID: grant.ID,
+ RedirectURI: redirectURI,
+ Code: codeSecret,
+ CodeChallenge: codeChallenge,
+ CodeChallengeMethod: codeChallengeMethod,
+ }
+ if err := db.Insert(ctx, code); err != nil {
+ return nil, err
+ }
+ return code, nil
+}
+
+// IncreaseCounter increases the counter and updates the grant
+func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
+ _, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
+ if err != nil {
+ return err
+ }
+ updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID)
+ if err != nil {
+ return err
+ }
+ grant.Counter = updatedGrant.Counter
+ return nil
+}
+
+// ScopeContains returns true if the grant scope contains the specified scope
+func (grant *OAuth2Grant) ScopeContains(scope string) bool {
+ for _, currentScope := range strings.Split(grant.Scope, " ") {
+ if scope == currentScope {
+ return true
+ }
+ }
+ return false
+}
+
+// SetNonce updates the current nonce value of a grant
+func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
+ grant.Nonce = nonce
+ _, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// GetOAuth2GrantByID returns the grant with the given ID
+func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
+ grant = new(OAuth2Grant)
+ if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
+ return nil, err
+ } else if !has {
+ return nil, nil
+ }
+ return grant, err
+}
+
+// GetOAuth2GrantsByUserID lists all grants of a certain user
+func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
+ type joinedOAuth2Grant struct {
+ Grant *OAuth2Grant `xorm:"extends"`
+ Application *OAuth2Application `xorm:"extends"`
+ }
+ var results *xorm.Rows
+ var err error
+ if results, err = db.GetEngine(ctx).
+ Table("oauth2_grant").
+ Where("user_id = ?", uid).
+ Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
+ Rows(new(joinedOAuth2Grant)); err != nil {
+ return nil, err
+ }
+ defer results.Close()
+ grants := make([]*OAuth2Grant, 0)
+ for results.Next() {
+ joinedGrant := new(joinedOAuth2Grant)
+ if err := results.Scan(joinedGrant); err != nil {
+ return nil, err
+ }
+ joinedGrant.Grant.Application = joinedGrant.Application
+ grants = append(grants, joinedGrant.Grant)
+ }
+ return grants, nil
+}
+
+// RevokeOAuth2Grant deletes the grant with grantID and userID
+func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
+ _, err := db.GetEngine(ctx).Where(builder.Eq{"id": grantID, "user_id": userID}).Delete(&OAuth2Grant{})
+ return err
+}
+
+// ErrOAuthClientIDInvalid will be thrown if client id cannot be found
+type ErrOAuthClientIDInvalid struct {
+ ClientID string
+}
+
+// IsErrOauthClientIDInvalid checks if an error is a ErrOAuthClientIDInvalid.
+func IsErrOauthClientIDInvalid(err error) bool {
+ _, ok := err.(ErrOAuthClientIDInvalid)
+ return ok
+}
+
+// Error returns the error message
+func (err ErrOAuthClientIDInvalid) Error() string {
+ return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID)
+}
+
+// Unwrap unwraps this as a ErrNotExist err
+func (err ErrOAuthClientIDInvalid) Unwrap() error {
+ return util.ErrNotExist
+}
+
+// ErrOAuthApplicationNotFound will be thrown if id cannot be found
+type ErrOAuthApplicationNotFound struct {
+ ID int64
+}
+
+// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist.
+func IsErrOAuthApplicationNotFound(err error) bool {
+ _, ok := err.(ErrOAuthApplicationNotFound)
+ return ok
+}
+
+// Error returns the error message
+func (err ErrOAuthApplicationNotFound) Error() string {
+ return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID)
+}
+
+// Unwrap unwraps this as a ErrNotExist err
+func (err ErrOAuthApplicationNotFound) Unwrap() error {
+ return util.ErrNotExist
+}
+
+// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name
+func GetActiveOAuth2SourceByName(ctx context.Context, name string) (*Source, error) {
+ authSource := new(Source)
+ has, err := db.GetEngine(ctx).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource)
+ if err != nil {
+ return nil, err
+ }
+
+ if !has {
+ return nil, fmt.Errorf("oauth2 source not found, name: %q", name)
+ }
+
+ return authSource, nil
+}
+
+func DeleteOAuth2RelictsByUserID(ctx context.Context, userID int64) error {
+ deleteCond := builder.Select("id").From("oauth2_grant").Where(builder.Eq{"oauth2_grant.user_id": userID})
+
+ if _, err := db.GetEngine(ctx).In("grant_id", deleteCond).
+ Delete(&OAuth2AuthorizationCode{}); err != nil {
+ return err
+ }
+
+ if err := db.DeleteBeans(ctx,
+ &OAuth2Application{UID: userID},
+ &OAuth2Grant{UserID: userID},
+ ); err != nil {
+ return fmt.Errorf("DeleteBeans: %w", err)
+ }
+
+ return nil
+}
+
+// CountOrphanedOAuth2Applications returns the amount of orphaned OAuth2 applications.
+func CountOrphanedOAuth2Applications(ctx context.Context) (int64, error) {
+ return db.GetEngine(ctx).
+ Table("`oauth2_application`").
+ Join("LEFT", "`user`", "`oauth2_application`.`uid` = `user`.`id`").
+ Where(builder.IsNull{"`user`.id"}).
+ Where(builder.NotIn("`oauth2_application`.`client_id`", BuiltinApplicationsClientIDs())).
+ Select("COUNT(`oauth2_application`.`id`)").
+ Count()
+}
+
+// DeleteOrphanedOAuth2Applications deletes orphaned OAuth2 applications.
+func DeleteOrphanedOAuth2Applications(ctx context.Context) (int64, error) {
+ subQuery := builder.Select("`oauth2_application`.id").
+ From("`oauth2_application`").
+ Join("LEFT", "`user`", "`oauth2_application`.`uid` = `user`.`id`").
+ Where(builder.IsNull{"`user`.id"}).
+ Where(builder.NotIn("`oauth2_application`.`client_id`", BuiltinApplicationsClientIDs()))
+
+ b := builder.Delete(builder.In("id", subQuery)).From("`oauth2_application`")
+ _, err := db.GetEngine(ctx).Exec(b)
+ return -1, err
+}
diff --git a/models/auth/oauth2_list.go b/models/auth/oauth2_list.go
new file mode 100644
index 0000000..c55f10b
--- /dev/null
+++ b/models/auth/oauth2_list.go
@@ -0,0 +1,32 @@
+// Copyright 2023 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth
+
+import (
+ "code.gitea.io/gitea/models/db"
+
+ "xorm.io/builder"
+)
+
+type FindOAuth2ApplicationsOptions struct {
+ db.ListOptions
+ // OwnerID is the user id or org id of the owner of the application
+ OwnerID int64
+ // find global applications, if true, then OwnerID will be igonred
+ IsGlobal bool
+}
+
+func (opts FindOAuth2ApplicationsOptions) ToConds() builder.Cond {
+ conds := builder.NewCond()
+ if opts.IsGlobal {
+ conds = conds.And(builder.Eq{"uid": 0})
+ } else if opts.OwnerID != 0 {
+ conds = conds.And(builder.Eq{"uid": opts.OwnerID})
+ }
+ return conds
+}
+
+func (opts FindOAuth2ApplicationsOptions) ToOrders() string {
+ return "id DESC"
+}
diff --git a/models/auth/oauth2_test.go b/models/auth/oauth2_test.go
new file mode 100644
index 0000000..94b506e
--- /dev/null
+++ b/models/auth/oauth2_test.go
@@ -0,0 +1,299 @@
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth_test
+
+import (
+ "path/filepath"
+ "slices"
+ "testing"
+
+ auth_model "code.gitea.io/gitea/models/auth"
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/models/unittest"
+ "code.gitea.io/gitea/modules/setting"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOAuth2Application_GenerateClientSecret(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
+ secret, err := app.GenerateClientSecret(db.DefaultContext)
+ require.NoError(t, err)
+ assert.NotEmpty(t, secret)
+ unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1, ClientSecret: app.ClientSecret})
+}
+
+func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) {
+ require.NoError(b, unittest.PrepareTestDatabase())
+ app := unittest.AssertExistsAndLoadBean(b, &auth_model.OAuth2Application{ID: 1})
+ for i := 0; i < b.N; i++ {
+ _, _ = app.GenerateClientSecret(db.DefaultContext)
+ }
+}
+
+func TestOAuth2Application_ContainsRedirectURI(t *testing.T) {
+ app := &auth_model.OAuth2Application{
+ RedirectURIs: []string{"a", "b", "c"},
+ }
+ assert.True(t, app.ContainsRedirectURI("a"))
+ assert.True(t, app.ContainsRedirectURI("b"))
+ assert.True(t, app.ContainsRedirectURI("c"))
+ assert.False(t, app.ContainsRedirectURI("d"))
+}
+
+func TestOAuth2Application_ContainsRedirectURI_WithPort(t *testing.T) {
+ app := &auth_model.OAuth2Application{
+ RedirectURIs: []string{"http://127.0.0.1/", "http://::1/", "http://192.168.0.1/", "http://intranet/", "https://127.0.0.1/"},
+ ConfidentialClient: false,
+ }
+
+ // http loopback uris should ignore port
+ // https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
+ assert.True(t, app.ContainsRedirectURI("http://127.0.0.1:3456/"))
+ assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
+ assert.True(t, app.ContainsRedirectURI("http://[::1]:3456/"))
+
+ // not http
+ assert.False(t, app.ContainsRedirectURI("https://127.0.0.1:3456/"))
+ // not loopback
+ assert.False(t, app.ContainsRedirectURI("http://192.168.0.1:9954/"))
+ assert.False(t, app.ContainsRedirectURI("http://intranet:3456/"))
+ // unparsable
+ assert.False(t, app.ContainsRedirectURI(":"))
+}
+
+func TestOAuth2Application_ContainsRedirect_Slash(t *testing.T) {
+ app := &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1"}}
+ assert.True(t, app.ContainsRedirectURI("http://127.0.0.1"))
+ assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
+ assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other"))
+
+ app = &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1/"}}
+ assert.True(t, app.ContainsRedirectURI("http://127.0.0.1"))
+ assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
+ assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other"))
+}
+
+func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
+ secret, err := app.GenerateClientSecret(db.DefaultContext)
+ require.NoError(t, err)
+ assert.True(t, app.ValidateClientSecret([]byte(secret)))
+ assert.False(t, app.ValidateClientSecret([]byte("fewijfowejgfiowjeoifew")))
+}
+
+func TestGetOAuth2ApplicationByClientID(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ app, err := auth_model.GetOAuth2ApplicationByClientID(db.DefaultContext, "da7da3ba-9a13-4167-856f-3899de0b0138")
+ require.NoError(t, err)
+ assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
+
+ app, err = auth_model.GetOAuth2ApplicationByClientID(db.DefaultContext, "invalid client id")
+ require.Error(t, err)
+ assert.Nil(t, app)
+}
+
+func TestCreateOAuth2Application(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ app, err := auth_model.CreateOAuth2Application(db.DefaultContext, auth_model.CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
+ require.NoError(t, err)
+ assert.Equal(t, "newapp", app.Name)
+ assert.Len(t, app.ClientID, 36)
+ unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{Name: "newapp"})
+}
+
+func TestOAuth2Application_TableName(t *testing.T) {
+ assert.Equal(t, "oauth2_application", new(auth_model.OAuth2Application).TableName())
+}
+
+func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
+ grant, err := app.GetGrantByUserID(db.DefaultContext, 1)
+ require.NoError(t, err)
+ assert.Equal(t, int64(1), grant.UserID)
+
+ grant, err = app.GetGrantByUserID(db.DefaultContext, 34923458)
+ require.NoError(t, err)
+ assert.Nil(t, grant)
+}
+
+func TestOAuth2Application_CreateGrant(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
+ grant, err := app.CreateGrant(db.DefaultContext, 2, "")
+ require.NoError(t, err)
+ assert.NotNil(t, grant)
+ assert.Equal(t, int64(2), grant.UserID)
+ assert.Equal(t, int64(1), grant.ApplicationID)
+ assert.Equal(t, "", grant.Scope)
+}
+
+//////////////////// Grant
+
+func TestGetOAuth2GrantByID(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ grant, err := auth_model.GetOAuth2GrantByID(db.DefaultContext, 1)
+ require.NoError(t, err)
+ assert.Equal(t, int64(1), grant.ID)
+
+ grant, err = auth_model.GetOAuth2GrantByID(db.DefaultContext, 34923458)
+ require.NoError(t, err)
+ assert.Nil(t, grant)
+}
+
+func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 1})
+ require.NoError(t, grant.IncreaseCounter(db.DefaultContext))
+ assert.Equal(t, int64(2), grant.Counter)
+ unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 2})
+}
+
+func TestOAuth2Grant_ScopeContains(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Scope: "openid profile"})
+ assert.True(t, grant.ScopeContains("openid"))
+ assert.True(t, grant.ScopeContains("profile"))
+ assert.False(t, grant.ScopeContains("profil"))
+ assert.False(t, grant.ScopeContains("profile2"))
+}
+
+func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1})
+ code, err := grant.GenerateNewAuthorizationCode(db.DefaultContext, "https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
+ require.NoError(t, err)
+ assert.NotNil(t, code)
+ assert.Greater(t, len(code.Code), 32) // secret length > 32
+}
+
+func TestOAuth2Grant_TableName(t *testing.T) {
+ assert.Equal(t, "oauth2_grant", new(auth_model.OAuth2Grant).TableName())
+}
+
+func TestGetOAuth2GrantsByUserID(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ result, err := auth_model.GetOAuth2GrantsByUserID(db.DefaultContext, 1)
+ require.NoError(t, err)
+ assert.Len(t, result, 1)
+ assert.Equal(t, int64(1), result[0].ID)
+ assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
+
+ result, err = auth_model.GetOAuth2GrantsByUserID(db.DefaultContext, 34134)
+ require.NoError(t, err)
+ assert.Empty(t, result)
+}
+
+func TestRevokeOAuth2Grant(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ require.NoError(t, auth_model.RevokeOAuth2Grant(db.DefaultContext, 1, 1))
+ unittest.AssertNotExistsBean(t, &auth_model.OAuth2Grant{ID: 1, UserID: 1})
+}
+
+//////////////////// Authorization Code
+
+func TestGetOAuth2AuthorizationByCode(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ code, err := auth_model.GetOAuth2AuthorizationByCode(db.DefaultContext, "authcode")
+ require.NoError(t, err)
+ assert.NotNil(t, code)
+ assert.Equal(t, "authcode", code.Code)
+ assert.Equal(t, int64(1), code.ID)
+
+ code, err = auth_model.GetOAuth2AuthorizationByCode(db.DefaultContext, "does not exist")
+ require.NoError(t, err)
+ assert.Nil(t, code)
+}
+
+func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) {
+ // test plain
+ code := &auth_model.OAuth2AuthorizationCode{
+ CodeChallengeMethod: "plain",
+ CodeChallenge: "test123",
+ }
+ assert.True(t, code.ValidateCodeChallenge("test123"))
+ assert.False(t, code.ValidateCodeChallenge("ierwgjoergjio"))
+
+ // test S256
+ code = &auth_model.OAuth2AuthorizationCode{
+ CodeChallengeMethod: "S256",
+ CodeChallenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg",
+ }
+ assert.True(t, code.ValidateCodeChallenge("N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt"))
+ assert.False(t, code.ValidateCodeChallenge("wiogjerogorewngoenrgoiuenorg"))
+
+ // test unknown
+ code = &auth_model.OAuth2AuthorizationCode{
+ CodeChallengeMethod: "monkey",
+ CodeChallenge: "foiwgjioriogeiogjerger",
+ }
+ assert.False(t, code.ValidateCodeChallenge("foiwgjioriogeiogjerger"))
+
+ // test no code challenge
+ code = &auth_model.OAuth2AuthorizationCode{
+ CodeChallengeMethod: "",
+ CodeChallenge: "foierjiogerogerg",
+ }
+ assert.True(t, code.ValidateCodeChallenge(""))
+}
+
+func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
+ code := &auth_model.OAuth2AuthorizationCode{
+ RedirectURI: "https://example.com/callback",
+ Code: "thecode",
+ }
+
+ redirect, err := code.GenerateRedirectURI("thestate")
+ require.NoError(t, err)
+ assert.Equal(t, "https://example.com/callback?code=thecode&state=thestate", redirect.String())
+
+ redirect, err = code.GenerateRedirectURI("")
+ require.NoError(t, err)
+ assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String())
+}
+
+func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ code := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2AuthorizationCode{Code: "authcode"})
+ require.NoError(t, code.Invalidate(db.DefaultContext))
+ unittest.AssertNotExistsBean(t, &auth_model.OAuth2AuthorizationCode{Code: "authcode"})
+}
+
+func TestOAuth2AuthorizationCode_TableName(t *testing.T) {
+ assert.Equal(t, "oauth2_authorization_code", new(auth_model.OAuth2AuthorizationCode).TableName())
+}
+
+func TestBuiltinApplicationsClientIDs(t *testing.T) {
+ clientIDs := auth_model.BuiltinApplicationsClientIDs()
+ slices.Sort(clientIDs)
+ assert.EqualValues(t, []string{"a4792ccc-144e-407e-86c9-5e7d8d9c3269", "d57cb8c4-630c-4168-8324-ec79935e18d4", "e90ee53c-94e2-48ac-9358-a874fb9e0662"}, clientIDs)
+}
+
+func TestOrphanedOAuth2Applications(t *testing.T) {
+ defer unittest.OverrideFixtures(
+ unittest.FixturesOptions{
+ Dir: filepath.Join(setting.AppWorkPath, "models/fixtures/"),
+ Base: setting.AppWorkPath,
+ Dirs: []string{"models/auth/TestOrphanedOAuth2Applications/"},
+ },
+ )()
+ require.NoError(t, unittest.PrepareTestDatabase())
+
+ count, err := auth_model.CountOrphanedOAuth2Applications(db.DefaultContext)
+ require.NoError(t, err)
+ assert.EqualValues(t, 1, count)
+ unittest.AssertExistsIf(t, true, &auth_model.OAuth2Application{ID: 1002})
+
+ _, err = auth_model.DeleteOrphanedOAuth2Applications(db.DefaultContext)
+ require.NoError(t, err)
+
+ count, err = auth_model.CountOrphanedOAuth2Applications(db.DefaultContext)
+ require.NoError(t, err)
+ assert.EqualValues(t, 0, count)
+ unittest.AssertExistsIf(t, false, &auth_model.OAuth2Application{ID: 1002})
+}
diff --git a/models/auth/session.go b/models/auth/session.go
new file mode 100644
index 0000000..75a205f
--- /dev/null
+++ b/models/auth/session.go
@@ -0,0 +1,120 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth
+
+import (
+ "context"
+ "fmt"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/timeutil"
+
+ "xorm.io/builder"
+)
+
+// Session represents a session compatible for go-chi session
+type Session struct {
+ Key string `xorm:"pk CHAR(16)"` // has to be Key to match with go-chi/session
+ Data []byte `xorm:"BLOB"` // on MySQL this has a maximum size of 64Kb - this may need to be increased
+ Expiry timeutil.TimeStamp // has to be Expiry to match with go-chi/session
+}
+
+func init() {
+ db.RegisterModel(new(Session))
+}
+
+// UpdateSession updates the session with provided id
+func UpdateSession(ctx context.Context, key string, data []byte) error {
+ _, err := db.GetEngine(ctx).ID(key).Update(&Session{
+ Data: data,
+ Expiry: timeutil.TimeStampNow(),
+ })
+ return err
+}
+
+// ReadSession reads the data for the provided session
+func ReadSession(ctx context.Context, key string) (*Session, error) {
+ ctx, committer, err := db.TxContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer committer.Close()
+
+ session, exist, err := db.Get[Session](ctx, builder.Eq{"`key`": key})
+ if err != nil {
+ return nil, err
+ } else if !exist {
+ session = &Session{
+ Key: key,
+ Expiry: timeutil.TimeStampNow(),
+ }
+ if err := db.Insert(ctx, session); err != nil {
+ return nil, err
+ }
+ }
+
+ return session, committer.Commit()
+}
+
+// ExistSession checks if a session exists
+func ExistSession(ctx context.Context, key string) (bool, error) {
+ return db.Exist[Session](ctx, builder.Eq{"`key`": key})
+}
+
+// DestroySession destroys a session
+func DestroySession(ctx context.Context, key string) error {
+ _, err := db.GetEngine(ctx).Delete(&Session{
+ Key: key,
+ })
+ return err
+}
+
+// RegenerateSession regenerates a session from the old id
+func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
+ ctx, committer, err := db.TxContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer committer.Close()
+
+ if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": newKey}); err != nil {
+ return nil, err
+ } else if has {
+ return nil, fmt.Errorf("session Key: %s already exists", newKey)
+ }
+
+ if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": oldKey}); err != nil {
+ return nil, err
+ } else if !has {
+ if err := db.Insert(ctx, &Session{
+ Key: oldKey,
+ Expiry: timeutil.TimeStampNow(),
+ }); err != nil {
+ return nil, err
+ }
+ }
+
+ if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
+ return nil, err
+ }
+
+ s, _, err := db.Get[Session](ctx, builder.Eq{"`key`": newKey})
+ if err != nil {
+ // is not exist, it should be impossible
+ return nil, err
+ }
+
+ return s, committer.Commit()
+}
+
+// CountSessions returns the number of sessions
+func CountSessions(ctx context.Context) (int64, error) {
+ return db.GetEngine(ctx).Count(&Session{})
+}
+
+// CleanupSessions cleans up expired sessions
+func CleanupSessions(ctx context.Context, maxLifetime int64) error {
+ _, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
+ return err
+}
diff --git a/models/auth/session_test.go b/models/auth/session_test.go
new file mode 100644
index 0000000..3b57239
--- /dev/null
+++ b/models/auth/session_test.go
@@ -0,0 +1,143 @@
+// Copyright 2023 The Forgejo Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth_test
+
+import (
+ "testing"
+ "time"
+
+ "code.gitea.io/gitea/models/auth"
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/models/unittest"
+ "code.gitea.io/gitea/modules/timeutil"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthSession(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ defer timeutil.MockUnset()
+
+ key := "I-Like-Free-Software"
+
+ t.Run("Create Session", func(t *testing.T) {
+ // Ensure it doesn't exist.
+ ok, err := auth.ExistSession(db.DefaultContext, key)
+ require.NoError(t, err)
+ assert.False(t, ok)
+
+ preCount, err := auth.CountSessions(db.DefaultContext)
+ require.NoError(t, err)
+
+ now := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
+ timeutil.MockSet(now)
+
+ // New session is created.
+ sess, err := auth.ReadSession(db.DefaultContext, key)
+ require.NoError(t, err)
+ assert.EqualValues(t, key, sess.Key)
+ assert.Empty(t, sess.Data)
+ assert.EqualValues(t, now.Unix(), sess.Expiry)
+
+ // Ensure it exists.
+ ok, err = auth.ExistSession(db.DefaultContext, key)
+ require.NoError(t, err)
+ assert.True(t, ok)
+
+ // Ensure the session is taken into account for count..
+ postCount, err := auth.CountSessions(db.DefaultContext)
+ require.NoError(t, err)
+ assert.Greater(t, postCount, preCount)
+ })
+
+ t.Run("Update session", func(t *testing.T) {
+ data := []byte{0xba, 0xdd, 0xc0, 0xde}
+ now := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
+ timeutil.MockSet(now)
+
+ // Update session.
+ err := auth.UpdateSession(db.DefaultContext, key, data)
+ require.NoError(t, err)
+
+ timeutil.MockSet(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC))
+
+ // Read updated session.
+ // Ensure data is updated and expiry is set from the update session call.
+ sess, err := auth.ReadSession(db.DefaultContext, key)
+ require.NoError(t, err)
+ assert.EqualValues(t, key, sess.Key)
+ assert.EqualValues(t, data, sess.Data)
+ assert.EqualValues(t, now.Unix(), sess.Expiry)
+
+ timeutil.MockSet(now)
+ })
+
+ t.Run("Delete session", func(t *testing.T) {
+ // Ensure it't exist.
+ ok, err := auth.ExistSession(db.DefaultContext, key)
+ require.NoError(t, err)
+ assert.True(t, ok)
+
+ preCount, err := auth.CountSessions(db.DefaultContext)
+ require.NoError(t, err)
+
+ err = auth.DestroySession(db.DefaultContext, key)
+ require.NoError(t, err)
+
+ // Ensure it doesn't exists.
+ ok, err = auth.ExistSession(db.DefaultContext, key)
+ require.NoError(t, err)
+ assert.False(t, ok)
+
+ // Ensure the session is taken into account for count..
+ postCount, err := auth.CountSessions(db.DefaultContext)
+ require.NoError(t, err)
+ assert.Less(t, postCount, preCount)
+ })
+
+ t.Run("Cleanup sessions", func(t *testing.T) {
+ timeutil.MockSet(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC))
+
+ _, err := auth.ReadSession(db.DefaultContext, "sess-1")
+ require.NoError(t, err)
+
+ // One minute later.
+ timeutil.MockSet(time.Date(2023, 1, 1, 0, 1, 0, 0, time.UTC))
+ _, err = auth.ReadSession(db.DefaultContext, "sess-2")
+ require.NoError(t, err)
+
+ // 5 minutes, shouldn't clean up anything.
+ err = auth.CleanupSessions(db.DefaultContext, 5*60)
+ require.NoError(t, err)
+
+ ok, err := auth.ExistSession(db.DefaultContext, "sess-1")
+ require.NoError(t, err)
+ assert.True(t, ok)
+
+ ok, err = auth.ExistSession(db.DefaultContext, "sess-2")
+ require.NoError(t, err)
+ assert.True(t, ok)
+
+ // 1 minute, should clean up sess-1.
+ err = auth.CleanupSessions(db.DefaultContext, 60)
+ require.NoError(t, err)
+
+ ok, err = auth.ExistSession(db.DefaultContext, "sess-1")
+ require.NoError(t, err)
+ assert.False(t, ok)
+
+ ok, err = auth.ExistSession(db.DefaultContext, "sess-2")
+ require.NoError(t, err)
+ assert.True(t, ok)
+
+ // Now, should clean up sess-2.
+ err = auth.CleanupSessions(db.DefaultContext, 0)
+ require.NoError(t, err)
+
+ ok, err = auth.ExistSession(db.DefaultContext, "sess-2")
+ require.NoError(t, err)
+ assert.False(t, ok)
+ })
+}
diff --git a/models/auth/source.go b/models/auth/source.go
new file mode 100644
index 0000000..8f7c2a8
--- /dev/null
+++ b/models/auth/source.go
@@ -0,0 +1,412 @@
+// Copyright 2014 The Gogs Authors. All rights reserved.
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth
+
+import (
+ "context"
+ "fmt"
+ "reflect"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/log"
+ "code.gitea.io/gitea/modules/optional"
+ "code.gitea.io/gitea/modules/timeutil"
+ "code.gitea.io/gitea/modules/util"
+
+ "xorm.io/builder"
+ "xorm.io/xorm"
+ "xorm.io/xorm/convert"
+)
+
+// Type represents an login type.
+type Type int
+
+// Note: new type must append to the end of list to maintain compatibility.
+const (
+ NoType Type = iota
+ Plain // 1
+ LDAP // 2
+ SMTP // 3
+ PAM // 4
+ DLDAP // 5
+ OAuth2 // 6
+ SSPI // 7
+ Remote // 8
+)
+
+// String returns the string name of the LoginType
+func (typ Type) String() string {
+ return Names[typ]
+}
+
+// Int returns the int value of the LoginType
+func (typ Type) Int() int {
+ return int(typ)
+}
+
+// Names contains the name of LoginType values.
+var Names = map[Type]string{
+ LDAP: "LDAP (via BindDN)",
+ DLDAP: "LDAP (simple auth)", // Via direct bind
+ SMTP: "SMTP",
+ PAM: "PAM",
+ OAuth2: "OAuth2",
+ SSPI: "SPNEGO with SSPI",
+ Remote: "Remote",
+}
+
+// Config represents login config as far as the db is concerned
+type Config interface {
+ convert.Conversion
+}
+
+// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set
+type SkipVerifiable interface {
+ IsSkipVerify() bool
+}
+
+// HasTLSer configurations provide a HasTLS to check if TLS can be enabled
+type HasTLSer interface {
+ HasTLS() bool
+}
+
+// UseTLSer configurations provide a HasTLS to check if TLS is enabled
+type UseTLSer interface {
+ UseTLS() bool
+}
+
+// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys
+type SSHKeyProvider interface {
+ ProvidesSSHKeys() bool
+}
+
+// RegisterableSource configurations provide RegisterSource which needs to be run on creation
+type RegisterableSource interface {
+ RegisterSource() error
+ UnregisterSource() error
+}
+
+var registeredConfigs = map[Type]func() Config{}
+
+// RegisterTypeConfig register a config for a provided type
+func RegisterTypeConfig(typ Type, exemplar Config) {
+ if reflect.TypeOf(exemplar).Kind() == reflect.Ptr {
+ // Pointer:
+ registeredConfigs[typ] = func() Config {
+ return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config)
+ }
+ return
+ }
+
+ // Not a Pointer
+ registeredConfigs[typ] = func() Config {
+ return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config)
+ }
+}
+
+// SourceSettable configurations can have their authSource set on them
+type SourceSettable interface {
+ SetAuthSource(*Source)
+}
+
+// Source represents an external way for authorizing users.
+type Source struct {
+ ID int64 `xorm:"pk autoincr"`
+ Type Type
+ Name string `xorm:"UNIQUE"`
+ IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"`
+ IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"`
+ Cfg convert.Conversion `xorm:"TEXT"`
+
+ CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
+}
+
+// TableName xorm will read the table name from this method
+func (Source) TableName() string {
+ return "login_source"
+}
+
+func init() {
+ db.RegisterModel(new(Source))
+}
+
+// BeforeSet is invoked from XORM before setting the value of a field of this object.
+func (source *Source) BeforeSet(colName string, val xorm.Cell) {
+ if colName == "type" {
+ typ := Type(db.Cell2Int64(val))
+ constructor, ok := registeredConfigs[typ]
+ if !ok {
+ return
+ }
+ source.Cfg = constructor()
+ if settable, ok := source.Cfg.(SourceSettable); ok {
+ settable.SetAuthSource(source)
+ }
+ }
+}
+
+// TypeName return name of this login source type.
+func (source *Source) TypeName() string {
+ return Names[source.Type]
+}
+
+// IsLDAP returns true of this source is of the LDAP type.
+func (source *Source) IsLDAP() bool {
+ return source.Type == LDAP
+}
+
+// IsDLDAP returns true of this source is of the DLDAP type.
+func (source *Source) IsDLDAP() bool {
+ return source.Type == DLDAP
+}
+
+// IsSMTP returns true of this source is of the SMTP type.
+func (source *Source) IsSMTP() bool {
+ return source.Type == SMTP
+}
+
+// IsPAM returns true of this source is of the PAM type.
+func (source *Source) IsPAM() bool {
+ return source.Type == PAM
+}
+
+// IsOAuth2 returns true of this source is of the OAuth2 type.
+func (source *Source) IsOAuth2() bool {
+ return source.Type == OAuth2
+}
+
+// IsSSPI returns true of this source is of the SSPI type.
+func (source *Source) IsSSPI() bool {
+ return source.Type == SSPI
+}
+
+func (source *Source) IsRemote() bool {
+ return source.Type == Remote
+}
+
+// HasTLS returns true of this source supports TLS.
+func (source *Source) HasTLS() bool {
+ hasTLSer, ok := source.Cfg.(HasTLSer)
+ return ok && hasTLSer.HasTLS()
+}
+
+// UseTLS returns true of this source is configured to use TLS.
+func (source *Source) UseTLS() bool {
+ useTLSer, ok := source.Cfg.(UseTLSer)
+ return ok && useTLSer.UseTLS()
+}
+
+// SkipVerify returns true if this source is configured to skip SSL
+// verification.
+func (source *Source) SkipVerify() bool {
+ skipVerifiable, ok := source.Cfg.(SkipVerifiable)
+ return ok && skipVerifiable.IsSkipVerify()
+}
+
+// CreateSource inserts a AuthSource in the DB if not already
+// existing with the given name.
+func CreateSource(ctx context.Context, source *Source) error {
+ has, err := db.GetEngine(ctx).Where("name=?", source.Name).Exist(new(Source))
+ if err != nil {
+ return err
+ } else if has {
+ return ErrSourceAlreadyExist{source.Name}
+ }
+ // Synchronization is only available with LDAP for now
+ if !source.IsLDAP() && !source.IsOAuth2() {
+ source.IsSyncEnabled = false
+ }
+
+ _, err = db.GetEngine(ctx).Insert(source)
+ if err != nil {
+ return err
+ }
+
+ if !source.IsActive {
+ return nil
+ }
+
+ if settable, ok := source.Cfg.(SourceSettable); ok {
+ settable.SetAuthSource(source)
+ }
+
+ registerableSource, ok := source.Cfg.(RegisterableSource)
+ if !ok {
+ return nil
+ }
+
+ err = registerableSource.RegisterSource()
+ if err != nil {
+ // remove the AuthSource in case of errors while registering configuration
+ if _, err := db.GetEngine(ctx).ID(source.ID).Delete(new(Source)); err != nil {
+ log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
+ }
+ }
+ return err
+}
+
+type FindSourcesOptions struct {
+ db.ListOptions
+ IsActive optional.Option[bool]
+ LoginType Type
+}
+
+func (opts FindSourcesOptions) ToConds() builder.Cond {
+ conds := builder.NewCond()
+ if opts.IsActive.Has() {
+ conds = conds.And(builder.Eq{"is_active": opts.IsActive.Value()})
+ }
+ if opts.LoginType != NoType {
+ conds = conds.And(builder.Eq{"`type`": opts.LoginType})
+ }
+ return conds
+}
+
+// IsSSPIEnabled returns true if there is at least one activated login
+// source of type LoginSSPI
+func IsSSPIEnabled(ctx context.Context) bool {
+ exist, err := db.Exist[Source](ctx, FindSourcesOptions{
+ IsActive: optional.Some(true),
+ LoginType: SSPI,
+ }.ToConds())
+ if err != nil {
+ log.Error("IsSSPIEnabled: failed to query active SSPI sources: %v", err)
+ return false
+ }
+ return exist
+}
+
+// GetSourceByID returns login source by given ID.
+func GetSourceByID(ctx context.Context, id int64) (*Source, error) {
+ source := new(Source)
+ if id == 0 {
+ source.Cfg = registeredConfigs[NoType]()
+ // Set this source to active
+ // FIXME: allow disabling of db based password authentication in future
+ source.IsActive = true
+ return source, nil
+ }
+
+ has, err := db.GetEngine(ctx).ID(id).Get(source)
+ if err != nil {
+ return nil, err
+ } else if !has {
+ return nil, ErrSourceNotExist{id}
+ }
+ return source, nil
+}
+
+func GetSourceByName(ctx context.Context, name string) (*Source, error) {
+ source := &Source{}
+ has, err := db.GetEngine(ctx).Where("name = ?", name).Get(source)
+ if err != nil {
+ return nil, err
+ } else if !has {
+ return nil, ErrSourceNotExist{}
+ }
+ return source, nil
+}
+
+// UpdateSource updates a Source record in DB.
+func UpdateSource(ctx context.Context, source *Source) error {
+ var originalSource *Source
+ if source.IsOAuth2() {
+ // keep track of the original values so we can restore in case of errors while registering OAuth2 providers
+ var err error
+ if originalSource, err = GetSourceByID(ctx, source.ID); err != nil {
+ return err
+ }
+ }
+
+ has, err := db.GetEngine(ctx).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
+ if err != nil {
+ return err
+ } else if has {
+ return ErrSourceAlreadyExist{source.Name}
+ }
+
+ _, err = db.GetEngine(ctx).ID(source.ID).AllCols().Update(source)
+ if err != nil {
+ return err
+ }
+
+ if !source.IsActive {
+ return nil
+ }
+
+ if settable, ok := source.Cfg.(SourceSettable); ok {
+ settable.SetAuthSource(source)
+ }
+
+ registerableSource, ok := source.Cfg.(RegisterableSource)
+ if !ok {
+ return nil
+ }
+
+ err = registerableSource.RegisterSource()
+ if err != nil {
+ // restore original values since we cannot update the provider it self
+ if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil {
+ log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
+ }
+ }
+ return err
+}
+
+// ErrSourceNotExist represents a "SourceNotExist" kind of error.
+type ErrSourceNotExist struct {
+ ID int64
+}
+
+// IsErrSourceNotExist checks if an error is a ErrSourceNotExist.
+func IsErrSourceNotExist(err error) bool {
+ _, ok := err.(ErrSourceNotExist)
+ return ok
+}
+
+func (err ErrSourceNotExist) Error() string {
+ return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
+}
+
+// Unwrap unwraps this as a ErrNotExist err
+func (err ErrSourceNotExist) Unwrap() error {
+ return util.ErrNotExist
+}
+
+// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error.
+type ErrSourceAlreadyExist struct {
+ Name string
+}
+
+// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist.
+func IsErrSourceAlreadyExist(err error) bool {
+ _, ok := err.(ErrSourceAlreadyExist)
+ return ok
+}
+
+func (err ErrSourceAlreadyExist) Error() string {
+ return fmt.Sprintf("login source already exists [name: %s]", err.Name)
+}
+
+// Unwrap unwraps this as a ErrExist err
+func (err ErrSourceAlreadyExist) Unwrap() error {
+ return util.ErrAlreadyExist
+}
+
+// ErrSourceInUse represents a "SourceInUse" kind of error.
+type ErrSourceInUse struct {
+ ID int64
+}
+
+// IsErrSourceInUse checks if an error is a ErrSourceInUse.
+func IsErrSourceInUse(err error) bool {
+ _, ok := err.(ErrSourceInUse)
+ return ok
+}
+
+func (err ErrSourceInUse) Error() string {
+ return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
+}
diff --git a/models/auth/source_test.go b/models/auth/source_test.go
new file mode 100644
index 0000000..522fecc
--- /dev/null
+++ b/models/auth/source_test.go
@@ -0,0 +1,61 @@
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth_test
+
+import (
+ "strings"
+ "testing"
+
+ auth_model "code.gitea.io/gitea/models/auth"
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/models/unittest"
+ "code.gitea.io/gitea/modules/json"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "xorm.io/xorm/schemas"
+)
+
+type TestSource struct {
+ Provider string
+ ClientID string
+ ClientSecret string
+ OpenIDConnectAutoDiscoveryURL string
+ IconURL string
+}
+
+// FromDB fills up a LDAPConfig from serialized format.
+func (source *TestSource) FromDB(bs []byte) error {
+ return json.Unmarshal(bs, &source)
+}
+
+// ToDB exports a LDAPConfig to a serialized format.
+func (source *TestSource) ToDB() ([]byte, error) {
+ return json.Marshal(source)
+}
+
+func TestDumpAuthSource(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+
+ authSourceSchema, err := db.TableInfo(new(auth_model.Source))
+ require.NoError(t, err)
+
+ auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource))
+
+ auth_model.CreateSource(db.DefaultContext, &auth_model.Source{
+ Type: auth_model.OAuth2,
+ Name: "TestSource",
+ IsActive: false,
+ Cfg: &TestSource{
+ Provider: "ConvertibleSourceName",
+ ClientID: "42",
+ },
+ })
+
+ sb := new(strings.Builder)
+
+ db.DumpTables([]*schemas.Table{authSourceSchema}, sb)
+
+ assert.Contains(t, sb.String(), `"Provider":"ConvertibleSourceName"`)
+}
diff --git a/models/auth/twofactor.go b/models/auth/twofactor.go
new file mode 100644
index 0000000..d0c341a
--- /dev/null
+++ b/models/auth/twofactor.go
@@ -0,0 +1,166 @@
+// Copyright 2017 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth
+
+import (
+ "context"
+ "crypto/md5"
+ "crypto/sha256"
+ "crypto/subtle"
+ "encoding/base32"
+ "encoding/base64"
+ "encoding/hex"
+ "fmt"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/secret"
+ "code.gitea.io/gitea/modules/setting"
+ "code.gitea.io/gitea/modules/timeutil"
+ "code.gitea.io/gitea/modules/util"
+
+ "github.com/pquerna/otp/totp"
+ "golang.org/x/crypto/pbkdf2"
+)
+
+//
+// Two-factor authentication
+//
+
+// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication.
+type ErrTwoFactorNotEnrolled struct {
+ UID int64
+}
+
+// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled.
+func IsErrTwoFactorNotEnrolled(err error) bool {
+ _, ok := err.(ErrTwoFactorNotEnrolled)
+ return ok
+}
+
+func (err ErrTwoFactorNotEnrolled) Error() string {
+ return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID)
+}
+
+// Unwrap unwraps this as a ErrNotExist err
+func (err ErrTwoFactorNotEnrolled) Unwrap() error {
+ return util.ErrNotExist
+}
+
+// TwoFactor represents a two-factor authentication token.
+type TwoFactor struct {
+ ID int64 `xorm:"pk autoincr"`
+ UID int64 `xorm:"UNIQUE"`
+ Secret string
+ ScratchSalt string
+ ScratchHash string
+ LastUsedPasscode string `xorm:"VARCHAR(10)"`
+ CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
+}
+
+func init() {
+ db.RegisterModel(new(TwoFactor))
+}
+
+// GenerateScratchToken recreates the scratch token the user is using.
+func (t *TwoFactor) GenerateScratchToken() (string, error) {
+ tokenBytes, err := util.CryptoRandomBytes(6)
+ if err != nil {
+ return "", err
+ }
+ // these chars are specially chosen, avoid ambiguous chars like `0`, `O`, `1`, `I`.
+ const base32Chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
+ token := base32.NewEncoding(base32Chars).WithPadding(base32.NoPadding).EncodeToString(tokenBytes)
+ t.ScratchSalt, _ = util.CryptoRandomString(10)
+ t.ScratchHash = HashToken(token, t.ScratchSalt)
+ return token, nil
+}
+
+// HashToken return the hashable salt
+func HashToken(token, salt string) string {
+ tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New)
+ return hex.EncodeToString(tempHash)
+}
+
+// VerifyScratchToken verifies if the specified scratch token is valid.
+func (t *TwoFactor) VerifyScratchToken(token string) bool {
+ if len(token) == 0 {
+ return false
+ }
+ tempHash := HashToken(token, t.ScratchSalt)
+ return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1
+}
+
+func (t *TwoFactor) getEncryptionKey() []byte {
+ k := md5.Sum([]byte(setting.SecretKey))
+ return k[:]
+}
+
+// SetSecret sets the 2FA secret.
+func (t *TwoFactor) SetSecret(secretString string) error {
+ secretBytes, err := secret.AesEncrypt(t.getEncryptionKey(), []byte(secretString))
+ if err != nil {
+ return err
+ }
+ t.Secret = base64.StdEncoding.EncodeToString(secretBytes)
+ return nil
+}
+
+// ValidateTOTP validates the provided passcode.
+func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
+ decodedStoredSecret, err := base64.StdEncoding.DecodeString(t.Secret)
+ if err != nil {
+ return false, err
+ }
+ secretBytes, err := secret.AesDecrypt(t.getEncryptionKey(), decodedStoredSecret)
+ if err != nil {
+ return false, err
+ }
+ secretStr := string(secretBytes)
+ return totp.Validate(passcode, secretStr), nil
+}
+
+// NewTwoFactor creates a new two-factor authentication token.
+func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
+ _, err := db.GetEngine(ctx).Insert(t)
+ return err
+}
+
+// UpdateTwoFactor updates a two-factor authentication token.
+func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
+ _, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
+ return err
+}
+
+// GetTwoFactorByUID returns the two-factor authentication token associated with
+// the user, if any.
+func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
+ twofa := &TwoFactor{}
+ has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
+ if err != nil {
+ return nil, err
+ } else if !has {
+ return nil, ErrTwoFactorNotEnrolled{uid}
+ }
+ return twofa, nil
+}
+
+// HasTwoFactorByUID returns the two-factor authentication token associated with
+// the user, if any.
+func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
+ return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
+}
+
+// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
+func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
+ cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
+ UID: userID,
+ })
+ if err != nil {
+ return err
+ } else if cnt != 1 {
+ return ErrTwoFactorNotEnrolled{userID}
+ }
+ return nil
+}
diff --git a/models/auth/webauthn.go b/models/auth/webauthn.go
new file mode 100644
index 0000000..aa13cf6
--- /dev/null
+++ b/models/auth/webauthn.go
@@ -0,0 +1,209 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/modules/timeutil"
+ "code.gitea.io/gitea/modules/util"
+
+ "github.com/go-webauthn/webauthn/webauthn"
+)
+
+// ErrWebAuthnCredentialNotExist represents a "ErrWebAuthnCRedentialNotExist" kind of error.
+type ErrWebAuthnCredentialNotExist struct {
+ ID int64
+ CredentialID []byte
+}
+
+func (err ErrWebAuthnCredentialNotExist) Error() string {
+ if len(err.CredentialID) == 0 {
+ return fmt.Sprintf("WebAuthn credential does not exist [id: %d]", err.ID)
+ }
+ return fmt.Sprintf("WebAuthn credential does not exist [credential_id: %x]", err.CredentialID)
+}
+
+// Unwrap unwraps this as a ErrNotExist err
+func (err ErrWebAuthnCredentialNotExist) Unwrap() error {
+ return util.ErrNotExist
+}
+
+// IsErrWebAuthnCredentialNotExist checks if an error is a ErrWebAuthnCredentialNotExist.
+func IsErrWebAuthnCredentialNotExist(err error) bool {
+ _, ok := err.(ErrWebAuthnCredentialNotExist)
+ return ok
+}
+
+// WebAuthnCredential represents the WebAuthn credential data for a public-key
+// credential conformant to WebAuthn Level 3
+type WebAuthnCredential struct {
+ ID int64 `xorm:"pk autoincr"`
+ Name string
+ LowerName string `xorm:"unique(s)"`
+ UserID int64 `xorm:"INDEX unique(s)"`
+ CredentialID []byte `xorm:"INDEX VARBINARY(1024)"`
+ PublicKey []byte
+ AttestationType string
+ AAGUID []byte
+ SignCount uint32 `xorm:"BIGINT"`
+ CloneWarning bool
+ BackupEligible bool `XORM:"NOT NULL DEFAULT false"`
+ BackupState bool `XORM:"NOT NULL DEFAULT false"`
+ // If legacy is set to true, backup_eligible and backup_state isn't set.
+ Legacy bool `XORM:"NOT NULL DEFAULT true"`
+ CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
+ UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
+}
+
+func init() {
+ db.RegisterModel(new(WebAuthnCredential))
+}
+
+// TableName returns a better table name for WebAuthnCredential
+func (cred WebAuthnCredential) TableName() string {
+ return "webauthn_credential"
+}
+
+// UpdateSignCount will update the database value of SignCount
+func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error {
+ _, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred)
+ return err
+}
+
+// UpdateFromLegacy update the values that aren't present on legacy credentials.
+func (cred *WebAuthnCredential) UpdateFromLegacy(ctx context.Context) error {
+ _, err := db.GetEngine(ctx).ID(cred.ID).Cols("legacy", "backup_eligible", "backup_state").Update(cred)
+ return err
+}
+
+// BeforeInsert will be invoked by XORM before updating a record
+func (cred *WebAuthnCredential) BeforeInsert() {
+ cred.LowerName = strings.ToLower(cred.Name)
+}
+
+// BeforeUpdate will be invoked by XORM before updating a record
+func (cred *WebAuthnCredential) BeforeUpdate() {
+ cred.LowerName = strings.ToLower(cred.Name)
+}
+
+// AfterLoad is invoked from XORM after setting the values of all fields of this object.
+func (cred *WebAuthnCredential) AfterLoad() {
+ cred.LowerName = strings.ToLower(cred.Name)
+}
+
+// WebAuthnCredentialList is a list of *WebAuthnCredential
+type WebAuthnCredentialList []*WebAuthnCredential
+
+// ToCredentials will convert all WebAuthnCredentials to webauthn.Credentials
+func (list WebAuthnCredentialList) ToCredentials() []webauthn.Credential {
+ creds := make([]webauthn.Credential, 0, len(list))
+ for _, cred := range list {
+ creds = append(creds, webauthn.Credential{
+ ID: cred.CredentialID,
+ PublicKey: cred.PublicKey,
+ AttestationType: cred.AttestationType,
+ Flags: webauthn.CredentialFlags{
+ BackupEligible: cred.BackupEligible,
+ BackupState: cred.BackupState,
+ },
+ Authenticator: webauthn.Authenticator{
+ AAGUID: cred.AAGUID,
+ SignCount: cred.SignCount,
+ CloneWarning: cred.CloneWarning,
+ },
+ })
+ }
+ return creds
+}
+
+// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user
+func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
+ creds := make(WebAuthnCredentialList, 0)
+ return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds)
+}
+
+// ExistsWebAuthnCredentialsForUID returns if the given user has credentials
+func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) {
+ return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
+}
+
+// GetWebAuthnCredentialByName returns WebAuthn credential by id
+func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
+ cred := new(WebAuthnCredential)
+ if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil {
+ return nil, err
+ } else if !found {
+ return nil, ErrWebAuthnCredentialNotExist{}
+ }
+ return cred, nil
+}
+
+// GetWebAuthnCredentialByID returns WebAuthn credential by id
+func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
+ cred := new(WebAuthnCredential)
+ if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil {
+ return nil, err
+ } else if !found {
+ return nil, ErrWebAuthnCredentialNotExist{ID: id}
+ }
+ return cred, nil
+}
+
+// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations
+func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) {
+ return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
+}
+
+// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID
+func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
+ cred := new(WebAuthnCredential)
+ if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil {
+ return nil, err
+ } else if !found {
+ return nil, ErrWebAuthnCredentialNotExist{CredentialID: credID}
+ }
+ return cred, nil
+}
+
+// CreateCredential will create a new WebAuthnCredential from the given Credential
+func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
+ c := &WebAuthnCredential{
+ UserID: userID,
+ Name: name,
+ CredentialID: cred.ID,
+ PublicKey: cred.PublicKey,
+ AttestationType: cred.AttestationType,
+ AAGUID: cred.Authenticator.AAGUID,
+ SignCount: cred.Authenticator.SignCount,
+ CloneWarning: false,
+ BackupEligible: cred.Flags.BackupEligible,
+ BackupState: cred.Flags.BackupState,
+ Legacy: false,
+ }
+
+ if err := db.Insert(ctx, c); err != nil {
+ return nil, err
+ }
+ return c, nil
+}
+
+// DeleteCredential will delete WebAuthnCredential
+func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) {
+ had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{})
+ return had > 0, err
+}
+
+// WebAuthnCredentials implementns the webauthn.User interface
+func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) {
+ dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ return dbCreds.ToCredentials(), nil
+}
diff --git a/models/auth/webauthn_test.go b/models/auth/webauthn_test.go
new file mode 100644
index 0000000..e1cd652
--- /dev/null
+++ b/models/auth/webauthn_test.go
@@ -0,0 +1,78 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package auth_test
+
+import (
+ "testing"
+
+ auth_model "code.gitea.io/gitea/models/auth"
+ "code.gitea.io/gitea/models/db"
+ "code.gitea.io/gitea/models/unittest"
+
+ "github.com/go-webauthn/webauthn/webauthn"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestGetWebAuthnCredentialByID(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+
+ res, err := auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 1)
+ require.NoError(t, err)
+ assert.Equal(t, "WebAuthn credential", res.Name)
+
+ _, err = auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 342432)
+ require.Error(t, err)
+ assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err))
+}
+
+func TestGetWebAuthnCredentialsByUID(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+
+ res, err := auth_model.GetWebAuthnCredentialsByUID(db.DefaultContext, 32)
+ require.NoError(t, err)
+ assert.Len(t, res, 1)
+ assert.Equal(t, "WebAuthn credential", res[0].Name)
+}
+
+func TestWebAuthnCredential_TableName(t *testing.T) {
+ assert.Equal(t, "webauthn_credential", auth_model.WebAuthnCredential{}.TableName())
+}
+
+func TestWebAuthnCredential_UpdateSignCount(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
+ cred.SignCount = 1
+ require.NoError(t, cred.UpdateSignCount(db.DefaultContext))
+ unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1})
+}
+
+func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
+ cred.SignCount = 0xffffffff
+ require.NoError(t, cred.UpdateSignCount(db.DefaultContext))
+ unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff})
+}
+
+func TestWebAuthenCredential_UpdateFromLegacy(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+ cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1, Legacy: true})
+ cred.Legacy = false
+ cred.BackupEligible = true
+ cred.BackupState = true
+ require.NoError(t, cred.UpdateFromLegacy(db.DefaultContext))
+ unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, BackupEligible: true, BackupState: true}, "legacy = false")
+}
+
+func TestCreateCredential(t *testing.T) {
+ require.NoError(t, unittest.PrepareTestDatabase())
+
+ res, err := auth_model.CreateCredential(db.DefaultContext, 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test"), Flags: webauthn.CredentialFlags{BackupEligible: true, BackupState: true}})
+ require.NoError(t, err)
+ assert.Equal(t, "WebAuthn Created Credential", res.Name)
+ assert.Equal(t, []byte("Test"), res.CredentialID)
+
+ unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{Name: "WebAuthn Created Credential", UserID: 1, BackupEligible: true, BackupState: true}, "legacy = false")
+}