diff --git a/quesma/model/query.go b/quesma/model/query.go index c85df72ba..a9be3aa89 100644 --- a/quesma/model/query.go +++ b/quesma/model/query.go @@ -3,6 +3,7 @@ package model import ( + "quesma/painful" "quesma/schema" "time" ) @@ -85,9 +86,10 @@ type ( // RuntimeMapping is a mapping of a field to a runtime expression type RuntimeMapping struct { - Field string - Type string - Expr Expr + Field string + Type string + DatabaseExpression Expr + PostProcessExpression painful.Expr } const MainExecutionPlan = "main" diff --git a/quesma/painful/.gitattributes b/quesma/painful/.gitattributes new file mode 100644 index 000000000..3365ff753 --- /dev/null +++ b/quesma/painful/.gitattributes @@ -0,0 +1,5 @@ + +# Mark *.go files as generated +# https://github.com/github-linguist/linguist/blob/master/docs/overrides.md + +generated_parser.go linguist-generated diff --git a/quesma/painful/generated_parser.go b/quesma/painful/generated_parser.go new file mode 100644 index 000000000..d2fed5377 --- /dev/null +++ b/quesma/painful/generated_parser.go @@ -0,0 +1,2010 @@ +// Copyright Quesma, licensed under the Elastic License 2.0. +// SPDX-License-Identifier: Elastic-2.0 +// Code generated by pigeon; DO NOT EDIT. +package painful + +import ( + "bytes" + "errors" + "fmt" + "io" + "math" + "os" + "sort" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" +) + +var g = &grammar{ + rules: []*rule{ + { + name: "Expr", + pos: position{line: 8, col: 1, offset: 123}, + expr: &choiceExpr{ + pos: position{line: 8, col: 9, offset: 131}, + alternatives: []any{ + &labeledExpr{ + pos: position{line: 8, col: 9, offset: 131}, + label: "expr", + expr: &ruleRefExpr{ + pos: position{line: 8, col: 14, offset: 136}, + name: "OpExpr", + }, + }, + &labeledExpr{ + pos: position{line: 8, col: 23, offset: 145}, + label: "expr", + expr: &ruleRefExpr{ + pos: position{line: 8, col: 28, offset: 150}, + name: "MethodCall", + }, + }, + &labeledExpr{ + pos: position{line: 8, col: 41, offset: 163}, + label: "expr", + expr: &ruleRefExpr{ + pos: position{line: 8, col: 46, offset: 168}, + name: "Accessor", + }, + }, + &labeledExpr{ + pos: position{line: 8, col: 57, offset: 179}, + label: "expr", + expr: &ruleRefExpr{ + pos: position{line: 8, col: 62, offset: 184}, + name: "Doc", + }, + }, + &labeledExpr{ + pos: position{line: 8, col: 68, offset: 190}, + label: "expr", + expr: &ruleRefExpr{ + pos: position{line: 8, col: 73, offset: 195}, + name: "Emit", + }, + }, + &actionExpr{ + pos: position{line: 8, col: 81, offset: 203}, + run: (*parser).callonExpr12, + expr: &labeledExpr{ + pos: position{line: 8, col: 81, offset: 203}, + label: "expr", + expr: &ruleRefExpr{ + pos: position{line: 8, col: 86, offset: 208}, + name: "String", + }, + }, + }, + }, + }, + leader: true, + leftRecursive: true, + }, + { + name: "Emit", + pos: position{line: 12, col: 1, offset: 242}, + expr: &actionExpr{ + pos: position{line: 12, col: 8, offset: 249}, + run: (*parser).callonEmit1, + expr: &seqExpr{ + pos: position{line: 12, col: 8, offset: 249}, + exprs: []any{ + &litMatcher{ + pos: position{line: 12, col: 8, offset: 249}, + val: "emit", + ignoreCase: false, + want: "\"emit\"", + }, + &litMatcher{ + pos: position{line: 12, col: 15, offset: 256}, + val: "(", + ignoreCase: false, + want: "\"(\"", + }, + &ruleRefExpr{ + pos: position{line: 12, col: 19, offset: 260}, + name: "_", + }, + &labeledExpr{ + pos: position{line: 12, col: 21, offset: 262}, + label: "expr", + expr: &ruleRefExpr{ + pos: position{line: 12, col: 26, offset: 267}, + name: "Expr", + }, + }, + &ruleRefExpr{ + pos: position{line: 12, col: 31, offset: 272}, + name: "_", + }, + &litMatcher{ + pos: position{line: 12, col: 33, offset: 274}, + val: ")", + ignoreCase: false, + want: "\")\"", + }, + }, + }, + }, + leader: false, + leftRecursive: false, + }, + { + name: "Doc", + pos: position{line: 22, col: 1, offset: 413}, + expr: &actionExpr{ + pos: position{line: 22, col: 7, offset: 419}, + run: (*parser).callonDoc1, + expr: &seqExpr{ + pos: position{line: 22, col: 7, offset: 419}, + exprs: []any{ + &litMatcher{ + pos: position{line: 22, col: 7, offset: 419}, + val: "doc", + ignoreCase: false, + want: "\"doc\"", + }, + &litMatcher{ + pos: position{line: 22, col: 13, offset: 425}, + val: "[", + ignoreCase: false, + want: "\"[\"", + }, + &labeledExpr{ + pos: position{line: 22, col: 17, offset: 429}, + label: "key", + expr: &ruleRefExpr{ + pos: position{line: 22, col: 21, offset: 433}, + name: "Expr", + }, + }, + &litMatcher{ + pos: position{line: 22, col: 27, offset: 439}, + val: "]", + ignoreCase: false, + want: "\"]\"", + }, + }, + }, + }, + leader: false, + leftRecursive: false, + }, + { + name: "Accessor", + pos: position{line: 32, col: 1, offset: 581}, + expr: &actionExpr{ + pos: position{line: 32, col: 12, offset: 592}, + run: (*parser).callonAccessor1, + expr: &seqExpr{ + pos: position{line: 32, col: 12, offset: 592}, + exprs: []any{ + &labeledExpr{ + pos: position{line: 32, col: 12, offset: 592}, + label: "expr", + expr: &ruleRefExpr{ + pos: position{line: 32, col: 17, offset: 597}, + name: "Expr", + }, + }, + &litMatcher{ + pos: position{line: 32, col: 22, offset: 602}, + val: ".", + ignoreCase: false, + want: "\".\"", + }, + &labeledExpr{ + pos: position{line: 32, col: 26, offset: 606}, + label: "field", + expr: &ruleRefExpr{ + pos: position{line: 32, col: 32, offset: 612}, + name: "Identifier", + }, + }, + }, + }, + }, + leader: false, + leftRecursive: true, + }, + { + name: "MethodCall", + pos: position{line: 47, col: 1, offset: 898}, + expr: &actionExpr{ + pos: position{line: 47, col: 14, offset: 911}, + run: (*parser).callonMethodCall1, + expr: &seqExpr{ + pos: position{line: 47, col: 14, offset: 911}, + exprs: []any{ + &labeledExpr{ + pos: position{line: 47, col: 14, offset: 911}, + label: "expr", + expr: &ruleRefExpr{ + pos: position{line: 47, col: 19, offset: 916}, + name: "Expr", + }, + }, + &litMatcher{ + pos: position{line: 47, col: 24, offset: 921}, + val: ".", + ignoreCase: false, + want: "\".\"", + }, + &labeledExpr{ + pos: position{line: 47, col: 28, offset: 925}, + label: "method", + expr: &ruleRefExpr{ + pos: position{line: 47, col: 35, offset: 932}, + name: "Identifier", + }, + }, + &litMatcher{ + pos: position{line: 47, col: 46, offset: 943}, + val: "(", + ignoreCase: false, + want: "\"(\"", + }, + &labeledExpr{ + pos: position{line: 47, col: 50, offset: 947}, + label: "args", + expr: &zeroOrMoreExpr{ + pos: position{line: 47, col: 55, offset: 952}, + expr: &ruleRefExpr{ + pos: position{line: 47, col: 55, offset: 952}, + name: "Expr", + }, + }, + }, + &zeroOrOneExpr{ + pos: position{line: 47, col: 61, offset: 958}, + expr: &litMatcher{ + pos: position{line: 47, col: 61, offset: 958}, + val: ",", + ignoreCase: false, + want: "\",\"", + }, + }, + &litMatcher{ + pos: position{line: 47, col: 67, offset: 964}, + val: ")", + ignoreCase: false, + want: "\")\"", + }, + }, + }, + }, + leader: false, + leftRecursive: true, + }, + { + name: "OpExpr", + pos: position{line: 91, col: 1, offset: 1894}, + expr: &actionExpr{ + pos: position{line: 91, col: 10, offset: 1903}, + run: (*parser).callonOpExpr1, + expr: &seqExpr{ + pos: position{line: 91, col: 10, offset: 1903}, + exprs: []any{ + &labeledExpr{ + pos: position{line: 91, col: 10, offset: 1903}, + label: "left", + expr: &ruleRefExpr{ + pos: position{line: 91, col: 15, offset: 1908}, + name: "Expr", + }, + }, + &ruleRefExpr{ + pos: position{line: 91, col: 20, offset: 1913}, + name: "_", + }, + &labeledExpr{ + pos: position{line: 91, col: 23, offset: 1916}, + label: "op", + expr: &ruleRefExpr{ + pos: position{line: 91, col: 26, offset: 1919}, + name: "Op", + }, + }, + &ruleRefExpr{ + pos: position{line: 91, col: 29, offset: 1922}, + name: "_", + }, + &labeledExpr{ + pos: position{line: 91, col: 32, offset: 1925}, + label: "right", + expr: &ruleRefExpr{ + pos: position{line: 91, col: 38, offset: 1931}, + name: "Expr", + }, + }, + }, + }, + }, + leader: false, + leftRecursive: true, + }, + { + name: "Op", + pos: position{line: 110, col: 1, offset: 2300}, + expr: &actionExpr{ + pos: position{line: 110, col: 6, offset: 2305}, + run: (*parser).callonOp1, + expr: &labeledExpr{ + pos: position{line: 110, col: 6, offset: 2305}, + label: "op", + expr: &litMatcher{ + pos: position{line: 110, col: 9, offset: 2308}, + val: "+", + ignoreCase: false, + want: "\"+\"", + }, + }, + }, + leader: false, + leftRecursive: false, + }, + { + name: "String", + pos: position{line: 114, col: 1, offset: 2349}, + expr: &actionExpr{ + pos: position{line: 114, col: 10, offset: 2358}, + run: (*parser).callonString1, + expr: &seqExpr{ + pos: position{line: 114, col: 10, offset: 2358}, + exprs: []any{ + &litMatcher{ + pos: position{line: 114, col: 10, offset: 2358}, + val: "'", + ignoreCase: false, + want: "\"'\"", + }, + &labeledExpr{ + pos: position{line: 114, col: 15, offset: 2363}, + label: "s", + expr: &zeroOrMoreExpr{ + pos: position{line: 114, col: 17, offset: 2365}, + expr: &charClassMatcher{ + pos: position{line: 114, col: 17, offset: 2365}, + val: "[^']", + chars: []rune{'\''}, + ignoreCase: false, + inverted: true, + }, + }, + }, + &litMatcher{ + pos: position{line: 114, col: 23, offset: 2371}, + val: "'", + ignoreCase: false, + want: "\"'\"", + }, + }, + }, + }, + leader: false, + leftRecursive: false, + }, + { + name: "Identifier", + pos: position{line: 121, col: 1, offset: 2494}, + expr: &actionExpr{ + pos: position{line: 121, col: 14, offset: 2507}, + run: (*parser).callonIdentifier1, + expr: &labeledExpr{ + pos: position{line: 121, col: 14, offset: 2507}, + label: "id", + expr: &oneOrMoreExpr{ + pos: position{line: 121, col: 17, offset: 2510}, + expr: &charClassMatcher{ + pos: position{line: 121, col: 17, offset: 2510}, + val: "[a-zA-Z0-9_]", + chars: []rune{'_'}, + ranges: []rune{'a', 'z', 'A', 'Z', '0', '9'}, + ignoreCase: false, + inverted: false, + }, + }, + }, + }, + leader: false, + leftRecursive: false, + }, + { + name: "_", + displayName: "\"whitespace\"", + pos: position{line: 125, col: 1, offset: 2559}, + expr: &zeroOrMoreExpr{ + pos: position{line: 125, col: 19, offset: 2577}, + expr: &charClassMatcher{ + pos: position{line: 125, col: 19, offset: 2577}, + val: "[ \\n\\t\\r]", + chars: []rune{' ', '\n', '\t', '\r'}, + ignoreCase: false, + inverted: false, + }, + }, + leader: false, + leftRecursive: false, + }, + { + name: "EOF", + pos: position{line: 127, col: 1, offset: 2589}, + expr: ¬Expr{ + pos: position{line: 128, col: 5, offset: 2598}, + expr: &anyMatcher{ + line: 128, col: 6, offset: 2599, + }, + }, + leader: false, + leftRecursive: false, + }, + }, +} + +func (c *current) onExpr12(expr any) (any, error) { + return expr, nil +} + +func (p *parser) callonExpr12() (any, error) { + stack := p.vstack[len(p.vstack)-1] + _ = stack + return p.cur.onExpr12(stack["expr"]) +} + +func (c *current) onEmit1(expr any) (any, error) { + + exprVal, err := ExpectExpr(expr) + if err != nil { + return nil, err + } + + return &EmitExpr{Expr: exprVal}, nil +} + +func (p *parser) callonEmit1() (any, error) { + stack := p.vstack[len(p.vstack)-1] + _ = stack + return p.cur.onEmit1(stack["expr"]) +} + +func (c *current) onDoc1(key any) (any, error) { + + exprVal, err := ExpectExpr(key) + if err != nil { + return nil, err + } + + return &DocExpr{FieldName: exprVal}, nil +} + +func (p *parser) callonDoc1() (any, error) { + stack := p.vstack[len(p.vstack)-1] + _ = stack + return p.cur.onDoc1(stack["key"]) +} + +func (c *current) onAccessor1(expr, field any) (any, error) { + + exprVal, err := ExpectExpr(expr) + if err != nil { + return nil, err + } + + strVal, err := ExpectString(field) + if err != nil { + return nil, err + } + + return &AccessorExpr{Position: c.pos.String(), Expr: exprVal, PropertyName: strVal}, nil +} + +func (p *parser) callonAccessor1() (any, error) { + stack := p.vstack[len(p.vstack)-1] + _ = stack + return p.cur.onAccessor1(stack["expr"], stack["field"]) +} + +func (c *current) onMethodCall1(expr, method, args any) (any, error) { + + exprVal, err := ExpectExpr(expr) + if err != nil { + return nil, err + } + + strVal, err := ExpectString(method) + if err != nil { + return nil, err + } + + var argsVal []Expr + + switch argsVals := args.(type) { + + case nil: + argsVal = []Expr{} + case []any: + + for _, arg := range argsVals { + argVal, err := ExpectExpr(arg) + if err != nil { + return nil, err + } + argsVal = append(argsVal, argVal) + } + + default: + return nil, fmt.Errorf("internal parser error. '%T' is not valid method argument", args) + } + + for _, arg := range argsVal { + argVal, err := ExpectExpr(arg) + if err != nil { + return nil, err + } + argsVal = append(argsVal, argVal) + } + + return &MethodCallExpr{Position: c.pos.String(), Expr: exprVal, MethodName: strVal, Args: argsVal}, nil +} + +func (p *parser) callonMethodCall1() (any, error) { + stack := p.vstack[len(p.vstack)-1] + _ = stack + return p.cur.onMethodCall1(stack["expr"], stack["method"], stack["args"]) +} + +func (c *current) onOpExpr1(left, op, right any) (any, error) { + leftVal, err := ExpectExpr(left) + if err != nil { + return nil, err + } + + rightVal, err := ExpectExpr(right) + if err != nil { + return nil, err + } + + opVal, err := ExpectString(op) + if err != nil { + return nil, err + } + + return &InfixOpExpr{Position: c.pos.String(), Left: leftVal, Op: opVal, Right: rightVal}, nil +} + +func (p *parser) callonOpExpr1() (any, error) { + stack := p.vstack[len(p.vstack)-1] + _ = stack + return p.cur.onOpExpr1(stack["left"], stack["op"], stack["right"]) +} + +func (c *current) onOp1(op any) (any, error) { + return string(c.text), nil +} + +func (p *parser) callonOp1() (any, error) { + stack := p.vstack[len(p.vstack)-1] + _ = stack + return p.cur.onOp1(stack["op"]) +} + +func (c *current) onString1(s any) (any, error) { + + strVal := string(c.text) + strVal = strings.Trim(strVal, "'") + return &LiteralExpr{Value: strVal}, nil +} + +func (p *parser) callonString1() (any, error) { + stack := p.vstack[len(p.vstack)-1] + _ = stack + return p.cur.onString1(stack["s"]) +} + +func (c *current) onIdentifier1(id any) (any, error) { + return string(c.text), nil +} + +func (p *parser) callonIdentifier1() (any, error) { + stack := p.vstack[len(p.vstack)-1] + _ = stack + return p.cur.onIdentifier1(stack["id"]) +} + +var ( + // errNoRule is returned when the grammar to parse has no rule. + errNoRule = errors.New("grammar has no rule") + + // errInvalidEntrypoint is returned when the specified entrypoint rule + // does not exit. + errInvalidEntrypoint = errors.New("invalid entrypoint") + + // errInvalidEncoding is returned when the source is not properly + // utf8-encoded. + errInvalidEncoding = errors.New("invalid encoding") + + // errMaxExprCnt is used to signal that the maximum number of + // expressions have been parsed. + errMaxExprCnt = errors.New("max number of expressions parsed") +) + +// Option is a function that can set an option on the parser. It returns +// the previous setting as an Option. +type Option func(*parser) Option + +// MaxExpressions creates an Option to stop parsing after the provided +// number of expressions have been parsed, if the value is 0 then the parser will +// parse for as many steps as needed (possibly an infinite number). +// +// The default for maxExprCnt is 0. +func MaxExpressions(maxExprCnt uint64) Option { + return func(p *parser) Option { + oldMaxExprCnt := p.maxExprCnt + p.maxExprCnt = maxExprCnt + return MaxExpressions(oldMaxExprCnt) + } +} + +// Entrypoint creates an Option to set the rule name to use as entrypoint. +// The rule name must have been specified in the -alternate-entrypoints +// if generating the parser with the -optimize-grammar flag, otherwise +// it may have been optimized out. Passing an empty string sets the +// entrypoint to the first rule in the grammar. +// +// The default is to start parsing at the first rule in the grammar. +func Entrypoint(ruleName string) Option { + return func(p *parser) Option { + oldEntrypoint := p.entrypoint + p.entrypoint = ruleName + if ruleName == "" { + p.entrypoint = g.rules[0].name + } + return Entrypoint(oldEntrypoint) + } +} + +// Statistics adds a user provided Stats struct to the parser to allow +// the user to process the results after the parsing has finished. +// Also the key for the "no match" counter is set. +// +// Example usage: +// +// input := "input" +// stats := Stats{} +// _, err := Parse("input-file", []byte(input), Statistics(&stats, "no match")) +// if err != nil { +// log.Panicln(err) +// } +// b, err := json.MarshalIndent(stats.ChoiceAltCnt, "", " ") +// if err != nil { +// log.Panicln(err) +// } +// fmt.Println(string(b)) +func Statistics(stats *Stats, choiceNoMatch string) Option { + return func(p *parser) Option { + oldStats := p.Stats + p.Stats = stats + oldChoiceNoMatch := p.choiceNoMatch + p.choiceNoMatch = choiceNoMatch + if p.Stats.ChoiceAltCnt == nil { + p.Stats.ChoiceAltCnt = make(map[string]map[string]int) + } + return Statistics(oldStats, oldChoiceNoMatch) + } +} + +// Debug creates an Option to set the debug flag to b. When set to true, +// debugging information is printed to stdout while parsing. +// +// The default is false. +func Debug(b bool) Option { + return func(p *parser) Option { + old := p.debug + p.debug = b + return Debug(old) + } +} + +// Memoize creates an Option to set the memoize flag to b. When set to true, +// the parser will cache all results so each expression is evaluated only +// once. This guarantees linear parsing time even for pathological cases, +// at the expense of more memory and slower times for typical cases. +// +// The default is false. +func Memoize(b bool) Option { + return func(p *parser) Option { + old := p.memoize + p.memoize = b + return Memoize(old) + } +} + +// AllowInvalidUTF8 creates an Option to allow invalid UTF-8 bytes. +// Every invalid UTF-8 byte is treated as a utf8.RuneError (U+FFFD) +// by character class matchers and is matched by the any matcher. +// The returned matched value, c.text and c.offset are NOT affected. +// +// The default is false. +func AllowInvalidUTF8(b bool) Option { + return func(p *parser) Option { + old := p.allowInvalidUTF8 + p.allowInvalidUTF8 = b + return AllowInvalidUTF8(old) + } +} + +// Recover creates an Option to set the recover flag to b. When set to +// true, this causes the parser to recover from panics and convert it +// to an error. Setting it to false can be useful while debugging to +// access the full stack trace. +// +// The default is true. +func Recover(b bool) Option { + return func(p *parser) Option { + old := p.recover + p.recover = b + return Recover(old) + } +} + +// GlobalStore creates an Option to set a key to a certain value in +// the globalStore. +func GlobalStore(key string, value any) Option { + return func(p *parser) Option { + old := p.cur.globalStore[key] + p.cur.globalStore[key] = value + return GlobalStore(key, old) + } +} + +// InitState creates an Option to set a key to a certain value in +// the global "state" store. +func InitState(key string, value any) Option { + return func(p *parser) Option { + old := p.cur.state[key] + p.cur.state[key] = value + return InitState(key, old) + } +} + +// ParseFile parses the file identified by filename. +func ParseFile(filename string, opts ...Option) (i any, err error) { // nolint: deadcode + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer func() { + if closeErr := f.Close(); closeErr != nil { + err = closeErr + } + }() + return ParseReader(filename, f, opts...) +} + +// ParseReader parses the data from r using filename as information in the +// error messages. +func ParseReader(filename string, r io.Reader, opts ...Option) (any, error) { // nolint: deadcode + b, err := io.ReadAll(r) + if err != nil { + return nil, err + } + + return Parse(filename, b, opts...) +} + +// Parse parses the data from b using filename as information in the +// error messages. +func Parse(filename string, b []byte, opts ...Option) (any, error) { + return newParser(filename, b, opts...).parse(g) +} + +// position records a position in the text. +type position struct { + line, col, offset int +} + +func (p position) String() string { + return strconv.Itoa(p.line) + ":" + strconv.Itoa(p.col) + " [" + strconv.Itoa(p.offset) + "]" +} + +// savepoint stores all state required to go back to this point in the +// parser. +type savepoint struct { + position + rn rune + w int +} + +type current struct { + pos position // start position of the match + text []byte // raw text of the match + + // state is a store for arbitrary key,value pairs that the user wants to be + // tied to the backtracking of the parser. + // This is always rolled back if a parsing rule fails. + state storeDict + + // globalStore is a general store for the user to store arbitrary key-value + // pairs that they need to manage and that they do not want tied to the + // backtracking of the parser. This is only modified by the user and never + // rolled back by the parser. It is always up to the user to keep this in a + // consistent state. + globalStore storeDict +} + +type storeDict map[string]any + +// the AST types... + +// nolint: structcheck +type grammar struct { + pos position + rules []*rule +} + +// nolint: structcheck +type rule struct { + pos position + name string + displayName string + expr any + + leader bool + leftRecursive bool +} + +// nolint: structcheck +type choiceExpr struct { + pos position + alternatives []any +} + +// nolint: structcheck +type actionExpr struct { + pos position + expr any + run func(*parser) (any, error) +} + +// nolint: structcheck +type recoveryExpr struct { + pos position + expr any + recoverExpr any + failureLabel []string +} + +// nolint: structcheck +type seqExpr struct { + pos position + exprs []any +} + +// nolint: structcheck +type throwExpr struct { + pos position + label string +} + +// nolint: structcheck +type labeledExpr struct { + pos position + label string + expr any +} + +// nolint: structcheck +type expr struct { + pos position + expr any +} + +type ( + andExpr expr // nolint: structcheck + notExpr expr // nolint: structcheck + zeroOrOneExpr expr // nolint: structcheck + zeroOrMoreExpr expr // nolint: structcheck + oneOrMoreExpr expr // nolint: structcheck +) + +// nolint: structcheck +type ruleRefExpr struct { + pos position + name string +} + +// nolint: structcheck +type stateCodeExpr struct { + pos position + run func(*parser) error +} + +// nolint: structcheck +type andCodeExpr struct { + pos position + run func(*parser) (bool, error) +} + +// nolint: structcheck +type notCodeExpr struct { + pos position + run func(*parser) (bool, error) +} + +// nolint: structcheck +type litMatcher struct { + pos position + val string + ignoreCase bool + want string +} + +// nolint: structcheck +type charClassMatcher struct { + pos position + val string + basicLatinChars [128]bool + chars []rune + ranges []rune + classes []*unicode.RangeTable + ignoreCase bool + inverted bool +} + +type anyMatcher position // nolint: structcheck + +// errList cumulates the errors found by the parser. +type errList []error + +func (e *errList) add(err error) { + *e = append(*e, err) +} + +func (e errList) err() error { + if len(e) == 0 { + return nil + } + e.dedupe() + return e +} + +func (e *errList) dedupe() { + var cleaned []error + set := make(map[string]bool) + for _, err := range *e { + if msg := err.Error(); !set[msg] { + set[msg] = true + cleaned = append(cleaned, err) + } + } + *e = cleaned +} + +func (e errList) Error() string { + switch len(e) { + case 0: + return "" + case 1: + return e[0].Error() + default: + var buf bytes.Buffer + + for i, err := range e { + if i > 0 { + buf.WriteRune('\n') + } + buf.WriteString(err.Error()) + } + return buf.String() + } +} + +// parserError wraps an error with a prefix indicating the rule in which +// the error occurred. The original error is stored in the Inner field. +type parserError struct { + Inner error + pos position + prefix string + expected []string +} + +// Error returns the error message. +func (p *parserError) Error() string { + return p.prefix + ": " + p.Inner.Error() +} + +// newParser creates a parser with the specified input source and options. +func newParser(filename string, b []byte, opts ...Option) *parser { + stats := Stats{ + ChoiceAltCnt: make(map[string]map[string]int), + } + + p := &parser{ + filename: filename, + errs: new(errList), + data: b, + pt: savepoint{position: position{line: 1}}, + recover: true, + cur: current{ + state: make(storeDict), + globalStore: make(storeDict), + }, + maxFailPos: position{col: 1, line: 1}, + maxFailExpected: make([]string, 0, 20), + Stats: &stats, + // start rule is rule [0] unless an alternate entrypoint is specified + entrypoint: g.rules[0].name, + } + p.setOptions(opts) + + if p.maxExprCnt == 0 { + p.maxExprCnt = math.MaxUint64 + } + + return p +} + +// setOptions applies the options to the parser. +func (p *parser) setOptions(opts []Option) { + for _, opt := range opts { + opt(p) + } +} + +// nolint: structcheck,deadcode +type resultTuple struct { + v any + b bool + end savepoint +} + +// nolint: varcheck +const choiceNoMatch = -1 + +// Stats stores some statistics, gathered during parsing +type Stats struct { + // ExprCnt counts the number of expressions processed during parsing + // This value is compared to the maximum number of expressions allowed + // (set by the MaxExpressions option). + ExprCnt uint64 + + // ChoiceAltCnt is used to count for each ordered choice expression, + // which alternative is used how may times. + // These numbers allow to optimize the order of the ordered choice expression + // to increase the performance of the parser + // + // The outer key of ChoiceAltCnt is composed of the name of the rule as well + // as the line and the column of the ordered choice. + // The inner key of ChoiceAltCnt is the number (one-based) of the matching alternative. + // For each alternative the number of matches are counted. If an ordered choice does not + // match, a special counter is incremented. The name of this counter is set with + // the parser option Statistics. + // For an alternative to be included in ChoiceAltCnt, it has to match at least once. + ChoiceAltCnt map[string]map[string]int +} + +type ruleWithExpsStack struct { + rule *rule + estack []any +} + +// nolint: structcheck,maligned +type parser struct { + filename string + pt savepoint + cur current + + data []byte + errs *errList + + depth int + recover bool + debug bool + + memoize bool + // memoization table for the packrat algorithm: + // map[offset in source] map[expression or rule] {value, match} + memo map[int]map[any]resultTuple + + // rules table, maps the rule identifier to the rule node + rules map[string]*rule + // variables stack, map of label to value + vstack []map[string]any + // rule stack, allows identification of the current rule in errors + rstack []*rule + + // parse fail + maxFailPos position + maxFailExpected []string + maxFailInvertExpected bool + + // max number of expressions to be parsed + maxExprCnt uint64 + // entrypoint for the parser + entrypoint string + + allowInvalidUTF8 bool + + *Stats + + choiceNoMatch string + // recovery expression stack, keeps track of the currently available recovery expression, these are traversed in reverse + recoveryStack []map[string]any +} + +// push a variable set on the vstack. +func (p *parser) pushV() { + if cap(p.vstack) == len(p.vstack) { + // create new empty slot in the stack + p.vstack = append(p.vstack, nil) + } else { + // slice to 1 more + p.vstack = p.vstack[:len(p.vstack)+1] + } + + // get the last args set + m := p.vstack[len(p.vstack)-1] + if m != nil && len(m) == 0 { + // empty map, all good + return + } + + m = make(map[string]any) + p.vstack[len(p.vstack)-1] = m +} + +// pop a variable set from the vstack. +func (p *parser) popV() { + // if the map is not empty, clear it + m := p.vstack[len(p.vstack)-1] + if len(m) > 0 { + // GC that map + p.vstack[len(p.vstack)-1] = nil + } + p.vstack = p.vstack[:len(p.vstack)-1] +} + +// push a recovery expression with its labels to the recoveryStack +func (p *parser) pushRecovery(labels []string, expr any) { + if cap(p.recoveryStack) == len(p.recoveryStack) { + // create new empty slot in the stack + p.recoveryStack = append(p.recoveryStack, nil) + } else { + // slice to 1 more + p.recoveryStack = p.recoveryStack[:len(p.recoveryStack)+1] + } + + m := make(map[string]any, len(labels)) + for _, fl := range labels { + m[fl] = expr + } + p.recoveryStack[len(p.recoveryStack)-1] = m +} + +// pop a recovery expression from the recoveryStack +func (p *parser) popRecovery() { + // GC that map + p.recoveryStack[len(p.recoveryStack)-1] = nil + + p.recoveryStack = p.recoveryStack[:len(p.recoveryStack)-1] +} + +func (p *parser) print(prefix, s string) string { + if !p.debug { + return s + } + + fmt.Printf("%s %d:%d:%d: %s [%#U]\n", + prefix, p.pt.line, p.pt.col, p.pt.offset, s, p.pt.rn) + return s +} + +func (p *parser) printIndent(mark string, s string) string { + return p.print(strings.Repeat(" ", p.depth)+mark, s) +} + +func (p *parser) in(s string) string { + res := p.printIndent(">", s) + p.depth++ + return res +} + +func (p *parser) out(s string) string { + p.depth-- + return p.printIndent("<", s) +} + +func (p *parser) addErr(err error) { + p.addErrAt(err, p.pt.position, []string{}) +} + +func (p *parser) addErrAt(err error, pos position, expected []string) { + var buf bytes.Buffer + if p.filename != "" { + buf.WriteString(p.filename) + } + if buf.Len() > 0 { + buf.WriteString(":") + } + buf.WriteString(fmt.Sprintf("%d:%d (%d)", pos.line, pos.col, pos.offset)) + if len(p.rstack) > 0 { + if buf.Len() > 0 { + buf.WriteString(": ") + } + rule := p.rstack[len(p.rstack)-1] + if rule.displayName != "" { + buf.WriteString("rule " + rule.displayName) + } else { + buf.WriteString("rule " + rule.name) + } + } + pe := &parserError{Inner: err, pos: pos, prefix: buf.String(), expected: expected} + p.errs.add(pe) +} + +func (p *parser) failAt(fail bool, pos position, want string) { + // process fail if parsing fails and not inverted or parsing succeeds and invert is set + if fail == p.maxFailInvertExpected { + if pos.offset < p.maxFailPos.offset { + return + } + + if pos.offset > p.maxFailPos.offset { + p.maxFailPos = pos + p.maxFailExpected = p.maxFailExpected[:0] + } + + if p.maxFailInvertExpected { + want = "!" + want + } + p.maxFailExpected = append(p.maxFailExpected, want) + } +} + +// read advances the parser to the next rune. +func (p *parser) read() { + p.pt.offset += p.pt.w + rn, n := utf8.DecodeRune(p.data[p.pt.offset:]) + p.pt.rn = rn + p.pt.w = n + p.pt.col++ + if rn == '\n' { + p.pt.line++ + p.pt.col = 0 + } + + if rn == utf8.RuneError && n == 1 { // see utf8.DecodeRune + if !p.allowInvalidUTF8 { + p.addErr(errInvalidEncoding) + } + } +} + +// restore parser position to the savepoint pt. +func (p *parser) restore(pt savepoint) { + if p.debug { + defer p.out(p.in("restore")) + } + if pt.offset == p.pt.offset { + return + } + p.pt = pt +} + +// Cloner is implemented by any value that has a Clone method, which returns a +// copy of the value. This is mainly used for types which are not passed by +// value (e.g map, slice, chan) or structs that contain such types. +// +// This is used in conjunction with the global state feature to create proper +// copies of the state to allow the parser to properly restore the state in +// the case of backtracking. +type Cloner interface { + Clone() any +} + +var statePool = &sync.Pool{ + New: func() any { return make(storeDict) }, +} + +func (sd storeDict) Discard() { + for k := range sd { + delete(sd, k) + } + statePool.Put(sd) +} + +// clone and return parser current state. +func (p *parser) cloneState() storeDict { + if p.debug { + defer p.out(p.in("cloneState")) + } + + state := statePool.Get().(storeDict) + for k, v := range p.cur.state { + if c, ok := v.(Cloner); ok { + state[k] = c.Clone() + } else { + state[k] = v + } + } + return state +} + +// restore parser current state to the state storeDict. +// every restoreState should applied only one time for every cloned state +func (p *parser) restoreState(state storeDict) { + if p.debug { + defer p.out(p.in("restoreState")) + } + p.cur.state.Discard() + p.cur.state = state +} + +// get the slice of bytes from the savepoint start to the current position. +func (p *parser) sliceFrom(start savepoint) []byte { + return p.data[start.position.offset:p.pt.position.offset] +} + +func (p *parser) getMemoized(node any) (resultTuple, bool) { + if len(p.memo) == 0 { + return resultTuple{}, false + } + m := p.memo[p.pt.offset] + if len(m) == 0 { + return resultTuple{}, false + } + res, ok := m[node] + return res, ok +} + +func (p *parser) setMemoized(pt savepoint, node any, tuple resultTuple) { + if p.memo == nil { + p.memo = make(map[int]map[any]resultTuple) + } + m := p.memo[pt.offset] + if m == nil { + m = make(map[any]resultTuple) + p.memo[pt.offset] = m + } + m[node] = tuple +} + +func (p *parser) buildRulesTable(g *grammar) { + p.rules = make(map[string]*rule, len(g.rules)) + for _, r := range g.rules { + p.rules[r.name] = r + } +} + +// nolint: gocyclo +func (p *parser) parse(g *grammar) (val any, err error) { + if len(g.rules) == 0 { + p.addErr(errNoRule) + return nil, p.errs.err() + } + + // TODO : not super critical but this could be generated + p.buildRulesTable(g) + + if p.recover { + // panic can be used in action code to stop parsing immediately + // and return the panic as an error. + defer func() { + if e := recover(); e != nil { + if p.debug { + defer p.out(p.in("panic handler")) + } + val = nil + switch e := e.(type) { + case error: + p.addErr(e) + default: + p.addErr(fmt.Errorf("%v", e)) + } + err = p.errs.err() + } + }() + } + + startRule, ok := p.rules[p.entrypoint] + if !ok { + p.addErr(errInvalidEntrypoint) + return nil, p.errs.err() + } + + p.read() // advance to first rune + val, ok = p.parseRuleWrap(startRule) + if !ok { + if len(*p.errs) == 0 { + // If parsing fails, but no errors have been recorded, the expected values + // for the farthest parser position are returned as error. + maxFailExpectedMap := make(map[string]struct{}, len(p.maxFailExpected)) + for _, v := range p.maxFailExpected { + maxFailExpectedMap[v] = struct{}{} + } + expected := make([]string, 0, len(maxFailExpectedMap)) + eof := false + if _, ok := maxFailExpectedMap["!."]; ok { + delete(maxFailExpectedMap, "!.") + eof = true + } + for k := range maxFailExpectedMap { + expected = append(expected, k) + } + sort.Strings(expected) + if eof { + expected = append(expected, "EOF") + } + p.addErrAt(errors.New("no match found, expected: "+listJoin(expected, ", ", "or")), p.maxFailPos, expected) + } + + return nil, p.errs.err() + } + return val, p.errs.err() +} + +func listJoin(list []string, sep string, lastSep string) string { + switch len(list) { + case 0: + return "" + case 1: + return list[0] + default: + return strings.Join(list[:len(list)-1], sep) + " " + lastSep + " " + list[len(list)-1] + } +} + +func (p *parser) parseRuleRecursiveLeader(rule *rule) (any, bool) { + result, ok := p.getMemoized(rule) + if ok { + p.restore(result.end) + return result.v, result.b + } + + if p.debug { + defer p.out(p.in("recursive " + rule.name)) + } + + var ( + depth = 0 + startMark = p.pt + lastResult = resultTuple{nil, false, startMark} + lastErrors = *p.errs + ) + + for { + lastState := p.cloneState() + p.setMemoized(startMark, rule, lastResult) + val, ok := p.parseRule(rule) + endMark := p.pt + if p.debug { + p.printIndent("RECURSIVE", fmt.Sprintf( + "Rule %s depth %d: %t -> %s", + rule.name, depth, ok, string(p.sliceFrom(startMark)))) + } + if (!ok) || (endMark.offset <= lastResult.end.offset && depth != 0) { + p.restoreState(lastState) + *p.errs = lastErrors + break + } + lastResult = resultTuple{val, ok, endMark} + lastErrors = *p.errs + p.restore(startMark) + depth++ + } + + p.restore(lastResult.end) + p.setMemoized(startMark, rule, lastResult) + return lastResult.v, lastResult.b +} + +func (p *parser) parseRuleRecursiveNoLeader(rule *rule) (any, bool) { + return p.parseRule(rule) +} + +func (p *parser) parseRuleMemoize(rule *rule) (any, bool) { + res, ok := p.getMemoized(rule) + if ok { + p.restore(res.end) + return res.v, res.b + } + + startMark := p.pt + val, ok := p.parseRule(rule) + p.setMemoized(startMark, rule, resultTuple{val, ok, p.pt}) + + return val, ok +} + +func (p *parser) parseRuleWrap(rule *rule) (any, bool) { + if p.debug { + defer p.out(p.in("parseRule " + rule.name)) + } + var ( + val any + ok bool + startMark = p.pt + ) + + if p.memoize || rule.leftRecursive { + if rule.leader { + val, ok = p.parseRuleRecursiveLeader(rule) + } else if p.memoize && !rule.leftRecursive { + val, ok = p.parseRuleMemoize(rule) + } else { + val, ok = p.parseRuleRecursiveNoLeader(rule) + } + } else { + val, ok = p.parseRule(rule) + } + + if ok && p.debug { + p.printIndent("MATCH", string(p.sliceFrom(startMark))) + } + return val, ok +} + +func (p *parser) parseRule(rule *rule) (any, bool) { + p.rstack = append(p.rstack, rule) + p.pushV() + val, ok := p.parseExprWrap(rule.expr) + p.popV() + p.rstack = p.rstack[:len(p.rstack)-1] + return val, ok +} + +func (p *parser) parseExprWrap(expr any) (any, bool) { + var pt savepoint + + isLeftRecursion := p.rstack[len(p.rstack)-1].leftRecursive + if p.memoize && !isLeftRecursion { + res, ok := p.getMemoized(expr) + if ok { + p.restore(res.end) + return res.v, res.b + } + pt = p.pt + } + + val, ok := p.parseExpr(expr) + + if p.memoize && !isLeftRecursion { + p.setMemoized(pt, expr, resultTuple{val, ok, p.pt}) + } + return val, ok +} + +// nolint: gocyclo +func (p *parser) parseExpr(expr any) (any, bool) { + p.ExprCnt++ + if p.ExprCnt > p.maxExprCnt { + panic(errMaxExprCnt) + } + + var val any + var ok bool + switch expr := expr.(type) { + case *actionExpr: + val, ok = p.parseActionExpr(expr) + case *andCodeExpr: + val, ok = p.parseAndCodeExpr(expr) + case *andExpr: + val, ok = p.parseAndExpr(expr) + case *anyMatcher: + val, ok = p.parseAnyMatcher(expr) + case *charClassMatcher: + val, ok = p.parseCharClassMatcher(expr) + case *choiceExpr: + val, ok = p.parseChoiceExpr(expr) + case *labeledExpr: + val, ok = p.parseLabeledExpr(expr) + case *litMatcher: + val, ok = p.parseLitMatcher(expr) + case *notCodeExpr: + val, ok = p.parseNotCodeExpr(expr) + case *notExpr: + val, ok = p.parseNotExpr(expr) + case *oneOrMoreExpr: + val, ok = p.parseOneOrMoreExpr(expr) + case *recoveryExpr: + val, ok = p.parseRecoveryExpr(expr) + case *ruleRefExpr: + val, ok = p.parseRuleRefExpr(expr) + case *seqExpr: + val, ok = p.parseSeqExpr(expr) + case *stateCodeExpr: + val, ok = p.parseStateCodeExpr(expr) + case *throwExpr: + val, ok = p.parseThrowExpr(expr) + case *zeroOrMoreExpr: + val, ok = p.parseZeroOrMoreExpr(expr) + case *zeroOrOneExpr: + val, ok = p.parseZeroOrOneExpr(expr) + default: + panic(fmt.Sprintf("unknown expression type %T", expr)) + } + return val, ok +} + +func (p *parser) parseActionExpr(act *actionExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseActionExpr")) + } + + start := p.pt + val, ok := p.parseExprWrap(act.expr) + if ok { + p.cur.pos = start.position + p.cur.text = p.sliceFrom(start) + state := p.cloneState() + actVal, err := act.run(p) + if err != nil { + p.addErrAt(err, start.position, []string{}) + } + p.restoreState(state) + + val = actVal + } + if ok && p.debug { + p.printIndent("MATCH", string(p.sliceFrom(start))) + } + return val, ok +} + +func (p *parser) parseAndCodeExpr(and *andCodeExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseAndCodeExpr")) + } + + state := p.cloneState() + + ok, err := and.run(p) + if err != nil { + p.addErr(err) + } + p.restoreState(state) + + return nil, ok +} + +func (p *parser) parseAndExpr(and *andExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseAndExpr")) + } + + pt := p.pt + state := p.cloneState() + p.pushV() + _, ok := p.parseExprWrap(and.expr) + p.popV() + p.restoreState(state) + p.restore(pt) + + return nil, ok +} + +func (p *parser) parseAnyMatcher(any *anyMatcher) (any, bool) { + if p.debug { + defer p.out(p.in("parseAnyMatcher")) + } + + if p.pt.rn == utf8.RuneError && p.pt.w == 0 { + // EOF - see utf8.DecodeRune + p.failAt(false, p.pt.position, ".") + return nil, false + } + start := p.pt + p.read() + p.failAt(true, start.position, ".") + return p.sliceFrom(start), true +} + +// nolint: gocyclo +func (p *parser) parseCharClassMatcher(chr *charClassMatcher) (any, bool) { + if p.debug { + defer p.out(p.in("parseCharClassMatcher")) + } + + cur := p.pt.rn + start := p.pt + + // can't match EOF + if cur == utf8.RuneError && p.pt.w == 0 { // see utf8.DecodeRune + p.failAt(false, start.position, chr.val) + return nil, false + } + + if chr.ignoreCase { + cur = unicode.ToLower(cur) + } + + // try to match in the list of available chars + for _, rn := range chr.chars { + if rn == cur { + if chr.inverted { + p.failAt(false, start.position, chr.val) + return nil, false + } + p.read() + p.failAt(true, start.position, chr.val) + return p.sliceFrom(start), true + } + } + + // try to match in the list of ranges + for i := 0; i < len(chr.ranges); i += 2 { + if cur >= chr.ranges[i] && cur <= chr.ranges[i+1] { + if chr.inverted { + p.failAt(false, start.position, chr.val) + return nil, false + } + p.read() + p.failAt(true, start.position, chr.val) + return p.sliceFrom(start), true + } + } + + // try to match in the list of Unicode classes + for _, cl := range chr.classes { + if unicode.Is(cl, cur) { + if chr.inverted { + p.failAt(false, start.position, chr.val) + return nil, false + } + p.read() + p.failAt(true, start.position, chr.val) + return p.sliceFrom(start), true + } + } + + if chr.inverted { + p.read() + p.failAt(true, start.position, chr.val) + return p.sliceFrom(start), true + } + p.failAt(false, start.position, chr.val) + return nil, false +} + +func (p *parser) incChoiceAltCnt(ch *choiceExpr, altI int) { + choiceIdent := fmt.Sprintf("%s %d:%d", p.rstack[len(p.rstack)-1].name, ch.pos.line, ch.pos.col) + m := p.ChoiceAltCnt[choiceIdent] + if m == nil { + m = make(map[string]int) + p.ChoiceAltCnt[choiceIdent] = m + } + // We increment altI by 1, so the keys do not start at 0 + alt := strconv.Itoa(altI + 1) + if altI == choiceNoMatch { + alt = p.choiceNoMatch + } + m[alt]++ +} + +func (p *parser) parseChoiceExpr(ch *choiceExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseChoiceExpr")) + } + + for altI, alt := range ch.alternatives { + // dummy assignment to prevent compile error if optimized + _ = altI + + state := p.cloneState() + + p.pushV() + val, ok := p.parseExprWrap(alt) + p.popV() + if ok { + p.incChoiceAltCnt(ch, altI) + return val, ok + } + p.restoreState(state) + } + p.incChoiceAltCnt(ch, choiceNoMatch) + return nil, false +} + +func (p *parser) parseLabeledExpr(lab *labeledExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseLabeledExpr")) + } + + p.pushV() + val, ok := p.parseExprWrap(lab.expr) + p.popV() + if ok && lab.label != "" { + m := p.vstack[len(p.vstack)-1] + m[lab.label] = val + } + return val, ok +} + +func (p *parser) parseLitMatcher(lit *litMatcher) (any, bool) { + if p.debug { + defer p.out(p.in("parseLitMatcher")) + } + + start := p.pt + for _, want := range lit.val { + cur := p.pt.rn + if lit.ignoreCase { + cur = unicode.ToLower(cur) + } + if cur != want { + p.failAt(false, start.position, lit.want) + p.restore(start) + return nil, false + } + p.read() + } + p.failAt(true, start.position, lit.want) + return p.sliceFrom(start), true +} + +func (p *parser) parseNotCodeExpr(not *notCodeExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseNotCodeExpr")) + } + + state := p.cloneState() + + ok, err := not.run(p) + if err != nil { + p.addErr(err) + } + p.restoreState(state) + + return nil, !ok +} + +func (p *parser) parseNotExpr(not *notExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseNotExpr")) + } + + pt := p.pt + state := p.cloneState() + p.pushV() + p.maxFailInvertExpected = !p.maxFailInvertExpected + _, ok := p.parseExprWrap(not.expr) + p.maxFailInvertExpected = !p.maxFailInvertExpected + p.popV() + p.restoreState(state) + p.restore(pt) + + return nil, !ok +} + +func (p *parser) parseOneOrMoreExpr(expr *oneOrMoreExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseOneOrMoreExpr")) + } + + var vals []any + + for { + p.pushV() + val, ok := p.parseExprWrap(expr.expr) + p.popV() + if !ok { + if len(vals) == 0 { + // did not match once, no match + return nil, false + } + return vals, true + } + vals = append(vals, val) + } +} + +func (p *parser) parseRecoveryExpr(recover *recoveryExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseRecoveryExpr (" + strings.Join(recover.failureLabel, ",") + ")")) + } + + p.pushRecovery(recover.failureLabel, recover.recoverExpr) + val, ok := p.parseExprWrap(recover.expr) + p.popRecovery() + + return val, ok +} + +func (p *parser) parseRuleRefExpr(ref *ruleRefExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseRuleRefExpr " + ref.name)) + } + + if ref.name == "" { + panic(fmt.Sprintf("%s: invalid rule: missing name", ref.pos)) + } + + rule := p.rules[ref.name] + if rule == nil { + p.addErr(fmt.Errorf("undefined rule: %s", ref.name)) + return nil, false + } + return p.parseRuleWrap(rule) +} + +func (p *parser) parseSeqExpr(seq *seqExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseSeqExpr")) + } + + vals := make([]any, 0, len(seq.exprs)) + + pt := p.pt + state := p.cloneState() + for _, expr := range seq.exprs { + val, ok := p.parseExprWrap(expr) + if !ok { + p.restoreState(state) + p.restore(pt) + return nil, false + } + vals = append(vals, val) + } + return vals, true +} + +func (p *parser) parseStateCodeExpr(state *stateCodeExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseStateCodeExpr")) + } + + err := state.run(p) + if err != nil { + p.addErr(err) + } + return nil, true +} + +func (p *parser) parseThrowExpr(expr *throwExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseThrowExpr")) + } + + for i := len(p.recoveryStack) - 1; i >= 0; i-- { + if recoverExpr, ok := p.recoveryStack[i][expr.label]; ok { + if val, ok := p.parseExprWrap(recoverExpr); ok { + return val, ok + } + } + } + + return nil, false +} + +func (p *parser) parseZeroOrMoreExpr(expr *zeroOrMoreExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseZeroOrMoreExpr")) + } + + var vals []any + + for { + p.pushV() + val, ok := p.parseExprWrap(expr.expr) + p.popV() + if !ok { + return vals, true + } + vals = append(vals, val) + } +} + +func (p *parser) parseZeroOrOneExpr(expr *zeroOrOneExpr) (any, bool) { + if p.debug { + defer p.out(p.in("parseZeroOrOneExpr")) + } + + p.pushV() + val, _ := p.parseExprWrap(expr.expr) + p.popV() + // whether it matched or not, consider it a match + return val, true +} diff --git a/quesma/painful/model.go b/quesma/painful/model.go new file mode 100644 index 000000000..c4224201e --- /dev/null +++ b/quesma/painful/model.go @@ -0,0 +1,262 @@ +// Copyright Quesma, licensed under the Elastic License 2.0. +// SPDX-License-Identifier: Elastic-2.0 +package painful + +//go:generate pigeon -nolint -support-left-recursion -o generated_parser.go painless.peg + +import ( + "fmt" + "time" +) + +func ParsePainless(script string) (Expr, error) { + + evalTree, err := Parse("", []byte(script)) + if err != nil { + return nil, err + } + + switch expr := evalTree.(type) { + case Expr: + return expr, nil + + default: + return nil, fmt.Errorf("not an painless expression") + } + +} + +type Env struct { + Doc map[string]any + + EmitValue any +} + +type Expr interface { + Eval(env *Env) (any, error) +} + +type LiteralExpr struct { + Value any +} + +func (l *LiteralExpr) Eval(env *Env) (any, error) { + return l.Value, nil +} + +type InfixOpExpr struct { + Position string + Left Expr + Op string + Right Expr +} + +func (i *InfixOpExpr) Eval(env *Env) (any, error) { + + left, err := i.Left.Eval(env) + if err != nil { + return nil, err + } + + right, err := i.Right.Eval(env) + if err != nil { + return nil, err + } + + switch i.Op { + + case "+": + + switch left.(type) { + + case string: + return fmt.Sprintf("%v%v", left, right), nil + + default: + return fmt.Sprintf("%v%v", left, right), nil + } + + default: + + return nil, fmt.Errorf("%s: '%s' operator is not supported", i.Position, i.Op) + + } +} + +type ConditionalExpr struct { + Cond Expr + Then Expr + Else Expr +} + +func (c *ConditionalExpr) Eval(env *Env) (any, error) { + + cond, err := c.Cond.Eval(env) + if err != nil { + return nil, err + } + + if cond.(bool) { + return c.Then.Eval(env) + } + + return c.Else.Eval(env) +} + +type DocExpr struct { + FieldName Expr +} + +func (d *DocExpr) Eval(env *Env) (any, error) { + + fieldName, err := d.FieldName.Eval(env) + if err != nil { + return nil, err + } + + key := fmt.Sprintf("%v", fieldName) + return env.Doc[key], nil +} + +type EmitExpr struct { + Expr Expr +} + +func (e *EmitExpr) Eval(env *Env) (any, error) { + + val, err := e.Expr.Eval(env) + if err != nil { + return nil, err + } + + env.EmitValue = val + + return val, nil +} + +type AccessorExpr struct { + Position string + Expr Expr + PropertyName string +} + +func (a *AccessorExpr) Eval(env *Env) (any, error) { + + val, err := a.Expr.Eval(env) + if err != nil { + return nil, err + } + + // value property is a special case + // it's just a current value of the expression + if a.PropertyName == "value" { + return val, nil + } + + // for testing purposes + if a.PropertyName == "type" { + return fmt.Sprintf("%T", val), nil + } + + return nil, fmt.Errorf("%s: '%s' property is not supported", a.Position, a.PropertyName) + +} + +type MethodCallExpr struct { + Position string + Expr Expr + MethodName string + Args []Expr +} + +func (m *MethodCallExpr) Eval(env *Env) (any, error) { + + val, err := m.Expr.Eval(env) + if err != nil { + return nil, err + } + + switch m.MethodName { + + case "getHour": + + typeVal, err := ExpectDate(val) + + if err != nil { + return nil, fmt.Errorf("%s: method '%s' failed to coerce '%v' into a datetime: %v ", m.Position, m.MethodName, val, err) + } + + return typeVal.Hour(), nil + + case "formatISO8601": // TODO maybe more easier to remember name + + typeVal, err := ExpectDate(val) + + if err != nil { + return nil, fmt.Errorf("%s: method '%s' failed to coerce '%v' into a datetime: %v ", m.Position, m.MethodName, val, err) + } + + return typeVal.Format(time.RFC3339), nil + + default: + return nil, fmt.Errorf("%s: '%s' method is not supported", m.Position, m.MethodName) + } +} + +func ExpectExpr(potentialExpr any) (Expr, error) { + + switch expr := potentialExpr.(type) { + case Expr: + return expr, nil + default: + return nil, fmt.Errorf("expected expression, got %T", potentialExpr) + } +} + +func ExpectString(potentialExpr any) (string, error) { + + switch str := potentialExpr.(type) { + case string: + return str, nil + default: + return "", fmt.Errorf("expected string, got %T", potentialExpr) + } +} + +func ExpectDate(potentialExpr any) (time.Time, error) { + + switch date := potentialExpr.(type) { + case time.Time: + return date, nil + + case string: + + formats := []string{ + "Jan 2, 2006 @ 15:04:05.000 -0700 MST", // this format in example provided by Kibana\ + "2006-01-02 15:04:05.000 -0700 MST", // clickhouse format + time.Layout, + time.ANSIC, + time.UnixDate, + time.RubyDate, + time.RFC822, + time.RFC822Z, + time.RFC850, + time.RFC1123, + time.RFC1123Z, + time.RFC3339, + time.RFC3339Nano, + time.RFC3339, + } + + for _, format := range formats { + t, err := time.Parse(format, date) + if err == nil { + return t, nil + } + } + + return time.Time{}, fmt.Errorf("failed to parse date: %s", date) + default: + return time.Time{}, fmt.Errorf("expected date, got %T", potentialExpr) + } +} diff --git a/quesma/painful/painless.peg b/quesma/painful/painless.peg new file mode 100644 index 000000000..7cc767182 --- /dev/null +++ b/quesma/painful/painless.peg @@ -0,0 +1,129 @@ +{ +// Copyright Quesma, licensed under the Elastic License 2.0. +// SPDX-License-Identifier: Elastic-2.0 +package painful +} + + +Expr = expr:OpExpr / expr:MethodCall / expr:Accessor / expr:Doc / expr:Emit / expr:String { + return expr, nil +} + +Emit = "emit" "(" _ expr:Expr _ ")" { + + exprVal ,err := ExpectExpr(expr) + if err != nil { + return nil, err + } + + return &EmitExpr{Expr: exprVal}, nil +} + +Doc = "doc" "[" key:Expr "]" { + + exprVal ,err := ExpectExpr(key) + if err != nil { + return nil, err + } + + return &DocExpr{FieldName: exprVal}, nil +} + +Accessor = expr:Expr "." field:Identifier { + + exprVal,err := ExpectExpr(expr) + if err != nil { + return nil, err + } + + strVal,err := ExpectString(field) + if err != nil { + return nil, err + } + + return &AccessorExpr{Position: c.pos.String(), Expr: exprVal, PropertyName: strVal}, nil +} + +MethodCall = expr:Expr "." method:Identifier "(" args:Expr* ','? ")" { + + exprVal, err := ExpectExpr(expr) + if err != nil { + return nil, err + } + + strVal, err := ExpectString(method) + if err != nil { + return nil, err + } + + var argsVal []Expr + + switch argsVals := args.(type) { + + case nil: + argsVal = []Expr{} + case []any: + + for _, arg := range argsVals { + argVal,err := ExpectExpr(arg) + if err != nil { + return nil, err + } + argsVal = append(argsVal, argVal) + } + + default: + return nil, fmt.Errorf("internal parser error. '%T' is not valid method argument", args) + } + + + for _, arg := range argsVal { + argVal,err := ExpectExpr(arg) + if err != nil { + return nil, err + } + argsVal = append(argsVal, argVal) + } + + return &MethodCallExpr{Position: c.pos.String(), Expr: exprVal, MethodName: strVal, Args: argsVal}, nil +} + +OpExpr = left:Expr _ op:Op _ right:Expr { + leftVal,err := ExpectExpr(left) + if err != nil { + return nil, err + } + + rightVal,err := ExpectExpr(right) + if err != nil { + return nil, err + } + + opVal,err := ExpectString(op) + if err != nil { + return nil, err + } + + return &InfixOpExpr{Position: c.pos.String(), Left: leftVal, Op: opVal, Right: rightVal}, nil +} + +Op = op:"+" { + return string(c.text), nil +} + +String = '\'' s:[^']* '\'' { + + strVal := string(c.text) + strVal = strings.Trim(strVal, "'") + return &LiteralExpr{Value: strVal}, nil +} + +Identifier = id:[a-zA-Z0-9_]+ { + return string(c.text), nil +} + +_ "whitespace" <- [ \n\t\r]* + +EOF + = !. + diff --git a/quesma/painful/painless_test.go b/quesma/painful/painless_test.go new file mode 100644 index 000000000..592928961 --- /dev/null +++ b/quesma/painful/painless_test.go @@ -0,0 +1,101 @@ +// Copyright Quesma, licensed under the Elastic License 2.0. +// SPDX-License-Identifier: Elastic-2.0 +package painful + +import ( + "reflect" + "testing" +) + +func TestPainless(t *testing.T) { + + tests := []struct { + name string + input map[string]any + script string + output any + }{ + { + name: "simple addition", + input: map[string]any{ + "field": 42, + }, + script: "emit(doc['field'].value)", + output: 42, + }, + + { + name: "concat", + input: map[string]any{ + "foo": "a", + "bar": "b", + }, + script: "emit(doc['foo'].value + doc['bar'].value)", + output: "ab", + }, + + { + name: "concat strings", + input: map[string]any{}, + script: "emit('a' + 'b')", + output: "ab", + }, + + { + name: "concat date literal and string", + input: map[string]any{ + "@timestamp": "2022-09-22T12:16:59.985Z", + "uuid": "1234", + }, + script: "emit(doc['@timestamp'].value + '&' + doc['uuid'].value)", + output: "2022-09-22T12:16:59.985Z&1234", + }, + + { + name: "get hour from date", + input: map[string]any{ + "@timestamp": "2022-09-22T12:16:59.98Z", + }, + script: "emit(doc['@timestamp'].value.getHour())", + output: 12, + }, + + { + name: "format date with ISO", + input: map[string]any{ + "@timestamp": "2022-09-22T12:16:59.98Z", + }, + script: "emit(doc['@timestamp'].value.formatISO8601())", + output: "2022-09-22T12:16:59Z", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res, err := ParsePainless(tt.script) + if err != nil { + t.Fatal(err) + } + + env := &Env{ + Doc: tt.input, + } + + switch expr := res.(type) { + case Expr: + + _, err := expr.Eval(env) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(tt.output, env.EmitValue) { + t.Errorf("expected %v, got %v", tt.output, env.EmitValue) + } + + default: + t.Fatal("not an expression") + } + }) + } +} diff --git a/quesma/painful/rest_api.go b/quesma/painful/rest_api.go new file mode 100644 index 000000000..474819700 --- /dev/null +++ b/quesma/painful/rest_api.go @@ -0,0 +1,111 @@ +// Copyright Quesma, licensed under the Elastic License 2.0. +// SPDX-License-Identifier: Elastic-2.0 +package painful + +import ( + "net/http" + "quesma/quesma/types" +) + +type ScriptRequest struct { + Context string `json:"context"` + Script struct { + Source string `json:"source"` + } `json:"script"` + + ContextSetup struct { + Document types.JSON `json:"document"` + IndexName string `json:"index"` + } `json:"context_setup"` +} + +type ScriptResponse struct { + Result []any `json:"result"` +} + +type ScriptErrorErrorElement struct { + Lang string `json:"lang"` + Position struct { + End int `json:"end"` + Offset int `json:"offset"` + Start int `json:"start"` + } `json:"position"` + Reason string `json:"reason"` + Script string `json:"script"` + ScriptStack []string `json:"script_stack"` + Type string `json:"type"` + RootCause []ScriptErrorErrorElement `json:"root_cause"` +} + +type ScriptErrorResponse struct { + Error struct { + CausedBy struct { + Reason string `json:"reason"` + Type string `json:"type"` + } `json:"caused_by"` + Lang string `json:"lang"` + Position struct { + End int `json:"end"` + Offset int `json:"offset"` + Start int `json:"start"` + } `json:"position"` + Reason string `json:"reason"` + RootCause []ScriptErrorErrorElement `json:"root_cause"` + Script string `json:"script"` + ScriptStack []string `json:"script_stack"` + Type string `json:"type"` + } `json:"error"` + Status int `json:"status"` +} + +func RenderErrorResponse(script string, err error) ScriptErrorResponse { + res := ScriptErrorResponse{} + + rootCause := ScriptErrorErrorElement{} + rootCause.Reason = err.Error() + rootCause.Type = "script_exception" + rootCause.Lang = "painless" + rootCause.Position.Start = 0 + rootCause.Position.End = 0 + rootCause.Position.Offset = 0 + rootCause.Script = script + rootCause.ScriptStack = []string{script} + + res.Error.CausedBy.Reason = err.Error() + res.Error.CausedBy.Type = "illegal_argument_exception" + res.Error.Lang = "painless" + + res.Error.Position.Start = 0 + res.Error.Position.End = 0 + res.Error.Position.Offset = 0 + + res.Error.Type = "script_exception" + res.Error.Reason = "compile error" + + res.Error.RootCause = []ScriptErrorErrorElement{rootCause} + + res.Status = http.StatusBadRequest + + return res +} + +func (s ScriptRequest) Eval() (res ScriptResponse, err error) { + env := &Env{ + Doc: s.ContextSetup.Document, + } + + evalTree, err := ParsePainless(s.Script.Source) + if err != nil { + return res, err + } + + _, err = evalTree.Eval(env) + if err != nil { + return res, err + } + + res.Result = []any{env.EmitValue} + + return res, nil + +} diff --git a/quesma/queryparser/query_parser.go b/quesma/queryparser/query_parser.go index 8619c4304..33fc39081 100644 --- a/quesma/queryparser/query_parser.go +++ b/quesma/queryparser/query_parser.go @@ -71,7 +71,10 @@ func (cw *ClickhouseQueryTranslator) ParseQuery(body types.JSON) (*model.Executi queries = append(queries, listQuery) } - runtimeMappings := ParseRuntimeMappings(body) // we apply post query transformer for certain aggregation types + runtimeMappings, err := ParseRuntimeMappings(body) // we apply post query transformer for certain aggregation types + if err != nil { + return &model.ExecutionPlan{}, err + } // we apply post query transformer for certain aggregation types // this should be a part of the query parsing process diff --git a/quesma/queryparser/runtime_mappings.go b/quesma/queryparser/runtime_mappings.go index f7d6fe00a..697de6e91 100644 --- a/quesma/queryparser/runtime_mappings.go +++ b/quesma/queryparser/runtime_mappings.go @@ -4,10 +4,11 @@ package queryparser import ( "quesma/model" + "quesma/painful" "quesma/quesma/types" ) -func ParseRuntimeMappings(body types.JSON) map[string]model.RuntimeMapping { +func ParseRuntimeMappings(body types.JSON) (map[string]model.RuntimeMapping, error) { result := make(map[string]model.RuntimeMapping) @@ -27,28 +28,43 @@ func ParseRuntimeMappings(body types.JSON) map[string]model.RuntimeMapping { if scriptAsMap, ok := script.(map[string]interface{}); ok { if source, ok := scriptAsMap["source"]; ok { if sourceAsString, ok := source.(string); ok { - mapping.Expr = ParseScript(sourceAsString) + + dbExpr, postProcesExpr, err := ParseScript(sourceAsString) + + if err != nil { + return nil, err + } + + mapping.DatabaseExpression = dbExpr + mapping.PostProcessExpression = postProcesExpr } } } } } - if mapping.Expr != nil { + if mapping.DatabaseExpression != nil { result[k] = mapping } } } } - return result + return result, nil } -func ParseScript(s string) model.Expr { +func ParseScript(s string) (model.Expr, painful.Expr, error) { // TODO: add a real parser here if s == "emit(doc['timestamp'].value.getHour());" { - return model.NewFunction(model.DateHourFunction, model.NewColumnRef(model.TimestampFieldName)) + return model.NewFunction(model.DateHourFunction, model.NewColumnRef(model.TimestampFieldName)), nil, nil } - // harmless default - return model.NewLiteral("NULL") + expr, err := painful.ParsePainless(s) + if err != nil { + return nil, nil, err + } + + // TODO here we can transform the parsed expression to an SQL + + // we return an empty SQL expression for given field, it'll make a column in the result set + return model.NewLiteral("NULL"), expr, nil } diff --git a/quesma/quesma/matchers.go b/quesma/quesma/matchers.go index 8a3f7d838..f2528fc9d 100644 --- a/quesma/quesma/matchers.go +++ b/quesma/quesma/matchers.go @@ -3,7 +3,9 @@ package quesma import ( + "github.com/goccy/go-json" "quesma/logger" + "quesma/painful" "quesma/quesma/config" "quesma/quesma/types" "quesma/table_resolver" @@ -145,3 +147,28 @@ func matchAgainstKibanaInternal() quesma_api.RequestMatcher { return quesma_api.MatchResult{Matched: matched} }) } + +func matchAgainstIndexNameInScriptRequestBody(tableResolver table_resolver.TableResolver) quesma_api.RequestMatcher { + return quesma_api.RequestMatcherFunc(func(req *quesma_api.Request) quesma_api.MatchResult { + + var scriptRequest painful.ScriptRequest + + err := json.Unmarshal([]byte(req.Body), &scriptRequest) + if err != nil { + return quesma_api.MatchResult{Matched: false} + } + + decision := tableResolver.Resolve(quesma_api.QueryPipeline, scriptRequest.ContextSetup.IndexName) + + if decision.Err != nil { + return quesma_api.MatchResult{Matched: false, Decision: decision} + } + for _, connector := range decision.UseConnectors { + if _, ok := connector.(*quesma_api.ConnectorDecisionClickhouse); ok { + return quesma_api.MatchResult{Matched: true, Decision: decision} + } + } + + return quesma_api.MatchResult{Matched: false, Decision: nil} + }) +} diff --git a/quesma/quesma/router.go b/quesma/quesma/router.go index 55a73e2b3..c19dab2c8 100644 --- a/quesma/quesma/router.go +++ b/quesma/quesma/router.go @@ -12,6 +12,7 @@ import ( "quesma/end_user_errors" "quesma/ingest" "quesma/logger" + "quesma/painful" "quesma/queryparser" "quesma/quesma/config" "quesma/quesma/errors" @@ -58,7 +59,43 @@ func ConfigureRouter(cfg *config.QuesmaConfiguration, sr schema.Registry, lm *cl // So, if you add multiple handlers with the same path, the first one will be used, the rest will be redirected to the elastic cluster. // This is current limitation of the router. + router.Register(routes.ExecutePainlessScriptPath, and(method("POST"), matchAgainstIndexNameInScriptRequestBody(tableResolver)), func(ctx context.Context, req *quesma_api.Request) (*quesma_api.Result, error) { + + var scriptRequest painful.ScriptRequest + + err := json.Unmarshal([]byte(req.Body), &scriptRequest) + if err != nil { + return nil, err + } + + scriptResponse, err := scriptRequest.Eval() + + if err != nil { + errorResponse := painful.RenderErrorResponse(scriptRequest.Script.Source, err) + responseBytes, err := json.Marshal(errorResponse) + if err != nil { + return nil, err + } + + return &quesma_api.Result{ + Body: string(responseBytes), + StatusCode: errorResponse.Status, + }, nil + } + + responseBytes, err := json.Marshal(scriptResponse) + if err != nil { + return nil, err + } + + return &quesma_api.Result{ + Body: string(responseBytes), + StatusCode: http.StatusOK, + }, nil + }) + router.Register(routes.ClusterHealthPath, method("GET"), func(_ context.Context, req *quesma_api.Request) (*quesma_api.Result, error) { + return elasticsearchQueryResult(`{"cluster_name": "quesma"}`, http.StatusOK), nil }) diff --git a/quesma/quesma/router_v2.go b/quesma/quesma/router_v2.go index b96bb6835..d62ab075e 100644 --- a/quesma/quesma/router_v2.go +++ b/quesma/quesma/router_v2.go @@ -11,6 +11,7 @@ import ( "quesma/elasticsearch" "quesma/ingest" "quesma/logger" + "quesma/painful" "quesma/queryparser" "quesma/quesma/config" "quesma/quesma/errors" @@ -44,6 +45,42 @@ func ConfigureIngestRouterV2(cfg *config.QuesmaConfiguration, ip *ingest.IngestP for _, path := range elasticsearch.InternalPaths { router.Register(path, quesma_api.Never(), func(ctx context.Context, req *quesma_api.Request) (*quesma_api.Result, error) { return nil, nil }) } + + router.Register(routes.ExecutePainlessScriptPath, and(method("POST"), matchAgainstIndexNameInScriptRequestBody(tableResolver)), func(ctx context.Context, req *quesma_api.Request) (*quesma_api.Result, error) { + + var scriptRequest painful.ScriptRequest + + err := json.Unmarshal([]byte(req.Body), &scriptRequest) + if err != nil { + return nil, err + } + + scriptResponse, err := scriptRequest.Eval() + + if err != nil { + errorResponse := painful.RenderErrorResponse(scriptRequest.Script.Source, err) + responseBytes, err := json.Marshal(errorResponse) + if err != nil { + return nil, err + } + + return &quesma_api.Result{ + Body: string(responseBytes), + StatusCode: errorResponse.Status, + }, nil + } + + responseBytes, err := json.Marshal(scriptResponse) + if err != nil { + return nil, err + } + + return &quesma_api.Result{ + Body: string(responseBytes), + StatusCode: http.StatusOK, + }, nil + }) + router.Register(routes.BulkPath, and(method("POST", "PUT"), matchedAgainstBulkBody(cfg, tableResolver)), func(ctx context.Context, req *quesma_api.Request) (*quesma_api.Result, error) { body, err := types.ExpectNDJSON(req.ParsedBody) diff --git a/quesma/quesma/schema_transformer.go b/quesma/quesma/schema_transformer.go index 164deba1c..7e0f1b8f0 100644 --- a/quesma/quesma/schema_transformer.go +++ b/quesma/quesma/schema_transformer.go @@ -669,7 +669,7 @@ func (s *SchemaCheckPass) applyRuntimeMappings(indexSchema schema.Schema, query switch c := col.(type) { case model.ColumnRef: if mapping, ok := query.RuntimeMappings[c.ColumnName]; ok { - cols[i] = model.NewAliasedExpr(mapping.Expr, c.ColumnName) + cols[i] = model.NewAliasedExpr(mapping.DatabaseExpression, c.ColumnName) } } } @@ -679,7 +679,7 @@ func (s *SchemaCheckPass) applyRuntimeMappings(indexSchema schema.Schema, query visitor := model.NewBaseVisitor() visitor.OverrideVisitColumnRef = func(b *model.BaseExprVisitor, e model.ColumnRef) interface{} { if mapping, ok := query.RuntimeMappings[e.ColumnName]; ok { - return mapping.Expr + return mapping.DatabaseExpression } return e } diff --git a/quesma/quesma/search.go b/quesma/quesma/search.go index 19a296481..ab0303130 100644 --- a/quesma/quesma/search.go +++ b/quesma/quesma/search.go @@ -15,6 +15,7 @@ import ( "quesma/logger" "quesma/model" "quesma/optimize" + "quesma/painful" "quesma/queryparser" "quesma/queryparser/query_util" "quesma/quesma/async_search_storage" @@ -846,11 +847,33 @@ func (q *QueryRunner) postProcessResults(plan *model.ExecutionPlan, results [][] // maybe model.Schema should be part of ExecutionPlan instead of Query indexSchema := plan.Queries[0].Schema - pipeline := []struct { + type pipelineElement struct { name string transformer model.ResultTransformer - }{ - {"replaceColumNamesWithFieldNames", &replaceColumNamesWithFieldNames{indexSchema: indexSchema}}, + } + + var pipeline []pipelineElement + + pipeline = append(pipeline, pipelineElement{"replaceColumNamesWithFieldNames", &replaceColumNamesWithFieldNames{indexSchema: indexSchema}}) + + // we can take the first one because all queries have the same runtime mappings + if len(plan.Queries[0].RuntimeMappings) > 0 { + + // this transformer must be called after replaceColumNamesWithFieldNames + // painless scripts rely on field names not column names + + fieldScripts := make(map[string]painful.Expr) + + for field, runtimeMapping := range plan.Queries[0].RuntimeMappings { + if runtimeMapping.PostProcessExpression != nil { + fieldScripts[field] = runtimeMapping.PostProcessExpression + } + } + + if len(fieldScripts) > 0 { + pipeline = append(pipeline, pipelineElement{"applyPainlessScripts", &EvalPainlessScriptOnColumnsTransformer{FieldScripts: fieldScripts}}) + } + } var err error diff --git a/quesma/quesma/transformations.go b/quesma/quesma/transformations.go index 3557ffe76..e4b7f126a 100644 --- a/quesma/quesma/transformations.go +++ b/quesma/quesma/transformations.go @@ -4,6 +4,7 @@ package quesma import ( "quesma/model" + "quesma/painful" "quesma/schema" ) @@ -41,3 +42,35 @@ func (t *replaceColumNamesWithFieldNames) Transform(result [][]model.QueryResult } return result, nil } + +type EvalPainlessScriptOnColumnsTransformer struct { + FieldScripts map[string]painful.Expr +} + +func (t *EvalPainlessScriptOnColumnsTransformer) Transform(result [][]model.QueryResultRow) ([][]model.QueryResultRow, error) { + + for _, rows := range result { + for _, row := range rows { + doc := make(map[string]any) + for j := range row.Cols { + doc[row.Cols[j].ColName] = row.Cols[j].Value + } + + for j := range row.Cols { + + if script, exists := t.FieldScripts[row.Cols[j].ColName]; exists { + env := &painful.Env{ + Doc: doc, + } + + _, err := script.Eval(env) + if err != nil { + return nil, err + } + row.Cols[j].Value = env.EmitValue + } + } + } + } + return result, nil +} diff --git a/quesma/v2/core/routes/paths.go b/quesma/v2/core/routes/paths.go index 4e350e292..104e1bf7b 100644 --- a/quesma/v2/core/routes/paths.go +++ b/quesma/v2/core/routes/paths.go @@ -7,25 +7,26 @@ import ( ) const ( - GlobalSearchPath = "/_search" - IndexSearchPath = "/:index/_search" - IndexAsyncSearchPath = "/:index/_async_search" - IndexCountPath = "/:index/_count" - IndexDocPath = "/:index/_doc" - IndexRefreshPath = "/:index/_refresh" - IndexBulkPath = "/:index/_bulk" - IndexMappingPath = "/:index/_mapping" - FieldCapsPath = "/:index/_field_caps" - TermsEnumPath = "/:index/_terms_enum" - EQLSearch = "/:index/_eql/search" - ResolveIndexPath = "/_resolve/index/:index" - ClusterHealthPath = "/_cluster/health" - BulkPath = "/_bulk" - AsyncSearchIdPrefix = "/_async_search/" - AsyncSearchIdPath = "/_async_search/:id" - AsyncSearchStatusPath = "/_async_search/status/:id" - KibanaInternalPrefix = "/.kibana_" - IndexPath = "/:index" + GlobalSearchPath = "/_search" + IndexSearchPath = "/:index/_search" + IndexAsyncSearchPath = "/:index/_async_search" + IndexCountPath = "/:index/_count" + IndexDocPath = "/:index/_doc" + IndexRefreshPath = "/:index/_refresh" + IndexBulkPath = "/:index/_bulk" + IndexMappingPath = "/:index/_mapping" + FieldCapsPath = "/:index/_field_caps" + TermsEnumPath = "/:index/_terms_enum" + EQLSearch = "/:index/_eql/search" + ResolveIndexPath = "/_resolve/index/:index" + ClusterHealthPath = "/_cluster/health" + BulkPath = "/_bulk" + AsyncSearchIdPrefix = "/_async_search/" + AsyncSearchIdPath = "/_async_search/:id" + AsyncSearchStatusPath = "/_async_search/status/:id" + KibanaInternalPrefix = "/.kibana_" + IndexPath = "/:index" + ExecutePainlessScriptPath = "/_scripts/painless/_execute" // This path is used on the Kibana side to evaluate painless scripts when adding a new scripted field. // Quesma internal paths