From 4d0762c7f0720f0624963e0f708910a4f997e9c7 Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Mon, 5 Feb 2024 16:24:16 -0500 Subject: [PATCH] Use loopback address in responses when configured Fixes https://github.com/pion/webrtc/issues/2625 --- config.go | 6 ++++ conn.go | 32 +++++++++++++------ conn_test.go | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 10 deletions(-) diff --git a/config.go b/config.go index 936b0b6..34fc8fe 100644 --- a/config.go +++ b/config.go @@ -32,4 +32,10 @@ type Config struct { LocalAddress net.IP LoggerFactory logging.LoggerFactory + + // IncludeLoopback will include loopback interfaces to be eligble for queries and answers. + IncludeLoopback bool + + // Interfaces will override the interfaces used for queries and answers. + Interfaces []net.Interface } diff --git a/conn.go b/conn.go index dea7428..7bd3c03 100644 --- a/conn.go +++ b/conn.go @@ -63,16 +63,23 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { return nil, errNilConfig } - ifaces, err := net.Interfaces() - if err != nil { - return nil, err + ifaces := config.Interfaces + if ifaces == nil { + var err error + ifaces, err = net.Interfaces() + if err != nil { + return nil, err + } } inboundBufferSize := 0 joinErrCount := 0 ifacesToUse := make([]net.Interface, 0, len(ifaces)) for i, ifc := range ifaces { - if err = conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil { + if !config.IncludeLoopback && ifc.Flags&net.FlagLoopback == net.FlagLoopback { + continue + } + if err := conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil { joinErrCount++ continue } @@ -178,6 +185,11 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade case <-c.closed: return dnsmessage.ResourceHeader{}, nil, errConnectionClosed case res := <-queryChan: + // Given https://datatracker.ietf.org/doc/html/draft-ietf-mmusic-mdns-ice-candidates#section-3.2.2-2 + // An ICE agent SHOULD ignore candidates where the hostname resolution returns more than one IP address. + // + // We will take the first we receive which could result in a race between two suitable addresses where + // one is better than the other (e.g. localhost vs LAN). return res.answer, res.addr, nil case <-ctx.Done(): return dnsmessage.ResourceHeader{}, nil, errContextElapsed @@ -242,14 +254,14 @@ func (c *Conn) sendQuestion(name string) { c.writeToSocket(0, rawQuery, false) } -func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) { +func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool) { if ifIndex != 0 { ifc, err := net.InterfaceByIndex(ifIndex) if err != nil { c.log.Warnf("Failed to get interface interface for %d: %v", ifIndex, err) return } - if onlyLooback && ifc.Flags&net.FlagLoopback == 0 { + if srcIfcIsLoopback && ifc.Flags&net.FlagLoopback == 0 { // avoid accidentally tricking the destination that itself is the same as us c.log.Warnf("Interface is not loopback %d", ifIndex) return @@ -264,7 +276,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) { return } for ifcIdx := range c.ifaces { - if onlyLooback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 { + if srcIfcIsLoopback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 { // avoid accidentally tricking the destination that itself is the same as us continue } @@ -278,7 +290,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) { } } -func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) { +func (c *Conn) sendAnswer(name string, ifIndex int, addr net.IP) { packedName, err := dnsmessage.NewName(name) if err != nil { c.log.Warnf("Failed to construct mDNS packet %v", err) @@ -299,7 +311,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) { TTL: responseTTL, }, Body: &dnsmessage.AResource{ - A: ipToBytes(dst), + A: ipToBytes(addr), }, }, }, @@ -311,7 +323,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) { return } - c.writeToSocket(ifIndex, rawAnswer, dst.IsLoopback()) + c.writeToSocket(ifIndex, rawAnswer, addr.IsLoopback()) } func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit diff --git a/conn_test.go b/conn_test.go index 5564b5d..6112620 100644 --- a/conn_test.go +++ b/conn_test.go @@ -67,6 +67,21 @@ func TestValidCommunication(t *testing.T) { t.Fatalf("unexpected local address: %v", addr) } + // test against regression from https://github.com/pion/mdns/commit/608f20b + // where by properly sending mDNS responses to all interfaces, we significantly + // increased the chance that we send a loopback response to a Query that is + // unwillingly to use loopback addresses (the default in pion/ice). + for i := 0; i < 100; i++ { + _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + if addr.String() == "127.0.0.1" { + t.Fatal("unexpected loopback") + } + } + check(aServer.Close(), t) check(bServer.Close(), t) } @@ -95,6 +110,81 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) { check(aServer.Close(), t) } +func TestValidCommunicationWithLoopbackAddressConfig(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock := createListener(t) + + loopbackIP := net.ParseIP("127.0.0.1") + + aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + LocalAddress: loopbackIP, + IncludeLoopback: true, // the test would fail if this was false + }) + check(err, t) + + _, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local") + check(err, t) + if addr.String() != loopbackIP.String() { + t.Fatalf("address mismatch: expected %s, but got %v\n", localAddress, addr) + } + + check(aServer.Close(), t) +} + +func TestValidCommunicationWithLoopbackInterface(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock := createListener(t) + + ifaces, err := net.Interfaces() + check(err, t) + ifacesToUse := make([]net.Interface, 0, len(ifaces)) + for _, ifc := range ifaces { + if ifc.Flags&net.FlagLoopback != net.FlagLoopback { + continue + } + ifcCopy := ifc + ifacesToUse = append(ifacesToUse, ifcCopy) + } + + // the following checks are unlikely to fail since most places where this code runs + // will have a loopback + if len(ifacesToUse) == 0 { + t.Skip("expected at least one loopback interface, but got none") + } + expectedAddrs, err := ifacesToUse[0].Addrs() + check(err, t) + expectedAddr, ok := expectedAddrs[0].(*net.IPNet) + if !ok { + t.Fatalf("expected *net.IPNet address for loopback but got %T", expectedAddrs[0]) + } + + aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + IncludeLoopback: true, // the test would fail if this was false + Interfaces: ifacesToUse, + }) + check(err, t) + + _, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local") + check(err, t) + if addr.String() != expectedAddr.IP.String() { + t.Fatalf("address mismatch: expected %s, but got %v\n", expectedAddr.IP.String(), addr) + } + + check(aServer.Close(), t) +} + func TestMultipleClose(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop()