summaryrefslogtreecommitdiffstats
path: root/pkg/artifactcache
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/artifactcache')
-rw-r--r--pkg/artifactcache/doc.go8
-rw-r--r--pkg/artifactcache/handler.go536
-rw-r--r--pkg/artifactcache/handler_test.go471
-rw-r--r--pkg/artifactcache/model.go44
-rw-r--r--pkg/artifactcache/storage.go130
-rw-r--r--pkg/artifactcache/testdata/example/example.yaml30
6 files changed, 1219 insertions, 0 deletions
diff --git a/pkg/artifactcache/doc.go b/pkg/artifactcache/doc.go
new file mode 100644
index 0000000..13d2644
--- /dev/null
+++ b/pkg/artifactcache/doc.go
@@ -0,0 +1,8 @@
+// Package artifactcache provides a cache handler for the runner.
+//
+// Inspired by https://github.com/sp-ricard-valverde/github-act-cache-server
+//
+// TODO: Authorization
+// TODO: Restrictions for accessing a cache, see https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#restrictions-for-accessing-a-cache
+// TODO: Force deleting cache entries, see https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries
+package artifactcache
diff --git a/pkg/artifactcache/handler.go b/pkg/artifactcache/handler.go
new file mode 100644
index 0000000..3178260
--- /dev/null
+++ b/pkg/artifactcache/handler.go
@@ -0,0 +1,536 @@
+package artifactcache
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/julienschmidt/httprouter"
+ "github.com/sirupsen/logrus"
+ "github.com/timshannon/bolthold"
+ "go.etcd.io/bbolt"
+
+ "github.com/nektos/act/pkg/common"
+)
+
+const (
+ urlBase = "/_apis/artifactcache"
+)
+
+type Handler struct {
+ dir string
+ storage *Storage
+ router *httprouter.Router
+ listener net.Listener
+ server *http.Server
+ logger logrus.FieldLogger
+
+ gcing int32 // TODO: use atomic.Bool when we can use Go 1.19
+ gcAt time.Time
+
+ outboundIP string
+}
+
+func StartHandler(dir, outboundIP string, port uint16, logger logrus.FieldLogger) (*Handler, error) {
+ h := &Handler{}
+
+ if logger == nil {
+ discard := logrus.New()
+ discard.Out = io.Discard
+ logger = discard
+ }
+ logger = logger.WithField("module", "artifactcache")
+ h.logger = logger
+
+ if dir == "" {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil, err
+ }
+ dir = filepath.Join(home, ".cache", "actcache")
+ }
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return nil, err
+ }
+
+ h.dir = dir
+
+ storage, err := NewStorage(filepath.Join(dir, "cache"))
+ if err != nil {
+ return nil, err
+ }
+ h.storage = storage
+
+ if outboundIP != "" {
+ h.outboundIP = outboundIP
+ } else if ip := common.GetOutboundIP(); ip == nil {
+ return nil, fmt.Errorf("unable to determine outbound IP address")
+ } else {
+ h.outboundIP = ip.String()
+ }
+
+ router := httprouter.New()
+ router.GET(urlBase+"/cache", h.middleware(h.find))
+ router.POST(urlBase+"/caches", h.middleware(h.reserve))
+ router.PATCH(urlBase+"/caches/:id", h.middleware(h.upload))
+ router.POST(urlBase+"/caches/:id", h.middleware(h.commit))
+ router.GET(urlBase+"/artifacts/:id", h.middleware(h.get))
+ router.POST(urlBase+"/clean", h.middleware(h.clean))
+
+ h.router = router
+
+ h.gcCache()
+
+ listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) // listen on all interfaces
+ if err != nil {
+ return nil, err
+ }
+ server := &http.Server{
+ ReadHeaderTimeout: 2 * time.Second,
+ Handler: router,
+ }
+ go func() {
+ if err := server.Serve(listener); err != nil && errors.Is(err, net.ErrClosed) {
+ logger.Errorf("http serve: %v", err)
+ }
+ }()
+ h.listener = listener
+ h.server = server
+
+ return h, nil
+}
+
+func (h *Handler) ExternalURL() string {
+ // TODO: make the external url configurable if necessary
+ return fmt.Sprintf("http://%s:%d",
+ h.outboundIP,
+ h.listener.Addr().(*net.TCPAddr).Port)
+}
+
+func (h *Handler) Close() error {
+ if h == nil {
+ return nil
+ }
+ var retErr error
+ if h.server != nil {
+ err := h.server.Close()
+ if err != nil {
+ retErr = err
+ }
+ h.server = nil
+ }
+ if h.listener != nil {
+ err := h.listener.Close()
+ if errors.Is(err, net.ErrClosed) {
+ err = nil
+ }
+ if err != nil {
+ retErr = err
+ }
+ h.listener = nil
+ }
+ return retErr
+}
+
+func (h *Handler) openDB() (*bolthold.Store, error) {
+ return bolthold.Open(filepath.Join(h.dir, "bolt.db"), 0o644, &bolthold.Options{
+ Encoder: json.Marshal,
+ Decoder: json.Unmarshal,
+ Options: &bbolt.Options{
+ Timeout: 5 * time.Second,
+ NoGrowSync: bbolt.DefaultOptions.NoGrowSync,
+ FreelistType: bbolt.DefaultOptions.FreelistType,
+ },
+ })
+}
+
+// GET /_apis/artifactcache/cache
+func (h *Handler) find(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ keys := strings.Split(r.URL.Query().Get("keys"), ",")
+ // cache keys are case insensitive
+ for i, key := range keys {
+ keys[i] = strings.ToLower(key)
+ }
+ version := r.URL.Query().Get("version")
+
+ db, err := h.openDB()
+ if err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+ defer db.Close()
+
+ cache, err := h.findCache(db, keys, version)
+ if err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+ if cache == nil {
+ h.responseJSON(w, r, 204)
+ return
+ }
+
+ if ok, err := h.storage.Exist(cache.ID); err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ } else if !ok {
+ _ = db.Delete(cache.ID, cache)
+ h.responseJSON(w, r, 204)
+ return
+ }
+ h.responseJSON(w, r, 200, map[string]any{
+ "result": "hit",
+ "archiveLocation": fmt.Sprintf("%s%s/artifacts/%d", h.ExternalURL(), urlBase, cache.ID),
+ "cacheKey": cache.Key,
+ })
+}
+
+// POST /_apis/artifactcache/caches
+func (h *Handler) reserve(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ api := &Request{}
+ if err := json.NewDecoder(r.Body).Decode(api); err != nil {
+ h.responseJSON(w, r, 400, err)
+ return
+ }
+ // cache keys are case insensitive
+ api.Key = strings.ToLower(api.Key)
+
+ cache := api.ToCache()
+ cache.FillKeyVersionHash()
+ db, err := h.openDB()
+ if err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+ defer db.Close()
+ if err := db.FindOne(cache, bolthold.Where("KeyVersionHash").Eq(cache.KeyVersionHash)); err != nil {
+ if !errors.Is(err, bolthold.ErrNotFound) {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+ } else {
+ h.responseJSON(w, r, 400, fmt.Errorf("already exist"))
+ return
+ }
+
+ now := time.Now().Unix()
+ cache.CreatedAt = now
+ cache.UsedAt = now
+ if err := db.Insert(bolthold.NextSequence(), cache); err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+ // write back id to db
+ if err := db.Update(cache.ID, cache); err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+ h.responseJSON(w, r, 200, map[string]any{
+ "cacheId": cache.ID,
+ })
+}
+
+// PATCH /_apis/artifactcache/caches/:id
+func (h *Handler) upload(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
+ id, err := strconv.ParseInt(params.ByName("id"), 10, 64)
+ if err != nil {
+ h.responseJSON(w, r, 400, err)
+ return
+ }
+
+ cache := &Cache{}
+ db, err := h.openDB()
+ if err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+ defer db.Close()
+ if err := db.Get(id, cache); err != nil {
+ if errors.Is(err, bolthold.ErrNotFound) {
+ h.responseJSON(w, r, 400, fmt.Errorf("cache %d: not reserved", id))
+ return
+ }
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+
+ if cache.Complete {
+ h.responseJSON(w, r, 400, fmt.Errorf("cache %v %q: already complete", cache.ID, cache.Key))
+ return
+ }
+ db.Close()
+ start, _, err := parseContentRange(r.Header.Get("Content-Range"))
+ if err != nil {
+ h.responseJSON(w, r, 400, err)
+ return
+ }
+ if err := h.storage.Write(cache.ID, start, r.Body); err != nil {
+ h.responseJSON(w, r, 500, err)
+ }
+ h.useCache(id)
+ h.responseJSON(w, r, 200)
+}
+
+// POST /_apis/artifactcache/caches/:id
+func (h *Handler) commit(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
+ id, err := strconv.ParseInt(params.ByName("id"), 10, 64)
+ if err != nil {
+ h.responseJSON(w, r, 400, err)
+ return
+ }
+
+ cache := &Cache{}
+ db, err := h.openDB()
+ if err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+ defer db.Close()
+ if err := db.Get(id, cache); err != nil {
+ if errors.Is(err, bolthold.ErrNotFound) {
+ h.responseJSON(w, r, 400, fmt.Errorf("cache %d: not reserved", id))
+ return
+ }
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+
+ if cache.Complete {
+ h.responseJSON(w, r, 400, fmt.Errorf("cache %v %q: already complete", cache.ID, cache.Key))
+ return
+ }
+
+ db.Close()
+
+ size, err := h.storage.Commit(cache.ID, cache.Size)
+ if err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+ // write real size back to cache, it may be different from the current value when the request doesn't specify it.
+ cache.Size = size
+
+ db, err = h.openDB()
+ if err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+ defer db.Close()
+
+ cache.Complete = true
+ if err := db.Update(cache.ID, cache); err != nil {
+ h.responseJSON(w, r, 500, err)
+ return
+ }
+
+ h.responseJSON(w, r, 200)
+}
+
+// GET /_apis/artifactcache/artifacts/:id
+func (h *Handler) get(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
+ id, err := strconv.ParseInt(params.ByName("id"), 10, 64)
+ if err != nil {
+ h.responseJSON(w, r, 400, err)
+ return
+ }
+ h.useCache(id)
+ h.storage.Serve(w, r, uint64(id))
+}
+
+// POST /_apis/artifactcache/clean
+func (h *Handler) clean(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
+ // TODO: don't support force deleting cache entries
+ // see: https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries
+
+ h.responseJSON(w, r, 200)
+}
+
+func (h *Handler) middleware(handler httprouter.Handle) httprouter.Handle {
+ return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
+ h.logger.Debugf("%s %s", r.Method, r.RequestURI)
+ handler(w, r, params)
+ go h.gcCache()
+ }
+}
+
+// if not found, return (nil, nil) instead of an error.
+func (h *Handler) findCache(db *bolthold.Store, keys []string, version string) (*Cache, error) {
+ if len(keys) == 0 {
+ return nil, nil
+ }
+ key := keys[0] // the first key is for exact match.
+
+ cache := &Cache{
+ Key: key,
+ Version: version,
+ }
+ cache.FillKeyVersionHash()
+
+ if err := db.FindOne(cache, bolthold.Where("KeyVersionHash").Eq(cache.KeyVersionHash)); err != nil {
+ if !errors.Is(err, bolthold.ErrNotFound) {
+ return nil, err
+ }
+ } else if cache.Complete {
+ return cache, nil
+ }
+ stop := fmt.Errorf("stop")
+
+ for _, prefix := range keys[1:] {
+ found := false
+ prefixPattern := fmt.Sprintf("^%s", regexp.QuoteMeta(prefix))
+ re, err := regexp.Compile(prefixPattern)
+ if err != nil {
+ continue
+ }
+ if err := db.ForEach(bolthold.Where("Key").RegExp(re).And("Version").Eq(version).SortBy("CreatedAt").Reverse(), func(v *Cache) error {
+ if !strings.HasPrefix(v.Key, prefix) {
+ return stop
+ }
+ if v.Complete {
+ cache = v
+ found = true
+ return stop
+ }
+ return nil
+ }); err != nil {
+ if !errors.Is(err, stop) {
+ return nil, err
+ }
+ }
+ if found {
+ return cache, nil
+ }
+ }
+ return nil, nil
+}
+
+func (h *Handler) useCache(id int64) {
+ db, err := h.openDB()
+ if err != nil {
+ return
+ }
+ defer db.Close()
+ cache := &Cache{}
+ if err := db.Get(id, cache); err != nil {
+ return
+ }
+ cache.UsedAt = time.Now().Unix()
+ _ = db.Update(cache.ID, cache)
+}
+
+func (h *Handler) gcCache() {
+ if atomic.LoadInt32(&h.gcing) != 0 {
+ return
+ }
+ if !atomic.CompareAndSwapInt32(&h.gcing, 0, 1) {
+ return
+ }
+ defer atomic.StoreInt32(&h.gcing, 0)
+
+ if time.Since(h.gcAt) < time.Hour {
+ h.logger.Debugf("skip gc: %v", h.gcAt.String())
+ return
+ }
+ h.gcAt = time.Now()
+ h.logger.Debugf("gc: %v", h.gcAt.String())
+
+ const (
+ keepUsed = 30 * 24 * time.Hour
+ keepUnused = 7 * 24 * time.Hour
+ keepTemp = 5 * time.Minute
+ )
+
+ db, err := h.openDB()
+ if err != nil {
+ return
+ }
+ defer db.Close()
+
+ var caches []*Cache
+ if err := db.Find(&caches, bolthold.Where("UsedAt").Lt(time.Now().Add(-keepTemp).Unix())); err != nil {
+ h.logger.Warnf("find caches: %v", err)
+ } else {
+ for _, cache := range caches {
+ if cache.Complete {
+ continue
+ }
+ h.storage.Remove(cache.ID)
+ if err := db.Delete(cache.ID, cache); err != nil {
+ h.logger.Warnf("delete cache: %v", err)
+ continue
+ }
+ h.logger.Infof("deleted cache: %+v", cache)
+ }
+ }
+
+ caches = caches[:0]
+ if err := db.Find(&caches, bolthold.Where("UsedAt").Lt(time.Now().Add(-keepUnused).Unix())); err != nil {
+ h.logger.Warnf("find caches: %v", err)
+ } else {
+ for _, cache := range caches {
+ h.storage.Remove(cache.ID)
+ if err := db.Delete(cache.ID, cache); err != nil {
+ h.logger.Warnf("delete cache: %v", err)
+ continue
+ }
+ h.logger.Infof("deleted cache: %+v", cache)
+ }
+ }
+
+ caches = caches[:0]
+ if err := db.Find(&caches, bolthold.Where("CreatedAt").Lt(time.Now().Add(-keepUsed).Unix())); err != nil {
+ h.logger.Warnf("find caches: %v", err)
+ } else {
+ for _, cache := range caches {
+ h.storage.Remove(cache.ID)
+ if err := db.Delete(cache.ID, cache); err != nil {
+ h.logger.Warnf("delete cache: %v", err)
+ continue
+ }
+ h.logger.Infof("deleted cache: %+v", cache)
+ }
+ }
+}
+
+func (h *Handler) responseJSON(w http.ResponseWriter, r *http.Request, code int, v ...any) {
+ w.Header().Set("Content-Type", "application/json; charset=utf-8")
+ var data []byte
+ if len(v) == 0 || v[0] == nil {
+ data, _ = json.Marshal(struct{}{})
+ } else if err, ok := v[0].(error); ok {
+ h.logger.Errorf("%v %v: %v", r.Method, r.RequestURI, err)
+ data, _ = json.Marshal(map[string]any{
+ "error": err.Error(),
+ })
+ } else {
+ data, _ = json.Marshal(v[0])
+ }
+ w.WriteHeader(code)
+ _, _ = w.Write(data)
+}
+
+func parseContentRange(s string) (int64, int64, error) {
+ // support the format like "bytes 11-22/*" only
+ s, _, _ = strings.Cut(strings.TrimPrefix(s, "bytes "), "/")
+ s1, s2, _ := strings.Cut(s, "-")
+
+ start, err := strconv.ParseInt(s1, 10, 64)
+ if err != nil {
+ return 0, 0, fmt.Errorf("parse %q: %w", s, err)
+ }
+ stop, err := strconv.ParseInt(s2, 10, 64)
+ if err != nil {
+ return 0, 0, fmt.Errorf("parse %q: %w", s, err)
+ }
+ return start, stop, nil
+}
diff --git a/pkg/artifactcache/handler_test.go b/pkg/artifactcache/handler_test.go
new file mode 100644
index 0000000..35ec753
--- /dev/null
+++ b/pkg/artifactcache/handler_test.go
@@ -0,0 +1,471 @@
+package artifactcache
+
+import (
+ "bytes"
+ "crypto/rand"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.etcd.io/bbolt"
+)
+
+func TestHandler(t *testing.T) {
+ dir := filepath.Join(t.TempDir(), "artifactcache")
+ handler, err := StartHandler(dir, "", 0, nil)
+ require.NoError(t, err)
+
+ base := fmt.Sprintf("%s%s", handler.ExternalURL(), urlBase)
+
+ defer func() {
+ t.Run("inpect db", func(t *testing.T) {
+ db, err := handler.openDB()
+ require.NoError(t, err)
+ defer db.Close()
+ require.NoError(t, db.Bolt().View(func(tx *bbolt.Tx) error {
+ return tx.Bucket([]byte("Cache")).ForEach(func(k, v []byte) error {
+ t.Logf("%s: %s", k, v)
+ return nil
+ })
+ }))
+ })
+ t.Run("close", func(t *testing.T) {
+ require.NoError(t, handler.Close())
+ assert.Nil(t, handler.server)
+ assert.Nil(t, handler.listener)
+ _, err := http.Post(fmt.Sprintf("%s/caches/%d", base, 1), "", nil)
+ assert.Error(t, err)
+ })
+ }()
+
+ t.Run("get not exist", func(t *testing.T) {
+ key := strings.ToLower(t.Name())
+ version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20"
+ resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version))
+ require.NoError(t, err)
+ require.Equal(t, 204, resp.StatusCode)
+ })
+
+ t.Run("reserve and upload", func(t *testing.T) {
+ key := strings.ToLower(t.Name())
+ version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20"
+ content := make([]byte, 100)
+ _, err := rand.Read(content)
+ require.NoError(t, err)
+ uploadCacheNormally(t, base, key, version, content)
+ })
+
+ t.Run("clean", func(t *testing.T) {
+ resp, err := http.Post(fmt.Sprintf("%s/clean", base), "", nil)
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+ })
+
+ t.Run("reserve with bad request", func(t *testing.T) {
+ body := []byte(`invalid json`)
+ require.NoError(t, err)
+ resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body))
+ require.NoError(t, err)
+ assert.Equal(t, 400, resp.StatusCode)
+ })
+
+ t.Run("duplicate reserve", func(t *testing.T) {
+ key := strings.ToLower(t.Name())
+ version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20"
+ {
+ body, err := json.Marshal(&Request{
+ Key: key,
+ Version: version,
+ Size: 100,
+ })
+ require.NoError(t, err)
+ resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body))
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+
+ got := struct {
+ CacheID uint64 `json:"cacheId"`
+ }{}
+ require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
+ }
+ {
+ body, err := json.Marshal(&Request{
+ Key: key,
+ Version: version,
+ Size: 100,
+ })
+ require.NoError(t, err)
+ resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body))
+ require.NoError(t, err)
+ assert.Equal(t, 400, resp.StatusCode)
+ }
+ })
+
+ t.Run("upload with bad id", func(t *testing.T) {
+ req, err := http.NewRequest(http.MethodPatch,
+ fmt.Sprintf("%s/caches/invalid_id", base), bytes.NewReader(nil))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/octet-stream")
+ req.Header.Set("Content-Range", "bytes 0-99/*")
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, 400, resp.StatusCode)
+ })
+
+ t.Run("upload without reserve", func(t *testing.T) {
+ req, err := http.NewRequest(http.MethodPatch,
+ fmt.Sprintf("%s/caches/%d", base, 1000), bytes.NewReader(nil))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/octet-stream")
+ req.Header.Set("Content-Range", "bytes 0-99/*")
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, 400, resp.StatusCode)
+ })
+
+ t.Run("upload with complete", func(t *testing.T) {
+ key := strings.ToLower(t.Name())
+ version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20"
+ var id uint64
+ content := make([]byte, 100)
+ _, err := rand.Read(content)
+ require.NoError(t, err)
+ {
+ body, err := json.Marshal(&Request{
+ Key: key,
+ Version: version,
+ Size: 100,
+ })
+ require.NoError(t, err)
+ resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body))
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+
+ got := struct {
+ CacheID uint64 `json:"cacheId"`
+ }{}
+ require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
+ id = got.CacheID
+ }
+ {
+ req, err := http.NewRequest(http.MethodPatch,
+ fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/octet-stream")
+ req.Header.Set("Content-Range", "bytes 0-99/*")
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+ }
+ {
+ resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil)
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+ }
+ {
+ req, err := http.NewRequest(http.MethodPatch,
+ fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/octet-stream")
+ req.Header.Set("Content-Range", "bytes 0-99/*")
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, 400, resp.StatusCode)
+ }
+ })
+
+ t.Run("upload with invalid range", func(t *testing.T) {
+ key := strings.ToLower(t.Name())
+ version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20"
+ var id uint64
+ content := make([]byte, 100)
+ _, err := rand.Read(content)
+ require.NoError(t, err)
+ {
+ body, err := json.Marshal(&Request{
+ Key: key,
+ Version: version,
+ Size: 100,
+ })
+ require.NoError(t, err)
+ resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body))
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+
+ got := struct {
+ CacheID uint64 `json:"cacheId"`
+ }{}
+ require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
+ id = got.CacheID
+ }
+ {
+ req, err := http.NewRequest(http.MethodPatch,
+ fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/octet-stream")
+ req.Header.Set("Content-Range", "bytes xx-99/*")
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, 400, resp.StatusCode)
+ }
+ })
+
+ t.Run("commit with bad id", func(t *testing.T) {
+ {
+ resp, err := http.Post(fmt.Sprintf("%s/caches/invalid_id", base), "", nil)
+ require.NoError(t, err)
+ assert.Equal(t, 400, resp.StatusCode)
+ }
+ })
+
+ t.Run("commit with not exist id", func(t *testing.T) {
+ {
+ resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, 100), "", nil)
+ require.NoError(t, err)
+ assert.Equal(t, 400, resp.StatusCode)
+ }
+ })
+
+ t.Run("duplicate commit", func(t *testing.T) {
+ key := strings.ToLower(t.Name())
+ version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20"
+ var id uint64
+ content := make([]byte, 100)
+ _, err := rand.Read(content)
+ require.NoError(t, err)
+ {
+ body, err := json.Marshal(&Request{
+ Key: key,
+ Version: version,
+ Size: 100,
+ })
+ require.NoError(t, err)
+ resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body))
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+
+ got := struct {
+ CacheID uint64 `json:"cacheId"`
+ }{}
+ require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
+ id = got.CacheID
+ }
+ {
+ req, err := http.NewRequest(http.MethodPatch,
+ fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/octet-stream")
+ req.Header.Set("Content-Range", "bytes 0-99/*")
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+ }
+ {
+ resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil)
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+ }
+ {
+ resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil)
+ require.NoError(t, err)
+ assert.Equal(t, 400, resp.StatusCode)
+ }
+ })
+
+ t.Run("commit early", func(t *testing.T) {
+ key := strings.ToLower(t.Name())
+ version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20"
+ var id uint64
+ content := make([]byte, 100)
+ _, err := rand.Read(content)
+ require.NoError(t, err)
+ {
+ body, err := json.Marshal(&Request{
+ Key: key,
+ Version: version,
+ Size: 100,
+ })
+ require.NoError(t, err)
+ resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body))
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+
+ got := struct {
+ CacheID uint64 `json:"cacheId"`
+ }{}
+ require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
+ id = got.CacheID
+ }
+ {
+ req, err := http.NewRequest(http.MethodPatch,
+ fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content[:50]))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/octet-stream")
+ req.Header.Set("Content-Range", "bytes 0-59/*")
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+ }
+ {
+ resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil)
+ require.NoError(t, err)
+ assert.Equal(t, 500, resp.StatusCode)
+ }
+ })
+
+ t.Run("get with bad id", func(t *testing.T) {
+ resp, err := http.Get(fmt.Sprintf("%s/artifacts/invalid_id", base))
+ require.NoError(t, err)
+ require.Equal(t, 400, resp.StatusCode)
+ })
+
+ t.Run("get with not exist id", func(t *testing.T) {
+ resp, err := http.Get(fmt.Sprintf("%s/artifacts/%d", base, 100))
+ require.NoError(t, err)
+ require.Equal(t, 404, resp.StatusCode)
+ })
+
+ t.Run("get with not exist id", func(t *testing.T) {
+ resp, err := http.Get(fmt.Sprintf("%s/artifacts/%d", base, 100))
+ require.NoError(t, err)
+ require.Equal(t, 404, resp.StatusCode)
+ })
+
+ t.Run("get with multiple keys", func(t *testing.T) {
+ version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20"
+ key := strings.ToLower(t.Name())
+ keys := [3]string{
+ key + "_a",
+ key + "_a_b",
+ key + "_a_b_c",
+ }
+ contents := [3][]byte{
+ make([]byte, 100),
+ make([]byte, 200),
+ make([]byte, 300),
+ }
+ for i := range contents {
+ _, err := rand.Read(contents[i])
+ require.NoError(t, err)
+ uploadCacheNormally(t, base, keys[i], version, contents[i])
+ }
+
+ reqKeys := strings.Join([]string{
+ key + "_a_b_x",
+ key + "_a_b",
+ key + "_a",
+ }, ",")
+ var archiveLocation string
+ {
+ resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKeys, version))
+ require.NoError(t, err)
+ require.Equal(t, 200, resp.StatusCode)
+ got := struct {
+ Result string `json:"result"`
+ ArchiveLocation string `json:"archiveLocation"`
+ CacheKey string `json:"cacheKey"`
+ }{}
+ require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
+ assert.Equal(t, "hit", got.Result)
+ assert.Equal(t, keys[1], got.CacheKey)
+ archiveLocation = got.ArchiveLocation
+ }
+ {
+ resp, err := http.Get(archiveLocation) //nolint:gosec
+ require.NoError(t, err)
+ require.Equal(t, 200, resp.StatusCode)
+ got, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Equal(t, contents[1], got)
+ }
+ })
+
+ t.Run("case insensitive", func(t *testing.T) {
+ version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20"
+ key := strings.ToLower(t.Name())
+ content := make([]byte, 100)
+ _, err := rand.Read(content)
+ require.NoError(t, err)
+ uploadCacheNormally(t, base, key+"_ABC", version, content)
+
+ {
+ reqKey := key + "_aBc"
+ resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKey, version))
+ require.NoError(t, err)
+ require.Equal(t, 200, resp.StatusCode)
+ got := struct {
+ Result string `json:"result"`
+ ArchiveLocation string `json:"archiveLocation"`
+ CacheKey string `json:"cacheKey"`
+ }{}
+ require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
+ assert.Equal(t, "hit", got.Result)
+ assert.Equal(t, key+"_abc", got.CacheKey)
+ }
+ })
+}
+
+func uploadCacheNormally(t *testing.T, base, key, version string, content []byte) {
+ var id uint64
+ {
+ body, err := json.Marshal(&Request{
+ Key: key,
+ Version: version,
+ Size: int64(len(content)),
+ })
+ require.NoError(t, err)
+ resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body))
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+
+ got := struct {
+ CacheID uint64 `json:"cacheId"`
+ }{}
+ require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
+ id = got.CacheID
+ }
+ {
+ req, err := http.NewRequest(http.MethodPatch,
+ fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/octet-stream")
+ req.Header.Set("Content-Range", "bytes 0-99/*")
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+ }
+ {
+ resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil)
+ require.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+ }
+ var archiveLocation string
+ {
+ resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version))
+ require.NoError(t, err)
+ require.Equal(t, 200, resp.StatusCode)
+ got := struct {
+ Result string `json:"result"`
+ ArchiveLocation string `json:"archiveLocation"`
+ CacheKey string `json:"cacheKey"`
+ }{}
+ require.NoError(t, json.NewDecoder(resp.Body).Decode(&got))
+ assert.Equal(t, "hit", got.Result)
+ assert.Equal(t, strings.ToLower(key), got.CacheKey)
+ archiveLocation = got.ArchiveLocation
+ }
+ {
+ resp, err := http.Get(archiveLocation) //nolint:gosec
+ require.NoError(t, err)
+ require.Equal(t, 200, resp.StatusCode)
+ got, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Equal(t, content, got)
+ }
+}
diff --git a/pkg/artifactcache/model.go b/pkg/artifactcache/model.go
new file mode 100644
index 0000000..32b8ce5
--- /dev/null
+++ b/pkg/artifactcache/model.go
@@ -0,0 +1,44 @@
+package artifactcache
+
+import (
+ "crypto/sha256"
+ "fmt"
+)
+
+type Request struct {
+ Key string `json:"key" `
+ Version string `json:"version"`
+ Size int64 `json:"cacheSize"`
+}
+
+func (c *Request) ToCache() *Cache {
+ if c == nil {
+ return nil
+ }
+ ret := &Cache{
+ Key: c.Key,
+ Version: c.Version,
+ Size: c.Size,
+ }
+ if c.Size == 0 {
+ // So the request comes from old versions of actions, like `actions/cache@v2`.
+ // It doesn't send cache size. Set it to -1 to indicate that.
+ ret.Size = -1
+ }
+ return ret
+}
+
+type Cache struct {
+ ID uint64 `json:"id" boltholdKey:"ID"`
+ Key string `json:"key" boltholdIndex:"Key"`
+ Version string `json:"version" boltholdIndex:"Version"`
+ KeyVersionHash string `json:"keyVersionHash" boltholdUnique:"KeyVersionHash"`
+ Size int64 `json:"cacheSize"`
+ Complete bool `json:"complete"`
+ UsedAt int64 `json:"usedAt" boltholdIndex:"UsedAt"`
+ CreatedAt int64 `json:"createdAt" boltholdIndex:"CreatedAt"`
+}
+
+func (c *Cache) FillKeyVersionHash() {
+ c.KeyVersionHash = fmt.Sprintf("%x", sha256.Sum256([]byte(fmt.Sprintf("%s:%s", c.Key, c.Version))))
+}
diff --git a/pkg/artifactcache/storage.go b/pkg/artifactcache/storage.go
new file mode 100644
index 0000000..9a2609a
--- /dev/null
+++ b/pkg/artifactcache/storage.go
@@ -0,0 +1,130 @@
+package artifactcache
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+)
+
+type Storage struct {
+ rootDir string
+}
+
+func NewStorage(rootDir string) (*Storage, error) {
+ if err := os.MkdirAll(rootDir, 0o755); err != nil {
+ return nil, err
+ }
+ return &Storage{
+ rootDir: rootDir,
+ }, nil
+}
+
+func (s *Storage) Exist(id uint64) (bool, error) {
+ name := s.filename(id)
+ if _, err := os.Stat(name); os.IsNotExist(err) {
+ return false, nil
+ } else if err != nil {
+ return false, err
+ }
+ return true, nil
+}
+
+func (s *Storage) Write(id uint64, offset int64, reader io.Reader) error {
+ name := s.tempName(id, offset)
+ if err := os.MkdirAll(filepath.Dir(name), 0o755); err != nil {
+ return err
+ }
+ file, err := os.Create(name)
+ if err != nil {
+ return err
+ }
+ defer file.Close()
+
+ _, err = io.Copy(file, reader)
+ return err
+}
+
+func (s *Storage) Commit(id uint64, size int64) (int64, error) {
+ defer func() {
+ _ = os.RemoveAll(s.tempDir(id))
+ }()
+
+ name := s.filename(id)
+ tempNames, err := s.tempNames(id)
+ if err != nil {
+ return 0, err
+ }
+
+ if err := os.MkdirAll(filepath.Dir(name), 0o755); err != nil {
+ return 0, err
+ }
+ file, err := os.Create(name)
+ if err != nil {
+ return 0, err
+ }
+ defer file.Close()
+
+ var written int64
+ for _, v := range tempNames {
+ f, err := os.Open(v)
+ if err != nil {
+ return 0, err
+ }
+ n, err := io.Copy(file, f)
+ _ = f.Close()
+ if err != nil {
+ return 0, err
+ }
+ written += n
+ }
+
+ // If size is less than 0, it means the size is unknown.
+ // We can't check the size of the file, just skip the check.
+ // It happens when the request comes from old versions of actions, like `actions/cache@v2`.
+ if size >= 0 && written != size {
+ _ = file.Close()
+ _ = os.Remove(name)
+ return 0, fmt.Errorf("broken file: %v != %v", written, size)
+ }
+
+ return written, nil
+}
+
+func (s *Storage) Serve(w http.ResponseWriter, r *http.Request, id uint64) {
+ name := s.filename(id)
+ http.ServeFile(w, r, name)
+}
+
+func (s *Storage) Remove(id uint64) {
+ _ = os.Remove(s.filename(id))
+ _ = os.RemoveAll(s.tempDir(id))
+}
+
+func (s *Storage) filename(id uint64) string {
+ return filepath.Join(s.rootDir, fmt.Sprintf("%02x", id%0xff), fmt.Sprint(id))
+}
+
+func (s *Storage) tempDir(id uint64) string {
+ return filepath.Join(s.rootDir, "tmp", fmt.Sprint(id))
+}
+
+func (s *Storage) tempName(id uint64, offset int64) string {
+ return filepath.Join(s.tempDir(id), fmt.Sprintf("%016x", offset))
+}
+
+func (s *Storage) tempNames(id uint64) ([]string, error) {
+ dir := s.tempDir(id)
+ files, err := os.ReadDir(dir)
+ if err != nil {
+ return nil, err
+ }
+ var names []string
+ for _, v := range files {
+ if !v.IsDir() {
+ names = append(names, filepath.Join(dir, v.Name()))
+ }
+ }
+ return names, nil
+}
diff --git a/pkg/artifactcache/testdata/example/example.yaml b/pkg/artifactcache/testdata/example/example.yaml
new file mode 100644
index 0000000..5332e72
--- /dev/null
+++ b/pkg/artifactcache/testdata/example/example.yaml
@@ -0,0 +1,30 @@
+# Copied from https://github.com/actions/cache#example-cache-workflow
+name: Caching Primes
+
+on: push
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+
+ steps:
+ - run: env
+
+ - uses: actions/checkout@v3
+
+ - name: Cache Primes
+ id: cache-primes
+ uses: actions/cache@v3
+ with:
+ path: prime-numbers
+ key: ${{ runner.os }}-primes-${{ github.run_id }}
+ restore-keys: |
+ ${{ runner.os }}-primes
+ ${{ runner.os }}
+
+ - name: Generate Prime Numbers
+ if: steps.cache-primes.outputs.cache-hit != 'true'
+ run: cat /proc/sys/kernel/random/uuid > prime-numbers
+
+ - name: Use Prime Numbers
+ run: cat prime-numbers