From f1ac99897c1ad2f75141349a0602b934aefa620a Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Mon, 5 Feb 2024 16:24:16 -0500 Subject: [PATCH] Support unicast query and answer - Addresses #155 - Fixes #93 --- conn.go | 395 +++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 335 insertions(+), 60 deletions(-) diff --git a/conn.go b/conn.go index 1f14607..8c26e5c 100644 --- a/conn.go +++ b/conn.go @@ -14,6 +14,7 @@ import ( "github.com/pion/logging" "golang.org/x/net/dns/dnsmessage" "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) // Conn represents a mDNS Server @@ -21,8 +22,12 @@ type Conn struct { mu sync.RWMutex log logging.LeveledLogger - socket *ipv4.PacketConn - dstAddr *net.UDPAddr + multicastPktConnV4 ipPacketConn + dstAddr4 *net.UDPAddr + dstAddr6 *net.UDPAddr + + unicastPktConnV4 ipPacketConn + unicastPktConnV6 ipPacketConn queryInterval time.Duration localNames []string @@ -44,7 +49,8 @@ type queryResult struct { const ( defaultQueryInterval = time.Second - destinationAddress = "224.0.0.251:5353" + destinationAddress4 = "224.0.0.251:5353" + destinationAddress6 = "[FF02::FB]:5353" maxMessageRecords = 3 responseTTL = 120 // maxPacketSize is the maximum size of a mdns packet. @@ -61,10 +67,15 @@ var errNoPositiveMTUFound = errors.New("no positive MTU found") // // Currently, the server only supports listening on an IPv4 connection, but internally // it supports answering with IPv6 AAAA records if this were ever to change. -func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { +func Server(multicastPktConnV4 *ipv4.PacketConn, config *Config) (*Conn, error) { //nolint:gocognit if config == nil { return nil, errNilConfig } + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + log := loggerFactory.NewLogger("mdns") ifaces := config.Interfaces if ifaces == nil { @@ -75,22 +86,69 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { } } + var unicastPktConnV4 *ipv4.PacketConn + { + addr4, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0") + if err != nil { + return nil, err + } + + unicastConnV4, err := net.ListenUDP("udp4", addr4) + if err != nil { + log.Warnf("failed to lisetn on unicast IPv4 %s: %s; will not be able to receive unicast responses on IPv4", addr4, err) + } else { + unicastPktConnV4 = ipv4.NewPacketConn(unicastConnV4) + } + } + + var unicastPktConnV6 *ipv6.PacketConn + { + addr6, err := net.ResolveUDPAddr("udp6", "[::]:") + if err != nil { + return nil, err + } + + unicastConnV6, err := net.ListenUDP("udp6", addr6) + if err != nil { + log.Warnf("failed to lisetn on unicast IPv6 %s: %s; will not be able to receive unicast responses on IPv6", addr6, err) + } else { + unicastPktConnV6 = ipv6.NewPacketConn(unicastConnV6) + } + } + + mutlicastGroup4 := net.IPv4(224, 0, 0, 251) + multicastGroupAddr4 := &net.UDPAddr{IP: mutlicastGroup4} + + // FF02::FB + mutlicastGroup6 := net.IP{0xff, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xfb} + multicastGroupAddr6 := &net.UDPAddr{IP: mutlicastGroup6} + inboundBufferSize := 0 joinErrCount := 0 ifacesToUse := make([]net.Interface, 0, len(ifaces)) - for i, ifc := range ifaces { + for i := range ifaces { + ifc := ifaces[i] if !config.IncludeLoopback && ifc.Flags&net.FlagLoopback == net.FlagLoopback { continue } - if err := conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil { + if err := multicastPktConnV4.JoinGroup(&ifc, multicastGroupAddr4); err != nil { joinErrCount++ continue } - ifcCopy := ifc - ifacesToUse = append(ifacesToUse, ifcCopy) - if ifaces[i].MTU > inboundBufferSize { - inboundBufferSize = ifaces[i].MTU + ifacesToUse = append(ifacesToUse, ifc) + if ifc.MTU > inboundBufferSize { + inboundBufferSize = ifc.MTU + } + if unicastPktConnV4 != nil { + if err := unicastPktConnV4.JoinGroup(&ifc, multicastGroupAddr4); err != nil { + log.Warnf("Failed to JoinGroup on unicast IPv4 connection %v", err) + } + } + if unicastPktConnV6 != nil { + if err := unicastPktConnV6.JoinGroup(&ifc, multicastGroupAddr6); err != nil { + log.Warnf("Failed to JoinGroup on unicast IPv6 connection %v", err) + } } } @@ -104,14 +162,14 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { return nil, errJoiningMulticastGroup } - dstAddr, err := net.ResolveUDPAddr("udp", destinationAddress) + dstAddr4, err := net.ResolveUDPAddr("udp4", destinationAddress4) if err != nil { return nil, err } - loggerFactory := config.LoggerFactory - if loggerFactory == nil { - loggerFactory = logging.NewDefaultLoggerFactory() + dstAddr6, err := net.ResolveUDPAddr("udp6", destinationAddress6) + if err != nil { + return nil, err } localNames := []string{} @@ -120,28 +178,50 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { } c := &Conn{ - queryInterval: defaultQueryInterval, - queries: []*query{}, - socket: conn, - dstAddr: dstAddr, - localNames: localNames, - ifaces: ifacesToUse, - log: loggerFactory.NewLogger("mdns"), - closed: make(chan interface{}), + queryInterval: defaultQueryInterval, + multicastPktConnV4: ipPacketConn4{multicastPktConnV4, log}, + unicastPktConnV4: ipPacketConn4{unicastPktConnV4, log}, + unicastPktConnV6: ipPacketConn6{unicastPktConnV6, log}, + dstAddr4: dstAddr4, + dstAddr6: dstAddr6, + localNames: localNames, + ifaces: ifacesToUse, + log: log, + closed: make(chan interface{}), } if config.QueryInterval != 0 { c.queryInterval = config.QueryInterval } - if err := conn.SetControlMessage(ipv4.FlagInterface, true); err != nil { - c.log.Warnf("Failed to SetControlMessage on PacketConn %v", err) + if err := multicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil { + c.log.Warnf("Failed to SetControlMessage(ipv4.FlagInterface) on multicast IPv4 PacketConn %v", err) + } + if unicastPktConnV4 != nil { + if err := unicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil { + c.log.Warnf("Failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", err) + } + } + if unicastPktConnV6 != nil { + if err := unicastPktConnV6.SetControlMessage(ipv6.FlagInterface, true); err != nil { + c.log.Warnf("Failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", err) + } } if config.IncludeLoopback { // this is an efficient way for us to send ourselves a message faster instead of it going // further out into the network stack. - if err := conn.SetMulticastLoopback(true); err != nil { - c.log.Warnf("Failed to SetMulticastLoopback(true) on PacketConn %v; this may cause inefficient network path communications", err) + if err := multicastPktConnV4.SetMulticastLoopback(true); err != nil { + c.log.Warnf("Failed to SetMulticastLoopback(true) on multicast IPv4 PacketConn %v; this may cause inefficient network path communications", err) + } + if unicastPktConnV4 != nil { + if err := unicastPktConnV4.SetMulticastLoopback(true); err != nil { + c.log.Warnf("Failed to SetMulticastLoopback(true) on unicast IPv4 PacketConn %v; this may cause inefficient network path communications", err) + } + } + if unicastPktConnV6 != nil { + if err := unicastPktConnV6.SetMulticastLoopback(true); err != nil { + c.log.Warnf("Failed to SetMulticastLoopback(true) on unicast IPv6 PacketConn %v; this may cause inefficient network path communications", err) + } } } @@ -149,7 +229,9 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { // Multicast DNS messages carried by UDP may be up to the IP MTU of the // physical interface, less the space required for the IP header (20 // bytes for IPv4; 40 bytes for IPv6) and the UDP header (8 bytes). - go c.start(inboundBufferSize-20-8, config) + started := make(chan struct{}) + go c.start(started, inboundBufferSize-20-8, config) + <-started return c, nil } @@ -161,10 +243,22 @@ func (c *Conn) Close() error { default: } - if err := c.socket.Close(); err != nil { + if err := c.multicastPktConnV4.Close(); err != nil { return err } + if c.unicastPktConnV4 != nil { + if err := c.unicastPktConnV4.Close(); err != nil { + return err + } + } + + if c.unicastPktConnV6 != nil { + if err := c.unicastPktConnV6.Close(); err != nil { + return err + } + } + <-c.closed return nil } @@ -270,6 +364,13 @@ func interfaceForRemote(remote string) (net.IP, error) { return localAddr.IP, nil } +type writeType byte + +const ( + writeTypeQuestion writeType = iota + writeTypeAnswer +) + func (c *Conn) sendQuestion(name string) { packedName, err := dnsmessage.NewName(name) if err != nil { @@ -277,12 +378,27 @@ func (c *Conn) sendQuestion(name string) { return } + // https://datatracker.ietf.org/doc/html/draft-ietf-rtcweb-mdns-ice-candidates-04#section-3.2.1 + // + // 2. Otherwise, resolve the candidate using mDNS. The ICE agent + // SHOULD set the unicast-response bit of the corresponding mDNS + // query message; this minimizes multicast traffic, as the response + // is probably only useful to the querying node. + // + // 18.12. Repurposing of Top Bit of qclass in Question Section + // + // In the Question Section of a Multicast DNS query, the top bit of the + // qclass field is used to indicate that unicast responses are preferred + // for this particular question. (See Section 5.4.) + // + // We'll follow this up sending on our unicast based packet connections so that we can + // get a unicast response back. msg := dnsmessage.Message{ Header: dnsmessage.Header{}, Questions: []dnsmessage.Question{ { Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, + Class: dnsmessage.ClassINET | (1 << 15), Name: packedName, }, }, @@ -294,11 +410,21 @@ func (c *Conn) sendQuestion(name string) { return } - c.writeToSocket(0, rawQuery, false) + c.writeToSocket(0, rawQuery, false, writeTypeQuestion, nil) } -func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool) { +func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool, wType writeType, dst net.Addr) { //nolint:gocognit + if wType == writeTypeAnswer && dst == nil { + c.log.Error("Writing an answer must specify a destination address") + return + } + if ifIndex != 0 { + if wType == writeTypeQuestion { + c.log.Errorf("Unexpected question using specific interface index %d; dropping question", ifIndex) + return + } + ifc, err := net.InterfaceByIndex(ifIndex) if err != nil { c.log.Warnf("Failed to get interface for %d: %v", ifIndex, err) @@ -309,13 +435,14 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool) { c.log.Warnf("Interface is not loopback %d", ifIndex) return } - if err := c.socket.SetMulticastInterface(ifc); err != nil { - c.log.Warnf("Failed to set multicast interface for %d: %v", ifIndex, err) - } else { - if _, err := c.socket.WriteTo(b, nil, c.dstAddr); err != nil { - c.log.Warnf("Failed to send mDNS packet on interface %d: %v", ifIndex, err) - } + + //nolint:godox + // TODO(https://github.com/pion/mdns/issues/69): ipv6 + c.log.Debugf("writing answer to %s", dst) + if _, err := c.multicastPktConnV4.WriteTo(b, ifc, nil, dst); err != nil { + c.log.Warnf("Failed to send mDNS packet on interface %d: %v", ifIndex, err) } + return } for ifcIdx := range c.ifaces { @@ -323,10 +450,35 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool) { // avoid accidentally tricking the destination that itself is the same as us continue } - if err := c.socket.SetMulticastInterface(&c.ifaces[ifcIdx]); err != nil { - c.log.Warnf("Failed to set multicast interface for %d: %v", c.ifaces[ifcIdx].Index, err) + + if wType == writeTypeQuestion { + // we'll write via unicast if we can in case the responder chooses to respond to the address the request + // came from (i.e. not respecting unicast-response bit). If we were to use the multicast packet + // conn here, we'd be writing from a specific multicast address which won't be able to receive unicast + // traffic (it only works when listening on 0.0.0.0/[::]). + if c.unicastPktConnV4 == nil && c.unicastPktConnV6 == nil { + c.log.Debugf("writing question to multicast IPv4 %s", c.dstAddr4) + if _, err := c.multicastPktConnV4.WriteTo(b, &c.ifaces[ifcIdx], nil, c.dstAddr4); err != nil { + c.log.Warnf("Failed to send mDNS packet on interface %d: %v", c.ifaces[ifcIdx].Index, err) + } + } + if c.unicastPktConnV4 != nil { + c.log.Debugf("writing question to unicast IPv4 %s", c.dstAddr4) + if _, err := c.unicastPktConnV4.WriteTo(b, &c.ifaces[ifcIdx], nil, c.dstAddr4); err != nil { + c.log.Warnf("Failed to send mDNS packet on interface %d: %v", c.ifaces[ifcIdx].Index, err) + } + } + if c.unicastPktConnV6 != nil { + c.log.Debugf("writing question to unicast IPv6 %s", c.dstAddr6) + if _, err := c.unicastPktConnV6.WriteTo(b, &c.ifaces[ifcIdx], nil, c.dstAddr6); err != nil { + c.log.Warnf("Failed to send mDNS packet on interface %d: %v", c.ifaces[ifcIdx].Index, err) + } + } } else { - if _, err := c.socket.WriteTo(b, nil, c.dstAddr); err != nil { + //nolint:godox + // TODO(https://github.com/pion/mdns/issues/69): ipv6 + c.log.Debugf("writing answer to %s", dst) + if _, err := c.multicastPktConnV4.WriteTo(b, &c.ifaces[ifcIdx], nil, dst); err != nil { c.log.Warnf("Failed to send mDNS packet on interface %d: %v", c.ifaces[ifcIdx].Index, err) } } @@ -377,8 +529,8 @@ func createAnswer(name string, addr net.IP) (dnsmessage.Message, error) { return msg, nil } -func (c *Conn) sendAnswer(name string, ifIndex int, addr net.IP) { - answer, err := createAnswer(name, addr) +func (c *Conn) sendAnswer(name string, ifIndex int, result net.IP, dst net.Addr) { + answer, err := createAnswer(name, result) if err != nil { c.log.Warnf("Failed to create mDNS answer %v", err) return @@ -390,21 +542,89 @@ func (c *Conn) sendAnswer(name string, ifIndex int, addr net.IP) { return } - c.writeToSocket(ifIndex, rawAnswer, addr.IsLoopback()) + c.writeToSocket(ifIndex, rawAnswer, result.IsLoopback(), writeTypeAnswer, dst) } -func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit - defer func() { - c.mu.Lock() - defer c.mu.Unlock() - close(c.closed) - }() +type ipControlMessage struct { + IfIndex int +} +type ipPacketConn interface { + ReadFrom(b []byte) (n int, cm *ipControlMessage, src net.Addr, err error) + WriteTo(b []byte, via *net.Interface, cm *ipControlMessage, dst net.Addr) (n int, err error) + Close() error +} + +type ipPacketConn4 struct { + conn *ipv4.PacketConn + log logging.LeveledLogger +} + +func (c ipPacketConn4) ReadFrom(b []byte) (n int, cm *ipControlMessage, src net.Addr, err error) { + n, cm4, src, err := c.conn.ReadFrom(b) + if err != nil || cm4 == nil { + return n, nil, src, err + } + return n, &ipControlMessage{IfIndex: cm4.IfIndex}, src, err +} + +func (c ipPacketConn4) WriteTo(b []byte, via *net.Interface, cm *ipControlMessage, dst net.Addr) (n int, err error) { + var cm4 *ipv4.ControlMessage + if cm != nil { + cm4 = &ipv4.ControlMessage{ + IfIndex: cm.IfIndex, + } + } + if err := c.conn.SetMulticastInterface(via); err != nil { + c.log.Warnf("Failed to set multicast interface for %d: %v", via.Index, err) + return 0, err + } + return c.conn.WriteTo(b, cm4, dst) +} + +func (c ipPacketConn4) Close() error { + return c.conn.Close() +} + +type ipPacketConn6 struct { + conn *ipv6.PacketConn + log logging.LeveledLogger +} + +func (c ipPacketConn6) ReadFrom(b []byte) (n int, cm *ipControlMessage, src net.Addr, err error) { + n, cm6, src, err := c.conn.ReadFrom(b) + if err != nil || cm6 == nil { + return n, nil, src, err + } + return n, &ipControlMessage{IfIndex: cm6.IfIndex}, src, err +} + +func (c ipPacketConn6) WriteTo(b []byte, via *net.Interface, cm *ipControlMessage, dst net.Addr) (n int, err error) { + var cm6 *ipv6.ControlMessage + if cm != nil { + cm6 = &ipv6.ControlMessage{ + IfIndex: cm.IfIndex, + } + } + if err := c.conn.SetMulticastInterface(via); err != nil { + c.log.Warnf("Failed to set multicast interface for %d: %v", via.Index, err) + return 0, err + } + return c.conn.WriteTo(b, cm6, dst) +} + +func (c ipPacketConn6) Close() error { + return c.conn.Close() +} + +func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int, config *Config) { //nolint:gocognit b := make([]byte, inboundBufferSize) p := dnsmessage.Parser{} for { - n, cm, src, err := c.socket.ReadFrom(b) + _ = b + b := make([]byte, 1500) + n, cm, src, err := pktConn.ReadFrom(b) if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -412,20 +632,18 @@ func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit c.log.Warnf("Failed to ReadFrom %q %v", src, err) continue } + c.log.Debugf("got read on %s from %s", name, src) + var ifIndex int if cm != nil { ifIndex = cm.IfIndex } - var srcIP net.IP - switch addr := src.(type) { - case *net.UDPAddr: - srcIP = addr.IP - case *net.TCPAddr: - srcIP = addr.IP - default: - c.log.Warnf("Failed to determine address type %T for source address %s", src, src) + srcAddr, ok := src.(*net.UDPAddr) + if !ok { + c.log.Warnf("Expected source address %s to be UDP but got %", src, src) continue } + srcIP := srcAddr.IP srcIsIPv4 := srcIP.To4() != nil func() { @@ -445,11 +663,18 @@ func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit c.log.Warnf("Failed to parse mDNS packet %v", err) return } + shouldUnicastResponse := (q.Class & (1 << 15)) != 0 + //nolint:godox + // TODO(https://github.com/pion/mdns/issues/69): ipv6 here + dst := c.dstAddr4 + if shouldUnicastResponse { + dst = srcAddr + } for _, localName := range c.localNames { if localName == q.Name.String() { if config.LocalAddress != nil { - c.sendAnswer(q.Name.String(), ifIndex, config.LocalAddress) + c.sendAnswer(q.Name.String(), ifIndex, config.LocalAddress, dst) } else { var localAddress net.IP @@ -517,7 +742,7 @@ func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit } } - c.sendAnswer(q.Name.String(), ifIndex, localAddress) + c.sendAnswer(q.Name.String(), ifIndex, localAddress, dst) } } } @@ -556,6 +781,56 @@ func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit } } +func (c *Conn) start(started chan<- struct{}, inboundBufferSize int, config *Config) { + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + close(c.closed) + }() + + var numReaders int + readerStarted := make(chan struct{}) + readerEnded := make(chan struct{}) + + if c.multicastPktConnV4 != nil { + numReaders++ + go func() { + defer func() { + readerEnded <- struct{}{} + }() + readerStarted <- struct{}{} + c.readLoop("multi4", c.multicastPktConnV4, inboundBufferSize, config) + }() + } + if c.unicastPktConnV4 != nil { + numReaders++ + go func() { + defer func() { + readerEnded <- struct{}{} + }() + readerStarted <- struct{}{} + c.readLoop("uni4", c.unicastPktConnV4, inboundBufferSize, config) + }() + } + if c.unicastPktConnV6 != nil { + numReaders++ + go func() { + defer func() { + readerEnded <- struct{}{} + }() + readerStarted <- struct{}{} + c.readLoop("uni6", c.unicastPktConnV6, inboundBufferSize, config) + }() + } + for i := 0; i < numReaders; i++ { + <-readerStarted + } + close(started) + for i := 0; i < numReaders; i++ { + <-readerEnded + } +} + func ipFromAnswerHeader(a dnsmessage.ResourceHeader, p dnsmessage.Parser) (ip []byte, err error) { if a.Type == dnsmessage.TypeA { resource, err := p.AResource()