Skip to content

Commit

Permalink
WHERE clause refactor (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
mieciu authored May 23, 2024
1 parent c58a44f commit 5cb86d5
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 36 deletions.
23 changes: 19 additions & 4 deletions quesma/model/simple_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package model
import (
"context"
"mitmproxy/quesma/logger"
"mitmproxy/quesma/queryparser/where_clause"
)

type SimpleQuery struct {
Expand All @@ -21,9 +22,11 @@ func NewSimpleQueryWithFieldName(sql Statement, canParse bool, fieldName string)
}

type Statement struct {
Stmt string
IsCompound bool // "a" -> not compound, "a AND b" -> compound. Used to not make unnecessary brackets (not always, but usually)
FieldName string
// deprecated - we're moving to the new WhereStatement which should also remove the need for IsCompound and FieldName
Stmt string // Old, clunky and soon to be deprecated version
WhereStatement where_clause.Statement // New, better and bold version
IsCompound bool // "a" -> not compound, "a AND b" -> compound. Used to not make unnecessary brackets (not always, but usually)
FieldName string
}

func NewSimpleStatement(stmt string) Statement {
Expand Down Expand Up @@ -52,6 +55,13 @@ func Or(orStmts []Statement) Statement {
// sep = "AND" or "OR"
func combineStatements(stmts []Statement, sep string) Statement {
stmts = FilterNonEmpty(stmts)
var newWhereStatement where_clause.Statement
if len(stmts) > 0 {
newWhereStatement = stmts[0].WhereStatement
for _, stmt := range stmts[1:] {
newWhereStatement = where_clause.NewInfixOp(newWhereStatement, sep, stmt.WhereStatement)
}
}
if len(stmts) > 1 {
stmts = quoteWithBracketsIfCompound(stmts)
var fieldName string
Expand All @@ -65,7 +75,12 @@ func combineStatements(stmts []Statement, sep string) Statement {
fieldName = stmt.FieldName
}
}
return NewCompoundStatement(sql, fieldName)
return Statement{
WhereStatement: newWhereStatement,
Stmt: sql,
IsCompound: true,
FieldName: fieldName,
}
}
if len(stmts) == 1 {
return stmts[0]
Expand Down
11 changes: 9 additions & 2 deletions quesma/model/simple_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,23 @@ func TestOrAndAnd(t *testing.T) {
t.Run("AND "+strconv.Itoa(i), func(t *testing.T) {
b := make([]Statement, len(tt.stmts))
copy(b, tt.stmts)
assert.Equal(t, tt.want, And(b))
tt.want.WhereStatement = nil
finalAnd := And(b)
finalAnd.WhereStatement = nil
assert.Equal(t, tt.want, finalAnd)
})
}
for i, tt := range tests {
t.Run("OR "+strconv.Itoa(i), func(t *testing.T) {
tt.want.WhereStatement = nil
tt.want.Stmt = strings.ReplaceAll(tt.want.Stmt, "AND", "OR")
for i := range tt.stmts {
tt.stmts[i].Stmt = strings.ReplaceAll(tt.stmts[i].Stmt, "AND", "OR")
}
assert.Equal(t, tt.want, Or(tt.stmts))
tt.want.WhereStatement = nil
finalOr := Or(tt.stmts)
finalOr.WhereStatement = nil
assert.Equal(t, tt.want, finalOr)
})
}
}
117 changes: 87 additions & 30 deletions quesma/queryparser/query_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package queryparser

import (
"encoding/json"
"mitmproxy/quesma/queryparser/where_clause"

"fmt"
"github.com/k0kubun/pp"
"mitmproxy/quesma/clickhouse"
Expand Down Expand Up @@ -240,7 +242,9 @@ func (cw *ClickhouseQueryTranslator) ParseAutocomplete(indexFilter *QueryMap, fi
like = "LIKE"
}
cw.AddTokenToHighlight(*prefix)
stmts = append(stmts, model.NewSimpleStatement(fieldName+" "+like+" '"+*prefix+"%'"))
simpleStat := model.NewSimpleStatement(fieldName + " " + like + " '" + *prefix + "%'")
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldName), like, where_clause.NewLiteral("'"+*prefix+"%'"))
stmts = append(stmts, simpleStat)
}
return model.NewSimpleQuery(model.And(stmts), canParse)
}
Expand Down Expand Up @@ -333,19 +337,24 @@ func (cw *ClickhouseQueryTranslator) parseIds(queryMap QueryMap) model.SimpleQue
}
}

var statement string
var statement model.Statement
if v, ok := cw.Table.Cols[timestampColumnName]; ok {
switch v.Type.String() {
case clickhouse.DateTime64.String():
statement = fmt.Sprintf("toUnixTimestamp64Milli(%s) IN (%s)", strconv.Quote(timestampColumnName), ids)
statement = model.NewSimpleStatement(fmt.Sprintf("toUnixTimestamp64Milli(%s) IN (%s)", strconv.Quote(timestampColumnName), ids))
statement.WhereStatement = where_clause.NewInfixOp(where_clause.NewFunction("toUnixTimestamp64Milli", []where_clause.Statement{where_clause.NewColumnRef(timestampColumnName)}...), "IN", where_clause.NewLiteral("(["+strings.Join(ids, ",")+"])"))
case clickhouse.DateTime.String():
statement = fmt.Sprintf("toUnixTimestamp(%s) *1000 IN (%s)", strconv.Quote(timestampColumnName), ids)
statement = model.NewSimpleStatement(fmt.Sprintf("toUnixTimestamp(%s) * 1000 IN (%s)", strconv.Quote(timestampColumnName), ids))
statement.WhereStatement = where_clause.NewInfixOp(where_clause.NewInfixOp(
where_clause.NewFunction("toUnixTimestamp", []where_clause.Statement{where_clause.NewColumnRef(timestampColumnName)}...),
"*",
where_clause.NewLiteral("1000")), "IN", where_clause.NewLiteral("("+strings.Join(ids, ",")+")"))
default:
logger.Warn().Msgf("timestamp field of unsupported type %s", v.Type.String())
return model.NewSimpleQuery(model.NewSimpleStatement(""), true)
}
}
return model.NewSimpleQuery(model.NewSimpleStatement(statement), true)
return model.NewSimpleQuery(statement, true)
}

// Parses each model.SimpleQuery separately, returns list of translated SQLs
Expand Down Expand Up @@ -426,6 +435,7 @@ func (cw *ClickhouseQueryTranslator) parseBool(queryMap QueryMap) model.SimpleQu
canParse = canParse && canParseThis
if len(sqlNots) > 0 {
orSql := model.Or(sqlNots)
orSql.WhereStatement = where_clause.NewPrefixOp("NOT", []where_clause.Statement{orSql.WhereStatement})
if orSql.IsCompound {
orSql.Stmt = "NOT (" + orSql.Stmt + ")"
orSql.IsCompound = false // NOT (compound) is again simple
Expand All @@ -444,9 +454,13 @@ func (cw *ClickhouseQueryTranslator) parseTerm(queryMap QueryMap) model.SimpleQu
cw.AddTokenToHighlight(v)
if k == "_index" { // index is a table name, already taken from URI and moved to FROM clause
logger.Warn().Msgf("term %s=%v in query body, ignoring in result SQL", k, v)
return model.NewSimpleQuery(model.NewSimpleStatement(" 0=0 /* "+strconv.Quote(k)+"="+sprint(v)+" */ "), true)
simpleStat := model.NewSimpleStatement(" 0=0 /* " + strconv.Quote(k) + "=" + sprint(v) + " */ ")
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewLiteral("0"), "=", where_clause.NewLiteral("0 /* "+k+"="+sprint(v)+" */"))
return model.NewSimpleQuery(simpleStat, true)
}
return model.NewSimpleQuery(model.NewSimpleStatement(strconv.Quote(k)+"="+sprint(v)), true)
simpleStat := model.NewSimpleStatement(strconv.Quote(k) + "=" + sprint(v))
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(k), "=", where_clause.NewLiteral(sprint(v)))
return model.NewSimpleQuery(simpleStat, true)
}
}
logger.WarnWithCtx(cw.Ctx).Msgf("we expect only 1 term, got: %d. value: %v", len(queryMap), queryMap)
Expand Down Expand Up @@ -474,7 +488,9 @@ func (cw *ClickhouseQueryTranslator) parseTerms(queryMap QueryMap) model.SimpleQ
orStmts := make([]model.Statement, len(vAsArray))
for i, v := range vAsArray {
cw.AddTokenToHighlight(v)
orStmts[i] = model.NewSimpleStatement(strconv.Quote(k) + "=" + sprint(v))
simpleStat := model.NewSimpleStatement(strconv.Quote(k) + "=" + sprint(v))
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(k), "=", where_clause.NewLiteral(sprint(v)))
orStmts[i] = simpleStat
}
return model.NewSimpleQuery(model.Or(orStmts), true)
}
Expand Down Expand Up @@ -527,7 +543,9 @@ func (cw *ClickhouseQueryTranslator) parseMatch(queryMap QueryMap, matchPhrase b
computedIdMatchingQuery := cw.parseIds(QueryMap{"values": []interface{}{subQuery}})
statements = append(statements, computedIdMatchingQuery.Sql)
} else {
statements = append(statements, model.NewSimpleStatement(strconv.Quote(fieldName)+" iLIKE "+"'%"+subQuery+"%'"))
simpleStat := model.NewSimpleStatement(strconv.Quote(fieldName) + " iLIKE " + "'%" + subQuery + "%'")
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldName), "iLIKE", where_clause.NewLiteral("'%"+subQuery+"%'"))
statements = append(statements, simpleStat)
}
}
return model.NewSimpleQuery(model.Or(statements), true)
Expand Down Expand Up @@ -557,19 +575,21 @@ func (cw *ClickhouseQueryTranslator) parseMultiMatch(queryMap QueryMap) model.Si
} else {
fields = cw.Table.GetFulltextFields()
}
alwaysFalseStmt := model.AlwaysFalseStatement
alwaysFalseStmt.WhereStatement = where_clause.NewLiteral("false")
if len(fields) == 0 {
return model.NewSimpleQuery(model.AlwaysFalseStatement, true)
return model.NewSimpleQuery(alwaysFalseStmt, true)
}

query, ok := queryMap["query"]
if !ok {
logger.WarnWithCtx(cw.Ctx).Msgf("no query in multi_match query: %v", queryMap)
return model.NewSimpleQuery(model.AlwaysFalseStatement, false)
return model.NewSimpleQuery(alwaysFalseStmt, false)
}
queryAsString, ok := query.(string)
if !ok {
logger.WarnWithCtx(cw.Ctx).Msgf("invalid query type: %T, value: %v", query, query)
return model.NewSimpleQuery(model.AlwaysFalseStatement, false)
return model.NewSimpleQuery(alwaysFalseStmt, false)
}
var subQueries []string
wereDone := false
Expand All @@ -595,7 +615,9 @@ func (cw *ClickhouseQueryTranslator) parseMultiMatch(queryMap QueryMap) model.Si
i := 0
for _, field := range fields {
for _, subQ := range subQueries {
sqls[i] = model.NewSimpleStatement(strconv.Quote(field) + " iLIKE '%" + subQ + "%'")
simpleStat := model.NewSimpleStatement(strconv.Quote(field) + " iLIKE '%" + subQ + "%'")
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(field), "iLIKE", where_clause.NewLiteral("'%"+subQ+"%'"))
sqls[i] = simpleStat
i++
}
}
Expand All @@ -614,11 +636,15 @@ func (cw *ClickhouseQueryTranslator) parsePrefix(queryMap QueryMap) model.Simple
switch vCasted := v.(type) {
case string:
cw.AddTokenToHighlight(vCasted)
return model.NewSimpleQuery(model.NewSimpleStatement(strconv.Quote(fieldName)+" iLIKE '"+vCasted+"%'"), true)
simpleStat := model.NewSimpleStatement(strconv.Quote(fieldName) + " iLIKE '" + vCasted + "%'")
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldName), "iLIKE", where_clause.NewLiteral("'"+vCasted+"%'"))
return model.NewSimpleQuery(simpleStat, true)
case QueryMap:
token := vCasted["value"].(string)
cw.AddTokenToHighlight(token)
return model.NewSimpleQuery(model.NewSimpleStatement(strconv.Quote(fieldName)+" iLIKE '"+token+"%'"), true)
simpleStat := model.NewSimpleStatement(strconv.Quote(fieldName) + " iLIKE '" + token + "%'")
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldName), "iLIKE", where_clause.NewLiteral("'"+token+"%'"))
return model.NewSimpleQuery(simpleStat, true)
default:
logger.WarnWithCtx(cw.Ctx).Msgf("unsupported prefix type: %T, value: %v", v, v)
return model.NewSimpleQuery(model.NewSimpleStatement("unsupported prefix type"), false)
Expand All @@ -645,8 +671,10 @@ func (cw *ClickhouseQueryTranslator) parseWildcard(queryMap QueryMap) model.Simp
if value, ok := vAsMap["value"]; ok {
if valueAsString, ok := value.(string); ok {
cw.AddTokenToHighlight(valueAsString)
return model.NewSimpleQuery(model.NewSimpleStatement(strconv.Quote(fieldName)+" iLIKE '"+
strings.ReplaceAll(valueAsString, "*", "%")+"'"), true)
simpleStat := model.NewSimpleStatement(strconv.Quote(fieldName) + " iLIKE '" +
strings.ReplaceAll(valueAsString, "*", "%") + "'")
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldName), "iLIKE", where_clause.NewLiteral("'"+strings.ReplaceAll(valueAsString, "*", "%")+"'"))
return model.NewSimpleQuery(simpleStat, true)
} else {
logger.WarnWithCtx(cw.Ctx).Msgf("invalid value type: %T, value: %v", value, value)
return model.NewSimpleQuery(model.NewSimpleStatement("invalid value type"), false)
Expand Down Expand Up @@ -749,10 +777,13 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ
}

for op, v := range v.(QueryMap) {
var fieldToPrint, timeFormatFuncName string
var valueToCompare where_clause.Statement
fieldType := cw.Table.GetDateTimeType(cw.Ctx, field)
vToPrint := sprint(v)
var fieldToPrint string
valueToCompare = where_clause.NewLiteral(vToPrint)
if !isDatetimeInDefaultFormat {
timeFormatFuncName = "toUnixTimestamp64Milli"
fieldToPrint = "toUnixTimestamp64Milli(" + strconv.Quote(field) + ")"
} else {
fieldToPrint = strconv.Quote(field)
Expand All @@ -762,7 +793,9 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ
// if it's a date, we need to parse it to Clickhouse's DateTime format
// how to check if it does not contain date math expression?
if _, err := time.Parse(time.RFC3339Nano, dateTime); err == nil {
vToPrint = cw.parseDateTimeString(cw.Table, field, dateTime)
vToPrint, timeFormatFuncName = cw.parseDateTimeString(cw.Table, field, dateTime)
// TODO Investigate the quotation below
valueToCompare = where_clause.NewFunction(timeFormatFuncName, where_clause.NewLiteral(fmt.Sprintf("'%s'", dateTime)))
} else if op == "gte" || op == "lte" || op == "gt" || op == "lt" {
vToPrint, err = cw.parseDateMathExpression(vToPrint)
if err != nil {
Expand All @@ -772,6 +805,7 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ
}
} else if v == nil {
vToPrint = "NULL"
valueToCompare = where_clause.NewLiteral("NULL")
}
case clickhouse.Invalid: // assumes it is number that does not need formatting
if len(vToPrint) > 2 && vToPrint[0] == '\'' && vToPrint[len(vToPrint)-1] == '\'' {
Expand All @@ -786,6 +820,7 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ
} else {
logger.WarnWithCtx(cw.Ctx).Msgf("we use range with unknown literal %s, field %s", vToPrint, field)
}
valueToCompare = where_clause.NewLiteral(vToPrint)
}
default:
logger.WarnWithCtx(cw.Ctx).Msgf("invalid DateTime type for field: %s, parsed dateTime value: %s", field, vToPrint)
Expand All @@ -794,13 +829,21 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ

switch op {
case "gte":
stmts = append(stmts, model.NewSimpleStatement(fieldToPrint+">="+vToPrint))
simpleStat := model.NewSimpleStatement(fieldToPrint + ">=" + vToPrint)
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(field), ">=", valueToCompare)
stmts = append(stmts, simpleStat)
case "lte":
stmts = append(stmts, model.NewSimpleStatement(fieldToPrint+"<="+vToPrint))
simpleStat := model.NewSimpleStatement(fieldToPrint + "<=" + vToPrint)
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(field), "<=", valueToCompare)
stmts = append(stmts, simpleStat)
case "gt":
stmts = append(stmts, model.NewSimpleStatement(fieldToPrint+">"+vToPrint))
simpleStat := model.NewSimpleStatement(fieldToPrint + ">" + vToPrint)
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(field), ">", valueToCompare)
stmts = append(stmts, simpleStat)
case "lt":
stmts = append(stmts, model.NewSimpleStatement(fieldToPrint+"<"+vToPrint))
simpleStat := model.NewSimpleStatement(fieldToPrint + "<" + vToPrint)
simpleStat.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(field), "<", valueToCompare)
stmts = append(stmts, simpleStat)
case "format":
// ignored
default:
Expand All @@ -816,16 +859,16 @@ func (cw *ClickhouseQueryTranslator) parseRange(queryMap QueryMap) model.SimpleQ
}

// parseDateTimeString returns string used to parse DateTime in Clickhouse (depends on column type)
func (cw *ClickhouseQueryTranslator) parseDateTimeString(table *clickhouse.Table, field, dateTime string) string {
func (cw *ClickhouseQueryTranslator) parseDateTimeString(table *clickhouse.Table, field, dateTime string) (string, string) {
typ := table.GetDateTimeType(cw.Ctx, field)
switch typ {
case clickhouse.DateTime64:
return "parseDateTime64BestEffort('" + dateTime + "')"
return "parseDateTime64BestEffort('" + dateTime + "')", "parseDateTime64BestEffort"
case clickhouse.DateTime:
return "parseDateTimeBestEffort('" + dateTime + "')"
return "parseDateTimeBestEffort('" + dateTime + "')", "parseDateTimeBestEffort"
default:
logger.Error().Msgf("invalid DateTime type: %T for field: %s, parsed dateTime value: %s", typ, field, dateTime)
return ""
return "", ""
}
}

Expand All @@ -846,19 +889,33 @@ func (cw *ClickhouseQueryTranslator) parseExists(queryMap QueryMap) model.Simple

switch cw.Table.GetFieldInfo(cw.Ctx, fieldName) {
case clickhouse.ExistsAndIsBaseType:
sql = model.NewSimpleStatement(fieldNameQuoted + " IS NOT NULL")
simpleStatement := model.NewSimpleStatement(fieldNameQuoted + " IS NOT NULL")
simpleStatement.WhereStatement = where_clause.NewInfixOp(where_clause.NewColumnRef(fieldNameQuoted), "IS", where_clause.NewLiteral("NOT NULL"))
statement := simpleStatement
sql = statement
case clickhouse.ExistsAndIsArray:
sql = model.NewSimpleStatement(fieldNameQuoted + ".size0 = 0")
statement := model.NewSimpleStatement(fieldNameQuoted + ".size0 = 0")
statement.WhereStatement = where_clause.NewInfixOp(where_clause.NewNestedProperty(
*where_clause.NewColumnRef(fieldNameQuoted),
*where_clause.NewLiteral("size0"),
), "=", where_clause.NewLiteral("0"))
sql = statement
case clickhouse.NotExists:
attrs := cw.Table.GetAttributesList()
stmts := make([]model.Statement, len(attrs))
for i, a := range attrs {
stmts[i] = model.NewCompoundStatementNoFieldName(
compoundStatementNoFieldName := model.NewCompoundStatementNoFieldName(
fmt.Sprintf("has(%s,%s) AND %s[indexOf(%s,%s)] IS NOT NULL",
strconv.Quote(a.KeysArrayName), fieldNameQuoted, strconv.Quote(a.ValuesArrayName),
strconv.Quote(a.KeysArrayName), fieldNameQuoted,
),
)
compoundStatementNoFieldName.WhereStatement = nil
hasFunc := where_clause.NewFunction("has", []where_clause.Statement{where_clause.NewColumnRef(a.KeysArrayName), where_clause.NewColumnRef(fieldName)}...)
arrayAccess := where_clause.NewArrayAccess(*where_clause.NewColumnRef(a.ValuesArrayName), where_clause.NewFunction("indexOf", []where_clause.Statement{where_clause.NewColumnRef(a.KeysArrayName), where_clause.NewLiteral(fieldNameQuoted)}...))
isNotNull := where_clause.NewInfixOp(arrayAccess, "IS", where_clause.NewLiteral("NOT NULL"))
compoundStatementNoFieldName.WhereStatement = where_clause.NewInfixOp(hasFunc, "AND", isNotNull)
stmts[i] = compoundStatementNoFieldName
}
sql = model.Or(stmts)
default:
Expand Down
Loading

0 comments on commit 5cb86d5

Please sign in to comment.