From 667f281872f0b618b4d56c62c91e3e7a605dd26c Mon Sep 17 00:00:00 2001
From: Vinny <yehonatan.weinberger@supertenant.com>
Date: Tue, 20 Feb 2024 21:29:56 +0000
Subject: [PATCH] fix(query): add with totals case (issue #382)

---
 query.go      |  2 +-
 query_test.go | 31 +++++++++++++++++++++++++++++++
 2 files changed, 32 insertions(+), 1 deletion(-)

diff --git a/query.go b/query.go
index f8aabbe2..f3307e07 100644
--- a/query.go
+++ b/query.go
@@ -720,7 +720,7 @@ func (c *Client) Do(ctx context.Context, q Query) (err error) {
 				return errors.Wrap(err, "packet")
 			}
 			switch code {
-			case proto.ServerCodeData:
+			case proto.ServerCodeData, proto.ServerCodeTotals:
 				if err := c.decodeBlock(ctx, decodeOptions{
 					Handler:      onResult,
 					Result:       q.Result,
diff --git a/query_test.go b/query_test.go
index 6bf626c2..a752bc3c 100644
--- a/query_test.go
+++ b/query_test.go
@@ -27,6 +27,37 @@ func requireEqual[T any](t *testing.T, a, b proto.ColumnOf[T]) {
 	}
 }
 
+func TestWithTotals(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+	conn := Conn(t)
+	var n proto.ColUInt64
+	var c proto.ColUInt64
+
+	var data []uint64
+	query := Query{
+		Body: `
+			SELECT
+				number AS n,
+				COUNT() AS c
+			FROM (
+				SELECT number FROM system.numbers LIMIT 100
+			) GROUP BY n WITH TOTALS
+		`,
+		Result: proto.Results{
+			{Name: "n", Data: &n},
+			{Name: "c", Data: &c},
+		},
+		OnResult: func(ctx context.Context, b proto.Block) error {
+			data = append(data, c...)
+			return nil
+		},
+	}
+	require.NoError(t, conn.Do(ctx, query))
+	require.Equal(t, 101, len(data))
+	require.Equal(t, uint64(100), data[100])
+}
+
 func TestDateTimeOverflow(t *testing.T) {
 	t.Parallel()
 	ctx := context.Background()