diff options
Diffstat (limited to 'internal/app/poll')
-rw-r--r-- | internal/app/poll/poller.go | 167 | ||||
-rw-r--r-- | internal/app/poll/poller_test.go | 263 |
2 files changed, 430 insertions, 0 deletions
diff --git a/internal/app/poll/poller.go b/internal/app/poll/poller.go new file mode 100644 index 0000000..cc89fa5 --- /dev/null +++ b/internal/app/poll/poller.go @@ -0,0 +1,167 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package poll + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + + runnerv1 "code.gitea.io/actions-proto-go/runner/v1" + "connectrpc.com/connect" + log "github.com/sirupsen/logrus" + "golang.org/x/time/rate" + + "gitea.com/gitea/act_runner/internal/app/run" + "gitea.com/gitea/act_runner/internal/pkg/client" + "gitea.com/gitea/act_runner/internal/pkg/config" +) + +const PollerID = "PollerID" + +type Poller interface { + Poll() + Shutdown(ctx context.Context) error +} + +type poller struct { + client client.Client + runner run.RunnerInterface + cfg *config.Config + tasksVersion atomic.Int64 // tasksVersion used to store the version of the last task fetched from the Gitea. + + pollingCtx context.Context + shutdownPolling context.CancelFunc + + jobsCtx context.Context + shutdownJobs context.CancelFunc + + done chan any +} + +func New(cfg *config.Config, client client.Client, runner run.RunnerInterface) Poller { + return (&poller{}).init(cfg, client, runner) +} + +func (p *poller) init(cfg *config.Config, client client.Client, runner run.RunnerInterface) Poller { + pollingCtx, shutdownPolling := context.WithCancel(context.Background()) + + jobsCtx, shutdownJobs := context.WithCancel(context.Background()) + + done := make(chan any) + + p.client = client + p.runner = runner + p.cfg = cfg + + p.pollingCtx = pollingCtx + p.shutdownPolling = shutdownPolling + + p.jobsCtx = jobsCtx + p.shutdownJobs = shutdownJobs + p.done = done + + return p +} + +func (p *poller) Poll() { + limiter := rate.NewLimiter(rate.Every(p.cfg.Runner.FetchInterval), 1) + wg := &sync.WaitGroup{} + for i := 0; i < p.cfg.Runner.Capacity; i++ { + wg.Add(1) + go p.poll(i, wg, limiter) + } + wg.Wait() + + // signal the poller is finished + close(p.done) +} + +func (p *poller) Shutdown(ctx context.Context) error { + p.shutdownPolling() + + select { + case <-p.done: + log.Trace("all jobs are complete") + return nil + + case <-ctx.Done(): + log.Trace("forcing the jobs to shutdown") + p.shutdownJobs() + <-p.done + log.Trace("all jobs have been shutdown") + return ctx.Err() + } +} + +func (p *poller) poll(id int, wg *sync.WaitGroup, limiter *rate.Limiter) { + log.Infof("[poller %d] launched", id) + defer wg.Done() + for { + if err := limiter.Wait(p.pollingCtx); err != nil { + log.Infof("[poller %d] shutdown", id) + return + } + task, ok := p.fetchTask(p.pollingCtx) + if !ok { + continue + } + p.runTaskWithRecover(p.jobsCtx, task) + } +} + +func (p *poller) runTaskWithRecover(ctx context.Context, task *runnerv1.Task) { + defer func() { + if r := recover(); r != nil { + err := fmt.Errorf("panic: %v", r) + log.WithError(err).Error("panic in runTaskWithRecover") + } + }() + + if err := p.runner.Run(ctx, task); err != nil { + log.WithError(err).Error("failed to run task") + } +} + +func (p *poller) fetchTask(ctx context.Context) (*runnerv1.Task, bool) { + reqCtx, cancel := context.WithTimeout(ctx, p.cfg.Runner.FetchTimeout) + defer cancel() + + // Load the version value that was in the cache when the request was sent. + v := p.tasksVersion.Load() + resp, err := p.client.FetchTask(reqCtx, connect.NewRequest(&runnerv1.FetchTaskRequest{ + TasksVersion: v, + })) + if errors.Is(err, context.DeadlineExceeded) { + log.Trace("deadline exceeded") + err = nil + } + if err != nil { + if errors.Is(err, context.Canceled) { + log.WithError(err).Debugf("shutdown, fetch task canceled") + } else { + log.WithError(err).Error("failed to fetch task") + } + return nil, false + } + + if resp == nil || resp.Msg == nil { + return nil, false + } + + if resp.Msg.TasksVersion > v { + p.tasksVersion.CompareAndSwap(v, resp.Msg.TasksVersion) + } + + if resp.Msg.Task == nil { + return nil, false + } + + // got a task, set `tasksVersion` to zero to focre query db in next request. + p.tasksVersion.CompareAndSwap(resp.Msg.TasksVersion, 0) + + return resp.Msg.Task, true +} diff --git a/internal/app/poll/poller_test.go b/internal/app/poll/poller_test.go new file mode 100644 index 0000000..04b1a84 --- /dev/null +++ b/internal/app/poll/poller_test.go @@ -0,0 +1,263 @@ +// Copyright The Forgejo Authors. +// SPDX-License-Identifier: MIT + +package poll + +import ( + "context" + "fmt" + "testing" + "time" + + "connectrpc.com/connect" + + "code.gitea.io/actions-proto-go/ping/v1/pingv1connect" + runnerv1 "code.gitea.io/actions-proto-go/runner/v1" + "code.gitea.io/actions-proto-go/runner/v1/runnerv1connect" + "gitea.com/gitea/act_runner/internal/pkg/config" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +type mockPoller struct { + poller +} + +func (o *mockPoller) Poll() { + o.poller.Poll() +} + +type mockClient struct { + pingv1connect.PingServiceClient + runnerv1connect.RunnerServiceClient + + sleep time.Duration + cancel bool + err error + noTask bool +} + +func (o mockClient) Address() string { + return "" +} + +func (o mockClient) Insecure() bool { + return true +} + +func (o *mockClient) FetchTask(ctx context.Context, req *connect.Request[runnerv1.FetchTaskRequest]) (*connect.Response[runnerv1.FetchTaskResponse], error) { + if o.sleep > 0 { + select { + case <-ctx.Done(): + log.Trace("fetch task done") + return nil, context.DeadlineExceeded + case <-time.After(o.sleep): + log.Trace("slept") + return nil, fmt.Errorf("unexpected") + } + } + if o.cancel { + return nil, context.Canceled + } + if o.err != nil { + return nil, o.err + } + task := &runnerv1.Task{} + if o.noTask { + task = nil + o.noTask = false + } + + return connect.NewResponse(&runnerv1.FetchTaskResponse{ + Task: task, + TasksVersion: int64(1), + }), nil +} + +type mockRunner struct { + cfg *config.Runner + log chan string + panics bool + err error +} + +func (o *mockRunner) Run(ctx context.Context, task *runnerv1.Task) error { + o.log <- "runner starts" + if o.panics { + log.Trace("panics") + o.log <- "runner panics" + o.panics = false + panic("whatever") + } + if o.err != nil { + log.Trace("error") + o.log <- "runner error" + err := o.err + o.err = nil + return err + } + for { + select { + case <-ctx.Done(): + log.Trace("shutdown") + o.log <- "runner shutdown" + return nil + case <-time.After(o.cfg.Timeout): + log.Trace("after") + o.log <- "runner timeout" + return nil + } + } +} + +func setTrace(t *testing.T) { + t.Helper() + log.SetReportCaller(true) + log.SetLevel(log.TraceLevel) +} + +func TestPoller_New(t *testing.T) { + p := New(&config.Config{}, &mockClient{}, &mockRunner{}) + assert.NotNil(t, p) +} + +func TestPoller_Runner(t *testing.T) { + setTrace(t) + for _, testCase := range []struct { + name string + timeout time.Duration + noTask bool + panics bool + err error + expected string + contextTimeout time.Duration + }{ + { + name: "Simple", + timeout: 10 * time.Second, + expected: "runner shutdown", + }, + { + name: "Panics", + timeout: 10 * time.Second, + panics: true, + expected: "runner panics", + }, + { + name: "Error", + timeout: 10 * time.Second, + err: fmt.Errorf("ERROR"), + expected: "runner error", + }, + { + name: "PollTaskError", + timeout: 10 * time.Second, + noTask: true, + expected: "runner shutdown", + }, + { + name: "ShutdownTimeout", + timeout: 1 * time.Second, + contextTimeout: 1 * time.Minute, + expected: "runner timeout", + }, + } { + t.Run(testCase.name, func(t *testing.T) { + runnerLog := make(chan string, 3) + configRunner := config.Runner{ + FetchInterval: 1, + Capacity: 1, + Timeout: testCase.timeout, + } + p := &mockPoller{} + p.init( + &config.Config{ + Runner: configRunner, + }, + &mockClient{ + noTask: testCase.noTask, + }, + &mockRunner{ + cfg: &configRunner, + log: runnerLog, + panics: testCase.panics, + err: testCase.err, + }) + go p.Poll() + assert.Equal(t, "runner starts", <-runnerLog) + var ctx context.Context + var cancel context.CancelFunc + if testCase.contextTimeout > 0 { + ctx, cancel = context.WithTimeout(context.Background(), testCase.contextTimeout) + defer cancel() + } else { + ctx, cancel = context.WithCancel(context.Background()) + cancel() + } + p.Shutdown(ctx) + <-p.done + assert.Equal(t, testCase.expected, <-runnerLog) + }) + } +} + +func TestPoller_Fetch(t *testing.T) { + setTrace(t) + for _, testCase := range []struct { + name string + noTask bool + sleep time.Duration + err error + cancel bool + success bool + }{ + { + name: "Success", + success: true, + }, + { + name: "Timeout", + sleep: 100 * time.Millisecond, + }, + { + name: "Canceled", + cancel: true, + }, + { + name: "NoTask", + noTask: true, + }, + { + name: "Error", + err: fmt.Errorf("random error"), + }, + } { + t.Run(testCase.name, func(t *testing.T) { + configRunner := config.Runner{ + FetchTimeout: 1 * time.Millisecond, + } + p := &mockPoller{} + p.init( + &config.Config{ + Runner: configRunner, + }, + &mockClient{ + sleep: testCase.sleep, + cancel: testCase.cancel, + noTask: testCase.noTask, + err: testCase.err, + }, + &mockRunner{}, + ) + task, ok := p.fetchTask(context.Background()) + if testCase.success { + assert.True(t, ok) + assert.NotNil(t, task) + } else { + assert.False(t, ok) + assert.Nil(t, task) + } + }) + } +} |