summaryrefslogtreecommitdiffstats
path: root/pkg/exprparser
diff options
context:
space:
mode:
authorEarl Warren <contact@earl-warren.org>2024-10-31 15:58:39 +0100
committerEarl Warren <contact@earl-warren.org>2024-10-31 16:02:17 +0100
commitb1b9df5ef40e778f50a6f7cd08f4964c69258924 (patch)
treeae800a783a1c2c2f7c8ea7f4dd5dceacef6c2bdb /pkg/exprparser
parentMerge pull request 'fix: debug is leaking host container and network names' (... (diff)
downloadforgejo-act-b1b9df5ef40e778f50a6f7cd08f4964c69258924.tar.xz
forgejo-act-b1b9df5ef40e778f50a6f7cd08f4964c69258924.zip
fix: return an error when the argument count is wrong
Closes forgejo/runner#307
Diffstat (limited to 'pkg/exprparser')
-rw-r--r--pkg/exprparser/functions_test.go21
-rw-r--r--pkg/exprparser/interpreter.go35
2 files changed, 56 insertions, 0 deletions
diff --git a/pkg/exprparser/functions_test.go b/pkg/exprparser/functions_test.go
index ea51a2b..c90b326 100644
--- a/pkg/exprparser/functions_test.go
+++ b/pkg/exprparser/functions_test.go
@@ -43,6 +43,9 @@ func TestFunctionContains(t *testing.T) {
assert.Equal(t, tt.expected, output)
})
}
+
+ _, err := NewInterpeter(env, Config{}).Evaluate("contains('one')", DefaultStatusCheckNone)
+ assert.Error(t, err)
}
func TestFunctionStartsWith(t *testing.T) {
@@ -72,6 +75,9 @@ func TestFunctionStartsWith(t *testing.T) {
assert.Equal(t, tt.expected, output)
})
}
+
+ _, err := NewInterpeter(env, Config{}).Evaluate("startsWith('one')", DefaultStatusCheckNone)
+ assert.Error(t, err)
}
func TestFunctionEndsWith(t *testing.T) {
@@ -101,6 +107,9 @@ func TestFunctionEndsWith(t *testing.T) {
assert.Equal(t, tt.expected, output)
})
}
+
+ _, err := NewInterpeter(env, Config{}).Evaluate("endsWith('one')", DefaultStatusCheckNone)
+ assert.Error(t, err)
}
func TestFunctionJoin(t *testing.T) {
@@ -128,6 +137,9 @@ func TestFunctionJoin(t *testing.T) {
assert.Equal(t, tt.expected, output)
})
}
+
+ _, err := NewInterpeter(env, Config{}).Evaluate("join()", DefaultStatusCheckNone)
+ assert.Error(t, err)
}
func TestFunctionToJSON(t *testing.T) {
@@ -154,6 +166,9 @@ func TestFunctionToJSON(t *testing.T) {
assert.Equal(t, tt.expected, output)
})
}
+
+ _, err := NewInterpeter(env, Config{}).Evaluate("tojson()", DefaultStatusCheckNone)
+ assert.Error(t, err)
}
func TestFunctionFromJSON(t *testing.T) {
@@ -177,6 +192,9 @@ func TestFunctionFromJSON(t *testing.T) {
assert.Equal(t, tt.expected, output)
})
}
+
+ _, err := NewInterpeter(env, Config{}).Evaluate("fromjson()", DefaultStatusCheckNone)
+ assert.Error(t, err)
}
func TestFunctionHashFiles(t *testing.T) {
@@ -248,4 +266,7 @@ func TestFunctionFormat(t *testing.T) {
}
})
}
+
+ _, err := NewInterpeter(env, Config{}).Evaluate("format()", DefaultStatusCheckNone)
+ assert.Error(t, err)
}
diff --git a/pkg/exprparser/interpreter.go b/pkg/exprparser/interpreter.go
index 29c5686..021e5c9 100644
--- a/pkg/exprparser/interpreter.go
+++ b/pkg/exprparser/interpreter.go
@@ -593,23 +593,58 @@ func (impl *interperterImpl) evaluateFuncCall(funcCallNode *actionlint.FuncCallN
args = append(args, reflect.ValueOf(value))
}
+ argCountCheck := func(argCount int) error {
+ if len(args) != argCount {
+ return fmt.Errorf("'%s' expected %d arguments but got %d instead", funcCallNode.Callee, argCount, len(args))
+ }
+ return nil
+ }
+
+ argAtLeastCheck := func(atLeast int) error {
+ if len(args) < atLeast {
+ return fmt.Errorf("'%s' expected at least %d arguments but got %d instead", funcCallNode.Callee, atLeast, len(args))
+ }
+ return nil
+ }
+
switch strings.ToLower(funcCallNode.Callee) {
case "contains":
+ if err := argCountCheck(2); err != nil {
+ return nil, err
+ }
return impl.contains(args[0], args[1])
case "startswith":
+ if err := argCountCheck(2); err != nil {
+ return nil, err
+ }
return impl.startsWith(args[0], args[1])
case "endswith":
+ if err := argCountCheck(2); err != nil {
+ return nil, err
+ }
return impl.endsWith(args[0], args[1])
case "format":
+ if err := argAtLeastCheck(1); err != nil {
+ return nil, err
+ }
return impl.format(args[0], args[1:]...)
case "join":
+ if err := argAtLeastCheck(1); err != nil {
+ return nil, err
+ }
if len(args) == 1 {
return impl.join(args[0], reflect.ValueOf(","))
}
return impl.join(args[0], args[1])
case "tojson":
+ if err := argCountCheck(1); err != nil {
+ return nil, err
+ }
return impl.toJSON(args[0])
case "fromjson":
+ if err := argCountCheck(1); err != nil {
+ return nil, err
+ }
return impl.fromJSON(args[0])
case "hashfiles":
if impl.env.HashFiles != nil {