diff --git a/config.go b/config.go index 34fc8fe..9af0f79 100644 --- a/config.go +++ b/config.go @@ -8,13 +8,24 @@ import ( "time" "github.com/pion/logging" + "golang.org/x/net/ipv6" ) const ( // DefaultAddress is the default used by mDNS // and in most cases should be the address that the - // net.Conn passed to Server is bound to + // ipv4.PacketConn passed to Server is bound to DefaultAddress = "224.0.0.0:5353" + + // DefaultAddressIPv4 is the default used by mDNS + // and in most cases should be the address that the + // ipv4.PacketConn passed to Server is bound to + DefaultAddressIPv4 = DefaultAddress + + // DefaultAddressIPv6 is the default IPv6 address used + // by mDNS and in most cases should be the address that + // the ipv6.PacketConn passed to Server is bound to + DefaultAddressIPv6 = "[FF02::]:5353" ) // Config is used to configure a mDNS client or server. @@ -38,4 +49,8 @@ type Config struct { // Interfaces will override the interfaces used for queries and answers. Interfaces []net.Interface + + // MulticastConnV6 is used to receive mDNS questions over IPv6. It can be used in conjunction + // with Server's ipv4.PacketConn or on its own if the Server ipv4 argument is nil. + MulticastConnV6 *ipv6.PacketConn } diff --git a/conn.go b/conn.go index a14613a..f929072 100644 --- a/conn.go +++ b/conn.go @@ -24,6 +24,7 @@ type Conn struct { log logging.LeveledLogger multicastPktConnV4 ipPacketConn + multicastPktConnV6 ipPacketConn dstAddr4 *net.UDPAddr dstAddr6 *net.UDPAddr @@ -33,7 +34,7 @@ type Conn struct { queryInterval time.Duration localNames []string queries []*query - ifaces []net.Interface + ifaces map[int]netInterface closed chan interface{} } @@ -62,7 +63,18 @@ const ( maxPacketSize = 9000 ) -var errNoPositiveMTUFound = errors.New("no positive MTU found") +var ( + errNoPositiveMTUFound = errors.New("no positive MTU found") + errNoPacketConn = errors.New("must supply at least a multicast IPv4 or IPv6 PacketConn") + errNoUsableInterfaces = errors.New("no usable interfaces found for mDNS") +) + +type netInterface struct { + net.Interface + ips []net.IP + supportsV4 bool + supportsV6 bool +} // Server establishes a mDNS connection over an existing conn. // @@ -78,6 +90,11 @@ func Server(multicastPktConnV4 *ipv4.PacketConn, config *Config) (*Conn, error) } log := loggerFactory.NewLogger("mdns") + if multicastPktConnV4 == nil && config.MulticastConnV6 == nil { + return nil, errNoPacketConn + } + multicastPktConnV6 := config.MulticastConnV6 + ifaces := config.Interfaces if ifaces == nil { var err error @@ -126,33 +143,83 @@ func Server(multicastPktConnV4 *ipv4.PacketConn, config *Config) (*Conn, error) inboundBufferSize := 0 joinErrCount := 0 - ifacesToUse := make([]net.Interface, 0, len(ifaces)) + ifacesToUse := make(map[int]netInterface, len(ifaces)) for i := range ifaces { ifc := ifaces[i] if !config.IncludeLoopback && ifc.Flags&net.FlagLoopback == net.FlagLoopback { continue } - if err := multicastPktConnV4.JoinGroup(&ifc, multicastGroupAddr4); err != nil { + if ifc.Flags&net.FlagUp == 0 { + continue + } + + addrs, err := ifc.Addrs() + if err != nil { + continue + } + var supportsV4, supportsV6 bool + ifcIPs := make([]net.IP, 0, len(addrs)) + for _, addr := range addrs { + var ip net.IP + switch addr := addr.(type) { + case *net.IPNet: + ip = addr.IP + case *net.IPAddr: + ip = addr.IP + default: + continue + } + if ip.To4() == nil { + supportsV6 = true + } else { + supportsV4 = true + } + ifcIPs = append(ifcIPs, ip) + } + if !(supportsV4 || supportsV6) { + continue + } + + var atLeastOneJoin bool + if supportsV4 && multicastPktConnV4 != nil { + if err := multicastPktConnV4.JoinGroup(&ifc, multicastGroupAddr4); err == nil { + atLeastOneJoin = true + } + } + if supportsV6 && multicastPktConnV6 != nil { + if err := multicastPktConnV6.JoinGroup(&ifc, multicastGroupAddr6); err == nil { + atLeastOneJoin = true + } + } + if !atLeastOneJoin { joinErrCount++ continue } - ifacesToUse = append(ifacesToUse, ifc) + ifacesToUse[ifc.Index] = netInterface{ + Interface: ifc, + ips: ifcIPs, + supportsV4: supportsV4, + supportsV6: supportsV6, + } if ifc.MTU > inboundBufferSize { inboundBufferSize = ifc.MTU } - if unicastPktConnV4 != nil { + if supportsV4 && unicastPktConnV4 != nil { if err := unicastPktConnV4.JoinGroup(&ifc, multicastGroupAddr4); err != nil { - log.Warnf("Failed to JoinGroup on unicast IPv4 connection %v", err) + log.Debugf("failed to JoinGroup on unicast IPv4 connection for interface %d: %v", ifc.Index, err) } } - if unicastPktConnV6 != nil { + if supportsV6 && unicastPktConnV6 != nil { if err := unicastPktConnV6.JoinGroup(&ifc, multicastGroupAddr6); err != nil { - log.Warnf("Failed to JoinGroup on unicast IPv6 connection %v", err) + log.Debugf("failed to JoinGroup on unicast IPv6 connection for interface %d: %v", ifc.Index, err) } } } + if len(ifacesToUse) == 0 { + return nil, errNoUsableInterfaces + } if inboundBufferSize == 0 { return nil, errNoPositiveMTUFound } @@ -173,55 +240,70 @@ func Server(multicastPktConnV4 *ipv4.PacketConn, config *Config) (*Conn, error) return nil, err } - localNames := []string{} + var localNames []string for _, l := range config.LocalNames { localNames = append(localNames, l+".") } c := &Conn{ - queryInterval: defaultQueryInterval, - multicastPktConnV4: ipPacketConn4{multicastPktConnV4, log}, - unicastPktConnV4: ipPacketConn4{unicastPktConnV4, log}, - unicastPktConnV6: ipPacketConn6{unicastPktConnV6, log}, - dstAddr4: dstAddr4, - dstAddr6: dstAddr6, - localNames: localNames, - ifaces: ifacesToUse, - log: log, - closed: make(chan interface{}), + queryInterval: defaultQueryInterval, + dstAddr4: dstAddr4, + dstAddr6: dstAddr6, + localNames: localNames, + ifaces: ifacesToUse, + log: log, + closed: make(chan interface{}), } if config.QueryInterval != 0 { c.queryInterval = config.QueryInterval } - if err := multicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil { - c.log.Warnf("Failed to SetControlMessage(ipv4.FlagInterface) on multicast IPv4 PacketConn %v", err) + if multicastPktConnV4 != nil { + if err := multicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil { + c.log.Warnf("failed to SetControlMessage(ipv4.FlagInterface) on multicast IPv4 PacketConn %v", err) + } + c.multicastPktConnV4 = ipPacketConn4{multicastPktConnV4, log} + } + if multicastPktConnV6 != nil { + if err := multicastPktConnV6.SetControlMessage(ipv6.FlagInterface, true); err != nil { + c.log.Warnf("failed to SetControlMessage(ipv6.FlagInterface) on multicast IPv6 PacketConn %v", err) + } + c.multicastPktConnV6 = ipPacketConn6{multicastPktConnV6, log} } if unicastPktConnV4 != nil { if err := unicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil { - c.log.Warnf("Failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", err) + c.log.Warnf("failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", err) } + c.unicastPktConnV4 = ipPacketConn4{unicastPktConnV4, log} } if unicastPktConnV6 != nil { if err := unicastPktConnV6.SetControlMessage(ipv6.FlagInterface, true); err != nil { - c.log.Warnf("Failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", err) + c.log.Warnf("failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", err) } + c.unicastPktConnV6 = ipPacketConn6{unicastPktConnV6, log} } if config.IncludeLoopback { // this is an efficient way for us to send ourselves a message faster instead of it going // further out into the network stack. - if err := multicastPktConnV4.SetMulticastLoopback(true); err != nil { - c.log.Warnf("Failed to SetMulticastLoopback(true) on multicast IPv4 PacketConn %v; this may cause inefficient network path communications", err) + if multicastPktConnV4 != nil { + if err := multicastPktConnV4.SetMulticastLoopback(true); err != nil { + c.log.Warnf("failed to SetMulticastLoopback(true) on multicast IPv4 PacketConn %v; this may cause inefficient network path communications", err) + } + } + if multicastPktConnV6 != nil { + if err := multicastPktConnV6.SetMulticastLoopback(true); err != nil { + c.log.Warnf("failed to SetMulticastLoopback(true) on multicast IPv6 PacketConn %v; this may cause inefficient network path communications", err) + } } if unicastPktConnV4 != nil { if err := unicastPktConnV4.SetMulticastLoopback(true); err != nil { - c.log.Warnf("Failed to SetMulticastLoopback(true) on unicast IPv4 PacketConn %v; this may cause inefficient network path communications", err) + c.log.Warnf("failed to SetMulticastLoopback(true) on unicast IPv4 PacketConn %v; this may cause inefficient network path communications", err) } } if unicastPktConnV6 != nil { if err := unicastPktConnV6.SetMulticastLoopback(true); err != nil { - c.log.Warnf("Failed to SetMulticastLoopback(true) on unicast IPv6 PacketConn %v; this may cause inefficient network path communications", err) + c.log.Warnf("failed to SetMulticastLoopback(true) on unicast IPv6 PacketConn %v; this may cause inefficient network path communications", err) } } } @@ -245,10 +327,17 @@ func (c *Conn) Close() error { } // Once on go1.20, can use errors.Join - var errs error - if err := c.multicastPktConnV4.Close(); err != nil { - errs = multierr.Combine(errs, err) + if c.multicastPktConnV4 != nil { + if err := c.multicastPktConnV4.Close(); err != nil { + errs = multierr.Combine(errs, err) + } + } + + if c.multicastPktConnV6 != nil { + if err := c.multicastPktConnV6.Close(); err != nil { + errs = multierr.Combine(errs, err) + } } if c.unicastPktConnV4 != nil { @@ -381,7 +470,7 @@ const ( func (c *Conn) sendQuestion(name string) { packedName, err := dnsmessage.NewName(name) if err != nil { - c.log.Warnf("Failed to construct mDNS packet %v", err) + c.log.Warnf("failed to construct mDNS packet %v", err) return } @@ -402,28 +491,48 @@ func (c *Conn) sendQuestion(name string) { // get a unicast response back. msg := dnsmessage.Message{ Header: dnsmessage.Header{}, - Questions: []dnsmessage.Question{ - { - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET | (1 << 15), - Name: packedName, - }, - }, + } + + // limit what we ask for based on what IPv is available. In the future, + // this could be an option since there's no reason you cannot get an + // A record on an IPv6 sourced question and vice versa. + if c.multicastPktConnV4 != nil { + msg.Questions = append(msg.Questions, dnsmessage.Question{ + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET | (1 << 15), + Name: packedName, + }) + } + if c.multicastPktConnV6 != nil { + msg.Questions = append(msg.Questions, dnsmessage.Question{ + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET | (1 << 15), + Name: packedName, + }) } rawQuery, err := msg.Pack() if err != nil { - c.log.Warnf("Failed to construct mDNS packet %v", err) + c.log.Warnf("failed to construct mDNS packet %v", err) return } c.writeToSocket(0, rawQuery, false, writeTypeQuestion, nil) } -func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool, wType writeType, dst net.Addr) { //nolint:gocognit - if wType == writeTypeAnswer && dst == nil { - c.log.Error("Writing an answer must specify a destination address") - return +func (c *Conn) writeToSocket(ifIndex int, b []byte, hasLoopbackData bool, wType writeType, unicastDst *net.UDPAddr) { //nolint:gocognit + var dst4, dst6 net.Addr + if wType == writeTypeAnswer { + if unicastDst == nil { + dst4 = c.dstAddr4 + dst6 = c.dstAddr6 + } else { + if unicastDst.IP.To4() == nil { + dst6 = unicastDst + } else { + dst4 = unicastDst + } + } } if ifIndex != 0 { @@ -432,29 +541,36 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool, wType return } - ifc, err := net.InterfaceByIndex(ifIndex) - if err != nil { - c.log.Warnf("Failed to get interface for %d: %v", ifIndex, err) + ifc, ok := c.ifaces[ifIndex] + if !ok { + c.log.Warnf("no interface for %d", ifIndex) return } - if srcIfcIsLoopback && ifc.Flags&net.FlagLoopback == 0 { + if hasLoopbackData && ifc.Flags&net.FlagLoopback == 0 { // avoid accidentally tricking the destination that itself is the same as us - c.log.Warnf("Interface is not loopback %d", ifIndex) + c.log.Debugf("interface is not loopback %d", ifIndex) return } - //nolint:godox - // TODO(https://github.com/pion/mdns/issues/69): ipv6 - c.log.Debugf("writing answer to %s", dst) - if _, err := c.multicastPktConnV4.WriteTo(b, ifc, nil, dst); err != nil { - c.log.Warnf("Failed to send mDNS packet on interface %d: %v", ifIndex, err) + c.log.Debugf("writing answer to IPv4: %v, IPv6: %v", dst4, dst6) + + if ifc.supportsV4 && c.multicastPktConnV4 != nil && dst4 != nil { + if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, dst4); err != nil { + c.log.Warnf("failed to send mDNS packet on IPv4 interface %d: %v", ifIndex, err) + } + } + if ifc.supportsV6 && c.multicastPktConnV6 != nil && dst6 != nil { + if _, err := c.multicastPktConnV6.WriteTo(b, &ifc.Interface, nil, dst6); err != nil { + c.log.Warnf("failed to send mDNS packet on IPv6 interface %d: %v", ifIndex, err) + } } return } for ifcIdx := range c.ifaces { - if srcIfcIsLoopback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 { - // avoid accidentally tricking the destination that itself is the same as us + ifc := c.ifaces[ifcIdx] + if hasLoopbackData { + c.log.Debug("Refusing to send loopback data with non-specific interface") continue } @@ -464,29 +580,42 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool, wType // conn here, we'd be writing from a specific multicast address which won't be able to receive unicast // traffic (it only works when listening on 0.0.0.0/[::]). if c.unicastPktConnV4 == nil && c.unicastPktConnV6 == nil { - c.log.Debugf("writing question to multicast IPv4 %s", c.dstAddr4) - if _, err := c.multicastPktConnV4.WriteTo(b, &c.ifaces[ifcIdx], nil, c.dstAddr4); err != nil { - c.log.Warnf("Failed to send mDNS packet on interface %d: %v", c.ifaces[ifcIdx].Index, err) + c.log.Debugf("writing question to multicast IPv4/6 %s", c.dstAddr4) + if ifc.supportsV4 && c.multicastPktConnV4 != nil { + if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, c.dstAddr4); err != nil { + c.log.Warnf("failed to send mDNS packet (multicast) on IPv4 interface %d: %v", ifc.Index, err) + } + } + if ifc.supportsV6 && c.multicastPktConnV6 != nil { + if _, err := c.multicastPktConnV6.WriteTo(b, &ifc.Interface, nil, c.dstAddr6); err != nil { + c.log.Warnf("failed to send mDNS packet (multicast) on IPv6 interface %d: %v", ifc.Index, err) + } } } - if c.unicastPktConnV4 != nil { + if ifc.supportsV4 && c.unicastPktConnV4 != nil { c.log.Debugf("writing question to unicast IPv4 %s", c.dstAddr4) - if _, err := c.unicastPktConnV4.WriteTo(b, &c.ifaces[ifcIdx], nil, c.dstAddr4); err != nil { - c.log.Warnf("Failed to send mDNS packet on interface %d: %v", c.ifaces[ifcIdx].Index, err) + if _, err := c.unicastPktConnV4.WriteTo(b, &ifc.Interface, nil, c.dstAddr4); err != nil { + c.log.Warnf("failed to send mDNS packet (unicast) on interface %d: %v", ifc.Index, err) } } - if c.unicastPktConnV6 != nil { + if ifc.supportsV6 && c.unicastPktConnV6 != nil { c.log.Debugf("writing question to unicast IPv6 %s", c.dstAddr6) - if _, err := c.unicastPktConnV6.WriteTo(b, &c.ifaces[ifcIdx], nil, c.dstAddr6); err != nil { - c.log.Warnf("Failed to send mDNS packet on interface %d: %v", c.ifaces[ifcIdx].Index, err) + if _, err := c.unicastPktConnV6.WriteTo(b, &ifc.Interface, nil, c.dstAddr6); err != nil { + c.log.Warnf("failed to send mDNS packet (unicast) on interface %d: %v", ifc.Index, err) } } } else { - //nolint:godox - // TODO(https://github.com/pion/mdns/issues/69): ipv6 - c.log.Debugf("writing answer to %s", dst) - if _, err := c.multicastPktConnV4.WriteTo(b, &c.ifaces[ifcIdx], nil, dst); err != nil { - c.log.Warnf("Failed to send mDNS packet on interface %d: %v", c.ifaces[ifcIdx].Index, err) + c.log.Debugf("writing answer to IPv4: %s, IPv6: %s", dst4, dst6) + + if ifc.supportsV4 && c.multicastPktConnV4 != nil && dst4 != nil { + if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, dst4); err != nil { + c.log.Warnf("failed to send mDNS packet (multicast) on IPv4 interface %d: %v", ifIndex, err) + } + } + if ifc.supportsV6 && c.multicastPktConnV6 != nil && dst6 != nil { + if _, err := c.multicastPktConnV6.WriteTo(b, &ifc.Interface, nil, dst6); err != nil { + c.log.Warnf("failed to send mDNS packet (multicast) on IPv6 interface %d: %v", ifIndex, err) + } } } } @@ -506,7 +635,6 @@ func createAnswer(name string, addr net.IP) (dnsmessage.Message, error) { Answers: []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ - Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, Name: packedName, TTL: responseTTL, @@ -520,6 +648,7 @@ func createAnswer(name string, addr net.IP) (dnsmessage.Message, error) { if err != nil { return dnsmessage.Message{}, err } + msg.Answers[0].Header.Type = dnsmessage.TypeA msg.Answers[0].Body = &dnsmessage.AResource{ A: ipBuf, } @@ -528,6 +657,7 @@ func createAnswer(name string, addr net.IP) (dnsmessage.Message, error) { if err != nil { return dnsmessage.Message{}, err } + msg.Answers[0].Header.Type = dnsmessage.TypeAAAA msg.Answers[0].Body = &dnsmessage.AAAAResource{ AAAA: ipBuf, } @@ -536,16 +666,16 @@ func createAnswer(name string, addr net.IP) (dnsmessage.Message, error) { return msg, nil } -func (c *Conn) sendAnswer(name string, ifIndex int, result net.IP, dst net.Addr) { +func (c *Conn) sendAnswer(name string, ifIndex int, result net.IP, dst *net.UDPAddr) { answer, err := createAnswer(name, result) if err != nil { - c.log.Warnf("Failed to create mDNS answer %v", err) + c.log.Warnf("failed to create mDNS answer %v", err) return } rawAnswer, err := answer.Pack() if err != nil { - c.log.Warnf("Failed to construct mDNS packet %v", err) + c.log.Warnf("failed to construct mDNS packet %v", err) return } @@ -583,7 +713,7 @@ func (c ipPacketConn4) WriteTo(b []byte, via *net.Interface, cm *ipControlMessag } } if err := c.conn.SetMulticastInterface(via); err != nil { - c.log.Warnf("Failed to set multicast interface for %d: %v", via.Index, err) + c.log.Warnf("failed to set multicast interface for %d: %v", via.Index, err) return 0, err } return c.conn.WriteTo(b, cm4, dst) @@ -614,7 +744,7 @@ func (c ipPacketConn6) WriteTo(b []byte, via *net.Interface, cm *ipControlMessag } } if err := c.conn.SetMulticastInterface(via); err != nil { - c.log.Warnf("Failed to set multicast interface for %d: %v", via.Index, err) + c.log.Warnf("failed to set multicast interface for %d: %v", via.Index, err) return 0, err } return c.conn.WriteTo(b, cm6, dst) @@ -634,7 +764,7 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int if errors.Is(err, net.ErrClosed) { return } - c.log.Warnf("Failed to ReadFrom %q %v", src, err) + c.log.Warnf("failed to ReadFrom %q %v", src, err) continue } c.log.Debugf("got read on %s from %s", name, src) @@ -645,18 +775,13 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int } srcAddr, ok := src.(*net.UDPAddr) if !ok { - c.log.Warnf("Expected source address %s to be UDP but got %", src, src) + c.log.Warnf("expected source address %s to be UDP but got %", src, src) continue } - srcIP := srcAddr.IP - srcIsIPv4 := srcIP.To4() != nil func() { - c.mu.RLock() - defer c.mu.RUnlock() - if _, err := p.Start(b[:n]); err != nil { - c.log.Warnf("Failed to parse mDNS packet %v", err) + c.log.Warnf("failed to parse mDNS packet %v", err) return } @@ -665,24 +790,28 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int if errors.Is(err, dnsmessage.ErrSectionDone) { break } else if err != nil { - c.log.Warnf("Failed to parse mDNS packet %v", err) + c.log.Warnf("failed to parse mDNS packet %v", err) return } + + if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA { + continue + } + shouldUnicastResponse := (q.Class & (1 << 15)) != 0 - //nolint:godox - // TODO(https://github.com/pion/mdns/issues/69): ipv6 here - dst := c.dstAddr4 + var dst *net.UDPAddr if shouldUnicastResponse { dst = srcAddr } + queryWantsV4 := q.Type == dnsmessage.TypeA + for _, localName := range c.localNames { if localName == q.Name.String() { + var localAddress net.IP if config.LocalAddress != nil { - c.sendAnswer(q.Name.String(), ifIndex, config.LocalAddress, dst) + localAddress = config.LocalAddress } else { - var localAddress net.IP - // prefer the address of the interface if we know its index, but otherwise // derive it from the address we read from. We do this because even if // multicast loopback is in use or we send from a loopback interface, @@ -694,61 +823,64 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int // Interface Index: 1 // Interface Addresses @ 1: [127.0.0.1/8 ::1/128] if ifIndex != 0 { - ifc, netErr := net.InterfaceByIndex(ifIndex) - if netErr != nil { - c.log.Warnf("Failed to get interface for %d: %v", ifIndex, netErr) - continue - } - addrs, addrsErr := ifc.Addrs() - if addrsErr != nil { - c.log.Warnf("Failed to get addresses for interface %d: %v", ifIndex, addrsErr) - continue - } - if len(addrs) == 0 { - c.log.Warnf("Expected more than one address for interface %d", ifIndex) - continue + ifc, ok := c.ifaces[ifIndex] + if !ok { + c.log.Warnf("no interface for %d", ifIndex) + return } var selectedIP net.IP - for _, addr := range addrs { - var ip net.IP - switch addr := addr.(type) { - case *net.IPNet: - ip = addr.IP - case *net.IPAddr: - ip = addr.IP - default: - c.log.Warnf("Failed to determine address type %T from interface %d", addr, ifIndex) - continue - } + for _, ip := range ifc.ips { + ipCopy := ip - // match up respective IP types - if ipv4 := ip.To4(); ipv4 == nil { - if srcIsIPv4 { + // match up respective IP types based on question + if queryWantsV4 { + if ipv4 := ipCopy.To4(); ipv4 == nil { + continue + } + } else { // queryWantsV6 + if ipv6 := ipCopy.To16(); ipv6 == nil { continue - } else if !isSupportedIPv6(ip) { + } else if !isSupportedIPv6(ipv6) { + c.log.Debugf("interface %d address not a supported IPv6 address %s", ifIndex, ipCopy) continue } - } else if !srcIsIPv4 { - continue } - selectedIP = ip + + selectedIP = ipCopy break } if selectedIP == nil { - c.log.Warnf("Failed to find suitable IP for interface %d; deriving address from source address instead", ifIndex) + c.log.Debugf("failed to find suitable IP for interface %d; deriving address from source address instead", ifIndex) } else { localAddress = selectedIP } - } else if ifIndex == 0 || localAddress == nil { + } + if ifIndex == 0 || localAddress == nil { localAddress, err = interfaceForRemote(src.String()) if err != nil { - c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err) + c.log.Warnf("failed to get local interface to communicate with %s: %v", src.String(), err) continue } } - - c.sendAnswer(q.Name.String(), ifIndex, localAddress, dst) } + if queryWantsV4 { + localAddress = localAddress.To4() + if localAddress == nil { + c.log.Debugf("have IPv6 address %s to respond with but not question is for A not AAAA", localAddress) + continue + } + } else { + localAddress = localAddress.To16() + if localAddress == nil { + c.log.Debugf("have IPv4 address %s to respond with but not question is for AAAA not A", localAddress) + continue + } + if !isSupportedIPv6(localAddress) { + c.log.Debugf("got local interface address but not a supported IPv6 address %s", localAddress) + continue + } + } + c.sendAnswer(q.Name.String(), ifIndex, localAddress, dst) } } } @@ -759,7 +891,7 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int return } if err != nil { - c.log.Warnf("Failed to parse mDNS packet %v", err) + c.log.Warnf("failed to parse mDNS packet %v", err) return } @@ -767,20 +899,42 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int continue } - for i := len(c.queries) - 1; i >= 0; i-- { - if c.queries[i].nameWithSuffix == a.Name.String() { + c.mu.Lock() + queries := make([]*query, len(c.queries)) + copy(queries, c.queries) + c.mu.Unlock() + + var answered []*query + for _, query := range queries { + queryCopy := query + if queryCopy.nameWithSuffix == a.Name.String() { ip, err := ipFromAnswerHeader(a, p) if err != nil { - c.log.Warnf("Failed to parse mDNS answer %v", err) + c.log.Warnf("failed to parse mDNS answer %v", err) return } - c.queries[i].queryResultChan <- queryResult{a, &net.IPAddr{ + select { + case queryCopy.queryResultChan <- queryResult{a, &net.IPAddr{ IP: ip, - }} - c.queries = append(c.queries[:i], c.queries[i+1:]...) + }}: + answered = append(answered, queryCopy) + default: + } } } + + c.mu.Lock() + for queryIdx := len(c.queries) - 1; queryIdx >= 0; queryIdx-- { + for answerIdx := len(answered) - 1; answerIdx >= 0; answerIdx-- { + answer := answered[answerIdx] + if c.queries[queryIdx] == answer { + c.queries = append(c.queries[:queryIdx], c.queries[queryIdx+1:]...) + answered = append(answered[:answerIdx], answered[answerIdx+1:]...) + } + } + } + c.mu.Unlock() } }() } @@ -807,6 +961,16 @@ func (c *Conn) start(started chan<- struct{}, inboundBufferSize int, config *Con c.readLoop("multi4", c.multicastPktConnV4, inboundBufferSize, config) }() } + if c.multicastPktConnV6 != nil { + numReaders++ + go func() { + defer func() { + readerEnded <- struct{}{} + }() + readerStarted <- struct{}{} + c.readLoop("multi6", c.multicastPktConnV6, inboundBufferSize, config) + }() + } if c.unicastPktConnV4 != nil { numReaders++ go func() { diff --git a/conn_test.go b/conn_test.go index faa6ad4..ab38f4d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -11,23 +11,27 @@ import ( "context" "errors" "net" + "runtime" "testing" "time" "github.com/pion/transport/v3/test" "golang.org/x/net/dns/dnsmessage" "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) const localAddress = "1.2.3.4" func check(err error, t *testing.T) { + t.Helper() if err != nil { t.Fatal(err) } } func checkIPv4(addr net.Addr, t *testing.T) { + t.Helper() var ip net.IP switch addr := addr.(type) { case *net.IPNet: @@ -43,8 +47,25 @@ func checkIPv4(addr net.Addr, t *testing.T) { } } -func createListener(t *testing.T) *net.UDPConn { - addr, err := net.ResolveUDPAddr("udp", DefaultAddress) +func checkIPv6(addr net.Addr, t *testing.T) { + t.Helper() + var ip net.IP + switch addr := addr.(type) { + case *net.IPNet: + ip = addr.IP + case *net.IPAddr: + ip = addr.IP + default: + t.Fatalf("Failed to determine address type %T", addr) + } + + if ip.To16() == nil { + t.Fatalf("expected IPv6 for answer but got %s", ip) + } +} + +func createListener4(t *testing.T) *net.UDPConn { + addr, err := net.ResolveUDPAddr("udp", DefaultAddressIPv4) check(err, t) sock, err := net.ListenUDP("udp4", addr) @@ -53,6 +74,16 @@ func createListener(t *testing.T) *net.UDPConn { return sock } +func createListener6(t *testing.T) *net.UDPConn { + addr, err := net.ResolveUDPAddr("udp", DefaultAddressIPv6) + check(err, t) + + sock, err := net.ListenUDP("udp6", addr) + check(err, t) + + return sock +} + func TestValidCommunication(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() @@ -60,8 +91,8 @@ func TestValidCommunication(t *testing.T) { report := test.CheckRoutines(t) defer report() - aSock := createListener(t) - bSock := createListener(t) + aSock := createListener4(t) + bSock := createListener4(t) aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{ LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, @@ -119,7 +150,7 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) { report := test.CheckRoutines(t) defer report() - aSock := createListener(t) + aSock := createListener4(t) aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{ LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, @@ -146,7 +177,7 @@ func TestValidCommunicationWithLoopbackAddressConfig(t *testing.T) { report := test.CheckRoutines(t) defer report() - aSock := createListener(t) + aSock := createListener4(t) loopbackIP := net.ParseIP("127.0.0.1") @@ -173,7 +204,7 @@ func TestValidCommunicationWithLoopbackInterface(t *testing.T) { report := test.CheckRoutines(t) defer report() - aSock := createListener(t) + aSock := createListener4(t) ifaces, err := net.Interfaces() check(err, t) @@ -226,6 +257,209 @@ func TestValidCommunicationWithLoopbackInterface(t *testing.T) { check(aServer.Close(), t) } +func TestValidCommunicationIPv6(t *testing.T) { + if runtime.GOARCH == "386" { + t.Skip("IPv6 not supported on 386 for some reason") + } + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + _, err := Server(nil, &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + }) + if !errors.Is(err, errNoPacketConn) { + t.Fatalf("expected error if no PacketConn supplied to Server; got %v", err) + } + + aSock := createListener6(t) + bSock := createListener6(t) + + aServer, err := Server(nil, &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + MulticastConnV6: ipv6.NewPacketConn(aSock), + }) + check(err, t) + + bServer, err := Server(nil, &Config{ + MulticastConnV6: ipv6.NewPacketConn(bSock), + }) + check(err, t) + + _, addr, err := bServer.Query(context.TODO(), "pion-mdns-1.local") + check(err, t) + + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + checkIPv6(addr, t) + + _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + checkIPv6(addr, t) + + check(aServer.Close(), t) + check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } +} + +func TestValidCommunicationIPv46(t *testing.T) { + if runtime.GOARCH == "386" { + t.Skip("IPv6 not supported on 386 for some reason") + } + + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock4 := createListener4(t) + bSock4 := createListener4(t) + aSock6 := createListener6(t) + bSock6 := createListener6(t) + + aServer, err := Server(ipv4.NewPacketConn(aSock4), &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + MulticastConnV6: ipv6.NewPacketConn(aSock6), + }) + check(err, t) + + bServer, err := Server(ipv4.NewPacketConn(bSock4), &Config{ + MulticastConnV6: ipv6.NewPacketConn(bSock6), + }) + check(err, t) + + _, addr, err := bServer.Query(context.TODO(), "pion-mdns-1.local") + check(err, t) + + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + + _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + + check(aServer.Close(), t) + check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } +} + +func TestValidCommunicationIPv46Mixed(t *testing.T) { + if runtime.GOARCH == "386" { + t.Skip("IPv6 not supported on 386 for some reason") + } + + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock4 := createListener4(t) + bSock6 := createListener6(t) + + aServer, err := Server(ipv4.NewPacketConn(aSock4), &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + }) + check(err, t) + + bServer, err := Server(nil, &Config{ + MulticastConnV6: ipv6.NewPacketConn(bSock6), + }) + check(err, t) + + _, addr, err := bServer.Query(context.TODO(), "pion-mdns-1.local") + check(err, t) + + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + + _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + + check(aServer.Close(), t) + check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } +} + +func TestValidCommunicationIPv64Mixed(t *testing.T) { + if runtime.GOARCH == "386" { + t.Skip("IPv6 not supported on 386 for some reason") + } + + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock6 := createListener6(t) + bSock4 := createListener4(t) + + aServer, err := Server(nil, &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + MulticastConnV6: ipv6.NewPacketConn(aSock6), + }) + check(err, t) + + bServer, err := Server(ipv4.NewPacketConn(bSock4), &Config{}) + check(err, t) + + _, addr, err := bServer.Query(context.TODO(), "pion-mdns-1.local") + check(err, t) + + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + + _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + + check(aServer.Close(), t) + check(bServer.Close(), t) + + if len(aServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after aServer close") + } + if len(bServer.queries) > 0 { + t.Fatalf("Queries not cleaned up after bServer close") + } +} + func TestMultipleClose(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() @@ -233,7 +467,7 @@ func TestMultipleClose(t *testing.T) { report := test.CheckRoutines(t) defer report() - aSock := createListener(t) + aSock := createListener4(t) server, err := Server(ipv4.NewPacketConn(aSock), &Config{}) check(err, t) @@ -253,7 +487,7 @@ func TestQueryRespectTimeout(t *testing.T) { report := test.CheckRoutines(t) defer report() - aSock := createListener(t) + aSock := createListener4(t) server, err := Server(ipv4.NewPacketConn(aSock), &Config{}) check(err, t) @@ -281,7 +515,7 @@ func TestQueryRespectClose(t *testing.T) { report := test.CheckRoutines(t) defer report() - aSock := createListener(t) + aSock := createListener4(t) server, err := Server(ipv4.NewPacketConn(aSock), &Config{}) check(err, t)