summaryrefslogtreecommitdiffstats
path: root/internal/app/poll/poller.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/app/poll/poller.go')
-rw-r--r--internal/app/poll/poller.go167
1 files changed, 167 insertions, 0 deletions
diff --git a/internal/app/poll/poller.go b/internal/app/poll/poller.go
new file mode 100644
index 0000000..cc89fa5
--- /dev/null
+++ b/internal/app/poll/poller.go
@@ -0,0 +1,167 @@
+// Copyright 2023 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package poll
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ runnerv1 "code.gitea.io/actions-proto-go/runner/v1"
+ "connectrpc.com/connect"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/time/rate"
+
+ "gitea.com/gitea/act_runner/internal/app/run"
+ "gitea.com/gitea/act_runner/internal/pkg/client"
+ "gitea.com/gitea/act_runner/internal/pkg/config"
+)
+
+const PollerID = "PollerID"
+
+type Poller interface {
+ Poll()
+ Shutdown(ctx context.Context) error
+}
+
+type poller struct {
+ client client.Client
+ runner run.RunnerInterface
+ cfg *config.Config
+ tasksVersion atomic.Int64 // tasksVersion used to store the version of the last task fetched from the Gitea.
+
+ pollingCtx context.Context
+ shutdownPolling context.CancelFunc
+
+ jobsCtx context.Context
+ shutdownJobs context.CancelFunc
+
+ done chan any
+}
+
+func New(cfg *config.Config, client client.Client, runner run.RunnerInterface) Poller {
+ return (&poller{}).init(cfg, client, runner)
+}
+
+func (p *poller) init(cfg *config.Config, client client.Client, runner run.RunnerInterface) Poller {
+ pollingCtx, shutdownPolling := context.WithCancel(context.Background())
+
+ jobsCtx, shutdownJobs := context.WithCancel(context.Background())
+
+ done := make(chan any)
+
+ p.client = client
+ p.runner = runner
+ p.cfg = cfg
+
+ p.pollingCtx = pollingCtx
+ p.shutdownPolling = shutdownPolling
+
+ p.jobsCtx = jobsCtx
+ p.shutdownJobs = shutdownJobs
+ p.done = done
+
+ return p
+}
+
+func (p *poller) Poll() {
+ limiter := rate.NewLimiter(rate.Every(p.cfg.Runner.FetchInterval), 1)
+ wg := &sync.WaitGroup{}
+ for i := 0; i < p.cfg.Runner.Capacity; i++ {
+ wg.Add(1)
+ go p.poll(i, wg, limiter)
+ }
+ wg.Wait()
+
+ // signal the poller is finished
+ close(p.done)
+}
+
+func (p *poller) Shutdown(ctx context.Context) error {
+ p.shutdownPolling()
+
+ select {
+ case <-p.done:
+ log.Trace("all jobs are complete")
+ return nil
+
+ case <-ctx.Done():
+ log.Trace("forcing the jobs to shutdown")
+ p.shutdownJobs()
+ <-p.done
+ log.Trace("all jobs have been shutdown")
+ return ctx.Err()
+ }
+}
+
+func (p *poller) poll(id int, wg *sync.WaitGroup, limiter *rate.Limiter) {
+ log.Infof("[poller %d] launched", id)
+ defer wg.Done()
+ for {
+ if err := limiter.Wait(p.pollingCtx); err != nil {
+ log.Infof("[poller %d] shutdown", id)
+ return
+ }
+ task, ok := p.fetchTask(p.pollingCtx)
+ if !ok {
+ continue
+ }
+ p.runTaskWithRecover(p.jobsCtx, task)
+ }
+}
+
+func (p *poller) runTaskWithRecover(ctx context.Context, task *runnerv1.Task) {
+ defer func() {
+ if r := recover(); r != nil {
+ err := fmt.Errorf("panic: %v", r)
+ log.WithError(err).Error("panic in runTaskWithRecover")
+ }
+ }()
+
+ if err := p.runner.Run(ctx, task); err != nil {
+ log.WithError(err).Error("failed to run task")
+ }
+}
+
+func (p *poller) fetchTask(ctx context.Context) (*runnerv1.Task, bool) {
+ reqCtx, cancel := context.WithTimeout(ctx, p.cfg.Runner.FetchTimeout)
+ defer cancel()
+
+ // Load the version value that was in the cache when the request was sent.
+ v := p.tasksVersion.Load()
+ resp, err := p.client.FetchTask(reqCtx, connect.NewRequest(&runnerv1.FetchTaskRequest{
+ TasksVersion: v,
+ }))
+ if errors.Is(err, context.DeadlineExceeded) {
+ log.Trace("deadline exceeded")
+ err = nil
+ }
+ if err != nil {
+ if errors.Is(err, context.Canceled) {
+ log.WithError(err).Debugf("shutdown, fetch task canceled")
+ } else {
+ log.WithError(err).Error("failed to fetch task")
+ }
+ return nil, false
+ }
+
+ if resp == nil || resp.Msg == nil {
+ return nil, false
+ }
+
+ if resp.Msg.TasksVersion > v {
+ p.tasksVersion.CompareAndSwap(v, resp.Msg.TasksVersion)
+ }
+
+ if resp.Msg.Task == nil {
+ return nil, false
+ }
+
+ // got a task, set `tasksVersion` to zero to focre query db in next request.
+ p.tasksVersion.CompareAndSwap(resp.Msg.TasksVersion, 0)
+
+ return resp.Msg.Task, true
+}