summaryrefslogtreecommitdiffstats
path: root/modules/cache
diff options
context:
space:
mode:
Diffstat (limited to 'modules/cache')
-rw-r--r--modules/cache/cache.go184
-rw-r--r--modules/cache/cache_redis.go161
-rw-r--r--modules/cache/cache_test.go150
-rw-r--r--modules/cache/cache_twoqueue.go208
-rw-r--r--modules/cache/context.go181
-rw-r--r--modules/cache/context_test.go79
6 files changed, 963 insertions, 0 deletions
diff --git a/modules/cache/cache.go b/modules/cache/cache.go
new file mode 100644
index 0000000..2148e02
--- /dev/null
+++ b/modules/cache/cache.go
@@ -0,0 +1,184 @@
+// Copyright 2017 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package cache
+
+import (
+ "fmt"
+ "strconv"
+ "time"
+
+ "code.gitea.io/gitea/modules/setting"
+
+ mc "code.forgejo.org/go-chi/cache"
+
+ _ "code.forgejo.org/go-chi/cache/memcache" // memcache plugin for cache
+)
+
+var conn mc.Cache
+
+func newCache(cacheConfig setting.Cache) (mc.Cache, error) {
+ return mc.NewCacher(mc.Options{
+ Adapter: cacheConfig.Adapter,
+ AdapterConfig: cacheConfig.Conn,
+ Interval: cacheConfig.Interval,
+ })
+}
+
+// Init start cache service
+func Init() error {
+ var err error
+
+ if conn == nil {
+ if conn, err = newCache(setting.CacheService.Cache); err != nil {
+ return err
+ }
+ if err = conn.Ping(); err != nil {
+ return err
+ }
+ }
+
+ return err
+}
+
+const (
+ testCacheKey = "DefaultCache.TestKey"
+ SlowCacheThreshold = 100 * time.Microsecond
+)
+
+func Test() (time.Duration, error) {
+ if conn == nil {
+ return 0, fmt.Errorf("default cache not initialized")
+ }
+
+ testData := fmt.Sprintf("%x", make([]byte, 500))
+
+ start := time.Now()
+
+ if err := conn.Delete(testCacheKey); err != nil {
+ return 0, fmt.Errorf("expect cache to delete data based on key if exist but got: %w", err)
+ }
+ if err := conn.Put(testCacheKey, testData, 10); err != nil {
+ return 0, fmt.Errorf("expect cache to store data but got: %w", err)
+ }
+ testVal := conn.Get(testCacheKey)
+ if testVal == nil {
+ return 0, fmt.Errorf("expect cache hit but got none")
+ }
+ if testVal != testData {
+ return 0, fmt.Errorf("expect cache to return same value as stored but got other")
+ }
+
+ return time.Since(start), nil
+}
+
+// GetCache returns the currently configured cache
+func GetCache() mc.Cache {
+ return conn
+}
+
+// GetString returns the key value from cache with callback when no key exists in cache
+func GetString(key string, getFunc func() (string, error)) (string, error) {
+ if conn == nil || setting.CacheService.TTL == 0 {
+ return getFunc()
+ }
+
+ cached := conn.Get(key)
+
+ if cached == nil {
+ value, err := getFunc()
+ if err != nil {
+ return value, err
+ }
+ return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
+ }
+
+ if value, ok := cached.(string); ok {
+ return value, nil
+ }
+
+ if stringer, ok := cached.(fmt.Stringer); ok {
+ return stringer.String(), nil
+ }
+
+ return fmt.Sprintf("%s", cached), nil
+}
+
+// GetInt returns key value from cache with callback when no key exists in cache
+func GetInt(key string, getFunc func() (int, error)) (int, error) {
+ if conn == nil || setting.CacheService.TTL == 0 {
+ return getFunc()
+ }
+
+ cached := conn.Get(key)
+
+ if cached == nil {
+ value, err := getFunc()
+ if err != nil {
+ return value, err
+ }
+
+ return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
+ }
+
+ switch v := cached.(type) {
+ case int:
+ return v, nil
+ case string:
+ value, err := strconv.Atoi(v)
+ if err != nil {
+ return 0, err
+ }
+ return value, nil
+ default:
+ value, err := getFunc()
+ if err != nil {
+ return value, err
+ }
+ return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
+ }
+}
+
+// GetInt64 returns key value from cache with callback when no key exists in cache
+func GetInt64(key string, getFunc func() (int64, error)) (int64, error) {
+ if conn == nil || setting.CacheService.TTL == 0 {
+ return getFunc()
+ }
+
+ cached := conn.Get(key)
+
+ if cached == nil {
+ value, err := getFunc()
+ if err != nil {
+ return value, err
+ }
+
+ return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
+ }
+
+ switch v := conn.Get(key).(type) {
+ case int64:
+ return v, nil
+ case string:
+ value, err := strconv.ParseInt(v, 10, 64)
+ if err != nil {
+ return 0, err
+ }
+ return value, nil
+ default:
+ value, err := getFunc()
+ if err != nil {
+ return value, err
+ }
+
+ return value, conn.Put(key, value, setting.CacheService.TTLSeconds())
+ }
+}
+
+// Remove key from cache
+func Remove(key string) {
+ if conn == nil {
+ return
+ }
+ _ = conn.Delete(key)
+}
diff --git a/modules/cache/cache_redis.go b/modules/cache/cache_redis.go
new file mode 100644
index 0000000..4c243b2
--- /dev/null
+++ b/modules/cache/cache_redis.go
@@ -0,0 +1,161 @@
+// Copyright 2020 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package cache
+
+import (
+ "fmt"
+ "strconv"
+ "time"
+
+ "code.gitea.io/gitea/modules/graceful"
+ "code.gitea.io/gitea/modules/nosql"
+
+ "code.forgejo.org/go-chi/cache"
+)
+
+// RedisCacher represents a redis cache adapter implementation.
+type RedisCacher struct {
+ c nosql.RedisClient
+ prefix string
+ hsetName string
+ occupyMode bool
+}
+
+// toStr convert string/int/int64 interface to string. it's only used by the RedisCacher.Put internally
+func toStr(v any) string {
+ if v == nil {
+ return ""
+ }
+ switch v := v.(type) {
+ case string:
+ return v
+ case []byte:
+ return string(v)
+ case int:
+ return strconv.FormatInt(int64(v), 10)
+ case int64:
+ return strconv.FormatInt(v, 10)
+ default:
+ return fmt.Sprint(v) // as what the old com.ToStr does in most cases
+ }
+}
+
+// Put puts value (string type) into cache with key and expire time.
+// If expired is 0, it lives forever.
+func (c *RedisCacher) Put(key string, val any, expire int64) error {
+ // this function is not well-designed, it only puts string values into cache
+ key = c.prefix + key
+ if expire == 0 {
+ if err := c.c.Set(graceful.GetManager().HammerContext(), key, toStr(val), 0).Err(); err != nil {
+ return err
+ }
+ } else {
+ dur := time.Duration(expire) * time.Second
+ if err := c.c.Set(graceful.GetManager().HammerContext(), key, toStr(val), dur).Err(); err != nil {
+ return err
+ }
+ }
+
+ if c.occupyMode {
+ return nil
+ }
+ return c.c.HSet(graceful.GetManager().HammerContext(), c.hsetName, key, "0").Err()
+}
+
+// Get gets cached value by given key.
+func (c *RedisCacher) Get(key string) any {
+ val, err := c.c.Get(graceful.GetManager().HammerContext(), c.prefix+key).Result()
+ if err != nil {
+ return nil
+ }
+ return val
+}
+
+// Delete deletes cached value by given key.
+func (c *RedisCacher) Delete(key string) error {
+ key = c.prefix + key
+ if err := c.c.Del(graceful.GetManager().HammerContext(), key).Err(); err != nil {
+ return err
+ }
+
+ if c.occupyMode {
+ return nil
+ }
+ return c.c.HDel(graceful.GetManager().HammerContext(), c.hsetName, key).Err()
+}
+
+// Incr increases cached int-type value by given key as a counter.
+func (c *RedisCacher) Incr(key string) error {
+ if !c.IsExist(key) {
+ return fmt.Errorf("key '%s' not exist", key)
+ }
+ return c.c.Incr(graceful.GetManager().HammerContext(), c.prefix+key).Err()
+}
+
+// Decr decreases cached int-type value by given key as a counter.
+func (c *RedisCacher) Decr(key string) error {
+ if !c.IsExist(key) {
+ return fmt.Errorf("key '%s' not exist", key)
+ }
+ return c.c.Decr(graceful.GetManager().HammerContext(), c.prefix+key).Err()
+}
+
+// IsExist returns true if cached value exists.
+func (c *RedisCacher) IsExist(key string) bool {
+ if c.c.Exists(graceful.GetManager().HammerContext(), c.prefix+key).Val() == 1 {
+ return true
+ }
+
+ if !c.occupyMode {
+ c.c.HDel(graceful.GetManager().HammerContext(), c.hsetName, c.prefix+key)
+ }
+ return false
+}
+
+// Flush deletes all cached data.
+func (c *RedisCacher) Flush() error {
+ if c.occupyMode {
+ return c.c.FlushDB(graceful.GetManager().HammerContext()).Err()
+ }
+
+ keys, err := c.c.HKeys(graceful.GetManager().HammerContext(), c.hsetName).Result()
+ if err != nil {
+ return err
+ }
+ if err = c.c.Del(graceful.GetManager().HammerContext(), keys...).Err(); err != nil {
+ return err
+ }
+ return c.c.Del(graceful.GetManager().HammerContext(), c.hsetName).Err()
+}
+
+// StartAndGC starts GC routine based on config string settings.
+// AdapterConfig: network=tcp,addr=:6379,password=macaron,db=0,pool_size=100,idle_timeout=180,hset_name=MacaronCache,prefix=cache:
+func (c *RedisCacher) StartAndGC(opts cache.Options) error {
+ c.hsetName = "MacaronCache"
+ c.occupyMode = opts.OccupyMode
+
+ uri := nosql.ToRedisURI(opts.AdapterConfig)
+
+ c.c = nosql.GetManager().GetRedisClient(uri.String())
+
+ for k, v := range uri.Query() {
+ switch k {
+ case "hset_name":
+ c.hsetName = v[0]
+ case "prefix":
+ c.prefix = v[0]
+ }
+ }
+
+ return c.c.Ping(graceful.GetManager().HammerContext()).Err()
+}
+
+// Ping tests if the cache is alive.
+func (c *RedisCacher) Ping() error {
+ return c.c.Ping(graceful.GetManager().HammerContext()).Err()
+}
+
+func init() {
+ cache.Register("redis", &RedisCacher{})
+}
diff --git a/modules/cache/cache_test.go b/modules/cache/cache_test.go
new file mode 100644
index 0000000..8bc986f
--- /dev/null
+++ b/modules/cache/cache_test.go
@@ -0,0 +1,150 @@
+// Copyright 2021 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package cache
+
+import (
+ "fmt"
+ "testing"
+ "time"
+
+ "code.gitea.io/gitea/modules/setting"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func createTestCache() {
+ conn, _ = newCache(setting.Cache{
+ Adapter: "memory",
+ TTL: time.Minute,
+ })
+ setting.CacheService.TTL = 24 * time.Hour
+}
+
+func TestNewContext(t *testing.T) {
+ require.NoError(t, Init())
+
+ setting.CacheService.Cache = setting.Cache{Adapter: "redis", Conn: "some random string"}
+ con, err := newCache(setting.Cache{
+ Adapter: "rand",
+ Conn: "false conf",
+ Interval: 100,
+ })
+ require.Error(t, err)
+ assert.Nil(t, con)
+}
+
+func TestGetCache(t *testing.T) {
+ createTestCache()
+
+ assert.NotNil(t, GetCache())
+}
+
+func TestGetString(t *testing.T) {
+ createTestCache()
+
+ data, err := GetString("key", func() (string, error) {
+ return "", fmt.Errorf("some error")
+ })
+ require.Error(t, err)
+ assert.Equal(t, "", data)
+
+ data, err = GetString("key", func() (string, error) {
+ return "", nil
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "", data)
+
+ data, err = GetString("key", func() (string, error) {
+ return "some data", nil
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "", data)
+ Remove("key")
+
+ data, err = GetString("key", func() (string, error) {
+ return "some data", nil
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "some data", data)
+
+ data, err = GetString("key", func() (string, error) {
+ return "", fmt.Errorf("some error")
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "some data", data)
+ Remove("key")
+}
+
+func TestGetInt(t *testing.T) {
+ createTestCache()
+
+ data, err := GetInt("key", func() (int, error) {
+ return 0, fmt.Errorf("some error")
+ })
+ require.Error(t, err)
+ assert.Equal(t, 0, data)
+
+ data, err = GetInt("key", func() (int, error) {
+ return 0, nil
+ })
+ require.NoError(t, err)
+ assert.Equal(t, 0, data)
+
+ data, err = GetInt("key", func() (int, error) {
+ return 100, nil
+ })
+ require.NoError(t, err)
+ assert.Equal(t, 0, data)
+ Remove("key")
+
+ data, err = GetInt("key", func() (int, error) {
+ return 100, nil
+ })
+ require.NoError(t, err)
+ assert.Equal(t, 100, data)
+
+ data, err = GetInt("key", func() (int, error) {
+ return 0, fmt.Errorf("some error")
+ })
+ require.NoError(t, err)
+ assert.Equal(t, 100, data)
+ Remove("key")
+}
+
+func TestGetInt64(t *testing.T) {
+ createTestCache()
+
+ data, err := GetInt64("key", func() (int64, error) {
+ return 0, fmt.Errorf("some error")
+ })
+ require.Error(t, err)
+ assert.EqualValues(t, 0, data)
+
+ data, err = GetInt64("key", func() (int64, error) {
+ return 0, nil
+ })
+ require.NoError(t, err)
+ assert.EqualValues(t, 0, data)
+
+ data, err = GetInt64("key", func() (int64, error) {
+ return 100, nil
+ })
+ require.NoError(t, err)
+ assert.EqualValues(t, 0, data)
+ Remove("key")
+
+ data, err = GetInt64("key", func() (int64, error) {
+ return 100, nil
+ })
+ require.NoError(t, err)
+ assert.EqualValues(t, 100, data)
+
+ data, err = GetInt64("key", func() (int64, error) {
+ return 0, fmt.Errorf("some error")
+ })
+ require.NoError(t, err)
+ assert.EqualValues(t, 100, data)
+ Remove("key")
+}
diff --git a/modules/cache/cache_twoqueue.go b/modules/cache/cache_twoqueue.go
new file mode 100644
index 0000000..c15ed52
--- /dev/null
+++ b/modules/cache/cache_twoqueue.go
@@ -0,0 +1,208 @@
+// Copyright 2021 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package cache
+
+import (
+ "strconv"
+ "sync"
+ "time"
+
+ "code.gitea.io/gitea/modules/json"
+
+ mc "code.forgejo.org/go-chi/cache"
+ lru "github.com/hashicorp/golang-lru/v2"
+)
+
+// TwoQueueCache represents a LRU 2Q cache adapter implementation
+type TwoQueueCache struct {
+ lock sync.Mutex
+ cache *lru.TwoQueueCache[string, any]
+ interval int
+}
+
+// TwoQueueCacheConfig describes the configuration for TwoQueueCache
+type TwoQueueCacheConfig struct {
+ Size int `ini:"SIZE" json:"size"`
+ RecentRatio float64 `ini:"RECENT_RATIO" json:"recent_ratio"`
+ GhostRatio float64 `ini:"GHOST_RATIO" json:"ghost_ratio"`
+}
+
+// MemoryItem represents a memory cache item.
+type MemoryItem struct {
+ Val any
+ Created int64
+ Timeout int64
+}
+
+func (item *MemoryItem) hasExpired() bool {
+ return item.Timeout > 0 &&
+ (time.Now().Unix()-item.Created) >= item.Timeout
+}
+
+var _ mc.Cache = &TwoQueueCache{}
+
+// Put puts value into cache with key and expire time.
+func (c *TwoQueueCache) Put(key string, val any, timeout int64) error {
+ item := &MemoryItem{
+ Val: val,
+ Created: time.Now().Unix(),
+ Timeout: timeout,
+ }
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ c.cache.Add(key, item)
+ return nil
+}
+
+// Get gets cached value by given key.
+func (c *TwoQueueCache) Get(key string) any {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ cached, ok := c.cache.Get(key)
+ if !ok {
+ return nil
+ }
+ item, ok := cached.(*MemoryItem)
+
+ if !ok || item.hasExpired() {
+ c.cache.Remove(key)
+ return nil
+ }
+
+ return item.Val
+}
+
+// Delete deletes cached value by given key.
+func (c *TwoQueueCache) Delete(key string) error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ c.cache.Remove(key)
+ return nil
+}
+
+// Incr increases cached int-type value by given key as a counter.
+func (c *TwoQueueCache) Incr(key string) error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ cached, ok := c.cache.Get(key)
+ if !ok {
+ return nil
+ }
+ item, ok := cached.(*MemoryItem)
+
+ if !ok || item.hasExpired() {
+ c.cache.Remove(key)
+ return nil
+ }
+
+ var err error
+ item.Val, err = mc.Incr(item.Val)
+ return err
+}
+
+// Decr decreases cached int-type value by given key as a counter.
+func (c *TwoQueueCache) Decr(key string) error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ cached, ok := c.cache.Get(key)
+ if !ok {
+ return nil
+ }
+ item, ok := cached.(*MemoryItem)
+
+ if !ok || item.hasExpired() {
+ c.cache.Remove(key)
+ return nil
+ }
+
+ var err error
+ item.Val, err = mc.Decr(item.Val)
+ return err
+}
+
+// IsExist returns true if cached value exists.
+func (c *TwoQueueCache) IsExist(key string) bool {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ cached, ok := c.cache.Peek(key)
+ if !ok {
+ return false
+ }
+ item, ok := cached.(*MemoryItem)
+ if !ok || item.hasExpired() {
+ c.cache.Remove(key)
+ return false
+ }
+
+ return true
+}
+
+// Flush deletes all cached data.
+func (c *TwoQueueCache) Flush() error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ c.cache.Purge()
+ return nil
+}
+
+func (c *TwoQueueCache) checkAndInvalidate(key string) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+ cached, ok := c.cache.Peek(key)
+ if !ok {
+ return
+ }
+ item, ok := cached.(*MemoryItem)
+ if !ok || item.hasExpired() {
+ c.cache.Remove(key)
+ }
+}
+
+func (c *TwoQueueCache) startGC() {
+ if c.interval < 0 {
+ return
+ }
+ for _, key := range c.cache.Keys() {
+ c.checkAndInvalidate(key)
+ }
+ time.AfterFunc(time.Duration(c.interval)*time.Second, c.startGC)
+}
+
+// StartAndGC starts GC routine based on config string settings.
+func (c *TwoQueueCache) StartAndGC(opts mc.Options) error {
+ var err error
+ size := 50000
+ if opts.AdapterConfig != "" {
+ size, err = strconv.Atoi(opts.AdapterConfig)
+ }
+ if err != nil {
+ if !json.Valid([]byte(opts.AdapterConfig)) {
+ return err
+ }
+
+ cfg := &TwoQueueCacheConfig{
+ Size: 50000,
+ RecentRatio: lru.Default2QRecentRatio,
+ GhostRatio: lru.Default2QGhostEntries,
+ }
+ _ = json.Unmarshal([]byte(opts.AdapterConfig), cfg)
+ c.cache, err = lru.New2QParams[string, any](cfg.Size, cfg.RecentRatio, cfg.GhostRatio)
+ } else {
+ c.cache, err = lru.New2Q[string, any](size)
+ }
+ c.interval = opts.Interval
+ if c.interval > 0 {
+ go c.startGC()
+ }
+ return err
+}
+
+// Ping tests if the cache is alive.
+func (c *TwoQueueCache) Ping() error {
+ return mc.GenericPing(c)
+}
+
+func init() {
+ mc.Register("twoqueue", &TwoQueueCache{})
+}
diff --git a/modules/cache/context.go b/modules/cache/context.go
new file mode 100644
index 0000000..f9bdf52
--- /dev/null
+++ b/modules/cache/context.go
@@ -0,0 +1,181 @@
+// Copyright 2022 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package cache
+
+import (
+ "context"
+ "sync"
+ "time"
+
+ "code.gitea.io/gitea/modules/log"
+)
+
+// cacheContext is a context that can be used to cache data in a request level context
+// This is useful for caching data that is expensive to calculate and is likely to be
+// used multiple times in a request.
+type cacheContext struct {
+ data map[any]map[any]any
+ lock sync.RWMutex
+ created time.Time
+ discard bool
+}
+
+func (cc *cacheContext) Get(tp, key any) any {
+ cc.lock.RLock()
+ defer cc.lock.RUnlock()
+ return cc.data[tp][key]
+}
+
+func (cc *cacheContext) Put(tp, key, value any) {
+ cc.lock.Lock()
+ defer cc.lock.Unlock()
+
+ if cc.discard {
+ return
+ }
+
+ d := cc.data[tp]
+ if d == nil {
+ d = make(map[any]any)
+ cc.data[tp] = d
+ }
+ d[key] = value
+}
+
+func (cc *cacheContext) Delete(tp, key any) {
+ cc.lock.Lock()
+ defer cc.lock.Unlock()
+ delete(cc.data[tp], key)
+}
+
+func (cc *cacheContext) Discard() {
+ cc.lock.Lock()
+ defer cc.lock.Unlock()
+ cc.data = nil
+ cc.discard = true
+}
+
+func (cc *cacheContext) isDiscard() bool {
+ cc.lock.RLock()
+ defer cc.lock.RUnlock()
+ return cc.discard
+}
+
+// cacheContextLifetime is the max lifetime of cacheContext.
+// Since cacheContext is used to cache data in a request level context, 5 minutes is enough.
+// If a cacheContext is used more than 5 minutes, it's probably misuse.
+const cacheContextLifetime = 5 * time.Minute
+
+var timeNow = time.Now
+
+func (cc *cacheContext) Expired() bool {
+ return timeNow().Sub(cc.created) > cacheContextLifetime
+}
+
+type cacheContextType = struct{ useless struct{} }
+
+var cacheContextKey = cacheContextType{useless: struct{}{}}
+
+/*
+Since there are both WithCacheContext and WithNoCacheContext,
+it may be confusing when there is nesting.
+
+Some cases to explain the design:
+
+When:
+- A, B or C means a cache context.
+- A', B' or C' means a discard cache context.
+- ctx means context.Backgrand().
+- A(ctx) means a cache context with ctx as the parent context.
+- B(A(ctx)) means a cache context with A(ctx) as the parent context.
+- With is alias of WithCacheContext.
+- WithNo is alias of WithNoCacheContext.
+
+So:
+- With(ctx) -> A(ctx)
+- With(With(ctx)) -> A(ctx), not B(A(ctx)), always reuse parent cache context if possible.
+- With(With(With(ctx))) -> A(ctx), not C(B(A(ctx))), ditto.
+- WithNo(ctx) -> ctx, not A'(ctx), don't create new cache context if we don't have to.
+- WithNo(With(ctx)) -> A'(ctx)
+- WithNo(WithNo(With(ctx))) -> A'(ctx), not B'(A'(ctx)), don't create new cache context if we don't have to.
+- With(WithNo(With(ctx))) -> B(A'(ctx)), not A(ctx), never reuse a discard cache context.
+- WithNo(With(WithNo(With(ctx)))) -> B'(A'(ctx))
+- With(WithNo(With(WithNo(With(ctx))))) -> C(B'(A'(ctx))), so there's always only one not-discard cache context.
+*/
+
+func WithCacheContext(ctx context.Context) context.Context {
+ if c, ok := ctx.Value(cacheContextKey).(*cacheContext); ok {
+ if !c.isDiscard() {
+ // reuse parent context
+ return ctx
+ }
+ }
+ return context.WithValue(ctx, cacheContextKey, &cacheContext{
+ data: make(map[any]map[any]any),
+ created: timeNow(),
+ })
+}
+
+func WithNoCacheContext(ctx context.Context) context.Context {
+ if c, ok := ctx.Value(cacheContextKey).(*cacheContext); ok {
+ // The caller want to run long-life tasks, but the parent context is a cache context.
+ // So we should disable and clean the cache data, or it will be kept in memory for a long time.
+ c.Discard()
+ return ctx
+ }
+
+ return ctx
+}
+
+func GetContextData(ctx context.Context, tp, key any) any {
+ if c, ok := ctx.Value(cacheContextKey).(*cacheContext); ok {
+ if c.Expired() {
+ // The warning means that the cache context is misused for long-life task,
+ // it can be resolved with WithNoCacheContext(ctx).
+ log.Warn("cache context is expired, is highly likely to be misused for long-life tasks: %v", c)
+ return nil
+ }
+ return c.Get(tp, key)
+ }
+ return nil
+}
+
+func SetContextData(ctx context.Context, tp, key, value any) {
+ if c, ok := ctx.Value(cacheContextKey).(*cacheContext); ok {
+ if c.Expired() {
+ // The warning means that the cache context is misused for long-life task,
+ // it can be resolved with WithNoCacheContext(ctx).
+ log.Warn("cache context is expired, is highly likely to be misused for long-life tasks: %v", c)
+ return
+ }
+ c.Put(tp, key, value)
+ return
+ }
+}
+
+func RemoveContextData(ctx context.Context, tp, key any) {
+ if c, ok := ctx.Value(cacheContextKey).(*cacheContext); ok {
+ if c.Expired() {
+ // The warning means that the cache context is misused for long-life task,
+ // it can be resolved with WithNoCacheContext(ctx).
+ log.Warn("cache context is expired, is highly likely to be misused for long-life tasks: %v", c)
+ return
+ }
+ c.Delete(tp, key)
+ }
+}
+
+// GetWithContextCache returns the cache value of the given key in the given context.
+func GetWithContextCache[T any](ctx context.Context, cacheGroupKey string, cacheTargetID any, f func() (T, error)) (T, error) {
+ v := GetContextData(ctx, cacheGroupKey, cacheTargetID)
+ if vv, ok := v.(T); ok {
+ return vv, nil
+ }
+ t, err := f()
+ if err != nil {
+ return t, err
+ }
+ SetContextData(ctx, cacheGroupKey, cacheTargetID, t)
+ return t, nil
+}
diff --git a/modules/cache/context_test.go b/modules/cache/context_test.go
new file mode 100644
index 0000000..072c394
--- /dev/null
+++ b/modules/cache/context_test.go
@@ -0,0 +1,79 @@
+// Copyright 2022 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package cache
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestWithCacheContext(t *testing.T) {
+ ctx := WithCacheContext(context.Background())
+
+ v := GetContextData(ctx, "empty_field", "my_config1")
+ assert.Nil(t, v)
+
+ const field = "system_setting"
+ v = GetContextData(ctx, field, "my_config1")
+ assert.Nil(t, v)
+ SetContextData(ctx, field, "my_config1", 1)
+ v = GetContextData(ctx, field, "my_config1")
+ assert.NotNil(t, v)
+ assert.EqualValues(t, 1, v.(int))
+
+ RemoveContextData(ctx, field, "my_config1")
+ RemoveContextData(ctx, field, "my_config2") // remove a non-exist key
+
+ v = GetContextData(ctx, field, "my_config1")
+ assert.Nil(t, v)
+
+ vInt, err := GetWithContextCache(ctx, field, "my_config1", func() (int, error) {
+ return 1, nil
+ })
+ require.NoError(t, err)
+ assert.EqualValues(t, 1, vInt)
+
+ v = GetContextData(ctx, field, "my_config1")
+ assert.EqualValues(t, 1, v)
+
+ now := timeNow
+ defer func() {
+ timeNow = now
+ }()
+ timeNow = func() time.Time {
+ return now().Add(5 * time.Minute)
+ }
+ v = GetContextData(ctx, field, "my_config1")
+ assert.Nil(t, v)
+}
+
+func TestWithNoCacheContext(t *testing.T) {
+ ctx := context.Background()
+
+ const field = "system_setting"
+
+ v := GetContextData(ctx, field, "my_config1")
+ assert.Nil(t, v)
+ SetContextData(ctx, field, "my_config1", 1)
+ v = GetContextData(ctx, field, "my_config1")
+ assert.Nil(t, v) // still no cache
+
+ ctx = WithCacheContext(ctx)
+ v = GetContextData(ctx, field, "my_config1")
+ assert.Nil(t, v)
+ SetContextData(ctx, field, "my_config1", 1)
+ v = GetContextData(ctx, field, "my_config1")
+ assert.NotNil(t, v)
+
+ ctx = WithNoCacheContext(ctx)
+ v = GetContextData(ctx, field, "my_config1")
+ assert.Nil(t, v)
+ SetContextData(ctx, field, "my_config1", 1)
+ v = GetContextData(ctx, field, "my_config1")
+ assert.Nil(t, v) // still no cache
+}