From cde859b1f539a328d5dc54273dbe9c0c987ca695 Mon Sep 17 00:00:00 2001 From: atomirex Date: Sun, 24 Nov 2024 11:23:20 -0500 Subject: [PATCH 1/2] Sending the question back with the answer --- conn.go | 11 ++++++----- conn_test.go | 5 +++-- examples/query/main.go | 2 +- examples/server/main.go | 2 +- examples/server/publish_ip/main.go | 2 +- go.mod | 2 +- 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index d300163..26ac6da 100644 --- a/conn.go +++ b/conn.go @@ -710,8 +710,8 @@ func (c *Conn) writeToSocket( } } -func createAnswer(id uint16, name string, addr netip.Addr) (dnsmessage.Message, error) { - packedName, err := dnsmessage.NewName(name) +func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr) (dnsmessage.Message, error) { + packedName, err := dnsmessage.NewName(q.Name.String()) if err != nil { return dnsmessage.Message{}, err } @@ -722,6 +722,7 @@ func createAnswer(id uint16, name string, addr netip.Addr) (dnsmessage.Message, Response: true, Authoritative: true, }, + Questions: []dnsmessage.Question{q}, Answers: []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ @@ -757,8 +758,8 @@ func createAnswer(id uint16, name string, addr netip.Addr) (dnsmessage.Message, return msg, nil } -func (c *Conn) sendAnswer(queryID uint16, name string, ifIndex int, result netip.Addr, dst *net.UDPAddr) { - answer, err := createAnswer(queryID, name, result) +func (c *Conn) sendAnswer(queryID uint16, q dnsmessage.Question, ifIndex int, result netip.Addr, dst *net.UDPAddr) { + answer, err := createAnswer(queryID, q, result) if err != nil { c.log.Warnf("[%s] failed to create mDNS answer %v", c.name, err) return @@ -1043,7 +1044,7 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int continue } c.log.Debugf("[%s] sending response for %s on ifc %d of %s to %s", c.name, q.Name, ifIndex, *localAddress, dst) - c.sendAnswer(header.ID, q.Name.String(), ifIndex, *localAddress, dst) + c.sendAnswer(header.ID, q, ifIndex, *localAddress, dst) } } } diff --git a/conn_test.go b/conn_test.go index 2ae4e18..01ea52b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -704,9 +704,10 @@ func TestResourceParsing(t *testing.T) { } name := "test-server." + q := dnsmessage.Question{Name: dnsmessage.MustNewName(name)} t.Run("A Record", func(t *testing.T) { - answer, err := createAnswer(1, name, mustAddr(net.IP{127, 0, 0, 1})) + answer, err := createAnswer(1, q, mustAddr(net.IP{127, 0, 0, 1})) if err != nil { t.Fatal(err) } @@ -714,7 +715,7 @@ func TestResourceParsing(t *testing.T) { }) t.Run("AAAA Record", func(t *testing.T) { - answer, err := createAnswer(1, name, netip.MustParseAddr("::1")) + answer, err := createAnswer(1, q, netip.MustParseAddr("::1")) if err != nil { t.Fatal(err) } diff --git a/examples/query/main.go b/examples/query/main.go index 3539d05..56b200e 100644 --- a/examples/query/main.go +++ b/examples/query/main.go @@ -10,7 +10,7 @@ import ( "net" "os" - "github.com/pion/mdns/v2" + "github.com/atomirex/mdns" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) diff --git a/examples/server/main.go b/examples/server/main.go index ce9b20b..e1394ea 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -7,7 +7,7 @@ package main import ( "net" - "github.com/pion/mdns/v2" + "github.com/atomirex/mdns" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) diff --git a/examples/server/publish_ip/main.go b/examples/server/publish_ip/main.go index 968991f..8e26e41 100644 --- a/examples/server/publish_ip/main.go +++ b/examples/server/publish_ip/main.go @@ -9,7 +9,7 @@ import ( "flag" "net" - "github.com/pion/mdns/v2" + "github.com/atomirex/mdns" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) diff --git a/go.mod b/go.mod index 67211ff..0d2cfe3 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/pion/mdns/v2 +module github.com/atomirex/mdns go 1.20 From 39a9a802b6f42f41f6994c41580816c02520899a Mon Sep 17 00:00:00 2001 From: atomirex Date: Tue, 26 Nov 2024 09:07:50 -0500 Subject: [PATCH 2/2] Removes dnsmessage.Parser usage For code flow simplicity when handling data that may or may not be structured precisely as expected it's easier to move to unpacking the whole message and handling that as is. --- config.go | 3 + conn.go | 430 ++++++++++++++--------------- conn_test.go | 108 ++++++-- examples/query/main.go | 2 +- examples/server/main.go | 2 +- examples/server/publish_ip/main.go | 2 +- go.mod | 2 +- 7 files changed, 304 insertions(+), 245 deletions(-) diff --git a/config.go b/config.go index 4659e06..8274984 100644 --- a/config.go +++ b/config.go @@ -46,4 +46,7 @@ type Config struct { // Interfaces will override the interfaces used for queries and answers. Interfaces []net.Interface + + // Override the default behavior of echoing the query with the answer + DoNotEchoQueryWithAnswer bool } diff --git a/conn.go b/conn.go index 26ac6da..c7759d5 100644 --- a/conn.go +++ b/conn.go @@ -65,10 +65,13 @@ const ( ) var ( - errNoPositiveMTUFound = errors.New("no positive MTU found") - errNoPacketConn = errors.New("must supply at least a multicast IPv4 or IPv6 PacketConn") - errNoUsableInterfaces = errors.New("no usable interfaces found for mDNS") - errFailedToClose = errors.New("failed to close mDNS Conn") + errNoPositiveMTUFound = errors.New("no positive MTU found") + errNoPacketConn = errors.New("must supply at least a multicast IPv4 or IPv6 PacketConn") + errNoUsableInterfaces = errors.New("no usable interfaces found for mDNS") + errFailedToClose = errors.New("failed to close mDNS Conn") + errFailedToDecodeAddrFromAResource = errors.New("failed to decode netip.Addr from A type Resource") + errFailedToDecodeAddrFromAAAAResource = errors.New("failed to decode netip.Addr from AAAA type Resource") + errUnhandledAnswerHeaderType = errors.New("header for Answer had unhandled type") ) type netInterface struct { @@ -710,7 +713,7 @@ func (c *Conn) writeToSocket( } } -func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr) (dnsmessage.Message, error) { +func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr, config *Config) (dnsmessage.Message, error) { packedName, err := dnsmessage.NewName(q.Name.String()) if err != nil { return dnsmessage.Message{}, err @@ -722,7 +725,6 @@ func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr) (dnsmessage Response: true, Authoritative: true, }, - Questions: []dnsmessage.Question{q}, Answers: []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ @@ -734,6 +736,12 @@ func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr) (dnsmessage }, } + // This is a negative because we want to default to echoing the query with an answer + // The main use of turning it off is in testing + if !config.DoNotEchoQueryWithAnswer { + msg.Questions = []dnsmessage.Question{q} + } + if addr.Is4() { ipBuf, err := ipv4ToBytes(addr) if err != nil { @@ -758,8 +766,8 @@ func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr) (dnsmessage return msg, nil } -func (c *Conn) sendAnswer(queryID uint16, q dnsmessage.Question, ifIndex int, result netip.Addr, dst *net.UDPAddr) { - answer, err := createAnswer(queryID, q, result) +func (c *Conn) sendAnswer(queryID uint16, q dnsmessage.Question, ifIndex int, result netip.Addr, dst *net.UDPAddr, config *Config) { + answer, err := createAnswer(queryID, q, result, config) if err != nil { c.log.Warnf("[%s] failed to create mDNS answer %v", c.name, err) return @@ -858,7 +866,6 @@ func (c ipPacketConn6) Close() error { func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int, config *Config) { //nolint:gocognit b := make([]byte, inboundBufferSize) - p := dnsmessage.Parser{} for { n, cm, src, err := pktConn.ReadFrom(b) @@ -886,226 +893,214 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int } func() { - header, err := p.Start(b[:n]) + var msg dnsmessage.Message + err := msg.Unpack(b[:n]) if err != nil { c.log.Warnf("[%s] failed to parse mDNS packet %v", c.name, err) return } - for i := 0; i <= maxMessageRecords; i++ { - q, err := p.Question() - if errors.Is(err, dnsmessage.ErrSectionDone) { - break - } else if err != nil { - c.log.Warnf("[%s] failed to parse mDNS packet %v", c.name, err) - return - } + // Questions are often echoed with answers, therefore + // If we have more questions than answers it is a question we might need to respond to + if len(msg.Questions) > len(msg.Answers) { + for _, q := range msg.Questions { + if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA { + continue + } - if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA { - continue - } + // https://datatracker.ietf.org/doc/html/rfc6762#section-6 + // The destination UDP port in all Multicast DNS responses MUST be 5353, + // and the destination address MUST be the mDNS IPv4 link-local + // multicast address 224.0.0.251 or its IPv6 equivalent FF02::FB, except + // when generating a reply to a query that explicitly requested a + // unicast response + shouldUnicastResponse := (q.Class&(1<<15)) != 0 || // via the unicast-response bit + srcAddr.Port != 5353 || // by virtue of being a legacy query (Section 6.7), or + (len(pktDst) != 0 && !(pktDst.Equal(c.dstAddr4.IP) || // by virtue of being a direct unicast query + pktDst.Equal(c.dstAddr6.IP))) + var dst *net.UDPAddr + if shouldUnicastResponse { + dst = srcAddr + } - // https://datatracker.ietf.org/doc/html/rfc6762#section-6 - // The destination UDP port in all Multicast DNS responses MUST be 5353, - // and the destination address MUST be the mDNS IPv4 link-local - // multicast address 224.0.0.251 or its IPv6 equivalent FF02::FB, except - // when generating a reply to a query that explicitly requested a - // unicast response - shouldUnicastResponse := (q.Class&(1<<15)) != 0 || // via the unicast-response bit - srcAddr.Port != 5353 || // by virtue of being a legacy query (Section 6.7), or - (len(pktDst) != 0 && !(pktDst.Equal(c.dstAddr4.IP) || // by virtue of being a direct unicast query - pktDst.Equal(c.dstAddr6.IP))) - var dst *net.UDPAddr - if shouldUnicastResponse { - dst = srcAddr - } + queryWantsV4 := q.Type == dnsmessage.TypeA - queryWantsV4 := q.Type == dnsmessage.TypeA - - for _, localName := range c.localNames { - if localName == q.Name.String() { - var localAddress *netip.Addr - if config.LocalAddress != nil { - // this means the LocalAddress does not support link-local since - // we have no zone to set here. - ipAddr, ok := netip.AddrFromSlice(config.LocalAddress) - if !ok { - c.log.Warnf("[%s] failed to convert config.LocalAddress '%s' to netip.Addr", c.name, config.LocalAddress) - continue - } - if c.multicastPktConnV4 != nil { - // don't want mapping since we also support IPv4/A - ipAddr = ipAddr.Unmap() - } - localAddress = &ipAddr - } else { - // prefer the address of the interface if we know its index, but otherwise - // derive it from the address we read from. We do this because even if - // multicast loopback is in use or we send from a loopback interface, - // there are still cases where the IP packet will contain the wrong - // source IP (e.g. a LAN interface). - // For example, we can have a packet that has: - // Source: 192.168.65.3 - // Destination: 224.0.0.251 - // Interface Index: 1 - // Interface Addresses @ 1: [127.0.0.1/8 ::1/128] - if ifIndex != -1 { - ifc, ok := c.ifaces[ifIndex] + for _, localName := range c.localNames { + if localName == q.Name.String() { + var localAddress *netip.Addr + if config.LocalAddress != nil { + // this means the LocalAddress does not support link-local since + // we have no zone to set here. + ipAddr, ok := netip.AddrFromSlice(config.LocalAddress) if !ok { - c.log.Warnf("[%s] no interface for %d", c.name, ifIndex) - return + c.log.Warnf("[%s] failed to convert config.LocalAddress '%s' to netip.Addr", c.name, config.LocalAddress) + continue } - var selectedAddrs []netip.Addr - for _, addr := range ifc.ipAddrs { - addrCopy := addr - - // match up respective IP types based on question - if queryWantsV4 { - if addrCopy.Is4In6() { - // we may allow 4-in-6, but the question wants an A record - addrCopy = addrCopy.Unmap() - } - if !addrCopy.Is4() { - continue - } - } else { // queryWantsV6 - if !addrCopy.Is6() { - continue + if c.multicastPktConnV4 != nil { + // don't want mapping since we also support IPv4/A + ipAddr = ipAddr.Unmap() + } + localAddress = &ipAddr + } else { + // prefer the address of the interface if we know its index, but otherwise + // derive it from the address we read from. We do this because even if + // multicast loopback is in use or we send from a loopback interface, + // there are still cases where the IP packet will contain the wrong + // source IP (e.g. a LAN interface). + // For example, we can have a packet that has: + // Source: 192.168.65.3 + // Destination: 224.0.0.251 + // Interface Index: 1 + // Interface Addresses @ 1: [127.0.0.1/8 ::1/128] + if ifIndex != -1 { + ifc, ok := c.ifaces[ifIndex] + if !ok { + c.log.Warnf("[%s] no interface for %d", c.name, ifIndex) + return + } + var selectedAddrs []netip.Addr + for _, addr := range ifc.ipAddrs { + addrCopy := addr + + // match up respective IP types based on question + if queryWantsV4 { + if addrCopy.Is4In6() { + // we may allow 4-in-6, but the question wants an A record + addrCopy = addrCopy.Unmap() + } + if !addrCopy.Is4() { + continue + } + } else { // queryWantsV6 + if !addrCopy.Is6() { + continue + } + if !isSupportedIPv6(addrCopy, c.multicastPktConnV4 == nil) { + c.log.Debugf("[%s] interface %d address not a supported IPv6 address %s", c.name, ifIndex, &addrCopy) + continue + } } - if !isSupportedIPv6(addrCopy, c.multicastPktConnV4 == nil) { - c.log.Debugf("[%s] interface %d address not a supported IPv6 address %s", c.name, ifIndex, &addrCopy) - continue + + selectedAddrs = append(selectedAddrs, addrCopy) + } + if len(selectedAddrs) == 0 { + c.log.Debugf("[%s] failed to find suitable IP for interface %d; deriving address from source address c.name,instead", c.name, ifIndex) + } else { + // choose the best match + var choice *netip.Addr + for _, option := range selectedAddrs { + optCopy := option + if option.Is4() { + // select first + choice = &optCopy + break + } + // we're okay with 4In6 for now but ideally we get a an actual IPv6. + // Maybe in the future we never want this but it does look like Docker + // can route IPv4 over IPv6. + if choice == nil { + choice = &optCopy + } else if !optCopy.Is4In6() { + choice = &optCopy + } + if !optCopy.Is4In6() { + break + } + // otherwise keep searching for an actual IPv6 } + localAddress = choice } - - selectedAddrs = append(selectedAddrs, addrCopy) } - if len(selectedAddrs) == 0 { - c.log.Debugf("[%s] failed to find suitable IP for interface %d; deriving address from source address c.name,instead", c.name, ifIndex) - } else { - // choose the best match - var choice *netip.Addr - for _, option := range selectedAddrs { - optCopy := option - if option.Is4() { - // select first - choice = &optCopy - break - } - // we're okay with 4In6 for now but ideally we get a an actual IPv6. - // Maybe in the future we never want this but it does look like Docker - // can route IPv4 over IPv6. - if choice == nil { - choice = &optCopy - } else if !optCopy.Is4In6() { - choice = &optCopy - } - if !optCopy.Is4In6() { - break - } - // otherwise keep searching for an actual IPv6 + if ifIndex == -1 || localAddress == nil { + localAddress, err = interfaceForRemote(src.String()) + if err != nil { + c.log.Warnf("[%s] failed to get local interface to communicate with %s: %v", c.name, src.String(), err) + continue } - localAddress = choice } } - if ifIndex == -1 || localAddress == nil { - localAddress, err = interfaceForRemote(src.String()) - if err != nil { - c.log.Warnf("[%s] failed to get local interface to communicate with %s: %v", c.name, src.String(), err) + if queryWantsV4 { + if !localAddress.Is4() { + c.log.Debugf("[%s] have IPv6 address %s to respond with but question is for A not c.name,AAAA", c.name, localAddress) + continue + } + } else { + if !localAddress.Is6() { + c.log.Debugf("[%s] have IPv4 address %s to respond with but question is for AAAA not c.name,A", c.name, localAddress) + continue + } + if !isSupportedIPv6(*localAddress, c.multicastPktConnV4 == nil) { + c.log.Debugf("[%s] got local interface address but not a supported IPv6 address %v", c.name, localAddress) continue } } - } - if queryWantsV4 { - if !localAddress.Is4() { - c.log.Debugf("[%s] have IPv6 address %s to respond with but question is for A not c.name,AAAA", c.name, localAddress) - continue - } - } else { - if !localAddress.Is6() { - c.log.Debugf("[%s] have IPv4 address %s to respond with but question is for AAAA not c.name,A", c.name, localAddress) - continue - } - if !isSupportedIPv6(*localAddress, c.multicastPktConnV4 == nil) { - c.log.Debugf("[%s] got local interface address but not a supported IPv6 address %v", c.name, localAddress) + + if dst != nil && len(dst.IP) == net.IPv4len && + localAddress.Is6() && + localAddress.Zone() != "" && + (localAddress.IsLinkLocalUnicast() || localAddress.IsLinkLocalMulticast()) { + // This case happens when multicast v4 picks up an AAAA question that has a zone + // in the address. Since we cannot send this zone over DNS (it's meaningless), + // the other side can only infer this via the response interface on the other + // side (some IPv6 interface). + c.log.Debugf("[%s] refusing to send link-local address %s to an IPv4 destination %s", c.name, localAddress, dst) continue } + c.log.Debugf("[%s] sending response for %s on ifc %d of %s to %s", c.name, q.Name, ifIndex, *localAddress, dst) + c.sendAnswer(msg.Header.ID, q, ifIndex, *localAddress, dst, config) } - - if dst != nil && len(dst.IP) == net.IPv4len && - localAddress.Is6() && - localAddress.Zone() != "" && - (localAddress.IsLinkLocalUnicast() || localAddress.IsLinkLocalMulticast()) { - // This case happens when multicast v4 picks up an AAAA question that has a zone - // in the address. Since we cannot send this zone over DNS (it's meaningless), - // the other side can only infer this via the response interface on the other - // side (some IPv6 interface). - c.log.Debugf("[%s] refusing to send link-local address %s to an IPv4 destination %s", c.name, localAddress, dst) - continue - } - c.log.Debugf("[%s] sending response for %s on ifc %d of %s to %s", c.name, q.Name, ifIndex, *localAddress, dst) - c.sendAnswer(header.ID, q, ifIndex, *localAddress, dst) } } - } - - for i := 0; i <= maxMessageRecords; i++ { - a, err := p.AnswerHeader() - if errors.Is(err, dnsmessage.ErrSectionDone) { - return - } - if err != nil { - c.log.Warnf("[%s] failed to parse mDNS packet %v", c.name, err) - return - } - - if a.Type != dnsmessage.TypeA && a.Type != dnsmessage.TypeAAAA { - continue - } + } else { + for _, a := range msg.Answers { + if a.Header.Type != dnsmessage.TypeA && a.Header.Type != dnsmessage.TypeAAAA { + continue + } - c.mu.Lock() - queries := make([]*query, len(c.queries)) - copy(queries, c.queries) - c.mu.Unlock() - - var answered []*query - for _, query := range queries { - queryCopy := query - if queryCopy.nameWithSuffix == a.Name.String() { - addr, err := addrFromAnswerHeader(a, p) - if err != nil { - c.log.Warnf("[%s] failed to parse mDNS answer %v", c.name, err) - return - } + c.mu.Lock() + queries := make([]*query, len(c.queries)) + copy(queries, c.queries) + c.mu.Unlock() + + var answered []*query + for _, query := range queries { + queryCopy := query + if queryCopy.nameWithSuffix == a.Header.Name.String() { + addr, err := addrFromAnswer(a) + if err != nil { + c.log.Warnf("[%s] failed to parse mDNS answer %v", c.name, err) + return + } - resultAddr := *addr - // DNS records don't contain IPv6 zones. - // We're trusting that since we're on the same link, that we will only - // be sent link-local addresses from that source's interface's address. - // If it's not present, we're out of luck since we cannot rely on the - // interface zone to be the same as the source's. - resultAddr = addrWithOptionalZone(resultAddr, srcAddr.Zone) - - select { - case queryCopy.queryResultChan <- queryResult{a, resultAddr}: - answered = append(answered, queryCopy) - default: + resultAddr := *addr + // DNS records don't contain IPv6 zones. + // We're trusting that since we're on the same link, that we will only + // be sent link-local addresses from that source's interface's address. + // If it's not present, we're out of luck since we cannot rely on the + // interface zone to be the same as the source's. + resultAddr = addrWithOptionalZone(resultAddr, srcAddr.Zone) + + select { + case queryCopy.queryResultChan <- queryResult{a.Header, resultAddr}: + answered = append(answered, queryCopy) + default: + } } } - } - c.mu.Lock() - for queryIdx := len(c.queries) - 1; queryIdx >= 0; queryIdx-- { - for answerIdx := len(answered) - 1; answerIdx >= 0; answerIdx-- { - if c.queries[queryIdx] == answered[answerIdx] { - c.queries = append(c.queries[:queryIdx], c.queries[queryIdx+1:]...) - answered = append(answered[:answerIdx], answered[answerIdx+1:]...) - queryIdx-- - break + c.mu.Lock() + for queryIdx := len(c.queries) - 1; queryIdx >= 0; queryIdx-- { + for answerIdx := len(answered) - 1; answerIdx >= 0; answerIdx-- { + if c.queries[queryIdx] == answered[answerIdx] { + c.queries = append(c.queries[:queryIdx], c.queries[queryIdx+1:]...) + answered = append(answered[:answerIdx], answered[answerIdx+1:]...) + queryIdx-- + break + } } } + c.mu.Unlock() } - c.mu.Unlock() } }() } @@ -1171,31 +1166,30 @@ func (c *Conn) start(started chan<- struct{}, inboundBufferSize int, config *Con } } -func addrFromAnswerHeader(a dnsmessage.ResourceHeader, p dnsmessage.Parser) (addr *netip.Addr, err error) { - if a.Type == dnsmessage.TypeA { - resource, err := p.AResource() - if err != nil { - return nil, err - } - ipAddr, ok := netip.AddrFromSlice(resource.A[:]) - if !ok { - return nil, fmt.Errorf("failed to convert A record: %w", ipToAddrError{resource.A[:]}) - } - ipAddr = ipAddr.Unmap() // do not want 4-in-6 - addr = &ipAddr - } else { - resource, err := p.AAAAResource() - if err != nil { - return nil, err +func addrFromAnswer(answer dnsmessage.Resource) (*netip.Addr, error) { + switch answer.Header.Type { + case dnsmessage.TypeA: + if a, ok := answer.Body.(*dnsmessage.AResource); ok { + addr, ok := netip.AddrFromSlice(a.A[:]) + if ok { + addr = addr.Unmap() // do not want 4-in-6 + return &addr, nil + } } - ipAddr, ok := netip.AddrFromSlice(resource.AAAA[:]) - if !ok { - return nil, fmt.Errorf("failed to convert AAAA record: %w", ipToAddrError{resource.AAAA[:]}) + + return nil, errFailedToDecodeAddrFromAResource + case dnsmessage.TypeAAAA: + if a, ok := answer.Body.(*dnsmessage.AAAAResource); ok { + addr, ok := netip.AddrFromSlice(a.AAAA[:]) + if ok { + return &addr, nil + } } - addr = &ipAddr - } - return + return nil, errFailedToDecodeAddrFromAAAAResource + default: + return nil, errUnhandledAnswerHeaderType + } } func isSupportedIPv6(addr netip.Addr, ipv6Only bool) bool { diff --git a/conn_test.go b/conn_test.go index 01ea52b..c8dd02b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -66,7 +66,7 @@ func createListener6(t *testing.T) *net.UDPConn { } func TestValidCommunication(t *testing.T) { - lim := test.TimeOut(time.Second * 10) + lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) @@ -672,30 +672,21 @@ func TestQueryRespectClose(t *testing.T) { } } -func TestResourceParsing(t *testing.T) { +func testResourceParsing(t *testing.T, echoQuery bool) { lookForIP := func(msg dnsmessage.Message, expectedIP []byte, t *testing.T) { - buf, err := msg.Pack() + actualAddr, err := addrFromAnswer(msg.Answers[0]) if err != nil { t.Fatal(err) } - var p dnsmessage.Parser - if _, err = p.Start(buf); err != nil { - t.Fatal(err) - } - - if err = p.SkipAllQuestions(); err != nil { - t.Fatal(err) - } - - h, err := p.AnswerHeader() - if err != nil { - t.Fatal(err) - } - - actualAddr, err := addrFromAnswerHeader(h, p) - if err != nil { - t.Fatal(err) + if echoQuery { + if len(msg.Questions) == 0 { + t.Fatal("Echoed query not included in answer") + } + } else { + if len(msg.Questions) > 0 { + t.Fatal("Echoed query erroneously included in answer") + } } if !bytes.Equal(actualAddr.AsSlice(), expectedIP) { @@ -704,10 +695,16 @@ func TestResourceParsing(t *testing.T) { } name := "test-server." - q := dnsmessage.Question{Name: dnsmessage.MustNewName(name)} + + config := &Config{ + DoNotEchoQueryWithAnswer: !echoQuery, + } t.Run("A Record", func(t *testing.T) { - answer, err := createAnswer(1, q, mustAddr(net.IP{127, 0, 0, 1})) + answer, err := createAnswer(1, dnsmessage.Question{ + Name: dnsmessage.MustNewName(name), + Type: dnsmessage.TypeA, + }, mustAddr(net.IP{127, 0, 0, 1}), config) if err != nil { t.Fatal(err) } @@ -715,7 +712,10 @@ func TestResourceParsing(t *testing.T) { }) t.Run("AAAA Record", func(t *testing.T) { - answer, err := createAnswer(1, q, netip.MustParseAddr("::1")) + answer, err := createAnswer(1, dnsmessage.Question{ + Name: dnsmessage.MustNewName(name), + Type: dnsmessage.TypeAAAA, + }, netip.MustParseAddr("::1"), config) if err != nil { t.Fatal(err) } @@ -723,6 +723,14 @@ func TestResourceParsing(t *testing.T) { }) } +func TestResourceParsingWithEchoedQuery(t *testing.T) { + testResourceParsing(t, true) +} + +func TestResourceParsingWithoutEchoedQuery(t *testing.T) { + testResourceParsing(t, false) +} + func mustAddr(ip net.IP) netip.Addr { addr, ok := netip.AddrFromSlice(ip) if !ok { @@ -777,3 +785,57 @@ func TestIPToBytes(t *testing.T) { t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualAddr6) } } + +// Test for our client side handling cases where the server may or may +// not have included the echoed query with their answer. +func testAnswerHandlingWithQueryEchoed(t *testing.T, echoQuery bool) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock := createListener4(t) + bSock := createListener4(t) + + aServer, err := Server(ipv4.NewPacketConn(aSock), nil, &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + DoNotEchoQueryWithAnswer: !echoQuery, + }) + check(err, t) + + bServer, err := Server(ipv4.NewPacketConn(bSock), nil, &Config{}) + check(err, t) + + _, addr, err := bServer.QueryAddr(context.TODO(), "pion-mdns-1.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + checkIPv4(addr, t) + + _, addr, err = bServer.QueryAddr(context.TODO(), "pion-mdns-2.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + checkIPv4(addr, t) + + check(aServer.Close(), t) + check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } +} + +func TestAnswerHandlingWithQueryEchoed(t *testing.T) { + testAnswerHandlingWithQueryEchoed(t, true) +} + +func TestAnswerHandlingWithoutQueryEchoed(t *testing.T) { + testAnswerHandlingWithQueryEchoed(t, false) +} diff --git a/examples/query/main.go b/examples/query/main.go index 56b200e..3539d05 100644 --- a/examples/query/main.go +++ b/examples/query/main.go @@ -10,7 +10,7 @@ import ( "net" "os" - "github.com/atomirex/mdns" + "github.com/pion/mdns/v2" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) diff --git a/examples/server/main.go b/examples/server/main.go index e1394ea..ce9b20b 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -7,7 +7,7 @@ package main import ( "net" - "github.com/atomirex/mdns" + "github.com/pion/mdns/v2" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) diff --git a/examples/server/publish_ip/main.go b/examples/server/publish_ip/main.go index 8e26e41..968991f 100644 --- a/examples/server/publish_ip/main.go +++ b/examples/server/publish_ip/main.go @@ -9,7 +9,7 @@ import ( "flag" "net" - "github.com/atomirex/mdns" + "github.com/pion/mdns/v2" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) diff --git a/go.mod b/go.mod index 0d2cfe3..67211ff 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/atomirex/mdns +module github.com/pion/mdns/v2 go 1.20