summaryrefslogtreecommitdiffstats
path: root/internal/app/poll
diff options
context:
space:
mode:
Diffstat (limited to 'internal/app/poll')
-rw-r--r--internal/app/poll/poller.go167
-rw-r--r--internal/app/poll/poller_test.go263
2 files changed, 430 insertions, 0 deletions
diff --git a/internal/app/poll/poller.go b/internal/app/poll/poller.go
new file mode 100644
index 0000000..cc89fa5
--- /dev/null
+++ b/internal/app/poll/poller.go
@@ -0,0 +1,167 @@
+// Copyright 2023 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package poll
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ runnerv1 "code.gitea.io/actions-proto-go/runner/v1"
+ "connectrpc.com/connect"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/time/rate"
+
+ "gitea.com/gitea/act_runner/internal/app/run"
+ "gitea.com/gitea/act_runner/internal/pkg/client"
+ "gitea.com/gitea/act_runner/internal/pkg/config"
+)
+
+const PollerID = "PollerID"
+
+type Poller interface {
+ Poll()
+ Shutdown(ctx context.Context) error
+}
+
+type poller struct {
+ client client.Client
+ runner run.RunnerInterface
+ cfg *config.Config
+ tasksVersion atomic.Int64 // tasksVersion used to store the version of the last task fetched from the Gitea.
+
+ pollingCtx context.Context
+ shutdownPolling context.CancelFunc
+
+ jobsCtx context.Context
+ shutdownJobs context.CancelFunc
+
+ done chan any
+}
+
+func New(cfg *config.Config, client client.Client, runner run.RunnerInterface) Poller {
+ return (&poller{}).init(cfg, client, runner)
+}
+
+func (p *poller) init(cfg *config.Config, client client.Client, runner run.RunnerInterface) Poller {
+ pollingCtx, shutdownPolling := context.WithCancel(context.Background())
+
+ jobsCtx, shutdownJobs := context.WithCancel(context.Background())
+
+ done := make(chan any)
+
+ p.client = client
+ p.runner = runner
+ p.cfg = cfg
+
+ p.pollingCtx = pollingCtx
+ p.shutdownPolling = shutdownPolling
+
+ p.jobsCtx = jobsCtx
+ p.shutdownJobs = shutdownJobs
+ p.done = done
+
+ return p
+}
+
+func (p *poller) Poll() {
+ limiter := rate.NewLimiter(rate.Every(p.cfg.Runner.FetchInterval), 1)
+ wg := &sync.WaitGroup{}
+ for i := 0; i < p.cfg.Runner.Capacity; i++ {
+ wg.Add(1)
+ go p.poll(i, wg, limiter)
+ }
+ wg.Wait()
+
+ // signal the poller is finished
+ close(p.done)
+}
+
+func (p *poller) Shutdown(ctx context.Context) error {
+ p.shutdownPolling()
+
+ select {
+ case <-p.done:
+ log.Trace("all jobs are complete")
+ return nil
+
+ case <-ctx.Done():
+ log.Trace("forcing the jobs to shutdown")
+ p.shutdownJobs()
+ <-p.done
+ log.Trace("all jobs have been shutdown")
+ return ctx.Err()
+ }
+}
+
+func (p *poller) poll(id int, wg *sync.WaitGroup, limiter *rate.Limiter) {
+ log.Infof("[poller %d] launched", id)
+ defer wg.Done()
+ for {
+ if err := limiter.Wait(p.pollingCtx); err != nil {
+ log.Infof("[poller %d] shutdown", id)
+ return
+ }
+ task, ok := p.fetchTask(p.pollingCtx)
+ if !ok {
+ continue
+ }
+ p.runTaskWithRecover(p.jobsCtx, task)
+ }
+}
+
+func (p *poller) runTaskWithRecover(ctx context.Context, task *runnerv1.Task) {
+ defer func() {
+ if r := recover(); r != nil {
+ err := fmt.Errorf("panic: %v", r)
+ log.WithError(err).Error("panic in runTaskWithRecover")
+ }
+ }()
+
+ if err := p.runner.Run(ctx, task); err != nil {
+ log.WithError(err).Error("failed to run task")
+ }
+}
+
+func (p *poller) fetchTask(ctx context.Context) (*runnerv1.Task, bool) {
+ reqCtx, cancel := context.WithTimeout(ctx, p.cfg.Runner.FetchTimeout)
+ defer cancel()
+
+ // Load the version value that was in the cache when the request was sent.
+ v := p.tasksVersion.Load()
+ resp, err := p.client.FetchTask(reqCtx, connect.NewRequest(&runnerv1.FetchTaskRequest{
+ TasksVersion: v,
+ }))
+ if errors.Is(err, context.DeadlineExceeded) {
+ log.Trace("deadline exceeded")
+ err = nil
+ }
+ if err != nil {
+ if errors.Is(err, context.Canceled) {
+ log.WithError(err).Debugf("shutdown, fetch task canceled")
+ } else {
+ log.WithError(err).Error("failed to fetch task")
+ }
+ return nil, false
+ }
+
+ if resp == nil || resp.Msg == nil {
+ return nil, false
+ }
+
+ if resp.Msg.TasksVersion > v {
+ p.tasksVersion.CompareAndSwap(v, resp.Msg.TasksVersion)
+ }
+
+ if resp.Msg.Task == nil {
+ return nil, false
+ }
+
+ // got a task, set `tasksVersion` to zero to focre query db in next request.
+ p.tasksVersion.CompareAndSwap(resp.Msg.TasksVersion, 0)
+
+ return resp.Msg.Task, true
+}
diff --git a/internal/app/poll/poller_test.go b/internal/app/poll/poller_test.go
new file mode 100644
index 0000000..04b1a84
--- /dev/null
+++ b/internal/app/poll/poller_test.go
@@ -0,0 +1,263 @@
+// Copyright The Forgejo Authors.
+// SPDX-License-Identifier: MIT
+
+package poll
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "connectrpc.com/connect"
+
+ "code.gitea.io/actions-proto-go/ping/v1/pingv1connect"
+ runnerv1 "code.gitea.io/actions-proto-go/runner/v1"
+ "code.gitea.io/actions-proto-go/runner/v1/runnerv1connect"
+ "gitea.com/gitea/act_runner/internal/pkg/config"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+)
+
+type mockPoller struct {
+ poller
+}
+
+func (o *mockPoller) Poll() {
+ o.poller.Poll()
+}
+
+type mockClient struct {
+ pingv1connect.PingServiceClient
+ runnerv1connect.RunnerServiceClient
+
+ sleep time.Duration
+ cancel bool
+ err error
+ noTask bool
+}
+
+func (o mockClient) Address() string {
+ return ""
+}
+
+func (o mockClient) Insecure() bool {
+ return true
+}
+
+func (o *mockClient) FetchTask(ctx context.Context, req *connect.Request[runnerv1.FetchTaskRequest]) (*connect.Response[runnerv1.FetchTaskResponse], error) {
+ if o.sleep > 0 {
+ select {
+ case <-ctx.Done():
+ log.Trace("fetch task done")
+ return nil, context.DeadlineExceeded
+ case <-time.After(o.sleep):
+ log.Trace("slept")
+ return nil, fmt.Errorf("unexpected")
+ }
+ }
+ if o.cancel {
+ return nil, context.Canceled
+ }
+ if o.err != nil {
+ return nil, o.err
+ }
+ task := &runnerv1.Task{}
+ if o.noTask {
+ task = nil
+ o.noTask = false
+ }
+
+ return connect.NewResponse(&runnerv1.FetchTaskResponse{
+ Task: task,
+ TasksVersion: int64(1),
+ }), nil
+}
+
+type mockRunner struct {
+ cfg *config.Runner
+ log chan string
+ panics bool
+ err error
+}
+
+func (o *mockRunner) Run(ctx context.Context, task *runnerv1.Task) error {
+ o.log <- "runner starts"
+ if o.panics {
+ log.Trace("panics")
+ o.log <- "runner panics"
+ o.panics = false
+ panic("whatever")
+ }
+ if o.err != nil {
+ log.Trace("error")
+ o.log <- "runner error"
+ err := o.err
+ o.err = nil
+ return err
+ }
+ for {
+ select {
+ case <-ctx.Done():
+ log.Trace("shutdown")
+ o.log <- "runner shutdown"
+ return nil
+ case <-time.After(o.cfg.Timeout):
+ log.Trace("after")
+ o.log <- "runner timeout"
+ return nil
+ }
+ }
+}
+
+func setTrace(t *testing.T) {
+ t.Helper()
+ log.SetReportCaller(true)
+ log.SetLevel(log.TraceLevel)
+}
+
+func TestPoller_New(t *testing.T) {
+ p := New(&config.Config{}, &mockClient{}, &mockRunner{})
+ assert.NotNil(t, p)
+}
+
+func TestPoller_Runner(t *testing.T) {
+ setTrace(t)
+ for _, testCase := range []struct {
+ name string
+ timeout time.Duration
+ noTask bool
+ panics bool
+ err error
+ expected string
+ contextTimeout time.Duration
+ }{
+ {
+ name: "Simple",
+ timeout: 10 * time.Second,
+ expected: "runner shutdown",
+ },
+ {
+ name: "Panics",
+ timeout: 10 * time.Second,
+ panics: true,
+ expected: "runner panics",
+ },
+ {
+ name: "Error",
+ timeout: 10 * time.Second,
+ err: fmt.Errorf("ERROR"),
+ expected: "runner error",
+ },
+ {
+ name: "PollTaskError",
+ timeout: 10 * time.Second,
+ noTask: true,
+ expected: "runner shutdown",
+ },
+ {
+ name: "ShutdownTimeout",
+ timeout: 1 * time.Second,
+ contextTimeout: 1 * time.Minute,
+ expected: "runner timeout",
+ },
+ } {
+ t.Run(testCase.name, func(t *testing.T) {
+ runnerLog := make(chan string, 3)
+ configRunner := config.Runner{
+ FetchInterval: 1,
+ Capacity: 1,
+ Timeout: testCase.timeout,
+ }
+ p := &mockPoller{}
+ p.init(
+ &config.Config{
+ Runner: configRunner,
+ },
+ &mockClient{
+ noTask: testCase.noTask,
+ },
+ &mockRunner{
+ cfg: &configRunner,
+ log: runnerLog,
+ panics: testCase.panics,
+ err: testCase.err,
+ })
+ go p.Poll()
+ assert.Equal(t, "runner starts", <-runnerLog)
+ var ctx context.Context
+ var cancel context.CancelFunc
+ if testCase.contextTimeout > 0 {
+ ctx, cancel = context.WithTimeout(context.Background(), testCase.contextTimeout)
+ defer cancel()
+ } else {
+ ctx, cancel = context.WithCancel(context.Background())
+ cancel()
+ }
+ p.Shutdown(ctx)
+ <-p.done
+ assert.Equal(t, testCase.expected, <-runnerLog)
+ })
+ }
+}
+
+func TestPoller_Fetch(t *testing.T) {
+ setTrace(t)
+ for _, testCase := range []struct {
+ name string
+ noTask bool
+ sleep time.Duration
+ err error
+ cancel bool
+ success bool
+ }{
+ {
+ name: "Success",
+ success: true,
+ },
+ {
+ name: "Timeout",
+ sleep: 100 * time.Millisecond,
+ },
+ {
+ name: "Canceled",
+ cancel: true,
+ },
+ {
+ name: "NoTask",
+ noTask: true,
+ },
+ {
+ name: "Error",
+ err: fmt.Errorf("random error"),
+ },
+ } {
+ t.Run(testCase.name, func(t *testing.T) {
+ configRunner := config.Runner{
+ FetchTimeout: 1 * time.Millisecond,
+ }
+ p := &mockPoller{}
+ p.init(
+ &config.Config{
+ Runner: configRunner,
+ },
+ &mockClient{
+ sleep: testCase.sleep,
+ cancel: testCase.cancel,
+ noTask: testCase.noTask,
+ err: testCase.err,
+ },
+ &mockRunner{},
+ )
+ task, ok := p.fetchTask(context.Background())
+ if testCase.success {
+ assert.True(t, ok)
+ assert.NotNil(t, task)
+ } else {
+ assert.False(t, ok)
+ assert.Nil(t, task)
+ }
+ })
+ }
+}