diff options
Diffstat (limited to 'pkg/exprparser/interpreter.go')
-rw-r--r-- | pkg/exprparser/interpreter.go | 642 |
1 files changed, 642 insertions, 0 deletions
diff --git a/pkg/exprparser/interpreter.go b/pkg/exprparser/interpreter.go new file mode 100644 index 0000000..29c5686 --- /dev/null +++ b/pkg/exprparser/interpreter.go @@ -0,0 +1,642 @@ +package exprparser + +import ( + "encoding" + "fmt" + "math" + "reflect" + "strings" + + "github.com/nektos/act/pkg/model" + "github.com/rhysd/actionlint" +) + +type EvaluationEnvironment struct { + Github *model.GithubContext + Env map[string]string + Job *model.JobContext + Jobs *map[string]*model.WorkflowCallResult + Steps map[string]*model.StepResult + Runner map[string]interface{} + Secrets map[string]string + Vars map[string]string + Strategy map[string]interface{} + Matrix map[string]interface{} + Needs map[string]Needs + Inputs map[string]interface{} + HashFiles func([]reflect.Value) (interface{}, error) +} + +type Needs struct { + Outputs map[string]string `json:"outputs"` + Result string `json:"result"` +} + +type Config struct { + Run *model.Run + WorkingDir string + Context string +} + +type DefaultStatusCheck int + +const ( + DefaultStatusCheckNone DefaultStatusCheck = iota + DefaultStatusCheckSuccess + DefaultStatusCheckAlways + DefaultStatusCheckCanceled + DefaultStatusCheckFailure +) + +func (dsc DefaultStatusCheck) String() string { + switch dsc { + case DefaultStatusCheckSuccess: + return "success" + case DefaultStatusCheckAlways: + return "always" + case DefaultStatusCheckCanceled: + return "cancelled" + case DefaultStatusCheckFailure: + return "failure" + } + return "" +} + +type Interpreter interface { + Evaluate(input string, defaultStatusCheck DefaultStatusCheck) (interface{}, error) +} + +type interperterImpl struct { + env *EvaluationEnvironment + config Config +} + +func NewInterpeter(env *EvaluationEnvironment, config Config) Interpreter { + return &interperterImpl{ + env: env, + config: config, + } +} + +func (impl *interperterImpl) Evaluate(input string, defaultStatusCheck DefaultStatusCheck) (interface{}, error) { + input = strings.TrimPrefix(input, "${{") + if defaultStatusCheck != DefaultStatusCheckNone && input == "" { + input = "success()" + } + parser := actionlint.NewExprParser() + exprNode, err := parser.Parse(actionlint.NewExprLexer(input + "}}")) + if err != nil { + return nil, fmt.Errorf("Failed to parse: %s", err.Message) + } + + if defaultStatusCheck != DefaultStatusCheckNone { + hasStatusCheckFunction := false + actionlint.VisitExprNode(exprNode, func(node, _ actionlint.ExprNode, entering bool) { + if funcCallNode, ok := node.(*actionlint.FuncCallNode); entering && ok { + switch strings.ToLower(funcCallNode.Callee) { + case "success", "always", "cancelled", "failure": + hasStatusCheckFunction = true + } + } + }) + + if !hasStatusCheckFunction { + exprNode = &actionlint.LogicalOpNode{ + Kind: actionlint.LogicalOpNodeKindAnd, + Left: &actionlint.FuncCallNode{ + Callee: defaultStatusCheck.String(), + Args: []actionlint.ExprNode{}, + }, + Right: exprNode, + } + } + } + + result, err2 := impl.evaluateNode(exprNode) + + return result, err2 +} + +func (impl *interperterImpl) evaluateNode(exprNode actionlint.ExprNode) (interface{}, error) { + switch node := exprNode.(type) { + case *actionlint.VariableNode: + return impl.evaluateVariable(node) + case *actionlint.BoolNode: + return node.Value, nil + case *actionlint.NullNode: + return nil, nil + case *actionlint.IntNode: + return node.Value, nil + case *actionlint.FloatNode: + return node.Value, nil + case *actionlint.StringNode: + return node.Value, nil + case *actionlint.IndexAccessNode: + return impl.evaluateIndexAccess(node) + case *actionlint.ObjectDerefNode: + return impl.evaluateObjectDeref(node) + case *actionlint.ArrayDerefNode: + return impl.evaluateArrayDeref(node) + case *actionlint.NotOpNode: + return impl.evaluateNot(node) + case *actionlint.CompareOpNode: + return impl.evaluateCompare(node) + case *actionlint.LogicalOpNode: + return impl.evaluateLogicalCompare(node) + case *actionlint.FuncCallNode: + return impl.evaluateFuncCall(node) + default: + return nil, fmt.Errorf("Fatal error! Unknown node type: %s node: %+v", reflect.TypeOf(exprNode), exprNode) + } +} + +//nolint:gocyclo +func (impl *interperterImpl) evaluateVariable(variableNode *actionlint.VariableNode) (interface{}, error) { + switch strings.ToLower(variableNode.Name) { + case "github": + return impl.env.Github, nil + case "gitea": // compatible with Gitea + return impl.env.Github, nil + case "forge": + return impl.env.Github, nil + case "env": + return impl.env.Env, nil + case "job": + return impl.env.Job, nil + case "jobs": + if impl.env.Jobs == nil { + return nil, fmt.Errorf("Unavailable context: jobs") + } + return impl.env.Jobs, nil + case "steps": + return impl.env.Steps, nil + case "runner": + return impl.env.Runner, nil + case "secrets": + return impl.env.Secrets, nil + case "vars": + return impl.env.Vars, nil + case "strategy": + return impl.env.Strategy, nil + case "matrix": + return impl.env.Matrix, nil + case "needs": + return impl.env.Needs, nil + case "inputs": + return impl.env.Inputs, nil + case "infinity": + return math.Inf(1), nil + case "nan": + return math.NaN(), nil + default: + return nil, fmt.Errorf("Unavailable context: %s", variableNode.Name) + } +} + +func (impl *interperterImpl) evaluateIndexAccess(indexAccessNode *actionlint.IndexAccessNode) (interface{}, error) { + left, err := impl.evaluateNode(indexAccessNode.Operand) + if err != nil { + return nil, err + } + + leftValue := reflect.ValueOf(left) + + right, err := impl.evaluateNode(indexAccessNode.Index) + if err != nil { + return nil, err + } + + rightValue := reflect.ValueOf(right) + + switch rightValue.Kind() { + case reflect.String: + return impl.getPropertyValue(leftValue, rightValue.String()) + + case reflect.Int: + switch leftValue.Kind() { + case reflect.Slice: + if rightValue.Int() < 0 || rightValue.Int() >= int64(leftValue.Len()) { + return nil, nil + } + return leftValue.Index(int(rightValue.Int())).Interface(), nil + default: + return nil, nil + } + + default: + return nil, nil + } +} + +func (impl *interperterImpl) evaluateObjectDeref(objectDerefNode *actionlint.ObjectDerefNode) (interface{}, error) { + left, err := impl.evaluateNode(objectDerefNode.Receiver) + if err != nil { + return nil, err + } + + return impl.getPropertyValue(reflect.ValueOf(left), objectDerefNode.Property) +} + +func (impl *interperterImpl) evaluateArrayDeref(arrayDerefNode *actionlint.ArrayDerefNode) (interface{}, error) { + left, err := impl.evaluateNode(arrayDerefNode.Receiver) + if err != nil { + return nil, err + } + + return impl.getSafeValue(reflect.ValueOf(left)), nil +} + +func (impl *interperterImpl) getPropertyValue(left reflect.Value, property string) (value interface{}, err error) { + switch left.Kind() { + case reflect.Ptr: + return impl.getPropertyValue(left.Elem(), property) + + case reflect.Struct: + leftType := left.Type() + for i := 0; i < leftType.NumField(); i++ { + jsonName := leftType.Field(i).Tag.Get("json") + if jsonName == property { + property = leftType.Field(i).Name + break + } + } + + fieldValue := left.FieldByNameFunc(func(name string) bool { + return strings.EqualFold(name, property) + }) + + if fieldValue.Kind() == reflect.Invalid { + return "", nil + } + + i := fieldValue.Interface() + // The type stepStatus int is an integer, but should be treated as string + if m, ok := i.(encoding.TextMarshaler); ok { + text, err := m.MarshalText() + if err != nil { + return nil, err + } + return string(text), nil + } + return i, nil + + case reflect.Map: + iter := left.MapRange() + + for iter.Next() { + key := iter.Key() + + switch key.Kind() { + case reflect.String: + if strings.EqualFold(key.String(), property) { + return impl.getMapValue(iter.Value()) + } + + default: + return nil, fmt.Errorf("'%s' in map key not implemented", key.Kind()) + } + } + + return nil, nil + + case reflect.Slice: + var values []interface{} + + for i := 0; i < left.Len(); i++ { + value, err := impl.getPropertyValue(left.Index(i).Elem(), property) + if err != nil { + return nil, err + } + + values = append(values, value) + } + + return values, nil + } + + return nil, nil +} + +func (impl *interperterImpl) getMapValue(value reflect.Value) (interface{}, error) { + if value.Kind() == reflect.Ptr { + return impl.getMapValue(value.Elem()) + } + + return value.Interface(), nil +} + +func (impl *interperterImpl) evaluateNot(notNode *actionlint.NotOpNode) (interface{}, error) { + operand, err := impl.evaluateNode(notNode.Operand) + if err != nil { + return nil, err + } + + return !IsTruthy(operand), nil +} + +func (impl *interperterImpl) evaluateCompare(compareNode *actionlint.CompareOpNode) (interface{}, error) { + left, err := impl.evaluateNode(compareNode.Left) + if err != nil { + return nil, err + } + + right, err := impl.evaluateNode(compareNode.Right) + if err != nil { + return nil, err + } + + leftValue := reflect.ValueOf(left) + rightValue := reflect.ValueOf(right) + + return impl.compareValues(leftValue, rightValue, compareNode.Kind) +} + +func (impl *interperterImpl) compareValues(leftValue reflect.Value, rightValue reflect.Value, kind actionlint.CompareOpNodeKind) (interface{}, error) { + if leftValue.Kind() != rightValue.Kind() { + if !impl.isNumber(leftValue) { + leftValue = impl.coerceToNumber(leftValue) + } + if !impl.isNumber(rightValue) { + rightValue = impl.coerceToNumber(rightValue) + } + } + + switch leftValue.Kind() { + case reflect.Bool: + return impl.compareNumber(float64(impl.coerceToNumber(leftValue).Int()), float64(impl.coerceToNumber(rightValue).Int()), kind) + case reflect.String: + return impl.compareString(strings.ToLower(leftValue.String()), strings.ToLower(rightValue.String()), kind) + + case reflect.Int: + if rightValue.Kind() == reflect.Float64 { + return impl.compareNumber(float64(leftValue.Int()), rightValue.Float(), kind) + } + + return impl.compareNumber(float64(leftValue.Int()), float64(rightValue.Int()), kind) + + case reflect.Float64: + if rightValue.Kind() == reflect.Int { + return impl.compareNumber(leftValue.Float(), float64(rightValue.Int()), kind) + } + + return impl.compareNumber(leftValue.Float(), rightValue.Float(), kind) + + case reflect.Invalid: + if rightValue.Kind() == reflect.Invalid { + return true, nil + } + + // not possible situation - params are converted to the same type in code above + return nil, fmt.Errorf("Compare params of Invalid type: left: %+v, right: %+v", leftValue.Kind(), rightValue.Kind()) + + default: + return nil, fmt.Errorf("Compare not implemented for types: left: %+v, right: %+v", leftValue.Kind(), rightValue.Kind()) + } +} + +func (impl *interperterImpl) coerceToNumber(value reflect.Value) reflect.Value { + switch value.Kind() { + case reflect.Invalid: + return reflect.ValueOf(0) + + case reflect.Bool: + switch value.Bool() { + case true: + return reflect.ValueOf(1) + case false: + return reflect.ValueOf(0) + } + + case reflect.String: + if value.String() == "" { + return reflect.ValueOf(0) + } + + // try to parse the string as a number + evaluated, err := impl.Evaluate(value.String(), DefaultStatusCheckNone) + if err != nil { + return reflect.ValueOf(math.NaN()) + } + + if value := reflect.ValueOf(evaluated); impl.isNumber(value) { + return value + } + } + + return reflect.ValueOf(math.NaN()) +} + +func (impl *interperterImpl) coerceToString(value reflect.Value) reflect.Value { + switch value.Kind() { + case reflect.Invalid: + return reflect.ValueOf("") + + case reflect.Bool: + switch value.Bool() { + case true: + return reflect.ValueOf("true") + case false: + return reflect.ValueOf("false") + } + + case reflect.String: + return value + + case reflect.Int: + return reflect.ValueOf(fmt.Sprint(value)) + + case reflect.Float64: + if math.IsInf(value.Float(), 1) { + return reflect.ValueOf("Infinity") + } else if math.IsInf(value.Float(), -1) { + return reflect.ValueOf("-Infinity") + } + return reflect.ValueOf(fmt.Sprintf("%.15G", value.Float())) + + case reflect.Slice: + return reflect.ValueOf("Array") + + case reflect.Map: + return reflect.ValueOf("Object") + } + + return value +} + +func (impl *interperterImpl) compareString(left string, right string, kind actionlint.CompareOpNodeKind) (bool, error) { + switch kind { + case actionlint.CompareOpNodeKindLess: + return left < right, nil + case actionlint.CompareOpNodeKindLessEq: + return left <= right, nil + case actionlint.CompareOpNodeKindGreater: + return left > right, nil + case actionlint.CompareOpNodeKindGreaterEq: + return left >= right, nil + case actionlint.CompareOpNodeKindEq: + return left == right, nil + case actionlint.CompareOpNodeKindNotEq: + return left != right, nil + default: + return false, fmt.Errorf("TODO: not implemented to compare '%+v'", kind) + } +} + +func (impl *interperterImpl) compareNumber(left float64, right float64, kind actionlint.CompareOpNodeKind) (bool, error) { + switch kind { + case actionlint.CompareOpNodeKindLess: + return left < right, nil + case actionlint.CompareOpNodeKindLessEq: + return left <= right, nil + case actionlint.CompareOpNodeKindGreater: + return left > right, nil + case actionlint.CompareOpNodeKindGreaterEq: + return left >= right, nil + case actionlint.CompareOpNodeKindEq: + return left == right, nil + case actionlint.CompareOpNodeKindNotEq: + return left != right, nil + default: + return false, fmt.Errorf("TODO: not implemented to compare '%+v'", kind) + } +} + +func IsTruthy(input interface{}) bool { + value := reflect.ValueOf(input) + switch value.Kind() { + case reflect.Bool: + return value.Bool() + + case reflect.String: + return value.String() != "" + + case reflect.Int: + return value.Int() != 0 + + case reflect.Float64: + if math.IsNaN(value.Float()) { + return false + } + + return value.Float() != 0 + + case reflect.Map, reflect.Slice: + return true + + default: + return false + } +} + +func (impl *interperterImpl) isNumber(value reflect.Value) bool { + switch value.Kind() { + case reflect.Int, reflect.Float64: + return true + default: + return false + } +} + +func (impl *interperterImpl) getSafeValue(value reflect.Value) interface{} { + switch value.Kind() { + case reflect.Invalid: + return nil + + case reflect.Float64: + if value.Float() == 0 { + return 0 + } + } + + return value.Interface() +} + +func (impl *interperterImpl) evaluateLogicalCompare(compareNode *actionlint.LogicalOpNode) (interface{}, error) { + left, err := impl.evaluateNode(compareNode.Left) + if err != nil { + return nil, err + } + + leftValue := reflect.ValueOf(left) + + if IsTruthy(left) == (compareNode.Kind == actionlint.LogicalOpNodeKindOr) { + return impl.getSafeValue(leftValue), nil + } + + right, err := impl.evaluateNode(compareNode.Right) + if err != nil { + return nil, err + } + + rightValue := reflect.ValueOf(right) + + switch compareNode.Kind { + case actionlint.LogicalOpNodeKindAnd: + return impl.getSafeValue(rightValue), nil + case actionlint.LogicalOpNodeKindOr: + return impl.getSafeValue(rightValue), nil + } + + return nil, fmt.Errorf("Unable to compare incompatibles types '%s' and '%s'", leftValue.Kind(), rightValue.Kind()) +} + +//nolint:gocyclo +func (impl *interperterImpl) evaluateFuncCall(funcCallNode *actionlint.FuncCallNode) (interface{}, error) { + args := make([]reflect.Value, 0) + + for _, arg := range funcCallNode.Args { + value, err := impl.evaluateNode(arg) + if err != nil { + return nil, err + } + + args = append(args, reflect.ValueOf(value)) + } + + switch strings.ToLower(funcCallNode.Callee) { + case "contains": + return impl.contains(args[0], args[1]) + case "startswith": + return impl.startsWith(args[0], args[1]) + case "endswith": + return impl.endsWith(args[0], args[1]) + case "format": + return impl.format(args[0], args[1:]...) + case "join": + if len(args) == 1 { + return impl.join(args[0], reflect.ValueOf(",")) + } + return impl.join(args[0], args[1]) + case "tojson": + return impl.toJSON(args[0]) + case "fromjson": + return impl.fromJSON(args[0]) + case "hashfiles": + if impl.env.HashFiles != nil { + return impl.env.HashFiles(args) + } + return impl.hashFiles(args...) + case "always": + return impl.always() + case "success": + if impl.config.Context == "job" { + return impl.jobSuccess() + } + if impl.config.Context == "step" { + return impl.stepSuccess() + } + return nil, fmt.Errorf("Context '%s' must be one of 'job' or 'step'", impl.config.Context) + case "failure": + if impl.config.Context == "job" { + return impl.jobFailure() + } + if impl.config.Context == "step" { + return impl.stepFailure() + } + return nil, fmt.Errorf("Context '%s' must be one of 'job' or 'step'", impl.config.Context) + case "cancelled": + return impl.cancelled() + default: + return nil, fmt.Errorf("TODO: '%s' not implemented", funcCallNode.Callee) + } +} |