diff --git a/config.go b/config.go index 4659e06..8274984 100644 --- a/config.go +++ b/config.go @@ -46,4 +46,7 @@ type Config struct { // Interfaces will override the interfaces used for queries and answers. Interfaces []net.Interface + + // Override the default behavior of echoing the query with the answer + DoNotEchoQueryWithAnswer bool } diff --git a/conn.go b/conn.go index d300163..c7759d5 100644 --- a/conn.go +++ b/conn.go @@ -65,10 +65,13 @@ const ( ) var ( - errNoPositiveMTUFound = errors.New("no positive MTU found") - errNoPacketConn = errors.New("must supply at least a multicast IPv4 or IPv6 PacketConn") - errNoUsableInterfaces = errors.New("no usable interfaces found for mDNS") - errFailedToClose = errors.New("failed to close mDNS Conn") + errNoPositiveMTUFound = errors.New("no positive MTU found") + errNoPacketConn = errors.New("must supply at least a multicast IPv4 or IPv6 PacketConn") + errNoUsableInterfaces = errors.New("no usable interfaces found for mDNS") + errFailedToClose = errors.New("failed to close mDNS Conn") + errFailedToDecodeAddrFromAResource = errors.New("failed to decode netip.Addr from A type Resource") + errFailedToDecodeAddrFromAAAAResource = errors.New("failed to decode netip.Addr from AAAA type Resource") + errUnhandledAnswerHeaderType = errors.New("header for Answer had unhandled type") ) type netInterface struct { @@ -710,8 +713,8 @@ func (c *Conn) writeToSocket( } } -func createAnswer(id uint16, name string, addr netip.Addr) (dnsmessage.Message, error) { - packedName, err := dnsmessage.NewName(name) +func createAnswer(id uint16, q dnsmessage.Question, addr netip.Addr, config *Config) (dnsmessage.Message, error) { + packedName, err := dnsmessage.NewName(q.Name.String()) if err != nil { return dnsmessage.Message{}, err } @@ -733,6 +736,12 @@ func createAnswer(id uint16, name string, addr netip.Addr) (dnsmessage.Message, }, } + // This is a negative because we want to default to echoing the query with an answer + // The main use of turning it off is in testing + if !config.DoNotEchoQueryWithAnswer { + msg.Questions = []dnsmessage.Question{q} + } + if addr.Is4() { ipBuf, err := ipv4ToBytes(addr) if err != nil { @@ -757,8 +766,8 @@ func createAnswer(id uint16, name string, addr netip.Addr) (dnsmessage.Message, return msg, nil } -func (c *Conn) sendAnswer(queryID uint16, name string, ifIndex int, result netip.Addr, dst *net.UDPAddr) { - answer, err := createAnswer(queryID, name, result) +func (c *Conn) sendAnswer(queryID uint16, q dnsmessage.Question, ifIndex int, result netip.Addr, dst *net.UDPAddr, config *Config) { + answer, err := createAnswer(queryID, q, result, config) if err != nil { c.log.Warnf("[%s] failed to create mDNS answer %v", c.name, err) return @@ -857,7 +866,6 @@ func (c ipPacketConn6) Close() error { func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int, config *Config) { //nolint:gocognit b := make([]byte, inboundBufferSize) - p := dnsmessage.Parser{} for { n, cm, src, err := pktConn.ReadFrom(b) @@ -885,226 +893,214 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int } func() { - header, err := p.Start(b[:n]) + var msg dnsmessage.Message + err := msg.Unpack(b[:n]) if err != nil { c.log.Warnf("[%s] failed to parse mDNS packet %v", c.name, err) return } - for i := 0; i <= maxMessageRecords; i++ { - q, err := p.Question() - if errors.Is(err, dnsmessage.ErrSectionDone) { - break - } else if err != nil { - c.log.Warnf("[%s] failed to parse mDNS packet %v", c.name, err) - return - } + // Questions are often echoed with answers, therefore + // If we have more questions than answers it is a question we might need to respond to + if len(msg.Questions) > len(msg.Answers) { + for _, q := range msg.Questions { + if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA { + continue + } - if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA { - continue - } + // https://datatracker.ietf.org/doc/html/rfc6762#section-6 + // The destination UDP port in all Multicast DNS responses MUST be 5353, + // and the destination address MUST be the mDNS IPv4 link-local + // multicast address 224.0.0.251 or its IPv6 equivalent FF02::FB, except + // when generating a reply to a query that explicitly requested a + // unicast response + shouldUnicastResponse := (q.Class&(1<<15)) != 0 || // via the unicast-response bit + srcAddr.Port != 5353 || // by virtue of being a legacy query (Section 6.7), or + (len(pktDst) != 0 && !(pktDst.Equal(c.dstAddr4.IP) || // by virtue of being a direct unicast query + pktDst.Equal(c.dstAddr6.IP))) + var dst *net.UDPAddr + if shouldUnicastResponse { + dst = srcAddr + } - // https://datatracker.ietf.org/doc/html/rfc6762#section-6 - // The destination UDP port in all Multicast DNS responses MUST be 5353, - // and the destination address MUST be the mDNS IPv4 link-local - // multicast address 224.0.0.251 or its IPv6 equivalent FF02::FB, except - // when generating a reply to a query that explicitly requested a - // unicast response - shouldUnicastResponse := (q.Class&(1<<15)) != 0 || // via the unicast-response bit - srcAddr.Port != 5353 || // by virtue of being a legacy query (Section 6.7), or - (len(pktDst) != 0 && !(pktDst.Equal(c.dstAddr4.IP) || // by virtue of being a direct unicast query - pktDst.Equal(c.dstAddr6.IP))) - var dst *net.UDPAddr - if shouldUnicastResponse { - dst = srcAddr - } + queryWantsV4 := q.Type == dnsmessage.TypeA - queryWantsV4 := q.Type == dnsmessage.TypeA - - for _, localName := range c.localNames { - if localName == q.Name.String() { - var localAddress *netip.Addr - if config.LocalAddress != nil { - // this means the LocalAddress does not support link-local since - // we have no zone to set here. - ipAddr, ok := netip.AddrFromSlice(config.LocalAddress) - if !ok { - c.log.Warnf("[%s] failed to convert config.LocalAddress '%s' to netip.Addr", c.name, config.LocalAddress) - continue - } - if c.multicastPktConnV4 != nil { - // don't want mapping since we also support IPv4/A - ipAddr = ipAddr.Unmap() - } - localAddress = &ipAddr - } else { - // prefer the address of the interface if we know its index, but otherwise - // derive it from the address we read from. We do this because even if - // multicast loopback is in use or we send from a loopback interface, - // there are still cases where the IP packet will contain the wrong - // source IP (e.g. a LAN interface). - // For example, we can have a packet that has: - // Source: 192.168.65.3 - // Destination: 224.0.0.251 - // Interface Index: 1 - // Interface Addresses @ 1: [127.0.0.1/8 ::1/128] - if ifIndex != -1 { - ifc, ok := c.ifaces[ifIndex] + for _, localName := range c.localNames { + if localName == q.Name.String() { + var localAddress *netip.Addr + if config.LocalAddress != nil { + // this means the LocalAddress does not support link-local since + // we have no zone to set here. + ipAddr, ok := netip.AddrFromSlice(config.LocalAddress) if !ok { - c.log.Warnf("[%s] no interface for %d", c.name, ifIndex) - return + c.log.Warnf("[%s] failed to convert config.LocalAddress '%s' to netip.Addr", c.name, config.LocalAddress) + continue } - var selectedAddrs []netip.Addr - for _, addr := range ifc.ipAddrs { - addrCopy := addr - - // match up respective IP types based on question - if queryWantsV4 { - if addrCopy.Is4In6() { - // we may allow 4-in-6, but the question wants an A record - addrCopy = addrCopy.Unmap() - } - if !addrCopy.Is4() { - continue - } - } else { // queryWantsV6 - if !addrCopy.Is6() { - continue + if c.multicastPktConnV4 != nil { + // don't want mapping since we also support IPv4/A + ipAddr = ipAddr.Unmap() + } + localAddress = &ipAddr + } else { + // prefer the address of the interface if we know its index, but otherwise + // derive it from the address we read from. We do this because even if + // multicast loopback is in use or we send from a loopback interface, + // there are still cases where the IP packet will contain the wrong + // source IP (e.g. a LAN interface). + // For example, we can have a packet that has: + // Source: 192.168.65.3 + // Destination: 224.0.0.251 + // Interface Index: 1 + // Interface Addresses @ 1: [127.0.0.1/8 ::1/128] + if ifIndex != -1 { + ifc, ok := c.ifaces[ifIndex] + if !ok { + c.log.Warnf("[%s] no interface for %d", c.name, ifIndex) + return + } + var selectedAddrs []netip.Addr + for _, addr := range ifc.ipAddrs { + addrCopy := addr + + // match up respective IP types based on question + if queryWantsV4 { + if addrCopy.Is4In6() { + // we may allow 4-in-6, but the question wants an A record + addrCopy = addrCopy.Unmap() + } + if !addrCopy.Is4() { + continue + } + } else { // queryWantsV6 + if !addrCopy.Is6() { + continue + } + if !isSupportedIPv6(addrCopy, c.multicastPktConnV4 == nil) { + c.log.Debugf("[%s] interface %d address not a supported IPv6 address %s", c.name, ifIndex, &addrCopy) + continue + } } - if !isSupportedIPv6(addrCopy, c.multicastPktConnV4 == nil) { - c.log.Debugf("[%s] interface %d address not a supported IPv6 address %s", c.name, ifIndex, &addrCopy) - continue + + selectedAddrs = append(selectedAddrs, addrCopy) + } + if len(selectedAddrs) == 0 { + c.log.Debugf("[%s] failed to find suitable IP for interface %d; deriving address from source address c.name,instead", c.name, ifIndex) + } else { + // choose the best match + var choice *netip.Addr + for _, option := range selectedAddrs { + optCopy := option + if option.Is4() { + // select first + choice = &optCopy + break + } + // we're okay with 4In6 for now but ideally we get a an actual IPv6. + // Maybe in the future we never want this but it does look like Docker + // can route IPv4 over IPv6. + if choice == nil { + choice = &optCopy + } else if !optCopy.Is4In6() { + choice = &optCopy + } + if !optCopy.Is4In6() { + break + } + // otherwise keep searching for an actual IPv6 } + localAddress = choice } - - selectedAddrs = append(selectedAddrs, addrCopy) } - if len(selectedAddrs) == 0 { - c.log.Debugf("[%s] failed to find suitable IP for interface %d; deriving address from source address c.name,instead", c.name, ifIndex) - } else { - // choose the best match - var choice *netip.Addr - for _, option := range selectedAddrs { - optCopy := option - if option.Is4() { - // select first - choice = &optCopy - break - } - // we're okay with 4In6 for now but ideally we get a an actual IPv6. - // Maybe in the future we never want this but it does look like Docker - // can route IPv4 over IPv6. - if choice == nil { - choice = &optCopy - } else if !optCopy.Is4In6() { - choice = &optCopy - } - if !optCopy.Is4In6() { - break - } - // otherwise keep searching for an actual IPv6 + if ifIndex == -1 || localAddress == nil { + localAddress, err = interfaceForRemote(src.String()) + if err != nil { + c.log.Warnf("[%s] failed to get local interface to communicate with %s: %v", c.name, src.String(), err) + continue } - localAddress = choice } } - if ifIndex == -1 || localAddress == nil { - localAddress, err = interfaceForRemote(src.String()) - if err != nil { - c.log.Warnf("[%s] failed to get local interface to communicate with %s: %v", c.name, src.String(), err) + if queryWantsV4 { + if !localAddress.Is4() { + c.log.Debugf("[%s] have IPv6 address %s to respond with but question is for A not c.name,AAAA", c.name, localAddress) + continue + } + } else { + if !localAddress.Is6() { + c.log.Debugf("[%s] have IPv4 address %s to respond with but question is for AAAA not c.name,A", c.name, localAddress) + continue + } + if !isSupportedIPv6(*localAddress, c.multicastPktConnV4 == nil) { + c.log.Debugf("[%s] got local interface address but not a supported IPv6 address %v", c.name, localAddress) continue } } - } - if queryWantsV4 { - if !localAddress.Is4() { - c.log.Debugf("[%s] have IPv6 address %s to respond with but question is for A not c.name,AAAA", c.name, localAddress) - continue - } - } else { - if !localAddress.Is6() { - c.log.Debugf("[%s] have IPv4 address %s to respond with but question is for AAAA not c.name,A", c.name, localAddress) - continue - } - if !isSupportedIPv6(*localAddress, c.multicastPktConnV4 == nil) { - c.log.Debugf("[%s] got local interface address but not a supported IPv6 address %v", c.name, localAddress) + + if dst != nil && len(dst.IP) == net.IPv4len && + localAddress.Is6() && + localAddress.Zone() != "" && + (localAddress.IsLinkLocalUnicast() || localAddress.IsLinkLocalMulticast()) { + // This case happens when multicast v4 picks up an AAAA question that has a zone + // in the address. Since we cannot send this zone over DNS (it's meaningless), + // the other side can only infer this via the response interface on the other + // side (some IPv6 interface). + c.log.Debugf("[%s] refusing to send link-local address %s to an IPv4 destination %s", c.name, localAddress, dst) continue } + c.log.Debugf("[%s] sending response for %s on ifc %d of %s to %s", c.name, q.Name, ifIndex, *localAddress, dst) + c.sendAnswer(msg.Header.ID, q, ifIndex, *localAddress, dst, config) } - - if dst != nil && len(dst.IP) == net.IPv4len && - localAddress.Is6() && - localAddress.Zone() != "" && - (localAddress.IsLinkLocalUnicast() || localAddress.IsLinkLocalMulticast()) { - // This case happens when multicast v4 picks up an AAAA question that has a zone - // in the address. Since we cannot send this zone over DNS (it's meaningless), - // the other side can only infer this via the response interface on the other - // side (some IPv6 interface). - c.log.Debugf("[%s] refusing to send link-local address %s to an IPv4 destination %s", c.name, localAddress, dst) - continue - } - c.log.Debugf("[%s] sending response for %s on ifc %d of %s to %s", c.name, q.Name, ifIndex, *localAddress, dst) - c.sendAnswer(header.ID, q.Name.String(), ifIndex, *localAddress, dst) } } - } - - for i := 0; i <= maxMessageRecords; i++ { - a, err := p.AnswerHeader() - if errors.Is(err, dnsmessage.ErrSectionDone) { - return - } - if err != nil { - c.log.Warnf("[%s] failed to parse mDNS packet %v", c.name, err) - return - } - - if a.Type != dnsmessage.TypeA && a.Type != dnsmessage.TypeAAAA { - continue - } + } else { + for _, a := range msg.Answers { + if a.Header.Type != dnsmessage.TypeA && a.Header.Type != dnsmessage.TypeAAAA { + continue + } - c.mu.Lock() - queries := make([]*query, len(c.queries)) - copy(queries, c.queries) - c.mu.Unlock() - - var answered []*query - for _, query := range queries { - queryCopy := query - if queryCopy.nameWithSuffix == a.Name.String() { - addr, err := addrFromAnswerHeader(a, p) - if err != nil { - c.log.Warnf("[%s] failed to parse mDNS answer %v", c.name, err) - return - } + c.mu.Lock() + queries := make([]*query, len(c.queries)) + copy(queries, c.queries) + c.mu.Unlock() + + var answered []*query + for _, query := range queries { + queryCopy := query + if queryCopy.nameWithSuffix == a.Header.Name.String() { + addr, err := addrFromAnswer(a) + if err != nil { + c.log.Warnf("[%s] failed to parse mDNS answer %v", c.name, err) + return + } - resultAddr := *addr - // DNS records don't contain IPv6 zones. - // We're trusting that since we're on the same link, that we will only - // be sent link-local addresses from that source's interface's address. - // If it's not present, we're out of luck since we cannot rely on the - // interface zone to be the same as the source's. - resultAddr = addrWithOptionalZone(resultAddr, srcAddr.Zone) - - select { - case queryCopy.queryResultChan <- queryResult{a, resultAddr}: - answered = append(answered, queryCopy) - default: + resultAddr := *addr + // DNS records don't contain IPv6 zones. + // We're trusting that since we're on the same link, that we will only + // be sent link-local addresses from that source's interface's address. + // If it's not present, we're out of luck since we cannot rely on the + // interface zone to be the same as the source's. + resultAddr = addrWithOptionalZone(resultAddr, srcAddr.Zone) + + select { + case queryCopy.queryResultChan <- queryResult{a.Header, resultAddr}: + answered = append(answered, queryCopy) + default: + } } } - } - c.mu.Lock() - for queryIdx := len(c.queries) - 1; queryIdx >= 0; queryIdx-- { - for answerIdx := len(answered) - 1; answerIdx >= 0; answerIdx-- { - if c.queries[queryIdx] == answered[answerIdx] { - c.queries = append(c.queries[:queryIdx], c.queries[queryIdx+1:]...) - answered = append(answered[:answerIdx], answered[answerIdx+1:]...) - queryIdx-- - break + c.mu.Lock() + for queryIdx := len(c.queries) - 1; queryIdx >= 0; queryIdx-- { + for answerIdx := len(answered) - 1; answerIdx >= 0; answerIdx-- { + if c.queries[queryIdx] == answered[answerIdx] { + c.queries = append(c.queries[:queryIdx], c.queries[queryIdx+1:]...) + answered = append(answered[:answerIdx], answered[answerIdx+1:]...) + queryIdx-- + break + } } } + c.mu.Unlock() } - c.mu.Unlock() } }() } @@ -1170,31 +1166,30 @@ func (c *Conn) start(started chan<- struct{}, inboundBufferSize int, config *Con } } -func addrFromAnswerHeader(a dnsmessage.ResourceHeader, p dnsmessage.Parser) (addr *netip.Addr, err error) { - if a.Type == dnsmessage.TypeA { - resource, err := p.AResource() - if err != nil { - return nil, err - } - ipAddr, ok := netip.AddrFromSlice(resource.A[:]) - if !ok { - return nil, fmt.Errorf("failed to convert A record: %w", ipToAddrError{resource.A[:]}) - } - ipAddr = ipAddr.Unmap() // do not want 4-in-6 - addr = &ipAddr - } else { - resource, err := p.AAAAResource() - if err != nil { - return nil, err +func addrFromAnswer(answer dnsmessage.Resource) (*netip.Addr, error) { + switch answer.Header.Type { + case dnsmessage.TypeA: + if a, ok := answer.Body.(*dnsmessage.AResource); ok { + addr, ok := netip.AddrFromSlice(a.A[:]) + if ok { + addr = addr.Unmap() // do not want 4-in-6 + return &addr, nil + } } - ipAddr, ok := netip.AddrFromSlice(resource.AAAA[:]) - if !ok { - return nil, fmt.Errorf("failed to convert AAAA record: %w", ipToAddrError{resource.AAAA[:]}) + + return nil, errFailedToDecodeAddrFromAResource + case dnsmessage.TypeAAAA: + if a, ok := answer.Body.(*dnsmessage.AAAAResource); ok { + addr, ok := netip.AddrFromSlice(a.AAAA[:]) + if ok { + return &addr, nil + } } - addr = &ipAddr - } - return + return nil, errFailedToDecodeAddrFromAAAAResource + default: + return nil, errUnhandledAnswerHeaderType + } } func isSupportedIPv6(addr netip.Addr, ipv6Only bool) bool { diff --git a/conn_test.go b/conn_test.go index 2ae4e18..c8dd02b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -66,7 +66,7 @@ func createListener6(t *testing.T) *net.UDPConn { } func TestValidCommunication(t *testing.T) { - lim := test.TimeOut(time.Second * 10) + lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) @@ -672,30 +672,21 @@ func TestQueryRespectClose(t *testing.T) { } } -func TestResourceParsing(t *testing.T) { +func testResourceParsing(t *testing.T, echoQuery bool) { lookForIP := func(msg dnsmessage.Message, expectedIP []byte, t *testing.T) { - buf, err := msg.Pack() + actualAddr, err := addrFromAnswer(msg.Answers[0]) if err != nil { t.Fatal(err) } - var p dnsmessage.Parser - if _, err = p.Start(buf); err != nil { - t.Fatal(err) - } - - if err = p.SkipAllQuestions(); err != nil { - t.Fatal(err) - } - - h, err := p.AnswerHeader() - if err != nil { - t.Fatal(err) - } - - actualAddr, err := addrFromAnswerHeader(h, p) - if err != nil { - t.Fatal(err) + if echoQuery { + if len(msg.Questions) == 0 { + t.Fatal("Echoed query not included in answer") + } + } else { + if len(msg.Questions) > 0 { + t.Fatal("Echoed query erroneously included in answer") + } } if !bytes.Equal(actualAddr.AsSlice(), expectedIP) { @@ -705,8 +696,15 @@ func TestResourceParsing(t *testing.T) { name := "test-server." + config := &Config{ + DoNotEchoQueryWithAnswer: !echoQuery, + } + t.Run("A Record", func(t *testing.T) { - answer, err := createAnswer(1, name, mustAddr(net.IP{127, 0, 0, 1})) + answer, err := createAnswer(1, dnsmessage.Question{ + Name: dnsmessage.MustNewName(name), + Type: dnsmessage.TypeA, + }, mustAddr(net.IP{127, 0, 0, 1}), config) if err != nil { t.Fatal(err) } @@ -714,7 +712,10 @@ func TestResourceParsing(t *testing.T) { }) t.Run("AAAA Record", func(t *testing.T) { - answer, err := createAnswer(1, name, netip.MustParseAddr("::1")) + answer, err := createAnswer(1, dnsmessage.Question{ + Name: dnsmessage.MustNewName(name), + Type: dnsmessage.TypeAAAA, + }, netip.MustParseAddr("::1"), config) if err != nil { t.Fatal(err) } @@ -722,6 +723,14 @@ func TestResourceParsing(t *testing.T) { }) } +func TestResourceParsingWithEchoedQuery(t *testing.T) { + testResourceParsing(t, true) +} + +func TestResourceParsingWithoutEchoedQuery(t *testing.T) { + testResourceParsing(t, false) +} + func mustAddr(ip net.IP) netip.Addr { addr, ok := netip.AddrFromSlice(ip) if !ok { @@ -776,3 +785,57 @@ func TestIPToBytes(t *testing.T) { t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualAddr6) } } + +// Test for our client side handling cases where the server may or may +// not have included the echoed query with their answer. +func testAnswerHandlingWithQueryEchoed(t *testing.T, echoQuery bool) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock := createListener4(t) + bSock := createListener4(t) + + aServer, err := Server(ipv4.NewPacketConn(aSock), nil, &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + DoNotEchoQueryWithAnswer: !echoQuery, + }) + check(err, t) + + bServer, err := Server(ipv4.NewPacketConn(bSock), nil, &Config{}) + check(err, t) + + _, addr, err := bServer.QueryAddr(context.TODO(), "pion-mdns-1.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + checkIPv4(addr, t) + + _, addr, err = bServer.QueryAddr(context.TODO(), "pion-mdns-2.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + checkIPv4(addr, t) + + check(aServer.Close(), t) + check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } +} + +func TestAnswerHandlingWithQueryEchoed(t *testing.T) { + testAnswerHandlingWithQueryEchoed(t, true) +} + +func TestAnswerHandlingWithoutQueryEchoed(t *testing.T) { + testAnswerHandlingWithQueryEchoed(t, false) +}