Skip to content

Commit

Permalink
Use loopback address in responses when configured
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels committed Feb 5, 2024
1 parent 1d4f9bc commit 4d0762c
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 10 deletions.
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,10 @@ type Config struct {
LocalAddress net.IP

LoggerFactory logging.LoggerFactory

// IncludeLoopback will include loopback interfaces to be eligble for queries and answers.
IncludeLoopback bool

// Interfaces will override the interfaces used for queries and answers.
Interfaces []net.Interface
}
32 changes: 22 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,23 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {
return nil, errNilConfig
}

ifaces, err := net.Interfaces()
if err != nil {
return nil, err
ifaces := config.Interfaces
if ifaces == nil {
var err error
ifaces, err = net.Interfaces()
if err != nil {
return nil, err
}
}

inboundBufferSize := 0
joinErrCount := 0
ifacesToUse := make([]net.Interface, 0, len(ifaces))
for i, ifc := range ifaces {
if err = conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
if !config.IncludeLoopback && ifc.Flags&net.FlagLoopback == net.FlagLoopback {
continue
}
if err := conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
joinErrCount++
continue
}
Expand Down Expand Up @@ -178,6 +185,11 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade
case <-c.closed:
return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
case res := <-queryChan:
// Given https://datatracker.ietf.org/doc/html/draft-ietf-mmusic-mdns-ice-candidates#section-3.2.2-2
// An ICE agent SHOULD ignore candidates where the hostname resolution returns more than one IP address.
//
// We will take the first we receive which could result in a race between two suitable addresses where
// one is better than the other (e.g. localhost vs LAN).
return res.answer, res.addr, nil
case <-ctx.Done():
return dnsmessage.ResourceHeader{}, nil, errContextElapsed
Expand Down Expand Up @@ -242,14 +254,14 @@ func (c *Conn) sendQuestion(name string) {
c.writeToSocket(0, rawQuery, false)
}

func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) {
func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool) {
if ifIndex != 0 {
ifc, err := net.InterfaceByIndex(ifIndex)
if err != nil {
c.log.Warnf("Failed to get interface interface for %d: %v", ifIndex, err)
return
}
if onlyLooback && ifc.Flags&net.FlagLoopback == 0 {
if srcIfcIsLoopback && ifc.Flags&net.FlagLoopback == 0 {
// avoid accidentally tricking the destination that itself is the same as us
c.log.Warnf("Interface is not loopback %d", ifIndex)
return
Expand All @@ -264,7 +276,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) {
return
}
for ifcIdx := range c.ifaces {
if onlyLooback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 {
if srcIfcIsLoopback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 {
// avoid accidentally tricking the destination that itself is the same as us
continue
}
Expand All @@ -278,7 +290,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) {
}
}

func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {
func (c *Conn) sendAnswer(name string, ifIndex int, addr net.IP) {
packedName, err := dnsmessage.NewName(name)
if err != nil {
c.log.Warnf("Failed to construct mDNS packet %v", err)
Expand All @@ -299,7 +311,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {
TTL: responseTTL,
},
Body: &dnsmessage.AResource{
A: ipToBytes(dst),
A: ipToBytes(addr),
},
},
},
Expand All @@ -311,7 +323,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {
return
}

c.writeToSocket(ifIndex, rawAnswer, dst.IsLoopback())
c.writeToSocket(ifIndex, rawAnswer, addr.IsLoopback())
}

func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit
Expand Down
90 changes: 90 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ func TestValidCommunication(t *testing.T) {
t.Fatalf("unexpected local address: %v", addr)
}

// test against regression from https://github.com/pion/mdns/commit/608f20b
// where by properly sending mDNS responses to all interfaces, we significantly
// increased the chance that we send a loopback response to a Query that is
// unwillingly to use loopback addresses (the default in pion/ice).
for i := 0; i < 100; i++ {
_, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local")
check(err, t)
if addr.String() == localAddress {
t.Fatalf("unexpected local address: %v", addr)
}
if addr.String() == "127.0.0.1" {
t.Fatal("unexpected loopback")
}
}

check(aServer.Close(), t)
check(bServer.Close(), t)
}
Expand Down Expand Up @@ -95,6 +110,81 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) {
check(aServer.Close(), t)
}

func TestValidCommunicationWithLoopbackAddressConfig(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

aSock := createListener(t)

loopbackIP := net.ParseIP("127.0.0.1")

aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{
LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"},
LocalAddress: loopbackIP,
IncludeLoopback: true, // the test would fail if this was false
})
check(err, t)

_, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local")
check(err, t)
if addr.String() != loopbackIP.String() {
t.Fatalf("address mismatch: expected %s, but got %v\n", localAddress, addr)
}

check(aServer.Close(), t)
}

func TestValidCommunicationWithLoopbackInterface(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

aSock := createListener(t)

ifaces, err := net.Interfaces()
check(err, t)
ifacesToUse := make([]net.Interface, 0, len(ifaces))
for _, ifc := range ifaces {
if ifc.Flags&net.FlagLoopback != net.FlagLoopback {
continue
}
ifcCopy := ifc
ifacesToUse = append(ifacesToUse, ifcCopy)
}

// the following checks are unlikely to fail since most places where this code runs
// will have a loopback
if len(ifacesToUse) == 0 {
t.Skip("expected at least one loopback interface, but got none")
}
expectedAddrs, err := ifacesToUse[0].Addrs()
check(err, t)
expectedAddr, ok := expectedAddrs[0].(*net.IPNet)
if !ok {
t.Fatalf("expected *net.IPNet address for loopback but got %T", expectedAddrs[0])
}

aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{
LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"},
IncludeLoopback: true, // the test would fail if this was false
Interfaces: ifacesToUse,
})
check(err, t)

_, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local")
check(err, t)
if addr.String() != expectedAddr.IP.String() {
t.Fatalf("address mismatch: expected %s, but got %v\n", expectedAddr.IP.String(), addr)
}

check(aServer.Close(), t)
}

func TestMultipleClose(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()
Expand Down

0 comments on commit 4d0762c

Please sign in to comment.