summaryrefslogtreecommitdiffstats
path: root/pkg/exprparser/interpreter.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/exprparser/interpreter.go')
-rw-r--r--pkg/exprparser/interpreter.go642
1 files changed, 642 insertions, 0 deletions
diff --git a/pkg/exprparser/interpreter.go b/pkg/exprparser/interpreter.go
new file mode 100644
index 0000000..29c5686
--- /dev/null
+++ b/pkg/exprparser/interpreter.go
@@ -0,0 +1,642 @@
+package exprparser
+
+import (
+ "encoding"
+ "fmt"
+ "math"
+ "reflect"
+ "strings"
+
+ "github.com/nektos/act/pkg/model"
+ "github.com/rhysd/actionlint"
+)
+
+type EvaluationEnvironment struct {
+ Github *model.GithubContext
+ Env map[string]string
+ Job *model.JobContext
+ Jobs *map[string]*model.WorkflowCallResult
+ Steps map[string]*model.StepResult
+ Runner map[string]interface{}
+ Secrets map[string]string
+ Vars map[string]string
+ Strategy map[string]interface{}
+ Matrix map[string]interface{}
+ Needs map[string]Needs
+ Inputs map[string]interface{}
+ HashFiles func([]reflect.Value) (interface{}, error)
+}
+
+type Needs struct {
+ Outputs map[string]string `json:"outputs"`
+ Result string `json:"result"`
+}
+
+type Config struct {
+ Run *model.Run
+ WorkingDir string
+ Context string
+}
+
+type DefaultStatusCheck int
+
+const (
+ DefaultStatusCheckNone DefaultStatusCheck = iota
+ DefaultStatusCheckSuccess
+ DefaultStatusCheckAlways
+ DefaultStatusCheckCanceled
+ DefaultStatusCheckFailure
+)
+
+func (dsc DefaultStatusCheck) String() string {
+ switch dsc {
+ case DefaultStatusCheckSuccess:
+ return "success"
+ case DefaultStatusCheckAlways:
+ return "always"
+ case DefaultStatusCheckCanceled:
+ return "cancelled"
+ case DefaultStatusCheckFailure:
+ return "failure"
+ }
+ return ""
+}
+
+type Interpreter interface {
+ Evaluate(input string, defaultStatusCheck DefaultStatusCheck) (interface{}, error)
+}
+
+type interperterImpl struct {
+ env *EvaluationEnvironment
+ config Config
+}
+
+func NewInterpeter(env *EvaluationEnvironment, config Config) Interpreter {
+ return &interperterImpl{
+ env: env,
+ config: config,
+ }
+}
+
+func (impl *interperterImpl) Evaluate(input string, defaultStatusCheck DefaultStatusCheck) (interface{}, error) {
+ input = strings.TrimPrefix(input, "${{")
+ if defaultStatusCheck != DefaultStatusCheckNone && input == "" {
+ input = "success()"
+ }
+ parser := actionlint.NewExprParser()
+ exprNode, err := parser.Parse(actionlint.NewExprLexer(input + "}}"))
+ if err != nil {
+ return nil, fmt.Errorf("Failed to parse: %s", err.Message)
+ }
+
+ if defaultStatusCheck != DefaultStatusCheckNone {
+ hasStatusCheckFunction := false
+ actionlint.VisitExprNode(exprNode, func(node, _ actionlint.ExprNode, entering bool) {
+ if funcCallNode, ok := node.(*actionlint.FuncCallNode); entering && ok {
+ switch strings.ToLower(funcCallNode.Callee) {
+ case "success", "always", "cancelled", "failure":
+ hasStatusCheckFunction = true
+ }
+ }
+ })
+
+ if !hasStatusCheckFunction {
+ exprNode = &actionlint.LogicalOpNode{
+ Kind: actionlint.LogicalOpNodeKindAnd,
+ Left: &actionlint.FuncCallNode{
+ Callee: defaultStatusCheck.String(),
+ Args: []actionlint.ExprNode{},
+ },
+ Right: exprNode,
+ }
+ }
+ }
+
+ result, err2 := impl.evaluateNode(exprNode)
+
+ return result, err2
+}
+
+func (impl *interperterImpl) evaluateNode(exprNode actionlint.ExprNode) (interface{}, error) {
+ switch node := exprNode.(type) {
+ case *actionlint.VariableNode:
+ return impl.evaluateVariable(node)
+ case *actionlint.BoolNode:
+ return node.Value, nil
+ case *actionlint.NullNode:
+ return nil, nil
+ case *actionlint.IntNode:
+ return node.Value, nil
+ case *actionlint.FloatNode:
+ return node.Value, nil
+ case *actionlint.StringNode:
+ return node.Value, nil
+ case *actionlint.IndexAccessNode:
+ return impl.evaluateIndexAccess(node)
+ case *actionlint.ObjectDerefNode:
+ return impl.evaluateObjectDeref(node)
+ case *actionlint.ArrayDerefNode:
+ return impl.evaluateArrayDeref(node)
+ case *actionlint.NotOpNode:
+ return impl.evaluateNot(node)
+ case *actionlint.CompareOpNode:
+ return impl.evaluateCompare(node)
+ case *actionlint.LogicalOpNode:
+ return impl.evaluateLogicalCompare(node)
+ case *actionlint.FuncCallNode:
+ return impl.evaluateFuncCall(node)
+ default:
+ return nil, fmt.Errorf("Fatal error! Unknown node type: %s node: %+v", reflect.TypeOf(exprNode), exprNode)
+ }
+}
+
+//nolint:gocyclo
+func (impl *interperterImpl) evaluateVariable(variableNode *actionlint.VariableNode) (interface{}, error) {
+ switch strings.ToLower(variableNode.Name) {
+ case "github":
+ return impl.env.Github, nil
+ case "gitea": // compatible with Gitea
+ return impl.env.Github, nil
+ case "forge":
+ return impl.env.Github, nil
+ case "env":
+ return impl.env.Env, nil
+ case "job":
+ return impl.env.Job, nil
+ case "jobs":
+ if impl.env.Jobs == nil {
+ return nil, fmt.Errorf("Unavailable context: jobs")
+ }
+ return impl.env.Jobs, nil
+ case "steps":
+ return impl.env.Steps, nil
+ case "runner":
+ return impl.env.Runner, nil
+ case "secrets":
+ return impl.env.Secrets, nil
+ case "vars":
+ return impl.env.Vars, nil
+ case "strategy":
+ return impl.env.Strategy, nil
+ case "matrix":
+ return impl.env.Matrix, nil
+ case "needs":
+ return impl.env.Needs, nil
+ case "inputs":
+ return impl.env.Inputs, nil
+ case "infinity":
+ return math.Inf(1), nil
+ case "nan":
+ return math.NaN(), nil
+ default:
+ return nil, fmt.Errorf("Unavailable context: %s", variableNode.Name)
+ }
+}
+
+func (impl *interperterImpl) evaluateIndexAccess(indexAccessNode *actionlint.IndexAccessNode) (interface{}, error) {
+ left, err := impl.evaluateNode(indexAccessNode.Operand)
+ if err != nil {
+ return nil, err
+ }
+
+ leftValue := reflect.ValueOf(left)
+
+ right, err := impl.evaluateNode(indexAccessNode.Index)
+ if err != nil {
+ return nil, err
+ }
+
+ rightValue := reflect.ValueOf(right)
+
+ switch rightValue.Kind() {
+ case reflect.String:
+ return impl.getPropertyValue(leftValue, rightValue.String())
+
+ case reflect.Int:
+ switch leftValue.Kind() {
+ case reflect.Slice:
+ if rightValue.Int() < 0 || rightValue.Int() >= int64(leftValue.Len()) {
+ return nil, nil
+ }
+ return leftValue.Index(int(rightValue.Int())).Interface(), nil
+ default:
+ return nil, nil
+ }
+
+ default:
+ return nil, nil
+ }
+}
+
+func (impl *interperterImpl) evaluateObjectDeref(objectDerefNode *actionlint.ObjectDerefNode) (interface{}, error) {
+ left, err := impl.evaluateNode(objectDerefNode.Receiver)
+ if err != nil {
+ return nil, err
+ }
+
+ return impl.getPropertyValue(reflect.ValueOf(left), objectDerefNode.Property)
+}
+
+func (impl *interperterImpl) evaluateArrayDeref(arrayDerefNode *actionlint.ArrayDerefNode) (interface{}, error) {
+ left, err := impl.evaluateNode(arrayDerefNode.Receiver)
+ if err != nil {
+ return nil, err
+ }
+
+ return impl.getSafeValue(reflect.ValueOf(left)), nil
+}
+
+func (impl *interperterImpl) getPropertyValue(left reflect.Value, property string) (value interface{}, err error) {
+ switch left.Kind() {
+ case reflect.Ptr:
+ return impl.getPropertyValue(left.Elem(), property)
+
+ case reflect.Struct:
+ leftType := left.Type()
+ for i := 0; i < leftType.NumField(); i++ {
+ jsonName := leftType.Field(i).Tag.Get("json")
+ if jsonName == property {
+ property = leftType.Field(i).Name
+ break
+ }
+ }
+
+ fieldValue := left.FieldByNameFunc(func(name string) bool {
+ return strings.EqualFold(name, property)
+ })
+
+ if fieldValue.Kind() == reflect.Invalid {
+ return "", nil
+ }
+
+ i := fieldValue.Interface()
+ // The type stepStatus int is an integer, but should be treated as string
+ if m, ok := i.(encoding.TextMarshaler); ok {
+ text, err := m.MarshalText()
+ if err != nil {
+ return nil, err
+ }
+ return string(text), nil
+ }
+ return i, nil
+
+ case reflect.Map:
+ iter := left.MapRange()
+
+ for iter.Next() {
+ key := iter.Key()
+
+ switch key.Kind() {
+ case reflect.String:
+ if strings.EqualFold(key.String(), property) {
+ return impl.getMapValue(iter.Value())
+ }
+
+ default:
+ return nil, fmt.Errorf("'%s' in map key not implemented", key.Kind())
+ }
+ }
+
+ return nil, nil
+
+ case reflect.Slice:
+ var values []interface{}
+
+ for i := 0; i < left.Len(); i++ {
+ value, err := impl.getPropertyValue(left.Index(i).Elem(), property)
+ if err != nil {
+ return nil, err
+ }
+
+ values = append(values, value)
+ }
+
+ return values, nil
+ }
+
+ return nil, nil
+}
+
+func (impl *interperterImpl) getMapValue(value reflect.Value) (interface{}, error) {
+ if value.Kind() == reflect.Ptr {
+ return impl.getMapValue(value.Elem())
+ }
+
+ return value.Interface(), nil
+}
+
+func (impl *interperterImpl) evaluateNot(notNode *actionlint.NotOpNode) (interface{}, error) {
+ operand, err := impl.evaluateNode(notNode.Operand)
+ if err != nil {
+ return nil, err
+ }
+
+ return !IsTruthy(operand), nil
+}
+
+func (impl *interperterImpl) evaluateCompare(compareNode *actionlint.CompareOpNode) (interface{}, error) {
+ left, err := impl.evaluateNode(compareNode.Left)
+ if err != nil {
+ return nil, err
+ }
+
+ right, err := impl.evaluateNode(compareNode.Right)
+ if err != nil {
+ return nil, err
+ }
+
+ leftValue := reflect.ValueOf(left)
+ rightValue := reflect.ValueOf(right)
+
+ return impl.compareValues(leftValue, rightValue, compareNode.Kind)
+}
+
+func (impl *interperterImpl) compareValues(leftValue reflect.Value, rightValue reflect.Value, kind actionlint.CompareOpNodeKind) (interface{}, error) {
+ if leftValue.Kind() != rightValue.Kind() {
+ if !impl.isNumber(leftValue) {
+ leftValue = impl.coerceToNumber(leftValue)
+ }
+ if !impl.isNumber(rightValue) {
+ rightValue = impl.coerceToNumber(rightValue)
+ }
+ }
+
+ switch leftValue.Kind() {
+ case reflect.Bool:
+ return impl.compareNumber(float64(impl.coerceToNumber(leftValue).Int()), float64(impl.coerceToNumber(rightValue).Int()), kind)
+ case reflect.String:
+ return impl.compareString(strings.ToLower(leftValue.String()), strings.ToLower(rightValue.String()), kind)
+
+ case reflect.Int:
+ if rightValue.Kind() == reflect.Float64 {
+ return impl.compareNumber(float64(leftValue.Int()), rightValue.Float(), kind)
+ }
+
+ return impl.compareNumber(float64(leftValue.Int()), float64(rightValue.Int()), kind)
+
+ case reflect.Float64:
+ if rightValue.Kind() == reflect.Int {
+ return impl.compareNumber(leftValue.Float(), float64(rightValue.Int()), kind)
+ }
+
+ return impl.compareNumber(leftValue.Float(), rightValue.Float(), kind)
+
+ case reflect.Invalid:
+ if rightValue.Kind() == reflect.Invalid {
+ return true, nil
+ }
+
+ // not possible situation - params are converted to the same type in code above
+ return nil, fmt.Errorf("Compare params of Invalid type: left: %+v, right: %+v", leftValue.Kind(), rightValue.Kind())
+
+ default:
+ return nil, fmt.Errorf("Compare not implemented for types: left: %+v, right: %+v", leftValue.Kind(), rightValue.Kind())
+ }
+}
+
+func (impl *interperterImpl) coerceToNumber(value reflect.Value) reflect.Value {
+ switch value.Kind() {
+ case reflect.Invalid:
+ return reflect.ValueOf(0)
+
+ case reflect.Bool:
+ switch value.Bool() {
+ case true:
+ return reflect.ValueOf(1)
+ case false:
+ return reflect.ValueOf(0)
+ }
+
+ case reflect.String:
+ if value.String() == "" {
+ return reflect.ValueOf(0)
+ }
+
+ // try to parse the string as a number
+ evaluated, err := impl.Evaluate(value.String(), DefaultStatusCheckNone)
+ if err != nil {
+ return reflect.ValueOf(math.NaN())
+ }
+
+ if value := reflect.ValueOf(evaluated); impl.isNumber(value) {
+ return value
+ }
+ }
+
+ return reflect.ValueOf(math.NaN())
+}
+
+func (impl *interperterImpl) coerceToString(value reflect.Value) reflect.Value {
+ switch value.Kind() {
+ case reflect.Invalid:
+ return reflect.ValueOf("")
+
+ case reflect.Bool:
+ switch value.Bool() {
+ case true:
+ return reflect.ValueOf("true")
+ case false:
+ return reflect.ValueOf("false")
+ }
+
+ case reflect.String:
+ return value
+
+ case reflect.Int:
+ return reflect.ValueOf(fmt.Sprint(value))
+
+ case reflect.Float64:
+ if math.IsInf(value.Float(), 1) {
+ return reflect.ValueOf("Infinity")
+ } else if math.IsInf(value.Float(), -1) {
+ return reflect.ValueOf("-Infinity")
+ }
+ return reflect.ValueOf(fmt.Sprintf("%.15G", value.Float()))
+
+ case reflect.Slice:
+ return reflect.ValueOf("Array")
+
+ case reflect.Map:
+ return reflect.ValueOf("Object")
+ }
+
+ return value
+}
+
+func (impl *interperterImpl) compareString(left string, right string, kind actionlint.CompareOpNodeKind) (bool, error) {
+ switch kind {
+ case actionlint.CompareOpNodeKindLess:
+ return left < right, nil
+ case actionlint.CompareOpNodeKindLessEq:
+ return left <= right, nil
+ case actionlint.CompareOpNodeKindGreater:
+ return left > right, nil
+ case actionlint.CompareOpNodeKindGreaterEq:
+ return left >= right, nil
+ case actionlint.CompareOpNodeKindEq:
+ return left == right, nil
+ case actionlint.CompareOpNodeKindNotEq:
+ return left != right, nil
+ default:
+ return false, fmt.Errorf("TODO: not implemented to compare '%+v'", kind)
+ }
+}
+
+func (impl *interperterImpl) compareNumber(left float64, right float64, kind actionlint.CompareOpNodeKind) (bool, error) {
+ switch kind {
+ case actionlint.CompareOpNodeKindLess:
+ return left < right, nil
+ case actionlint.CompareOpNodeKindLessEq:
+ return left <= right, nil
+ case actionlint.CompareOpNodeKindGreater:
+ return left > right, nil
+ case actionlint.CompareOpNodeKindGreaterEq:
+ return left >= right, nil
+ case actionlint.CompareOpNodeKindEq:
+ return left == right, nil
+ case actionlint.CompareOpNodeKindNotEq:
+ return left != right, nil
+ default:
+ return false, fmt.Errorf("TODO: not implemented to compare '%+v'", kind)
+ }
+}
+
+func IsTruthy(input interface{}) bool {
+ value := reflect.ValueOf(input)
+ switch value.Kind() {
+ case reflect.Bool:
+ return value.Bool()
+
+ case reflect.String:
+ return value.String() != ""
+
+ case reflect.Int:
+ return value.Int() != 0
+
+ case reflect.Float64:
+ if math.IsNaN(value.Float()) {
+ return false
+ }
+
+ return value.Float() != 0
+
+ case reflect.Map, reflect.Slice:
+ return true
+
+ default:
+ return false
+ }
+}
+
+func (impl *interperterImpl) isNumber(value reflect.Value) bool {
+ switch value.Kind() {
+ case reflect.Int, reflect.Float64:
+ return true
+ default:
+ return false
+ }
+}
+
+func (impl *interperterImpl) getSafeValue(value reflect.Value) interface{} {
+ switch value.Kind() {
+ case reflect.Invalid:
+ return nil
+
+ case reflect.Float64:
+ if value.Float() == 0 {
+ return 0
+ }
+ }
+
+ return value.Interface()
+}
+
+func (impl *interperterImpl) evaluateLogicalCompare(compareNode *actionlint.LogicalOpNode) (interface{}, error) {
+ left, err := impl.evaluateNode(compareNode.Left)
+ if err != nil {
+ return nil, err
+ }
+
+ leftValue := reflect.ValueOf(left)
+
+ if IsTruthy(left) == (compareNode.Kind == actionlint.LogicalOpNodeKindOr) {
+ return impl.getSafeValue(leftValue), nil
+ }
+
+ right, err := impl.evaluateNode(compareNode.Right)
+ if err != nil {
+ return nil, err
+ }
+
+ rightValue := reflect.ValueOf(right)
+
+ switch compareNode.Kind {
+ case actionlint.LogicalOpNodeKindAnd:
+ return impl.getSafeValue(rightValue), nil
+ case actionlint.LogicalOpNodeKindOr:
+ return impl.getSafeValue(rightValue), nil
+ }
+
+ return nil, fmt.Errorf("Unable to compare incompatibles types '%s' and '%s'", leftValue.Kind(), rightValue.Kind())
+}
+
+//nolint:gocyclo
+func (impl *interperterImpl) evaluateFuncCall(funcCallNode *actionlint.FuncCallNode) (interface{}, error) {
+ args := make([]reflect.Value, 0)
+
+ for _, arg := range funcCallNode.Args {
+ value, err := impl.evaluateNode(arg)
+ if err != nil {
+ return nil, err
+ }
+
+ args = append(args, reflect.ValueOf(value))
+ }
+
+ switch strings.ToLower(funcCallNode.Callee) {
+ case "contains":
+ return impl.contains(args[0], args[1])
+ case "startswith":
+ return impl.startsWith(args[0], args[1])
+ case "endswith":
+ return impl.endsWith(args[0], args[1])
+ case "format":
+ return impl.format(args[0], args[1:]...)
+ case "join":
+ if len(args) == 1 {
+ return impl.join(args[0], reflect.ValueOf(","))
+ }
+ return impl.join(args[0], args[1])
+ case "tojson":
+ return impl.toJSON(args[0])
+ case "fromjson":
+ return impl.fromJSON(args[0])
+ case "hashfiles":
+ if impl.env.HashFiles != nil {
+ return impl.env.HashFiles(args)
+ }
+ return impl.hashFiles(args...)
+ case "always":
+ return impl.always()
+ case "success":
+ if impl.config.Context == "job" {
+ return impl.jobSuccess()
+ }
+ if impl.config.Context == "step" {
+ return impl.stepSuccess()
+ }
+ return nil, fmt.Errorf("Context '%s' must be one of 'job' or 'step'", impl.config.Context)
+ case "failure":
+ if impl.config.Context == "job" {
+ return impl.jobFailure()
+ }
+ if impl.config.Context == "step" {
+ return impl.stepFailure()
+ }
+ return nil, fmt.Errorf("Context '%s' must be one of 'job' or 'step'", impl.config.Context)
+ case "cancelled":
+ return impl.cancelled()
+ default:
+ return nil, fmt.Errorf("TODO: '%s' not implemented", funcCallNode.Callee)
+ }
+}