summaryrefslogtreecommitdiffstats
path: root/internal/app/poll/poller_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/app/poll/poller_test.go')
-rw-r--r--internal/app/poll/poller_test.go263
1 files changed, 263 insertions, 0 deletions
diff --git a/internal/app/poll/poller_test.go b/internal/app/poll/poller_test.go
new file mode 100644
index 0000000..04b1a84
--- /dev/null
+++ b/internal/app/poll/poller_test.go
@@ -0,0 +1,263 @@
+// Copyright The Forgejo Authors.
+// SPDX-License-Identifier: MIT
+
+package poll
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "connectrpc.com/connect"
+
+ "code.gitea.io/actions-proto-go/ping/v1/pingv1connect"
+ runnerv1 "code.gitea.io/actions-proto-go/runner/v1"
+ "code.gitea.io/actions-proto-go/runner/v1/runnerv1connect"
+ "gitea.com/gitea/act_runner/internal/pkg/config"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+)
+
+type mockPoller struct {
+ poller
+}
+
+func (o *mockPoller) Poll() {
+ o.poller.Poll()
+}
+
+type mockClient struct {
+ pingv1connect.PingServiceClient
+ runnerv1connect.RunnerServiceClient
+
+ sleep time.Duration
+ cancel bool
+ err error
+ noTask bool
+}
+
+func (o mockClient) Address() string {
+ return ""
+}
+
+func (o mockClient) Insecure() bool {
+ return true
+}
+
+func (o *mockClient) FetchTask(ctx context.Context, req *connect.Request[runnerv1.FetchTaskRequest]) (*connect.Response[runnerv1.FetchTaskResponse], error) {
+ if o.sleep > 0 {
+ select {
+ case <-ctx.Done():
+ log.Trace("fetch task done")
+ return nil, context.DeadlineExceeded
+ case <-time.After(o.sleep):
+ log.Trace("slept")
+ return nil, fmt.Errorf("unexpected")
+ }
+ }
+ if o.cancel {
+ return nil, context.Canceled
+ }
+ if o.err != nil {
+ return nil, o.err
+ }
+ task := &runnerv1.Task{}
+ if o.noTask {
+ task = nil
+ o.noTask = false
+ }
+
+ return connect.NewResponse(&runnerv1.FetchTaskResponse{
+ Task: task,
+ TasksVersion: int64(1),
+ }), nil
+}
+
+type mockRunner struct {
+ cfg *config.Runner
+ log chan string
+ panics bool
+ err error
+}
+
+func (o *mockRunner) Run(ctx context.Context, task *runnerv1.Task) error {
+ o.log <- "runner starts"
+ if o.panics {
+ log.Trace("panics")
+ o.log <- "runner panics"
+ o.panics = false
+ panic("whatever")
+ }
+ if o.err != nil {
+ log.Trace("error")
+ o.log <- "runner error"
+ err := o.err
+ o.err = nil
+ return err
+ }
+ for {
+ select {
+ case <-ctx.Done():
+ log.Trace("shutdown")
+ o.log <- "runner shutdown"
+ return nil
+ case <-time.After(o.cfg.Timeout):
+ log.Trace("after")
+ o.log <- "runner timeout"
+ return nil
+ }
+ }
+}
+
+func setTrace(t *testing.T) {
+ t.Helper()
+ log.SetReportCaller(true)
+ log.SetLevel(log.TraceLevel)
+}
+
+func TestPoller_New(t *testing.T) {
+ p := New(&config.Config{}, &mockClient{}, &mockRunner{})
+ assert.NotNil(t, p)
+}
+
+func TestPoller_Runner(t *testing.T) {
+ setTrace(t)
+ for _, testCase := range []struct {
+ name string
+ timeout time.Duration
+ noTask bool
+ panics bool
+ err error
+ expected string
+ contextTimeout time.Duration
+ }{
+ {
+ name: "Simple",
+ timeout: 10 * time.Second,
+ expected: "runner shutdown",
+ },
+ {
+ name: "Panics",
+ timeout: 10 * time.Second,
+ panics: true,
+ expected: "runner panics",
+ },
+ {
+ name: "Error",
+ timeout: 10 * time.Second,
+ err: fmt.Errorf("ERROR"),
+ expected: "runner error",
+ },
+ {
+ name: "PollTaskError",
+ timeout: 10 * time.Second,
+ noTask: true,
+ expected: "runner shutdown",
+ },
+ {
+ name: "ShutdownTimeout",
+ timeout: 1 * time.Second,
+ contextTimeout: 1 * time.Minute,
+ expected: "runner timeout",
+ },
+ } {
+ t.Run(testCase.name, func(t *testing.T) {
+ runnerLog := make(chan string, 3)
+ configRunner := config.Runner{
+ FetchInterval: 1,
+ Capacity: 1,
+ Timeout: testCase.timeout,
+ }
+ p := &mockPoller{}
+ p.init(
+ &config.Config{
+ Runner: configRunner,
+ },
+ &mockClient{
+ noTask: testCase.noTask,
+ },
+ &mockRunner{
+ cfg: &configRunner,
+ log: runnerLog,
+ panics: testCase.panics,
+ err: testCase.err,
+ })
+ go p.Poll()
+ assert.Equal(t, "runner starts", <-runnerLog)
+ var ctx context.Context
+ var cancel context.CancelFunc
+ if testCase.contextTimeout > 0 {
+ ctx, cancel = context.WithTimeout(context.Background(), testCase.contextTimeout)
+ defer cancel()
+ } else {
+ ctx, cancel = context.WithCancel(context.Background())
+ cancel()
+ }
+ p.Shutdown(ctx)
+ <-p.done
+ assert.Equal(t, testCase.expected, <-runnerLog)
+ })
+ }
+}
+
+func TestPoller_Fetch(t *testing.T) {
+ setTrace(t)
+ for _, testCase := range []struct {
+ name string
+ noTask bool
+ sleep time.Duration
+ err error
+ cancel bool
+ success bool
+ }{
+ {
+ name: "Success",
+ success: true,
+ },
+ {
+ name: "Timeout",
+ sleep: 100 * time.Millisecond,
+ },
+ {
+ name: "Canceled",
+ cancel: true,
+ },
+ {
+ name: "NoTask",
+ noTask: true,
+ },
+ {
+ name: "Error",
+ err: fmt.Errorf("random error"),
+ },
+ } {
+ t.Run(testCase.name, func(t *testing.T) {
+ configRunner := config.Runner{
+ FetchTimeout: 1 * time.Millisecond,
+ }
+ p := &mockPoller{}
+ p.init(
+ &config.Config{
+ Runner: configRunner,
+ },
+ &mockClient{
+ sleep: testCase.sleep,
+ cancel: testCase.cancel,
+ noTask: testCase.noTask,
+ err: testCase.err,
+ },
+ &mockRunner{},
+ )
+ task, ok := p.fetchTask(context.Background())
+ if testCase.success {
+ assert.True(t, ok)
+ assert.NotNil(t, task)
+ } else {
+ assert.False(t, ok)
+ assert.Nil(t, task)
+ }
+ })
+ }
+}