From 5cb86d5f4f7e0a2814fb0220ac32891c20651b68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Hejman?= Date: Thu, 23 May 2024 08:46:17 +0200 Subject: [PATCH] `WHERE` clause refactor (#176) --- quesma/model/simple_query.go | 23 ++- quesma/model/simple_query_test.go | 11 +- quesma/queryparser/query_parser.go | 117 +++++++++++---- quesma/queryparser/query_parser_test.go | 38 +++++ .../where_clause/string_renderer.go | 70 +++++++++ .../queryparser/where_clause/used_fields.go | 64 +++++++++ .../queryparser/where_clause/where_clause.go | 133 ++++++++++++++++++ 7 files changed, 420 insertions(+), 36 deletions(-) create mode 100644 quesma/queryparser/where_clause/string_renderer.go create mode 100644 quesma/queryparser/where_clause/used_fields.go create mode 100644 quesma/queryparser/where_clause/where_clause.go diff --git a/quesma/model/simple_query.go b/quesma/model/simple_query.go index 97c990414..e4e26cedc 100644 --- a/quesma/model/simple_query.go +++ b/quesma/model/simple_query.go @@ -3,6 +3,7 @@ package model import ( "context" "mitmproxy/quesma/logger" + "mitmproxy/quesma/queryparser/where_clause" ) type SimpleQuery struct { @@ -21,9 +22,11 @@ func NewSimpleQueryWithFieldName(sql Statement, canParse bool, fieldName string) } type Statement struct { - Stmt string - IsCompound bool // "a" -> not compound, "a AND b" -> compound. Used to not make unnecessary brackets (not always, but usually) - FieldName string + // deprecated - we're moving to the new WhereStatement which should also remove the need for IsCompound and FieldName + Stmt string // Old, clunky and soon to be deprecated version + WhereStatement where_clause.Statement // New, better and bold version + IsCompound bool // "a" -> not compound, "a AND b" -> compound. Used to not make unnecessary brackets (not always, but usually) + FieldName string } func NewSimpleStatement(stmt string) Statement { @@ -52,6 +55,13 @@ func Or(orStmts []Statement) Statement { // sep = "AND" or "OR" func combineStatements(stmts []Statement, sep string) Statement { stmts = FilterNonEmpty(stmts) + var newWhereStatement where_clause.Statement + if len(stmts) > 0 { + newWhereStatement = stmts[0].WhereStatement + for _, stmt := range stmts[1:] { + newWhereStatement = where_clause.NewInfixOp(newWhereStatement, sep, stmt.WhereStatement) + } + } if len(stmts) > 1 { stmts = quoteWithBracketsIfCompound(stmts) var fieldName string @@ -65,7 +75,12 @@ func combineStatements(stmts []Statement, sep string) Statement { fieldName = stmt.FieldName } } - return NewCompoundStatement(sql, fieldName) + return Statement{ + WhereStatement: newWhereStatement, + Stmt: sql, + IsCompound: true, + FieldName: fieldName, + } } if len(stmts) == 1 { return stmts[0] diff --git a/quesma/model/simple_query_test.go b/quesma/model/simple_query_test.go index 9351d2c00..4fe4f0f3f 100644 --- a/quesma/model/simple_query_test.go +++ b/quesma/model/simple_query_test.go @@ -63,16 +63,23 @@ func TestOrAndAnd(t *testing.T) { t.Run("AND "+strconv.Itoa(i), func(t *testing.T) { b := make([]Statement, len(tt.stmts)) copy(b, tt.stmts) - assert.Equal(t, tt.want, And(b)) + tt.want.WhereStatement = nil + finalAnd := And(b) + finalAnd.WhereStatement = nil + assert.Equal(t, tt.want, finalAnd) }) } for i, tt := range tests { t.Run("OR "+strconv.Itoa(i), func(t *testing.T) { + tt.want.WhereStatement = nil tt.want.Stmt = strings.ReplaceAll(tt.want.Stmt, "AND", "OR") for i := range tt.stmts { tt.stmts[i].Stmt = strings.ReplaceAll(tt.stmts[i].Stmt, "AND", "OR") } - assert.Equal(t, tt.want, Or(tt.stmts)) + tt.want.WhereStatement = nil + finalOr := Or(tt.stmts) + finalOr.WhereStatement = nil + assert.Equal(t, tt.want, finalOr) }) } } diff --git a/quesma/queryparser/query_parser.go b/quesma/queryparser/query_parser.go index 22d5569c5..f9d3c4d60 100644 --- a/quesma/queryparser/query_parser.go +++ b/quesma/queryparser/query_parser.go @@ -2,6 +2,8 @@ package queryparser import ( "encoding/json" + "mitmproxy/quesma/queryparser/where_clause" + "fmt" "github.com/k0kubun/pp" "mitmproxy/quesma/clickhouse" @@ -240,7 +242,9 @@ func (cw *ClickhouseQueryTranslator) ParseAutocomplete(indexFilter *QueryMap, fi like = "LIKE" } cw.AddTokenToHighlight(*prefix) - stmts = append(stmts, model.NewSimpleStatement(fieldName+" "+like+" '"+*prefix+"%'")) + simpleStat := model.NewSimpleStatement(fieldName + " " + like + " '" + *prefix + "%'") + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldName), like, where_clause.NewLiteral("'"+*prefix+"%'")) + stmts = append(stmts, simpleStat) } return model.NewSimpleQuery(model.And(stmts), canParse) } @@ -333,19 +337,24 @@ func (cw *ClickhouseQueryTranslator) parseIds(queryMap QueryMap) model.SimpleQue } } - var statement string + var statement model.Statement if v, ok := cw.Table.Cols[timestampColumnName]; ok { switch v.Type.String() { case clickhouse.DateTime64.String(): - statement = fmt.Sprintf("toUnixTimestamp64Milli(%s) IN (%s)", strconv.Quote(timestampColumnName), ids) + statement = model.NewSimpleStatement(fmt.Sprintf("toUnixTimestamp64Milli(%s) IN (%s)", strconv.Quote(timestampColumnName), ids)) + statement.WhereStatement = where_clause.NewInfixOp(where_clause.NewFunction("toUnixTimestamp64Milli", []where_clause.Statement{where_clause.NewColumnRef(timestampColumnName)}...), "IN", where_clause.NewLiteral("(["+strings.Join(ids, ",")+"])")) case clickhouse.DateTime.String(): - statement = fmt.Sprintf("toUnixTimestamp(%s) *1000 IN (%s)", strconv.Quote(timestampColumnName), ids) + statement = model.NewSimpleStatement(fmt.Sprintf("toUnixTimestamp(%s) * 1000 IN (%s)", strconv.Quote(timestampColumnName), ids)) + statement.WhereStatement = where_clause.NewInfixOp(where_clause.NewInfixOp( + where_clause.NewFunction("toUnixTimestamp", []where_clause.Statement{where_clause.NewColumnRef(timestampColumnName)}...), + "*", + where_clause.NewLiteral("1000")), "IN", where_clause.NewLiteral("("+strings.Join(ids, ",")+")")) default: logger.Warn().Msgf("timestamp field of unsupported type %s", v.Type.String()) return model.NewSimpleQuery(model.NewSimpleStatement(""), true) } } - return model.NewSimpleQuery(model.NewSimpleStatement(statement), true) + return model.NewSimpleQuery(statement, true) } // Parses each model.SimpleQuery separately, returns list of translated SQLs @@ -426,6 +435,7 @@ func (cw *ClickhouseQueryTranslator) parseBool(queryMap QueryMap) model.SimpleQu canParse = canParse && canParseThis if len(sqlNots) > 0 { orSql := model.Or(sqlNots) + orSql.WhereStatement = where_clause.NewPrefixOp("NOT", []where_clause.Statement{orSql.WhereStatement}) if orSql.IsCompound { orSql.Stmt = "NOT (" + orSql.Stmt + ")" orSql.IsCompound = false // NOT (compound) is again simple @@ -444,9 +454,13 @@ func (cw *ClickhouseQueryTranslator) parseTerm(queryMap QueryMap) model.SimpleQu cw.AddTokenToHighlight(v) if k == "_index" { // index is a table name, already taken from URI and moved to FROM clause logger.Warn().Msgf("term %s=%v in query body, ignoring in result SQL", k, v) - return model.NewSimpleQuery(model.NewSimpleStatement(" 0=0 /* "+strconv.Quote(k)+"="+sprint(v)+" */ "), true) + simpleStat := model.NewSimpleStatement(" 0=0 /* " + strconv.Quote(k) + "=" + sprint(v) + " */ ") + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewLiteral("0"), "=", where_clause.NewLiteral("0 /* "+k+"="+sprint(v)+" */")) + return model.NewSimpleQuery(simpleStat, true) } - return model.NewSimpleQuery(model.NewSimpleStatement(strconv.Quote(k)+"="+sprint(v)), true) + simpleStat := model.NewSimpleStatement(strconv.Quote(k) + "=" + sprint(v)) + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(k), "=", where_clause.NewLiteral(sprint(v))) + return model.NewSimpleQuery(simpleStat, true) } } logger.WarnWithCtx(cw.Ctx).Msgf("we expect only 1 term, got: %d. value: %v", len(queryMap), queryMap) @@ -474,7 +488,9 @@ func (cw *ClickhouseQueryTranslator) parseTerms(queryMap QueryMap) model.SimpleQ orStmts := make([]model.Statement, len(vAsArray)) for i, v := range vAsArray { cw.AddTokenToHighlight(v) - orStmts[i] = model.NewSimpleStatement(strconv.Quote(k) + "=" + sprint(v)) + simpleStat := model.NewSimpleStatement(strconv.Quote(k) + "=" + sprint(v)) + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(k), "=", where_clause.NewLiteral(sprint(v))) + orStmts[i] = simpleStat } return model.NewSimpleQuery(model.Or(orStmts), true) } @@ -527,7 +543,9 @@ func (cw *ClickhouseQueryTranslator) parseMatch(queryMap QueryMap, matchPhrase b computedIdMatchingQuery := cw.parseIds(QueryMap{"values": []interface{}{subQuery}}) statements = append(statements, computedIdMatchingQuery.Sql) } else { - statements = append(statements, model.NewSimpleStatement(strconv.Quote(fieldName)+" iLIKE "+"'%"+subQuery+"%'")) + simpleStat := model.NewSimpleStatement(strconv.Quote(fieldName) + " iLIKE " + "'%" + subQuery + "%'") + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldName), "iLIKE", where_clause.NewLiteral("'%"+subQuery+"%'")) + statements = append(statements, simpleStat) } } return model.NewSimpleQuery(model.Or(statements), true) @@ -557,19 +575,21 @@ func (cw *ClickhouseQueryTranslator) parseMultiMatch(queryMap QueryMap) model.Si } else { fields = cw.Table.GetFulltextFields() } + alwaysFalseStmt := model.AlwaysFalseStatement + alwaysFalseStmt.WhereStatement = where_clause.NewLiteral("false") if len(fields) == 0 { - return model.NewSimpleQuery(model.AlwaysFalseStatement, true) + return model.NewSimpleQuery(alwaysFalseStmt, true) } query, ok := queryMap["query"] if !ok { logger.WarnWithCtx(cw.Ctx).Msgf("no query in multi_match query: %v", queryMap) - return model.NewSimpleQuery(model.AlwaysFalseStatement, false) + return model.NewSimpleQuery(alwaysFalseStmt, false) } queryAsString, ok := query.(string) if !ok { logger.WarnWithCtx(cw.Ctx).Msgf("invalid query type: %T, value: %v", query, query) - return model.NewSimpleQuery(model.AlwaysFalseStatement, false) + return model.NewSimpleQuery(alwaysFalseStmt, false) } var subQueries []string wereDone := false @@ -595,7 +615,9 @@ func (cw *ClickhouseQueryTranslator) parseMultiMatch(queryMap QueryMap) model.Si i := 0 for _, field := range fields { for _, subQ := range subQueries { - sqls[i] = model.NewSimpleStatement(strconv.Quote(field) + " iLIKE '%" + subQ + "%'") + simpleStat := model.NewSimpleStatement(strconv.Quote(field) + " iLIKE '%" + subQ + "%'") + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(field), "iLIKE", where_clause.NewLiteral("'%"+subQ+"%'")) + sqls[i] = simpleStat i++ } } @@ -614,11 +636,15 @@ func (cw *ClickhouseQueryTranslator) parsePrefix(queryMap QueryMap) model.Simple switch vCasted := v.(type) { case string: cw.AddTokenToHighlight(vCasted) - return model.NewSimpleQuery(model.NewSimpleStatement(strconv.Quote(fieldName)+" iLIKE '"+vCasted+"%'"), true) + simpleStat := model.NewSimpleStatement(strconv.Quote(fieldName) + " iLIKE '" + vCasted + "%'") + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldName), "iLIKE", where_clause.NewLiteral("'"+vCasted+"%'")) + return model.NewSimpleQuery(simpleStat, true) case QueryMap: token := vCasted["value"].(string) cw.AddTokenToHighlight(token) - return model.NewSimpleQuery(model.NewSimpleStatement(strconv.Quote(fieldName)+" iLIKE '"+token+"%'"), true) + simpleStat := model.NewSimpleStatement(strconv.Quote(fieldName) + " iLIKE '" + token + "%'") + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldName), "iLIKE", where_clause.NewLiteral("'"+token+"%'")) + return model.NewSimpleQuery(simpleStat, true) default: logger.WarnWithCtx(cw.Ctx).Msgf("unsupported prefix type: %T, value: %v", v, v) return model.NewSimpleQuery(model.NewSimpleStatement("unsupported prefix type"), false) @@ -645,8 +671,10 @@ func (cw *ClickhouseQueryTranslator) parseWildcard(queryMap QueryMap) model.Simp if value, ok := vAsMap["value"]; ok { if valueAsString, ok := value.(string); ok { cw.AddTokenToHighlight(valueAsString) - return model.NewSimpleQuery(model.NewSimpleStatement(strconv.Quote(fieldName)+" iLIKE '"+ - strings.ReplaceAll(valueAsString, "*", "%")+"'"), true) + simpleStat := model.NewSimpleStatement(strconv.Quote(fieldName) + " iLIKE '" + + strings.ReplaceAll(valueAsString, "*", "%") + "'") + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldName), "iLIKE", where_clause.NewLiteral("'"+strings.ReplaceAll(valueAsString, "*", "%")+"'")) + return model.NewSimpleQuery(simpleStat, true) } else { logger.WarnWithCtx(cw.Ctx).Msgf("invalid value type: %T, value: %v", value, value) return model.NewSimpleQuery(model.NewSimpleStatement("invalid value type"), false) @@ -749,10 +777,13 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ } for op, v := range v.(QueryMap) { + var fieldToPrint, timeFormatFuncName string + var valueToCompare where_clause.Statement fieldType := cw.Table.GetDateTimeType(cw.Ctx, field) vToPrint := sprint(v) - var fieldToPrint string + valueToCompare = where_clause.NewLiteral(vToPrint) if !isDatetimeInDefaultFormat { + timeFormatFuncName = "toUnixTimestamp64Milli" fieldToPrint = "toUnixTimestamp64Milli(" + strconv.Quote(field) + ")" } else { fieldToPrint = strconv.Quote(field) @@ -762,7 +793,9 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ // if it's a date, we need to parse it to Clickhouse's DateTime format // how to check if it does not contain date math expression? if _, err := time.Parse(time.RFC3339Nano, dateTime); err == nil { - vToPrint = cw.parseDateTimeString(cw.Table, field, dateTime) + vToPrint, timeFormatFuncName = cw.parseDateTimeString(cw.Table, field, dateTime) + // TODO Investigate the quotation below + valueToCompare = where_clause.NewFunction(timeFormatFuncName, where_clause.NewLiteral(fmt.Sprintf("'%s'", dateTime))) } else if op == "gte" || op == "lte" || op == "gt" || op == "lt" { vToPrint, err = cw.parseDateMathExpression(vToPrint) if err != nil { @@ -772,6 +805,7 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ } } else if v == nil { vToPrint = "NULL" + valueToCompare = where_clause.NewLiteral("NULL") } case clickhouse.Invalid: // assumes it is number that does not need formatting if len(vToPrint) > 2 && vToPrint[0] == '\'' && vToPrint[len(vToPrint)-1] == '\'' { @@ -786,6 +820,7 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ } else { logger.WarnWithCtx(cw.Ctx).Msgf("we use range with unknown literal %s, field %s", vToPrint, field) } + valueToCompare = where_clause.NewLiteral(vToPrint) } default: logger.WarnWithCtx(cw.Ctx).Msgf("invalid DateTime type for field: %s, parsed dateTime value: %s", field, vToPrint) @@ -794,13 +829,21 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ switch op { case "gte": - stmts = append(stmts, model.NewSimpleStatement(fieldToPrint+">="+vToPrint)) + simpleStat := model.NewSimpleStatement(fieldToPrint + ">=" + vToPrint) + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(field), ">=", valueToCompare) + stmts = append(stmts, simpleStat) case "lte": - stmts = append(stmts, model.NewSimpleStatement(fieldToPrint+"<="+vToPrint)) + simpleStat := model.NewSimpleStatement(fieldToPrint + "<=" + vToPrint) + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(field), "<=", valueToCompare) + stmts = append(stmts, simpleStat) case "gt": - stmts = append(stmts, model.NewSimpleStatement(fieldToPrint+">"+vToPrint)) + simpleStat := model.NewSimpleStatement(fieldToPrint + ">" + vToPrint) + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(field), ">", valueToCompare) + stmts = append(stmts, simpleStat) case "lt": - stmts = append(stmts, model.NewSimpleStatement(fieldToPrint+"<"+vToPrint)) + simpleStat := model.NewSimpleStatement(fieldToPrint + "<" + vToPrint) + simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(field), "<", valueToCompare) + stmts = append(stmts, simpleStat) case "format": // ignored default: @@ -816,16 +859,16 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ } // parseDateTimeString returns string used to parse DateTime in Clickhouse (depends on column type) -func (cw *ClickhouseQueryTranslator) parseDateTimeString(table *clickhouse.Table, field, dateTime string) string { +func (cw *ClickhouseQueryTranslator) parseDateTimeString(table *clickhouse.Table, field, dateTime string) (string, string) { typ := table.GetDateTimeType(cw.Ctx, field) switch typ { case clickhouse.DateTime64: - return "parseDateTime64BestEffort('" + dateTime + "')" + return "parseDateTime64BestEffort('" + dateTime + "')", "parseDateTime64BestEffort" case clickhouse.DateTime: - return "parseDateTimeBestEffort('" + dateTime + "')" + return "parseDateTimeBestEffort('" + dateTime + "')", "parseDateTimeBestEffort" default: logger.Error().Msgf("invalid DateTime type: %T for field: %s, parsed dateTime value: %s", typ, field, dateTime) - return "" + return "", "" } } @@ -846,19 +889,33 @@ func (cw *ClickhouseQueryTranslator) parseExists(queryMap QueryMap) model.Simple switch cw.Table.GetFieldInfo(cw.Ctx, fieldName) { case clickhouse.ExistsAndIsBaseType: - sql = model.NewSimpleStatement(fieldNameQuoted + " IS NOT NULL") + simpleStatement := model.NewSimpleStatement(fieldNameQuoted + " IS NOT NULL") + simpleStatement.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldNameQuoted), "IS", where_clause.NewLiteral("NOT NULL")) + statement := simpleStatement + sql = statement case clickhouse.ExistsAndIsArray: - sql = model.NewSimpleStatement(fieldNameQuoted + ".size0 = 0") + statement := model.NewSimpleStatement(fieldNameQuoted + ".size0 = 0") + statement.WhereStatement = where_clause.NewInfixOp(where_clause.NewNestedProperty( + *where_clause.NewColumnRef(fieldNameQuoted), + *where_clause.NewLiteral("size0"), + ), "=", where_clause.NewLiteral("0")) + sql = statement case clickhouse.NotExists: attrs := cw.Table.GetAttributesList() stmts := make([]model.Statement, len(attrs)) for i, a := range attrs { - stmts[i] = model.NewCompoundStatementNoFieldName( + compoundStatementNoFieldName := model.NewCompoundStatementNoFieldName( fmt.Sprintf("has(%s,%s) AND %s[indexOf(%s,%s)] IS NOT NULL", strconv.Quote(a.KeysArrayName), fieldNameQuoted, strconv.Quote(a.ValuesArrayName), strconv.Quote(a.KeysArrayName), fieldNameQuoted, ), ) + compoundStatementNoFieldName.WhereStatement = nil + hasFunc := where_clause.NewFunction("has", []where_clause.Statement{where_clause.NewColumnRef(a.KeysArrayName), where_clause.NewColumnRef(fieldName)}...) + arrayAccess := where_clause.NewArrayAccess(*where_clause.NewColumnRef(a.ValuesArrayName), where_clause.NewFunction("indexOf", []where_clause.Statement{where_clause.NewColumnRef(a.KeysArrayName), where_clause.NewLiteral(fieldNameQuoted)}...)) + isNotNull := where_clause.NewInfixOp(arrayAccess, "IS", where_clause.NewLiteral("NOT NULL")) + compoundStatementNoFieldName.WhereStatement = where_clause.NewInfixOp(hasFunc, "AND", isNotNull) + stmts[i] = compoundStatementNoFieldName } sql = model.Or(stmts) default: diff --git a/quesma/queryparser/query_parser_test.go b/quesma/queryparser/query_parser_test.go index 992359a94..4d2b165ef 100644 --- a/quesma/queryparser/query_parser_test.go +++ b/quesma/queryparser/query_parser_test.go @@ -2,9 +2,11 @@ package queryparser import ( "context" + "fmt" "mitmproxy/quesma/clickhouse" "mitmproxy/quesma/concurrent" "mitmproxy/quesma/model" + "mitmproxy/quesma/queryparser/where_clause" "mitmproxy/quesma/quesma/config" "mitmproxy/quesma/telemetry" "mitmproxy/quesma/testdata" @@ -15,6 +17,8 @@ import ( "github.com/stretchr/testify/assert" ) +var whereStatementRenderer = &where_clause.StringRenderer{} + // TODO: // 1. 14th test, "Query string". "(message LIKE '%%%' OR message LIKE '%logged%')", is it really // what should be? According to docs, I think so... Maybe test in Kibana? @@ -54,6 +58,20 @@ func TestQueryParserStringAttrConfig(t *testing.T) { assert.Equal(t, tt.WantedQueryType, queryInfo.Typ, "equals to wanted query type") query := cw.BuildNRowsQuery("*", simpleQuery, model.DefaultSizeListQuery) assert.Contains(t, tt.WantedQuery, *query) + // Test the new WhereStatement + if simpleQuery.Sql.WhereStatement != nil { + oldStmtWithoutParentheses := strings.ReplaceAll(simpleQuery.Sql.Stmt, "(", "") + oldStmtWithoutParentheses = strings.ReplaceAll(oldStmtWithoutParentheses, ")", "") + + newWhereStmt := simpleQuery.Sql.WhereStatement.Accept(whereStatementRenderer) + newStmtWithoutParentheses := strings.ReplaceAll(newWhereStmt.(string), "(", "") + newStmtWithoutParentheses = strings.ReplaceAll(newStmtWithoutParentheses, ")", "") + + assert.Equal(t, newStmtWithoutParentheses, oldStmtWithoutParentheses) + } + // the old where statement should be empty then... + // BUT have some Lucene fields to figure out ... + //assert.Equal(t, simpleQuery.Sql.Stmt, "") }) } } @@ -81,6 +99,19 @@ func TestQueryParserNoFullTextFields(t *testing.T) { assert.Equal(t, tt.WantedQueryType, queryInfo.Typ, "equals to wanted query type") query := cw.BuildNRowsQuery("*", simpleQuery, model.DefaultSizeListQuery) assert.Contains(t, tt.WantedQuery, *query) + // Test the new WhereStatement + if simpleQuery.Sql.WhereStatement != nil { + oldStmtWithoutParentheses := strings.ReplaceAll(simpleQuery.Sql.Stmt, "(", "") + oldStmtWithoutParentheses = strings.ReplaceAll(oldStmtWithoutParentheses, ")", "") + + newWhereStmt := simpleQuery.Sql.WhereStatement.Accept(whereStatementRenderer) + newStmtWithoutParentheses := strings.ReplaceAll(newWhereStmt.(string), "(", "") + newStmtWithoutParentheses = strings.ReplaceAll(newStmtWithoutParentheses, ")", "") + + assert.Equal(t, newStmtWithoutParentheses, oldStmtWithoutParentheses) + } else { // the old where statement should be empty then... + assert.Equal(t, simpleQuery.Sql.Stmt, "") + } }) } } @@ -106,6 +137,13 @@ func TestQueryParserNoAttrsConfig(t *testing.T) { assert.Equal(t, tt.WantedQueryType, queryInfo.Typ) query := cw.BuildNRowsQuery("*", simpleQuery, model.DefaultSizeListQuery) + if simpleQuery.Sql.WhereStatement != nil { + ss := simpleQuery.Sql.WhereStatement.Accept(whereStatementRenderer) + assert.Equal(t, simpleQuery.Sql.Stmt, ss.(string)) + } else { + oldOne := simpleQuery.Sql.Stmt + fmt.Printf("No new where statement but old one is [%s]", oldOne) + } assert.Contains(t, tt.WantedQuery, *query) }) } diff --git a/quesma/queryparser/where_clause/string_renderer.go b/quesma/queryparser/where_clause/string_renderer.go new file mode 100644 index 000000000..248183c92 --- /dev/null +++ b/quesma/queryparser/where_clause/string_renderer.go @@ -0,0 +1,70 @@ +package where_clause + +import ( + "fmt" + "strconv" + "strings" +) + +// StringRenderer is a visitor that renders the WHERE statement as a string +type StringRenderer struct { +} + +func (v *StringRenderer) VisitLiteral(e *Literal) interface{} { + return e.Name +} + +func (v *StringRenderer) VisitInfixOp(e *InfixOp) interface{} { + var lhs, rhs interface{} // TODO FOR NOW LITTLE PARANOID BUT HELPS ME NOT SEE MANY PANICS WHEN TESTING + if e.Left != nil { + lhs = e.Left.Accept(v) + } else { + lhs = "< LHS NIL >" + } + if e.Right != nil { + rhs = e.Right.Accept(v) + } else { + rhs = "< RHS NIL >" + } + // This might look like a strange heuristics to but is aligned with the way we are currently generating the statement + // I think in the future every infix op should be in braces. + if e.Op == "OR" { + return fmt.Sprintf("(%v %v %v)", lhs, e.Op, rhs) + } else if e.Op == "AND" || strings.Contains(e.Op, "LIKE") || e.Op == "IS" || e.Op == "IN" { + return fmt.Sprintf("%v %v %v", lhs, e.Op, rhs) + } else { + return fmt.Sprintf("%v%v%v", lhs, e.Op, rhs) + } +} + +func (v *StringRenderer) VisitPrefixOp(e *PrefixOp) interface{} { + args := make([]string, len(e.Args)) + for i, arg := range e.Args { + args[i] = arg.Accept(v).(string) + } + + argsAsString := strings.Join(args, ", ") + return fmt.Sprintf("%v (%v)", e.Op, argsAsString) +} + +func (v *StringRenderer) VisitFunction(e *Function) interface{} { + args := make([]string, len(e.Args)) + for i, arg := range e.Args { + args[i] = arg.Accept(v).(string) + } + + argsAsString := strings.Join(args, ",") + return fmt.Sprintf("%v(%v)", e.Name.Accept(v), argsAsString) +} + +func (v *StringRenderer) VisitColumnRef(e *ColumnRef) interface{} { + return strconv.Quote(e.ColumnName) +} + +func (v *StringRenderer) VisitNestedProperty(e *NestedProperty) interface{} { + return fmt.Sprintf("%v.%v", e.ColumnRef.Accept(v), e.PropertyName.Accept(v)) +} + +func (v *StringRenderer) VisitArrayAccess(e *ArrayAccess) interface{} { + return fmt.Sprintf("%v[%v]", e.ColumnRef.Accept(v), e.Index.Accept(v)) +} diff --git a/quesma/queryparser/where_clause/used_fields.go b/quesma/queryparser/where_clause/used_fields.go new file mode 100644 index 000000000..8cbe498f6 --- /dev/null +++ b/quesma/queryparser/where_clause/used_fields.go @@ -0,0 +1,64 @@ +package where_clause + +import "strings" + +// UsedFieldsVisitor is a visitor that fetches all fields (columns) used in a given where clause +type UsedFieldsVisitor struct { + Columns []*ColumnRef +} + +// Sheer beauty: +// colFetch := &where_clause.UsedFieldsVisitor{} +// parsedQuery.Sql.WhereStatement.Accept(colFetch) +// cc := colFetch.GetColumnsUsed() + +func (v *UsedFieldsVisitor) GetColumnsUsed() []*ColumnRef { + return v.Columns +} + +func (v *UsedFieldsVisitor) PrintColumnsUsed() string { + var columns []string + for _, col := range v.Columns { + columns = append(columns, col.ColumnName) + } + return strings.Join(columns, ", ") +} + +func (v *UsedFieldsVisitor) VisitLiteral(e *Literal) interface{} { + return nil +} + +func (v *UsedFieldsVisitor) VisitInfixOp(e *InfixOp) interface{} { + e.Left.Accept(v) + e.Right.Accept(v) + return nil +} + +func (v *UsedFieldsVisitor) VisitPrefixOp(e *PrefixOp) interface{} { + for _, arg := range e.Args { + arg.Accept(v) + } + return nil +} + +func (v *UsedFieldsVisitor) VisitFunction(e *Function) interface{} { + for _, arg := range e.Args { + arg.Accept(v) + } + return nil +} + +func (v *UsedFieldsVisitor) VisitColumnRef(e *ColumnRef) interface{} { + v.Columns = append(v.Columns, e) + return nil +} + +func (v *UsedFieldsVisitor) VisitNestedProperty(e *NestedProperty) interface{} { + v.Columns = append(v.Columns, &e.ColumnRef) + return nil +} + +func (v *UsedFieldsVisitor) VisitArrayAccess(e *ArrayAccess) interface{} { + v.Columns = append(v.Columns, &e.ColumnRef) + return nil +} diff --git a/quesma/queryparser/where_clause/where_clause.go b/quesma/queryparser/where_clause/where_clause.go new file mode 100644 index 000000000..e62d1e1a4 --- /dev/null +++ b/quesma/queryparser/where_clause/where_clause.go @@ -0,0 +1,133 @@ +package where_clause + +import "fmt" + +// Statement is main structure for WHERE clause +type Statement interface { + Accept(v StatementVisitor) interface{} +} + +// ColumnRef is a reference to a column in a table, we can enrich it with more information (e.g. type used) as we go +type ColumnRef struct { + ColumnName string +} + +func NewColumnRef(name string) *ColumnRef { + return &ColumnRef{ColumnName: name} +} + +func (e *ColumnRef) Accept(v StatementVisitor) interface{} { + return v.VisitColumnRef(e) +} + +type Literal struct { + Name string +} + +// NestedProperty for nested objects, e.g. `columnName.propertyName` +type NestedProperty struct { + ColumnRef ColumnRef + PropertyName Literal +} + +func NewNestedProperty(columnRef ColumnRef, propertyName Literal) *NestedProperty { + return &NestedProperty{ColumnRef: columnRef, PropertyName: propertyName} +} + +func (e *NestedProperty) Accept(v StatementVisitor) interface{} { return v.VisitNestedProperty(e) } + +// ArrayAccess for array accessing, e.g. `columnName[0]` +type ArrayAccess struct { + ColumnRef ColumnRef + Index Statement +} + +func NewArrayAccess(columnRef ColumnRef, index Statement) *ArrayAccess { + return &ArrayAccess{ColumnRef: columnRef, Index: index} +} + +func (e *ArrayAccess) Accept(v StatementVisitor) interface{} { return v.VisitArrayAccess(e) } + +func NewLiteral(name string) *Literal { + return &Literal{Name: name} +} + +func (e *Literal) String() string { + return fmt.Sprintf("(Literal %v)", e.Name) +} + +func (e *Literal) Accept(v StatementVisitor) interface{} { + return v.VisitLiteral(e) +} + +type InfixOp struct { + Left Statement + Op string + Right Statement +} + +func NewInfixOp(left Statement, op string, right Statement) *InfixOp { + return &InfixOp{ + Left: left, + Op: op, + Right: right, + } +} + +func (e *InfixOp) String() string { + return fmt.Sprintf("(infix '%v' %v %v)", e.Op, e.Left, e.Right) +} + +func (e *InfixOp) Accept(v StatementVisitor) interface{} { + return v.VisitInfixOp(e) +} + +type PrefixOp struct { + Op string + Args []Statement +} + +func NewPrefixOp(op string, args []Statement) *PrefixOp { + return &PrefixOp{ + Op: op, + Args: args, + } +} + +func (e *PrefixOp) String() string { + return fmt.Sprintf("(prefix '%v' %v)", e.Op, e.Args) +} + +func (e *PrefixOp) Accept(v StatementVisitor) interface{} { + return v.VisitPrefixOp(e) +} + +type Function struct { + Name Literal + Args []Statement +} + +func NewFunction(name string, args ...Statement) *Function { + return &Function{ + Name: Literal{Name: name}, + Args: args, + } +} + +func (e *Function) String() string { + return fmt.Sprintf("(function %v %v)", e.Name, e.Args) +} + +func (e *Function) Accept(v StatementVisitor) interface{} { + return v.VisitFunction(e) +} + +type StatementVisitor interface { + VisitLiteral(e *Literal) interface{} + VisitInfixOp(e *InfixOp) interface{} + VisitPrefixOp(e *PrefixOp) interface{} + VisitFunction(e *Function) interface{} + VisitColumnRef(e *ColumnRef) interface{} + VisitNestedProperty(e *NestedProperty) interface{} + VisitArrayAccess(e *ArrayAccess) interface{} +}