diff options
Diffstat (limited to 'services/auth/source')
42 files changed, 3794 insertions, 0 deletions
diff --git a/services/auth/source/db/assert_interface_test.go b/services/auth/source/db/assert_interface_test.go new file mode 100644 index 0000000..62387c7 --- /dev/null +++ b/services/auth/source/db/assert_interface_test.go @@ -0,0 +1,20 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db_test + +import ( + auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/services/auth" + "code.gitea.io/gitea/services/auth/source/db" +) + +// This test file exists to assert that our Source exposes the interfaces that we expect +// It tightly binds the interfaces and implementation without breaking go import cycles + +type sourceInterface interface { + auth.PasswordAuthenticator + auth_model.Config +} + +var _ (sourceInterface) = &db.Source{} diff --git a/services/auth/source/db/authenticate.go b/services/auth/source/db/authenticate.go new file mode 100644 index 0000000..8160141 --- /dev/null +++ b/services/auth/source/db/authenticate.go @@ -0,0 +1,87 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "context" + "fmt" + + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" +) + +// ErrUserPasswordNotSet represents a "ErrUserPasswordNotSet" kind of error. +type ErrUserPasswordNotSet struct { + UID int64 + Name string +} + +func (err ErrUserPasswordNotSet) Error() string { + return fmt.Sprintf("user's password isn't set [uid: %d, name: %s]", err.UID, err.Name) +} + +// Unwrap unwraps this error as a ErrInvalidArgument error +func (err ErrUserPasswordNotSet) Unwrap() error { + return util.ErrInvalidArgument +} + +// ErrUserPasswordInvalid represents a "ErrUserPasswordInvalid" kind of error. +type ErrUserPasswordInvalid struct { + UID int64 + Name string +} + +func (err ErrUserPasswordInvalid) Error() string { + return fmt.Sprintf("user's password is invalid [uid: %d, name: %s]", err.UID, err.Name) +} + +// Unwrap unwraps this error as a ErrInvalidArgument error +func (err ErrUserPasswordInvalid) Unwrap() error { + return util.ErrInvalidArgument +} + +// Authenticate authenticates the provided user against the DB +func Authenticate(ctx context.Context, user *user_model.User, login, password string) (*user_model.User, error) { + if user == nil { + return nil, user_model.ErrUserNotExist{Name: login} + } + + if !user.IsPasswordSet() { + return nil, ErrUserPasswordNotSet{UID: user.ID, Name: user.Name} + } else if !user.ValidatePassword(password) { + return nil, ErrUserPasswordInvalid{UID: user.ID, Name: user.Name} + } + + // Update password hash if server password hash algorithm have changed + // Or update the password when the salt length doesn't match the current + // recommended salt length, this in order to migrate user's salts to a more secure salt. + if user.PasswdHashAlgo != setting.PasswordHashAlgo || len(user.Salt) != user_model.SaltByteLength*2 { + if err := user.SetPassword(password); err != nil { + return nil, err + } + if err := user_model.UpdateUserCols(ctx, user, "passwd", "passwd_hash_algo", "salt"); err != nil { + return nil, err + } + } + + // WARN: DON'T check user.IsActive, that will be checked on reqSign so that + // user could be hinted to resend confirm email. + if user.ProhibitLogin { + return nil, user_model.ErrUserProhibitLogin{ + UID: user.ID, + Name: user.Name, + } + } + + // attempting to login as a non-user account + if user.Type != user_model.UserTypeIndividual { + return nil, user_model.ErrUserProhibitLogin{ + UID: user.ID, + Name: user.Name, + } + } + + return user, nil +} diff --git a/services/auth/source/db/source.go b/services/auth/source/db/source.go new file mode 100644 index 0000000..bb2270c --- /dev/null +++ b/services/auth/source/db/source.go @@ -0,0 +1,35 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package db + +import ( + "context" + + "code.gitea.io/gitea/models/auth" + user_model "code.gitea.io/gitea/models/user" +) + +// Source is a password authentication service +type Source struct{} + +// FromDB fills up an OAuth2Config from serialized format. +func (source *Source) FromDB(bs []byte) error { + return nil +} + +// ToDB exports the config to a byte slice to be saved into database (this method is just dummy and does nothing for DB source) +func (source *Source) ToDB() ([]byte, error) { + return nil, nil +} + +// Authenticate queries if login/password is valid against the PAM, +// and create a local user if success when enabled. +func (source *Source) Authenticate(ctx context.Context, user *user_model.User, login, password string) (*user_model.User, error) { + return Authenticate(ctx, user, login, password) +} + +func init() { + auth.RegisterTypeConfig(auth.NoType, &Source{}) + auth.RegisterTypeConfig(auth.Plain, &Source{}) +} diff --git a/services/auth/source/ldap/README.md b/services/auth/source/ldap/README.md new file mode 100644 index 0000000..34c8117 --- /dev/null +++ b/services/auth/source/ldap/README.md @@ -0,0 +1,131 @@ +# Gitea LDAP Authentication Module + +## About + +This authentication module attempts to authorize and authenticate a user +against an LDAP server. It provides two methods of authentication: LDAP via +BindDN, and LDAP simple authentication. + +LDAP via BindDN functions like most LDAP authentication systems. First, it +queries the LDAP server using a Bind DN and searches for the user that is +attempting to sign in. If the user is found, the module attempts to bind to the +server using the user's supplied credentials. If this succeeds, the user has +been authenticated, and his account information is retrieved and passed to the +Gogs login infrastructure. + +LDAP simple authentication does not utilize a Bind DN. Instead, it binds +directly with the LDAP server using the user's supplied credentials. If the bind +succeeds and no filter rules out the user, the user is authenticated. + +LDAP via BindDN is recommended for most users. By using a Bind DN, the server +can perform authorization by restricting which entries the Bind DN account can +read. Further, using a Bind DN with reduced permissions can reduce security risk +in the face of application bugs. + +## Usage + +To use this module, add an LDAP authentication source via the Authentications +section in the admin panel. Both the LDAP via BindDN and the simple auth LDAP +share the following fields: + +* Authorization Name **(required)** + * A name to assign to the new method of authorization. + +* Host **(required)** + * The address where the LDAP server can be reached. + * Example: mydomain.com + +* Port **(required)** + * The port to use when connecting to the server. + * Example: 636 + +* Enable TLS Encryption (optional) + * Whether to use TLS when connecting to the LDAP server. + +* Admin Filter (optional) + * An LDAP filter specifying if a user should be given administrator + privileges. If a user accounts passes the filter, the user will be + privileged as an administrator. + * Example: (objectClass=adminAccount) + +* First name attribute (optional) + * The attribute of the user's LDAP record containing the user's first name. + This will be used to populate their account information. + * Example: givenName + +* Surname attribute (optional) + * The attribute of the user's LDAP record containing the user's surname This + will be used to populate their account information. + * Example: sn + +* E-mail attribute **(required)** + * The attribute of the user's LDAP record containing the user's email + address. This will be used to populate their account information. + * Example: mail + +**LDAP via BindDN** adds the following fields: + +* Bind DN (optional) + * The DN to bind to the LDAP server with when searching for the user. This + may be left blank to perform an anonymous search. + * Example: cn=Search,dc=mydomain,dc=com + +* Bind Password (optional) + * The password for the Bind DN specified above, if any. _Note: The password + is stored in plaintext at the server. As such, ensure that your Bind DN + has as few privileges as possible._ + +* User Search Base **(required)** + * The LDAP base at which user accounts will be searched for. + * Example: ou=Users,dc=mydomain,dc=com + +* User Filter **(required)** + * An LDAP filter declaring how to find the user record that is attempting to + authenticate. The '%[1]s' matching parameter will be substituted with the + user's username. + * Example: (&(objectClass=posixAccount)(|(uid=%[1]s)(mail=%[1]s))) + +**LDAP using simple auth** adds the following fields: + +* User DN **(required)** + * A template to use as the user's DN. The `%s` matching parameter will be + substituted with the user's username. + * Example: cn=%s,ou=Users,dc=mydomain,dc=com + * Example: uid=%s,ou=Users,dc=mydomain,dc=com + +* User Search Base (optional) + * The LDAP base at which user accounts will be searched for. + * Example: ou=Users,dc=mydomain,dc=com + +* User Filter **(required)** + * An LDAP filter declaring when a user should be allowed to log in. The `%[1]s` + matching parameter will be substituted with the user's username. + * Example: (&(objectClass=posixAccount)(|(cn=%[1]s)(mail=%[1]s))) + * Example: (&(objectClass=posixAccount)(|(uid=%[1]s)(mail=%[1]s))) + +**Verify group membership in LDAP** uses the following fields: + +* Group Search Base (optional) + * The LDAP DN used for groups. + * Example: ou=group,dc=mydomain,dc=com + +* Group Name Filter (optional) + * An LDAP filter declaring how to find valid groups in the above DN. + * Example: (|(cn=gitea_users)(cn=admins)) + +* User Attribute in Group (optional) + * The user attribute that is used to reference a user in the group object. + * Example: uid if the group objects contains a member: bender and the user object contains a uid: bender. + * Example: dn if the group object contains a member: uid=bender,ou=users,dc=planetexpress,dc=com. + +* Group Attribute for User (optional) + * The attribute of the group object that lists/contains the group members. + * Example: memberUid or member + +* Team group map (optional) + * Automatically add users to Organization teams, depending on LDAP group memberships. + * Note: this function only adds users to teams, it never removes users. + * Example: {"cn=MyGroup,cn=groups,dc=example,dc=org": {"MyGiteaOrganization": ["MyGiteaTeam1", "MyGiteaTeam2", ...], ...}, ...} + +* Team group map removal (optional) + * If set to true, users will be removed from teams if they are not members of the corresponding group. diff --git a/services/auth/source/ldap/assert_interface_test.go b/services/auth/source/ldap/assert_interface_test.go new file mode 100644 index 0000000..3334768 --- /dev/null +++ b/services/auth/source/ldap/assert_interface_test.go @@ -0,0 +1,27 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package ldap_test + +import ( + auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/services/auth" + "code.gitea.io/gitea/services/auth/source/ldap" +) + +// This test file exists to assert that our Source exposes the interfaces that we expect +// It tightly binds the interfaces and implementation without breaking go import cycles + +type sourceInterface interface { + auth.PasswordAuthenticator + auth.SynchronizableSource + auth.LocalTwoFASkipper + auth_model.SSHKeyProvider + auth_model.Config + auth_model.SkipVerifiable + auth_model.HasTLSer + auth_model.UseTLSer + auth_model.SourceSettable +} + +var _ (sourceInterface) = &ldap.Source{} diff --git a/services/auth/source/ldap/security_protocol.go b/services/auth/source/ldap/security_protocol.go new file mode 100644 index 0000000..af83ce1 --- /dev/null +++ b/services/auth/source/ldap/security_protocol.go @@ -0,0 +1,31 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package ldap + +// SecurityProtocol protocol type +type SecurityProtocol int + +// Note: new type must be added at the end of list to maintain compatibility. +const ( + SecurityProtocolUnencrypted SecurityProtocol = iota + SecurityProtocolLDAPS + SecurityProtocolStartTLS +) + +// String returns the name of the SecurityProtocol +func (s SecurityProtocol) String() string { + return SecurityProtocolNames[s] +} + +// Int returns the int value of the SecurityProtocol +func (s SecurityProtocol) Int() int { + return int(s) +} + +// SecurityProtocolNames contains the name of SecurityProtocol values. +var SecurityProtocolNames = map[SecurityProtocol]string{ + SecurityProtocolUnencrypted: "Unencrypted", + SecurityProtocolLDAPS: "LDAPS", + SecurityProtocolStartTLS: "StartTLS", +} diff --git a/services/auth/source/ldap/source.go b/services/auth/source/ldap/source.go new file mode 100644 index 0000000..ba407b3 --- /dev/null +++ b/services/auth/source/ldap/source.go @@ -0,0 +1,122 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package ldap + +import ( + "strings" + + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/secret" + "code.gitea.io/gitea/modules/setting" +) + +// .____ ________ _____ __________ +// | | \______ \ / _ \\______ \ +// | | | | \ / /_\ \| ___/ +// | |___ | ` \/ | \ | +// |_______ \/_______ /\____|__ /____| +// \/ \/ \/ + +// Package ldap provide functions & structure to query a LDAP ldap directory +// For now, it's mainly tested again an MS Active Directory service, see README.md for more information + +// Source Basic LDAP authentication service +type Source struct { + Name string // canonical name (ie. corporate.ad) + Host string // LDAP host + Port int // port number + SecurityProtocol SecurityProtocol + SkipVerify bool + BindDN string // DN to bind with + BindPasswordEncrypt string // Encrypted Bind BN password + BindPassword string // Bind DN password + UserBase string // Base search path for users + UserDN string // Template for the DN of the user for simple auth + DefaultDomainName string // DomainName used if none are in the field, default "localhost.local" + AttributeUsername string // Username attribute + AttributeName string // First name attribute + AttributeSurname string // Surname attribute + AttributeMail string // E-mail attribute + AttributesInBind bool // fetch attributes in bind context (not user) + AttributeSSHPublicKey string // LDAP SSH Public Key attribute + AttributeAvatar string + SearchPageSize uint32 // Search with paging page size + Filter string // Query filter to validate entry + AdminFilter string // Query filter to check if user is admin + RestrictedFilter string // Query filter to check if user is restricted + Enabled bool // if this source is disabled + AllowDeactivateAll bool // Allow an empty search response to deactivate all users from this source + GroupsEnabled bool // if the group checking is enabled + GroupDN string // Group Search Base + GroupFilter string // Group Name Filter + GroupMemberUID string // Group Attribute containing array of UserUID + GroupTeamMap string // Map LDAP groups to teams + GroupTeamMapRemoval bool // Remove user from teams which are synchronized and user is not a member of the corresponding LDAP group + UserUID string // User Attribute listed in Group + SkipLocalTwoFA bool `json:",omitempty"` // Skip Local 2fa for users authenticated with this source + + // reference to the authSource + authSource *auth.Source +} + +// FromDB fills up a LDAPConfig from serialized format. +func (source *Source) FromDB(bs []byte) error { + err := json.UnmarshalHandleDoubleEncode(bs, &source) + if err != nil { + return err + } + if source.BindPasswordEncrypt != "" { + source.BindPassword, err = secret.DecryptSecret(setting.SecretKey, source.BindPasswordEncrypt) + source.BindPasswordEncrypt = "" + } + return err +} + +// ToDB exports a LDAPConfig to a serialized format. +func (source *Source) ToDB() ([]byte, error) { + var err error + source.BindPasswordEncrypt, err = secret.EncryptSecret(setting.SecretKey, source.BindPassword) + if err != nil { + return nil, err + } + source.BindPassword = "" + return json.Marshal(source) +} + +// SecurityProtocolName returns the name of configured security +// protocol. +func (source *Source) SecurityProtocolName() string { + return SecurityProtocolNames[source.SecurityProtocol] +} + +// IsSkipVerify returns if SkipVerify is set +func (source *Source) IsSkipVerify() bool { + return source.SkipVerify +} + +// HasTLS returns if HasTLS +func (source *Source) HasTLS() bool { + return source.SecurityProtocol > SecurityProtocolUnencrypted +} + +// UseTLS returns if UseTLS +func (source *Source) UseTLS() bool { + return source.SecurityProtocol != SecurityProtocolUnencrypted +} + +// ProvidesSSHKeys returns if this source provides SSH Keys +func (source *Source) ProvidesSSHKeys() bool { + return len(strings.TrimSpace(source.AttributeSSHPublicKey)) > 0 +} + +// SetAuthSource sets the related AuthSource +func (source *Source) SetAuthSource(authSource *auth.Source) { + source.authSource = authSource +} + +func init() { + auth.RegisterTypeConfig(auth.LDAP, &Source{}) + auth.RegisterTypeConfig(auth.DLDAP, &Source{}) +} diff --git a/services/auth/source/ldap/source_authenticate.go b/services/auth/source/ldap/source_authenticate.go new file mode 100644 index 0000000..68ecd16 --- /dev/null +++ b/services/auth/source/ldap/source_authenticate.go @@ -0,0 +1,124 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package ldap + +import ( + "context" + "fmt" + "strings" + + asymkey_model "code.gitea.io/gitea/models/asymkey" + "code.gitea.io/gitea/models/auth" + user_model "code.gitea.io/gitea/models/user" + auth_module "code.gitea.io/gitea/modules/auth" + "code.gitea.io/gitea/modules/optional" + source_service "code.gitea.io/gitea/services/auth/source" + user_service "code.gitea.io/gitea/services/user" +) + +// Authenticate queries if login/password is valid against the LDAP directory pool, +// and create a local user if success when enabled. +func (source *Source) Authenticate(ctx context.Context, user *user_model.User, userName, password string) (*user_model.User, error) { + loginName := userName + if user != nil { + loginName = user.LoginName + } + sr := source.SearchEntry(loginName, password, source.authSource.Type == auth.DLDAP) + if sr == nil { + // User not in LDAP, do nothing + return nil, user_model.ErrUserNotExist{Name: loginName} + } + // Fallback. + if len(sr.Username) == 0 { + sr.Username = userName + } + if len(sr.Mail) == 0 { + sr.Mail = fmt.Sprintf("%s@localhost.local", sr.Username) + } + isAttributeSSHPublicKeySet := len(strings.TrimSpace(source.AttributeSSHPublicKey)) > 0 + + // Update User admin flag if exist + if isExist, err := user_model.IsUserExist(ctx, 0, sr.Username); err != nil { + return nil, err + } else if isExist { + if user == nil { + user, err = user_model.GetUserByName(ctx, sr.Username) + if err != nil { + return nil, err + } + } + if user != nil && !user.ProhibitLogin { + opts := &user_service.UpdateOptions{} + if len(source.AdminFilter) > 0 && user.IsAdmin != sr.IsAdmin { + // Change existing admin flag only if AdminFilter option is set + opts.IsAdmin = optional.Some(sr.IsAdmin) + } + if !sr.IsAdmin && len(source.RestrictedFilter) > 0 && user.IsRestricted != sr.IsRestricted { + // Change existing restricted flag only if RestrictedFilter option is set + opts.IsRestricted = optional.Some(sr.IsRestricted) + } + if opts.IsAdmin.Has() || opts.IsRestricted.Has() { + if err := user_service.UpdateUser(ctx, user, opts); err != nil { + return nil, err + } + } + } + } + + if user != nil { + if isAttributeSSHPublicKeySet && asymkey_model.SynchronizePublicKeys(ctx, user, source.authSource, sr.SSHPublicKey) { + if err := asymkey_model.RewriteAllPublicKeys(ctx); err != nil { + return user, err + } + } + } else { + user = &user_model.User{ + LowerName: strings.ToLower(sr.Username), + Name: sr.Username, + FullName: composeFullName(sr.Name, sr.Surname, sr.Username), + Email: sr.Mail, + LoginType: source.authSource.Type, + LoginSource: source.authSource.ID, + LoginName: userName, + IsAdmin: sr.IsAdmin, + } + overwriteDefault := &user_model.CreateUserOverwriteOptions{ + IsRestricted: optional.Some(sr.IsRestricted), + IsActive: optional.Some(true), + } + + err := user_model.CreateUser(ctx, user, overwriteDefault) + if err != nil { + return user, err + } + + if isAttributeSSHPublicKeySet && asymkey_model.AddPublicKeysBySource(ctx, user, source.authSource, sr.SSHPublicKey) { + if err := asymkey_model.RewriteAllPublicKeys(ctx); err != nil { + return user, err + } + } + if len(source.AttributeAvatar) > 0 { + if err := user_service.UploadAvatar(ctx, user, sr.Avatar); err != nil { + return user, err + } + } + } + + if source.GroupsEnabled && (source.GroupTeamMap != "" || source.GroupTeamMapRemoval) { + groupTeamMapping, err := auth_module.UnmarshalGroupTeamMapping(source.GroupTeamMap) + if err != nil { + return user, err + } + if err := source_service.SyncGroupsToTeams(ctx, user, sr.Groups, groupTeamMapping, source.GroupTeamMapRemoval); err != nil { + return user, err + } + } + + return user, nil +} + +// IsSkipLocalTwoFA returns if this source should skip local 2fa for password authentication +func (source *Source) IsSkipLocalTwoFA() bool { + return source.SkipLocalTwoFA +} diff --git a/services/auth/source/ldap/source_search.go b/services/auth/source/ldap/source_search.go new file mode 100644 index 0000000..2a61386 --- /dev/null +++ b/services/auth/source/ldap/source_search.go @@ -0,0 +1,516 @@ +// Copyright 2014 The Gogs Authors. All rights reserved. +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package ldap + +import ( + "crypto/tls" + "fmt" + "net" + "strconv" + "strings" + + "code.gitea.io/gitea/modules/container" + "code.gitea.io/gitea/modules/log" + + "github.com/go-ldap/ldap/v3" +) + +// SearchResult : user data +type SearchResult struct { + Username string // Username + Name string // Name + Surname string // Surname + Mail string // E-mail address + SSHPublicKey []string // SSH Public Key + IsAdmin bool // if user is administrator + IsRestricted bool // if user is restricted + LowerName string // LowerName + Avatar []byte + Groups container.Set[string] +} + +func (source *Source) sanitizedUserQuery(username string) (string, bool) { + // See http://tools.ietf.org/search/rfc4515 + badCharacters := "\x00()*\\" + if strings.ContainsAny(username, badCharacters) { + log.Debug("'%s' contains invalid query characters. Aborting.", username) + return "", false + } + + return fmt.Sprintf(source.Filter, username), true +} + +func (source *Source) sanitizedUserDN(username string) (string, bool) { + // See http://tools.ietf.org/search/rfc4514: "special characters" + badCharacters := "\x00()*\\,='\"#+;<>" + if strings.ContainsAny(username, badCharacters) { + log.Debug("'%s' contains invalid DN characters. Aborting.", username) + return "", false + } + + return fmt.Sprintf(source.UserDN, username), true +} + +func (source *Source) sanitizedGroupFilter(group string) (string, bool) { + // See http://tools.ietf.org/search/rfc4515 + badCharacters := "\x00*\\" + if strings.ContainsAny(group, badCharacters) { + log.Trace("Group filter invalid query characters: %s", group) + return "", false + } + + return group, true +} + +func (source *Source) sanitizedGroupDN(groupDn string) (string, bool) { + // See http://tools.ietf.org/search/rfc4514: "special characters" + badCharacters := "\x00()*\\'\"#+;<>" + if strings.ContainsAny(groupDn, badCharacters) || strings.HasPrefix(groupDn, " ") || strings.HasSuffix(groupDn, " ") { + log.Trace("Group DN contains invalid query characters: %s", groupDn) + return "", false + } + + return groupDn, true +} + +func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) { + log.Trace("Search for LDAP user: %s", name) + + // A search for the user. + userFilter, ok := source.sanitizedUserQuery(name) + if !ok { + return "", false + } + + log.Trace("Searching for DN using filter %s and base %s", userFilter, source.UserBase) + search := ldap.NewSearchRequest( + source.UserBase, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, + false, userFilter, []string{}, nil) + + // Ensure we found a user + sr, err := l.Search(search) + if err != nil || len(sr.Entries) < 1 { + log.Debug("Failed search using filter[%s]: %v", userFilter, err) + return "", false + } else if len(sr.Entries) > 1 { + log.Debug("Filter '%s' returned more than one user.", userFilter) + return "", false + } + + userDN := sr.Entries[0].DN + if userDN == "" { + log.Error("LDAP search was successful, but found no DN!") + return "", false + } + + return userDN, true +} + +func dial(source *Source) (*ldap.Conn, error) { + log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify) + + tlsConfig := &tls.Config{ + ServerName: source.Host, + InsecureSkipVerify: source.SkipVerify, + } + + if source.SecurityProtocol == SecurityProtocolLDAPS { + return ldap.DialTLS("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port)), tlsConfig) + } + + conn, err := ldap.Dial("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port))) + if err != nil { + return nil, fmt.Errorf("error during Dial: %w", err) + } + + if source.SecurityProtocol == SecurityProtocolStartTLS { + if err = conn.StartTLS(tlsConfig); err != nil { + conn.Close() + return nil, fmt.Errorf("error during StartTLS: %w", err) + } + } + + return conn, nil +} + +func bindUser(l *ldap.Conn, userDN, passwd string) error { + log.Trace("Binding with userDN: %s", userDN) + err := l.Bind(userDN, passwd) + if err != nil { + log.Debug("LDAP auth. failed for %s, reason: %v", userDN, err) + return err + } + log.Trace("Bound successfully with userDN: %s", userDN) + return err +} + +func checkAdmin(l *ldap.Conn, ls *Source, userDN string) bool { + if len(ls.AdminFilter) == 0 { + return false + } + log.Trace("Checking admin with filter %s and base %s", ls.AdminFilter, userDN) + search := ldap.NewSearchRequest( + userDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, ls.AdminFilter, + []string{ls.AttributeName}, + nil) + + sr, err := l.Search(search) + + if err != nil { + log.Error("LDAP Admin Search with filter %s for %s failed unexpectedly! (%v)", ls.AdminFilter, userDN, err) + } else if len(sr.Entries) < 1 { + log.Trace("LDAP Admin Search found no matching entries.") + } else { + return true + } + return false +} + +func checkRestricted(l *ldap.Conn, ls *Source, userDN string) bool { + if len(ls.RestrictedFilter) == 0 { + return false + } + if ls.RestrictedFilter == "*" { + return true + } + log.Trace("Checking restricted with filter %s and base %s", ls.RestrictedFilter, userDN) + search := ldap.NewSearchRequest( + userDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, ls.RestrictedFilter, + []string{ls.AttributeName}, + nil) + + sr, err := l.Search(search) + + if err != nil { + log.Error("LDAP Restrictred Search with filter %s for %s failed unexpectedly! (%v)", ls.RestrictedFilter, userDN, err) + } else if len(sr.Entries) < 1 { + log.Trace("LDAP Restricted Search found no matching entries.") + } else { + return true + } + return false +} + +// List all group memberships of a user +func (source *Source) listLdapGroupMemberships(l *ldap.Conn, uid string, applyGroupFilter bool) container.Set[string] { + ldapGroups := make(container.Set[string]) + + groupFilter, ok := source.sanitizedGroupFilter(source.GroupFilter) + if !ok { + return ldapGroups + } + + groupDN, ok := source.sanitizedGroupDN(source.GroupDN) + if !ok { + return ldapGroups + } + + var searchFilter string + if applyGroupFilter && groupFilter != "" { + searchFilter = fmt.Sprintf("(&(%s)(%s=%s))", groupFilter, source.GroupMemberUID, ldap.EscapeFilter(uid)) + } else { + searchFilter = fmt.Sprintf("(%s=%s)", source.GroupMemberUID, ldap.EscapeFilter(uid)) + } + result, err := l.Search(ldap.NewSearchRequest( + groupDN, + ldap.ScopeWholeSubtree, + ldap.NeverDerefAliases, + 0, + 0, + false, + searchFilter, + []string{}, + nil, + )) + if err != nil { + log.Error("Failed group search in LDAP with filter [%s]: %v", searchFilter, err) + return ldapGroups + } + + for _, entry := range result.Entries { + if entry.DN == "" { + log.Error("LDAP search was successful, but found no DN!") + continue + } + ldapGroups.Add(entry.DN) + } + + return ldapGroups +} + +func (source *Source) getUserAttributeListedInGroup(entry *ldap.Entry) string { + if strings.ToLower(source.UserUID) == "dn" { + return entry.DN + } + + return entry.GetAttributeValue(source.UserUID) +} + +// SearchEntry : search an LDAP source if an entry (name, passwd) is valid and in the specific filter +func (source *Source) SearchEntry(name, passwd string, directBind bool) *SearchResult { + // See https://tools.ietf.org/search/rfc4513#section-5.1.2 + if len(passwd) == 0 { + log.Debug("Auth. failed for %s, password cannot be empty", name) + return nil + } + l, err := dial(source) + if err != nil { + log.Error("LDAP Connect error, %s:%v", source.Host, err) + source.Enabled = false + return nil + } + defer l.Close() + + var userDN string + if directBind { + log.Trace("LDAP will bind directly via UserDN template: %s", source.UserDN) + + var ok bool + userDN, ok = source.sanitizedUserDN(name) + + if !ok { + return nil + } + + err = bindUser(l, userDN, passwd) + if err != nil { + return nil + } + + if source.UserBase != "" { + // not everyone has a CN compatible with input name so we need to find + // the real userDN in that case + + userDN, ok = source.findUserDN(l, name) + if !ok { + return nil + } + } + } else { + log.Trace("LDAP will use BindDN.") + + var found bool + + if source.BindDN != "" && source.BindPassword != "" { + err := l.Bind(source.BindDN, source.BindPassword) + if err != nil { + log.Debug("Failed to bind as BindDN[%s]: %v", source.BindDN, err) + return nil + } + log.Trace("Bound as BindDN %s", source.BindDN) + } else { + log.Trace("Proceeding with anonymous LDAP search.") + } + + userDN, found = source.findUserDN(l, name) + if !found { + return nil + } + } + + if !source.AttributesInBind { + // binds user (checking password) before looking-up attributes in user context + err = bindUser(l, userDN, passwd) + if err != nil { + return nil + } + } + + userFilter, ok := source.sanitizedUserQuery(name) + if !ok { + return nil + } + + isAttributeSSHPublicKeySet := len(strings.TrimSpace(source.AttributeSSHPublicKey)) > 0 + isAtributeAvatarSet := len(strings.TrimSpace(source.AttributeAvatar)) > 0 + + attribs := []string{source.AttributeUsername, source.AttributeName, source.AttributeSurname, source.AttributeMail} + if len(strings.TrimSpace(source.UserUID)) > 0 { + attribs = append(attribs, source.UserUID) + } + if isAttributeSSHPublicKeySet { + attribs = append(attribs, source.AttributeSSHPublicKey) + } + if isAtributeAvatarSet { + attribs = append(attribs, source.AttributeAvatar) + } + + log.Trace("Fetching attributes '%v', '%v', '%v', '%v', '%v', '%v', '%v' with filter '%s' and base '%s'", source.AttributeUsername, source.AttributeName, source.AttributeSurname, source.AttributeMail, source.AttributeSSHPublicKey, source.AttributeAvatar, source.UserUID, userFilter, userDN) + search := ldap.NewSearchRequest( + userDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, userFilter, + attribs, nil) + + sr, err := l.Search(search) + if err != nil { + log.Error("LDAP Search failed unexpectedly! (%v)", err) + return nil + } else if len(sr.Entries) < 1 { + if directBind { + log.Trace("User filter inhibited user login.") + } else { + log.Trace("LDAP Search found no matching entries.") + } + + return nil + } + + var sshPublicKey []string + var Avatar []byte + + username := sr.Entries[0].GetAttributeValue(source.AttributeUsername) + firstname := sr.Entries[0].GetAttributeValue(source.AttributeName) + surname := sr.Entries[0].GetAttributeValue(source.AttributeSurname) + mail := sr.Entries[0].GetAttributeValue(source.AttributeMail) + + if isAttributeSSHPublicKeySet { + sshPublicKey = sr.Entries[0].GetAttributeValues(source.AttributeSSHPublicKey) + } + + isAdmin := checkAdmin(l, source, userDN) + + var isRestricted bool + if !isAdmin { + isRestricted = checkRestricted(l, source, userDN) + } + + if isAtributeAvatarSet { + Avatar = sr.Entries[0].GetRawAttributeValue(source.AttributeAvatar) + } + + // Check group membership + var usersLdapGroups container.Set[string] + if source.GroupsEnabled { + userAttributeListedInGroup := source.getUserAttributeListedInGroup(sr.Entries[0]) + usersLdapGroups = source.listLdapGroupMemberships(l, userAttributeListedInGroup, true) + + if source.GroupFilter != "" && len(usersLdapGroups) == 0 { + return nil + } + } + + if !directBind && source.AttributesInBind { + // binds user (checking password) after looking-up attributes in BindDN context + err = bindUser(l, userDN, passwd) + if err != nil { + return nil + } + } + + return &SearchResult{ + LowerName: strings.ToLower(username), + Username: username, + Name: firstname, + Surname: surname, + Mail: mail, + SSHPublicKey: sshPublicKey, + IsAdmin: isAdmin, + IsRestricted: isRestricted, + Avatar: Avatar, + Groups: usersLdapGroups, + } +} + +// UsePagedSearch returns if need to use paged search +func (source *Source) UsePagedSearch() bool { + return source.SearchPageSize > 0 +} + +// SearchEntries : search an LDAP source for all users matching userFilter +func (source *Source) SearchEntries() ([]*SearchResult, error) { + l, err := dial(source) + if err != nil { + log.Error("LDAP Connect error, %s:%v", source.Host, err) + source.Enabled = false + return nil, err + } + defer l.Close() + + if source.BindDN != "" && source.BindPassword != "" { + err := l.Bind(source.BindDN, source.BindPassword) + if err != nil { + log.Debug("Failed to bind as BindDN[%s]: %v", source.BindDN, err) + return nil, err + } + log.Trace("Bound as BindDN %s", source.BindDN) + } else { + log.Trace("Proceeding with anonymous LDAP search.") + } + + userFilter := fmt.Sprintf(source.Filter, "*") + + isAttributeSSHPublicKeySet := len(strings.TrimSpace(source.AttributeSSHPublicKey)) > 0 + isAtributeAvatarSet := len(strings.TrimSpace(source.AttributeAvatar)) > 0 + + attribs := []string{source.AttributeUsername, source.AttributeName, source.AttributeSurname, source.AttributeMail, source.UserUID} + if isAttributeSSHPublicKeySet { + attribs = append(attribs, source.AttributeSSHPublicKey) + } + if isAtributeAvatarSet { + attribs = append(attribs, source.AttributeAvatar) + } + + log.Trace("Fetching attributes '%v', '%v', '%v', '%v', '%v', '%v' with filter %s and base %s", source.AttributeUsername, source.AttributeName, source.AttributeSurname, source.AttributeMail, source.AttributeSSHPublicKey, source.AttributeAvatar, userFilter, source.UserBase) + search := ldap.NewSearchRequest( + source.UserBase, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, userFilter, + attribs, nil) + + var sr *ldap.SearchResult + if source.UsePagedSearch() { + sr, err = l.SearchWithPaging(search, source.SearchPageSize) + } else { + sr, err = l.Search(search) + } + if err != nil { + log.Error("LDAP Search failed unexpectedly! (%v)", err) + return nil, err + } + + result := make([]*SearchResult, 0, len(sr.Entries)) + + for _, v := range sr.Entries { + var usersLdapGroups container.Set[string] + if source.GroupsEnabled { + userAttributeListedInGroup := source.getUserAttributeListedInGroup(v) + + if source.GroupFilter != "" { + usersLdapGroups = source.listLdapGroupMemberships(l, userAttributeListedInGroup, true) + if len(usersLdapGroups) == 0 { + continue + } + } + + if source.GroupTeamMap != "" || source.GroupTeamMapRemoval { + usersLdapGroups = source.listLdapGroupMemberships(l, userAttributeListedInGroup, false) + } + } + + user := &SearchResult{ + Username: v.GetAttributeValue(source.AttributeUsername), + Name: v.GetAttributeValue(source.AttributeName), + Surname: v.GetAttributeValue(source.AttributeSurname), + Mail: v.GetAttributeValue(source.AttributeMail), + IsAdmin: checkAdmin(l, source, v.DN), + Groups: usersLdapGroups, + } + + if !user.IsAdmin { + user.IsRestricted = checkRestricted(l, source, v.DN) + } + + if isAttributeSSHPublicKeySet { + user.SSHPublicKey = v.GetAttributeValues(source.AttributeSSHPublicKey) + } + + if isAtributeAvatarSet { + user.Avatar = v.GetRawAttributeValue(source.AttributeAvatar) + } + + user.LowerName = strings.ToLower(user.Username) + + result = append(result, user) + } + + return result, nil +} diff --git a/services/auth/source/ldap/source_sync.go b/services/auth/source/ldap/source_sync.go new file mode 100644 index 0000000..1f70eda --- /dev/null +++ b/services/auth/source/ldap/source_sync.go @@ -0,0 +1,232 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package ldap + +import ( + "context" + "fmt" + "strings" + + asymkey_model "code.gitea.io/gitea/models/asymkey" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/organization" + user_model "code.gitea.io/gitea/models/user" + auth_module "code.gitea.io/gitea/modules/auth" + "code.gitea.io/gitea/modules/container" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/optional" + source_service "code.gitea.io/gitea/services/auth/source" + user_service "code.gitea.io/gitea/services/user" +) + +// Sync causes this ldap source to synchronize its users with the db +func (source *Source) Sync(ctx context.Context, updateExisting bool) error { + log.Trace("Doing: SyncExternalUsers[%s]", source.authSource.Name) + + isAttributeSSHPublicKeySet := len(strings.TrimSpace(source.AttributeSSHPublicKey)) > 0 + var sshKeysNeedUpdate bool + + // Find all users with this login type - FIXME: Should this be an iterator? + users, err := user_model.GetUsersBySource(ctx, source.authSource) + if err != nil { + log.Error("SyncExternalUsers: %v", err) + return err + } + select { + case <-ctx.Done(): + log.Warn("SyncExternalUsers: Cancelled before update of %s", source.authSource.Name) + return db.ErrCancelledf("Before update of %s", source.authSource.Name) + default: + } + + usernameUsers := make(map[string]*user_model.User, len(users)) + mailUsers := make(map[string]*user_model.User, len(users)) + keepActiveUsers := make(container.Set[int64]) + + for _, u := range users { + usernameUsers[u.LowerName] = u + mailUsers[strings.ToLower(u.Email)] = u + } + + sr, err := source.SearchEntries() + if err != nil { + log.Error("SyncExternalUsers LDAP source failure [%s], skipped", source.authSource.Name) + return nil + } + + if len(sr) == 0 { + if !source.AllowDeactivateAll { + log.Error("LDAP search found no entries but did not report an error. Refusing to deactivate all users") + return nil + } + log.Warn("LDAP search found no entries but did not report an error. All users will be deactivated as per settings") + } + + orgCache := make(map[string]*organization.Organization) + teamCache := make(map[string]*organization.Team) + + groupTeamMapping, err := auth_module.UnmarshalGroupTeamMapping(source.GroupTeamMap) + if err != nil { + return err + } + + for _, su := range sr { + select { + case <-ctx.Done(): + log.Warn("SyncExternalUsers: Cancelled at update of %s before completed update of users", source.authSource.Name) + // Rewrite authorized_keys file if LDAP Public SSH Key attribute is set and any key was added or removed + if sshKeysNeedUpdate { + err = asymkey_model.RewriteAllPublicKeys(ctx) + if err != nil { + log.Error("RewriteAllPublicKeys: %v", err) + } + } + return db.ErrCancelledf("During update of %s before completed update of users", source.authSource.Name) + default: + } + if len(su.Username) == 0 && len(su.Mail) == 0 { + continue + } + + var usr *user_model.User + if len(su.Username) > 0 { + usr = usernameUsers[su.LowerName] + } + if usr == nil && len(su.Mail) > 0 { + usr = mailUsers[strings.ToLower(su.Mail)] + } + + if usr != nil { + keepActiveUsers.Add(usr.ID) + } else if len(su.Username) == 0 { + // we cannot create the user if su.Username is empty + continue + } + + if len(su.Mail) == 0 { + domainName := source.DefaultDomainName + if len(domainName) == 0 { + domainName = "localhost.local" + } + su.Mail = fmt.Sprintf("%s@%s", su.Username, domainName) + } + + fullName := composeFullName(su.Name, su.Surname, su.Username) + // If no existing user found, create one + if usr == nil { + log.Trace("SyncExternalUsers[%s]: Creating user %s", source.authSource.Name, su.Username) + + usr = &user_model.User{ + LowerName: su.LowerName, + Name: su.Username, + FullName: fullName, + LoginType: source.authSource.Type, + LoginSource: source.authSource.ID, + LoginName: su.Username, + Email: su.Mail, + IsAdmin: su.IsAdmin, + } + overwriteDefault := &user_model.CreateUserOverwriteOptions{ + IsRestricted: optional.Some(su.IsRestricted), + IsActive: optional.Some(true), + } + + err = user_model.CreateUser(ctx, usr, overwriteDefault) + if err != nil { + log.Error("SyncExternalUsers[%s]: Error creating user %s: %v", source.authSource.Name, su.Username, err) + } + + if err == nil && isAttributeSSHPublicKeySet { + log.Trace("SyncExternalUsers[%s]: Adding LDAP Public SSH Keys for user %s", source.authSource.Name, usr.Name) + if asymkey_model.AddPublicKeysBySource(ctx, usr, source.authSource, su.SSHPublicKey) { + sshKeysNeedUpdate = true + } + } + + if err == nil && len(source.AttributeAvatar) > 0 { + _ = user_service.UploadAvatar(ctx, usr, su.Avatar) + } + } else if updateExisting { + // Synchronize SSH Public Key if that attribute is set + if isAttributeSSHPublicKeySet && asymkey_model.SynchronizePublicKeys(ctx, usr, source.authSource, su.SSHPublicKey) { + sshKeysNeedUpdate = true + } + + // Check if user data has changed + if (len(source.AdminFilter) > 0 && usr.IsAdmin != su.IsAdmin) || + (len(source.RestrictedFilter) > 0 && usr.IsRestricted != su.IsRestricted) || + !strings.EqualFold(usr.Email, su.Mail) || + usr.FullName != fullName || + !usr.IsActive { + log.Trace("SyncExternalUsers[%s]: Updating user %s", source.authSource.Name, usr.Name) + + opts := &user_service.UpdateOptions{ + FullName: optional.Some(fullName), + IsActive: optional.Some(true), + } + if source.AdminFilter != "" { + opts.IsAdmin = optional.Some(su.IsAdmin) + } + // Change existing restricted flag only if RestrictedFilter option is set + if !su.IsAdmin && source.RestrictedFilter != "" { + opts.IsRestricted = optional.Some(su.IsRestricted) + } + + if err := user_service.UpdateUser(ctx, usr, opts); err != nil { + log.Error("SyncExternalUsers[%s]: Error updating user %s: %v", source.authSource.Name, usr.Name, err) + } + + if err := user_service.ReplacePrimaryEmailAddress(ctx, usr, su.Mail); err != nil { + log.Error("SyncExternalUsers[%s]: Error updating user %s primary email %s: %v", source.authSource.Name, usr.Name, su.Mail, err) + } + } + + if usr.IsUploadAvatarChanged(su.Avatar) { + if err == nil && len(source.AttributeAvatar) > 0 { + _ = user_service.UploadAvatar(ctx, usr, su.Avatar) + } + } + } + // Synchronize LDAP groups with organization and team memberships + if source.GroupsEnabled && (source.GroupTeamMap != "" || source.GroupTeamMapRemoval) { + if err := source_service.SyncGroupsToTeamsCached(ctx, usr, su.Groups, groupTeamMapping, source.GroupTeamMapRemoval, orgCache, teamCache); err != nil { + log.Error("SyncGroupsToTeamsCached: %v", err) + } + } + } + + // Rewrite authorized_keys file if LDAP Public SSH Key attribute is set and any key was added or removed + if sshKeysNeedUpdate { + err = asymkey_model.RewriteAllPublicKeys(ctx) + if err != nil { + log.Error("RewriteAllPublicKeys: %v", err) + } + } + + select { + case <-ctx.Done(): + log.Warn("SyncExternalUsers: Cancelled during update of %s before delete users", source.authSource.Name) + return db.ErrCancelledf("During update of %s before delete users", source.authSource.Name) + default: + } + + // Deactivate users not present in LDAP + if updateExisting { + for _, usr := range users { + if keepActiveUsers.Contains(usr.ID) { + continue + } + + log.Trace("SyncExternalUsers[%s]: Deactivating user %s", source.authSource.Name, usr.Name) + + opts := &user_service.UpdateOptions{ + IsActive: optional.Some(false), + } + if err := user_service.UpdateUser(ctx, usr, opts); err != nil { + log.Error("SyncExternalUsers[%s]: Error deactivating user %s: %v", source.authSource.Name, usr.Name, err) + } + } + } + return nil +} diff --git a/services/auth/source/ldap/util.go b/services/auth/source/ldap/util.go new file mode 100644 index 0000000..bd11e2d --- /dev/null +++ b/services/auth/source/ldap/util.go @@ -0,0 +1,18 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package ldap + +// composeFullName composes a firstname surname or username +func composeFullName(firstname, surname, username string) string { + switch { + case len(firstname) == 0 && len(surname) == 0: + return username + case len(firstname) == 0: + return surname + case len(surname) == 0: + return firstname + default: + return firstname + " " + surname + } +} diff --git a/services/auth/source/oauth2/assert_interface_test.go b/services/auth/source/oauth2/assert_interface_test.go new file mode 100644 index 0000000..56fe0e4 --- /dev/null +++ b/services/auth/source/oauth2/assert_interface_test.go @@ -0,0 +1,22 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2_test + +import ( + auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/services/auth" + "code.gitea.io/gitea/services/auth/source/oauth2" +) + +// This test file exists to assert that our Source exposes the interfaces that we expect +// It tightly binds the interfaces and implementation without breaking go import cycles + +type sourceInterface interface { + auth_model.Config + auth_model.SourceSettable + auth_model.RegisterableSource + auth.PasswordAuthenticator +} + +var _ (sourceInterface) = &oauth2.Source{} diff --git a/services/auth/source/oauth2/init.go b/services/auth/source/oauth2/init.go new file mode 100644 index 0000000..5c25681 --- /dev/null +++ b/services/auth/source/oauth2/init.go @@ -0,0 +1,86 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "context" + "encoding/gob" + "net/http" + "sync" + + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/optional" + "code.gitea.io/gitea/modules/setting" + + "github.com/google/uuid" + "github.com/gorilla/sessions" + "github.com/markbates/goth/gothic" +) + +var gothRWMutex = sync.RWMutex{} + +// UsersStoreKey is the key for the store +const UsersStoreKey = "gitea-oauth2-sessions" + +// ProviderHeaderKey is the HTTP header key +const ProviderHeaderKey = "gitea-oauth2-provider" + +// Init initializes the oauth source +func Init(ctx context.Context) error { + if err := InitSigningKey(); err != nil { + return err + } + + // Lock our mutex + gothRWMutex.Lock() + + gob.Register(&sessions.Session{}) + + gothic.Store = &SessionsStore{ + maxLength: int64(setting.OAuth2.MaxTokenLength), + } + + gothic.SetState = func(req *http.Request) string { + return uuid.New().String() + } + + gothic.GetProviderName = func(req *http.Request) (string, error) { + return req.Header.Get(ProviderHeaderKey), nil + } + + // Unlock our mutex + gothRWMutex.Unlock() + + return initOAuth2Sources(ctx) +} + +// ResetOAuth2 clears existing OAuth2 providers and loads them from DB +func ResetOAuth2(ctx context.Context) error { + ClearProviders() + return initOAuth2Sources(ctx) +} + +// initOAuth2Sources is used to load and register all active OAuth2 providers +func initOAuth2Sources(ctx context.Context) error { + authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{ + IsActive: optional.Some(true), + LoginType: auth.OAuth2, + }) + if err != nil { + return err + } + for _, source := range authSources { + oauth2Source, ok := source.Cfg.(*Source) + if !ok { + continue + } + err := oauth2Source.RegisterSource() + if err != nil { + log.Critical("Unable to register source: %s due to Error: %v.", source.Name, err) + } + } + return nil +} diff --git a/services/auth/source/oauth2/jwtsigningkey.go b/services/auth/source/oauth2/jwtsigningkey.go new file mode 100644 index 0000000..070fffe --- /dev/null +++ b/services/auth/source/oauth2/jwtsigningkey.go @@ -0,0 +1,404 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "strings" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" + + "github.com/golang-jwt/jwt/v5" +) + +// ErrInvalidAlgorithmType represents an invalid algorithm error. +type ErrInvalidAlgorithmType struct { + Algorithm string +} + +func (err ErrInvalidAlgorithmType) Error() string { + return fmt.Sprintf("JWT signing algorithm is not supported: %s", err.Algorithm) +} + +// JWTSigningKey represents a algorithm/key pair to sign JWTs +type JWTSigningKey interface { + IsSymmetric() bool + SigningMethod() jwt.SigningMethod + SignKey() any + VerifyKey() any + ToJWK() (map[string]string, error) + PreProcessToken(*jwt.Token) +} + +type hmacSigningKey struct { + signingMethod jwt.SigningMethod + secret []byte +} + +func (key hmacSigningKey) IsSymmetric() bool { + return true +} + +func (key hmacSigningKey) SigningMethod() jwt.SigningMethod { + return key.signingMethod +} + +func (key hmacSigningKey) SignKey() any { + return key.secret +} + +func (key hmacSigningKey) VerifyKey() any { + return key.secret +} + +func (key hmacSigningKey) ToJWK() (map[string]string, error) { + return map[string]string{ + "kty": "oct", + "alg": key.SigningMethod().Alg(), + }, nil +} + +func (key hmacSigningKey) PreProcessToken(*jwt.Token) {} + +type rsaSingingKey struct { + signingMethod jwt.SigningMethod + key *rsa.PrivateKey + id string +} + +func newRSASingingKey(signingMethod jwt.SigningMethod, key *rsa.PrivateKey) (rsaSingingKey, error) { + kid, err := util.CreatePublicKeyFingerprint(key.Public().(*rsa.PublicKey)) + if err != nil { + return rsaSingingKey{}, err + } + + return rsaSingingKey{ + signingMethod, + key, + base64.RawURLEncoding.EncodeToString(kid), + }, nil +} + +func (key rsaSingingKey) IsSymmetric() bool { + return false +} + +func (key rsaSingingKey) SigningMethod() jwt.SigningMethod { + return key.signingMethod +} + +func (key rsaSingingKey) SignKey() any { + return key.key +} + +func (key rsaSingingKey) VerifyKey() any { + return key.key.Public() +} + +func (key rsaSingingKey) ToJWK() (map[string]string, error) { + pubKey := key.key.Public().(*rsa.PublicKey) + + return map[string]string{ + "kty": "RSA", + "alg": key.SigningMethod().Alg(), + "kid": key.id, + "e": base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pubKey.E)).Bytes()), + "n": base64.RawURLEncoding.EncodeToString(pubKey.N.Bytes()), + }, nil +} + +func (key rsaSingingKey) PreProcessToken(token *jwt.Token) { + token.Header["kid"] = key.id +} + +type eddsaSigningKey struct { + signingMethod jwt.SigningMethod + key ed25519.PrivateKey + id string +} + +func newEdDSASingingKey(signingMethod jwt.SigningMethod, key ed25519.PrivateKey) (eddsaSigningKey, error) { + kid, err := util.CreatePublicKeyFingerprint(key.Public().(ed25519.PublicKey)) + if err != nil { + return eddsaSigningKey{}, err + } + + return eddsaSigningKey{ + signingMethod, + key, + base64.RawURLEncoding.EncodeToString(kid), + }, nil +} + +func (key eddsaSigningKey) IsSymmetric() bool { + return false +} + +func (key eddsaSigningKey) SigningMethod() jwt.SigningMethod { + return key.signingMethod +} + +func (key eddsaSigningKey) SignKey() any { + return key.key +} + +func (key eddsaSigningKey) VerifyKey() any { + return key.key.Public() +} + +func (key eddsaSigningKey) ToJWK() (map[string]string, error) { + pubKey := key.key.Public().(ed25519.PublicKey) + + return map[string]string{ + "alg": key.SigningMethod().Alg(), + "kid": key.id, + "kty": "OKP", + "crv": "Ed25519", + "x": base64.RawURLEncoding.EncodeToString(pubKey), + }, nil +} + +func (key eddsaSigningKey) PreProcessToken(token *jwt.Token) { + token.Header["kid"] = key.id +} + +type ecdsaSingingKey struct { + signingMethod jwt.SigningMethod + key *ecdsa.PrivateKey + id string +} + +func newECDSASingingKey(signingMethod jwt.SigningMethod, key *ecdsa.PrivateKey) (ecdsaSingingKey, error) { + kid, err := util.CreatePublicKeyFingerprint(key.Public().(*ecdsa.PublicKey)) + if err != nil { + return ecdsaSingingKey{}, err + } + + return ecdsaSingingKey{ + signingMethod, + key, + base64.RawURLEncoding.EncodeToString(kid), + }, nil +} + +func (key ecdsaSingingKey) IsSymmetric() bool { + return false +} + +func (key ecdsaSingingKey) SigningMethod() jwt.SigningMethod { + return key.signingMethod +} + +func (key ecdsaSingingKey) SignKey() any { + return key.key +} + +func (key ecdsaSingingKey) VerifyKey() any { + return key.key.Public() +} + +func (key ecdsaSingingKey) ToJWK() (map[string]string, error) { + pubKey := key.key.Public().(*ecdsa.PublicKey) + + return map[string]string{ + "kty": "EC", + "alg": key.SigningMethod().Alg(), + "kid": key.id, + "crv": pubKey.Params().Name, + "x": base64.RawURLEncoding.EncodeToString(pubKey.X.Bytes()), + "y": base64.RawURLEncoding.EncodeToString(pubKey.Y.Bytes()), + }, nil +} + +func (key ecdsaSingingKey) PreProcessToken(token *jwt.Token) { + token.Header["kid"] = key.id +} + +// CreateJWTSigningKey creates a signing key from an algorithm / key pair. +func CreateJWTSigningKey(algorithm string, key any) (JWTSigningKey, error) { + var signingMethod jwt.SigningMethod + switch algorithm { + case "HS256": + signingMethod = jwt.SigningMethodHS256 + case "HS384": + signingMethod = jwt.SigningMethodHS384 + case "HS512": + signingMethod = jwt.SigningMethodHS512 + + case "RS256": + signingMethod = jwt.SigningMethodRS256 + case "RS384": + signingMethod = jwt.SigningMethodRS384 + case "RS512": + signingMethod = jwt.SigningMethodRS512 + + case "ES256": + signingMethod = jwt.SigningMethodES256 + case "ES384": + signingMethod = jwt.SigningMethodES384 + case "ES512": + signingMethod = jwt.SigningMethodES512 + case "EdDSA": + signingMethod = jwt.SigningMethodEdDSA + default: + return nil, ErrInvalidAlgorithmType{algorithm} + } + + switch signingMethod.(type) { + case *jwt.SigningMethodEd25519: + privateKey, ok := key.(ed25519.PrivateKey) + if !ok { + return nil, jwt.ErrInvalidKeyType + } + return newEdDSASingingKey(signingMethod, privateKey) + case *jwt.SigningMethodECDSA: + privateKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + return nil, jwt.ErrInvalidKeyType + } + return newECDSASingingKey(signingMethod, privateKey) + case *jwt.SigningMethodRSA: + privateKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, jwt.ErrInvalidKeyType + } + return newRSASingingKey(signingMethod, privateKey) + default: + secret, ok := key.([]byte) + if !ok { + return nil, jwt.ErrInvalidKeyType + } + return hmacSigningKey{signingMethod, secret}, nil + } +} + +// DefaultSigningKey is the default signing key for JWTs. +var DefaultSigningKey JWTSigningKey + +// InitSigningKey creates the default signing key from settings or creates a random key. +func InitSigningKey() error { + var err error + var key any + + switch setting.OAuth2.JWTSigningAlgorithm { + case "HS256": + fallthrough + case "HS384": + fallthrough + case "HS512": + key = setting.GetGeneralTokenSigningSecret() + case "RS256": + fallthrough + case "RS384": + fallthrough + case "RS512": + fallthrough + case "ES256": + fallthrough + case "ES384": + fallthrough + case "ES512": + fallthrough + case "EdDSA": + key, err = loadOrCreateAsymmetricKey() + default: + return ErrInvalidAlgorithmType{setting.OAuth2.JWTSigningAlgorithm} + } + + if err != nil { + return fmt.Errorf("Error while loading or creating JWT key: %w", err) + } + + signingKey, err := CreateJWTSigningKey(setting.OAuth2.JWTSigningAlgorithm, key) + if err != nil { + return err + } + + DefaultSigningKey = signingKey + + return nil +} + +// loadOrCreateAsymmetricKey checks if the configured private key exists. +// If it does not exist a new random key gets generated and saved on the configured path. +func loadOrCreateAsymmetricKey() (any, error) { + keyPath := setting.OAuth2.JWTSigningPrivateKeyFile + + isExist, err := util.IsExist(keyPath) + if err != nil { + log.Fatal("Unable to check if %s exists. Error: %v", keyPath, err) + } + if !isExist { + err := func() error { + key, err := func() (any, error) { + switch { + case strings.HasPrefix(setting.OAuth2.JWTSigningAlgorithm, "RS"): + return rsa.GenerateKey(rand.Reader, 4096) + case setting.OAuth2.JWTSigningAlgorithm == "EdDSA": + _, pk, err := ed25519.GenerateKey(rand.Reader) + return pk, err + default: + return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + } + }() + if err != nil { + return err + } + + bytes, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + return err + } + + privateKeyPEM := &pem.Block{Type: "PRIVATE KEY", Bytes: bytes} + + if err := os.MkdirAll(filepath.Dir(keyPath), os.ModePerm); err != nil { + return err + } + + f, err := os.OpenFile(keyPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o600) + if err != nil { + return err + } + defer func() { + if err = f.Close(); err != nil { + log.Error("Close: %v", err) + } + }() + + return pem.Encode(f, privateKeyPEM) + }() + if err != nil { + log.Fatal("Error generating private key: %v", err) + return nil, err + } + } + + bytes, err := os.ReadFile(keyPath) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(bytes) + if block == nil { + return nil, fmt.Errorf("no valid PEM data found in %s", keyPath) + } else if block.Type != "PRIVATE KEY" { + return nil, fmt.Errorf("expected PRIVATE KEY, got %s in %s", block.Type, keyPath) + } + + return x509.ParsePKCS8PrivateKey(block.Bytes) +} diff --git a/services/auth/source/oauth2/main_test.go b/services/auth/source/oauth2/main_test.go new file mode 100644 index 0000000..57c74fd --- /dev/null +++ b/services/auth/source/oauth2/main_test.go @@ -0,0 +1,14 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "testing" + + "code.gitea.io/gitea/models/unittest" +) + +func TestMain(m *testing.M) { + unittest.MainTest(m, &unittest.TestOptions{}) +} diff --git a/services/auth/source/oauth2/providers.go b/services/auth/source/oauth2/providers.go new file mode 100644 index 0000000..f2c1bb4 --- /dev/null +++ b/services/auth/source/oauth2/providers.go @@ -0,0 +1,190 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "context" + "errors" + "fmt" + "html" + "html/template" + "net/url" + "sort" + + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/optional" + "code.gitea.io/gitea/modules/setting" + + "github.com/markbates/goth" +) + +// Provider is an interface for describing a single OAuth2 provider +type Provider interface { + Name() string + DisplayName() string + IconHTML(size int) template.HTML + CustomURLSettings() *CustomURLSettings +} + +// GothProviderCreator provides a function to create a goth.Provider +type GothProviderCreator interface { + CreateGothProvider(providerName, callbackURL string, source *Source) (goth.Provider, error) +} + +// GothProvider is an interface for describing a single OAuth2 provider +type GothProvider interface { + Provider + GothProviderCreator +} + +// AuthSourceProvider provides a provider for an AuthSource. Multiple auth sources could use the same registered GothProvider +// So each auth source should have its own DisplayName and IconHTML for display. +// The Name is the GothProvider's name, to help to find the GothProvider to sign in. +// The DisplayName is the auth source config's name, site admin set it on the admin page, the IconURL can also be set there. +type AuthSourceProvider struct { + GothProvider + sourceName, iconURL string +} + +func (p *AuthSourceProvider) Name() string { + return p.GothProvider.Name() +} + +func (p *AuthSourceProvider) DisplayName() string { + return p.sourceName +} + +func (p *AuthSourceProvider) IconHTML(size int) template.HTML { + if p.iconURL != "" { + img := fmt.Sprintf(`<img class="tw-object-contain tw-mr-2" width="%d" height="%d" src="%s" alt="%s">`, + size, + size, + html.EscapeString(p.iconURL), html.EscapeString(p.DisplayName()), + ) + return template.HTML(img) + } + return p.GothProvider.IconHTML(size) +} + +// Providers contains the map of registered OAuth2 providers in Gitea (based on goth) +// key is used to map the OAuth2Provider with the goth provider type (also in AuthSource.OAuth2Config.Provider) +// value is used to store display data +var gothProviders = map[string]GothProvider{} + +// RegisterGothProvider registers a GothProvider +func RegisterGothProvider(provider GothProvider) { + if _, has := gothProviders[provider.Name()]; has { + log.Fatal("Duplicate oauth2provider type provided: %s", provider.Name()) + } + gothProviders[provider.Name()] = provider +} + +// GetSupportedOAuth2Providers returns the map of unconfigured OAuth2 providers +// key is used as technical name (like in the callbackURL) +// values to display +func GetSupportedOAuth2Providers() []Provider { + providers := make([]Provider, 0, len(gothProviders)) + + for _, provider := range gothProviders { + providers = append(providers, provider) + } + sort.Slice(providers, func(i, j int) bool { + return providers[i].Name() < providers[j].Name() + }) + return providers +} + +func CreateProviderFromSource(source *auth.Source) (Provider, error) { + oauth2Cfg, ok := source.Cfg.(*Source) + if !ok { + return nil, fmt.Errorf("invalid OAuth2 source config: %v", oauth2Cfg) + } + gothProv := gothProviders[oauth2Cfg.Provider] + return &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL}, nil +} + +// GetOAuth2Providers returns the list of configured OAuth2 providers +func GetOAuth2Providers(ctx context.Context, isActive optional.Option[bool]) ([]Provider, error) { + authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{ + IsActive: isActive, + LoginType: auth.OAuth2, + }) + if err != nil { + return nil, err + } + + providers := make([]Provider, 0, len(authSources)) + for _, source := range authSources { + provider, err := CreateProviderFromSource(source) + if err != nil { + return nil, err + } + providers = append(providers, provider) + } + + sort.Slice(providers, func(i, j int) bool { + return providers[i].Name() < providers[j].Name() + }) + + return providers, nil +} + +// RegisterProviderWithGothic register a OAuth2 provider in goth lib +func RegisterProviderWithGothic(providerName string, source *Source) error { + provider, err := createProvider(providerName, source) + + if err == nil && provider != nil { + gothRWMutex.Lock() + defer gothRWMutex.Unlock() + + goth.UseProviders(provider) + } + + return err +} + +// RemoveProviderFromGothic removes the given OAuth2 provider from the goth lib +func RemoveProviderFromGothic(providerName string) { + gothRWMutex.Lock() + defer gothRWMutex.Unlock() + + delete(goth.GetProviders(), providerName) +} + +// ClearProviders clears all OAuth2 providers from the goth lib +func ClearProviders() { + gothRWMutex.Lock() + defer gothRWMutex.Unlock() + + goth.ClearProviders() +} + +var ErrAuthSourceNotActivated = errors.New("auth source is not activated") + +// used to create different types of goth providers +func createProvider(providerName string, source *Source) (goth.Provider, error) { + callbackURL := setting.AppURL + "user/oauth2/" + url.PathEscape(providerName) + "/callback" + + var provider goth.Provider + var err error + + p, ok := gothProviders[source.Provider] + if !ok { + return nil, ErrAuthSourceNotActivated + } + + provider, err = p.CreateGothProvider(providerName, callbackURL, source) + if err != nil { + return provider, err + } + + // always set the name if provider is created so we can support multiple setups of 1 provider + if provider != nil { + provider.SetName(providerName) + } + + return provider, err +} diff --git a/services/auth/source/oauth2/providers_base.go b/services/auth/source/oauth2/providers_base.go new file mode 100644 index 0000000..9d4ab10 --- /dev/null +++ b/services/auth/source/oauth2/providers_base.go @@ -0,0 +1,51 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "html/template" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/svg" +) + +// BaseProvider represents a common base for Provider +type BaseProvider struct { + name string + displayName string +} + +// Name provides the technical name for this provider +func (b *BaseProvider) Name() string { + return b.name +} + +// DisplayName returns the friendly name for this provider +func (b *BaseProvider) DisplayName() string { + return b.displayName +} + +// IconHTML returns icon HTML for this provider +func (b *BaseProvider) IconHTML(size int) template.HTML { + svgName := "gitea-" + b.name + switch b.name { + case "gplus": + svgName = "gitea-google" + case "github": + svgName = "octicon-mark-github" + } + svgHTML := svg.RenderHTML(svgName, size, "tw-mr-2") + if svgHTML == "" { + log.Error("No SVG icon for oauth2 provider %q", b.name) + svgHTML = svg.RenderHTML("gitea-openid", size, "tw-mr-2") + } + return svgHTML +} + +// CustomURLSettings returns the custom url settings for this provider +func (b *BaseProvider) CustomURLSettings() *CustomURLSettings { + return nil +} + +var _ Provider = &BaseProvider{} diff --git a/services/auth/source/oauth2/providers_custom.go b/services/auth/source/oauth2/providers_custom.go new file mode 100644 index 0000000..65cf538 --- /dev/null +++ b/services/auth/source/oauth2/providers_custom.go @@ -0,0 +1,123 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "code.gitea.io/gitea/modules/setting" + + "github.com/markbates/goth" + "github.com/markbates/goth/providers/azureadv2" + "github.com/markbates/goth/providers/gitea" + "github.com/markbates/goth/providers/github" + "github.com/markbates/goth/providers/gitlab" + "github.com/markbates/goth/providers/mastodon" + "github.com/markbates/goth/providers/nextcloud" +) + +// CustomProviderNewFn creates a goth.Provider using a custom url mapping +type CustomProviderNewFn func(clientID, secret, callbackURL string, custom *CustomURLMapping, scopes []string) (goth.Provider, error) + +// CustomProvider is a GothProvider that has CustomURL features +type CustomProvider struct { + BaseProvider + customURLSettings *CustomURLSettings + newFn CustomProviderNewFn +} + +// CustomURLSettings returns the CustomURLSettings for this provider +func (c *CustomProvider) CustomURLSettings() *CustomURLSettings { + return c.customURLSettings +} + +// CreateGothProvider creates a GothProvider from this Provider +func (c *CustomProvider) CreateGothProvider(providerName, callbackURL string, source *Source) (goth.Provider, error) { + custom := c.customURLSettings.OverrideWith(source.CustomURLMapping) + + return c.newFn(source.ClientID, source.ClientSecret, callbackURL, custom, source.Scopes) +} + +// NewCustomProvider is a constructor function for custom providers +func NewCustomProvider(name, displayName string, customURLSetting *CustomURLSettings, newFn CustomProviderNewFn) *CustomProvider { + return &CustomProvider{ + BaseProvider: BaseProvider{ + name: name, + displayName: displayName, + }, + customURLSettings: customURLSetting, + newFn: newFn, + } +} + +var _ GothProvider = &CustomProvider{} + +func init() { + RegisterGothProvider(NewCustomProvider( + "github", "GitHub", &CustomURLSettings{ + TokenURL: availableAttribute(github.TokenURL), + AuthURL: availableAttribute(github.AuthURL), + ProfileURL: availableAttribute(github.ProfileURL), + EmailURL: availableAttribute(github.EmailURL), + }, + func(clientID, secret, callbackURL string, custom *CustomURLMapping, scopes []string) (goth.Provider, error) { + if setting.OAuth2Client.EnableAutoRegistration { + scopes = append(scopes, "user:email") + } + return github.NewCustomisedURL(clientID, secret, callbackURL, custom.AuthURL, custom.TokenURL, custom.ProfileURL, custom.EmailURL, scopes...), nil + })) + + RegisterGothProvider(NewCustomProvider( + "gitlab", "GitLab", &CustomURLSettings{ + AuthURL: availableAttribute(gitlab.AuthURL), + TokenURL: availableAttribute(gitlab.TokenURL), + ProfileURL: availableAttribute(gitlab.ProfileURL), + }, func(clientID, secret, callbackURL string, custom *CustomURLMapping, scopes []string) (goth.Provider, error) { + scopes = append(scopes, "read_user") + return gitlab.NewCustomisedURL(clientID, secret, callbackURL, custom.AuthURL, custom.TokenURL, custom.ProfileURL, scopes...), nil + })) + + RegisterGothProvider(NewCustomProvider( + "gitea", "Gitea", &CustomURLSettings{ + TokenURL: requiredAttribute(gitea.TokenURL), + AuthURL: requiredAttribute(gitea.AuthURL), + ProfileURL: requiredAttribute(gitea.ProfileURL), + }, + func(clientID, secret, callbackURL string, custom *CustomURLMapping, scopes []string) (goth.Provider, error) { + return gitea.NewCustomisedURL(clientID, secret, callbackURL, custom.AuthURL, custom.TokenURL, custom.ProfileURL, scopes...), nil + })) + + RegisterGothProvider(NewCustomProvider( + "nextcloud", "Nextcloud", &CustomURLSettings{ + TokenURL: requiredAttribute(nextcloud.TokenURL), + AuthURL: requiredAttribute(nextcloud.AuthURL), + ProfileURL: requiredAttribute(nextcloud.ProfileURL), + }, + func(clientID, secret, callbackURL string, custom *CustomURLMapping, scopes []string) (goth.Provider, error) { + return nextcloud.NewCustomisedURL(clientID, secret, callbackURL, custom.AuthURL, custom.TokenURL, custom.ProfileURL, scopes...), nil + })) + + RegisterGothProvider(NewCustomProvider( + "mastodon", "Mastodon", &CustomURLSettings{ + AuthURL: requiredAttribute(mastodon.InstanceURL), + }, + func(clientID, secret, callbackURL string, custom *CustomURLMapping, scopes []string) (goth.Provider, error) { + return mastodon.NewCustomisedURL(clientID, secret, callbackURL, custom.AuthURL, scopes...), nil + })) + + RegisterGothProvider(NewCustomProvider( + "azureadv2", "Azure AD v2", &CustomURLSettings{ + Tenant: requiredAttribute("organizations"), + }, + func(clientID, secret, callbackURL string, custom *CustomURLMapping, scopes []string) (goth.Provider, error) { + azureScopes := make([]azureadv2.ScopeType, len(scopes)) + for i, scope := range scopes { + azureScopes[i] = azureadv2.ScopeType(scope) + } + + return azureadv2.New(clientID, secret, callbackURL, azureadv2.ProviderOptions{ + Tenant: azureadv2.TenantType(custom.Tenant), + Scopes: azureScopes, + }), nil + }, + )) +} diff --git a/services/auth/source/oauth2/providers_openid.go b/services/auth/source/oauth2/providers_openid.go new file mode 100644 index 0000000..285876d --- /dev/null +++ b/services/auth/source/oauth2/providers_openid.go @@ -0,0 +1,58 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "html/template" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/svg" + + "github.com/markbates/goth" + "github.com/markbates/goth/providers/openidConnect" +) + +// OpenIDProvider is a GothProvider for OpenID +type OpenIDProvider struct{} + +// Name provides the technical name for this provider +func (o *OpenIDProvider) Name() string { + return "openidConnect" +} + +// DisplayName returns the friendly name for this provider +func (o *OpenIDProvider) DisplayName() string { + return "OpenID Connect" +} + +// IconHTML returns icon HTML for this provider +func (o *OpenIDProvider) IconHTML(size int) template.HTML { + return svg.RenderHTML("gitea-openid", size, "tw-mr-2") +} + +// CreateGothProvider creates a GothProvider from this Provider +func (o *OpenIDProvider) CreateGothProvider(providerName, callbackURL string, source *Source) (goth.Provider, error) { + scopes := setting.OAuth2Client.OpenIDConnectScopes + if len(scopes) == 0 { + scopes = append(scopes, source.Scopes...) + } + + provider, err := openidConnect.New(source.ClientID, source.ClientSecret, callbackURL, source.OpenIDConnectAutoDiscoveryURL, scopes...) + if err != nil { + log.Warn("Failed to create OpenID Connect Provider with name '%s' with url '%s': %v", providerName, source.OpenIDConnectAutoDiscoveryURL, err) + } + return provider, err +} + +// CustomURLSettings returns the custom url settings for this provider +func (o *OpenIDProvider) CustomURLSettings() *CustomURLSettings { + return nil +} + +var _ GothProvider = &OpenIDProvider{} + +func init() { + RegisterGothProvider(&OpenIDProvider{}) +} diff --git a/services/auth/source/oauth2/providers_simple.go b/services/auth/source/oauth2/providers_simple.go new file mode 100644 index 0000000..e95323a --- /dev/null +++ b/services/auth/source/oauth2/providers_simple.go @@ -0,0 +1,109 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "code.gitea.io/gitea/modules/setting" + + "github.com/markbates/goth" + "github.com/markbates/goth/providers/azuread" + "github.com/markbates/goth/providers/bitbucket" + "github.com/markbates/goth/providers/discord" + "github.com/markbates/goth/providers/dropbox" + "github.com/markbates/goth/providers/facebook" + "github.com/markbates/goth/providers/google" + "github.com/markbates/goth/providers/microsoftonline" + "github.com/markbates/goth/providers/twitter" + "github.com/markbates/goth/providers/yandex" +) + +// SimpleProviderNewFn create goth.Providers without custom url features +type SimpleProviderNewFn func(clientKey, secret, callbackURL string, scopes ...string) goth.Provider + +// SimpleProvider is a GothProvider which does not have custom url features +type SimpleProvider struct { + BaseProvider + scopes []string + newFn SimpleProviderNewFn +} + +// CreateGothProvider creates a GothProvider from this Provider +func (c *SimpleProvider) CreateGothProvider(providerName, callbackURL string, source *Source) (goth.Provider, error) { + scopes := make([]string, len(c.scopes)+len(source.Scopes)) + copy(scopes, c.scopes) + copy(scopes[len(c.scopes):], source.Scopes) + return c.newFn(source.ClientID, source.ClientSecret, callbackURL, scopes...), nil +} + +// NewSimpleProvider is a constructor function for simple providers +func NewSimpleProvider(name, displayName string, scopes []string, newFn SimpleProviderNewFn) *SimpleProvider { + return &SimpleProvider{ + BaseProvider: BaseProvider{ + name: name, + displayName: displayName, + }, + scopes: scopes, + newFn: newFn, + } +} + +var _ GothProvider = &SimpleProvider{} + +func init() { + RegisterGothProvider( + NewSimpleProvider("bitbucket", "Bitbucket", []string{"account"}, + func(clientKey, secret, callbackURL string, scopes ...string) goth.Provider { + return bitbucket.New(clientKey, secret, callbackURL, scopes...) + })) + + RegisterGothProvider( + NewSimpleProvider("dropbox", "Dropbox", nil, + func(clientKey, secret, callbackURL string, scopes ...string) goth.Provider { + return dropbox.New(clientKey, secret, callbackURL, scopes...) + })) + + RegisterGothProvider(NewSimpleProvider("facebook", "Facebook", nil, + func(clientKey, secret, callbackURL string, scopes ...string) goth.Provider { + return facebook.New(clientKey, secret, callbackURL, scopes...) + })) + + // named gplus due to legacy gplus -> google migration (Google killed Google+). This ensures old connections still work + RegisterGothProvider(NewSimpleProvider("gplus", "Google", []string{"email"}, + func(clientKey, secret, callbackURL string, scopes ...string) goth.Provider { + if setting.OAuth2Client.UpdateAvatar || setting.OAuth2Client.EnableAutoRegistration { + scopes = append(scopes, "profile") + } + return google.New(clientKey, secret, callbackURL, scopes...) + })) + + RegisterGothProvider(NewSimpleProvider("twitter", "Twitter", nil, + func(clientKey, secret, callbackURL string, scopes ...string) goth.Provider { + return twitter.New(clientKey, secret, callbackURL) + })) + + RegisterGothProvider(NewSimpleProvider("discord", "Discord", []string{discord.ScopeIdentify, discord.ScopeEmail}, + func(clientKey, secret, callbackURL string, scopes ...string) goth.Provider { + return discord.New(clientKey, secret, callbackURL, scopes...) + })) + + // See https://tech.yandex.com/passport/doc/dg/reference/response-docpage/ + RegisterGothProvider(NewSimpleProvider("yandex", "Yandex", []string{"login:email", "login:info", "login:avatar"}, + func(clientKey, secret, callbackURL string, scopes ...string) goth.Provider { + return yandex.New(clientKey, secret, callbackURL, scopes...) + })) + + RegisterGothProvider(NewSimpleProvider( + "azuread", "Azure AD", nil, + func(clientID, secret, callbackURL string, scopes ...string) goth.Provider { + return azuread.New(clientID, secret, callbackURL, nil, scopes...) + }, + )) + + RegisterGothProvider(NewSimpleProvider( + "microsoftonline", "Microsoft Online", nil, + func(clientID, secret, callbackURL string, scopes ...string) goth.Provider { + return microsoftonline.New(clientID, secret, callbackURL, scopes...) + }, + )) +} diff --git a/services/auth/source/oauth2/providers_test.go b/services/auth/source/oauth2/providers_test.go new file mode 100644 index 0000000..353816c --- /dev/null +++ b/services/auth/source/oauth2/providers_test.go @@ -0,0 +1,62 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "time" + + "github.com/markbates/goth" + "golang.org/x/oauth2" +) + +type fakeProvider struct{} + +func (p *fakeProvider) Name() string { + return "fake" +} + +func (p *fakeProvider) SetName(name string) {} + +func (p *fakeProvider) BeginAuth(state string) (goth.Session, error) { + return nil, nil +} + +func (p *fakeProvider) UnmarshalSession(string) (goth.Session, error) { + return nil, nil +} + +func (p *fakeProvider) FetchUser(goth.Session) (goth.User, error) { + return goth.User{}, nil +} + +func (p *fakeProvider) Debug(bool) { +} + +func (p *fakeProvider) RefreshToken(refreshToken string) (*oauth2.Token, error) { + switch refreshToken { + case "expired": + return nil, &oauth2.RetrieveError{ + ErrorCode: "invalid_grant", + } + default: + return &oauth2.Token{ + AccessToken: "token", + TokenType: "Bearer", + RefreshToken: "refresh", + Expiry: time.Now().Add(time.Hour), + }, nil + } +} + +func (p *fakeProvider) RefreshTokenAvailable() bool { + return true +} + +func init() { + RegisterGothProvider( + NewSimpleProvider("fake", "Fake", []string{"account"}, + func(clientKey, secret, callbackURL string, scopes ...string) goth.Provider { + return &fakeProvider{} + })) +} diff --git a/services/auth/source/oauth2/source.go b/services/auth/source/oauth2/source.go new file mode 100644 index 0000000..3454c9a --- /dev/null +++ b/services/auth/source/oauth2/source.go @@ -0,0 +1,51 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/modules/json" +) + +// Source holds configuration for the OAuth2 login source. +type Source struct { + Provider string + ClientID string + ClientSecret string + OpenIDConnectAutoDiscoveryURL string + CustomURLMapping *CustomURLMapping + IconURL string + + Scopes []string + RequiredClaimName string + RequiredClaimValue string + GroupClaimName string + AdminGroup string + GroupTeamMap string + GroupTeamMapRemoval bool + RestrictedGroup string + SkipLocalTwoFA bool `json:",omitempty"` + + // reference to the authSource + authSource *auth.Source +} + +// FromDB fills up an OAuth2Config from serialized format. +func (source *Source) FromDB(bs []byte) error { + return json.UnmarshalHandleDoubleEncode(bs, &source) +} + +// ToDB exports an OAuth2Config to a serialized format. +func (source *Source) ToDB() ([]byte, error) { + return json.Marshal(source) +} + +// SetAuthSource sets the related AuthSource +func (source *Source) SetAuthSource(authSource *auth.Source) { + source.authSource = authSource +} + +func init() { + auth.RegisterTypeConfig(auth.OAuth2, &Source{}) +} diff --git a/services/auth/source/oauth2/source_authenticate.go b/services/auth/source/oauth2/source_authenticate.go new file mode 100644 index 0000000..bbda35d --- /dev/null +++ b/services/auth/source/oauth2/source_authenticate.go @@ -0,0 +1,19 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "context" + + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/services/auth/source/db" +) + +// Authenticate falls back to the db authenticator +func (source *Source) Authenticate(ctx context.Context, user *user_model.User, login, password string) (*user_model.User, error) { + return db.Authenticate(ctx, user, login, password) +} + +// NB: Oauth2 does not implement LocalTwoFASkipper for password authentication +// as its password authentication drops to db authentication diff --git a/services/auth/source/oauth2/source_callout.go b/services/auth/source/oauth2/source_callout.go new file mode 100644 index 0000000..f95a80f --- /dev/null +++ b/services/auth/source/oauth2/source_callout.go @@ -0,0 +1,68 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "net/http" + "net/url" + + "github.com/markbates/goth" + "github.com/markbates/goth/gothic" +) + +// Callout redirects request/response pair to authenticate against the provider +func (source *Source) Callout(request *http.Request, response http.ResponseWriter, codeChallengeS256 string) error { + // not sure if goth is thread safe (?) when using multiple providers + request.Header.Set(ProviderHeaderKey, source.authSource.Name) + + var querySuffix string + if codeChallengeS256 != "" { + querySuffix = "&" + url.Values{ + "code_challenge_method": []string{"S256"}, + "code_challenge": []string{codeChallengeS256}, + }.Encode() + } + + // don't use the default gothic begin handler to prevent issues when some error occurs + // normally the gothic library will write some custom stuff to the response instead of our own nice error page + // gothic.BeginAuthHandler(response, request) + + gothRWMutex.RLock() + defer gothRWMutex.RUnlock() + + url, err := gothic.GetAuthURL(response, request) + if err == nil { + // hacky way to set the code_challenge, but no better way until + // https://github.com/markbates/goth/issues/516 is resolved + http.Redirect(response, request, url+querySuffix, http.StatusTemporaryRedirect) + } + return err +} + +// Callback handles OAuth callback, resolve to a goth user and send back to original url +// this will trigger a new authentication request, but because we save it in the session we can use that +func (source *Source) Callback(request *http.Request, response http.ResponseWriter, codeVerifier string) (goth.User, error) { + // not sure if goth is thread safe (?) when using multiple providers + request.Header.Set(ProviderHeaderKey, source.authSource.Name) + + if codeVerifier != "" { + // hacky way to set the code_verifier... + // Will be picked up inside CompleteUserAuth: params := req.URL.Query() + // https://github.com/markbates/goth/pull/474/files + request = request.Clone(request.Context()) + q := request.URL.Query() + q.Add("code_verifier", codeVerifier) + request.URL.RawQuery = q.Encode() + } + + gothRWMutex.RLock() + defer gothRWMutex.RUnlock() + + user, err := gothic.CompleteUserAuth(response, request) + if err != nil { + return user, err + } + + return user, nil +} diff --git a/services/auth/source/oauth2/source_name.go b/services/auth/source/oauth2/source_name.go new file mode 100644 index 0000000..eee789e --- /dev/null +++ b/services/auth/source/oauth2/source_name.go @@ -0,0 +1,18 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +// Name returns the provider name of this source +func (source *Source) Name() string { + return source.Provider +} + +// DisplayName returns the display name of this source +func (source *Source) DisplayName() string { + provider, has := gothProviders[source.Provider] + if !has { + return source.Provider + } + return provider.DisplayName() +} diff --git a/services/auth/source/oauth2/source_register.go b/services/auth/source/oauth2/source_register.go new file mode 100644 index 0000000..82a36ac --- /dev/null +++ b/services/auth/source/oauth2/source_register.go @@ -0,0 +1,50 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "fmt" +) + +// RegisterSource causes an OAuth2 configuration to be registered +func (source *Source) RegisterSource() error { + err := RegisterProviderWithGothic(source.authSource.Name, source) + return wrapOpenIDConnectInitializeError(err, source.authSource.Name, source) +} + +// UnregisterSource causes an OAuth2 configuration to be unregistered +func (source *Source) UnregisterSource() error { + RemoveProviderFromGothic(source.authSource.Name) + return nil +} + +// ErrOpenIDConnectInitialize represents a "OpenIDConnectInitialize" kind of error. +type ErrOpenIDConnectInitialize struct { + OpenIDConnectAutoDiscoveryURL string + ProviderName string + Cause error +} + +// IsErrOpenIDConnectInitialize checks if an error is a ExternalLoginUserAlreadyExist. +func IsErrOpenIDConnectInitialize(err error) bool { + _, ok := err.(ErrOpenIDConnectInitialize) + return ok +} + +func (err ErrOpenIDConnectInitialize) Error() string { + return fmt.Sprintf("Failed to initialize OpenID Connect Provider with name '%s' with url '%s': %v", err.ProviderName, err.OpenIDConnectAutoDiscoveryURL, err.Cause) +} + +func (err ErrOpenIDConnectInitialize) Unwrap() error { + return err.Cause +} + +// wrapOpenIDConnectInitializeError is used to wrap the error but this cannot be done in modules/auth/oauth2 +// inside oauth2: import cycle not allowed models -> modules/auth/oauth2 -> models +func wrapOpenIDConnectInitializeError(err error, providerName string, source *Source) error { + if err != nil && source.Provider == "openidConnect" { + err = ErrOpenIDConnectInitialize{ProviderName: providerName, OpenIDConnectAutoDiscoveryURL: source.OpenIDConnectAutoDiscoveryURL, Cause: err} + } + return err +} diff --git a/services/auth/source/oauth2/source_sync.go b/services/auth/source/oauth2/source_sync.go new file mode 100644 index 0000000..5e30313 --- /dev/null +++ b/services/auth/source/oauth2/source_sync.go @@ -0,0 +1,114 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "context" + "time" + + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/log" + + "github.com/markbates/goth" + "golang.org/x/oauth2" +) + +// Sync causes this OAuth2 source to synchronize its users with the db. +func (source *Source) Sync(ctx context.Context, updateExisting bool) error { + log.Trace("Doing: SyncExternalUsers[%s] %d", source.authSource.Name, source.authSource.ID) + + if !updateExisting { + log.Info("SyncExternalUsers[%s] not running since updateExisting is false", source.authSource.Name) + return nil + } + + provider, err := createProvider(source.authSource.Name, source) + if err != nil { + return err + } + + if !provider.RefreshTokenAvailable() { + log.Trace("SyncExternalUsers[%s] provider doesn't support refresh tokens, can't synchronize", source.authSource.Name) + return nil + } + + opts := user_model.FindExternalUserOptions{ + HasRefreshToken: true, + Expired: true, + LoginSourceID: source.authSource.ID, + } + + return user_model.IterateExternalLogin(ctx, opts, func(ctx context.Context, u *user_model.ExternalLoginUser) error { + return source.refresh(ctx, provider, u) + }) +} + +func (source *Source) refresh(ctx context.Context, provider goth.Provider, u *user_model.ExternalLoginUser) error { + log.Trace("Syncing login_source_id=%d external_id=%s expiration=%s", u.LoginSourceID, u.ExternalID, u.ExpiresAt) + + shouldDisable := false + + token, err := provider.RefreshToken(u.RefreshToken) + if err != nil { + if err, ok := err.(*oauth2.RetrieveError); ok && err.ErrorCode == "invalid_grant" { + // this signals that the token is not valid and the user should be disabled + shouldDisable = true + } else { + return err + } + } + + user := &user_model.User{ + LoginName: u.ExternalID, + LoginType: auth.OAuth2, + LoginSource: u.LoginSourceID, + } + + hasUser, err := user_model.GetUser(ctx, user) + if err != nil { + return err + } + + // If the grant is no longer valid, disable the user and + // delete local tokens. If the OAuth2 provider still + // recognizes them as a valid user, they will be able to login + // via their provider and reactivate their account. + if shouldDisable { + log.Info("SyncExternalUsers[%s] disabling user %d", source.authSource.Name, user.ID) + + return db.WithTx(ctx, func(ctx context.Context) error { + if hasUser { + user.IsActive = false + err := user_model.UpdateUserCols(ctx, user, "is_active") + if err != nil { + return err + } + } + + // Delete stored tokens, since they are invalid. This + // also provents us from checking this in subsequent runs. + u.AccessToken = "" + u.RefreshToken = "" + u.ExpiresAt = time.Time{} + + return user_model.UpdateExternalUserByExternalID(ctx, u) + }) + } + + // Otherwise, update the tokens + u.AccessToken = token.AccessToken + u.ExpiresAt = token.Expiry + + // Some providers only update access tokens provide a new + // refresh token, so avoid updating it if it's empty + if token.RefreshToken != "" { + u.RefreshToken = token.RefreshToken + } + + err = user_model.UpdateExternalUserByExternalID(ctx, u) + + return err +} diff --git a/services/auth/source/oauth2/source_sync_test.go b/services/auth/source/oauth2/source_sync_test.go new file mode 100644 index 0000000..746df82 --- /dev/null +++ b/services/auth/source/oauth2/source_sync_test.go @@ -0,0 +1,101 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "context" + "testing" + + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/unittest" + user_model "code.gitea.io/gitea/models/user" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSource(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + source := &Source{ + Provider: "fake", + authSource: &auth.Source{ + ID: 12, + Type: auth.OAuth2, + Name: "fake", + IsActive: true, + IsSyncEnabled: true, + }, + } + + user := &user_model.User{ + LoginName: "external", + LoginType: auth.OAuth2, + LoginSource: source.authSource.ID, + Name: "test", + Email: "external@example.com", + } + + err := user_model.CreateUser(context.Background(), user, &user_model.CreateUserOverwriteOptions{}) + require.NoError(t, err) + + e := &user_model.ExternalLoginUser{ + ExternalID: "external", + UserID: user.ID, + LoginSourceID: user.LoginSource, + RefreshToken: "valid", + } + err = user_model.LinkExternalToUser(context.Background(), user, e) + require.NoError(t, err) + + provider, err := createProvider(source.authSource.Name, source) + require.NoError(t, err) + + t.Run("refresh", func(t *testing.T) { + t.Run("valid", func(t *testing.T) { + err := source.refresh(context.Background(), provider, e) + require.NoError(t, err) + + e := &user_model.ExternalLoginUser{ + ExternalID: e.ExternalID, + LoginSourceID: e.LoginSourceID, + } + + ok, err := user_model.GetExternalLogin(context.Background(), e) + require.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, "refresh", e.RefreshToken) + assert.Equal(t, "token", e.AccessToken) + + u, err := user_model.GetUserByID(context.Background(), user.ID) + require.NoError(t, err) + assert.True(t, u.IsActive) + }) + + t.Run("expired", func(t *testing.T) { + err := source.refresh(context.Background(), provider, &user_model.ExternalLoginUser{ + ExternalID: "external", + UserID: user.ID, + LoginSourceID: user.LoginSource, + RefreshToken: "expired", + }) + require.NoError(t, err) + + e := &user_model.ExternalLoginUser{ + ExternalID: e.ExternalID, + LoginSourceID: e.LoginSourceID, + } + + ok, err := user_model.GetExternalLogin(context.Background(), e) + require.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, "", e.RefreshToken) + assert.Equal(t, "", e.AccessToken) + + u, err := user_model.GetUserByID(context.Background(), user.ID) + require.NoError(t, err) + assert.False(t, u.IsActive) + }) + }) +} diff --git a/services/auth/source/oauth2/store.go b/services/auth/source/oauth2/store.go new file mode 100644 index 0000000..e031653 --- /dev/null +++ b/services/auth/source/oauth2/store.go @@ -0,0 +1,98 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "encoding/gob" + "fmt" + "net/http" + + "code.gitea.io/gitea/modules/log" + session_module "code.gitea.io/gitea/modules/session" + + chiSession "code.forgejo.org/go-chi/session" + "github.com/gorilla/sessions" +) + +// SessionsStore creates a gothic store from our session +type SessionsStore struct { + maxLength int64 +} + +// Get should return a cached session. +func (st *SessionsStore) Get(r *http.Request, name string) (*sessions.Session, error) { + return st.getOrNew(r, name, false) +} + +// New should create and return a new session. +// +// Note that New should never return a nil session, even in the case of +// an error if using the Registry infrastructure to cache the session. +func (st *SessionsStore) New(r *http.Request, name string) (*sessions.Session, error) { + return st.getOrNew(r, name, true) +} + +// getOrNew gets the session from the chi-session if it exists. Override permits the overriding of an unexpected object. +func (st *SessionsStore) getOrNew(r *http.Request, name string, override bool) (*sessions.Session, error) { + chiStore := chiSession.GetSession(r) + + session := sessions.NewSession(st, name) + + rawData := chiStore.Get(name) + if rawData != nil { + oldSession, ok := rawData.(*sessions.Session) + if ok { + session.ID = oldSession.ID + session.IsNew = oldSession.IsNew + session.Options = oldSession.Options + session.Values = oldSession.Values + + return session, nil + } else if !override { + log.Error("Unexpected object in session at name: %s: %v", name, rawData) + return nil, fmt.Errorf("unexpected object in session at name: %s", name) + } + } + + session.IsNew = override + session.ID = chiStore.ID() // Simply copy the session id from the chi store + + return session, chiStore.Set(name, session) +} + +// Save should persist session to the underlying store implementation. +func (st *SessionsStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { + chiStore := chiSession.GetSession(r) + + if session.IsNew { + _, _ = session_module.RegenerateSession(w, r) + session.IsNew = false + } + + if err := chiStore.Set(session.Name(), session); err != nil { + return err + } + + if st.maxLength > 0 { + sizeWriter := &sizeWriter{} + + _ = gob.NewEncoder(sizeWriter).Encode(session) + if sizeWriter.size > st.maxLength { + return fmt.Errorf("encode session: Data too long: %d > %d", sizeWriter.size, st.maxLength) + } + } + + return chiStore.Release() +} + +type sizeWriter struct { + size int64 +} + +func (s *sizeWriter) Write(data []byte) (int, error) { + s.size += int64(len(data)) + return len(data), nil +} + +var _ (sessions.Store) = &SessionsStore{} diff --git a/services/auth/source/oauth2/token.go b/services/auth/source/oauth2/token.go new file mode 100644 index 0000000..3405619 --- /dev/null +++ b/services/auth/source/oauth2/token.go @@ -0,0 +1,100 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +import ( + "fmt" + "time" + + "code.gitea.io/gitea/modules/timeutil" + + "github.com/golang-jwt/jwt/v5" +) + +// ___________ __ +// \__ ___/___ | | __ ____ ____ +// | | / _ \| |/ // __ \ / \ +// | |( <_> ) <\ ___/| | \ +// |____| \____/|__|_ \\___ >___| / +// \/ \/ \/ + +// Token represents an Oauth grant + +// TokenType represents the type of token for an oauth application +type TokenType int + +const ( + // TypeAccessToken is a token with short lifetime to access the api + TypeAccessToken TokenType = 0 + // TypeRefreshToken is token with long lifetime to refresh access tokens obtained by the client + TypeRefreshToken = iota +) + +// Token represents a JWT token used to authenticate a client +type Token struct { + GrantID int64 `json:"gnt"` + Type TokenType `json:"tt"` + Counter int64 `json:"cnt,omitempty"` + jwt.RegisteredClaims +} + +// ParseToken parses a signed jwt string +func ParseToken(jwtToken string, signingKey JWTSigningKey) (*Token, error) { + parsedToken, err := jwt.ParseWithClaims(jwtToken, &Token{}, func(token *jwt.Token) (any, error) { + if token.Method == nil || token.Method.Alg() != signingKey.SigningMethod().Alg() { + return nil, fmt.Errorf("unexpected signing algo: %v", token.Header["alg"]) + } + return signingKey.VerifyKey(), nil + }) + if err != nil { + return nil, err + } + if !parsedToken.Valid { + return nil, fmt.Errorf("invalid token") + } + var token *Token + var ok bool + if token, ok = parsedToken.Claims.(*Token); !ok || !parsedToken.Valid { + return nil, fmt.Errorf("invalid token") + } + return token, nil +} + +// SignToken signs the token with the JWT secret +func (token *Token) SignToken(signingKey JWTSigningKey) (string, error) { + token.IssuedAt = jwt.NewNumericDate(time.Now()) + jwtToken := jwt.NewWithClaims(signingKey.SigningMethod(), token) + signingKey.PreProcessToken(jwtToken) + return jwtToken.SignedString(signingKey.SignKey()) +} + +// OIDCToken represents an OpenID Connect id_token +type OIDCToken struct { + jwt.RegisteredClaims + Nonce string `json:"nonce,omitempty"` + + // Scope profile + Name string `json:"name,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` + Profile string `json:"profile,omitempty"` + Picture string `json:"picture,omitempty"` + Website string `json:"website,omitempty"` + Locale string `json:"locale,omitempty"` + UpdatedAt timeutil.TimeStamp `json:"updated_at,omitempty"` + + // Scope email + Email string `json:"email,omitempty"` + EmailVerified bool `json:"email_verified,omitempty"` + + // Groups are generated by organization and team names + Groups []string `json:"groups,omitempty"` +} + +// SignToken signs an id_token with the (symmetric) client secret key +func (token *OIDCToken) SignToken(signingKey JWTSigningKey) (string, error) { + token.IssuedAt = jwt.NewNumericDate(time.Now()) + jwtToken := jwt.NewWithClaims(signingKey.SigningMethod(), token) + signingKey.PreProcessToken(jwtToken) + return jwtToken.SignedString(signingKey.SignKey()) +} diff --git a/services/auth/source/oauth2/urlmapping.go b/services/auth/source/oauth2/urlmapping.go new file mode 100644 index 0000000..d0442d5 --- /dev/null +++ b/services/auth/source/oauth2/urlmapping.go @@ -0,0 +1,77 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package oauth2 + +// CustomURLMapping describes the urls values to use when customizing OAuth2 provider URLs +type CustomURLMapping struct { + AuthURL string `json:",omitempty"` + TokenURL string `json:",omitempty"` + ProfileURL string `json:",omitempty"` + EmailURL string `json:",omitempty"` + Tenant string `json:",omitempty"` +} + +// CustomURLSettings describes the urls values and availability to use when customizing OAuth2 provider URLs +type CustomURLSettings struct { + AuthURL Attribute `json:",omitempty"` + TokenURL Attribute `json:",omitempty"` + ProfileURL Attribute `json:",omitempty"` + EmailURL Attribute `json:",omitempty"` + Tenant Attribute `json:",omitempty"` +} + +// Attribute describes the availability, and required status for a custom url configuration +type Attribute struct { + Value string + Available bool + Required bool +} + +func availableAttribute(value string) Attribute { + return Attribute{Value: value, Available: true} +} + +func requiredAttribute(value string) Attribute { + return Attribute{Value: value, Available: true, Required: true} +} + +// Required is true if any attribute is required +func (c *CustomURLSettings) Required() bool { + if c == nil { + return false + } + if c.AuthURL.Required || c.EmailURL.Required || c.ProfileURL.Required || c.TokenURL.Required || c.Tenant.Required { + return true + } + return false +} + +// OverrideWith copies the current customURLMapping and overrides it with values from the provided mapping +func (c *CustomURLSettings) OverrideWith(override *CustomURLMapping) *CustomURLMapping { + custom := &CustomURLMapping{ + AuthURL: c.AuthURL.Value, + TokenURL: c.TokenURL.Value, + ProfileURL: c.ProfileURL.Value, + EmailURL: c.EmailURL.Value, + Tenant: c.Tenant.Value, + } + if override != nil { + if len(override.AuthURL) > 0 && c.AuthURL.Available { + custom.AuthURL = override.AuthURL + } + if len(override.TokenURL) > 0 && c.TokenURL.Available { + custom.TokenURL = override.TokenURL + } + if len(override.ProfileURL) > 0 && c.ProfileURL.Available { + custom.ProfileURL = override.ProfileURL + } + if len(override.EmailURL) > 0 && c.EmailURL.Available { + custom.EmailURL = override.EmailURL + } + if len(override.Tenant) > 0 && c.Tenant.Available { + custom.Tenant = override.Tenant + } + } + return custom +} diff --git a/services/auth/source/pam/assert_interface_test.go b/services/auth/source/pam/assert_interface_test.go new file mode 100644 index 0000000..8e7648b --- /dev/null +++ b/services/auth/source/pam/assert_interface_test.go @@ -0,0 +1,21 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package pam_test + +import ( + auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/services/auth" + "code.gitea.io/gitea/services/auth/source/pam" +) + +// This test file exists to assert that our Source exposes the interfaces that we expect +// It tightly binds the interfaces and implementation without breaking go import cycles + +type sourceInterface interface { + auth.PasswordAuthenticator + auth_model.Config + auth_model.SourceSettable +} + +var _ (sourceInterface) = &pam.Source{} diff --git a/services/auth/source/pam/source.go b/services/auth/source/pam/source.go new file mode 100644 index 0000000..96b182e --- /dev/null +++ b/services/auth/source/pam/source.go @@ -0,0 +1,45 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package pam + +import ( + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/modules/json" +) + +// __________ _____ _____ +// \______ \/ _ \ / \ +// | ___/ /_\ \ / \ / \ +// | | / | \/ Y \ +// |____| \____|__ /\____|__ / +// \/ \/ + +// Source holds configuration for the PAM login source. +type Source struct { + ServiceName string // pam service (e.g. system-auth) + EmailDomain string + SkipLocalTwoFA bool `json:",omitempty"` // Skip Local 2fa for users authenticated with this source + + // reference to the authSource + authSource *auth.Source +} + +// FromDB fills up a PAMConfig from serialized format. +func (source *Source) FromDB(bs []byte) error { + return json.UnmarshalHandleDoubleEncode(bs, &source) +} + +// ToDB exports a PAMConfig to a serialized format. +func (source *Source) ToDB() ([]byte, error) { + return json.Marshal(source) +} + +// SetAuthSource sets the related AuthSource +func (source *Source) SetAuthSource(authSource *auth.Source) { + source.authSource = authSource +} + +func init() { + auth.RegisterTypeConfig(auth.PAM, &Source{}) +} diff --git a/services/auth/source/pam/source_authenticate.go b/services/auth/source/pam/source_authenticate.go new file mode 100644 index 0000000..addd1bd --- /dev/null +++ b/services/auth/source/pam/source_authenticate.go @@ -0,0 +1,76 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package pam + +import ( + "context" + "fmt" + "strings" + + "code.gitea.io/gitea/models/auth" + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/auth/pam" + "code.gitea.io/gitea/modules/optional" + "code.gitea.io/gitea/modules/setting" + + "github.com/google/uuid" +) + +// Authenticate queries if login/password is valid against the PAM, +// and create a local user if success when enabled. +func (source *Source) Authenticate(ctx context.Context, user *user_model.User, userName, password string) (*user_model.User, error) { + pamLogin, err := pam.Auth(source.ServiceName, userName, password) + if err != nil { + if strings.Contains(err.Error(), "Authentication failure") { + return nil, user_model.ErrUserNotExist{Name: userName} + } + return nil, err + } + + if user != nil { + return user, nil + } + + // Allow PAM sources with `@` in their name, like from Active Directory + username := pamLogin + email := pamLogin + idx := strings.Index(pamLogin, "@") + if idx > -1 { + username = pamLogin[:idx] + } + if user_model.ValidateEmail(email) != nil { + if source.EmailDomain != "" { + email = fmt.Sprintf("%s@%s", username, source.EmailDomain) + } else { + email = fmt.Sprintf("%s@%s", username, setting.Service.NoReplyAddress) + } + if user_model.ValidateEmail(email) != nil { + email = uuid.New().String() + "@localhost" + } + } + + user = &user_model.User{ + LowerName: strings.ToLower(username), + Name: username, + Email: email, + Passwd: password, + LoginType: auth.PAM, + LoginSource: source.authSource.ID, + LoginName: userName, // This is what the user typed in + } + overwriteDefault := &user_model.CreateUserOverwriteOptions{ + IsActive: optional.Some(true), + } + + if err := user_model.CreateUser(ctx, user, overwriteDefault); err != nil { + return user, err + } + + return user, nil +} + +// IsSkipLocalTwoFA returns if this source should skip local 2fa for password authentication +func (source *Source) IsSkipLocalTwoFA() bool { + return source.SkipLocalTwoFA +} diff --git a/services/auth/source/remote/source.go b/services/auth/source/remote/source.go new file mode 100644 index 0000000..4165858 --- /dev/null +++ b/services/auth/source/remote/source.go @@ -0,0 +1,33 @@ +// Copyright Earl Warren <contact@earl-warren.org> +// SPDX-License-Identifier: MIT + +package remote + +import ( + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/modules/json" +) + +type Source struct { + URL string + MatchingSource string + + // reference to the authSource + authSource *auth.Source +} + +func (source *Source) FromDB(bs []byte) error { + return json.UnmarshalHandleDoubleEncode(bs, &source) +} + +func (source *Source) ToDB() ([]byte, error) { + return json.Marshal(source) +} + +func (source *Source) SetAuthSource(authSource *auth.Source) { + source.authSource = authSource +} + +func init() { + auth.RegisterTypeConfig(auth.Remote, &Source{}) +} diff --git a/services/auth/source/smtp/assert_interface_test.go b/services/auth/source/smtp/assert_interface_test.go new file mode 100644 index 0000000..6c9cde6 --- /dev/null +++ b/services/auth/source/smtp/assert_interface_test.go @@ -0,0 +1,24 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package smtp_test + +import ( + auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/services/auth" + "code.gitea.io/gitea/services/auth/source/smtp" +) + +// This test file exists to assert that our Source exposes the interfaces that we expect +// It tightly binds the interfaces and implementation without breaking go import cycles + +type sourceInterface interface { + auth.PasswordAuthenticator + auth_model.Config + auth_model.SkipVerifiable + auth_model.HasTLSer + auth_model.UseTLSer + auth_model.SourceSettable +} + +var _ (sourceInterface) = &smtp.Source{} diff --git a/services/auth/source/smtp/auth.go b/services/auth/source/smtp/auth.go new file mode 100644 index 0000000..6446fcd --- /dev/null +++ b/services/auth/source/smtp/auth.go @@ -0,0 +1,106 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package smtp + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + "net/smtp" + "os" + "strconv" +) + +// _________ __________________________ +// / _____/ / \__ ___/\______ \ +// \_____ \ / \ / \| | | ___/ +// / \/ Y \ | | | +// /_______ /\____|__ /____| |____| +// \/ \/ + +type loginAuthenticator struct { + username, password string +} + +func (auth *loginAuthenticator) Start(server *smtp.ServerInfo) (string, []byte, error) { + return "LOGIN", []byte(auth.username), nil +} + +func (auth *loginAuthenticator) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + switch string(fromServer) { + case "Username:": + return []byte(auth.username), nil + case "Password:": + return []byte(auth.password), nil + } + } + return nil, nil +} + +// SMTP authentication type names. +const ( + PlainAuthentication = "PLAIN" + LoginAuthentication = "LOGIN" + CRAMMD5Authentication = "CRAM-MD5" +) + +// Authenticators contains available SMTP authentication type names. +var Authenticators = []string{PlainAuthentication, LoginAuthentication, CRAMMD5Authentication} + +// ErrUnsupportedLoginType login source is unknown error +var ErrUnsupportedLoginType = errors.New("Login source is unknown") + +// Authenticate performs an SMTP authentication. +func Authenticate(a smtp.Auth, source *Source) error { + tlsConfig := &tls.Config{ + InsecureSkipVerify: source.SkipVerify, + ServerName: source.Host, + } + + conn, err := net.Dial("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port))) + if err != nil { + return err + } + defer conn.Close() + + if source.UseTLS() { + conn = tls.Client(conn, tlsConfig) + } + + client, err := smtp.NewClient(conn, source.Host) + if err != nil { + return fmt.Errorf("failed to create NewClient: %w", err) + } + defer client.Close() + + if !source.DisableHelo { + hostname := source.HeloHostname + if len(hostname) == 0 { + hostname, err = os.Hostname() + if err != nil { + return fmt.Errorf("failed to find Hostname: %w", err) + } + } + + if err = client.Hello(hostname); err != nil { + return fmt.Errorf("failed to send Helo: %w", err) + } + } + + // If not using SMTPS, always use STARTTLS if available + hasStartTLS, _ := client.Extension("STARTTLS") + if !source.UseTLS() && hasStartTLS { + if err = client.StartTLS(tlsConfig); err != nil { + return fmt.Errorf("failed to start StartTLS: %w", err) + } + } + + if ok, _ := client.Extension("AUTH"); ok { + return client.Auth(a) + } + + return ErrUnsupportedLoginType +} diff --git a/services/auth/source/smtp/source.go b/services/auth/source/smtp/source.go new file mode 100644 index 0000000..2a648e4 --- /dev/null +++ b/services/auth/source/smtp/source.go @@ -0,0 +1,66 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package smtp + +import ( + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/modules/json" +) + +// _________ __________________________ +// / _____/ / \__ ___/\______ \ +// \_____ \ / \ / \| | | ___/ +// / \/ Y \ | | | +// /_______ /\____|__ /____| |____| +// \/ \/ + +// Source holds configuration for the SMTP login source. +type Source struct { + Auth string + Host string + Port int + AllowedDomains string `xorm:"TEXT"` + ForceSMTPS bool + SkipVerify bool + HeloHostname string + DisableHelo bool + SkipLocalTwoFA bool `json:",omitempty"` + + // reference to the authSource + authSource *auth.Source +} + +// FromDB fills up an SMTPConfig from serialized format. +func (source *Source) FromDB(bs []byte) error { + return json.UnmarshalHandleDoubleEncode(bs, &source) +} + +// ToDB exports an SMTPConfig to a serialized format. +func (source *Source) ToDB() ([]byte, error) { + return json.Marshal(source) +} + +// IsSkipVerify returns if SkipVerify is set +func (source *Source) IsSkipVerify() bool { + return source.SkipVerify +} + +// HasTLS returns true for SMTP +func (source *Source) HasTLS() bool { + return true +} + +// UseTLS returns if TLS is set +func (source *Source) UseTLS() bool { + return source.ForceSMTPS || source.Port == 465 +} + +// SetAuthSource sets the related AuthSource +func (source *Source) SetAuthSource(authSource *auth.Source) { + source.authSource = authSource +} + +func init() { + auth.RegisterTypeConfig(auth.SMTP, &Source{}) +} diff --git a/services/auth/source/smtp/source_authenticate.go b/services/auth/source/smtp/source_authenticate.go new file mode 100644 index 0000000..1f0a61c --- /dev/null +++ b/services/auth/source/smtp/source_authenticate.go @@ -0,0 +1,92 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package smtp + +import ( + "context" + "errors" + "net/smtp" + "net/textproto" + "strings" + + auth_model "code.gitea.io/gitea/models/auth" + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/optional" + "code.gitea.io/gitea/modules/util" +) + +// Authenticate queries if the provided login/password is authenticates against the SMTP server +// Users will be autoregistered as required +func (source *Source) Authenticate(ctx context.Context, user *user_model.User, userName, password string) (*user_model.User, error) { + // Verify allowed domains. + if len(source.AllowedDomains) > 0 { + idx := strings.Index(userName, "@") + if idx == -1 { + return nil, user_model.ErrUserNotExist{Name: userName} + } else if !util.SliceContainsString(strings.Split(source.AllowedDomains, ","), userName[idx+1:], true) { + return nil, user_model.ErrUserNotExist{Name: userName} + } + } + + var auth smtp.Auth + switch source.Auth { + case PlainAuthentication: + auth = smtp.PlainAuth("", userName, password, source.Host) + case LoginAuthentication: + auth = &loginAuthenticator{userName, password} + case CRAMMD5Authentication: + auth = smtp.CRAMMD5Auth(userName, password) + default: + return nil, errors.New("unsupported SMTP auth type") + } + + if err := Authenticate(auth, source); err != nil { + // Check standard error format first, + // then fallback to worse case. + tperr, ok := err.(*textproto.Error) + if (ok && tperr.Code == 535) || + strings.Contains(err.Error(), "Username and Password not accepted") { + return nil, user_model.ErrUserNotExist{Name: userName} + } + if (ok && tperr.Code == 534) || + strings.Contains(err.Error(), "Application-specific password required") { + return nil, user_model.ErrUserNotExist{Name: userName} + } + return nil, err + } + + if user != nil { + return user, nil + } + + username := userName + idx := strings.Index(userName, "@") + if idx > -1 { + username = userName[:idx] + } + + user = &user_model.User{ + LowerName: strings.ToLower(username), + Name: strings.ToLower(username), + Email: userName, + Passwd: password, + LoginType: auth_model.SMTP, + LoginSource: source.authSource.ID, + LoginName: userName, + } + overwriteDefault := &user_model.CreateUserOverwriteOptions{ + IsActive: optional.Some(true), + } + + if err := user_model.CreateUser(ctx, user, overwriteDefault); err != nil { + return user, err + } + + return user, nil +} + +// IsSkipLocalTwoFA returns if this source should skip local 2fa for password authentication +func (source *Source) IsSkipLocalTwoFA() bool { + return source.SkipLocalTwoFA +} diff --git a/services/auth/source/source_group_sync.go b/services/auth/source/source_group_sync.go new file mode 100644 index 0000000..3a2411e --- /dev/null +++ b/services/auth/source/source_group_sync.go @@ -0,0 +1,116 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package source + +import ( + "context" + "fmt" + + "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/organization" + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/container" + "code.gitea.io/gitea/modules/log" +) + +type syncType int + +const ( + syncAdd syncType = iota + syncRemove +) + +// SyncGroupsToTeams maps authentication source groups to organization and team memberships +func SyncGroupsToTeams(ctx context.Context, user *user_model.User, sourceUserGroups container.Set[string], sourceGroupTeamMapping map[string]map[string][]string, performRemoval bool) error { + orgCache := make(map[string]*organization.Organization) + teamCache := make(map[string]*organization.Team) + return SyncGroupsToTeamsCached(ctx, user, sourceUserGroups, sourceGroupTeamMapping, performRemoval, orgCache, teamCache) +} + +// SyncGroupsToTeamsCached maps authentication source groups to organization and team memberships +func SyncGroupsToTeamsCached(ctx context.Context, user *user_model.User, sourceUserGroups container.Set[string], sourceGroupTeamMapping map[string]map[string][]string, performRemoval bool, orgCache map[string]*organization.Organization, teamCache map[string]*organization.Team) error { + membershipsToAdd, membershipsToRemove := resolveMappedMemberships(sourceUserGroups, sourceGroupTeamMapping) + + if performRemoval { + if err := syncGroupsToTeamsCached(ctx, user, membershipsToRemove, syncRemove, orgCache, teamCache); err != nil { + return fmt.Errorf("could not sync[remove] user groups: %w", err) + } + } + + if err := syncGroupsToTeamsCached(ctx, user, membershipsToAdd, syncAdd, orgCache, teamCache); err != nil { + return fmt.Errorf("could not sync[add] user groups: %w", err) + } + + return nil +} + +func resolveMappedMemberships(sourceUserGroups container.Set[string], sourceGroupTeamMapping map[string]map[string][]string) (map[string][]string, map[string][]string) { + membershipsToAdd := map[string][]string{} + membershipsToRemove := map[string][]string{} + for group, memberships := range sourceGroupTeamMapping { + isUserInGroup := sourceUserGroups.Contains(group) + if isUserInGroup { + for org, teams := range memberships { + membershipsToAdd[org] = append(membershipsToAdd[org], teams...) + } + } else { + for org, teams := range memberships { + membershipsToRemove[org] = append(membershipsToRemove[org], teams...) + } + } + } + return membershipsToAdd, membershipsToRemove +} + +func syncGroupsToTeamsCached(ctx context.Context, user *user_model.User, orgTeamMap map[string][]string, action syncType, orgCache map[string]*organization.Organization, teamCache map[string]*organization.Team) error { + for orgName, teamNames := range orgTeamMap { + var err error + org, ok := orgCache[orgName] + if !ok { + org, err = organization.GetOrgByName(ctx, orgName) + if err != nil { + if organization.IsErrOrgNotExist(err) { + // organization must be created before group sync + log.Warn("group sync: Could not find organisation %s: %v", orgName, err) + continue + } + return err + } + orgCache[orgName] = org + } + for _, teamName := range teamNames { + team, ok := teamCache[orgName+teamName] + if !ok { + team, err = org.GetTeam(ctx, teamName) + if err != nil { + if organization.IsErrTeamNotExist(err) { + // team must be created before group sync + log.Warn("group sync: Could not find team %s: %v", teamName, err) + continue + } + return err + } + teamCache[orgName+teamName] = team + } + + isMember, err := organization.IsTeamMember(ctx, org.ID, team.ID, user.ID) + if err != nil { + return err + } + + if action == syncAdd && !isMember { + if err := models.AddTeamMember(ctx, team, user.ID); err != nil { + log.Error("group sync: Could not add user to team: %v", err) + return err + } + } else if action == syncRemove && isMember { + if err := models.RemoveTeamMember(ctx, team, user.ID); err != nil { + log.Error("group sync: Could not remove user from team: %v", err) + return err + } + } + } + } + return nil +} diff --git a/services/auth/source/sspi/assert_interface_test.go b/services/auth/source/sspi/assert_interface_test.go new file mode 100644 index 0000000..03d836d --- /dev/null +++ b/services/auth/source/sspi/assert_interface_test.go @@ -0,0 +1,18 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package sspi_test + +import ( + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/services/auth/source/sspi" +) + +// This test file exists to assert that our Source exposes the interfaces that we expect +// It tightly binds the interfaces and implementation without breaking go import cycles + +type sourceInterface interface { + auth.Config +} + +var _ (sourceInterface) = &sspi.Source{} diff --git a/services/auth/source/sspi/source.go b/services/auth/source/sspi/source.go new file mode 100644 index 0000000..bdd6ef4 --- /dev/null +++ b/services/auth/source/sspi/source.go @@ -0,0 +1,39 @@ +// Copyright 2021 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package sspi + +import ( + "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/modules/json" +) + +// _________ ___________________.___ +// / _____// _____/\______ \ | +// \_____ \ \_____ \ | ___/ | +// / \/ \ | | | | +// /_______ /_______ / |____| |___| +// \/ \/ + +// Source holds configuration for SSPI single sign-on. +type Source struct { + AutoCreateUsers bool + AutoActivateUsers bool + StripDomainNames bool + SeparatorReplacement string + DefaultLanguage string +} + +// FromDB fills up an SSPIConfig from serialized format. +func (cfg *Source) FromDB(bs []byte) error { + return json.UnmarshalHandleDoubleEncode(bs, &cfg) +} + +// ToDB exports an SSPIConfig to a serialized format. +func (cfg *Source) ToDB() ([]byte, error) { + return json.Marshal(cfg) +} + +func init() { + auth.RegisterTypeConfig(auth.SSPI, &Source{}) +} |