summaryrefslogtreecommitdiffstats
path: root/modules/nosql
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--modules/nosql/leveldb.go24
-rw-r--r--modules/nosql/manager.go116
-rw-r--r--modules/nosql/manager_leveldb.go214
-rw-r--r--modules/nosql/manager_redis.go258
-rw-r--r--modules/nosql/manager_redis_test.go81
-rw-r--r--modules/nosql/redis.go100
-rw-r--r--modules/nosql/redis_test.go34
7 files changed, 827 insertions, 0 deletions
diff --git a/modules/nosql/leveldb.go b/modules/nosql/leveldb.go
new file mode 100644
index 0000000..aac5b21
--- /dev/null
+++ b/modules/nosql/leveldb.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package nosql
+
+import "net/url"
+
+// ToLevelDBURI converts old style connections to a LevelDBURI
+//
+// A LevelDBURI matches the pattern:
+//
+// leveldb://path[?[option=value]*]
+//
+// We have previously just provided the path but this prevent other options
+func ToLevelDBURI(connection string) *url.URL {
+ uri, err := url.Parse(connection)
+ if err == nil && uri.Scheme == "leveldb" {
+ return uri
+ }
+ uri, _ = url.Parse("leveldb://common")
+ uri.Host = ""
+ uri.Path = connection
+ return uri
+}
diff --git a/modules/nosql/manager.go b/modules/nosql/manager.go
new file mode 100644
index 0000000..0ba2158
--- /dev/null
+++ b/modules/nosql/manager.go
@@ -0,0 +1,116 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package nosql
+
+import (
+ "context"
+ "strconv"
+ "sync"
+ "time"
+
+ "code.gitea.io/gitea/modules/process"
+
+ "github.com/redis/go-redis/v9"
+ "github.com/syndtr/goleveldb/leveldb"
+)
+
+var manager *Manager
+
+// Manager is the nosql connection manager
+type Manager struct {
+ ctx context.Context
+ finished context.CancelFunc
+ mutex sync.Mutex
+
+ RedisConnections map[string]*redisClientHolder
+ LevelDBConnections map[string]*levelDBHolder
+}
+
+// RedisClient is a subset of redis.UniversalClient, it exposes less methods
+// to avoid generating machine code for unused methods. New method definitions
+// should be copied from the definitions in the Redis library github.com/redis/go-redis.
+type RedisClient interface {
+ // redis.GenericCmdable
+ Del(ctx context.Context, keys ...string) *redis.IntCmd
+ Exists(ctx context.Context, keys ...string) *redis.IntCmd
+
+ // redis.ListCmdable
+ RPush(ctx context.Context, key string, values ...any) *redis.IntCmd
+ LPop(ctx context.Context, key string) *redis.StringCmd
+ LLen(ctx context.Context, key string) *redis.IntCmd
+
+ // redis.StringCmdable
+ Decr(ctx context.Context, key string) *redis.IntCmd
+ Incr(ctx context.Context, key string) *redis.IntCmd
+ Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd
+ Get(ctx context.Context, key string) *redis.StringCmd
+
+ // redis.HashCmdable
+ HSet(ctx context.Context, key string, values ...any) *redis.IntCmd
+ HDel(ctx context.Context, key string, fields ...string) *redis.IntCmd
+ HKeys(ctx context.Context, key string) *redis.StringSliceCmd
+
+ // redis.SetCmdable
+ SAdd(ctx context.Context, key string, members ...any) *redis.IntCmd
+ SRem(ctx context.Context, key string, members ...any) *redis.IntCmd
+ SIsMember(ctx context.Context, key string, member any) *redis.BoolCmd
+
+ // redis.Cmdable
+ DBSize(ctx context.Context) *redis.IntCmd
+ FlushDB(ctx context.Context) *redis.StatusCmd
+ Ping(ctx context.Context) *redis.StatusCmd
+
+ // redis.UniversalClient
+ Close() error
+}
+
+type redisClientHolder struct {
+ RedisClient
+ name []string
+ count int64
+}
+
+func (r *redisClientHolder) Close() error {
+ return manager.CloseRedisClient(r.name[0])
+}
+
+type levelDBHolder struct {
+ name []string
+ count int64
+ db *leveldb.DB
+}
+
+func init() {
+ _ = GetManager()
+}
+
+// GetManager returns a Manager and initializes one as singleton is there's none yet
+func GetManager() *Manager {
+ if manager == nil {
+ ctx, _, finished := process.GetManager().AddTypedContext(context.Background(), "Service: NoSQL", process.SystemProcessType, false)
+ manager = &Manager{
+ ctx: ctx,
+ finished: finished,
+ RedisConnections: make(map[string]*redisClientHolder),
+ LevelDBConnections: make(map[string]*levelDBHolder),
+ }
+ }
+ return manager
+}
+
+func valToTimeDuration(vs []string) (result time.Duration) {
+ var err error
+ for _, v := range vs {
+ result, err = time.ParseDuration(v)
+ if err != nil {
+ var val int
+ val, err = strconv.Atoi(v)
+ result = time.Duration(val)
+ }
+ if err == nil {
+ return result
+ }
+ }
+ return result
+}
diff --git a/modules/nosql/manager_leveldb.go b/modules/nosql/manager_leveldb.go
new file mode 100644
index 0000000..4d2c90d
--- /dev/null
+++ b/modules/nosql/manager_leveldb.go
@@ -0,0 +1,214 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package nosql
+
+import (
+ "fmt"
+ "path"
+ "runtime/pprof"
+ "strconv"
+ "strings"
+
+ "code.gitea.io/gitea/modules/log"
+
+ "github.com/syndtr/goleveldb/leveldb"
+ "github.com/syndtr/goleveldb/leveldb/errors"
+ "github.com/syndtr/goleveldb/leveldb/opt"
+)
+
+// CloseLevelDB closes a levelDB
+func (m *Manager) CloseLevelDB(connection string) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+ db, ok := m.LevelDBConnections[connection]
+ if !ok {
+ // Try the full URI
+ uri := ToLevelDBURI(connection)
+ db, ok = m.LevelDBConnections[uri.String()]
+
+ if !ok {
+ // Try the datadir directly
+ dataDir := path.Join(uri.Host, uri.Path)
+
+ db, ok = m.LevelDBConnections[dataDir]
+ }
+ }
+ if !ok {
+ return nil
+ }
+
+ db.count--
+ if db.count > 0 {
+ return nil
+ }
+
+ for _, name := range db.name {
+ delete(m.LevelDBConnections, name)
+ }
+ return db.db.Close()
+}
+
+// GetLevelDB gets a levelDB for a particular connection
+func (m *Manager) GetLevelDB(connection string) (db *leveldb.DB, err error) {
+ // Because we want associate any goroutines created by this call to the main nosqldb context we need to
+ // wrap this in a goroutine labelled with the nosqldb context
+ done := make(chan struct{})
+ var recovered any
+ go func() {
+ defer func() {
+ recovered = recover()
+ if recovered != nil {
+ log.Critical("PANIC during GetLevelDB: %v\nStacktrace: %s", recovered, log.Stack(2))
+ }
+ close(done)
+ }()
+ pprof.SetGoroutineLabels(m.ctx)
+
+ db, err = m.getLevelDB(connection)
+ }()
+ <-done
+ if recovered != nil {
+ panic(recovered)
+ }
+ return db, err
+}
+
+func (m *Manager) getLevelDB(connection string) (*leveldb.DB, error) {
+ // Convert the provided connection description to the common format
+ uri := ToLevelDBURI(connection)
+
+ // Get the datadir
+ dataDir := path.Join(uri.Host, uri.Path)
+
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+ db, ok := m.LevelDBConnections[connection]
+ if ok {
+ db.count++
+
+ return db.db, nil
+ }
+
+ db, ok = m.LevelDBConnections[uri.String()]
+ if ok {
+ db.count++
+
+ return db.db, nil
+ }
+
+ // if there is already a connection to this leveldb reuse that
+ // NOTE: if there differing options then only the first leveldb connection will be used
+ db, ok = m.LevelDBConnections[dataDir]
+ if ok {
+ db.count++
+ log.Warn("Duplicate connection to level db: %s with different connection strings. Initial connection: %s. This connection: %s", dataDir, db.name[0], connection)
+ db.name = append(db.name, connection)
+ m.LevelDBConnections[connection] = db
+ return db.db, nil
+ }
+ db = &levelDBHolder{
+ name: []string{connection, uri.String(), dataDir},
+ }
+
+ opts := &opt.Options{}
+ for k, v := range uri.Query() {
+ switch replacer.Replace(strings.ToLower(k)) {
+ case "blockcachecapacity":
+ opts.BlockCacheCapacity, _ = strconv.Atoi(v[0])
+ case "blockcacheevictremoved":
+ opts.BlockCacheEvictRemoved, _ = strconv.ParseBool(v[0])
+ case "blockrestartinterval":
+ opts.BlockRestartInterval, _ = strconv.Atoi(v[0])
+ case "blocksize":
+ opts.BlockSize, _ = strconv.Atoi(v[0])
+ case "compactionexpandlimitfactor":
+ opts.CompactionExpandLimitFactor, _ = strconv.Atoi(v[0])
+ case "compactiongpoverlapsfactor":
+ opts.CompactionGPOverlapsFactor, _ = strconv.Atoi(v[0])
+ case "compactionl0trigger":
+ opts.CompactionL0Trigger, _ = strconv.Atoi(v[0])
+ case "compactionsourcelimitfactor":
+ opts.CompactionSourceLimitFactor, _ = strconv.Atoi(v[0])
+ case "compactiontablesize":
+ opts.CompactionTableSize, _ = strconv.Atoi(v[0])
+ case "compactiontablesizemultiplier":
+ opts.CompactionTableSizeMultiplier, _ = strconv.ParseFloat(v[0], 64)
+ case "compactiontablesizemultiplierperlevel":
+ for _, val := range v {
+ f, _ := strconv.ParseFloat(val, 64)
+ opts.CompactionTableSizeMultiplierPerLevel = append(opts.CompactionTableSizeMultiplierPerLevel, f)
+ }
+ case "compactiontotalsize":
+ opts.CompactionTotalSize, _ = strconv.Atoi(v[0])
+ case "compactiontotalsizemultiplier":
+ opts.CompactionTotalSizeMultiplier, _ = strconv.ParseFloat(v[0], 64)
+ case "compactiontotalsizemultiplierperlevel":
+ for _, val := range v {
+ f, _ := strconv.ParseFloat(val, 64)
+ opts.CompactionTotalSizeMultiplierPerLevel = append(opts.CompactionTotalSizeMultiplierPerLevel, f)
+ }
+ case "compression":
+ val, _ := strconv.Atoi(v[0])
+ opts.Compression = opt.Compression(val)
+ case "disablebufferpool":
+ opts.DisableBufferPool, _ = strconv.ParseBool(v[0])
+ case "disableblockcache":
+ opts.DisableBlockCache, _ = strconv.ParseBool(v[0])
+ case "disablecompactionbackoff":
+ opts.DisableCompactionBackoff, _ = strconv.ParseBool(v[0])
+ case "disablelargebatchtransaction":
+ opts.DisableLargeBatchTransaction, _ = strconv.ParseBool(v[0])
+ case "errorifexist":
+ opts.ErrorIfExist, _ = strconv.ParseBool(v[0])
+ case "errorifmissing":
+ opts.ErrorIfMissing, _ = strconv.ParseBool(v[0])
+ case "iteratorsamplingrate":
+ opts.IteratorSamplingRate, _ = strconv.Atoi(v[0])
+ case "nosync":
+ opts.NoSync, _ = strconv.ParseBool(v[0])
+ case "nowritemerge":
+ opts.NoWriteMerge, _ = strconv.ParseBool(v[0])
+ case "openfilescachecapacity":
+ opts.OpenFilesCacheCapacity, _ = strconv.Atoi(v[0])
+ case "readonly":
+ opts.ReadOnly, _ = strconv.ParseBool(v[0])
+ case "strict":
+ val, _ := strconv.Atoi(v[0])
+ opts.Strict = opt.Strict(val)
+ case "writebuffer":
+ opts.WriteBuffer, _ = strconv.Atoi(v[0])
+ case "writel0pausetrigger":
+ opts.WriteL0PauseTrigger, _ = strconv.Atoi(v[0])
+ case "writel0slowdowntrigger":
+ opts.WriteL0SlowdownTrigger, _ = strconv.Atoi(v[0])
+ case "clientname":
+ db.name = append(db.name, v[0])
+ }
+ }
+
+ var err error
+ db.db, err = leveldb.OpenFile(dataDir, opts)
+ if err != nil {
+ if !errors.IsCorrupted(err) {
+ if strings.Contains(err.Error(), "resource temporarily unavailable") {
+ err = fmt.Errorf("unable to lock level db at %s: %w", dataDir, err)
+ return nil, err
+ }
+
+ err = fmt.Errorf("unable to open level db at %s: %w", dataDir, err)
+ return nil, err
+ }
+ db.db, err = leveldb.RecoverFile(dataDir, opts)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ for _, name := range db.name {
+ m.LevelDBConnections[name] = db
+ }
+ db.count++
+ return db.db, nil
+}
diff --git a/modules/nosql/manager_redis.go b/modules/nosql/manager_redis.go
new file mode 100644
index 0000000..79a533b
--- /dev/null
+++ b/modules/nosql/manager_redis.go
@@ -0,0 +1,258 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package nosql
+
+import (
+ "crypto/tls"
+ "net/url"
+ "path"
+ "runtime/pprof"
+ "strconv"
+ "strings"
+
+ "code.gitea.io/gitea/modules/log"
+
+ "github.com/redis/go-redis/v9"
+)
+
+var replacer = strings.NewReplacer("_", "", "-", "")
+
+// CloseRedisClient closes a redis client
+func (m *Manager) CloseRedisClient(connection string) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+ client, ok := m.RedisConnections[connection]
+ if !ok {
+ connection = ToRedisURI(connection).String()
+ client, ok = m.RedisConnections[connection]
+ }
+ if !ok {
+ return nil
+ }
+
+ client.count--
+ if client.count > 0 {
+ return nil
+ }
+
+ for _, name := range client.name {
+ delete(m.RedisConnections, name)
+ }
+ return client.RedisClient.Close()
+}
+
+// GetRedisClient gets a redis client for a particular connection
+func (m *Manager) GetRedisClient(connection string) (client RedisClient) {
+ // Because we want associate any goroutines created by this call to the main nosqldb context we need to
+ // wrap this in a goroutine labelled with the nosqldb context
+ done := make(chan struct{})
+ var recovered any
+ go func() {
+ defer func() {
+ recovered = recover()
+ if recovered != nil {
+ log.Critical("PANIC during GetRedisClient: %v\nStacktrace: %s", recovered, log.Stack(2))
+ }
+ close(done)
+ }()
+ pprof.SetGoroutineLabels(m.ctx)
+
+ client = m.getRedisClient(connection)
+ }()
+ <-done
+ if recovered != nil {
+ panic(recovered)
+ }
+ return client
+}
+
+func (m *Manager) getRedisClient(connection string) RedisClient {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+ client, ok := m.RedisConnections[connection]
+ if ok {
+ client.count++
+ return client
+ }
+
+ uri := ToRedisURI(connection)
+ client, ok = m.RedisConnections[uri.String()]
+ if ok {
+ client.count++
+ return client
+ }
+ client = &redisClientHolder{
+ name: []string{connection, uri.String()},
+ }
+
+ opts := getRedisOptions(uri)
+ tlsConfig := getRedisTLSOptions(uri)
+
+ clientName := uri.Query().Get("clientname")
+
+ if len(clientName) > 0 {
+ client.name = append(client.name, clientName)
+ }
+
+ switch uri.Scheme {
+ case "redis+sentinels":
+ fallthrough
+ case "rediss+sentinel":
+ opts.TLSConfig = tlsConfig
+ fallthrough
+ case "redis+sentinel":
+ client.RedisClient = redis.NewFailoverClient(opts.Failover())
+ case "redis+clusters":
+ fallthrough
+ case "rediss+cluster":
+ opts.TLSConfig = tlsConfig
+ fallthrough
+ case "redis+cluster":
+ client.RedisClient = redis.NewClusterClient(opts.Cluster())
+ case "redis+socket":
+ simpleOpts := opts.Simple()
+ simpleOpts.Network = "unix"
+ simpleOpts.Addr = path.Join(uri.Host, uri.Path)
+ client.RedisClient = redis.NewClient(simpleOpts)
+ case "rediss":
+ opts.TLSConfig = tlsConfig
+ fallthrough
+ case "redis":
+ client.RedisClient = redis.NewClient(opts.Simple())
+ default:
+ return nil
+ }
+
+ for _, name := range client.name {
+ m.RedisConnections[name] = client
+ }
+
+ client.count++
+
+ return client
+}
+
+// getRedisOptions pulls various configuration options based on the RedisUri format and converts them to go-redis's
+// UniversalOptions fields. This function explicitly excludes fields related to TLS configuration, which is
+// conditionally attached to this options struct before being converted to the specific type for the redis scheme being
+// used, and only in scenarios where TLS is applicable (e.g. rediss://, redis+clusters://).
+func getRedisOptions(uri *url.URL) *redis.UniversalOptions {
+ opts := &redis.UniversalOptions{}
+
+ // Handle username/password
+ if password, ok := uri.User.Password(); ok {
+ opts.Password = password
+ // Username does not appear to be handled by redis.Options
+ opts.Username = uri.User.Username()
+ } else if uri.User.Username() != "" {
+ // assume this is the password
+ opts.Password = uri.User.Username()
+ }
+
+ // Now handle the uri query sets
+ for k, v := range uri.Query() {
+ switch replacer.Replace(strings.ToLower(k)) {
+ case "addr":
+ opts.Addrs = append(opts.Addrs, v...)
+ case "addrs":
+ opts.Addrs = append(opts.Addrs, strings.Split(v[0], ",")...)
+ case "username":
+ opts.Username = v[0]
+ case "password":
+ opts.Password = v[0]
+ case "database":
+ fallthrough
+ case "db":
+ opts.DB, _ = strconv.Atoi(v[0])
+ case "maxretries":
+ opts.MaxRetries, _ = strconv.Atoi(v[0])
+ case "minretrybackoff":
+ opts.MinRetryBackoff = valToTimeDuration(v)
+ case "maxretrybackoff":
+ opts.MaxRetryBackoff = valToTimeDuration(v)
+ case "timeout":
+ timeout := valToTimeDuration(v)
+ if timeout != 0 {
+ if opts.DialTimeout == 0 {
+ opts.DialTimeout = timeout
+ }
+ if opts.ReadTimeout == 0 {
+ opts.ReadTimeout = timeout
+ }
+ }
+ case "dialtimeout":
+ opts.DialTimeout = valToTimeDuration(v)
+ case "readtimeout":
+ opts.ReadTimeout = valToTimeDuration(v)
+ case "writetimeout":
+ opts.WriteTimeout = valToTimeDuration(v)
+ case "poolsize":
+ opts.PoolSize, _ = strconv.Atoi(v[0])
+ case "minidleconns":
+ opts.MinIdleConns, _ = strconv.Atoi(v[0])
+ case "pooltimeout":
+ opts.PoolTimeout = valToTimeDuration(v)
+ case "maxredirects":
+ opts.MaxRedirects, _ = strconv.Atoi(v[0])
+ case "readonly":
+ opts.ReadOnly, _ = strconv.ParseBool(v[0])
+ case "routebylatency":
+ opts.RouteByLatency, _ = strconv.ParseBool(v[0])
+ case "routerandomly":
+ opts.RouteRandomly, _ = strconv.ParseBool(v[0])
+ case "sentinelmasterid":
+ fallthrough
+ case "mastername":
+ opts.MasterName = v[0]
+ case "sentinelusername":
+ opts.SentinelUsername = v[0]
+ case "sentinelpassword":
+ opts.SentinelPassword = v[0]
+ }
+ }
+
+ if uri.Host != "" {
+ opts.Addrs = append(opts.Addrs, strings.Split(uri.Host, ",")...)
+ }
+
+ // A redis connection string uses the path section of the URI in two different ways. In a TCP-based connection, the
+ // path will be a database index to automatically have the client SELECT. In a Unix socket connection, it will be the
+ // file path. We only want to try to coerce this to the database index when we're not expecting a file path so that
+ // the error log stays clean.
+ if uri.Path != "" && uri.Scheme != "redis+socket" {
+ if db, err := strconv.Atoi(uri.Path[1:]); err == nil {
+ opts.DB = db
+ } else {
+ log.Error("Provided database identifier '%s' is not a valid integer. Forgejo will ignore this option.", uri.Path)
+ }
+ }
+
+ return opts
+}
+
+// getRedisTlsOptions parses RedisUri TLS configuration parameters and converts them to the go TLS configuration
+// equivalent fields.
+func getRedisTLSOptions(uri *url.URL) *tls.Config {
+ tlsConfig := &tls.Config{}
+
+ skipverify := uri.Query().Get("skipverify")
+
+ if len(skipverify) > 0 {
+ skipverify, err := strconv.ParseBool(skipverify)
+ if err == nil {
+ tlsConfig.InsecureSkipVerify = skipverify
+ }
+ }
+
+ insecureskipverify := uri.Query().Get("insecureskipverify")
+
+ if len(insecureskipverify) > 0 {
+ insecureskipverify, err := strconv.ParseBool(insecureskipverify)
+ if err == nil {
+ tlsConfig.InsecureSkipVerify = insecureskipverify
+ }
+ }
+
+ return tlsConfig
+}
diff --git a/modules/nosql/manager_redis_test.go b/modules/nosql/manager_redis_test.go
new file mode 100644
index 0000000..d979ea0
--- /dev/null
+++ b/modules/nosql/manager_redis_test.go
@@ -0,0 +1,81 @@
+// Copyright 2022 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package nosql
+
+import (
+ "net/url"
+ "testing"
+)
+
+func TestRedisUsernameOpt(t *testing.T) {
+ uri, _ := url.Parse("redis://redis:password@myredis/0")
+ opts := getRedisOptions(uri)
+
+ if opts.Username != "redis" {
+ t.Fail()
+ }
+}
+
+func TestRedisPasswordOpt(t *testing.T) {
+ uri, _ := url.Parse("redis://redis:password@myredis/0")
+ opts := getRedisOptions(uri)
+
+ if opts.Password != "password" {
+ t.Fail()
+ }
+}
+
+func TestSkipVerifyOpt(t *testing.T) {
+ uri, _ := url.Parse("rediss://myredis/0?skipverify=true")
+ tlsConfig := getRedisTLSOptions(uri)
+
+ if !tlsConfig.InsecureSkipVerify {
+ t.Fail()
+ }
+}
+
+func TestInsecureSkipVerifyOpt(t *testing.T) {
+ uri, _ := url.Parse("rediss://myredis/0?insecureskipverify=true")
+ tlsConfig := getRedisTLSOptions(uri)
+
+ if !tlsConfig.InsecureSkipVerify {
+ t.Fail()
+ }
+}
+
+func TestRedisSentinelUsernameOpt(t *testing.T) {
+ uri, _ := url.Parse("redis+sentinel://redis:password@myredis/0?sentinelusername=suser&sentinelpassword=spass")
+ opts := getRedisOptions(uri).Failover()
+
+ if opts.SentinelUsername != "suser" {
+ t.Fail()
+ }
+}
+
+func TestRedisSentinelPasswordOpt(t *testing.T) {
+ uri, _ := url.Parse("redis+sentinel://redis:password@myredis/0?sentinelusername=suser&sentinelpassword=spass")
+ opts := getRedisOptions(uri).Failover()
+
+ if opts.SentinelPassword != "spass" {
+ t.Fail()
+ }
+}
+
+func TestRedisDatabaseIndexTcp(t *testing.T) {
+ uri, _ := url.Parse("redis://redis:password@myredis/12")
+ opts := getRedisOptions(uri)
+
+ if opts.DB != 12 {
+ t.Fail()
+ }
+}
+
+func TestRedisDatabaseIndexUnix(t *testing.T) {
+ uri, _ := url.Parse("redis+socket:///var/run/redis.sock?database=12")
+ opts := getRedisOptions(uri)
+
+ if opts.DB != 12 {
+ t.Fail()
+ }
+}
diff --git a/modules/nosql/redis.go b/modules/nosql/redis.go
new file mode 100644
index 0000000..52e8ff9
--- /dev/null
+++ b/modules/nosql/redis.go
@@ -0,0 +1,100 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package nosql
+
+import (
+ "net/url"
+ "strconv"
+ "strings"
+)
+
+// The file contains common redis connection functions
+
+// ToRedisURI converts old style connections to a RedisURI
+//
+// A RedisURI matches the pattern:
+//
+// redis://[username:password@]host[:port][/database][?[option=value]*]
+// rediss://[username:password@]host[:port][/database][?[option=value]*]
+// redis+socket://[username:password@]path[/database][?[option=value]*]
+// redis+sentinel://[password@]host1 [: port1][, host2 [:port2]][, hostN [:portN]][/ database][?[option=value]*]
+// redis+cluster://[password@]host1 [: port1][, host2 [:port2]][, hostN [:portN]][/ database][?[option=value]*]
+//
+// We have previously used a URI like:
+// addrs=127.0.0.1:6379 db=0
+// network=tcp,addr=127.0.0.1:6379,password=macaron,db=0,pool_size=100,idle_timeout=180
+//
+// We need to convert this old style to the new style
+func ToRedisURI(connection string) *url.URL {
+ uri, err := url.Parse(connection)
+ if err == nil && strings.HasPrefix(uri.Scheme, "redis") {
+ // OK we're going to assume that this is a reasonable redis URI
+ return uri
+ }
+
+ // Let's set a nice default
+ uri, _ = url.Parse("redis://127.0.0.1:6379/0")
+ network := "tcp"
+ query := uri.Query()
+
+ // OK so there are two types: Space delimited and Comma delimited
+ // Let's assume that we have a space delimited string - as this is the most common
+ fields := strings.Fields(connection)
+ if len(fields) == 1 {
+ // It's a comma delimited string, then...
+ fields = strings.Split(connection, ",")
+ }
+ for _, f := range fields {
+ items := strings.SplitN(f, "=", 2)
+ if len(items) < 2 {
+ continue
+ }
+ switch strings.ToLower(items[0]) {
+ case "network":
+ if items[1] == "unix" {
+ uri.Scheme = "redis+socket"
+ }
+ network = items[1]
+ case "addrs":
+ uri.Host = items[1]
+ // now we need to handle the clustering
+ if strings.Contains(items[1], ",") && network == "tcp" {
+ uri.Scheme = "redis+cluster"
+ }
+ case "addr":
+ uri.Host = items[1]
+ case "password":
+ uri.User = url.UserPassword(uri.User.Username(), items[1])
+ case "username":
+ password, set := uri.User.Password()
+ if !set {
+ uri.User = url.User(items[1])
+ } else {
+ uri.User = url.UserPassword(items[1], password)
+ }
+ case "db":
+ uri.Path = "/" + items[1]
+ case "idle_timeout":
+ _, err := strconv.Atoi(items[1])
+ if err == nil {
+ query.Add("idle_timeout", items[1]+"s")
+ } else {
+ query.Add("idle_timeout", items[1])
+ }
+ default:
+ // Other options become query params
+ query.Add(items[0], items[1])
+ }
+ }
+
+ // Finally we need to fix up the Host if we have a unix port
+ if uri.Scheme == "redis+socket" {
+ query.Set("db", uri.Path)
+ uri.Path = uri.Host
+ uri.Host = ""
+ }
+ uri.RawQuery = query.Encode()
+
+ return uri
+}
diff --git a/modules/nosql/redis_test.go b/modules/nosql/redis_test.go
new file mode 100644
index 0000000..43652e3
--- /dev/null
+++ b/modules/nosql/redis_test.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package nosql
+
+import (
+ "testing"
+)
+
+func TestToRedisURI(t *testing.T) {
+ tests := []struct {
+ name string
+ connection string
+ want string
+ }{
+ {
+ name: "old_default",
+ connection: "addrs=127.0.0.1:6379 db=0",
+ want: "redis://127.0.0.1:6379/0",
+ },
+ {
+ name: "old_macaron_session_default",
+ connection: "network=tcp,addr=127.0.0.1:6379,password=macaron,db=0,pool_size=100,idle_timeout=180",
+ want: "redis://:macaron@127.0.0.1:6379/0?idle_timeout=180s&pool_size=100",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := ToRedisURI(tt.connection); got == nil || got.String() != tt.want {
+ t.Errorf(`ToRedisURI(%q) = %s, want %s`, tt.connection, got.String(), tt.want)
+ }
+ })
+ }
+}