diff --git a/.mockery.yaml b/.mockery.yaml index 819823e..b1d8755 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -4,6 +4,9 @@ packages: github.com/lukasjarosch/skipper: interfaces: ValueReferenceSource: + github.com/lukasjarosch/skipper/expression: + interfaces: + PathValueProvider: github.com/lukasjarosch/skipper/reference: interfaces: ValueSource: diff --git a/expression.go b/expression.go new file mode 100644 index 0000000..52c6c1f --- /dev/null +++ b/expression.go @@ -0,0 +1,3 @@ +package skipper + +type ExpressionManager struct{} diff --git a/expression/exec.go b/expression/exec.go new file mode 100644 index 0000000..d9149a5 --- /dev/null +++ b/expression/exec.go @@ -0,0 +1,267 @@ +package expression + +import ( + "fmt" + "reflect" + "strings" + + "github.com/lukasjarosch/skipper/data" +) + +type state struct { + node Node // the current node + expression *ExpressionNode + valueProvider PathValueProvider + variableMap map[string]any + funcMap map[string]any +} + +var ( + zero reflect.Value + + ErrUndefinedVariable = fmt.Errorf("undefined variable") + ErrFunctionNotDefined = fmt.Errorf("function not defined") + ErrCallInvalidArgumentCount = fmt.Errorf("invalid argument count") + ErrNotAFunc = fmt.Errorf("not a function") + ErrBadFuncSignature = fmt.Errorf("bad function signature") +) + +// UsedVariables returns a list of all variable names which are used within the expression. +func UsedVariables(expr *ExpressionNode) (variableNames []string) { + variablesInPath := func(path *PathNode) (vars []string) { + for _, segNode := range path.Segments { + switch node := segNode.(type) { + case *IdentifierNode: + continue + case *VariableNode: + vars = append(vars, node.Name) + } + } + return + } + var variablesInCall func(*CallNode) []string + variablesInCall = func(call *CallNode) (vars []string) { + for _, argNode := range call.Arguments { + switch node := argNode.(type) { + case *VariableNode: + vars = append(vars, node.Name) + case *PathNode: + vars = append(vars, variablesInPath(node)...) + case *CallNode: + vars = append(vars, variablesInCall(node)...) + } + } + if call.AlternativeExpr != nil { + vars = append(vars, UsedVariables(call.AlternativeExpr)...) + } + return + } + + switch node := expr.Child.(type) { + case *PathNode: + variableNames = append(variableNames, variablesInPath(node)...) + case *CallNode: + variableNames = append(variableNames, variablesInCall(node)...) + case *VariableNode: + variableNames = append(variableNames, node.Name) + } + + return +} + +type PathValueProvider interface { + GetPath(data.Path) (interface{}, error) +} + +func Execute(expr *ExpressionNode, valueProvider PathValueProvider, variableMap map[string]any, funcMap map[string]any) (val reflect.Value, err error) { + state := &state{ + expression: expr, + valueProvider: valueProvider, + variableMap: variableMap, + funcMap: funcMap, + } + defer errRecover(&err) + val, err = state.walkExpression(state.expression) + return +} + +// at marks the node as current node +func (s *state) at(node Node) { + s.node = node +} + +// errRecover is the handler that turns panics into returns from the top +// level of Parse. +func errRecover(errp *error) { + e := recover() + if e != nil { + *errp = fmt.Errorf("%v", e) + } +} + +func (s *state) error(err error) { + s.errorf("error: %w", err) +} + +func (s *state) errorf(format string, args ...any) { + if s.node != nil { + panic(fmt.Errorf("%w\n%s", fmt.Errorf(format, args...), s.expression.ErrorContext(s.node))) + } + + panic(fmt.Errorf(format, args...)) +} + +func (s *state) walkExpression(node *ExpressionNode) (reflect.Value, error) { + s.at(node) + switch node := node.Child.(type) { + case *PathNode: + return s.evalPath(node) + case *VariableNode: + return s.evalVariable(node) + case *CallNode: + return s.evalCall(node) + } + + return reflect.ValueOf(nil), fmt.Errorf("unimplemented") +} + +func (s *state) evalPath(path *PathNode) (reflect.Value, error) { + s.at(path) + + segments := []string{} + + for _, seg := range path.Segments { + switch segment := seg.(type) { + case *IdentifierNode: + segments = append(segments, segment.Value) + case *VariableNode: + val, err := s.evalVariable(segment) + if err != nil { + panic(err) // TODO: implement + } + segments = append(segments, val.String()) + default: + panic("NOT IMPLEMENTED") + } + } + + val, err := s.valueProvider.GetPath(data.NewPathVar(segments...)) + if err != nil { + s.error(err) + } + + return reflect.ValueOf(val), nil +} + +func (s *state) evalVariable(variable *VariableNode) (reflect.Value, error) { + s.at(variable) + value, exists := s.variableMap[variable.Name] + if !exists { + return zero, fmt.Errorf("%w: %s", ErrUndefinedVariable, variable.Name) + } + return reflect.ValueOf(value), nil +} + +func (s *state) evalCall(call *CallNode) (reflect.Value, error) { + s.at(call) + + ident := call.Identifier.Value + fn, ok := findFunction(ident, s) + if !ok { + return zero, fmt.Errorf("%w: %s", ErrFunctionNotDefined, ident) + } + + // make sure we're actually dealing with a func + typ := fn.Type() + if typ.Kind() != reflect.Func { + return zero, fmt.Errorf("%w: %s", ErrNotAFunc, ident) + } + + // number of argument nodes must match the parameter count of the func + numInArgs := len(call.Arguments) + if numInArgs != typ.NumIn() { + return zero, fmt.Errorf("%w for %s: want %d got %d", ErrCallInvalidArgumentCount, ident, typ.NumIn(), numInArgs) + } + + // assert that the function return values are valid + if !goodFunc(typ) { + outStr := []string{} + for i := 0; i < typ.NumOut(); i++ { + outStr = append(outStr, typ.Out(i).String()) + } + return zero, fmt.Errorf("%w %s: does not meet the requirements: %s() (%s)", ErrBadFuncSignature, ident, ident, strings.Join(outStr, ", ")) + } + + // make argument list + argv := make([]reflect.Value, numInArgs) + for i := 0; i < numInArgs; i++ { + argv[i] = s.evalCallArg(typ.In(i), call.Arguments[i]) + } + + // in case of an existing AlternativeExpr, perform the call and execute the AlternativeExpr in case of an error + if call.AlternativeExpr != nil { + val, err := safeCall(ident, fn, argv) + if err != nil { + return Execute(call.AlternativeExpr, s.valueProvider, s.variableMap, s.funcMap) + } + return val, nil + } + + // TODO: handle variadic functions + + return safeCall(ident, fn, argv) +} + +func (s *state) evalCallArg(typ reflect.Type, node Node) reflect.Value { + s.at(node) + switch typ.Kind() { + case reflect.String: + switch n := node.(type) { + case *VariableNode: + val, err := s.evalVariable(n) + if err != nil { + s.error(err) + } + return val + case *CallNode: + val, err := s.evalCall(n) + if err != nil { + s.error(err) + } + return val + default: + return s.evalString(typ, node) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + // TODO: handle variable and call + return s.evalInteger(typ, node) + } + + // TODO: handle floats + + return zero +} + +func (s *state) evalString(typ reflect.Type, n Node) reflect.Value { + s.at(n) + + if n, ok := n.(*StringNode); ok { + value := reflect.New(typ).Elem() + value.SetString(n.Value) + return value + } + s.errorf("expected string, found %s", n) + panic("not reached") +} + +func (s *state) evalInteger(typ reflect.Type, n Node) reflect.Value { + s.at(n) + + if n, ok := n.(*NumberNode); ok && n.IsInt { + value := reflect.New(typ).Elem() + value.SetInt(n.Int64) + return value + } + s.errorf("expected integer; found %s", n) + panic("not reached") +} diff --git a/expression/exec_test.go b/expression/exec_test.go new file mode 100644 index 0000000..3338a43 --- /dev/null +++ b/expression/exec_test.go @@ -0,0 +1,216 @@ +package expression_test + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/lukasjarosch/skipper/data" + "github.com/lukasjarosch/skipper/expression" + mock "github.com/lukasjarosch/skipper/mocks/expression" +) + +var ( + envVar = "WHY_DID_YOU_SET_THIS" + envVarValue = "lol" +) + +func init() { + os.Setenv(envVar, envVarValue) +} + +func TestExecuteExpression(t *testing.T) { + tests := []struct { + name string + input string + pathValues map[string]interface{} + variableMap map[string]interface{} + funcMap map[string]interface{} + expected interface{} + errExpected error + }{ + { + name: "single path without variables", + input: `${foo:bar:baz}`, + pathValues: map[string]interface{}{ + "foo.bar.baz": "HELLO", + }, + expected: "HELLO", + }, + { + name: "single path with single variable", + input: `${foo:bar:$target_name}`, + variableMap: map[string]interface{}{ + "target_name": "develop", + }, + pathValues: map[string]interface{}{ + "foo.bar.develop": "HELLO", + }, + expected: "HELLO", + }, + { + name: "single path with multiple variables", + input: `${foo:$name:$target_name}`, + variableMap: map[string]interface{}{ + "name": "bar", + "target_name": "develop", + }, + pathValues: map[string]interface{}{ + "foo.bar.develop": "HELLO", + }, + expected: "HELLO", + }, + { + name: "standalone variable expression", + input: `${target_name}`, + variableMap: map[string]interface{}{ + "target_name": "develop", + }, + expected: "develop", + }, + { + name: "standalone inline variable expression", + input: `${$target_name}`, + variableMap: map[string]interface{}{ + "target_name": "develop", + }, + expected: "develop", + }, + { + name: "undefined variable", + input: `${foo_bar}`, + variableMap: map[string]interface{}{}, + pathValues: map[string]interface{}{}, + funcMap: map[string]interface{}{}, + errExpected: expression.ErrUndefinedVariable, + }, + { + name: "undefined function", + input: `${say_hello()}`, + variableMap: map[string]interface{}{}, + pathValues: map[string]interface{}{}, + funcMap: map[string]interface{}{}, + errExpected: expression.ErrFunctionNotDefined, + }, + { + name: "call user func with too many args", + input: `${say_hello("foo", "bar")}`, + variableMap: map[string]interface{}{}, + pathValues: map[string]interface{}{}, + funcMap: map[string]interface{}{ + "say_hello": func() string { return "Hello there" }, + }, + errExpected: expression.ErrCallInvalidArgumentCount, + }, + { + name: "call user func with wrong return types", + input: `${say_hello()}`, + variableMap: map[string]interface{}{}, + pathValues: map[string]interface{}{}, + funcMap: map[string]interface{}{ + "say_hello": func() (string, int) { return "", 0 }, + }, + errExpected: expression.ErrBadFuncSignature, + }, + { + name: "not a function in funcMap", + input: `${say_hello()}`, + variableMap: map[string]interface{}{}, + pathValues: map[string]interface{}{}, + funcMap: map[string]interface{}{ + "say_hello": "i am invalid", + }, + errExpected: expression.ErrNotAFunc, + }, + { + name: "call user defined function with no args", + input: `${say_hello()}`, + variableMap: map[string]interface{}{}, + pathValues: map[string]interface{}{}, + funcMap: map[string]interface{}{ + "say_hello": func() string { return "hello" }, + }, + expected: "hello", + }, + { + name: "call user defined function with string argument", + input: `${say_hello("john")}`, + variableMap: map[string]interface{}{}, + pathValues: map[string]interface{}{}, + funcMap: map[string]interface{}{ + "say_hello": func(name string) string { return fmt.Sprintf("hello, %s", name) }, + }, + expected: "hello, john", + }, + { + name: "call user defined function with invalid argument type", + input: `${say_hello("1337")}`, + variableMap: map[string]interface{}{}, + pathValues: map[string]interface{}{}, + funcMap: map[string]interface{}{ + "say_hello": func(count int) string { + return fmt.Sprintf("hello, %d", count) + }, + }, + errExpected: fmt.Errorf("expected integer; found String"), + }, + { + name: "get_env builtin with variable not set", + input: `${get_env("THIS_CANNOT_POSSIBLY_BE_SET_COME_ON")}`, + variableMap: map[string]interface{}{}, + pathValues: map[string]interface{}{}, + errExpected: fmt.Errorf("environment variable not set"), + }, + { + name: "get_env builtin with variable set", + input: `${get_env("WHY_DID_YOU_SET_THIS")}`, // set via init() + variableMap: map[string]interface{}{}, + pathValues: map[string]interface{}{}, + expected: envVarValue, + }, + { + name: "get_env builtin with variable not set but set_env alternative expression", + input: `${get_env("THIS_CANNOT_POSSIBLY_BE_SET_COME_ON") || set_env("THIS_CANNOT_POSSIBLY_BE_SET_COME_ON", "lol")}`, + expected: envVarValue, + }, + { + name: "set_env with variable arguments", + input: `${set_env($MY_VAR, $value)}`, + variableMap: map[string]interface{}{ + "MY_VAR": "THIS_CANNOT_POSSIBLY_BE_SET_COME_ON_NOW_REALLY", + "value": envVarValue, + }, + pathValues: map[string]interface{}{}, + expected: envVarValue, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expressions, err := expression.Parse(tt.input) + assert.NoError(t, err) + assert.NotEmpty(t, expressions) + + pathValueProvider := mock.NewMockPathValueProvider(t) + + for _, expr := range expressions { + + for path, val := range tt.pathValues { + pathValueProvider.EXPECT().GetPath(data.NewPath(path)).Return(val, nil) + } + + val, err := expression.Execute(expr, pathValueProvider, tt.variableMap, tt.funcMap) + + if tt.errExpected != nil { + assert.ErrorContains(t, err, tt.errExpected.Error()) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.expected, val.String()) + } + }) + } +} diff --git a/expression/func.go b/expression/func.go new file mode 100644 index 0000000..3fd07c5 --- /dev/null +++ b/expression/func.go @@ -0,0 +1,110 @@ +package expression + +import ( + "fmt" + "os" + "reflect" + "sync" +) + +type FuncMap map[string]any + +var errorType = reflect.TypeFor[error]() + +var builtins = FuncMap{ + "get_env": get_env, + "set_env": set_env, +} + +var builtinFuncsOnce struct { + sync.Once + v map[string]reflect.Value +} + +func builtinFuncs() map[string]reflect.Value { + builtinFuncsOnce.Do(func() { + builtinFuncsOnce.v = createValueFuncs(builtins) + }) + return builtinFuncsOnce.v +} + +// createValueFuncs turns a FuncMap into a map[string]reflect.Value +func createValueFuncs(funcMap FuncMap) map[string]reflect.Value { + m := make(map[string]reflect.Value) + addValueFuncs(m, funcMap) + return m +} + +// addValueFuncs adds to values the functions in funcs, converting them to reflect.Values. +func addValueFuncs(out map[string]reflect.Value, in FuncMap) { + for name, fn := range in { + v := reflect.ValueOf(fn) + if v.Kind() != reflect.Func { + panic("value for " + name + " not a function") + } + if !goodFunc(v.Type()) { + panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut())) + } + out[name] = v + } +} + +// goodFunc returns true when the given function has either one return value +// or two, whereas the second must be of type 'error'. +// All other function signatures are not good and false is returned. +func goodFunc(fn reflect.Type) bool { + switch { + case fn.NumOut() == 1: + return true + case fn.NumOut() == 2 && fn.Out(1) == errorType: + return true + } + + return false +} + +func findFunction(name string, s *state) (reflect.Value, bool) { + if fn := s.funcMap[name]; reflect.ValueOf(fn).IsValid() { + return reflect.ValueOf(fn), true + } + if fn := builtinFuncs()[name]; fn.IsValid() { + return fn, true + } + + return zero, false +} + +func safeCall(ident string, fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) { + defer func() { + if r := recover(); r != nil { + if e, ok := r.(error); ok { + err = e + } else { + err = fmt.Errorf("recovered panic in %s: %v", ident, r) + } + } + }() + + ret := fun.Call(args) + if len(ret) == 2 && !ret[1].IsNil() { + return ret[0], fmt.Errorf("%s: %w", ident, ret[1].Interface().(error)) + } + return ret[0], nil +} + +// get_env will lookup the given name as environment variable and return its value. +// If the variable does not exist, an error is returned. +// If the variable exists, but is empty, the empty value is returned. +func get_env(name string) (string, error) { + val, exists := os.LookupEnv(name) + if !exists { + return "", fmt.Errorf("environment variable not set: %s", name) + } + return val, nil +} + +// set_env will attempt to set an environment variable with the given name and value. +// It will return the set value and an error (if any). +func set_env(name string, value string) (string, error) { + return value, os.Setenv(name, value) +} diff --git a/expression/lex.go b/expression/lex.go new file mode 100644 index 0000000..bbd8945 --- /dev/null +++ b/expression/lex.go @@ -0,0 +1,370 @@ +package expression + +import ( + "fmt" + "regexp" + "strings" + "unicode" + "unicode/utf8" +) + +type TokenType int + +const ( + tError TokenType = iota + tEOF + tIdent + tLeftDelim + tRightDelim + tLeftParen + tRightParen + tPathSep // : + tDoublePipe // || + tDollar + tComma + + tString + tNumber +) + +func TokenString(t TokenType) string { + switch t { + case tEOF: + return "EOF" + case tError: + return "ERROR" + case tString: + return "STRING" + case tIdent: + return "IDENTIFIER" + case tPathSep: + return "PATH_SEPARATOR" + case tLeftDelim: + return "LEFT_DELIMITER" + case tRightDelim: + return "RIGHT_DELIMITER" + case tLeftParen: + return "LEFT_PARENTHESES" + case tRightParen: + return "RIGHT_PARENTHESES" + case tDoublePipe: + return "DOUBLE_PIPE" + case tDollar: + return "DOLLAR" + case tComma: + return "COMMA" + case tNumber: + return "NUMBER" + default: + return "UNKNOWN" + } +} + +const eof = -1 + +type Token struct { + Pos int + Type TokenType + Value string +} + +type stateFn func(*lexer) stateFn + +type lexer struct { + input string + start int + pos int + width int + tokens chan Token + token Token + parenDepth int // nesting depth of '( )' expressions + exprDepth int // nesting depth of delimited expressions (e.g. ${foo:${bar}}) +} + +func lex(input string) *lexer { + l := &lexer{ + input: input, + tokens: make(chan Token, 3), + } + go l.run() + return l +} + +// run starts the lexer +func (l *lexer) run() { + for state := lexText; state != nil; { + state = state(l) + } + close(l.tokens) +} + +// nextToken returns the next Token from the input. +// Called by the parser, not the lexer! +func (l *lexer) nextToken() Token { + for { + select { + case token, ok := <-l.tokens: + if !ok { + return Token{Type: tEOF, Value: "EOF"} + } + return token + default: + } + } +} + +func (l *lexer) emit(t TokenType) { + value := l.current() + l.tokens <- Token{ + Pos: l.pos, + Value: value, + Type: t, + } + l.updatePos() +} + +func (l *lexer) next() rune { + if l.pos >= len(l.input) { + l.width = 0 + return eof + } + r, w := utf8.DecodeRuneInString(l.input[l.pos:]) + l.width = w + l.pos += l.width + return r +} + +// accept consumes the next rune if it's in the valid set +func (l *lexer) accept(valid string) bool { + if strings.IndexRune(valid, l.next()) >= 0 { + return true + } + l.backup() + return false +} + +// acceptRun consumes a run of runes from the valid set +func (l *lexer) acceptRun(valid string) { + for strings.IndexRune(valid, l.next()) >= 0 { + } + l.backup() +} + +// acceptRegexRun consumes all runes which match the given regex +func (l *lexer) acceptRegexRun(valid *regexp.Regexp) { + for valid.MatchString(string(l.next())) { + } + l.backup() +} + +func (l *lexer) current() string { + return l.input[l.start:l.pos] +} + +func (l *lexer) updatePos() { + l.start = l.pos +} + +func (l *lexer) ignore() { + l.start = l.pos +} + +func (l *lexer) backup() { + l.pos -= l.width +} + +func (l *lexer) peek() rune { + r := l.next() + l.backup() + return r +} + +func (l *lexer) atLeftDelim() bool { + return strings.HasPrefix(l.input[l.pos:], leftDelim) +} + +func (l *lexer) atRightDelim() bool { + return strings.HasPrefix(l.input[l.pos:], rightDelim) +} + +// isSpace reports whether r is a space character. +func isSpace(r rune) bool { + return r == ' ' || r == '\t' || r == '\r' || r == '\n' +} + +// isAlphaNumeric reports whether r is an alphabetic, digit, or underscore. +func isAlphaNumeric(r rune) bool { + return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r) +} + +// ----- state transition functions + +const ( + leftDelim = "${" + rightDelim = "}" +) + +// errorf emits an error token and returns nil +func (l *lexer) errorf(format string, args ...interface{}) stateFn { + l.tokens <- Token{ + Type: tError, + Pos: l.pos, + Value: fmt.Sprintf(format, args...), + } + return nil +} + +func lexText(l *lexer) stateFn { + for { + if l.atLeftDelim() { + if l.pos > l.start { + l.ignore() // ignore any preceding text + } + l.pos += len(leftDelim) // skip leftDelim + l.emit(tLeftDelim) + l.exprDepth++ + return lexExpression + } + + if l.next() == eof { + break + } + } + l.ignore() // drop any text + l.emit(tEOF) + return nil +} + +func lexExpression(l *lexer) stateFn { + if l.atRightDelim() { + l.pos += len(rightDelim) // skip rightDelim + l.emit(tRightDelim) + l.exprDepth-- + if l.exprDepth == 0 { + + // if there are still unclosed function calls, the expression cannot be ending here + if l.parenDepth > 0 { + return l.errorf("missing right parentheses") + } + + return lexText + } + if l.exprDepth < 0 { + return l.errorf("unexpected right delimiter %s", rightDelim) + } + return lexExpression + } + + switch r := l.next(); { + case r == '+' || r == '-' || ('0' <= r && r <= '9'): + l.backup() + return lexNumber + case isAlphaNumeric(r): + l.backup() + return lexIdentifier + case r == ':': + l.emit(tPathSep) + return lexExpression + case r == '$': + // nested expression? + if l.peek() == '{' { + l.next() + l.emit(tLeftDelim) + l.exprDepth++ + return lexExpression + } + // dollars indicate a variable + l.emit(tDollar) + return lexExpression + + // start param list + case r == '(': + l.parenDepth++ + l.emit(tLeftParen) + return lexExpression + + // end param list + case r == ')': + l.parenDepth-- + l.emit(tRightParen) + return lexExpression + + // quoted strings + case r == '\'': + l.ignore() + return lexQuotedString('\'') + case r == '"': + l.ignore() + return lexQuotedString('"') + + // commas within parameter lists + case r == ',': + l.emit(tComma) + return lexExpression + + // drop any spaces within the expression + case isSpace(r): + l.ignore() + return lexExpression + + // alternate expressions + case r == '|': + if l.peek() == '|' { + l.next() + l.emit(tDoublePipe) + return lexExpression + } + return l.errorf("invalid token %#U", r) + + // input cannot end before the rightDelim is found + case r == eof: + return l.errorf("unclosed expression, expected %s, got %s", rightDelim, TokenString(tEOF)) + + // fail on any other rune + default: + return l.errorf("unrecognized rune in expression: %#U", r) + } +} + +func lexIdentifier(l *lexer) stateFn { + l.acceptRegexRun(regexp.MustCompile(`\w+`)) + l.emit(tIdent) + return lexExpression +} + +func lexQuotedString(quote rune) stateFn { + return func(l *lexer) stateFn { + Loop: + for { + switch l.next() { + case eof, '\n': + return l.errorf("unterminated single quoted string") + case quote: + l.backup() + break Loop + } + } + + l.emit(tString) + + // skip quote and return + l.next() + l.ignore() + return lexExpression + } +} + +// lexNumber scans a number. +// This is very basic as there is no support for octal, hex, imaginary, etc. +// The only supported numbers are integers and floats including signs +func lexNumber(l *lexer) stateFn { + // optional: leading sign + l.accept("+-") + digits := "0123456789" + l.acceptRun(digits) + if l.accept(".") { + l.acceptRun(digits) + } + l.emit(tNumber) + return lexExpression +} diff --git a/expression/node.go b/expression/node.go new file mode 100644 index 0000000..b291e82 --- /dev/null +++ b/expression/node.go @@ -0,0 +1,263 @@ +package expression + +import ( + "fmt" + "strconv" + "strings" +) + +type Node interface { + Type() NodeType + Position() Pos + Text() string +} + +// NodeType identifies the type of a parse tree node. +type NodeType int + +// Pos represents a byte position in the original input text +type Pos int + +func (p Pos) Position() Pos { + return p +} + +// Type returns itself and provides an easy default implementation +// for embedding in a Node. Embedded in all non-trivial Nodes. +func (t NodeType) Type() NodeType { + return t +} + +func (t NodeType) Text() string { + return "" +} + +func (t NodeType) String() string { + switch t { + case NodeExpression: + return "Expression" + case NodeList: + return "List" + case NodePath: + return "Path" + case NodeIdentifier: + return "Identifier" + case NodeVariable: + return "Variable" + case NodeCall: + return "Call" + case NodeString: + return "String" + case NodeNumber: + return "Number" + default: + return "UNKNOWN NODE TYPE" + } +} + +const ( + NodeExpression NodeType = iota + NodeList + NodePath + NodeIdentifier + NodeVariable + NodeCall + NodeString + NodeNumber +) + +// ListNode holds a sequence of nodes. +type ListNode struct { + NodeType + Pos + Nodes []Node // The element nodes in lexical order. +} + +func (t *Tree) newList(pos Pos) *ListNode { + return &ListNode{NodeType: NodeList, Pos: pos} +} + +func (l *ListNode) append(n Node) { + l.Nodes = append(l.Nodes, n) +} + +type ExpressionNode struct { + NodeType + Pos + Child Node +} + +func (n ExpressionNode) Text() string { + return fmt.Sprintf("${%s}", n.Child.Text()) +} + +func (n ExpressionNode) ErrorContext(node Node) string { + underline := func(a, b string) string { + return strings.Repeat(" ", strings.Index(a, b)) + strings.Repeat("^", len(b)) + } + + context := fmt.Sprintln("Context:") + // context += fmt.Sprintln("|") + context += fmt.Sprintln("|", n.Text()) + context += fmt.Sprintln("|", underline(n.Text(), node.Text()), "-- HERE") + + return context +} + +func (t *Tree) newExpression(pos Pos, child Node) *ExpressionNode { + return &ExpressionNode{Pos: pos, NodeType: NodeExpression, Child: child} +} + +type VariableNode struct { + NodeType + Pos + Name string +} + +func (t *Tree) newVariable(pos Pos, name string) *VariableNode { + return &VariableNode{Pos: pos, Name: name, NodeType: NodeVariable} +} + +func (v *VariableNode) Text() string { + return "$" + v.Name +} + +type CallNode struct { + NodeType + Pos + Identifier *IdentifierNode + Arguments []Node + AlternativeExpr *ExpressionNode +} + +func (t *Tree) newCall(pos Pos, ident *IdentifierNode) *CallNode { + return &CallNode{Pos: pos, Identifier: ident, NodeType: NodeCall} +} + +func (n *CallNode) Text() string { + args := []string{} + for _, a := range n.Arguments { + args = append(args, a.Text()) + } + + if n.AlternativeExpr != nil { + return fmt.Sprintf("%s(%s) || %s", n.Identifier.Text(), strings.Join(args, ", "), n.AlternativeExpr.Text()) + } + + return fmt.Sprintf("%s(%s)", n.Identifier.Text(), strings.Join(args, ", ")) +} + +func (n *CallNode) appendArgument(arg Node) { + n.Arguments = append(n.Arguments, arg) +} + +type PathNode struct { + NodeType + Pos + Segments []Node // path segments from left to right, without separators +} + +func (t *Tree) newPath(pos Pos) *PathNode { + return &PathNode{Pos: pos, NodeType: NodePath} +} + +func (n *PathNode) appendSegment(node Node) { + n.Segments = append(n.Segments, node) +} + +func (n *PathNode) Text() string { + segments := []string{} + for _, seg := range n.Segments { + segments = append(segments, seg.Text()) + } + return strings.Join(segments, ":") +} + +type IdentifierNode struct { + NodeType + Pos + Value string +} + +func (t *Tree) newIdentifier(pos Pos, value string) *IdentifierNode { + return &IdentifierNode{Pos: pos, NodeType: NodeIdentifier, Value: value} +} + +func (i *IdentifierNode) Text() string { + return i.Value +} + +type StringNode struct { + NodeType + Pos + Value string +} + +func (t *Tree) newString(pos Pos, value string) *StringNode { + return &StringNode{Pos: pos, NodeType: NodeString, Value: value} +} + +func (s *StringNode) Text() string { + return fmt.Sprintf("\"%s\"", s.Value) +} + +type NumberNode struct { + NodeType + Pos + IsInt bool // Number has an integral value. + IsUint bool // Number has an unsigned integral value. + IsFloat bool // Number has a floating-point value. + Int64 int64 // The signed integer value. + Uint64 uint64 // The unsigned integer value. + Float64 float64 // The floating-point value. + Value string +} + +func (n *NumberNode) Text() string { + return n.Value +} + +func (t *Tree) newNumber(pos Pos, value string) (*NumberNode, error) { + n := &NumberNode{Pos: pos, Value: value, NodeType: NodeNumber} + + u, err := strconv.ParseUint(value, 0, 64) + if err == nil { + n.IsUint = true + n.Uint64 = u + } + i, err := strconv.ParseInt(value, 0, 64) + if err == nil { + n.IsInt = true + n.Int64 = i + } + + if n.IsInt { + n.IsFloat = true + n.Float64 = float64(n.Int64) + } else if n.IsUint { + n.IsFloat = true + n.Float64 = float64(n.Uint64) + } else { + f, err := strconv.ParseFloat(value, 64) + if err == nil { + n.IsFloat = true + n.Float64 = f + + // a float may also be a valid integer + if !n.IsInt && float64(int64(f)) == f { + n.IsInt = true + n.Int64 = int64(f) + } + if !n.IsUint && float64(uint64(f)) == f { + n.IsUint = true + n.Uint64 = uint64(f) + } + } + } + + if !n.IsInt && !n.IsUint && !n.IsFloat { + return nil, fmt.Errorf("illegal number syntax: %q", value) + } + + return n, nil +} diff --git a/expression/parse.go b/expression/parse.go new file mode 100644 index 0000000..301669a --- /dev/null +++ b/expression/parse.go @@ -0,0 +1,380 @@ +package expression + +import ( + "fmt" + "strings" +) + +type Tree struct { + root *ListNode + lex *lexer + input string + token [3]Token // lookahead buffer + inExpression bool + peekCount int +} + +func Parse(text string) ([]*ExpressionNode, error) { + t := &Tree{} + return t.Parse(text) +} + +func (t *Tree) Parse(input string) ([]*ExpressionNode, error) { + t.input = input + t.lex = lex(input) + t.parse() + + expr := []*ExpressionNode{} + for _, node := range t.root.Nodes { + expr = append(expr, node.(*ExpressionNode)) + } + + return expr, nil +} + +// expect consumes the next token and guarantees it has the required type. +func (t *Tree) expect(expected TokenType, context string) Token { + token := t.next() + if token.Type != expected { + t.unexpected(token, context) + } + return token +} + +// unexpected complains about the token and terminates processing. +func (t *Tree) unexpected(token Token, context string) { + if token.Type == tError { + t.errorfWithContext(token, "%s in %s: %s", TokenString(token.Type), context, token.Value) + } + t.errorfWithContext(token, "unexpected %s in %s", TokenString(token.Type), context) +} + +// errorf formats the error and terminates processing. +func (t *Tree) errorf(format string, args ...any) { + t.root = nil + format = fmt.Sprintf("parse: %s at %d: %s", t.token[0].Value, t.token[0].Pos, format) + panic(fmt.Errorf(format, args...)) +} + +// errorfWithContext calls 'errorf' but adds failure context for the user based on the token. +func (t *Tree) errorfWithContext(tok Token, format string, args ...interface{}) { + line := func(pos int) (string, int) { + // in case the input is multiline, extract just the line we're in + if strings.Contains(t.input, "\n") { + beforeNewLine := strings.LastIndex(t.input[:pos], "\n") + 1 + afterNewLine := strings.Index(t.input[pos:], "\n") + pos + return strings.TrimSpace(t.input[beforeNewLine:afterNewLine]), pos - beforeNewLine - 1 + } + + return t.input, pos + } + + context := "\nContext:" + context += "\n|" + + contextLine, newPos := line(tok.Pos) + context += fmt.Sprintf("\n| %s\n", contextLine) + context += fmt.Sprintf("| %s^--HERE\n", strings.Repeat(" ", newPos-1)) + + format += "\n%s" + args = append(args, context) + t.errorf(format, args...) +} + +// peek returns the next token without consuming it +func (t *Tree) peek() Token { + if t.peekCount > 0 { + return t.token[t.peekCount-1] + } + t.peekCount = 1 + t.token[0] = t.lex.nextToken() + return t.token[0] +} + +// backup backs the input stream up one token. +func (t *Tree) backup() { + t.peekCount++ +} + +// backup2 backs the input stream up two tokens. +// The zeroth token is already there. +func (t *Tree) backup2(t1 Token) { + t.token[1] = t1 + t.peekCount = 2 +} + +// next returns the next token. +func (t *Tree) next() Token { + if t.peekCount > 0 { + t.peekCount-- + } else { + t.token[0] = t.lex.nextToken() + } + return t.token[t.peekCount] +} + +// parse starts parsing the input +// +// grammar ::= '${' expression '}' +func (t *Tree) parse() { + t.root = t.newList(Pos(t.peek().Pos)) + for t.peek().Type != tEOF { + + // consume the next token if its a left delimiter + // otherwise backup + if tok := t.next(); tok.Type == tLeftDelim { + n := t.parseExpression() + if n != nil { + t.root.append(n) + } + continue + } + t.backup() + + // In case we're already inside an expression, go back in. + if t.inExpression { + t.parseExpression() + continue + } + } +} + +// parseExpression +// +// expression ::= standalone_variable | inline_variable | path | call +// +// NOTE: The left delimiter is already consumed at this point. +func (t *Tree) parseExpression() (expr *ExpressionNode) { + t.inExpression = true + + switch tok := t.peek(); tok.Type { + // inline variable (with dollar prefix) + case tDollar: + return t.newExpression(Pos(tok.Pos), t.parseInlineVariable()) + case tIdent: + ident := t.next() // swallow identifier to peek at the next token + + switch tok := t.peek(); tok.Type { + // identifier followed by '(' -> Call + case tLeftParen: + t.backup2(ident) // restore identifier + return t.newExpression(Pos(tok.Pos), t.parseCall()) + + // identifier followed by tPathSep -> Path + case tPathSep: + t.backup2(ident) // restore identifier + return t.newExpression(Pos(tok.Pos), t.parsePath()) + + // only an identifier, then a right-delimiter -> standalone_variable + case tRightDelim: + t.backup2(ident) + return t.newExpression(Pos(ident.Pos), t.parseStandaloneVariable()) + + default: + t.errorfWithContext(tok, "unexpected %s after identifier", TokenString(tok.Type)) + } + + // expression ends + case tRightDelim: + t.next() + t.inExpression = false + return nil // nothing to return, all expression nodes were already emitted above + + case tError: + t.errorfWithContext(tok, "lexer error") + default: + t.unexpected(tok, "parseExpression") + } + + return +} + +// parseInlineVariable +// +// inline_variable ::= '$' standalone_variable +func (t *Tree) parseInlineVariable() *VariableNode { + t.expect(tDollar, "parseVariable") + + return t.parseStandaloneVariable() +} + +// parseStandaloneVariable +// +// standalone_variable ::= identifier +func (t *Tree) parseStandaloneVariable() *VariableNode { + tok := t.next() + if tok.Type != tIdent { + t.unexpected(tok, "parseVariable") + } + + return t.newVariable(Pos(tok.Pos), tok.Value) +} + +// parseCall +// +// call ::= identifier "(" argument_list ")" alternative_expression +// alternative_expression ::= "||" expression +func (t *Tree) parseCall() Node { + ident := t.parseIdentifier() + call := t.newCall(ident.Pos, ident) + + t.expect(tLeftParen, "parseCall") + + // argument list surrounded by parentheses + for _, arg := range t.parseCallArgumentList() { + call.appendArgument(arg) + } + + // closing parentheses + t.expect(tRightParen, "parseCall") + + // alternative exepression + if t.peek().Type == tDoublePipe { + t.expect(tDoublePipe, "parseCall") + call.AlternativeExpr = t.parseExpression() + } + + return call +} + +// parseCallArgumentList +// +// argument_list ::= {(quoted_string | inline_variable | path | call)} {argument_list_tail} +// argument_list_tail ::= {"," (quoted_string | inline_variable | path | call)} +func (t *Tree) parseCallArgumentList() (args []Node) { + for t.peek().Type != tRightParen { + tok := t.peek() + + switch tok.Type { + + // variable argument + case tDollar: + args = append(args, t.parseInlineVariable()) + continue + + // path or call argument + case tIdent: + ident := t.next() + + switch tok := t.peek(); tok.Type { + // path is argument + case tPathSep: + t.backup2(ident) + args = append(args, t.parsePath()) + + // call is argument + case tLeftParen: + t.backup2(ident) + args = append(args, t.parseCall()) + + default: + t.errorfWithContext(tok, "expected path separator or left parentheses after identifier, got %s", TokenString(tok.Type)) + } + continue + + // quoted string argument + case tString: + args = append(args, t.parseString()) + continue + case tNumber: + args = append(args, t.parseNumber()) + continue + case tError: + t.errorfWithContext(tok, "lexer error") + } + + // if there are more args, there must be a comma + // anything other is a syntax error + if tok.Type != tRightParen { + switch tok.Type { + case tComma: + t.next() // consume comma, and continue parsing args + continue + default: + t.errorfWithContext(t.peek(), "unexpected %s in argument list", TokenString(t.peek().Type)) + } + } + } + + return +} + +// parseIdentifier expects the next token to be tIdent. +func (t *Tree) parseIdentifier() *IdentifierNode { + tok := t.next() + if tok.Type != tIdent { + t.unexpected(tok, "parseIdentifier") + } + + return t.newIdentifier(Pos(tok.Pos), tok.Value) +} + +// parseString expects the next token to be tString +func (t *Tree) parseString() *StringNode { + tok := t.next() + if tok.Type != tString { + t.unexpected(tok, "parseString") + } + return t.newString(Pos(tok.Pos), tok.Value) +} + +func (t *Tree) parseNumber() *NumberNode { + tok := t.next() + if tok.Type != tNumber { + t.unexpected(tok, "parseNumber") + } + + node, err := t.newNumber(Pos(tok.Pos), tok.Value) + if err != nil { + t.errorfWithContext(tok, err.Error()) + return nil + } + + return node +} + +// parsePath +// +// path ::= identifier, path_tail +// path_tail ::= {":" (identifier | inline_variable)}+ +// +// Thus, a path MUST start with an identifier and must have at lest +// one path_tail segment. +func (t *Tree) parsePath() Node { + // a path can have only 256 path segments + const maxLength = 256 + + path := t.newPath(Pos(t.peek().Pos)) + + // Every second (uneven) token must be a path identifier + // There are maxLength-1 identifiers in a path of maxLength + // If this loop terminates, the path is too long. + for i := 0; i <= (maxLength + (maxLength - 1)); i++ { + // The first segment of a path must be an identifier. + if i == 0 { + path.appendSegment(t.parseIdentifier()) + continue + } + + // Intermediate segments may be either identifiers or variables + switch tok := t.peek(); tok.Type { + case tIdent: + path.appendSegment(t.parseIdentifier()) + continue + case tDollar: + path.appendSegment(t.parseInlineVariable()) + continue + case tRightDelim, tRightParen, tComma: + return path + } + + // Every second token must be a separator + if i%2 == 1 { + t.expect(tPathSep, "parsePath") + continue + } + } + + t.errorf("path is too long, max length is %d segments", maxLength) + return nil +} diff --git a/expression/parse_test.go b/expression/parse_test.go new file mode 100644 index 0000000..b9a37b4 --- /dev/null +++ b/expression/parse_test.go @@ -0,0 +1,49 @@ +package expression_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/lukasjarosch/skipper/expression" +) + +func TestParse(t *testing.T) { + tests := []struct { + name string + input string + expressions []*expression.ExpressionNode + }{ + { + name: "single path expr with only identifiers", + input: `${foo:bar:baz}`, + expressions: []*expression.ExpressionNode{ + { + Child: &expression.PathNode{ + Segments: []expression.Node{ + &expression.IdentifierNode{ + Value: "foo", + }, + &expression.IdentifierNode{ + Value: "bar", + }, + &expression.IdentifierNode{ + Value: "baz", + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expressions, err := expression.Parse(tt.input) + assert.NoError(t, err) + assert.ElementsMatch(t, tt.expressions, expressions) + + // TODO: write own ElementsMatch which only compares values and not positions + }) + } +} diff --git a/mocks/expression/mock_PathValueProvider.go b/mocks/expression/mock_PathValueProvider.go new file mode 100644 index 0000000..1b52459 --- /dev/null +++ b/mocks/expression/mock_PathValueProvider.go @@ -0,0 +1,90 @@ +// Code generated by mockery v2.34.2. DO NOT EDIT. + +package expression + +import ( + data "github.com/lukasjarosch/skipper/data" + + mock "github.com/stretchr/testify/mock" +) + +// MockPathValueProvider is an autogenerated mock type for the PathValueProvider type +type MockPathValueProvider struct { + mock.Mock +} + +type MockPathValueProvider_Expecter struct { + mock *mock.Mock +} + +func (_m *MockPathValueProvider) EXPECT() *MockPathValueProvider_Expecter { + return &MockPathValueProvider_Expecter{mock: &_m.Mock} +} + +// GetPath provides a mock function with given fields: _a0 +func (_m *MockPathValueProvider) GetPath(_a0 data.Path) (interface{}, error) { + ret := _m.Called(_a0) + + var r0 interface{} + var r1 error + if rf, ok := ret.Get(0).(func(data.Path) (interface{}, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(data.Path) interface{}); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + if rf, ok := ret.Get(1).(func(data.Path) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockPathValueProvider_GetPath_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPath' +type MockPathValueProvider_GetPath_Call struct { + *mock.Call +} + +// GetPath is a helper method to define mock.On call +// - _a0 data.Path +func (_e *MockPathValueProvider_Expecter) GetPath(_a0 interface{}) *MockPathValueProvider_GetPath_Call { + return &MockPathValueProvider_GetPath_Call{Call: _e.mock.On("GetPath", _a0)} +} + +func (_c *MockPathValueProvider_GetPath_Call) Run(run func(_a0 data.Path)) *MockPathValueProvider_GetPath_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(data.Path)) + }) + return _c +} + +func (_c *MockPathValueProvider_GetPath_Call) Return(_a0 interface{}, _a1 error) *MockPathValueProvider_GetPath_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockPathValueProvider_GetPath_Call) RunAndReturn(run func(data.Path) (interface{}, error)) *MockPathValueProvider_GetPath_Call { + _c.Call.Return(run) + return _c +} + +// NewMockPathValueProvider creates a new instance of MockPathValueProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockPathValueProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *MockPathValueProvider { + mock := &MockPathValueProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/mocks/skipper/mock_ValueReferenceSource.go b/mocks/skipper/mock_ValueReferenceSource.go index b3de1ac..66d2728 100644 --- a/mocks/skipper/mock_ValueReferenceSource.go +++ b/mocks/skipper/mock_ValueReferenceSource.go @@ -162,39 +162,6 @@ func (_c *MockValueReferenceSource_RegisterPostSetHook_Call) RunAndReturn(run fu return _c } -// RegisterPreSetHook provides a mock function with given fields: _a0 -func (_m *MockValueReferenceSource) RegisterPreSetHook(_a0 skipper.SetHookFunc) { - _m.Called(_a0) -} - -// MockValueReferenceSource_RegisterPreSetHook_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RegisterPreSetHook' -type MockValueReferenceSource_RegisterPreSetHook_Call struct { - *mock.Call -} - -// RegisterPreSetHook is a helper method to define mock.On call -// - _a0 skipper.SetHookFunc -func (_e *MockValueReferenceSource_Expecter) RegisterPreSetHook(_a0 interface{}) *MockValueReferenceSource_RegisterPreSetHook_Call { - return &MockValueReferenceSource_RegisterPreSetHook_Call{Call: _e.mock.On("RegisterPreSetHook", _a0)} -} - -func (_c *MockValueReferenceSource_RegisterPreSetHook_Call) Run(run func(_a0 skipper.SetHookFunc)) *MockValueReferenceSource_RegisterPreSetHook_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(skipper.SetHookFunc)) - }) - return _c -} - -func (_c *MockValueReferenceSource_RegisterPreSetHook_Call) Return() *MockValueReferenceSource_RegisterPreSetHook_Call { - _c.Call.Return() - return _c -} - -func (_c *MockValueReferenceSource_RegisterPreSetHook_Call) RunAndReturn(run func(skipper.SetHookFunc)) *MockValueReferenceSource_RegisterPreSetHook_Call { - _c.Call.Return(run) - return _c -} - // SetPath provides a mock function with given fields: _a0, _a1 func (_m *MockValueReferenceSource) SetPath(_a0 data.Path, _a1 interface{}) error { ret := _m.Called(_a0, _a1)