diff --git a/config.go b/config.go index 936b0b6..34fc8fe 100644 --- a/config.go +++ b/config.go @@ -32,4 +32,10 @@ type Config struct { LocalAddress net.IP LoggerFactory logging.LoggerFactory + + // IncludeLoopback will include loopback interfaces to be eligble for queries and answers. + IncludeLoopback bool + + // Interfaces will override the interfaces used for queries and answers. + Interfaces []net.Interface } diff --git a/conn.go b/conn.go index dea7428..6f94dbe 100644 --- a/conn.go +++ b/conn.go @@ -63,16 +63,23 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { return nil, errNilConfig } - ifaces, err := net.Interfaces() - if err != nil { - return nil, err + ifaces := config.Interfaces + if ifaces == nil { + var err error + ifaces, err = net.Interfaces() + if err != nil { + return nil, err + } } inboundBufferSize := 0 joinErrCount := 0 ifacesToUse := make([]net.Interface, 0, len(ifaces)) for i, ifc := range ifaces { - if err = conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil { + if !config.IncludeLoopback && ifc.Flags&net.FlagLoopback == net.FlagLoopback { + continue + } + if err := conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil { joinErrCount++ continue } @@ -127,6 +134,14 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { c.log.Warnf("Failed to SetControlMessage on PacketConn %v", err) } + if config.IncludeLoopback { + // this is an efficient way for us to send ourselves a message faster instead of it going + // further out into the network stack. + // if err := conn.SetMulticastLoopback(true); err != nil { + // c.log.Warnf("Failed to SetMulticastLoopback(true) on PacketConn %v; this may cause inefficient network path communications", err) + // } + } + // https://www.rfc-editor.org/rfc/rfc6762.html#section-17 // Multicast DNS messages carried by UDP may be up to the IP MTU of the // physical interface, less the space required for the IP header (20 @@ -178,6 +193,11 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade case <-c.closed: return dnsmessage.ResourceHeader{}, nil, errConnectionClosed case res := <-queryChan: + // Given https://datatracker.ietf.org/doc/html/draft-ietf-mmusic-mdns-ice-candidates#section-3.2.2-2 + // An ICE agent SHOULD ignore candidates where the hostname resolution returns more than one IP address. + // + // We will take the first we receive which could result in a race between two suitable addresses where + // one is better than the other (e.g. localhost vs LAN). return res.answer, res.addr, nil case <-ctx.Done(): return dnsmessage.ResourceHeader{}, nil, errContextElapsed @@ -242,14 +262,14 @@ func (c *Conn) sendQuestion(name string) { c.writeToSocket(0, rawQuery, false) } -func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) { +func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool) { if ifIndex != 0 { ifc, err := net.InterfaceByIndex(ifIndex) if err != nil { - c.log.Warnf("Failed to get interface interface for %d: %v", ifIndex, err) + c.log.Warnf("Failed to get interface for %d: %v", ifIndex, err) return } - if onlyLooback && ifc.Flags&net.FlagLoopback == 0 { + if srcIfcIsLoopback && ifc.Flags&net.FlagLoopback == 0 { // avoid accidentally tricking the destination that itself is the same as us c.log.Warnf("Interface is not loopback %d", ifIndex) return @@ -264,7 +284,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) { return } for ifcIdx := range c.ifaces { - if onlyLooback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 { + if srcIfcIsLoopback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 { // avoid accidentally tricking the destination that itself is the same as us continue } @@ -278,7 +298,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) { } } -func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) { +func (c *Conn) sendAnswer(name string, ifIndex int, addr net.IP) { packedName, err := dnsmessage.NewName(name) if err != nil { c.log.Warnf("Failed to construct mDNS packet %v", err) @@ -299,7 +319,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) { TTL: responseTTL, }, Body: &dnsmessage.AResource{ - A: ipToBytes(dst), + A: ipToBytes(addr), }, }, }, @@ -311,7 +331,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) { return } - c.writeToSocket(ifIndex, rawAnswer, dst.IsLoopback()) + c.writeToSocket(ifIndex, rawAnswer, addr.IsLoopback()) } func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit @@ -361,10 +381,43 @@ func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit if config.LocalAddress != nil { c.sendAnswer(q.Name.String(), ifIndex, config.LocalAddress) } else { - localAddress, err := interfaceForRemote(src.String()) - if err != nil { - c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err) - continue + var localAddress net.IP + + // prefer the address of the interface if we know its index, but otherwise + // derive it from the address we read from. We do this because even if + // multicast loopback is in use, there are still cases where the IP packet + // will contain the wrong source IP. + // For example, we can have a packet that has: + // Source: 192.168.65.3 + // Destination: 224.0.0.251 + // Interface Index: 1 + // Interface Addresses @ 1: [127.0.0.1/8 ::1/128] + if ifIndex == 0 { + localAddress, err = interfaceForRemote(src.String()) + if err != nil { + c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err) + continue + } + } else { + ifc, err := net.InterfaceByIndex(ifIndex) + if err != nil { + c.log.Warnf("Failed to get interface for %d: %v", ifIndex, err) + continue + } + addrs, err := ifc.Addrs() + if err != nil { + c.log.Warnf("Failed to get addresses for interface %d: %v", ifIndex, err) + continue + } + if len(addrs) == 0 { + c.log.Warnf("Expected more than one address for interface %d", ifIndex) + continue + } + ipAddr, ok := addrs[0].(*net.IPNet) + if !ok { + c.log.Warnf("expected *net.IPNet address for interface but got %T", ipAddr) + } + localAddress = ipAddr.IP } c.sendAnswer(q.Name.String(), ifIndex, localAddress) diff --git a/conn_test.go b/conn_test.go index 5564b5d..c869f77 100644 --- a/conn_test.go +++ b/conn_test.go @@ -67,6 +67,21 @@ func TestValidCommunication(t *testing.T) { t.Fatalf("unexpected local address: %v", addr) } + // test against regression from https://github.com/pion/mdns/commit/608f20b + // where by properly sending mDNS responses to all interfaces, we significantly + // increased the chance that we send a loopback response to a Query that is + // unwillingly to use loopback addresses (the default in pion/ice). + for i := 0; i < 100; i++ { + _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + if addr.String() == "127.0.0.1" { + t.Fatal("unexpected loopback") + } + } + check(aServer.Close(), t) check(bServer.Close(), t) } @@ -95,6 +110,93 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) { check(aServer.Close(), t) } +func TestValidCommunicationWithLoopbackAddressConfig(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock := createListener(t) + + loopbackIP := net.ParseIP("127.0.0.1") + + aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + LocalAddress: loopbackIP, + IncludeLoopback: true, // the test would fail if this was false + }) + check(err, t) + + _, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local") + check(err, t) + if addr.String() != loopbackIP.String() { + t.Fatalf("address mismatch: expected %s, but got %v\n", localAddress, addr) + } + + check(aServer.Close(), t) +} + +func TestValidCommunicationWithLoopbackInterface(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock := createListener(t) + + ifaces, err := net.Interfaces() + check(err, t) + ifacesToUse := make([]net.Interface, 0, len(ifaces)) + for _, ifc := range ifaces { + if ifc.Flags&net.FlagLoopback != net.FlagLoopback { + continue + } + ifcCopy := ifc + ifacesToUse = append(ifacesToUse, ifcCopy) + } + + // the following checks are unlikely to fail since most places where this code runs + // will have a loopback + if len(ifacesToUse) == 0 { + t.Skip("expected at least one loopback interface, but got none") + } + + aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + IncludeLoopback: true, // the test would fail if this was false + Interfaces: ifacesToUse, + }) + check(err, t) + + _, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local") + check(err, t) + var found bool + for _, iface := range ifacesToUse { + addrs, err := iface.Addrs() + check(err, t) + for _, ifaceAddr := range addrs { + ipAddr, ok := ifaceAddr.(*net.IPNet) + if !ok { + t.Fatalf("expected *net.IPNet address for loopback but got %T", addr) + } + if addr.String() == ipAddr.IP.String() { + found = true + break + } + } + if found { + break + } + } + if !found { + t.Fatalf("address mismatch: expected loopback address, but got %v\n", addr) + } + + check(aServer.Close(), t) +} + func TestMultipleClose(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop()