From dd136858f1ea40ad3c94191d647487fa4f31926c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 18 Oct 2024 20:33:49 +0200 Subject: Adding upstream version 9.0.0. Signed-off-by: Daniel Baumann --- .../oauth2_application.yaml | 25 + models/auth/access_token.go | 236 +++++++ models/auth/access_token_scope.go | 350 +++++++++++ models/auth/access_token_scope_test.go | 90 +++ models/auth/access_token_test.go | 133 ++++ models/auth/auth_token.go | 96 +++ models/auth/main_test.go | 20 + models/auth/oauth2.go | 676 +++++++++++++++++++++ models/auth/oauth2_list.go | 32 + models/auth/oauth2_test.go | 299 +++++++++ models/auth/session.go | 120 ++++ models/auth/session_test.go | 143 +++++ models/auth/source.go | 412 +++++++++++++ models/auth/source_test.go | 61 ++ models/auth/twofactor.go | 166 +++++ models/auth/webauthn.go | 209 +++++++ models/auth/webauthn_test.go | 78 +++ 17 files changed, 3146 insertions(+) create mode 100644 models/auth/TestOrphanedOAuth2Applications/oauth2_application.yaml create mode 100644 models/auth/access_token.go create mode 100644 models/auth/access_token_scope.go create mode 100644 models/auth/access_token_scope_test.go create mode 100644 models/auth/access_token_test.go create mode 100644 models/auth/auth_token.go create mode 100644 models/auth/main_test.go create mode 100644 models/auth/oauth2.go create mode 100644 models/auth/oauth2_list.go create mode 100644 models/auth/oauth2_test.go create mode 100644 models/auth/session.go create mode 100644 models/auth/session_test.go create mode 100644 models/auth/source.go create mode 100644 models/auth/source_test.go create mode 100644 models/auth/twofactor.go create mode 100644 models/auth/webauthn.go create mode 100644 models/auth/webauthn_test.go (limited to 'models/auth') diff --git a/models/auth/TestOrphanedOAuth2Applications/oauth2_application.yaml b/models/auth/TestOrphanedOAuth2Applications/oauth2_application.yaml new file mode 100644 index 0000000..b188770 --- /dev/null +++ b/models/auth/TestOrphanedOAuth2Applications/oauth2_application.yaml @@ -0,0 +1,25 @@ +- + id: 1000 + uid: 0 + name: "Git Credential Manager" + client_id: "e90ee53c-94e2-48ac-9358-a874fb9e0662" + redirect_uris: '["http://127.0.0.1", "https://127.0.0.1"]' + created_unix: 1712358091 + updated_unix: 1712358091 +- + id: 1001 + uid: 0 + name: "git-credential-oauth" + client_id: "a4792ccc-144e-407e-86c9-5e7d8d9c3269" + redirect_uris: '["http://127.0.0.1", "https://127.0.0.1"]' + created_unix: 1712358091 + updated_unix: 1712358091 + +- + id: 1002 + uid: 1234567890 + name: "Should be removed" + client_id: "deadc0de-badd-dd11-fee1-deaddecafbad" + redirect_uris: '["http://127.0.0.1", "https://127.0.0.1"]' + created_unix: 1712358091 + updated_unix: 1712358091 diff --git a/models/auth/access_token.go b/models/auth/access_token.go new file mode 100644 index 0000000..63331b4 --- /dev/null +++ b/models/auth/access_token.go @@ -0,0 +1,236 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "context" + "crypto/subtle" + "encoding/hex" + "fmt" + "time" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/modules/util" + + lru "github.com/hashicorp/golang-lru/v2" + "xorm.io/builder" +) + +// ErrAccessTokenNotExist represents a "AccessTokenNotExist" kind of error. +type ErrAccessTokenNotExist struct { + Token string +} + +// IsErrAccessTokenNotExist checks if an error is a ErrAccessTokenNotExist. +func IsErrAccessTokenNotExist(err error) bool { + _, ok := err.(ErrAccessTokenNotExist) + return ok +} + +func (err ErrAccessTokenNotExist) Error() string { + return fmt.Sprintf("access token does not exist [sha: %s]", err.Token) +} + +func (err ErrAccessTokenNotExist) Unwrap() error { + return util.ErrNotExist +} + +// ErrAccessTokenEmpty represents a "AccessTokenEmpty" kind of error. +type ErrAccessTokenEmpty struct{} + +// IsErrAccessTokenEmpty checks if an error is a ErrAccessTokenEmpty. +func IsErrAccessTokenEmpty(err error) bool { + _, ok := err.(ErrAccessTokenEmpty) + return ok +} + +func (err ErrAccessTokenEmpty) Error() string { + return "access token is empty" +} + +func (err ErrAccessTokenEmpty) Unwrap() error { + return util.ErrInvalidArgument +} + +var successfulAccessTokenCache *lru.Cache[string, any] + +// AccessToken represents a personal access token. +type AccessToken struct { + ID int64 `xorm:"pk autoincr"` + UID int64 `xorm:"INDEX"` + Name string + Token string `xorm:"-"` + TokenHash string `xorm:"UNIQUE"` // sha256 of token + TokenSalt string + TokenLastEight string `xorm:"INDEX token_last_eight"` + Scope AccessTokenScope + + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` + HasRecentActivity bool `xorm:"-"` + HasUsed bool `xorm:"-"` +} + +// AfterLoad is invoked from XORM after setting the values of all fields of this object. +func (t *AccessToken) AfterLoad() { + t.HasUsed = t.UpdatedUnix > t.CreatedUnix + t.HasRecentActivity = t.UpdatedUnix.AddDuration(7*24*time.Hour) > timeutil.TimeStampNow() +} + +func init() { + db.RegisterModel(new(AccessToken), func() error { + if setting.SuccessfulTokensCacheSize > 0 { + var err error + successfulAccessTokenCache, err = lru.New[string, any](setting.SuccessfulTokensCacheSize) + if err != nil { + return fmt.Errorf("unable to allocate AccessToken cache: %w", err) + } + } else { + successfulAccessTokenCache = nil + } + return nil + }) +} + +// NewAccessToken creates new access token. +func NewAccessToken(ctx context.Context, t *AccessToken) error { + salt, err := util.CryptoRandomString(10) + if err != nil { + return err + } + token, err := util.CryptoRandomBytes(20) + if err != nil { + return err + } + t.TokenSalt = salt + t.Token = hex.EncodeToString(token) + t.TokenHash = HashToken(t.Token, t.TokenSalt) + t.TokenLastEight = t.Token[len(t.Token)-8:] + _, err = db.GetEngine(ctx).Insert(t) + return err +} + +// DisplayPublicOnly whether to display this as a public-only token. +func (t *AccessToken) DisplayPublicOnly() bool { + publicOnly, err := t.Scope.PublicOnly() + if err != nil { + return false + } + return publicOnly +} + +func getAccessTokenIDFromCache(token string) int64 { + if successfulAccessTokenCache == nil { + return 0 + } + tInterface, ok := successfulAccessTokenCache.Get(token) + if !ok { + return 0 + } + t, ok := tInterface.(int64) + if !ok { + return 0 + } + return t +} + +// GetAccessTokenBySHA returns access token by given token value +func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) { + if token == "" { + return nil, ErrAccessTokenEmpty{} + } + // A token is defined as being SHA1 sum these are 40 hexadecimal bytes long + if len(token) != 40 { + return nil, ErrAccessTokenNotExist{token} + } + for _, x := range []byte(token) { + if x < '0' || (x > '9' && x < 'a') || x > 'f' { + return nil, ErrAccessTokenNotExist{token} + } + } + + lastEight := token[len(token)-8:] + + if id := getAccessTokenIDFromCache(token); id > 0 { + accessToken := &AccessToken{ + TokenLastEight: lastEight, + } + // Re-get the token from the db in case it has been deleted in the intervening period + has, err := db.GetEngine(ctx).ID(id).Get(accessToken) + if err != nil { + return nil, err + } + if has { + return accessToken, nil + } + successfulAccessTokenCache.Remove(token) + } + + var tokens []AccessToken + err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens) + if err != nil { + return nil, err + } else if len(tokens) == 0 { + return nil, ErrAccessTokenNotExist{token} + } + + for _, t := range tokens { + tempHash := HashToken(token, t.TokenSalt) + if subtle.ConstantTimeCompare([]byte(t.TokenHash), []byte(tempHash)) == 1 { + if successfulAccessTokenCache != nil { + successfulAccessTokenCache.Add(token, t.ID) + } + return &t, nil + } + } + return nil, ErrAccessTokenNotExist{token} +} + +// AccessTokenByNameExists checks if a token name has been used already by a user. +func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) { + return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist() +} + +// ListAccessTokensOptions contain filter options +type ListAccessTokensOptions struct { + db.ListOptions + Name string + UserID int64 +} + +func (opts ListAccessTokensOptions) ToConds() builder.Cond { + cond := builder.NewCond() + // user id is required, otherwise it will return all result which maybe a possible bug + cond = cond.And(builder.Eq{"uid": opts.UserID}) + if len(opts.Name) > 0 { + cond = cond.And(builder.Eq{"name": opts.Name}) + } + return cond +} + +func (opts ListAccessTokensOptions) ToOrders() string { + return "created_unix DESC" +} + +// UpdateAccessToken updates information of access token. +func UpdateAccessToken(ctx context.Context, t *AccessToken) error { + _, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t) + return err +} + +// DeleteAccessTokenByID deletes access token by given ID. +func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error { + cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{ + UID: userID, + }) + if err != nil { + return err + } else if cnt != 1 { + return ErrAccessTokenNotExist{} + } + return nil +} diff --git a/models/auth/access_token_scope.go b/models/auth/access_token_scope.go new file mode 100644 index 0000000..003ca5c --- /dev/null +++ b/models/auth/access_token_scope.go @@ -0,0 +1,350 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "fmt" + "strings" + + "code.gitea.io/gitea/models/perm" +) + +// AccessTokenScopeCategory represents the scope category for an access token +type AccessTokenScopeCategory int + +const ( + AccessTokenScopeCategoryActivityPub = iota + AccessTokenScopeCategoryAdmin + AccessTokenScopeCategoryMisc // WARN: this is now just a placeholder, don't remove it which will change the following values + AccessTokenScopeCategoryNotification + AccessTokenScopeCategoryOrganization + AccessTokenScopeCategoryPackage + AccessTokenScopeCategoryIssue + AccessTokenScopeCategoryRepository + AccessTokenScopeCategoryUser +) + +// AllAccessTokenScopeCategories contains all access token scope categories +var AllAccessTokenScopeCategories = []AccessTokenScopeCategory{ + AccessTokenScopeCategoryActivityPub, + AccessTokenScopeCategoryAdmin, + AccessTokenScopeCategoryMisc, + AccessTokenScopeCategoryNotification, + AccessTokenScopeCategoryOrganization, + AccessTokenScopeCategoryPackage, + AccessTokenScopeCategoryIssue, + AccessTokenScopeCategoryRepository, + AccessTokenScopeCategoryUser, +} + +// AccessTokenScopeLevel represents the access levels without a given scope category +type AccessTokenScopeLevel int + +const ( + NoAccess AccessTokenScopeLevel = iota + Read + Write +) + +// AccessTokenScope represents the scope for an access token. +type AccessTokenScope string + +// for all categories, write implies read +const ( + AccessTokenScopeAll AccessTokenScope = "all" + AccessTokenScopePublicOnly AccessTokenScope = "public-only" // limited to public orgs/repos + + AccessTokenScopeReadActivityPub AccessTokenScope = "read:activitypub" + AccessTokenScopeWriteActivityPub AccessTokenScope = "write:activitypub" + + AccessTokenScopeReadAdmin AccessTokenScope = "read:admin" + AccessTokenScopeWriteAdmin AccessTokenScope = "write:admin" + + AccessTokenScopeReadMisc AccessTokenScope = "read:misc" + AccessTokenScopeWriteMisc AccessTokenScope = "write:misc" + + AccessTokenScopeReadNotification AccessTokenScope = "read:notification" + AccessTokenScopeWriteNotification AccessTokenScope = "write:notification" + + AccessTokenScopeReadOrganization AccessTokenScope = "read:organization" + AccessTokenScopeWriteOrganization AccessTokenScope = "write:organization" + + AccessTokenScopeReadPackage AccessTokenScope = "read:package" + AccessTokenScopeWritePackage AccessTokenScope = "write:package" + + AccessTokenScopeReadIssue AccessTokenScope = "read:issue" + AccessTokenScopeWriteIssue AccessTokenScope = "write:issue" + + AccessTokenScopeReadRepository AccessTokenScope = "read:repository" + AccessTokenScopeWriteRepository AccessTokenScope = "write:repository" + + AccessTokenScopeReadUser AccessTokenScope = "read:user" + AccessTokenScopeWriteUser AccessTokenScope = "write:user" +) + +// accessTokenScopeBitmap represents a bitmap of access token scopes. +type accessTokenScopeBitmap uint64 + +// Bitmap of each scope, including the child scopes. +const ( + // AccessTokenScopeAllBits is the bitmap of all access token scopes + accessTokenScopeAllBits accessTokenScopeBitmap = accessTokenScopeWriteActivityPubBits | + accessTokenScopeWriteAdminBits | accessTokenScopeWriteMiscBits | accessTokenScopeWriteNotificationBits | + accessTokenScopeWriteOrganizationBits | accessTokenScopeWritePackageBits | accessTokenScopeWriteIssueBits | + accessTokenScopeWriteRepositoryBits | accessTokenScopeWriteUserBits + + accessTokenScopePublicOnlyBits accessTokenScopeBitmap = 1 << iota + + accessTokenScopeReadActivityPubBits accessTokenScopeBitmap = 1 << iota + accessTokenScopeWriteActivityPubBits accessTokenScopeBitmap = 1< 64 scopes, + // refactoring the whole implementation in this file (and only this file) is needed. +) + +// allAccessTokenScopes contains all access token scopes. +// The order is important: parent scope must precede child scopes. +var allAccessTokenScopes = []AccessTokenScope{ + AccessTokenScopePublicOnly, + AccessTokenScopeWriteActivityPub, AccessTokenScopeReadActivityPub, + AccessTokenScopeWriteAdmin, AccessTokenScopeReadAdmin, + AccessTokenScopeWriteMisc, AccessTokenScopeReadMisc, + AccessTokenScopeWriteNotification, AccessTokenScopeReadNotification, + AccessTokenScopeWriteOrganization, AccessTokenScopeReadOrganization, + AccessTokenScopeWritePackage, AccessTokenScopeReadPackage, + AccessTokenScopeWriteIssue, AccessTokenScopeReadIssue, + AccessTokenScopeWriteRepository, AccessTokenScopeReadRepository, + AccessTokenScopeWriteUser, AccessTokenScopeReadUser, +} + +// allAccessTokenScopeBits contains all access token scopes. +var allAccessTokenScopeBits = map[AccessTokenScope]accessTokenScopeBitmap{ + AccessTokenScopeAll: accessTokenScopeAllBits, + AccessTokenScopePublicOnly: accessTokenScopePublicOnlyBits, + AccessTokenScopeReadActivityPub: accessTokenScopeReadActivityPubBits, + AccessTokenScopeWriteActivityPub: accessTokenScopeWriteActivityPubBits, + AccessTokenScopeReadAdmin: accessTokenScopeReadAdminBits, + AccessTokenScopeWriteAdmin: accessTokenScopeWriteAdminBits, + AccessTokenScopeReadMisc: accessTokenScopeReadMiscBits, + AccessTokenScopeWriteMisc: accessTokenScopeWriteMiscBits, + AccessTokenScopeReadNotification: accessTokenScopeReadNotificationBits, + AccessTokenScopeWriteNotification: accessTokenScopeWriteNotificationBits, + AccessTokenScopeReadOrganization: accessTokenScopeReadOrganizationBits, + AccessTokenScopeWriteOrganization: accessTokenScopeWriteOrganizationBits, + AccessTokenScopeReadPackage: accessTokenScopeReadPackageBits, + AccessTokenScopeWritePackage: accessTokenScopeWritePackageBits, + AccessTokenScopeReadIssue: accessTokenScopeReadIssueBits, + AccessTokenScopeWriteIssue: accessTokenScopeWriteIssueBits, + AccessTokenScopeReadRepository: accessTokenScopeReadRepositoryBits, + AccessTokenScopeWriteRepository: accessTokenScopeWriteRepositoryBits, + AccessTokenScopeReadUser: accessTokenScopeReadUserBits, + AccessTokenScopeWriteUser: accessTokenScopeWriteUserBits, +} + +// readAccessTokenScopes maps a scope category to the read permission scope +var accessTokenScopes = map[AccessTokenScopeLevel]map[AccessTokenScopeCategory]AccessTokenScope{ + Read: { + AccessTokenScopeCategoryActivityPub: AccessTokenScopeReadActivityPub, + AccessTokenScopeCategoryAdmin: AccessTokenScopeReadAdmin, + AccessTokenScopeCategoryMisc: AccessTokenScopeReadMisc, + AccessTokenScopeCategoryNotification: AccessTokenScopeReadNotification, + AccessTokenScopeCategoryOrganization: AccessTokenScopeReadOrganization, + AccessTokenScopeCategoryPackage: AccessTokenScopeReadPackage, + AccessTokenScopeCategoryIssue: AccessTokenScopeReadIssue, + AccessTokenScopeCategoryRepository: AccessTokenScopeReadRepository, + AccessTokenScopeCategoryUser: AccessTokenScopeReadUser, + }, + Write: { + AccessTokenScopeCategoryActivityPub: AccessTokenScopeWriteActivityPub, + AccessTokenScopeCategoryAdmin: AccessTokenScopeWriteAdmin, + AccessTokenScopeCategoryMisc: AccessTokenScopeWriteMisc, + AccessTokenScopeCategoryNotification: AccessTokenScopeWriteNotification, + AccessTokenScopeCategoryOrganization: AccessTokenScopeWriteOrganization, + AccessTokenScopeCategoryPackage: AccessTokenScopeWritePackage, + AccessTokenScopeCategoryIssue: AccessTokenScopeWriteIssue, + AccessTokenScopeCategoryRepository: AccessTokenScopeWriteRepository, + AccessTokenScopeCategoryUser: AccessTokenScopeWriteUser, + }, +} + +// GetRequiredScopes gets the specific scopes for a given level and categories +func GetRequiredScopes(level AccessTokenScopeLevel, scopeCategories ...AccessTokenScopeCategory) []AccessTokenScope { + scopes := make([]AccessTokenScope, 0, len(scopeCategories)) + for _, cat := range scopeCategories { + scopes = append(scopes, accessTokenScopes[level][cat]) + } + return scopes +} + +// ContainsCategory checks if a list of categories contains a specific category +func ContainsCategory(categories []AccessTokenScopeCategory, category AccessTokenScopeCategory) bool { + for _, c := range categories { + if c == category { + return true + } + } + return false +} + +// GetScopeLevelFromAccessMode converts permission access mode to scope level +func GetScopeLevelFromAccessMode(mode perm.AccessMode) AccessTokenScopeLevel { + switch mode { + case perm.AccessModeNone: + return NoAccess + case perm.AccessModeRead: + return Read + case perm.AccessModeWrite: + return Write + case perm.AccessModeAdmin: + return Write + case perm.AccessModeOwner: + return Write + default: + return NoAccess + } +} + +// parse the scope string into a bitmap, thus removing possible duplicates. +func (s AccessTokenScope) parse() (accessTokenScopeBitmap, error) { + var bitmap accessTokenScopeBitmap + + // The following is the more performant equivalent of 'for _, v := range strings.Split(remainingScope, ",")' as this is hot code + remainingScopes := string(s) + for len(remainingScopes) > 0 { + i := strings.IndexByte(remainingScopes, ',') + var v string + if i < 0 { + v = remainingScopes + remainingScopes = "" + } else if i+1 >= len(remainingScopes) { + v = remainingScopes[:i] + remainingScopes = "" + } else { + v = remainingScopes[:i] + remainingScopes = remainingScopes[i+1:] + } + singleScope := AccessTokenScope(v) + if singleScope == "" || singleScope == "sudo" { + continue + } + if singleScope == AccessTokenScopeAll { + bitmap |= accessTokenScopeAllBits + continue + } + + bits, ok := allAccessTokenScopeBits[singleScope] + if !ok { + return 0, fmt.Errorf("invalid access token scope: %s", singleScope) + } + bitmap |= bits + } + + return bitmap, nil +} + +// StringSlice returns the AccessTokenScope as a []string +func (s AccessTokenScope) StringSlice() []string { + return strings.Split(string(s), ",") +} + +// Normalize returns a normalized scope string without any duplicates. +func (s AccessTokenScope) Normalize() (AccessTokenScope, error) { + bitmap, err := s.parse() + if err != nil { + return "", err + } + + return bitmap.toScope(), nil +} + +// PublicOnly checks if this token scope is limited to public resources +func (s AccessTokenScope) PublicOnly() (bool, error) { + bitmap, err := s.parse() + if err != nil { + return false, err + } + + return bitmap.hasScope(AccessTokenScopePublicOnly) +} + +// HasScope returns true if the string has the given scope +func (s AccessTokenScope) HasScope(scopes ...AccessTokenScope) (bool, error) { + bitmap, err := s.parse() + if err != nil { + return false, err + } + + for _, s := range scopes { + if has, err := bitmap.hasScope(s); !has || err != nil { + return has, err + } + } + + return true, nil +} + +// hasScope returns true if the string has the given scope +func (bitmap accessTokenScopeBitmap) hasScope(scope AccessTokenScope) (bool, error) { + expectedBits, ok := allAccessTokenScopeBits[scope] + if !ok { + return false, fmt.Errorf("invalid access token scope: %s", scope) + } + + return bitmap&expectedBits == expectedBits, nil +} + +// toScope returns a normalized scope string without any duplicates. +func (bitmap accessTokenScopeBitmap) toScope() AccessTokenScope { + var scopes []string + + // iterate over all scopes, and reconstruct the bitmap + // if the reconstructed bitmap doesn't change, then the scope is already included + var reconstruct accessTokenScopeBitmap + + for _, singleScope := range allAccessTokenScopes { + // no need for error checking here, since we know the scope is valid + if ok, _ := bitmap.hasScope(singleScope); ok { + current := reconstruct | allAccessTokenScopeBits[singleScope] + if current == reconstruct { + continue + } + + reconstruct = current + scopes = append(scopes, string(singleScope)) + } + } + + scope := AccessTokenScope(strings.Join(scopes, ",")) + scope = AccessTokenScope(strings.ReplaceAll( + string(scope), + "write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user", + "all", + )) + return scope +} diff --git a/models/auth/access_token_scope_test.go b/models/auth/access_token_scope_test.go new file mode 100644 index 0000000..d11c5e6 --- /dev/null +++ b/models/auth/access_token_scope_test.go @@ -0,0 +1,90 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +type scopeTestNormalize struct { + in AccessTokenScope + out AccessTokenScope + err error +} + +func TestAccessTokenScope_Normalize(t *testing.T) { + tests := []scopeTestNormalize{ + {"", "", nil}, + {"write:misc,write:notification,read:package,write:notification,public-only", "public-only,write:misc,write:notification,read:package", nil}, + {"all,sudo", "all", nil}, + {"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user", "all", nil}, + {"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user,public-only", "public-only,all", nil}, + } + + for _, scope := range []string{"activitypub", "admin", "misc", "notification", "organization", "package", "issue", "repository", "user"} { + tests = append(tests, + scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%s", scope)), AccessTokenScope(fmt.Sprintf("read:%s", scope)), nil}, + scopeTestNormalize{AccessTokenScope(fmt.Sprintf("write:%s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil}, + scopeTestNormalize{AccessTokenScope(fmt.Sprintf("write:%[1]s,read:%[1]s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil}, + scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil}, + scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s,write:%[1]s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil}, + ) + } + + for _, test := range tests { + t.Run(string(test.in), func(t *testing.T) { + scope, err := test.in.Normalize() + assert.Equal(t, test.out, scope) + assert.Equal(t, test.err, err) + }) + } +} + +type scopeTestHasScope struct { + in AccessTokenScope + scope AccessTokenScope + out bool + err error +} + +func TestAccessTokenScope_HasScope(t *testing.T) { + tests := []scopeTestHasScope{ + {"read:admin", "write:package", false, nil}, + {"all", "write:package", true, nil}, + {"write:package", "all", false, nil}, + {"public-only", "read:issue", false, nil}, + } + + for _, scope := range []string{"activitypub", "admin", "misc", "notification", "organization", "package", "issue", "repository", "user"} { + tests = append(tests, + scopeTestHasScope{ + AccessTokenScope(fmt.Sprintf("read:%s", scope)), + AccessTokenScope(fmt.Sprintf("read:%s", scope)), true, nil, + }, + scopeTestHasScope{ + AccessTokenScope(fmt.Sprintf("write:%s", scope)), + AccessTokenScope(fmt.Sprintf("write:%s", scope)), true, nil, + }, + scopeTestHasScope{ + AccessTokenScope(fmt.Sprintf("write:%s", scope)), + AccessTokenScope(fmt.Sprintf("read:%s", scope)), true, nil, + }, + scopeTestHasScope{ + AccessTokenScope(fmt.Sprintf("read:%s", scope)), + AccessTokenScope(fmt.Sprintf("write:%s", scope)), false, nil, + }, + ) + } + + for _, test := range tests { + t.Run(string(test.in), func(t *testing.T) { + hasScope, err := test.in.HasScope(test.scope) + assert.Equal(t, test.out, hasScope) + assert.Equal(t, test.err, err) + }) + } +} diff --git a/models/auth/access_token_test.go b/models/auth/access_token_test.go new file mode 100644 index 0000000..e6ea487 --- /dev/null +++ b/models/auth/access_token_test.go @@ -0,0 +1,133 @@ +// Copyright 2016 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth_test + +import ( + "testing" + + auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/unittest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAccessToken(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + token := &auth_model.AccessToken{ + UID: 3, + Name: "Token C", + } + require.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token)) + unittest.AssertExistsAndLoadBean(t, token) + + invalidToken := &auth_model.AccessToken{ + ID: token.ID, // duplicate + UID: 2, + Name: "Token F", + } + require.Error(t, auth_model.NewAccessToken(db.DefaultContext, invalidToken)) +} + +func TestAccessTokenByNameExists(t *testing.T) { + name := "Token Gitea" + + require.NoError(t, unittest.PrepareTestDatabase()) + token := &auth_model.AccessToken{ + UID: 3, + Name: name, + } + + // Check to make sure it doesn't exists already + exist, err := auth_model.AccessTokenByNameExists(db.DefaultContext, token) + require.NoError(t, err) + assert.False(t, exist) + + // Save it to the database + require.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token)) + unittest.AssertExistsAndLoadBean(t, token) + + // This token must be found by name in the DB now + exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, token) + require.NoError(t, err) + assert.True(t, exist) + + user4Token := &auth_model.AccessToken{ + UID: 4, + Name: name, + } + + // Name matches but different user ID, this shouldn't exists in the + // database + exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, user4Token) + require.NoError(t, err) + assert.False(t, exist) +} + +func TestGetAccessTokenBySHA(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "d2c6c1ba3890b309189a8e618c72a162e4efbf36") + require.NoError(t, err) + assert.Equal(t, int64(1), token.UID) + assert.Equal(t, "Token A", token.Name) + assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash) + assert.Equal(t, "e4efbf36", token.TokenLastEight) + + _, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "notahash") + require.Error(t, err) + assert.True(t, auth_model.IsErrAccessTokenNotExist(err)) + + _, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "") + require.Error(t, err) + assert.True(t, auth_model.IsErrAccessTokenEmpty(err)) +} + +func TestListAccessTokens(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + tokens, err := db.Find[auth_model.AccessToken](db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 1}) + require.NoError(t, err) + if assert.Len(t, tokens, 2) { + assert.Equal(t, int64(1), tokens[0].UID) + assert.Equal(t, int64(1), tokens[1].UID) + assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token A") + assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B") + } + + tokens, err = db.Find[auth_model.AccessToken](db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 2}) + require.NoError(t, err) + if assert.Len(t, tokens, 1) { + assert.Equal(t, int64(2), tokens[0].UID) + assert.Equal(t, "Token A", tokens[0].Name) + } + + tokens, err = db.Find[auth_model.AccessToken](db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 100}) + require.NoError(t, err) + assert.Empty(t, tokens) +} + +func TestUpdateAccessToken(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c") + require.NoError(t, err) + token.Name = "Token Z" + + require.NoError(t, auth_model.UpdateAccessToken(db.DefaultContext, token)) + unittest.AssertExistsAndLoadBean(t, token) +} + +func TestDeleteAccessTokenByID(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c") + require.NoError(t, err) + assert.Equal(t, int64(1), token.UID) + + require.NoError(t, auth_model.DeleteAccessTokenByID(db.DefaultContext, token.ID, 1)) + unittest.AssertNotExistsBean(t, token) + + err = auth_model.DeleteAccessTokenByID(db.DefaultContext, 100, 100) + require.Error(t, err) + assert.True(t, auth_model.IsErrAccessTokenNotExist(err)) +} diff --git a/models/auth/auth_token.go b/models/auth/auth_token.go new file mode 100644 index 0000000..2c3ca90 --- /dev/null +++ b/models/auth/auth_token.go @@ -0,0 +1,96 @@ +// Copyright 2023 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "time" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/modules/util" +) + +// AuthorizationToken represents a authorization token to a user. +type AuthorizationToken struct { + ID int64 `xorm:"pk autoincr"` + UID int64 `xorm:"INDEX"` + LookupKey string `xorm:"INDEX UNIQUE"` + HashedValidator string + Expiry timeutil.TimeStamp +} + +// TableName provides the real table name. +func (AuthorizationToken) TableName() string { + return "forgejo_auth_token" +} + +func init() { + db.RegisterModel(new(AuthorizationToken)) +} + +// IsExpired returns if the authorization token is expired. +func (authToken *AuthorizationToken) IsExpired() bool { + return authToken.Expiry.AsLocalTime().Before(time.Now()) +} + +// GenerateAuthToken generates a new authentication token for the given user. +// It returns the lookup key and validator values that should be passed to the +// user via a long-term cookie. +func GenerateAuthToken(ctx context.Context, userID int64, expiry timeutil.TimeStamp) (lookupKey, validator string, err error) { + // Request 64 random bytes. The first 32 bytes will be used for the lookupKey + // and the other 32 bytes will be used for the validator. + rBytes, err := util.CryptoRandomBytes(64) + if err != nil { + return "", "", err + } + hexEncoded := hex.EncodeToString(rBytes) + validator, lookupKey = hexEncoded[64:], hexEncoded[:64] + + _, err = db.GetEngine(ctx).Insert(&AuthorizationToken{ + UID: userID, + Expiry: expiry, + LookupKey: lookupKey, + HashedValidator: HashValidator(rBytes[32:]), + }) + return lookupKey, validator, err +} + +// FindAuthToken will find a authorization token via the lookup key. +func FindAuthToken(ctx context.Context, lookupKey string) (*AuthorizationToken, error) { + var authToken AuthorizationToken + has, err := db.GetEngine(ctx).Where("lookup_key = ?", lookupKey).Get(&authToken) + if err != nil { + return nil, err + } else if !has { + return nil, fmt.Errorf("lookup key %q: %w", lookupKey, util.ErrNotExist) + } + return &authToken, nil +} + +// DeleteAuthToken will delete the authorization token. +func DeleteAuthToken(ctx context.Context, authToken *AuthorizationToken) error { + _, err := db.DeleteByBean(ctx, authToken) + return err +} + +// DeleteAuthTokenByUser will delete all authorization tokens for the user. +func DeleteAuthTokenByUser(ctx context.Context, userID int64) error { + if userID == 0 { + return nil + } + + _, err := db.DeleteByBean(ctx, &AuthorizationToken{UID: userID}) + return err +} + +// HashValidator will return a hexified hashed version of the validator. +func HashValidator(validator []byte) string { + h := sha256.New() + h.Write(validator) + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/models/auth/main_test.go b/models/auth/main_test.go new file mode 100644 index 0000000..d772ea6 --- /dev/null +++ b/models/auth/main_test.go @@ -0,0 +1,20 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth_test + +import ( + "testing" + + "code.gitea.io/gitea/models/unittest" + + _ "code.gitea.io/gitea/models" + _ "code.gitea.io/gitea/models/actions" + _ "code.gitea.io/gitea/models/activities" + _ "code.gitea.io/gitea/models/auth" + _ "code.gitea.io/gitea/models/perm/access" +) + +func TestMain(m *testing.M) { + unittest.MainTest(m) +} diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go new file mode 100644 index 0000000..125d64b --- /dev/null +++ b/models/auth/oauth2.go @@ -0,0 +1,676 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "context" + "crypto/sha256" + "encoding/base32" + "encoding/base64" + "errors" + "fmt" + "net" + "net/url" + "strings" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/container" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/modules/util" + + uuid "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" + "xorm.io/builder" + "xorm.io/xorm" +) + +// OAuth2Application represents an OAuth2 client (RFC 6749) +type OAuth2Application struct { + ID int64 `xorm:"pk autoincr"` + UID int64 `xorm:"INDEX"` + Name string + ClientID string `xorm:"unique"` + ClientSecret string + // OAuth defines both Confidential and Public client types + // https://datatracker.ietf.org/doc/html/rfc6749#section-2.1 + // "Authorization servers MUST record the client type in the client registration details" + // https://datatracker.ietf.org/doc/html/rfc8252#section-8.4 + ConfidentialClient bool `xorm:"NOT NULL DEFAULT TRUE"` + RedirectURIs []string `xorm:"redirect_uris JSON TEXT"` + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` +} + +func init() { + db.RegisterModel(new(OAuth2Application)) + db.RegisterModel(new(OAuth2AuthorizationCode)) + db.RegisterModel(new(OAuth2Grant)) +} + +type BuiltinOAuth2Application struct { + ConfigName string + DisplayName string + RedirectURIs []string +} + +func BuiltinApplications() map[string]*BuiltinOAuth2Application { + m := make(map[string]*BuiltinOAuth2Application) + m["a4792ccc-144e-407e-86c9-5e7d8d9c3269"] = &BuiltinOAuth2Application{ + ConfigName: "git-credential-oauth", + DisplayName: "git-credential-oauth", + RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"}, + } + m["e90ee53c-94e2-48ac-9358-a874fb9e0662"] = &BuiltinOAuth2Application{ + ConfigName: "git-credential-manager", + DisplayName: "Git Credential Manager", + RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"}, + } + m["d57cb8c4-630c-4168-8324-ec79935e18d4"] = &BuiltinOAuth2Application{ + ConfigName: "tea", + DisplayName: "tea", + RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"}, + } + return m +} + +func BuiltinApplicationsClientIDs() (clientIDs []string) { + for clientID := range BuiltinApplications() { + clientIDs = append(clientIDs, clientID) + } + return clientIDs +} + +func Init(ctx context.Context) error { + builtinApps := BuiltinApplications() + var builtinAllClientIDs []string + for clientID := range builtinApps { + builtinAllClientIDs = append(builtinAllClientIDs, clientID) + } + + var registeredApps []*OAuth2Application + if err := db.GetEngine(ctx).In("client_id", builtinAllClientIDs).Find(®isteredApps); err != nil { + return err + } + + clientIDsToAdd := container.Set[string]{} + for _, configName := range setting.OAuth2.DefaultApplications { + found := false + for clientID, builtinApp := range builtinApps { + if builtinApp.ConfigName == configName { + clientIDsToAdd.Add(clientID) // add all user-configured apps to the "add" list + found = true + } + } + if !found { + return fmt.Errorf("unknown oauth2 application: %q", configName) + } + } + clientIDsToDelete := container.Set[string]{} + for _, app := range registeredApps { + if !clientIDsToAdd.Contains(app.ClientID) { + clientIDsToDelete.Add(app.ClientID) // if a registered app is not in the "add" list, it should be deleted + } + } + for _, app := range registeredApps { + clientIDsToAdd.Remove(app.ClientID) // no need to re-add existing (registered) apps, so remove them from the set + } + + for _, app := range registeredApps { + if clientIDsToDelete.Contains(app.ClientID) { + if err := deleteOAuth2Application(ctx, app.ID, 0); err != nil { + return err + } + } + } + for clientID := range clientIDsToAdd { + builtinApp := builtinApps[clientID] + if err := db.Insert(ctx, &OAuth2Application{ + Name: builtinApp.DisplayName, + ClientID: clientID, + RedirectURIs: builtinApp.RedirectURIs, + }); err != nil { + return err + } + } + + return nil +} + +// TableName sets the table name to `oauth2_application` +func (app *OAuth2Application) TableName() string { + return "oauth2_application" +} + +// ContainsRedirectURI checks if redirectURI is allowed for app +func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool { + // OAuth2 requires the redirect URI to be an exact match, no dynamic parts are allowed. + // https://stackoverflow.com/questions/55524480/should-dynamic-query-parameters-be-present-in-the-redirection-uri-for-an-oauth2 + // https://www.rfc-editor.org/rfc/rfc6819#section-5.2.3.3 + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics-12#section-3.1 + contains := func(s string) bool { + s = strings.TrimSuffix(strings.ToLower(s), "/") + for _, u := range app.RedirectURIs { + if strings.TrimSuffix(strings.ToLower(u), "/") == s { + return true + } + } + return false + } + if !app.ConfidentialClient { + uri, err := url.Parse(redirectURI) + // ignore port for http loopback uris following https://datatracker.ietf.org/doc/html/rfc8252#section-7.3 + if err == nil && uri.Scheme == "http" && uri.Port() != "" { + ip := net.ParseIP(uri.Hostname()) + if ip != nil && ip.IsLoopback() { + // strip port + uri.Host = uri.Hostname() + if contains(uri.String()) { + return true + } + } + } + } + return contains(redirectURI) +} + +// Base32 characters, but lowercased. +const lowerBase32Chars = "abcdefghijklmnopqrstuvwxyz234567" + +// base32 encoder that uses lowered characters without padding. +var base32Lower = base32.NewEncoding(lowerBase32Chars).WithPadding(base32.NoPadding) + +// GenerateClientSecret will generate the client secret and returns the plaintext and saves the hash at the database +func (app *OAuth2Application) GenerateClientSecret(ctx context.Context) (string, error) { + rBytes, err := util.CryptoRandomBytes(32) + if err != nil { + return "", err + } + // Add a prefix to the base32, this is in order to make it easier + // for code scanners to grab sensitive tokens. + clientSecret := "gto_" + base32Lower.EncodeToString(rBytes) + + hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost) + if err != nil { + return "", err + } + app.ClientSecret = string(hashedSecret) + if _, err := db.GetEngine(ctx).ID(app.ID).Cols("client_secret").Update(app); err != nil { + return "", err + } + return clientSecret, nil +} + +// ValidateClientSecret validates the given secret by the hash saved in database +func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool { + return bcrypt.CompareHashAndPassword([]byte(app.ClientSecret), secret) == nil +} + +// GetGrantByUserID returns a OAuth2Grant by its user and application ID +func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) { + grant = new(OAuth2Grant) + if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil { + return nil, err + } else if !has { + return nil, nil + } + return grant, nil +} + +// CreateGrant generates a grant for an user +func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) { + grant := &OAuth2Grant{ + ApplicationID: app.ID, + UserID: userID, + Scope: scope, + } + err := db.Insert(ctx, grant) + if err != nil { + return nil, err + } + return grant, nil +} + +// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found. +func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) { + app = new(OAuth2Application) + has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app) + if !has { + return nil, ErrOAuthClientIDInvalid{ClientID: clientID} + } + return app, err +} + +// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found. +func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) { + app = new(OAuth2Application) + has, err := db.GetEngine(ctx).ID(id).Get(app) + if err != nil { + return nil, err + } + if !has { + return nil, ErrOAuthApplicationNotFound{ID: id} + } + return app, nil +} + +// CreateOAuth2ApplicationOptions holds options to create an oauth2 application +type CreateOAuth2ApplicationOptions struct { + Name string + UserID int64 + ConfidentialClient bool + RedirectURIs []string +} + +// CreateOAuth2Application inserts a new oauth2 application +func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { + clientID := uuid.New().String() + app := &OAuth2Application{ + UID: opts.UserID, + Name: opts.Name, + ClientID: clientID, + RedirectURIs: opts.RedirectURIs, + ConfidentialClient: opts.ConfidentialClient, + } + if err := db.Insert(ctx, app); err != nil { + return nil, err + } + return app, nil +} + +// UpdateOAuth2ApplicationOptions holds options to update an oauth2 application +type UpdateOAuth2ApplicationOptions struct { + ID int64 + Name string + UserID int64 + ConfidentialClient bool + RedirectURIs []string +} + +// UpdateOAuth2Application updates an oauth2 application +func UpdateOAuth2Application(ctx context.Context, opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) { + ctx, committer, err := db.TxContext(ctx) + if err != nil { + return nil, err + } + defer committer.Close() + + app, err := GetOAuth2ApplicationByID(ctx, opts.ID) + if err != nil { + return nil, err + } + if app.UID != opts.UserID { + return nil, errors.New("UID mismatch") + } + builtinApps := BuiltinApplications() + if _, builtin := builtinApps[app.ClientID]; builtin { + return nil, fmt.Errorf("failed to edit OAuth2 application: application is locked: %s", app.ClientID) + } + + app.Name = opts.Name + app.RedirectURIs = opts.RedirectURIs + app.ConfidentialClient = opts.ConfidentialClient + + if err = updateOAuth2Application(ctx, app); err != nil { + return nil, err + } + app.ClientSecret = "" + + return app, committer.Commit() +} + +func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error { + if _, err := db.GetEngine(ctx).ID(app.ID).UseBool("confidential_client").Update(app); err != nil { + return err + } + return nil +} + +func deleteOAuth2Application(ctx context.Context, id, userid int64) error { + sess := db.GetEngine(ctx) + // the userid could be 0 if the app is instance-wide + if deleted, err := sess.Where(builder.Eq{"id": id, "uid": userid}).Delete(&OAuth2Application{}); err != nil { + return err + } else if deleted == 0 { + return ErrOAuthApplicationNotFound{ID: id} + } + codes := make([]*OAuth2AuthorizationCode, 0) + // delete correlating auth codes + if err := sess.Join("INNER", "oauth2_grant", + "oauth2_authorization_code.grant_id = oauth2_grant.id AND oauth2_grant.application_id = ?", id).Find(&codes); err != nil { + return err + } + codeIDs := make([]int64, 0, len(codes)) + for _, grant := range codes { + codeIDs = append(codeIDs, grant.ID) + } + + if _, err := sess.In("id", codeIDs).Delete(new(OAuth2AuthorizationCode)); err != nil { + return err + } + + if _, err := sess.Where("application_id = ?", id).Delete(new(OAuth2Grant)); err != nil { + return err + } + return nil +} + +// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app. +func DeleteOAuth2Application(ctx context.Context, id, userid int64) error { + ctx, committer, err := db.TxContext(ctx) + if err != nil { + return err + } + defer committer.Close() + app, err := GetOAuth2ApplicationByID(ctx, id) + if err != nil { + return err + } + builtinApps := BuiltinApplications() + if _, builtin := builtinApps[app.ClientID]; builtin { + return fmt.Errorf("failed to delete OAuth2 application: application is locked: %s", app.ClientID) + } + if err := deleteOAuth2Application(ctx, id, userid); err != nil { + return err + } + return committer.Commit() +} + +////////////////////////////////////////////////////// + +// OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime. +type OAuth2AuthorizationCode struct { + ID int64 `xorm:"pk autoincr"` + Grant *OAuth2Grant `xorm:"-"` + GrantID int64 + Code string `xorm:"INDEX unique"` + CodeChallenge string + CodeChallengeMethod string + RedirectURI string + ValidUntil timeutil.TimeStamp `xorm:"index"` +} + +// TableName sets the table name to `oauth2_authorization_code` +func (code *OAuth2AuthorizationCode) TableName() string { + return "oauth2_authorization_code" +} + +// GenerateRedirectURI generates a redirect URI for a successful authorization request. State will be used if not empty. +func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (*url.URL, error) { + redirect, err := url.Parse(code.RedirectURI) + if err != nil { + return nil, err + } + q := redirect.Query() + if state != "" { + q.Set("state", state) + } + q.Set("code", code.Code) + redirect.RawQuery = q.Encode() + return redirect, err +} + +// Invalidate deletes the auth code from the database to invalidate this code +func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error { + _, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code) + return err +} + +// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation. +func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool { + switch code.CodeChallengeMethod { + case "S256": + // base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6 + h := sha256.Sum256([]byte(verifier)) + hashedVerifier := base64.RawURLEncoding.EncodeToString(h[:]) + return hashedVerifier == code.CodeChallenge + case "plain": + return verifier == code.CodeChallenge + case "": + return true + default: + // unsupported method -> return false + return false + } +} + +// GetOAuth2AuthorizationByCode returns an authorization by its code +func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) { + auth = new(OAuth2AuthorizationCode) + if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil { + return nil, err + } else if !has { + return nil, nil + } + auth.Grant = new(OAuth2Grant) + if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil { + return nil, err + } else if !has { + return nil, nil + } + return auth, nil +} + +////////////////////////////////////////////////////// + +// OAuth2Grant represents the permission of an user for a specific application to access resources +type OAuth2Grant struct { + ID int64 `xorm:"pk autoincr"` + UserID int64 `xorm:"INDEX unique(user_application)"` + Application *OAuth2Application `xorm:"-"` + ApplicationID int64 `xorm:"INDEX unique(user_application)"` + Counter int64 `xorm:"NOT NULL DEFAULT 1"` + Scope string `xorm:"TEXT"` + Nonce string `xorm:"TEXT"` + CreatedUnix timeutil.TimeStamp `xorm:"created"` + UpdatedUnix timeutil.TimeStamp `xorm:"updated"` +} + +// TableName sets the table name to `oauth2_grant` +func (grant *OAuth2Grant) TableName() string { + return "oauth2_grant" +} + +// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database +func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) { + rBytes, err := util.CryptoRandomBytes(32) + if err != nil { + return &OAuth2AuthorizationCode{}, err + } + // Add a prefix to the base32, this is in order to make it easier + // for code scanners to grab sensitive tokens. + codeSecret := "gta_" + base32Lower.EncodeToString(rBytes) + + code = &OAuth2AuthorizationCode{ + Grant: grant, + GrantID: grant.ID, + RedirectURI: redirectURI, + Code: codeSecret, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + } + if err := db.Insert(ctx, code); err != nil { + return nil, err + } + return code, nil +} + +// IncreaseCounter increases the counter and updates the grant +func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error { + _, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant)) + if err != nil { + return err + } + updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID) + if err != nil { + return err + } + grant.Counter = updatedGrant.Counter + return nil +} + +// ScopeContains returns true if the grant scope contains the specified scope +func (grant *OAuth2Grant) ScopeContains(scope string) bool { + for _, currentScope := range strings.Split(grant.Scope, " ") { + if scope == currentScope { + return true + } + } + return false +} + +// SetNonce updates the current nonce value of a grant +func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error { + grant.Nonce = nonce + _, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant) + if err != nil { + return err + } + return nil +} + +// GetOAuth2GrantByID returns the grant with the given ID +func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) { + grant = new(OAuth2Grant) + if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil { + return nil, err + } else if !has { + return nil, nil + } + return grant, err +} + +// GetOAuth2GrantsByUserID lists all grants of a certain user +func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) { + type joinedOAuth2Grant struct { + Grant *OAuth2Grant `xorm:"extends"` + Application *OAuth2Application `xorm:"extends"` + } + var results *xorm.Rows + var err error + if results, err = db.GetEngine(ctx). + Table("oauth2_grant"). + Where("user_id = ?", uid). + Join("INNER", "oauth2_application", "application_id = oauth2_application.id"). + Rows(new(joinedOAuth2Grant)); err != nil { + return nil, err + } + defer results.Close() + grants := make([]*OAuth2Grant, 0) + for results.Next() { + joinedGrant := new(joinedOAuth2Grant) + if err := results.Scan(joinedGrant); err != nil { + return nil, err + } + joinedGrant.Grant.Application = joinedGrant.Application + grants = append(grants, joinedGrant.Grant) + } + return grants, nil +} + +// RevokeOAuth2Grant deletes the grant with grantID and userID +func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error { + _, err := db.GetEngine(ctx).Where(builder.Eq{"id": grantID, "user_id": userID}).Delete(&OAuth2Grant{}) + return err +} + +// ErrOAuthClientIDInvalid will be thrown if client id cannot be found +type ErrOAuthClientIDInvalid struct { + ClientID string +} + +// IsErrOauthClientIDInvalid checks if an error is a ErrOAuthClientIDInvalid. +func IsErrOauthClientIDInvalid(err error) bool { + _, ok := err.(ErrOAuthClientIDInvalid) + return ok +} + +// Error returns the error message +func (err ErrOAuthClientIDInvalid) Error() string { + return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID) +} + +// Unwrap unwraps this as a ErrNotExist err +func (err ErrOAuthClientIDInvalid) Unwrap() error { + return util.ErrNotExist +} + +// ErrOAuthApplicationNotFound will be thrown if id cannot be found +type ErrOAuthApplicationNotFound struct { + ID int64 +} + +// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist. +func IsErrOAuthApplicationNotFound(err error) bool { + _, ok := err.(ErrOAuthApplicationNotFound) + return ok +} + +// Error returns the error message +func (err ErrOAuthApplicationNotFound) Error() string { + return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID) +} + +// Unwrap unwraps this as a ErrNotExist err +func (err ErrOAuthApplicationNotFound) Unwrap() error { + return util.ErrNotExist +} + +// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name +func GetActiveOAuth2SourceByName(ctx context.Context, name string) (*Source, error) { + authSource := new(Source) + has, err := db.GetEngine(ctx).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource) + if err != nil { + return nil, err + } + + if !has { + return nil, fmt.Errorf("oauth2 source not found, name: %q", name) + } + + return authSource, nil +} + +func DeleteOAuth2RelictsByUserID(ctx context.Context, userID int64) error { + deleteCond := builder.Select("id").From("oauth2_grant").Where(builder.Eq{"oauth2_grant.user_id": userID}) + + if _, err := db.GetEngine(ctx).In("grant_id", deleteCond). + Delete(&OAuth2AuthorizationCode{}); err != nil { + return err + } + + if err := db.DeleteBeans(ctx, + &OAuth2Application{UID: userID}, + &OAuth2Grant{UserID: userID}, + ); err != nil { + return fmt.Errorf("DeleteBeans: %w", err) + } + + return nil +} + +// CountOrphanedOAuth2Applications returns the amount of orphaned OAuth2 applications. +func CountOrphanedOAuth2Applications(ctx context.Context) (int64, error) { + return db.GetEngine(ctx). + Table("`oauth2_application`"). + Join("LEFT", "`user`", "`oauth2_application`.`uid` = `user`.`id`"). + Where(builder.IsNull{"`user`.id"}). + Where(builder.NotIn("`oauth2_application`.`client_id`", BuiltinApplicationsClientIDs())). + Select("COUNT(`oauth2_application`.`id`)"). + Count() +} + +// DeleteOrphanedOAuth2Applications deletes orphaned OAuth2 applications. +func DeleteOrphanedOAuth2Applications(ctx context.Context) (int64, error) { + subQuery := builder.Select("`oauth2_application`.id"). + From("`oauth2_application`"). + Join("LEFT", "`user`", "`oauth2_application`.`uid` = `user`.`id`"). + Where(builder.IsNull{"`user`.id"}). + Where(builder.NotIn("`oauth2_application`.`client_id`", BuiltinApplicationsClientIDs())) + + b := builder.Delete(builder.In("id", subQuery)).From("`oauth2_application`") + _, err := db.GetEngine(ctx).Exec(b) + return -1, err +} diff --git a/models/auth/oauth2_list.go b/models/auth/oauth2_list.go new file mode 100644 index 0000000..c55f10b --- /dev/null +++ b/models/auth/oauth2_list.go @@ -0,0 +1,32 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "code.gitea.io/gitea/models/db" + + "xorm.io/builder" +) + +type FindOAuth2ApplicationsOptions struct { + db.ListOptions + // OwnerID is the user id or org id of the owner of the application + OwnerID int64 + // find global applications, if true, then OwnerID will be igonred + IsGlobal bool +} + +func (opts FindOAuth2ApplicationsOptions) ToConds() builder.Cond { + conds := builder.NewCond() + if opts.IsGlobal { + conds = conds.And(builder.Eq{"uid": 0}) + } else if opts.OwnerID != 0 { + conds = conds.And(builder.Eq{"uid": opts.OwnerID}) + } + return conds +} + +func (opts FindOAuth2ApplicationsOptions) ToOrders() string { + return "id DESC" +} diff --git a/models/auth/oauth2_test.go b/models/auth/oauth2_test.go new file mode 100644 index 0000000..94b506e --- /dev/null +++ b/models/auth/oauth2_test.go @@ -0,0 +1,299 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth_test + +import ( + "path/filepath" + "slices" + "testing" + + auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/unittest" + "code.gitea.io/gitea/modules/setting" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOAuth2Application_GenerateClientSecret(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1}) + secret, err := app.GenerateClientSecret(db.DefaultContext) + require.NoError(t, err) + assert.NotEmpty(t, secret) + unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1, ClientSecret: app.ClientSecret}) +} + +func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) { + require.NoError(b, unittest.PrepareTestDatabase()) + app := unittest.AssertExistsAndLoadBean(b, &auth_model.OAuth2Application{ID: 1}) + for i := 0; i < b.N; i++ { + _, _ = app.GenerateClientSecret(db.DefaultContext) + } +} + +func TestOAuth2Application_ContainsRedirectURI(t *testing.T) { + app := &auth_model.OAuth2Application{ + RedirectURIs: []string{"a", "b", "c"}, + } + assert.True(t, app.ContainsRedirectURI("a")) + assert.True(t, app.ContainsRedirectURI("b")) + assert.True(t, app.ContainsRedirectURI("c")) + assert.False(t, app.ContainsRedirectURI("d")) +} + +func TestOAuth2Application_ContainsRedirectURI_WithPort(t *testing.T) { + app := &auth_model.OAuth2Application{ + RedirectURIs: []string{"http://127.0.0.1/", "http://::1/", "http://192.168.0.1/", "http://intranet/", "https://127.0.0.1/"}, + ConfidentialClient: false, + } + + // http loopback uris should ignore port + // https://datatracker.ietf.org/doc/html/rfc8252#section-7.3 + assert.True(t, app.ContainsRedirectURI("http://127.0.0.1:3456/")) + assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/")) + assert.True(t, app.ContainsRedirectURI("http://[::1]:3456/")) + + // not http + assert.False(t, app.ContainsRedirectURI("https://127.0.0.1:3456/")) + // not loopback + assert.False(t, app.ContainsRedirectURI("http://192.168.0.1:9954/")) + assert.False(t, app.ContainsRedirectURI("http://intranet:3456/")) + // unparsable + assert.False(t, app.ContainsRedirectURI(":")) +} + +func TestOAuth2Application_ContainsRedirect_Slash(t *testing.T) { + app := &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1"}} + assert.True(t, app.ContainsRedirectURI("http://127.0.0.1")) + assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/")) + assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other")) + + app = &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1/"}} + assert.True(t, app.ContainsRedirectURI("http://127.0.0.1")) + assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/")) + assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other")) +} + +func TestOAuth2Application_ValidateClientSecret(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1}) + secret, err := app.GenerateClientSecret(db.DefaultContext) + require.NoError(t, err) + assert.True(t, app.ValidateClientSecret([]byte(secret))) + assert.False(t, app.ValidateClientSecret([]byte("fewijfowejgfiowjeoifew"))) +} + +func TestGetOAuth2ApplicationByClientID(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + app, err := auth_model.GetOAuth2ApplicationByClientID(db.DefaultContext, "da7da3ba-9a13-4167-856f-3899de0b0138") + require.NoError(t, err) + assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID) + + app, err = auth_model.GetOAuth2ApplicationByClientID(db.DefaultContext, "invalid client id") + require.Error(t, err) + assert.Nil(t, app) +} + +func TestCreateOAuth2Application(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + app, err := auth_model.CreateOAuth2Application(db.DefaultContext, auth_model.CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1}) + require.NoError(t, err) + assert.Equal(t, "newapp", app.Name) + assert.Len(t, app.ClientID, 36) + unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{Name: "newapp"}) +} + +func TestOAuth2Application_TableName(t *testing.T) { + assert.Equal(t, "oauth2_application", new(auth_model.OAuth2Application).TableName()) +} + +func TestOAuth2Application_GetGrantByUserID(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1}) + grant, err := app.GetGrantByUserID(db.DefaultContext, 1) + require.NoError(t, err) + assert.Equal(t, int64(1), grant.UserID) + + grant, err = app.GetGrantByUserID(db.DefaultContext, 34923458) + require.NoError(t, err) + assert.Nil(t, grant) +} + +func TestOAuth2Application_CreateGrant(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1}) + grant, err := app.CreateGrant(db.DefaultContext, 2, "") + require.NoError(t, err) + assert.NotNil(t, grant) + assert.Equal(t, int64(2), grant.UserID) + assert.Equal(t, int64(1), grant.ApplicationID) + assert.Equal(t, "", grant.Scope) +} + +//////////////////// Grant + +func TestGetOAuth2GrantByID(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + grant, err := auth_model.GetOAuth2GrantByID(db.DefaultContext, 1) + require.NoError(t, err) + assert.Equal(t, int64(1), grant.ID) + + grant, err = auth_model.GetOAuth2GrantByID(db.DefaultContext, 34923458) + require.NoError(t, err) + assert.Nil(t, grant) +} + +func TestOAuth2Grant_IncreaseCounter(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 1}) + require.NoError(t, grant.IncreaseCounter(db.DefaultContext)) + assert.Equal(t, int64(2), grant.Counter) + unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 2}) +} + +func TestOAuth2Grant_ScopeContains(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Scope: "openid profile"}) + assert.True(t, grant.ScopeContains("openid")) + assert.True(t, grant.ScopeContains("profile")) + assert.False(t, grant.ScopeContains("profil")) + assert.False(t, grant.ScopeContains("profile2")) +} + +func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1}) + code, err := grant.GenerateNewAuthorizationCode(db.DefaultContext, "https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256") + require.NoError(t, err) + assert.NotNil(t, code) + assert.Greater(t, len(code.Code), 32) // secret length > 32 +} + +func TestOAuth2Grant_TableName(t *testing.T) { + assert.Equal(t, "oauth2_grant", new(auth_model.OAuth2Grant).TableName()) +} + +func TestGetOAuth2GrantsByUserID(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + result, err := auth_model.GetOAuth2GrantsByUserID(db.DefaultContext, 1) + require.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, int64(1), result[0].ID) + assert.Equal(t, result[0].ApplicationID, result[0].Application.ID) + + result, err = auth_model.GetOAuth2GrantsByUserID(db.DefaultContext, 34134) + require.NoError(t, err) + assert.Empty(t, result) +} + +func TestRevokeOAuth2Grant(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + require.NoError(t, auth_model.RevokeOAuth2Grant(db.DefaultContext, 1, 1)) + unittest.AssertNotExistsBean(t, &auth_model.OAuth2Grant{ID: 1, UserID: 1}) +} + +//////////////////// Authorization Code + +func TestGetOAuth2AuthorizationByCode(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + code, err := auth_model.GetOAuth2AuthorizationByCode(db.DefaultContext, "authcode") + require.NoError(t, err) + assert.NotNil(t, code) + assert.Equal(t, "authcode", code.Code) + assert.Equal(t, int64(1), code.ID) + + code, err = auth_model.GetOAuth2AuthorizationByCode(db.DefaultContext, "does not exist") + require.NoError(t, err) + assert.Nil(t, code) +} + +func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) { + // test plain + code := &auth_model.OAuth2AuthorizationCode{ + CodeChallengeMethod: "plain", + CodeChallenge: "test123", + } + assert.True(t, code.ValidateCodeChallenge("test123")) + assert.False(t, code.ValidateCodeChallenge("ierwgjoergjio")) + + // test S256 + code = &auth_model.OAuth2AuthorizationCode{ + CodeChallengeMethod: "S256", + CodeChallenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", + } + assert.True(t, code.ValidateCodeChallenge("N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt")) + assert.False(t, code.ValidateCodeChallenge("wiogjerogorewngoenrgoiuenorg")) + + // test unknown + code = &auth_model.OAuth2AuthorizationCode{ + CodeChallengeMethod: "monkey", + CodeChallenge: "foiwgjioriogeiogjerger", + } + assert.False(t, code.ValidateCodeChallenge("foiwgjioriogeiogjerger")) + + // test no code challenge + code = &auth_model.OAuth2AuthorizationCode{ + CodeChallengeMethod: "", + CodeChallenge: "foierjiogerogerg", + } + assert.True(t, code.ValidateCodeChallenge("")) +} + +func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) { + code := &auth_model.OAuth2AuthorizationCode{ + RedirectURI: "https://example.com/callback", + Code: "thecode", + } + + redirect, err := code.GenerateRedirectURI("thestate") + require.NoError(t, err) + assert.Equal(t, "https://example.com/callback?code=thecode&state=thestate", redirect.String()) + + redirect, err = code.GenerateRedirectURI("") + require.NoError(t, err) + assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String()) +} + +func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + code := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2AuthorizationCode{Code: "authcode"}) + require.NoError(t, code.Invalidate(db.DefaultContext)) + unittest.AssertNotExistsBean(t, &auth_model.OAuth2AuthorizationCode{Code: "authcode"}) +} + +func TestOAuth2AuthorizationCode_TableName(t *testing.T) { + assert.Equal(t, "oauth2_authorization_code", new(auth_model.OAuth2AuthorizationCode).TableName()) +} + +func TestBuiltinApplicationsClientIDs(t *testing.T) { + clientIDs := auth_model.BuiltinApplicationsClientIDs() + slices.Sort(clientIDs) + assert.EqualValues(t, []string{"a4792ccc-144e-407e-86c9-5e7d8d9c3269", "d57cb8c4-630c-4168-8324-ec79935e18d4", "e90ee53c-94e2-48ac-9358-a874fb9e0662"}, clientIDs) +} + +func TestOrphanedOAuth2Applications(t *testing.T) { + defer unittest.OverrideFixtures( + unittest.FixturesOptions{ + Dir: filepath.Join(setting.AppWorkPath, "models/fixtures/"), + Base: setting.AppWorkPath, + Dirs: []string{"models/auth/TestOrphanedOAuth2Applications/"}, + }, + )() + require.NoError(t, unittest.PrepareTestDatabase()) + + count, err := auth_model.CountOrphanedOAuth2Applications(db.DefaultContext) + require.NoError(t, err) + assert.EqualValues(t, 1, count) + unittest.AssertExistsIf(t, true, &auth_model.OAuth2Application{ID: 1002}) + + _, err = auth_model.DeleteOrphanedOAuth2Applications(db.DefaultContext) + require.NoError(t, err) + + count, err = auth_model.CountOrphanedOAuth2Applications(db.DefaultContext) + require.NoError(t, err) + assert.EqualValues(t, 0, count) + unittest.AssertExistsIf(t, false, &auth_model.OAuth2Application{ID: 1002}) +} diff --git a/models/auth/session.go b/models/auth/session.go new file mode 100644 index 0000000..75a205f --- /dev/null +++ b/models/auth/session.go @@ -0,0 +1,120 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "context" + "fmt" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/timeutil" + + "xorm.io/builder" +) + +// Session represents a session compatible for go-chi session +type Session struct { + Key string `xorm:"pk CHAR(16)"` // has to be Key to match with go-chi/session + Data []byte `xorm:"BLOB"` // on MySQL this has a maximum size of 64Kb - this may need to be increased + Expiry timeutil.TimeStamp // has to be Expiry to match with go-chi/session +} + +func init() { + db.RegisterModel(new(Session)) +} + +// UpdateSession updates the session with provided id +func UpdateSession(ctx context.Context, key string, data []byte) error { + _, err := db.GetEngine(ctx).ID(key).Update(&Session{ + Data: data, + Expiry: timeutil.TimeStampNow(), + }) + return err +} + +// ReadSession reads the data for the provided session +func ReadSession(ctx context.Context, key string) (*Session, error) { + ctx, committer, err := db.TxContext(ctx) + if err != nil { + return nil, err + } + defer committer.Close() + + session, exist, err := db.Get[Session](ctx, builder.Eq{"`key`": key}) + if err != nil { + return nil, err + } else if !exist { + session = &Session{ + Key: key, + Expiry: timeutil.TimeStampNow(), + } + if err := db.Insert(ctx, session); err != nil { + return nil, err + } + } + + return session, committer.Commit() +} + +// ExistSession checks if a session exists +func ExistSession(ctx context.Context, key string) (bool, error) { + return db.Exist[Session](ctx, builder.Eq{"`key`": key}) +} + +// DestroySession destroys a session +func DestroySession(ctx context.Context, key string) error { + _, err := db.GetEngine(ctx).Delete(&Session{ + Key: key, + }) + return err +} + +// RegenerateSession regenerates a session from the old id +func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) { + ctx, committer, err := db.TxContext(ctx) + if err != nil { + return nil, err + } + defer committer.Close() + + if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": newKey}); err != nil { + return nil, err + } else if has { + return nil, fmt.Errorf("session Key: %s already exists", newKey) + } + + if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": oldKey}); err != nil { + return nil, err + } else if !has { + if err := db.Insert(ctx, &Session{ + Key: oldKey, + Expiry: timeutil.TimeStampNow(), + }); err != nil { + return nil, err + } + } + + if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil { + return nil, err + } + + s, _, err := db.Get[Session](ctx, builder.Eq{"`key`": newKey}) + if err != nil { + // is not exist, it should be impossible + return nil, err + } + + return s, committer.Commit() +} + +// CountSessions returns the number of sessions +func CountSessions(ctx context.Context) (int64, error) { + return db.GetEngine(ctx).Count(&Session{}) +} + +// CleanupSessions cleans up expired sessions +func CleanupSessions(ctx context.Context, maxLifetime int64) error { + _, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{}) + return err +} diff --git a/models/auth/session_test.go b/models/auth/session_test.go new file mode 100644 index 0000000..3b57239 --- /dev/null +++ b/models/auth/session_test.go @@ -0,0 +1,143 @@ +// Copyright 2023 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth_test + +import ( + "testing" + "time" + + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/unittest" + "code.gitea.io/gitea/modules/timeutil" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthSession(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + defer timeutil.MockUnset() + + key := "I-Like-Free-Software" + + t.Run("Create Session", func(t *testing.T) { + // Ensure it doesn't exist. + ok, err := auth.ExistSession(db.DefaultContext, key) + require.NoError(t, err) + assert.False(t, ok) + + preCount, err := auth.CountSessions(db.DefaultContext) + require.NoError(t, err) + + now := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) + timeutil.MockSet(now) + + // New session is created. + sess, err := auth.ReadSession(db.DefaultContext, key) + require.NoError(t, err) + assert.EqualValues(t, key, sess.Key) + assert.Empty(t, sess.Data) + assert.EqualValues(t, now.Unix(), sess.Expiry) + + // Ensure it exists. + ok, err = auth.ExistSession(db.DefaultContext, key) + require.NoError(t, err) + assert.True(t, ok) + + // Ensure the session is taken into account for count.. + postCount, err := auth.CountSessions(db.DefaultContext) + require.NoError(t, err) + assert.Greater(t, postCount, preCount) + }) + + t.Run("Update session", func(t *testing.T) { + data := []byte{0xba, 0xdd, 0xc0, 0xde} + now := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + timeutil.MockSet(now) + + // Update session. + err := auth.UpdateSession(db.DefaultContext, key, data) + require.NoError(t, err) + + timeutil.MockSet(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)) + + // Read updated session. + // Ensure data is updated and expiry is set from the update session call. + sess, err := auth.ReadSession(db.DefaultContext, key) + require.NoError(t, err) + assert.EqualValues(t, key, sess.Key) + assert.EqualValues(t, data, sess.Data) + assert.EqualValues(t, now.Unix(), sess.Expiry) + + timeutil.MockSet(now) + }) + + t.Run("Delete session", func(t *testing.T) { + // Ensure it't exist. + ok, err := auth.ExistSession(db.DefaultContext, key) + require.NoError(t, err) + assert.True(t, ok) + + preCount, err := auth.CountSessions(db.DefaultContext) + require.NoError(t, err) + + err = auth.DestroySession(db.DefaultContext, key) + require.NoError(t, err) + + // Ensure it doesn't exists. + ok, err = auth.ExistSession(db.DefaultContext, key) + require.NoError(t, err) + assert.False(t, ok) + + // Ensure the session is taken into account for count.. + postCount, err := auth.CountSessions(db.DefaultContext) + require.NoError(t, err) + assert.Less(t, postCount, preCount) + }) + + t.Run("Cleanup sessions", func(t *testing.T) { + timeutil.MockSet(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)) + + _, err := auth.ReadSession(db.DefaultContext, "sess-1") + require.NoError(t, err) + + // One minute later. + timeutil.MockSet(time.Date(2023, 1, 1, 0, 1, 0, 0, time.UTC)) + _, err = auth.ReadSession(db.DefaultContext, "sess-2") + require.NoError(t, err) + + // 5 minutes, shouldn't clean up anything. + err = auth.CleanupSessions(db.DefaultContext, 5*60) + require.NoError(t, err) + + ok, err := auth.ExistSession(db.DefaultContext, "sess-1") + require.NoError(t, err) + assert.True(t, ok) + + ok, err = auth.ExistSession(db.DefaultContext, "sess-2") + require.NoError(t, err) + assert.True(t, ok) + + // 1 minute, should clean up sess-1. + err = auth.CleanupSessions(db.DefaultContext, 60) + require.NoError(t, err) + + ok, err = auth.ExistSession(db.DefaultContext, "sess-1") + require.NoError(t, err) + assert.False(t, ok) + + ok, err = auth.ExistSession(db.DefaultContext, "sess-2") + require.NoError(t, err) + assert.True(t, ok) + + // Now, should clean up sess-2. + err = auth.CleanupSessions(db.DefaultContext, 0) + require.NoError(t, err) + + ok, err = auth.ExistSession(db.DefaultContext, "sess-2") + require.NoError(t, err) + assert.False(t, ok) + }) +} diff --git a/models/auth/source.go b/models/auth/source.go new file mode 100644 index 0000000..8f7c2a8 --- /dev/null +++ b/models/auth/source.go @@ -0,0 +1,412 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "context" + "fmt" + "reflect" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/optional" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/modules/util" + + "xorm.io/builder" + "xorm.io/xorm" + "xorm.io/xorm/convert" +) + +// Type represents an login type. +type Type int + +// Note: new type must append to the end of list to maintain compatibility. +const ( + NoType Type = iota + Plain // 1 + LDAP // 2 + SMTP // 3 + PAM // 4 + DLDAP // 5 + OAuth2 // 6 + SSPI // 7 + Remote // 8 +) + +// String returns the string name of the LoginType +func (typ Type) String() string { + return Names[typ] +} + +// Int returns the int value of the LoginType +func (typ Type) Int() int { + return int(typ) +} + +// Names contains the name of LoginType values. +var Names = map[Type]string{ + LDAP: "LDAP (via BindDN)", + DLDAP: "LDAP (simple auth)", // Via direct bind + SMTP: "SMTP", + PAM: "PAM", + OAuth2: "OAuth2", + SSPI: "SPNEGO with SSPI", + Remote: "Remote", +} + +// Config represents login config as far as the db is concerned +type Config interface { + convert.Conversion +} + +// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set +type SkipVerifiable interface { + IsSkipVerify() bool +} + +// HasTLSer configurations provide a HasTLS to check if TLS can be enabled +type HasTLSer interface { + HasTLS() bool +} + +// UseTLSer configurations provide a HasTLS to check if TLS is enabled +type UseTLSer interface { + UseTLS() bool +} + +// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys +type SSHKeyProvider interface { + ProvidesSSHKeys() bool +} + +// RegisterableSource configurations provide RegisterSource which needs to be run on creation +type RegisterableSource interface { + RegisterSource() error + UnregisterSource() error +} + +var registeredConfigs = map[Type]func() Config{} + +// RegisterTypeConfig register a config for a provided type +func RegisterTypeConfig(typ Type, exemplar Config) { + if reflect.TypeOf(exemplar).Kind() == reflect.Ptr { + // Pointer: + registeredConfigs[typ] = func() Config { + return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config) + } + return + } + + // Not a Pointer + registeredConfigs[typ] = func() Config { + return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config) + } +} + +// SourceSettable configurations can have their authSource set on them +type SourceSettable interface { + SetAuthSource(*Source) +} + +// Source represents an external way for authorizing users. +type Source struct { + ID int64 `xorm:"pk autoincr"` + Type Type + Name string `xorm:"UNIQUE"` + IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"` + IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"` + Cfg convert.Conversion `xorm:"TEXT"` + + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` +} + +// TableName xorm will read the table name from this method +func (Source) TableName() string { + return "login_source" +} + +func init() { + db.RegisterModel(new(Source)) +} + +// BeforeSet is invoked from XORM before setting the value of a field of this object. +func (source *Source) BeforeSet(colName string, val xorm.Cell) { + if colName == "type" { + typ := Type(db.Cell2Int64(val)) + constructor, ok := registeredConfigs[typ] + if !ok { + return + } + source.Cfg = constructor() + if settable, ok := source.Cfg.(SourceSettable); ok { + settable.SetAuthSource(source) + } + } +} + +// TypeName return name of this login source type. +func (source *Source) TypeName() string { + return Names[source.Type] +} + +// IsLDAP returns true of this source is of the LDAP type. +func (source *Source) IsLDAP() bool { + return source.Type == LDAP +} + +// IsDLDAP returns true of this source is of the DLDAP type. +func (source *Source) IsDLDAP() bool { + return source.Type == DLDAP +} + +// IsSMTP returns true of this source is of the SMTP type. +func (source *Source) IsSMTP() bool { + return source.Type == SMTP +} + +// IsPAM returns true of this source is of the PAM type. +func (source *Source) IsPAM() bool { + return source.Type == PAM +} + +// IsOAuth2 returns true of this source is of the OAuth2 type. +func (source *Source) IsOAuth2() bool { + return source.Type == OAuth2 +} + +// IsSSPI returns true of this source is of the SSPI type. +func (source *Source) IsSSPI() bool { + return source.Type == SSPI +} + +func (source *Source) IsRemote() bool { + return source.Type == Remote +} + +// HasTLS returns true of this source supports TLS. +func (source *Source) HasTLS() bool { + hasTLSer, ok := source.Cfg.(HasTLSer) + return ok && hasTLSer.HasTLS() +} + +// UseTLS returns true of this source is configured to use TLS. +func (source *Source) UseTLS() bool { + useTLSer, ok := source.Cfg.(UseTLSer) + return ok && useTLSer.UseTLS() +} + +// SkipVerify returns true if this source is configured to skip SSL +// verification. +func (source *Source) SkipVerify() bool { + skipVerifiable, ok := source.Cfg.(SkipVerifiable) + return ok && skipVerifiable.IsSkipVerify() +} + +// CreateSource inserts a AuthSource in the DB if not already +// existing with the given name. +func CreateSource(ctx context.Context, source *Source) error { + has, err := db.GetEngine(ctx).Where("name=?", source.Name).Exist(new(Source)) + if err != nil { + return err + } else if has { + return ErrSourceAlreadyExist{source.Name} + } + // Synchronization is only available with LDAP for now + if !source.IsLDAP() && !source.IsOAuth2() { + source.IsSyncEnabled = false + } + + _, err = db.GetEngine(ctx).Insert(source) + if err != nil { + return err + } + + if !source.IsActive { + return nil + } + + if settable, ok := source.Cfg.(SourceSettable); ok { + settable.SetAuthSource(source) + } + + registerableSource, ok := source.Cfg.(RegisterableSource) + if !ok { + return nil + } + + err = registerableSource.RegisterSource() + if err != nil { + // remove the AuthSource in case of errors while registering configuration + if _, err := db.GetEngine(ctx).ID(source.ID).Delete(new(Source)); err != nil { + log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err) + } + } + return err +} + +type FindSourcesOptions struct { + db.ListOptions + IsActive optional.Option[bool] + LoginType Type +} + +func (opts FindSourcesOptions) ToConds() builder.Cond { + conds := builder.NewCond() + if opts.IsActive.Has() { + conds = conds.And(builder.Eq{"is_active": opts.IsActive.Value()}) + } + if opts.LoginType != NoType { + conds = conds.And(builder.Eq{"`type`": opts.LoginType}) + } + return conds +} + +// IsSSPIEnabled returns true if there is at least one activated login +// source of type LoginSSPI +func IsSSPIEnabled(ctx context.Context) bool { + exist, err := db.Exist[Source](ctx, FindSourcesOptions{ + IsActive: optional.Some(true), + LoginType: SSPI, + }.ToConds()) + if err != nil { + log.Error("IsSSPIEnabled: failed to query active SSPI sources: %v", err) + return false + } + return exist +} + +// GetSourceByID returns login source by given ID. +func GetSourceByID(ctx context.Context, id int64) (*Source, error) { + source := new(Source) + if id == 0 { + source.Cfg = registeredConfigs[NoType]() + // Set this source to active + // FIXME: allow disabling of db based password authentication in future + source.IsActive = true + return source, nil + } + + has, err := db.GetEngine(ctx).ID(id).Get(source) + if err != nil { + return nil, err + } else if !has { + return nil, ErrSourceNotExist{id} + } + return source, nil +} + +func GetSourceByName(ctx context.Context, name string) (*Source, error) { + source := &Source{} + has, err := db.GetEngine(ctx).Where("name = ?", name).Get(source) + if err != nil { + return nil, err + } else if !has { + return nil, ErrSourceNotExist{} + } + return source, nil +} + +// UpdateSource updates a Source record in DB. +func UpdateSource(ctx context.Context, source *Source) error { + var originalSource *Source + if source.IsOAuth2() { + // keep track of the original values so we can restore in case of errors while registering OAuth2 providers + var err error + if originalSource, err = GetSourceByID(ctx, source.ID); err != nil { + return err + } + } + + has, err := db.GetEngine(ctx).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source)) + if err != nil { + return err + } else if has { + return ErrSourceAlreadyExist{source.Name} + } + + _, err = db.GetEngine(ctx).ID(source.ID).AllCols().Update(source) + if err != nil { + return err + } + + if !source.IsActive { + return nil + } + + if settable, ok := source.Cfg.(SourceSettable); ok { + settable.SetAuthSource(source) + } + + registerableSource, ok := source.Cfg.(RegisterableSource) + if !ok { + return nil + } + + err = registerableSource.RegisterSource() + if err != nil { + // restore original values since we cannot update the provider it self + if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil { + log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err) + } + } + return err +} + +// ErrSourceNotExist represents a "SourceNotExist" kind of error. +type ErrSourceNotExist struct { + ID int64 +} + +// IsErrSourceNotExist checks if an error is a ErrSourceNotExist. +func IsErrSourceNotExist(err error) bool { + _, ok := err.(ErrSourceNotExist) + return ok +} + +func (err ErrSourceNotExist) Error() string { + return fmt.Sprintf("login source does not exist [id: %d]", err.ID) +} + +// Unwrap unwraps this as a ErrNotExist err +func (err ErrSourceNotExist) Unwrap() error { + return util.ErrNotExist +} + +// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error. +type ErrSourceAlreadyExist struct { + Name string +} + +// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist. +func IsErrSourceAlreadyExist(err error) bool { + _, ok := err.(ErrSourceAlreadyExist) + return ok +} + +func (err ErrSourceAlreadyExist) Error() string { + return fmt.Sprintf("login source already exists [name: %s]", err.Name) +} + +// Unwrap unwraps this as a ErrExist err +func (err ErrSourceAlreadyExist) Unwrap() error { + return util.ErrAlreadyExist +} + +// ErrSourceInUse represents a "SourceInUse" kind of error. +type ErrSourceInUse struct { + ID int64 +} + +// IsErrSourceInUse checks if an error is a ErrSourceInUse. +func IsErrSourceInUse(err error) bool { + _, ok := err.(ErrSourceInUse) + return ok +} + +func (err ErrSourceInUse) Error() string { + return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID) +} diff --git a/models/auth/source_test.go b/models/auth/source_test.go new file mode 100644 index 0000000..522fecc --- /dev/null +++ b/models/auth/source_test.go @@ -0,0 +1,61 @@ +// Copyright 2019 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth_test + +import ( + "strings" + "testing" + + auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/unittest" + "code.gitea.io/gitea/modules/json" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "xorm.io/xorm/schemas" +) + +type TestSource struct { + Provider string + ClientID string + ClientSecret string + OpenIDConnectAutoDiscoveryURL string + IconURL string +} + +// FromDB fills up a LDAPConfig from serialized format. +func (source *TestSource) FromDB(bs []byte) error { + return json.Unmarshal(bs, &source) +} + +// ToDB exports a LDAPConfig to a serialized format. +func (source *TestSource) ToDB() ([]byte, error) { + return json.Marshal(source) +} + +func TestDumpAuthSource(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + authSourceSchema, err := db.TableInfo(new(auth_model.Source)) + require.NoError(t, err) + + auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource)) + + auth_model.CreateSource(db.DefaultContext, &auth_model.Source{ + Type: auth_model.OAuth2, + Name: "TestSource", + IsActive: false, + Cfg: &TestSource{ + Provider: "ConvertibleSourceName", + ClientID: "42", + }, + }) + + sb := new(strings.Builder) + + db.DumpTables([]*schemas.Table{authSourceSchema}, sb) + + assert.Contains(t, sb.String(), `"Provider":"ConvertibleSourceName"`) +} diff --git a/models/auth/twofactor.go b/models/auth/twofactor.go new file mode 100644 index 0000000..d0c341a --- /dev/null +++ b/models/auth/twofactor.go @@ -0,0 +1,166 @@ +// Copyright 2017 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "context" + "crypto/md5" + "crypto/sha256" + "crypto/subtle" + "encoding/base32" + "encoding/base64" + "encoding/hex" + "fmt" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/secret" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/modules/util" + + "github.com/pquerna/otp/totp" + "golang.org/x/crypto/pbkdf2" +) + +// +// Two-factor authentication +// + +// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication. +type ErrTwoFactorNotEnrolled struct { + UID int64 +} + +// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled. +func IsErrTwoFactorNotEnrolled(err error) bool { + _, ok := err.(ErrTwoFactorNotEnrolled) + return ok +} + +func (err ErrTwoFactorNotEnrolled) Error() string { + return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID) +} + +// Unwrap unwraps this as a ErrNotExist err +func (err ErrTwoFactorNotEnrolled) Unwrap() error { + return util.ErrNotExist +} + +// TwoFactor represents a two-factor authentication token. +type TwoFactor struct { + ID int64 `xorm:"pk autoincr"` + UID int64 `xorm:"UNIQUE"` + Secret string + ScratchSalt string + ScratchHash string + LastUsedPasscode string `xorm:"VARCHAR(10)"` + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` +} + +func init() { + db.RegisterModel(new(TwoFactor)) +} + +// GenerateScratchToken recreates the scratch token the user is using. +func (t *TwoFactor) GenerateScratchToken() (string, error) { + tokenBytes, err := util.CryptoRandomBytes(6) + if err != nil { + return "", err + } + // these chars are specially chosen, avoid ambiguous chars like `0`, `O`, `1`, `I`. + const base32Chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" + token := base32.NewEncoding(base32Chars).WithPadding(base32.NoPadding).EncodeToString(tokenBytes) + t.ScratchSalt, _ = util.CryptoRandomString(10) + t.ScratchHash = HashToken(token, t.ScratchSalt) + return token, nil +} + +// HashToken return the hashable salt +func HashToken(token, salt string) string { + tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New) + return hex.EncodeToString(tempHash) +} + +// VerifyScratchToken verifies if the specified scratch token is valid. +func (t *TwoFactor) VerifyScratchToken(token string) bool { + if len(token) == 0 { + return false + } + tempHash := HashToken(token, t.ScratchSalt) + return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1 +} + +func (t *TwoFactor) getEncryptionKey() []byte { + k := md5.Sum([]byte(setting.SecretKey)) + return k[:] +} + +// SetSecret sets the 2FA secret. +func (t *TwoFactor) SetSecret(secretString string) error { + secretBytes, err := secret.AesEncrypt(t.getEncryptionKey(), []byte(secretString)) + if err != nil { + return err + } + t.Secret = base64.StdEncoding.EncodeToString(secretBytes) + return nil +} + +// ValidateTOTP validates the provided passcode. +func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) { + decodedStoredSecret, err := base64.StdEncoding.DecodeString(t.Secret) + if err != nil { + return false, err + } + secretBytes, err := secret.AesDecrypt(t.getEncryptionKey(), decodedStoredSecret) + if err != nil { + return false, err + } + secretStr := string(secretBytes) + return totp.Validate(passcode, secretStr), nil +} + +// NewTwoFactor creates a new two-factor authentication token. +func NewTwoFactor(ctx context.Context, t *TwoFactor) error { + _, err := db.GetEngine(ctx).Insert(t) + return err +} + +// UpdateTwoFactor updates a two-factor authentication token. +func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error { + _, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t) + return err +} + +// GetTwoFactorByUID returns the two-factor authentication token associated with +// the user, if any. +func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) { + twofa := &TwoFactor{} + has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa) + if err != nil { + return nil, err + } else if !has { + return nil, ErrTwoFactorNotEnrolled{uid} + } + return twofa, nil +} + +// HasTwoFactorByUID returns the two-factor authentication token associated with +// the user, if any. +func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) { + return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{}) +} + +// DeleteTwoFactorByID deletes two-factor authentication token by given ID. +func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error { + cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{ + UID: userID, + }) + if err != nil { + return err + } else if cnt != 1 { + return ErrTwoFactorNotEnrolled{userID} + } + return nil +} diff --git a/models/auth/webauthn.go b/models/auth/webauthn.go new file mode 100644 index 0000000..aa13cf6 --- /dev/null +++ b/models/auth/webauthn.go @@ -0,0 +1,209 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth + +import ( + "context" + "fmt" + "strings" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/timeutil" + "code.gitea.io/gitea/modules/util" + + "github.com/go-webauthn/webauthn/webauthn" +) + +// ErrWebAuthnCredentialNotExist represents a "ErrWebAuthnCRedentialNotExist" kind of error. +type ErrWebAuthnCredentialNotExist struct { + ID int64 + CredentialID []byte +} + +func (err ErrWebAuthnCredentialNotExist) Error() string { + if len(err.CredentialID) == 0 { + return fmt.Sprintf("WebAuthn credential does not exist [id: %d]", err.ID) + } + return fmt.Sprintf("WebAuthn credential does not exist [credential_id: %x]", err.CredentialID) +} + +// Unwrap unwraps this as a ErrNotExist err +func (err ErrWebAuthnCredentialNotExist) Unwrap() error { + return util.ErrNotExist +} + +// IsErrWebAuthnCredentialNotExist checks if an error is a ErrWebAuthnCredentialNotExist. +func IsErrWebAuthnCredentialNotExist(err error) bool { + _, ok := err.(ErrWebAuthnCredentialNotExist) + return ok +} + +// WebAuthnCredential represents the WebAuthn credential data for a public-key +// credential conformant to WebAuthn Level 3 +type WebAuthnCredential struct { + ID int64 `xorm:"pk autoincr"` + Name string + LowerName string `xorm:"unique(s)"` + UserID int64 `xorm:"INDEX unique(s)"` + CredentialID []byte `xorm:"INDEX VARBINARY(1024)"` + PublicKey []byte + AttestationType string + AAGUID []byte + SignCount uint32 `xorm:"BIGINT"` + CloneWarning bool + BackupEligible bool `XORM:"NOT NULL DEFAULT false"` + BackupState bool `XORM:"NOT NULL DEFAULT false"` + // If legacy is set to true, backup_eligible and backup_state isn't set. + Legacy bool `XORM:"NOT NULL DEFAULT true"` + CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"` + UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"` +} + +func init() { + db.RegisterModel(new(WebAuthnCredential)) +} + +// TableName returns a better table name for WebAuthnCredential +func (cred WebAuthnCredential) TableName() string { + return "webauthn_credential" +} + +// UpdateSignCount will update the database value of SignCount +func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error { + _, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred) + return err +} + +// UpdateFromLegacy update the values that aren't present on legacy credentials. +func (cred *WebAuthnCredential) UpdateFromLegacy(ctx context.Context) error { + _, err := db.GetEngine(ctx).ID(cred.ID).Cols("legacy", "backup_eligible", "backup_state").Update(cred) + return err +} + +// BeforeInsert will be invoked by XORM before updating a record +func (cred *WebAuthnCredential) BeforeInsert() { + cred.LowerName = strings.ToLower(cred.Name) +} + +// BeforeUpdate will be invoked by XORM before updating a record +func (cred *WebAuthnCredential) BeforeUpdate() { + cred.LowerName = strings.ToLower(cred.Name) +} + +// AfterLoad is invoked from XORM after setting the values of all fields of this object. +func (cred *WebAuthnCredential) AfterLoad() { + cred.LowerName = strings.ToLower(cred.Name) +} + +// WebAuthnCredentialList is a list of *WebAuthnCredential +type WebAuthnCredentialList []*WebAuthnCredential + +// ToCredentials will convert all WebAuthnCredentials to webauthn.Credentials +func (list WebAuthnCredentialList) ToCredentials() []webauthn.Credential { + creds := make([]webauthn.Credential, 0, len(list)) + for _, cred := range list { + creds = append(creds, webauthn.Credential{ + ID: cred.CredentialID, + PublicKey: cred.PublicKey, + AttestationType: cred.AttestationType, + Flags: webauthn.CredentialFlags{ + BackupEligible: cred.BackupEligible, + BackupState: cred.BackupState, + }, + Authenticator: webauthn.Authenticator{ + AAGUID: cred.AAGUID, + SignCount: cred.SignCount, + CloneWarning: cred.CloneWarning, + }, + }) + } + return creds +} + +// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user +func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) { + creds := make(WebAuthnCredentialList, 0) + return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds) +} + +// ExistsWebAuthnCredentialsForUID returns if the given user has credentials +func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) { + return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) +} + +// GetWebAuthnCredentialByName returns WebAuthn credential by id +func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) { + cred := new(WebAuthnCredential) + if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil { + return nil, err + } else if !found { + return nil, ErrWebAuthnCredentialNotExist{} + } + return cred, nil +} + +// GetWebAuthnCredentialByID returns WebAuthn credential by id +func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) { + cred := new(WebAuthnCredential) + if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil { + return nil, err + } else if !found { + return nil, ErrWebAuthnCredentialNotExist{ID: id} + } + return cred, nil +} + +// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations +func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) { + return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{}) +} + +// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID +func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) { + cred := new(WebAuthnCredential) + if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil { + return nil, err + } else if !found { + return nil, ErrWebAuthnCredentialNotExist{CredentialID: credID} + } + return cred, nil +} + +// CreateCredential will create a new WebAuthnCredential from the given Credential +func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) { + c := &WebAuthnCredential{ + UserID: userID, + Name: name, + CredentialID: cred.ID, + PublicKey: cred.PublicKey, + AttestationType: cred.AttestationType, + AAGUID: cred.Authenticator.AAGUID, + SignCount: cred.Authenticator.SignCount, + CloneWarning: false, + BackupEligible: cred.Flags.BackupEligible, + BackupState: cred.Flags.BackupState, + Legacy: false, + } + + if err := db.Insert(ctx, c); err != nil { + return nil, err + } + return c, nil +} + +// DeleteCredential will delete WebAuthnCredential +func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) { + had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{}) + return had > 0, err +} + +// WebAuthnCredentials implementns the webauthn.User interface +func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) { + dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID) + if err != nil { + return nil, err + } + + return dbCreds.ToCredentials(), nil +} diff --git a/models/auth/webauthn_test.go b/models/auth/webauthn_test.go new file mode 100644 index 0000000..e1cd652 --- /dev/null +++ b/models/auth/webauthn_test.go @@ -0,0 +1,78 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package auth_test + +import ( + "testing" + + auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/unittest" + + "github.com/go-webauthn/webauthn/webauthn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetWebAuthnCredentialByID(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + res, err := auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 1) + require.NoError(t, err) + assert.Equal(t, "WebAuthn credential", res.Name) + + _, err = auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 342432) + require.Error(t, err) + assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err)) +} + +func TestGetWebAuthnCredentialsByUID(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + res, err := auth_model.GetWebAuthnCredentialsByUID(db.DefaultContext, 32) + require.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, "WebAuthn credential", res[0].Name) +} + +func TestWebAuthnCredential_TableName(t *testing.T) { + assert.Equal(t, "webauthn_credential", auth_model.WebAuthnCredential{}.TableName()) +} + +func TestWebAuthnCredential_UpdateSignCount(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1}) + cred.SignCount = 1 + require.NoError(t, cred.UpdateSignCount(db.DefaultContext)) + unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1}) +} + +func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1}) + cred.SignCount = 0xffffffff + require.NoError(t, cred.UpdateSignCount(db.DefaultContext)) + unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff}) +} + +func TestWebAuthenCredential_UpdateFromLegacy(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1, Legacy: true}) + cred.Legacy = false + cred.BackupEligible = true + cred.BackupState = true + require.NoError(t, cred.UpdateFromLegacy(db.DefaultContext)) + unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, BackupEligible: true, BackupState: true}, "legacy = false") +} + +func TestCreateCredential(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + res, err := auth_model.CreateCredential(db.DefaultContext, 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test"), Flags: webauthn.CredentialFlags{BackupEligible: true, BackupState: true}}) + require.NoError(t, err) + assert.Equal(t, "WebAuthn Created Credential", res.Name) + assert.Equal(t, []byte("Test"), res.CredentialID) + + unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{Name: "WebAuthn Created Credential", UserID: 1, BackupEligible: true, BackupState: true}, "legacy = false") +} -- cgit v1.2.3