From 225c67e5eba566ca68099a7b8deefea00fe478bf Mon Sep 17 00:00:00 2001 From: Simon Cousineau Date: Wed, 7 Feb 2024 04:36:49 +0000 Subject: [PATCH] Cleanup expired queries --- conn.go | 9 +++++++++ conn_test.go | 22 ++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/conn.go b/conn.go index dea7428..8100e27 100644 --- a/conn.go +++ b/conn.go @@ -169,6 +169,15 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade c.mu.Unlock() defer ticker.Stop() + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + for i := len(c.queries) - 1; i >= 0; i-- { + if c.queries[i].nameWithSuffix == nameWithSuffix { + c.queries = append(c.queries[:i], c.queries[i+1:]...) + } + } + }() c.sendQuestion(nameWithSuffix) for { diff --git a/conn_test.go b/conn_test.go index 5564b5d..d480361 100644 --- a/conn_test.go +++ b/conn_test.go @@ -69,6 +69,13 @@ func TestValidCommunication(t *testing.T) { check(aServer.Close(), t) check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } } func TestValidCommunicationWithAddressConfig(t *testing.T) { @@ -93,6 +100,9 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) { } check(aServer.Close(), t) + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } } func TestMultipleClose(t *testing.T) { @@ -109,6 +119,10 @@ func TestMultipleClose(t *testing.T) { check(server.Close(), t) check(server.Close(), t) + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after server close") + } } func TestQueryRespectTimeout(t *testing.T) { @@ -133,6 +147,10 @@ func TestQueryRespectTimeout(t *testing.T) { if closeErr := server.Close(); closeErr != nil { t.Fatal(closeErr) } + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after context expiration") + } } func TestQueryRespectClose(t *testing.T) { @@ -159,6 +177,10 @@ func TestQueryRespectClose(t *testing.T) { if _, _, err = server.Query(context.TODO(), "invalid-host"); !errors.Is(err, errConnectionClosed) { t.Fatalf("Query on closed server but returned unexpected error %v", err) } + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after query") + } } func TestResourceParsing(t *testing.T) {