diff options
author | Earl Warren <contact@earl-warren.org> | 2024-10-31 15:58:39 +0100 |
---|---|---|
committer | Earl Warren <contact@earl-warren.org> | 2024-10-31 16:02:17 +0100 |
commit | b1b9df5ef40e778f50a6f7cd08f4964c69258924 (patch) | |
tree | ae800a783a1c2c2f7c8ea7f4dd5dceacef6c2bdb /pkg/exprparser | |
parent | Merge pull request 'fix: debug is leaking host container and network names' (... (diff) | |
download | forgejo-act-b1b9df5ef40e778f50a6f7cd08f4964c69258924.tar.xz forgejo-act-b1b9df5ef40e778f50a6f7cd08f4964c69258924.zip |
fix: return an error when the argument count is wrong
Closes forgejo/runner#307
Diffstat (limited to 'pkg/exprparser')
-rw-r--r-- | pkg/exprparser/functions_test.go | 21 | ||||
-rw-r--r-- | pkg/exprparser/interpreter.go | 35 |
2 files changed, 56 insertions, 0 deletions
diff --git a/pkg/exprparser/functions_test.go b/pkg/exprparser/functions_test.go index ea51a2b..c90b326 100644 --- a/pkg/exprparser/functions_test.go +++ b/pkg/exprparser/functions_test.go @@ -43,6 +43,9 @@ func TestFunctionContains(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("contains('one')", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionStartsWith(t *testing.T) { @@ -72,6 +75,9 @@ func TestFunctionStartsWith(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("startsWith('one')", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionEndsWith(t *testing.T) { @@ -101,6 +107,9 @@ func TestFunctionEndsWith(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("endsWith('one')", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionJoin(t *testing.T) { @@ -128,6 +137,9 @@ func TestFunctionJoin(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("join()", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionToJSON(t *testing.T) { @@ -154,6 +166,9 @@ func TestFunctionToJSON(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("tojson()", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionFromJSON(t *testing.T) { @@ -177,6 +192,9 @@ func TestFunctionFromJSON(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("fromjson()", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionHashFiles(t *testing.T) { @@ -248,4 +266,7 @@ func TestFunctionFormat(t *testing.T) { } }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("format()", DefaultStatusCheckNone) + assert.Error(t, err) } diff --git a/pkg/exprparser/interpreter.go b/pkg/exprparser/interpreter.go index 29c5686..021e5c9 100644 --- a/pkg/exprparser/interpreter.go +++ b/pkg/exprparser/interpreter.go @@ -593,23 +593,58 @@ func (impl *interperterImpl) evaluateFuncCall(funcCallNode *actionlint.FuncCallN args = append(args, reflect.ValueOf(value)) } + argCountCheck := func(argCount int) error { + if len(args) != argCount { + return fmt.Errorf("'%s' expected %d arguments but got %d instead", funcCallNode.Callee, argCount, len(args)) + } + return nil + } + + argAtLeastCheck := func(atLeast int) error { + if len(args) < atLeast { + return fmt.Errorf("'%s' expected at least %d arguments but got %d instead", funcCallNode.Callee, atLeast, len(args)) + } + return nil + } + switch strings.ToLower(funcCallNode.Callee) { case "contains": + if err := argCountCheck(2); err != nil { + return nil, err + } return impl.contains(args[0], args[1]) case "startswith": + if err := argCountCheck(2); err != nil { + return nil, err + } return impl.startsWith(args[0], args[1]) case "endswith": + if err := argCountCheck(2); err != nil { + return nil, err + } return impl.endsWith(args[0], args[1]) case "format": + if err := argAtLeastCheck(1); err != nil { + return nil, err + } return impl.format(args[0], args[1:]...) case "join": + if err := argAtLeastCheck(1); err != nil { + return nil, err + } if len(args) == 1 { return impl.join(args[0], reflect.ValueOf(",")) } return impl.join(args[0], args[1]) case "tojson": + if err := argCountCheck(1); err != nil { + return nil, err + } return impl.toJSON(args[0]) case "fromjson": + if err := argCountCheck(1); err != nil { + return nil, err + } return impl.fromJSON(args[0]) case "hashfiles": if impl.env.HashFiles != nil { |