From 950fc4a81d79894e0cca889a7a18a936155d50cb Mon Sep 17 00:00:00 2001 From: Simon Cousineau Date: Wed, 7 Feb 2024 04:36:49 +0000 Subject: [PATCH 1/3] Cleanup expired queries --- conn.go | 9 +++++++++ conn_test.go | 22 ++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/conn.go b/conn.go index dea7428..8100e27 100644 --- a/conn.go +++ b/conn.go @@ -169,6 +169,15 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade c.mu.Unlock() defer ticker.Stop() + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + for i := len(c.queries) - 1; i >= 0; i-- { + if c.queries[i].nameWithSuffix == nameWithSuffix { + c.queries = append(c.queries[:i], c.queries[i+1:]...) + } + } + }() c.sendQuestion(nameWithSuffix) for { diff --git a/conn_test.go b/conn_test.go index 5564b5d..d480361 100644 --- a/conn_test.go +++ b/conn_test.go @@ -69,6 +69,13 @@ func TestValidCommunication(t *testing.T) { check(aServer.Close(), t) check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } } func TestValidCommunicationWithAddressConfig(t *testing.T) { @@ -93,6 +100,9 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) { } check(aServer.Close(), t) + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } } func TestMultipleClose(t *testing.T) { @@ -109,6 +119,10 @@ func TestMultipleClose(t *testing.T) { check(server.Close(), t) check(server.Close(), t) + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after server close") + } } func TestQueryRespectTimeout(t *testing.T) { @@ -133,6 +147,10 @@ func TestQueryRespectTimeout(t *testing.T) { if closeErr := server.Close(); closeErr != nil { t.Fatal(closeErr) } + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after context expiration") + } } func TestQueryRespectClose(t *testing.T) { @@ -159,6 +177,10 @@ func TestQueryRespectClose(t *testing.T) { if _, _, err = server.Query(context.TODO(), "invalid-host"); !errors.Is(err, errConnectionClosed) { t.Fatalf("Query on closed server but returned unexpected error %v", err) } + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after query") + } } func TestResourceParsing(t *testing.T) { From 9ec78351db16c2c2a30f31e3bd3499b40da15338 Mon Sep 17 00:00:00 2001 From: Simon Cousineau Date: Wed, 7 Feb 2024 17:34:50 +0000 Subject: [PATCH 2/3] Cleanup queries by reference instead of name --- conn.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 8100e27..6fe40c8 100644 --- a/conn.go +++ b/conn.go @@ -26,7 +26,7 @@ type Conn struct { queryInterval time.Duration localNames []string - queries []query + queries []*query ifaces []net.Interface closed chan interface{} @@ -111,7 +111,7 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { c := &Conn{ queryInterval: defaultQueryInterval, - queries: []query{}, + queries: []*query{}, socket: conn, dstAddr: dstAddr, localNames: localNames, @@ -163,8 +163,9 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade nameWithSuffix := name + "." queryChan := make(chan queryResult, 1) + query := &query{nameWithSuffix, queryChan} c.mu.Lock() - c.queries = append(c.queries, query{nameWithSuffix, queryChan}) + c.queries = append(c.queries, query) ticker := time.NewTicker(c.queryInterval) c.mu.Unlock() @@ -173,7 +174,7 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade c.mu.Lock() defer c.mu.Unlock() for i := len(c.queries) - 1; i >= 0; i-- { - if c.queries[i].nameWithSuffix == nameWithSuffix { + if c.queries[i] == query { c.queries = append(c.queries[:i], c.queries[i+1:]...) } } From d734c9640632b6417f60c5f351742a431797aff3 Mon Sep 17 00:00:00 2001 From: Simon Cousineau Date: Wed, 7 Feb 2024 17:36:42 +0000 Subject: [PATCH 3/3] Create ticker outside critical section --- conn.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 6fe40c8..6ca2b68 100644 --- a/conn.go +++ b/conn.go @@ -166,10 +166,8 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade query := &query{nameWithSuffix, queryChan} c.mu.Lock() c.queries = append(c.queries, query) - ticker := time.NewTicker(c.queryInterval) c.mu.Unlock() - defer ticker.Stop() defer func() { c.mu.Lock() defer c.mu.Unlock() @@ -180,6 +178,9 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade } }() + ticker := time.NewTicker(c.queryInterval) + defer ticker.Stop() + c.sendQuestion(nameWithSuffix) for { select {