diff --git a/quesma/logger/log_with_throttling.go b/quesma/logger/log_with_throttling.go index 9d29309b3..4d920ca9f 100644 --- a/quesma/logger/log_with_throttling.go +++ b/quesma/logger/log_with_throttling.go @@ -8,6 +8,8 @@ import ( "time" ) +// throttleMap: (reason name -> last logged time) +// We log only once per throttleDuration for each reason name, so that we don't spam the logs. var throttleMap = util.NewSyncMap[string, time.Time]() const throttleDuration = 30 * time.Minute @@ -28,11 +30,11 @@ func WarnWithCtxAndThrottling(ctx context.Context, aggrName, paramName, format s // WarnWithThrottling - logs a warning message with throttling. // We only log once per throttleDuration for each warnName, so that we don't spam the logs. -func WarnWithThrottling(warnName, format string, v ...any) { - timestamp, ok := throttleMap.Load(warnName) +func WarnWithThrottling(reasonName, format string, v ...any) { + timestamp, ok := throttleMap.Load(reasonName) weThrottle := ok && time.Since(timestamp) < throttleDuration if !weThrottle { Warn().Msgf(format, v...) - throttleMap.Store(warnName, time.Now()) + throttleMap.Store(reasonName, time.Now()) } } diff --git a/quesma/model/base_visitor.go b/quesma/model/base_visitor.go index 71d68d04b..88b2cede6 100644 --- a/quesma/model/base_visitor.go +++ b/quesma/model/base_visitor.go @@ -45,11 +45,11 @@ func (v *BaseExprVisitor) VisitLiteral(e LiteralExpr) interface{} { return NewLiteral(e.Value) } -func (v *BaseExprVisitor) VisitTuple(e TupleExpr) interface{} { +func (v *BaseExprVisitor) VisitTuple(t TupleExpr) interface{} { if v.OverrideVisitTuple != nil { - return v.OverrideVisitTuple(v, e) + return v.OverrideVisitTuple(v, t) } - return NewTupleExpr(v.VisitChildren(e.Exprs)...) + return NewTupleExpr(v.VisitChildren(t.Exprs)...) } func (v *BaseExprVisitor) VisitInfix(e InfixExpr) interface{} { diff --git a/quesma/model/bucket_aggregations/terms.go b/quesma/model/bucket_aggregations/terms.go index e25cb0c0a..2c62ec6b2 100644 --- a/quesma/model/bucket_aggregations/terms.go +++ b/quesma/model/bucket_aggregations/terms.go @@ -12,6 +12,7 @@ import ( "reflect" ) +// TODO when adding include/exclude, check escaping of ' and \ in those fields type Terms struct { ctx context.Context significant bool // true <=> significant_terms, false <=> terms diff --git a/quesma/model/expr_string_renderer.go b/quesma/model/expr_string_renderer.go index f4c73fa6c..a93e2bf97 100644 --- a/quesma/model/expr_string_renderer.go +++ b/quesma/model/expr_string_renderer.go @@ -6,6 +6,7 @@ import ( "fmt" "quesma/logger" "quesma/quesma/types" + "quesma/util" "regexp" "sort" "strconv" @@ -66,7 +67,12 @@ func (v *renderer) VisitFunction(e FunctionExpr) interface{} { } func (v *renderer) VisitLiteral(l LiteralExpr) interface{} { - return fmt.Sprintf("%v", l.Value) + switch val := l.Value.(type) { + case string: + return escapeString(val) + default: + return fmt.Sprintf("%v", val) + } } func (v *renderer) VisitTuple(t TupleExpr) interface{} { @@ -349,3 +355,14 @@ func (v *renderer) VisitJoinExpr(j JoinExpr) interface{} { func (v *renderer) VisitCTE(c CTE) interface{} { return fmt.Sprintf("%s AS (%s) ", c.Name, AsString(c.SelectCommand)) } + +// escapeString escapes the given string so that it can be used in a SQL Clickhouse query. +// It escapes ' and \ characters: ' -> \', \ -> \\. +func escapeString(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) // \ should be escaped with no exceptions + if len(s) >= 2 && s[0] == '\'' && s[len(s)-1] == '\'' { + // don't escape the first and last ' + return util.SingleQuote(strings.ReplaceAll(s[1:len(s)-1], `'`, `\'`)) + } + return strings.ReplaceAll(s, `'`, `\'`) +} diff --git a/quesma/queryparser/query_parser.go b/quesma/queryparser/query_parser.go index e38f47de8..7a659de2c 100644 --- a/quesma/queryparser/query_parser.go +++ b/quesma/queryparser/query_parser.go @@ -513,12 +513,12 @@ func (cw *ClickhouseQueryTranslator) parseTerms(queryMap QueryMap) model.SimpleQ simpleStatement := model.NewInfixExpr(model.NewColumnRef(k), "=", model.NewLiteral(sprint(vAsArray[0]))) return model.NewSimpleQuery(simpleStatement, true) } - values := make([]string, len(vAsArray)) + values := make([]model.Expr, len(vAsArray)) for i, v := range vAsArray { - values[i] = sprint(v) + values[i] = model.NewLiteral(sprint(v)) } - combinedValues := "(" + strings.Join(values, ",") + ")" - compoundStatement := model.NewInfixExpr(model.NewColumnRef(k), "IN", model.NewLiteral(combinedValues)) + tuple := model.NewTupleExpr(values...) + compoundStatement := model.NewInfixExpr(model.NewColumnRef(k), "IN", tuple) return model.NewSimpleQuery(compoundStatement, true) } diff --git a/quesma/testdata/aggregation_requests_2.go b/quesma/testdata/aggregation_requests_2.go index 21acfba83..6ea38f072 100644 --- a/quesma/testdata/aggregation_requests_2.go +++ b/quesma/testdata/aggregation_requests_2.go @@ -4691,7 +4691,6 @@ var AggregationTests2 = []AggregationTestCase{ }, { // [70] TestName: "simplest terms with exclude (array of values)", - // TODO add ' somewhere in exclude after the merge! QueryRequestJson: ` { "aggs": { @@ -4699,7 +4698,7 @@ var AggregationTests2 = []AggregationTestCase{ "terms": { "field": "chess_goat", "size": 2, - "exclude": ["Carlsen", "Kasparov", "Fis._er*"] + "exclude": ["Carlsen", "Kasparov", "Fis._er'*"] } } }, @@ -4740,10 +4739,10 @@ var AggregationTests2 = []AggregationTestCase{ }, ExpectedPancakeSQL: ` SELECT sum(count(*)) OVER () AS "aggr__1__parent_count", - if("chess_goat" NOT IN tuple('Carlsen', 'Kasparov', 'Fis._er*'), "chess_goat", NULL) + if("chess_goat" NOT IN tuple('Carlsen', 'Kasparov', 'Fis._er\'*'), "chess_goat", NULL) AS "aggr__1__key_0", count(*) AS "aggr__1__count" FROM __quesma_table_name - GROUP BY if("chess_goat" NOT IN tuple('Carlsen', 'Kasparov', 'Fis._er*'), "chess_goat", NULL) AS "aggr__1__key_0" + GROUP BY if("chess_goat" NOT IN tuple('Carlsen', 'Kasparov', 'Fis._er\'*'), "chess_goat", NULL) AS "aggr__1__key_0" ORDER BY "aggr__1__count" DESC, "aggr__1__key_0" ASC LIMIT 3`, }, @@ -5290,4 +5289,75 @@ var AggregationTests2 = []AggregationTestCase{ ORDER BY "aggr__terms2__count" DESC, "aggr__terms2__key_0" ASC LIMIT 3`}, }, + { // [77] + TestName: `Escaping of ', \, \n, and \t in some example aggregations. No tests for other escape characters, e.g. \r or 'b. Add if needed.`, + QueryRequestJson: ` + { + "aggs": { + "avg": { + "avg": { + "field": "@timestamp's\\" + } + }, + "terms": { + "terms": { + "field": "agent.keyword", + "size": 1, + "missing": "quote ' and slash \\ Also \t \n" + } + } + }, + "size": 0 + }`, + ExpectedResponse: ` + { + "_shards": { + "failed": 0, + "skipped": 0, + "successful": 1, + "total": 1 + }, + "aggregations": { + "avg": { + "value": null + }, + "terms": { + "buckets": [ + { + "doc_count": 5362, + "key": "Mozilla/5.0 (X11; Linux x86_64; rv:6.0a1) Gecko/20110421 Firefox/6.0a1" + } + ], + "doc_count_error_upper_bound": 0, + "sum_other_doc_count": 8712 + } + }, + "hits": { + "hits": [], + "max_score": null + }, + "timed_out": false, + "took": 5 + }`, + ExpectedPancakeResults: []model.QueryResultRow{ + {Cols: []model.QueryResultCol{ + model.NewQueryResultCol("metric__avg_col_0", nil), + model.NewQueryResultCol("aggr__terms__parent_count", int64(14074)), + model.NewQueryResultCol("aggr__terms__key_0", "Mozilla/5.0 (X11; Linux x86_64; rv:6.0a1) Gecko/20110421 Firefox/6.0a1"), + model.NewQueryResultCol("aggr__terms__count", int64(5362)), + }}, + }, + ExpectedPancakeSQL: ` + SELECT avgOrNullMerge(avgOrNullState("@timestamp's\\")) OVER () AS + "metric__avg_col_0", sum(count(*)) OVER () AS "aggr__terms__parent_count", + COALESCE("agent", 'quote \' and slash \\ Also +') AS "aggr__terms__key_0", + count(*) AS "aggr__terms__count" + FROM __quesma_table_name + GROUP BY COALESCE("agent", 'quote \' and slash \\ Also +') AS + "aggr__terms__key_0" + ORDER BY "aggr__terms__count" DESC, "aggr__terms__key_0" ASC + LIMIT 1`, + }, } diff --git a/quesma/testdata/requests.go b/quesma/testdata/requests.go index a1fc2bd0f..ea694aab3 100644 --- a/quesma/testdata/requests.go +++ b/quesma/testdata/requests.go @@ -1006,7 +1006,7 @@ var TestsSearch = []SearchTestCase{ }, { "terms": { - "task.enabled": [true, 54] + "task.enabled": [true, 54, "abc", "abc's"] } } ] @@ -1014,10 +1014,10 @@ var TestsSearch = []SearchTestCase{ }, "track_total_hits": true }`, - []string{`("type"='task' AND "task.enabled" IN (true,54))`}, + []string{`("type"='task' AND "task.enabled" IN tuple(true, 54, 'abc', 'abc\'s'))`}, model.ListAllFields, []string{ - `SELECT "message" FROM ` + TableName + ` WHERE ("type"='task' AND "task.enabled" IN (true,54)) LIMIT 10`, + `SELECT "message" FROM ` + TableName + ` WHERE ("type"='task' AND "task.enabled" IN tuple(true, 54, 'abc', 'abc\\'s')) LIMIT 10`, `SELECT count(*) AS "column_0" FROM ` + TableName, }, []string{}, @@ -2196,13 +2196,13 @@ var TestsSearch = []SearchTestCase{ }, "track_total_hits": false }`, - []string{`("cliIP" IN ('2601:204:c503:c240:9c41:5531:ad94:4d90','50.116.43.98','75.246.0.64') AND ("@timestamp">=fromUnixTimestamp64Milli(1715817600000) AND "@timestamp"<=fromUnixTimestamp64Milli(1715990399000)))`}, + []string{`("cliIP" IN tuple('2601:204:c503:c240:9c41:5531:ad94:4d90', '50.116.43.98', '75.246.0.64') AND ("@timestamp">=fromUnixTimestamp64Milli(1715817600000) AND "@timestamp"<=fromUnixTimestamp64Milli(1715990399000)))`}, model.ListAllFields, //[]model.Query{withLimit(justSimplestWhere(`("cliIP" IN ('2601:204:c503:c240:9c41:5531:ad94:4d90','50.116.43.98','75.246.0.64') AND ("@timestamp">=parseDateTime64BestEffort('2024-05-16T00:00:00') AND "@timestamp"<=parseDateTime64BestEffort('2024-05-17T23:59:59')))`), 1)}, []string{ `SELECT "message" ` + `FROM ` + TableName + ` ` + - `WHERE ("cliIP" IN ('2601:204:c503:c240:9c41:5531:ad94:4d90','50.116.43.98','75.246.0.64') ` + + `WHERE ("cliIP" IN tuple('2601:204:c503:c240:9c41:5531:ad94:4d90', '50.116.43.98', '75.246.0.64') ` + `AND ("@timestamp">=fromUnixTimestamp64Milli(1715817600000) AND "@timestamp"<=fromUnixTimestamp64Milli(1715990399000))) ` + `LIMIT 1`, }, @@ -2254,12 +2254,14 @@ var TestsSearch = []SearchTestCase{ }, "track_total_hits": false }`, - []string{`"field" LIKE '%\___'`}, + // Escaping _ twice ("\\_") seemed wrong, but it actually works in Clickhouse! + // \\\\ means 2 escaped backslashes, actual returned string is "\\" + []string{`"field" LIKE '%\\___'`}, model.ListAllFields, []string{ `SELECT "message" ` + `FROM ` + TableName + ` ` + - `WHERE "field" LIKE '%\\___' ` + + `WHERE "field" LIKE '%\\\\___' ` + `LIMIT 10`, }, []string{}, @@ -2321,6 +2323,30 @@ var TestsSearch = []SearchTestCase{ []string{}, }, { // [40] + `Escaping of ', \, \t and \n`, + ` + { + "query": { + "bool": { + "filter": [ + { + "match_phrase": { + "message": "\nMen's Clothing \\ \t" + } + } + ] + } + }, + "track_total_hits": false + }`, + []string{`("message" __quesma_match ' +Men\'s Clothing \\ ')`}, + model.ListAllFields, + []string{`SELECT "message" FROM ` + TableName + ` WHERE "message" iLIKE '% +Men\\'s Clothing \\\\ %' LIMIT 10`}, + []string{}, + }, + { // [41] "ids, 0 values", `{ "query": { @@ -2340,7 +2366,7 @@ var TestsSearch = []SearchTestCase{ }, []string{}, }, - { // [41] + { // [42] "ids, 1 value", `{ "query": { @@ -2360,7 +2386,7 @@ var TestsSearch = []SearchTestCase{ }, []string{}, }, - { // [42] + { // [43] "ids, 2+ values", `{ "query": { diff --git a/quesma/util/utils.go b/quesma/util/utils.go index fdb377cfa..2f58ba916 100644 --- a/quesma/util/utils.go +++ b/quesma/util/utils.go @@ -737,6 +737,11 @@ func ExtractNumeric64(value any) float64 { return asFloat64 } +// SingleQuote is a simple helper function: str -> 'str' +func SingleQuote(value string) string { + return "'" + value + "'" +} + type sqlMockMismatchSql struct { expected string actual string @@ -847,7 +852,7 @@ func stringifyHelper(v interface{}, isInsideArray bool) string { // This functions returns a string from an interface{}. func Stringify(v interface{}) string { - isInsideArray := false + const isInsideArray = false return stringifyHelper(v, isInsideArray) }