diff --git a/conn.go b/conn.go index 87a3e90..9610312 100644 --- a/conn.go +++ b/conn.go @@ -37,6 +37,8 @@ type Conn struct { queries []*query ifaces map[int]netInterface + echoQueryInAnswer bool + closed chan interface{} } @@ -102,9 +104,10 @@ func Server( log := loggerFactory.NewLogger("mdns") c := &Conn{ - queryInterval: defaultQueryInterval, - log: log, - closed: make(chan interface{}), + queryInterval: defaultQueryInterval, + log: log, + closed: make(chan interface{}), + echoQueryInAnswer: true, } c.name = config.Name if c.name == "" { @@ -710,7 +713,7 @@ func (c *Conn) writeToSocket( } } -func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr) (dnsmessage.Message, error) { +func (c *Conn) createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr) (dnsmessage.Message, error) { packedName, err := dnsmessage.NewName(q.Name.String()) if err != nil { return dnsmessage.Message{}, err @@ -722,7 +725,6 @@ func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr) (dnsmessage Response: true, Authoritative: true, }, - Questions: []dnsmessage.Question{q}, Answers: []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ @@ -734,6 +736,10 @@ func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr) (dnsmessage }, } + if c.echoQueryInAnswer { + msg.Questions = []dnsmessage.Question{q} + } + if addr.Is4() { ipBuf, err := ipv4ToBytes(addr) if err != nil { @@ -759,7 +765,7 @@ func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr) (dnsmessage } func (c *Conn) sendAnswer(queryID uint16, q dnsmessage.Question, ifIndex int, result netip.Addr, dst *net.UDPAddr) { - answer, err := createAnswer(queryID, q, result) + answer, err := c.createAnswer(queryID, q, result) if err != nil { c.log.Warnf("[%s] failed to create mDNS answer %v", c.name, err) return diff --git a/conn_test.go b/conn_test.go index 583be91..21c8ae4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -718,3 +718,58 @@ func TestIPToBytes(t *testing.T) { t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualAddr6) } } + +// Test for our client side handling cases where the server may or may +// not have included the echoed query with their answer. +func testAnswerHandlingWithQueryEchoed(t *testing.T, echoQuery bool) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock := createListener4(t) + bSock := createListener4(t) + + aServer, err := Server(ipv4.NewPacketConn(aSock), nil, &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + }) + check(err, t) + + aServer.echoQueryInAnswer = echoQuery + + bServer, err := Server(ipv4.NewPacketConn(bSock), nil, &Config{}) + check(err, t) + + _, addr, err := bServer.QueryAddr(context.TODO(), "pion-mdns-1.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + checkIPv4(addr, t) + + _, addr, err = bServer.QueryAddr(context.TODO(), "pion-mdns-2.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + checkIPv4(addr, t) + + check(aServer.Close(), t) + check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } +} + +func TestAnswerHandlingWithQueryEchoed(t *testing.T) { + testAnswerHandlingWithQueryEchoed(t, true) +} + +func TestAnswerHandlingWithoutQueryEchoed(t *testing.T) { + testAnswerHandlingWithQueryEchoed(t, false) +}