From dd136858f1ea40ad3c94191d647487fa4f31926c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 18 Oct 2024 20:33:49 +0200 Subject: Adding upstream version 9.0.0. Signed-off-by: Daniel Baumann --- modules/nosql/leveldb.go | 24 ++++ modules/nosql/manager.go | 116 ++++++++++++++++ modules/nosql/manager_leveldb.go | 214 ++++++++++++++++++++++++++++++ modules/nosql/manager_redis.go | 258 ++++++++++++++++++++++++++++++++++++ modules/nosql/manager_redis_test.go | 81 +++++++++++ modules/nosql/redis.go | 100 ++++++++++++++ modules/nosql/redis_test.go | 34 +++++ 7 files changed, 827 insertions(+) create mode 100644 modules/nosql/leveldb.go create mode 100644 modules/nosql/manager.go create mode 100644 modules/nosql/manager_leveldb.go create mode 100644 modules/nosql/manager_redis.go create mode 100644 modules/nosql/manager_redis_test.go create mode 100644 modules/nosql/redis.go create mode 100644 modules/nosql/redis_test.go (limited to 'modules/nosql') diff --git a/modules/nosql/leveldb.go b/modules/nosql/leveldb.go new file mode 100644 index 0000000..aac5b21 --- /dev/null +++ b/modules/nosql/leveldb.go @@ -0,0 +1,24 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package nosql + +import "net/url" + +// ToLevelDBURI converts old style connections to a LevelDBURI +// +// A LevelDBURI matches the pattern: +// +// leveldb://path[?[option=value]*] +// +// We have previously just provided the path but this prevent other options +func ToLevelDBURI(connection string) *url.URL { + uri, err := url.Parse(connection) + if err == nil && uri.Scheme == "leveldb" { + return uri + } + uri, _ = url.Parse("leveldb://common") + uri.Host = "" + uri.Path = connection + return uri +} diff --git a/modules/nosql/manager.go b/modules/nosql/manager.go new file mode 100644 index 0000000..0ba2158 --- /dev/null +++ b/modules/nosql/manager.go @@ -0,0 +1,116 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package nosql + +import ( + "context" + "strconv" + "sync" + "time" + + "code.gitea.io/gitea/modules/process" + + "github.com/redis/go-redis/v9" + "github.com/syndtr/goleveldb/leveldb" +) + +var manager *Manager + +// Manager is the nosql connection manager +type Manager struct { + ctx context.Context + finished context.CancelFunc + mutex sync.Mutex + + RedisConnections map[string]*redisClientHolder + LevelDBConnections map[string]*levelDBHolder +} + +// RedisClient is a subset of redis.UniversalClient, it exposes less methods +// to avoid generating machine code for unused methods. New method definitions +// should be copied from the definitions in the Redis library github.com/redis/go-redis. +type RedisClient interface { + // redis.GenericCmdable + Del(ctx context.Context, keys ...string) *redis.IntCmd + Exists(ctx context.Context, keys ...string) *redis.IntCmd + + // redis.ListCmdable + RPush(ctx context.Context, key string, values ...any) *redis.IntCmd + LPop(ctx context.Context, key string) *redis.StringCmd + LLen(ctx context.Context, key string) *redis.IntCmd + + // redis.StringCmdable + Decr(ctx context.Context, key string) *redis.IntCmd + Incr(ctx context.Context, key string) *redis.IntCmd + Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd + Get(ctx context.Context, key string) *redis.StringCmd + + // redis.HashCmdable + HSet(ctx context.Context, key string, values ...any) *redis.IntCmd + HDel(ctx context.Context, key string, fields ...string) *redis.IntCmd + HKeys(ctx context.Context, key string) *redis.StringSliceCmd + + // redis.SetCmdable + SAdd(ctx context.Context, key string, members ...any) *redis.IntCmd + SRem(ctx context.Context, key string, members ...any) *redis.IntCmd + SIsMember(ctx context.Context, key string, member any) *redis.BoolCmd + + // redis.Cmdable + DBSize(ctx context.Context) *redis.IntCmd + FlushDB(ctx context.Context) *redis.StatusCmd + Ping(ctx context.Context) *redis.StatusCmd + + // redis.UniversalClient + Close() error +} + +type redisClientHolder struct { + RedisClient + name []string + count int64 +} + +func (r *redisClientHolder) Close() error { + return manager.CloseRedisClient(r.name[0]) +} + +type levelDBHolder struct { + name []string + count int64 + db *leveldb.DB +} + +func init() { + _ = GetManager() +} + +// GetManager returns a Manager and initializes one as singleton is there's none yet +func GetManager() *Manager { + if manager == nil { + ctx, _, finished := process.GetManager().AddTypedContext(context.Background(), "Service: NoSQL", process.SystemProcessType, false) + manager = &Manager{ + ctx: ctx, + finished: finished, + RedisConnections: make(map[string]*redisClientHolder), + LevelDBConnections: make(map[string]*levelDBHolder), + } + } + return manager +} + +func valToTimeDuration(vs []string) (result time.Duration) { + var err error + for _, v := range vs { + result, err = time.ParseDuration(v) + if err != nil { + var val int + val, err = strconv.Atoi(v) + result = time.Duration(val) + } + if err == nil { + return result + } + } + return result +} diff --git a/modules/nosql/manager_leveldb.go b/modules/nosql/manager_leveldb.go new file mode 100644 index 0000000..4d2c90d --- /dev/null +++ b/modules/nosql/manager_leveldb.go @@ -0,0 +1,214 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package nosql + +import ( + "fmt" + "path" + "runtime/pprof" + "strconv" + "strings" + + "code.gitea.io/gitea/modules/log" + + "github.com/syndtr/goleveldb/leveldb" + "github.com/syndtr/goleveldb/leveldb/errors" + "github.com/syndtr/goleveldb/leveldb/opt" +) + +// CloseLevelDB closes a levelDB +func (m *Manager) CloseLevelDB(connection string) error { + m.mutex.Lock() + defer m.mutex.Unlock() + db, ok := m.LevelDBConnections[connection] + if !ok { + // Try the full URI + uri := ToLevelDBURI(connection) + db, ok = m.LevelDBConnections[uri.String()] + + if !ok { + // Try the datadir directly + dataDir := path.Join(uri.Host, uri.Path) + + db, ok = m.LevelDBConnections[dataDir] + } + } + if !ok { + return nil + } + + db.count-- + if db.count > 0 { + return nil + } + + for _, name := range db.name { + delete(m.LevelDBConnections, name) + } + return db.db.Close() +} + +// GetLevelDB gets a levelDB for a particular connection +func (m *Manager) GetLevelDB(connection string) (db *leveldb.DB, err error) { + // Because we want associate any goroutines created by this call to the main nosqldb context we need to + // wrap this in a goroutine labelled with the nosqldb context + done := make(chan struct{}) + var recovered any + go func() { + defer func() { + recovered = recover() + if recovered != nil { + log.Critical("PANIC during GetLevelDB: %v\nStacktrace: %s", recovered, log.Stack(2)) + } + close(done) + }() + pprof.SetGoroutineLabels(m.ctx) + + db, err = m.getLevelDB(connection) + }() + <-done + if recovered != nil { + panic(recovered) + } + return db, err +} + +func (m *Manager) getLevelDB(connection string) (*leveldb.DB, error) { + // Convert the provided connection description to the common format + uri := ToLevelDBURI(connection) + + // Get the datadir + dataDir := path.Join(uri.Host, uri.Path) + + m.mutex.Lock() + defer m.mutex.Unlock() + db, ok := m.LevelDBConnections[connection] + if ok { + db.count++ + + return db.db, nil + } + + db, ok = m.LevelDBConnections[uri.String()] + if ok { + db.count++ + + return db.db, nil + } + + // if there is already a connection to this leveldb reuse that + // NOTE: if there differing options then only the first leveldb connection will be used + db, ok = m.LevelDBConnections[dataDir] + if ok { + db.count++ + log.Warn("Duplicate connection to level db: %s with different connection strings. Initial connection: %s. This connection: %s", dataDir, db.name[0], connection) + db.name = append(db.name, connection) + m.LevelDBConnections[connection] = db + return db.db, nil + } + db = &levelDBHolder{ + name: []string{connection, uri.String(), dataDir}, + } + + opts := &opt.Options{} + for k, v := range uri.Query() { + switch replacer.Replace(strings.ToLower(k)) { + case "blockcachecapacity": + opts.BlockCacheCapacity, _ = strconv.Atoi(v[0]) + case "blockcacheevictremoved": + opts.BlockCacheEvictRemoved, _ = strconv.ParseBool(v[0]) + case "blockrestartinterval": + opts.BlockRestartInterval, _ = strconv.Atoi(v[0]) + case "blocksize": + opts.BlockSize, _ = strconv.Atoi(v[0]) + case "compactionexpandlimitfactor": + opts.CompactionExpandLimitFactor, _ = strconv.Atoi(v[0]) + case "compactiongpoverlapsfactor": + opts.CompactionGPOverlapsFactor, _ = strconv.Atoi(v[0]) + case "compactionl0trigger": + opts.CompactionL0Trigger, _ = strconv.Atoi(v[0]) + case "compactionsourcelimitfactor": + opts.CompactionSourceLimitFactor, _ = strconv.Atoi(v[0]) + case "compactiontablesize": + opts.CompactionTableSize, _ = strconv.Atoi(v[0]) + case "compactiontablesizemultiplier": + opts.CompactionTableSizeMultiplier, _ = strconv.ParseFloat(v[0], 64) + case "compactiontablesizemultiplierperlevel": + for _, val := range v { + f, _ := strconv.ParseFloat(val, 64) + opts.CompactionTableSizeMultiplierPerLevel = append(opts.CompactionTableSizeMultiplierPerLevel, f) + } + case "compactiontotalsize": + opts.CompactionTotalSize, _ = strconv.Atoi(v[0]) + case "compactiontotalsizemultiplier": + opts.CompactionTotalSizeMultiplier, _ = strconv.ParseFloat(v[0], 64) + case "compactiontotalsizemultiplierperlevel": + for _, val := range v { + f, _ := strconv.ParseFloat(val, 64) + opts.CompactionTotalSizeMultiplierPerLevel = append(opts.CompactionTotalSizeMultiplierPerLevel, f) + } + case "compression": + val, _ := strconv.Atoi(v[0]) + opts.Compression = opt.Compression(val) + case "disablebufferpool": + opts.DisableBufferPool, _ = strconv.ParseBool(v[0]) + case "disableblockcache": + opts.DisableBlockCache, _ = strconv.ParseBool(v[0]) + case "disablecompactionbackoff": + opts.DisableCompactionBackoff, _ = strconv.ParseBool(v[0]) + case "disablelargebatchtransaction": + opts.DisableLargeBatchTransaction, _ = strconv.ParseBool(v[0]) + case "errorifexist": + opts.ErrorIfExist, _ = strconv.ParseBool(v[0]) + case "errorifmissing": + opts.ErrorIfMissing, _ = strconv.ParseBool(v[0]) + case "iteratorsamplingrate": + opts.IteratorSamplingRate, _ = strconv.Atoi(v[0]) + case "nosync": + opts.NoSync, _ = strconv.ParseBool(v[0]) + case "nowritemerge": + opts.NoWriteMerge, _ = strconv.ParseBool(v[0]) + case "openfilescachecapacity": + opts.OpenFilesCacheCapacity, _ = strconv.Atoi(v[0]) + case "readonly": + opts.ReadOnly, _ = strconv.ParseBool(v[0]) + case "strict": + val, _ := strconv.Atoi(v[0]) + opts.Strict = opt.Strict(val) + case "writebuffer": + opts.WriteBuffer, _ = strconv.Atoi(v[0]) + case "writel0pausetrigger": + opts.WriteL0PauseTrigger, _ = strconv.Atoi(v[0]) + case "writel0slowdowntrigger": + opts.WriteL0SlowdownTrigger, _ = strconv.Atoi(v[0]) + case "clientname": + db.name = append(db.name, v[0]) + } + } + + var err error + db.db, err = leveldb.OpenFile(dataDir, opts) + if err != nil { + if !errors.IsCorrupted(err) { + if strings.Contains(err.Error(), "resource temporarily unavailable") { + err = fmt.Errorf("unable to lock level db at %s: %w", dataDir, err) + return nil, err + } + + err = fmt.Errorf("unable to open level db at %s: %w", dataDir, err) + return nil, err + } + db.db, err = leveldb.RecoverFile(dataDir, opts) + } + + if err != nil { + return nil, err + } + + for _, name := range db.name { + m.LevelDBConnections[name] = db + } + db.count++ + return db.db, nil +} diff --git a/modules/nosql/manager_redis.go b/modules/nosql/manager_redis.go new file mode 100644 index 0000000..79a533b --- /dev/null +++ b/modules/nosql/manager_redis.go @@ -0,0 +1,258 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package nosql + +import ( + "crypto/tls" + "net/url" + "path" + "runtime/pprof" + "strconv" + "strings" + + "code.gitea.io/gitea/modules/log" + + "github.com/redis/go-redis/v9" +) + +var replacer = strings.NewReplacer("_", "", "-", "") + +// CloseRedisClient closes a redis client +func (m *Manager) CloseRedisClient(connection string) error { + m.mutex.Lock() + defer m.mutex.Unlock() + client, ok := m.RedisConnections[connection] + if !ok { + connection = ToRedisURI(connection).String() + client, ok = m.RedisConnections[connection] + } + if !ok { + return nil + } + + client.count-- + if client.count > 0 { + return nil + } + + for _, name := range client.name { + delete(m.RedisConnections, name) + } + return client.RedisClient.Close() +} + +// GetRedisClient gets a redis client for a particular connection +func (m *Manager) GetRedisClient(connection string) (client RedisClient) { + // Because we want associate any goroutines created by this call to the main nosqldb context we need to + // wrap this in a goroutine labelled with the nosqldb context + done := make(chan struct{}) + var recovered any + go func() { + defer func() { + recovered = recover() + if recovered != nil { + log.Critical("PANIC during GetRedisClient: %v\nStacktrace: %s", recovered, log.Stack(2)) + } + close(done) + }() + pprof.SetGoroutineLabels(m.ctx) + + client = m.getRedisClient(connection) + }() + <-done + if recovered != nil { + panic(recovered) + } + return client +} + +func (m *Manager) getRedisClient(connection string) RedisClient { + m.mutex.Lock() + defer m.mutex.Unlock() + client, ok := m.RedisConnections[connection] + if ok { + client.count++ + return client + } + + uri := ToRedisURI(connection) + client, ok = m.RedisConnections[uri.String()] + if ok { + client.count++ + return client + } + client = &redisClientHolder{ + name: []string{connection, uri.String()}, + } + + opts := getRedisOptions(uri) + tlsConfig := getRedisTLSOptions(uri) + + clientName := uri.Query().Get("clientname") + + if len(clientName) > 0 { + client.name = append(client.name, clientName) + } + + switch uri.Scheme { + case "redis+sentinels": + fallthrough + case "rediss+sentinel": + opts.TLSConfig = tlsConfig + fallthrough + case "redis+sentinel": + client.RedisClient = redis.NewFailoverClient(opts.Failover()) + case "redis+clusters": + fallthrough + case "rediss+cluster": + opts.TLSConfig = tlsConfig + fallthrough + case "redis+cluster": + client.RedisClient = redis.NewClusterClient(opts.Cluster()) + case "redis+socket": + simpleOpts := opts.Simple() + simpleOpts.Network = "unix" + simpleOpts.Addr = path.Join(uri.Host, uri.Path) + client.RedisClient = redis.NewClient(simpleOpts) + case "rediss": + opts.TLSConfig = tlsConfig + fallthrough + case "redis": + client.RedisClient = redis.NewClient(opts.Simple()) + default: + return nil + } + + for _, name := range client.name { + m.RedisConnections[name] = client + } + + client.count++ + + return client +} + +// getRedisOptions pulls various configuration options based on the RedisUri format and converts them to go-redis's +// UniversalOptions fields. This function explicitly excludes fields related to TLS configuration, which is +// conditionally attached to this options struct before being converted to the specific type for the redis scheme being +// used, and only in scenarios where TLS is applicable (e.g. rediss://, redis+clusters://). +func getRedisOptions(uri *url.URL) *redis.UniversalOptions { + opts := &redis.UniversalOptions{} + + // Handle username/password + if password, ok := uri.User.Password(); ok { + opts.Password = password + // Username does not appear to be handled by redis.Options + opts.Username = uri.User.Username() + } else if uri.User.Username() != "" { + // assume this is the password + opts.Password = uri.User.Username() + } + + // Now handle the uri query sets + for k, v := range uri.Query() { + switch replacer.Replace(strings.ToLower(k)) { + case "addr": + opts.Addrs = append(opts.Addrs, v...) + case "addrs": + opts.Addrs = append(opts.Addrs, strings.Split(v[0], ",")...) + case "username": + opts.Username = v[0] + case "password": + opts.Password = v[0] + case "database": + fallthrough + case "db": + opts.DB, _ = strconv.Atoi(v[0]) + case "maxretries": + opts.MaxRetries, _ = strconv.Atoi(v[0]) + case "minretrybackoff": + opts.MinRetryBackoff = valToTimeDuration(v) + case "maxretrybackoff": + opts.MaxRetryBackoff = valToTimeDuration(v) + case "timeout": + timeout := valToTimeDuration(v) + if timeout != 0 { + if opts.DialTimeout == 0 { + opts.DialTimeout = timeout + } + if opts.ReadTimeout == 0 { + opts.ReadTimeout = timeout + } + } + case "dialtimeout": + opts.DialTimeout = valToTimeDuration(v) + case "readtimeout": + opts.ReadTimeout = valToTimeDuration(v) + case "writetimeout": + opts.WriteTimeout = valToTimeDuration(v) + case "poolsize": + opts.PoolSize, _ = strconv.Atoi(v[0]) + case "minidleconns": + opts.MinIdleConns, _ = strconv.Atoi(v[0]) + case "pooltimeout": + opts.PoolTimeout = valToTimeDuration(v) + case "maxredirects": + opts.MaxRedirects, _ = strconv.Atoi(v[0]) + case "readonly": + opts.ReadOnly, _ = strconv.ParseBool(v[0]) + case "routebylatency": + opts.RouteByLatency, _ = strconv.ParseBool(v[0]) + case "routerandomly": + opts.RouteRandomly, _ = strconv.ParseBool(v[0]) + case "sentinelmasterid": + fallthrough + case "mastername": + opts.MasterName = v[0] + case "sentinelusername": + opts.SentinelUsername = v[0] + case "sentinelpassword": + opts.SentinelPassword = v[0] + } + } + + if uri.Host != "" { + opts.Addrs = append(opts.Addrs, strings.Split(uri.Host, ",")...) + } + + // A redis connection string uses the path section of the URI in two different ways. In a TCP-based connection, the + // path will be a database index to automatically have the client SELECT. In a Unix socket connection, it will be the + // file path. We only want to try to coerce this to the database index when we're not expecting a file path so that + // the error log stays clean. + if uri.Path != "" && uri.Scheme != "redis+socket" { + if db, err := strconv.Atoi(uri.Path[1:]); err == nil { + opts.DB = db + } else { + log.Error("Provided database identifier '%s' is not a valid integer. Forgejo will ignore this option.", uri.Path) + } + } + + return opts +} + +// getRedisTlsOptions parses RedisUri TLS configuration parameters and converts them to the go TLS configuration +// equivalent fields. +func getRedisTLSOptions(uri *url.URL) *tls.Config { + tlsConfig := &tls.Config{} + + skipverify := uri.Query().Get("skipverify") + + if len(skipverify) > 0 { + skipverify, err := strconv.ParseBool(skipverify) + if err == nil { + tlsConfig.InsecureSkipVerify = skipverify + } + } + + insecureskipverify := uri.Query().Get("insecureskipverify") + + if len(insecureskipverify) > 0 { + insecureskipverify, err := strconv.ParseBool(insecureskipverify) + if err == nil { + tlsConfig.InsecureSkipVerify = insecureskipverify + } + } + + return tlsConfig +} diff --git a/modules/nosql/manager_redis_test.go b/modules/nosql/manager_redis_test.go new file mode 100644 index 0000000..d979ea0 --- /dev/null +++ b/modules/nosql/manager_redis_test.go @@ -0,0 +1,81 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package nosql + +import ( + "net/url" + "testing" +) + +func TestRedisUsernameOpt(t *testing.T) { + uri, _ := url.Parse("redis://redis:password@myredis/0") + opts := getRedisOptions(uri) + + if opts.Username != "redis" { + t.Fail() + } +} + +func TestRedisPasswordOpt(t *testing.T) { + uri, _ := url.Parse("redis://redis:password@myredis/0") + opts := getRedisOptions(uri) + + if opts.Password != "password" { + t.Fail() + } +} + +func TestSkipVerifyOpt(t *testing.T) { + uri, _ := url.Parse("rediss://myredis/0?skipverify=true") + tlsConfig := getRedisTLSOptions(uri) + + if !tlsConfig.InsecureSkipVerify { + t.Fail() + } +} + +func TestInsecureSkipVerifyOpt(t *testing.T) { + uri, _ := url.Parse("rediss://myredis/0?insecureskipverify=true") + tlsConfig := getRedisTLSOptions(uri) + + if !tlsConfig.InsecureSkipVerify { + t.Fail() + } +} + +func TestRedisSentinelUsernameOpt(t *testing.T) { + uri, _ := url.Parse("redis+sentinel://redis:password@myredis/0?sentinelusername=suser&sentinelpassword=spass") + opts := getRedisOptions(uri).Failover() + + if opts.SentinelUsername != "suser" { + t.Fail() + } +} + +func TestRedisSentinelPasswordOpt(t *testing.T) { + uri, _ := url.Parse("redis+sentinel://redis:password@myredis/0?sentinelusername=suser&sentinelpassword=spass") + opts := getRedisOptions(uri).Failover() + + if opts.SentinelPassword != "spass" { + t.Fail() + } +} + +func TestRedisDatabaseIndexTcp(t *testing.T) { + uri, _ := url.Parse("redis://redis:password@myredis/12") + opts := getRedisOptions(uri) + + if opts.DB != 12 { + t.Fail() + } +} + +func TestRedisDatabaseIndexUnix(t *testing.T) { + uri, _ := url.Parse("redis+socket:///var/run/redis.sock?database=12") + opts := getRedisOptions(uri) + + if opts.DB != 12 { + t.Fail() + } +} diff --git a/modules/nosql/redis.go b/modules/nosql/redis.go new file mode 100644 index 0000000..52e8ff9 --- /dev/null +++ b/modules/nosql/redis.go @@ -0,0 +1,100 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package nosql + +import ( + "net/url" + "strconv" + "strings" +) + +// The file contains common redis connection functions + +// ToRedisURI converts old style connections to a RedisURI +// +// A RedisURI matches the pattern: +// +// redis://[username:password@]host[:port][/database][?[option=value]*] +// rediss://[username:password@]host[:port][/database][?[option=value]*] +// redis+socket://[username:password@]path[/database][?[option=value]*] +// redis+sentinel://[password@]host1 [: port1][, host2 [:port2]][, hostN [:portN]][/ database][?[option=value]*] +// redis+cluster://[password@]host1 [: port1][, host2 [:port2]][, hostN [:portN]][/ database][?[option=value]*] +// +// We have previously used a URI like: +// addrs=127.0.0.1:6379 db=0 +// network=tcp,addr=127.0.0.1:6379,password=macaron,db=0,pool_size=100,idle_timeout=180 +// +// We need to convert this old style to the new style +func ToRedisURI(connection string) *url.URL { + uri, err := url.Parse(connection) + if err == nil && strings.HasPrefix(uri.Scheme, "redis") { + // OK we're going to assume that this is a reasonable redis URI + return uri + } + + // Let's set a nice default + uri, _ = url.Parse("redis://127.0.0.1:6379/0") + network := "tcp" + query := uri.Query() + + // OK so there are two types: Space delimited and Comma delimited + // Let's assume that we have a space delimited string - as this is the most common + fields := strings.Fields(connection) + if len(fields) == 1 { + // It's a comma delimited string, then... + fields = strings.Split(connection, ",") + } + for _, f := range fields { + items := strings.SplitN(f, "=", 2) + if len(items) < 2 { + continue + } + switch strings.ToLower(items[0]) { + case "network": + if items[1] == "unix" { + uri.Scheme = "redis+socket" + } + network = items[1] + case "addrs": + uri.Host = items[1] + // now we need to handle the clustering + if strings.Contains(items[1], ",") && network == "tcp" { + uri.Scheme = "redis+cluster" + } + case "addr": + uri.Host = items[1] + case "password": + uri.User = url.UserPassword(uri.User.Username(), items[1]) + case "username": + password, set := uri.User.Password() + if !set { + uri.User = url.User(items[1]) + } else { + uri.User = url.UserPassword(items[1], password) + } + case "db": + uri.Path = "/" + items[1] + case "idle_timeout": + _, err := strconv.Atoi(items[1]) + if err == nil { + query.Add("idle_timeout", items[1]+"s") + } else { + query.Add("idle_timeout", items[1]) + } + default: + // Other options become query params + query.Add(items[0], items[1]) + } + } + + // Finally we need to fix up the Host if we have a unix port + if uri.Scheme == "redis+socket" { + query.Set("db", uri.Path) + uri.Path = uri.Host + uri.Host = "" + } + uri.RawQuery = query.Encode() + + return uri +} diff --git a/modules/nosql/redis_test.go b/modules/nosql/redis_test.go new file mode 100644 index 0000000..43652e3 --- /dev/null +++ b/modules/nosql/redis_test.go @@ -0,0 +1,34 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package nosql + +import ( + "testing" +) + +func TestToRedisURI(t *testing.T) { + tests := []struct { + name string + connection string + want string + }{ + { + name: "old_default", + connection: "addrs=127.0.0.1:6379 db=0", + want: "redis://127.0.0.1:6379/0", + }, + { + name: "old_macaron_session_default", + connection: "network=tcp,addr=127.0.0.1:6379,password=macaron,db=0,pool_size=100,idle_timeout=180", + want: "redis://:macaron@127.0.0.1:6379/0?idle_timeout=180s&pool_size=100", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ToRedisURI(tt.connection); got == nil || got.String() != tt.want { + t.Errorf(`ToRedisURI(%q) = %s, want %s`, tt.connection, got.String(), tt.want) + } + }) + } +} -- cgit v1.2.3