Skip to content

Commit

Permalink
Use loopback address in responses when configured
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels committed Feb 5, 2024
1 parent 1d4f9bc commit 829b3fe
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 15 deletions.
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,10 @@ type Config struct {
LocalAddress net.IP

LoggerFactory logging.LoggerFactory

// IncludeLoopback will include loopback interfaces to be eligble for queries and answers.
IncludeLoopback bool

// Interfaces will override the interfaces used for queries and answers.
Interfaces []net.Interface
}
83 changes: 68 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,23 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {
return nil, errNilConfig
}

ifaces, err := net.Interfaces()
if err != nil {
return nil, err
ifaces := config.Interfaces
if ifaces == nil {
var err error
ifaces, err = net.Interfaces()
if err != nil {
return nil, err
}

Check warning on line 72 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L71-L72

Added lines #L71 - L72 were not covered by tests
}

inboundBufferSize := 0
joinErrCount := 0
ifacesToUse := make([]net.Interface, 0, len(ifaces))
for i, ifc := range ifaces {
if err = conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
if !config.IncludeLoopback && ifc.Flags&net.FlagLoopback == net.FlagLoopback {
continue
}
if err := conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
joinErrCount++
continue
}
Expand Down Expand Up @@ -127,6 +134,14 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {
c.log.Warnf("Failed to SetControlMessage on PacketConn %v", err)
}

if config.IncludeLoopback {
// this is an efficient way for us to send ourselves a message faster instead of it going
// further out into the network stack.
if err := conn.SetMulticastLoopback(true); err != nil {
c.log.Warnf("Failed to SetMulticastLoopback(true) on PacketConn %v; this may cause inefficient network path communications", err)
}

Check warning on line 142 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L141-L142

Added lines #L141 - L142 were not covered by tests
}

// https://www.rfc-editor.org/rfc/rfc6762.html#section-17
// Multicast DNS messages carried by UDP may be up to the IP MTU of the
// physical interface, less the space required for the IP header (20
Expand Down Expand Up @@ -178,6 +193,11 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade
case <-c.closed:
return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
case res := <-queryChan:
// Given https://datatracker.ietf.org/doc/html/draft-ietf-mmusic-mdns-ice-candidates#section-3.2.2-2
// An ICE agent SHOULD ignore candidates where the hostname resolution returns more than one IP address.
//
// We will take the first we receive which could result in a race between two suitable addresses where
// one is better than the other (e.g. localhost vs LAN).
return res.answer, res.addr, nil
case <-ctx.Done():
return dnsmessage.ResourceHeader{}, nil, errContextElapsed
Expand Down Expand Up @@ -242,14 +262,14 @@ func (c *Conn) sendQuestion(name string) {
c.writeToSocket(0, rawQuery, false)
}

func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) {
func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool) {
if ifIndex != 0 {
ifc, err := net.InterfaceByIndex(ifIndex)
if err != nil {
c.log.Warnf("Failed to get interface interface for %d: %v", ifIndex, err)
c.log.Warnf("Failed to get interface for %d: %v", ifIndex, err)

Check warning on line 269 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L269

Added line #L269 was not covered by tests
return
}
if onlyLooback && ifc.Flags&net.FlagLoopback == 0 {
if srcIfcIsLoopback && ifc.Flags&net.FlagLoopback == 0 {
// avoid accidentally tricking the destination that itself is the same as us
c.log.Warnf("Interface is not loopback %d", ifIndex)
return
Expand All @@ -264,7 +284,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) {
return
}
for ifcIdx := range c.ifaces {
if onlyLooback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 {
if srcIfcIsLoopback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 {
// avoid accidentally tricking the destination that itself is the same as us
continue
}
Expand All @@ -278,7 +298,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) {
}
}

func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {
func (c *Conn) sendAnswer(name string, ifIndex int, addr net.IP) {
packedName, err := dnsmessage.NewName(name)
if err != nil {
c.log.Warnf("Failed to construct mDNS packet %v", err)
Expand All @@ -299,7 +319,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {
TTL: responseTTL,
},
Body: &dnsmessage.AResource{
A: ipToBytes(dst),
A: ipToBytes(addr),
},
},
},
Expand All @@ -311,7 +331,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {
return
}

c.writeToSocket(ifIndex, rawAnswer, dst.IsLoopback())
c.writeToSocket(ifIndex, rawAnswer, addr.IsLoopback())
}

func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit
Expand Down Expand Up @@ -361,10 +381,43 @@ func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit
if config.LocalAddress != nil {
c.sendAnswer(q.Name.String(), ifIndex, config.LocalAddress)
} else {
localAddress, err := interfaceForRemote(src.String())
if err != nil {
c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err)
continue
var localAddress net.IP

// prefer the address of the interface if we know its index, but otherwise
// derive it from the address we read from. We do this because even if
// multicast loopback is in use, there are still cases where the IP packet
// will contain the wrong source IP.
// For example, we can have a packet that has:
// Source: 192.168.65.3
// Destination: 224.0.0.251
// Interface Index: 1
// Interface Addresses @ 1: [127.0.0.1/8 ::1/128]
if ifIndex == 0 {
localAddress, err = interfaceForRemote(src.String())
if err != nil {
c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err)
continue

Check warning on line 399 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L396-L399

Added lines #L396 - L399 were not covered by tests
}
} else {
ifc, err := net.InterfaceByIndex(ifIndex)
if err != nil {
c.log.Warnf("Failed to get interface for %d: %v", ifIndex, err)
continue

Check warning on line 405 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L404-L405

Added lines #L404 - L405 were not covered by tests
}
addrs, err := ifc.Addrs()
if err != nil {
c.log.Warnf("Failed to get addresses for interface %d: %v", ifIndex, err)
continue

Check warning on line 410 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L409-L410

Added lines #L409 - L410 were not covered by tests
}
if len(addrs) == 0 {
c.log.Warnf("Expected more than one address for interface %d", ifIndex)
continue

Check warning on line 414 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L413-L414

Added lines #L413 - L414 were not covered by tests
}
ipAddr, ok := addrs[0].(*net.IPNet)
if !ok {
c.log.Warnf("expected *net.IPNet address for interface but got %T", ipAddr)
}

Check warning on line 419 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L418-L419

Added lines #L418 - L419 were not covered by tests
localAddress = ipAddr.IP
}

c.sendAnswer(q.Name.String(), ifIndex, localAddress)
Expand Down
102 changes: 102 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ func TestValidCommunication(t *testing.T) {
t.Fatalf("unexpected local address: %v", addr)
}

// test against regression from https://github.com/pion/mdns/commit/608f20b
// where by properly sending mDNS responses to all interfaces, we significantly
// increased the chance that we send a loopback response to a Query that is
// unwillingly to use loopback addresses (the default in pion/ice).
for i := 0; i < 100; i++ {
_, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local")
check(err, t)
if addr.String() == localAddress {
t.Fatalf("unexpected local address: %v", addr)
}
if addr.String() == "127.0.0.1" {
t.Fatal("unexpected loopback")
}
}

check(aServer.Close(), t)
check(bServer.Close(), t)
}
Expand Down Expand Up @@ -95,6 +110,93 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) {
check(aServer.Close(), t)
}

func TestValidCommunicationWithLoopbackAddressConfig(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

aSock := createListener(t)

loopbackIP := net.ParseIP("127.0.0.1")

aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{
LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"},
LocalAddress: loopbackIP,
IncludeLoopback: true, // the test would fail if this was false
})
check(err, t)

_, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local")
check(err, t)
if addr.String() != loopbackIP.String() {
t.Fatalf("address mismatch: expected %s, but got %v\n", localAddress, addr)
}

check(aServer.Close(), t)
}

func TestValidCommunicationWithLoopbackInterface(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

aSock := createListener(t)

ifaces, err := net.Interfaces()
check(err, t)
ifacesToUse := make([]net.Interface, 0, len(ifaces))
for _, ifc := range ifaces {
if ifc.Flags&net.FlagLoopback != net.FlagLoopback {
continue
}
ifcCopy := ifc
ifacesToUse = append(ifacesToUse, ifcCopy)
}

// the following checks are unlikely to fail since most places where this code runs
// will have a loopback
if len(ifacesToUse) == 0 {
t.Skip("expected at least one loopback interface, but got none")
}

aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{
LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"},
IncludeLoopback: true, // the test would fail if this was false
Interfaces: ifacesToUse,
})
check(err, t)

_, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local")
check(err, t)
var found bool
for _, iface := range ifacesToUse {
addrs, err := iface.Addrs()
check(err, t)
for _, ifaceAddr := range addrs {
ipAddr, ok := ifaceAddr.(*net.IPNet)
if !ok {
t.Fatalf("expected *net.IPNet address for loopback but got %T", addr)
}
if addr.String() == ipAddr.IP.String() {
found = true
break
}
}
if found {
break
}
}
if !found {
t.Fatalf("address mismatch: expected loopback address, but got %v\n", addr)
}

check(aServer.Close(), t)
}

func TestMultipleClose(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()
Expand Down

0 comments on commit 829b3fe

Please sign in to comment.