Skip to content

Commit

Permalink
[sample_flights] Fix total in top_hits response (#1337)
Browse files Browse the repository at this point in the history
Makes the last skipped `sample_flights` test pass.
In other aggregations we have `parent_count` for that. Here I don't add
a new column to select, just extract from the existing ones, as it's
always there and I find it much simpler to fix it this way. We can
change that later if needed.
  • Loading branch information
trzysiek authored Mar 3, 2025
1 parent 97eec91 commit 5b7ed7f
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 23 deletions.
35 changes: 32 additions & 3 deletions platform/model/metrics_aggregations/top_hits.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"context"
"github.com/QuesmaOrg/quesma/platform/logger"
"github.com/QuesmaOrg/quesma/platform/model"
"github.com/QuesmaOrg/quesma/platform/util"
"strings"
)

type TopHits struct {
Expand All @@ -23,7 +25,6 @@ func (query *TopHits) AggregationType() model.AggregationType {
return model.MetricsAggregation
}

// TODO: implement correct
func (query *TopHits) TranslateSqlResponseToJson(rows []model.QueryResultRow) model.JsonMap {
var topElems []any
if len(rows) > 0 && 0 >= len(rows[0].Cols) {
Expand All @@ -39,7 +40,13 @@ func (query *TopHits) TranslateSqlResponseToJson(rows []model.QueryResultRow) mo
continue
}

valuesForHits := row.Cols
var valuesForHits []model.QueryResultCol
if query.isCount(row.Cols[0]) {
valuesForHits = row.Cols[1:]
} else {
valuesForHits = row.Cols
}

sourceMap := model.JsonMap{}

for _, col := range valuesForHits {
Expand All @@ -63,13 +70,19 @@ func (query *TopHits) TranslateSqlResponseToJson(rows []model.QueryResultRow) mo
if len(topElems) == 0 {
maxScore = nil
}

var total int
if len(rows) > 0 {
total = query.getCount(&rows[0])
}

return model.JsonMap{
"hits": model.JsonMap{
"hits": topElems,
"max_score": maxScore, // placeholder
"total": model.JsonMap{ // could be better
"relation": "eq", // TODO: wrong, but let's pass test, it should ge geq
"value": len(topElems),
"value": total,
},
},
}
Expand All @@ -78,3 +91,19 @@ func (query *TopHits) TranslateSqlResponseToJson(rows []model.QueryResultRow) mo
func (query *TopHits) String() string {
return "top_hits"
}

func (query *TopHits) getCount(row *model.QueryResultRow) int {
if len(row.Cols) == 0 {
return 0
}
if asInt, ok := util.ExtractInt64Maybe(row.Cols[0].ExtractValue()); ok {
return int(asInt)
} else {
logger.WarnWithCtxAndThrottling(query.ctx, "top_hits", "count", "could not extract count from top_hits, row: %v", row)
return 0
}
}

func (query *TopHits) isCount(col model.QueryResultCol) bool {
return strings.HasSuffix(col.ColName, "count")
}
3 changes: 3 additions & 0 deletions platform/parsers/elastic_query_dsl/pancake_json_rendering.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ func (p *pancakeJSONRenderer) selectTopHitsRows(topAggr *pancakeModelMetricAggre
}
newCols = append(newCols, col)
}
} else if topAggr.isColumnParentCount(col.ColName) {
// top_hits needs parent count, when it's available
newCols = append(newCols, col)
}
}
result = append(result, model.QueryResultRow{Index: row.Index, Cols: newCols})
Expand Down
19 changes: 19 additions & 0 deletions platform/parsers/elastic_query_dsl/pancake_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,25 @@ func (p pancakeModelMetricAggregation) InternalNameForCol(id int) string {
return fmt.Sprintf("%s%d", p.InternalNamePrefix(), id)
}

// isColumnParentCount checks if `internalName` is a parent count column for this metric aggregation
// Only tested/works for `top_hits`, not needed anywhere else.
func (p pancakeModelMetricAggregation) isColumnParentCount(internalNameMaybeParent string) bool {
// We return true only when:
// p.internalName ==."top_hits__[AGG_PATH]__[name]"
// AND internalNameMaybeParent == "aggr__[AGG_PATH]__count"
// (AGG_PATH must be the same)
thisAggrRegex := regexp.MustCompile("top_hits__([a-zA-Z0-9_]+)__[a-zA-Z0-9_]+")
maybeParentRegex := regexp.MustCompile("aggr__([a-zA-Z0-9_]+)__count")
if !thisAggrRegex.MatchString(p.internalName) || !maybeParentRegex.MatchString(internalNameMaybeParent) {
return false
}

matchThisAggr := thisAggrRegex.FindStringSubmatch(p.InternalNamePrefix())
matchMaybeParent := maybeParentRegex.FindStringSubmatch(internalNameMaybeParent)
// [1] is the first capturing group in the regex (called AGG_PATH above). It's ([a-zA-Z0-9_]+) from the regex
return len(matchThisAggr) == 2 && len(matchMaybeParent) == 2 && matchThisAggr[1] == matchMaybeParent[1]
}

func (p pancakeModelBucketAggregation) ShallowClone() pancakeModelBucketAggregation {
return pancakeModelBucketAggregation{
name: p.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ func TestPancakeQueryGeneration(t *testing.T) {

for i, test := range allAggregationTests() {
t.Run(test.TestName+"("+strconv.Itoa(i)+")", func(t *testing.T) {
// sample_flights
if test.TestName == "TODO Airport Connections (Hover Over Airport)(file:kibana-sample-data-flights,nr:14)" {
t.Skip("Fixing right now")
}
// sample_ecommerce
if test.TestName == "TODO Top products this/last week(file:kibana-sample-data-ecommerce,nr:9)" {
t.Skip("works IRL, need to update test's schema. It's already WIP https://github.com/QuesmaOrg/quesma/pull/1255. Let's wait for merge.")
Expand Down
88 changes: 72 additions & 16 deletions platform/testdata/kibana_sample_data_flights.go
Original file line number Diff line number Diff line change
Expand Up @@ -2735,8 +2735,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
"_score": 1.0,
"_source": {
"DestLocation": {
"lat": "-34.8222",
"lon": "-58.5358"
"lat": -34.8222,
"lon": -58.5358
}
}
}
Expand All @@ -2761,8 +2761,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
"_score": 1.0,
"_source": {
"DestLocation": {
"lat": "-0.129166667",
"lon": "-78.3575"
"lat": -0.129166667,
"lon": -78.3575
}
}
}
Expand Down Expand Up @@ -2793,8 +2793,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
"_source": {
"Origin": "Mariscal Sucre International Airport",
"OriginLocation": {
"lat": "-0.129166667",
"lon": "-78.3575"
"lat": -0.129166667,
"lon": -78.3575
}
}
}
Expand All @@ -2820,8 +2820,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
"_score": 1.0,
"_source": {
"DestLocation": {
"lat": "45.47060013",
"lon": "-73.74079895"
"lat": 45.47060013,
"lon": -73.74079895
}
}
}
Expand All @@ -2846,8 +2846,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
"_score": 1.0,
"_source": {
"DestLocation": {
"lat": "-34.8222",
"lon": "-58.5358"
"lat": -34.8222,
"lon": -58.5358
}
}
}
Expand Down Expand Up @@ -2878,8 +2878,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
"_source": {
"Origin": "Ministro Pistarini International Airport",
"OriginLocation": {
"lat": "-34.8222",
"lon": "-58.5358"
"lat": -34.8222,
"lon": -58.5358
}
}
}
Expand All @@ -2894,15 +2894,15 @@ var KibanaSampleDataFlights = []AggregationTestCase{
}
],
"doc_count_error_upper_bound": 0,
"sum_other_doc_count": 12474
"sum_other_doc_count": 1460
}
},
"hits": {
"hits": [],
"max_score": null,
"total": {
"relation": "eq",
"value": 13014
"value": 2000
}
},
"timed_out": false,
Expand All @@ -2912,15 +2912,71 @@ var KibanaSampleDataFlights = []AggregationTestCase{
}`,
ExpectedPancakeResults: []model.QueryResultRow{
{Cols: []model.QueryResultCol{
model.NewQueryResultCol("aggr__origins__parent_count", int64(283)),
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
model.NewQueryResultCol("aggr__origins__key_0", "UIO"),
model.NewQueryResultCol("aggr__origins__count", int64(283)),
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(283)),
model.NewQueryResultCol("aggr__origins__distinations__key_0", "EZE"),
model.NewQueryResultCol("aggr__origins__distinations__count", int64(21)),
model.NewQueryResultCol("top_hits__origins__distinations__destLocation_col_0", "[-34.8222, -58.5358]"),
model.NewQueryResultCol("top_hits__origins__distinations__destLocation_col_0", model.JsonMap{"lat": -34.8222, "lon": -58.5358}),
model.NewQueryResultCol("top_hits_rank", int64(1)),
}},
{Cols: []model.QueryResultCol{
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
model.NewQueryResultCol("aggr__origins__key_0", "UIO"),
model.NewQueryResultCol("aggr__origins__count", int64(283)),
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(283)),
model.NewQueryResultCol("aggr__origins__distinations__key_0", "UIO"),
model.NewQueryResultCol("aggr__origins__distinations__count", int64(12)),
model.NewQueryResultCol("top_hits__origins__distinations__destLocation_col_0", model.JsonMap{"lat": -0.129167, "lon": -78.3575}),
model.NewQueryResultCol("top_hits_rank", int64(1)),
}},
{Cols: []model.QueryResultCol{
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
model.NewQueryResultCol("aggr__origins__key_0", "EZE"),
model.NewQueryResultCol("aggr__origins__count", int64(257)),
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(257)),
model.NewQueryResultCol("aggr__origins__distinations__key_0", "YUL"),
model.NewQueryResultCol("aggr__origins__distinations__count", int64(11)),
model.NewQueryResultCol("top_hits__origins__distinations__destLocation_col_0", model.JsonMap{"lat": 45.470600, "lon": -73.740799}),
model.NewQueryResultCol("top_hits_rank", int64(1)),
}},
{Cols: []model.QueryResultCol{
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
model.NewQueryResultCol("aggr__origins__key_0", "EZE"),
model.NewQueryResultCol("aggr__origins__count", int64(257)),
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(257)),
model.NewQueryResultCol("aggr__origins__distinations__key_0", "EZE"),
model.NewQueryResultCol("aggr__origins__distinations__count", int64(10)),
model.NewQueryResultCol("top_hits__origins__distinations__destLocation_col_0", model.JsonMap{"lat": -34.822200, "lon": -58.535800}),
model.NewQueryResultCol("top_hits_rank", int64(1)),
}},
},
ExpectedAdditionalPancakeResults: [][]model.QueryResultRow{
{
{Cols: []model.QueryResultCol{
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
model.NewQueryResultCol("aggr__origins__key_0", "UIO"),
model.NewQueryResultCol("aggr__origins__count", int64(283)),
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(283)),
model.NewQueryResultCol("aggr__origins__distinations__key_0", "EZE"),
model.NewQueryResultCol("aggr__origins__distinations__count", int64(21)),
model.NewQueryResultCol("top_hits__origins__originLocation_col_0", model.JsonMap{"lat": -0.129167, "lon": -78.3575}),
model.NewQueryResultCol("top_hits__origins__originLocation_col_1", "Mariscal Sucre International Airport"),
model.NewQueryResultCol("top_hits_rank", int64(1)),
}},
{Cols: []model.QueryResultCol{
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
model.NewQueryResultCol("aggr__origins__key_0", "EZE"),
model.NewQueryResultCol("aggr__origins__count", int64(257)),
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(257)),
model.NewQueryResultCol("aggr__origins__distinations__key_0", "YUL"),
model.NewQueryResultCol("aggr__origins__distinations__count", int64(11)),
model.NewQueryResultCol("top_hits__origins__originLocation_col_0", model.JsonMap{"lat": -34.822200, "lon": -58.535800}),
model.NewQueryResultCol("top_hits__origins__originLocation_col_1", "Ministro Pistarini International Airport"),
model.NewQueryResultCol("top_hits_rank", int64(1)),
}},
},
},
ExpectedPancakeSQL: `
WITH quesma_top_hits_group_table AS (
Expand Down

0 comments on commit 5b7ed7f

Please sign in to comment.