diff --git a/conn.go b/conn.go index ed51e55..1f14607 100644 --- a/conn.go +++ b/conn.go @@ -26,7 +26,7 @@ type Conn struct { queryInterval time.Duration localNames []string - queries []query + queries []*query ifaces []net.Interface closed chan interface{} @@ -121,7 +121,7 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { c := &Conn{ queryInterval: defaultQueryInterval, - queries: []query{}, + queries: []*query{}, socket: conn, dstAddr: dstAddr, localNames: localNames, @@ -181,11 +181,22 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade nameWithSuffix := name + "." queryChan := make(chan queryResult, 1) + query := &query{nameWithSuffix, queryChan} c.mu.Lock() - c.queries = append(c.queries, query{nameWithSuffix, queryChan}) - ticker := time.NewTicker(c.queryInterval) + c.queries = append(c.queries, query) c.mu.Unlock() + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + for i := len(c.queries) - 1; i >= 0; i-- { + if c.queries[i] == query { + c.queries = append(c.queries[:i], c.queries[i+1:]...) + } + } + }() + + ticker := time.NewTicker(c.queryInterval) defer ticker.Stop() c.sendQuestion(nameWithSuffix) diff --git a/conn_test.go b/conn_test.go index 69f141e..faa6ad4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -103,6 +103,13 @@ func TestValidCommunication(t *testing.T) { check(aServer.Close(), t) check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } } func TestValidCommunicationWithAddressConfig(t *testing.T) { @@ -127,6 +134,9 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) { } check(aServer.Close(), t) + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } } func TestValidCommunicationWithLoopbackAddressConfig(t *testing.T) { @@ -230,6 +240,10 @@ func TestMultipleClose(t *testing.T) { check(server.Close(), t) check(server.Close(), t) + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after server close") + } } func TestQueryRespectTimeout(t *testing.T) { @@ -254,6 +268,10 @@ func TestQueryRespectTimeout(t *testing.T) { if closeErr := server.Close(); closeErr != nil { t.Fatal(closeErr) } + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after context expiration") + } } func TestQueryRespectClose(t *testing.T) { @@ -280,6 +298,10 @@ func TestQueryRespectClose(t *testing.T) { if _, _, err = server.Query(context.TODO(), "invalid-host"); !errors.Is(err, errConnectionClosed) { t.Fatalf("Query on closed server but returned unexpected error %v", err) } + + if len(server.queries) > 0 { + t.Fatalf("Queries not cleaned up after query") + } } func TestResourceParsing(t *testing.T) {