diff --git a/conn.go b/conn.go index 3c0dd62..c7759d5 100644 --- a/conn.go +++ b/conn.go @@ -713,7 +713,7 @@ func (c *Conn) writeToSocket( } } -func (c *Conn) createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr, config *Config) (dnsmessage.Message, error) { +func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr, config *Config) (dnsmessage.Message, error) { packedName, err := dnsmessage.NewName(q.Name.String()) if err != nil { return dnsmessage.Message{}, err @@ -767,7 +767,7 @@ func (c *Conn) createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr, c } func (c *Conn) sendAnswer(queryID uint16, q dnsmessage.Question, ifIndex int, result netip.Addr, dst *net.UDPAddr, config *Config) { - answer, err := c.createAnswer(queryID, q, result, config) + answer, err := createAnswer(queryID, q, result, config) if err != nil { c.log.Warnf("[%s] failed to create mDNS answer %v", c.name, err) return @@ -1066,7 +1066,7 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int for _, query := range queries { queryCopy := query if queryCopy.nameWithSuffix == a.Header.Name.String() { - addr, err := addrFromAnswerHeader(a) + addr, err := addrFromAnswer(a) if err != nil { c.log.Warnf("[%s] failed to parse mDNS answer %v", c.name, err) return @@ -1166,7 +1166,7 @@ func (c *Conn) start(started chan<- struct{}, inboundBufferSize int, config *Con } } -func addrFromAnswerHeader(answer dnsmessage.Resource) (*netip.Addr, error) { +func addrFromAnswer(answer dnsmessage.Resource) (*netip.Addr, error) { switch answer.Header.Type { case dnsmessage.TypeA: if a, ok := answer.Body.(*dnsmessage.AResource); ok { diff --git a/conn_test.go b/conn_test.go index 995388f..c8dd02b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -672,6 +672,73 @@ func TestQueryRespectClose(t *testing.T) { } } +func testResourceParsing(t *testing.T, echoQuery bool) { + lookForIP := func(msg dnsmessage.Message, expectedIP []byte, t *testing.T) { + actualAddr, err := addrFromAnswer(msg.Answers[0]) + if err != nil { + t.Fatal(err) + } + + if echoQuery { + if len(msg.Questions) == 0 { + t.Fatal("Echoed query not included in answer") + } + } else { + if len(msg.Questions) > 0 { + t.Fatal("Echoed query erroneously included in answer") + } + } + + if !bytes.Equal(actualAddr.AsSlice(), expectedIP) { + t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualAddr) + } + } + + name := "test-server." + + config := &Config{ + DoNotEchoQueryWithAnswer: !echoQuery, + } + + t.Run("A Record", func(t *testing.T) { + answer, err := createAnswer(1, dnsmessage.Question{ + Name: dnsmessage.MustNewName(name), + Type: dnsmessage.TypeA, + }, mustAddr(net.IP{127, 0, 0, 1}), config) + if err != nil { + t.Fatal(err) + } + lookForIP(answer, []byte{127, 0, 0, 1}, t) + }) + + t.Run("AAAA Record", func(t *testing.T) { + answer, err := createAnswer(1, dnsmessage.Question{ + Name: dnsmessage.MustNewName(name), + Type: dnsmessage.TypeAAAA, + }, netip.MustParseAddr("::1"), config) + if err != nil { + t.Fatal(err) + } + lookForIP(answer, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, t) + }) +} + +func TestResourceParsingWithEchoedQuery(t *testing.T) { + testResourceParsing(t, true) +} + +func TestResourceParsingWithoutEchoedQuery(t *testing.T) { + testResourceParsing(t, false) +} + +func mustAddr(ip net.IP) netip.Addr { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + panic(ipToAddrError{ip}) + } + return addr +} + func TestIPToBytes(t *testing.T) { expectedIP := []byte{127, 0, 0, 1} actualAddr4, err := ipv4ToBytes(netip.MustParseAddr("127.0.0.1"))