diff --git a/conn.go b/conn.go index dea7428..6ca2b68 100644 --- a/conn.go +++ b/conn.go @@ -26,7 +26,7 @@ type Conn struct { queryInterval time.Duration localNames []string - queries []query + queries []*query ifaces []net.Interface closed chan interface{} @@ -111,7 +111,7 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { c := &Conn{ queryInterval: defaultQueryInterval, - queries: []query{}, + queries: []*query{}, socket: conn, dstAddr: dstAddr, localNames: localNames, @@ -163,11 +163,22 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade nameWithSuffix := name + "." queryChan := make(chan queryResult, 1) + query := &query{nameWithSuffix, queryChan} c.mu.Lock() - c.queries = append(c.queries, query{nameWithSuffix, queryChan}) - ticker := time.NewTicker(c.queryInterval) + c.queries = append(c.queries, query) c.mu.Unlock() + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + for i := len(c.queries) - 1; i >= 0; i-- { + if c.queries[i] == query { + c.queries = append(c.queries[:i], c.queries[i+1:]...) + } + } + }() + + ticker := time.NewTicker(c.queryInterval) defer ticker.Stop() c.sendQuestion(nameWithSuffix) diff --git a/conn_test.go b/conn_test.go index 5564b5d..d480361 100644 --- a/conn_test.go +++ b/conn_test.go @@ -69,6 +69,13 @@ func TestValidCommunication(t *testing.T) { check(aServer.Close(), t) check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } } func TestValidCommunicationWithAddressConfig(t *testing.T) { @@ -93,6 +100,9 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) { } check(aServer.Close(), t) + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } } func TestMultipleClose(t *testing.T) { @@ -109,6 +119,10 @@ func TestMultipleClose(t *testing.T) { check(server.Close(), t) check(server.Close(), t) + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after server close") + } } func TestQueryRespectTimeout(t *testing.T) { @@ -133,6 +147,10 @@ func TestQueryRespectTimeout(t *testing.T) { if closeErr := server.Close(); closeErr != nil { t.Fatal(closeErr) } + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after context expiration") + } } func TestQueryRespectClose(t *testing.T) { @@ -159,6 +177,10 @@ func TestQueryRespectClose(t *testing.T) { if _, _, err = server.Query(context.TODO(), "invalid-host"); !errors.Is(err, errConnectionClosed) { t.Fatalf("Query on closed server but returned unexpected error %v", err) } + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after query") + } } func TestResourceParsing(t *testing.T) {