diff --git a/query.go b/query.go index f8aabbe2..f3307e07 100644 --- a/query.go +++ b/query.go @@ -720,7 +720,7 @@ func (c *Client) Do(ctx context.Context, q Query) (err error) { return errors.Wrap(err, "packet") } switch code { - case proto.ServerCodeData: + case proto.ServerCodeData, proto.ServerCodeTotals: if err := c.decodeBlock(ctx, decodeOptions{ Handler: onResult, Result: q.Result, diff --git a/query_test.go b/query_test.go index 6bf626c2..a752bc3c 100644 --- a/query_test.go +++ b/query_test.go @@ -27,6 +27,37 @@ func requireEqual[T any](t *testing.T, a, b proto.ColumnOf[T]) { } } +func TestWithTotals(t *testing.T) { + t.Parallel() + ctx := context.Background() + conn := Conn(t) + var n proto.ColUInt64 + var c proto.ColUInt64 + + var data []uint64 + query := Query{ + Body: ` + SELECT + number AS n, + COUNT() AS c + FROM ( + SELECT number FROM system.numbers LIMIT 100 + ) GROUP BY n WITH TOTALS + `, + Result: proto.Results{ + {Name: "n", Data: &n}, + {Name: "c", Data: &c}, + }, + OnResult: func(ctx context.Context, b proto.Block) error { + data = append(data, c...) + return nil + }, + } + require.NoError(t, conn.Do(ctx, query)) + require.Equal(t, 101, len(data)) + require.Equal(t, uint64(100), data[100]) +} + func TestDateTimeOverflow(t *testing.T) { t.Parallel() ctx := context.Background()