Skip to content

Commit

Permalink
Use loopback address in responses when configured
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels committed Feb 5, 2024
1 parent 1d4f9bc commit 1daf3b5
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 19 deletions.
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,10 @@ type Config struct {
LocalAddress net.IP

LoggerFactory logging.LoggerFactory

// IncludeLoopback will include loopback interfaces to be eligble for queries and answers.
IncludeLoopback bool

// Interfaces will override the interfaces used for queries and answers.
Interfaces []net.Interface
}
84 changes: 69 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,23 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {
return nil, errNilConfig
}

ifaces, err := net.Interfaces()
if err != nil {
return nil, err
ifaces := config.Interfaces
if ifaces == nil {
var err error
ifaces, err = net.Interfaces()
if err != nil {
return nil, err
}

Check warning on line 72 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L71-L72

Added lines #L71 - L72 were not covered by tests
}

inboundBufferSize := 0
joinErrCount := 0
ifacesToUse := make([]net.Interface, 0, len(ifaces))
for i, ifc := range ifaces {
if err = conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
if !config.IncludeLoopback && ifc.Flags&net.FlagLoopback == net.FlagLoopback {
continue
}
if err := conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
joinErrCount++
continue
}
Expand Down Expand Up @@ -127,6 +134,14 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {
c.log.Warnf("Failed to SetControlMessage on PacketConn %v", err)
}

if config.IncludeLoopback {
// this is an efficient way for us to send ourselves a message faster instead of it going
// further out into the network stack.
if err := conn.SetMulticastLoopback(true); err != nil {
c.log.Warnf("Failed to SetMulticastLoopback(true) on PacketConn %v; this may cause inefficient network path communications", err)
}

Check warning on line 142 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L141-L142

Added lines #L141 - L142 were not covered by tests
}

// https://www.rfc-editor.org/rfc/rfc6762.html#section-17
// Multicast DNS messages carried by UDP may be up to the IP MTU of the
// physical interface, less the space required for the IP header (20
Expand Down Expand Up @@ -178,6 +193,11 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade
case <-c.closed:
return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
case res := <-queryChan:
// Given https://datatracker.ietf.org/doc/html/draft-ietf-mmusic-mdns-ice-candidates#section-3.2.2-2
// An ICE agent SHOULD ignore candidates where the hostname resolution returns more than one IP address.
//
// We will take the first we receive which could result in a race between two suitable addresses where
// one is better than the other (e.g. localhost vs LAN).
return res.answer, res.addr, nil
case <-ctx.Done():
return dnsmessage.ResourceHeader{}, nil, errContextElapsed
Expand Down Expand Up @@ -242,14 +262,14 @@ func (c *Conn) sendQuestion(name string) {
c.writeToSocket(0, rawQuery, false)
}

func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) {
func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool) {
if ifIndex != 0 {
ifc, err := net.InterfaceByIndex(ifIndex)
if err != nil {
c.log.Warnf("Failed to get interface interface for %d: %v", ifIndex, err)
c.log.Warnf("Failed to get interface for %d: %v", ifIndex, err)

Check warning on line 269 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L269

Added line #L269 was not covered by tests
return
}
if onlyLooback && ifc.Flags&net.FlagLoopback == 0 {
if srcIfcIsLoopback && ifc.Flags&net.FlagLoopback == 0 {
// avoid accidentally tricking the destination that itself is the same as us
c.log.Warnf("Interface is not loopback %d", ifIndex)
return
Expand All @@ -264,7 +284,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) {
return
}
for ifcIdx := range c.ifaces {
if onlyLooback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 {
if srcIfcIsLoopback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 {
// avoid accidentally tricking the destination that itself is the same as us
continue
}
Expand All @@ -278,7 +298,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) {
}
}

func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {
func (c *Conn) sendAnswer(name string, ifIndex int, addr net.IP) {
packedName, err := dnsmessage.NewName(name)
if err != nil {
c.log.Warnf("Failed to construct mDNS packet %v", err)
Expand All @@ -299,7 +319,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {
TTL: responseTTL,
},
Body: &dnsmessage.AResource{
A: ipToBytes(dst),
A: ipToBytes(addr),
},
},
},
Expand All @@ -311,7 +331,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {
return
}

c.writeToSocket(ifIndex, rawAnswer, dst.IsLoopback())
c.writeToSocket(ifIndex, rawAnswer, addr.IsLoopback())
}

func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit
Expand Down Expand Up @@ -361,10 +381,44 @@ func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit
if config.LocalAddress != nil {
c.sendAnswer(q.Name.String(), ifIndex, config.LocalAddress)
} else {
localAddress, err := interfaceForRemote(src.String())
if err != nil {
c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err)
continue
var localAddress net.IP

// prefer the address of the interface if we know its index, but otherwise
// derive it from the address we read from. We do this because even if
// multicast loopback is in use or we send from a loopback interface,
// there are still cases where the IP packet will contain the wrong
// source IP (e.g. a LAN interface).
// For example, we can have a packet that has:
// Source: 192.168.65.3
// Destination: 224.0.0.251
// Interface Index: 1
// Interface Addresses @ 1: [127.0.0.1/8 ::1/128]
if ifIndex == 0 {
localAddress, err = interfaceForRemote(src.String())
if err != nil {
c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err)
continue

Check warning on line 400 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L397-L400

Added lines #L397 - L400 were not covered by tests
}
} else {
ifc, err := net.InterfaceByIndex(ifIndex)
if err != nil {
c.log.Warnf("Failed to get interface for %d: %v", ifIndex, err)
continue

Check warning on line 406 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L405-L406

Added lines #L405 - L406 were not covered by tests
}
addrs, err := ifc.Addrs()
if err != nil {
c.log.Warnf("Failed to get addresses for interface %d: %v", ifIndex, err)
continue

Check warning on line 411 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L410-L411

Added lines #L410 - L411 were not covered by tests
}
if len(addrs) == 0 {
c.log.Warnf("Expected more than one address for interface %d", ifIndex)
continue

Check warning on line 415 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L414-L415

Added lines #L414 - L415 were not covered by tests
}
ipAddr, ok := addrs[0].(*net.IPNet)
if !ok {
c.log.Warnf("expected *net.IPNet address for interface but got %T", ipAddr)
}

Check warning on line 420 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L419-L420

Added lines #L419 - L420 were not covered by tests
localAddress = ipAddr.IP
}

c.sendAnswer(q.Name.String(), ifIndex, localAddress)
Expand Down
102 changes: 102 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ func TestValidCommunication(t *testing.T) {
t.Fatalf("unexpected local address: %v", addr)
}

// test against regression from https://github.com/pion/mdns/commit/608f20b
// where by properly sending mDNS responses to all interfaces, we significantly
// increased the chance that we send a loopback response to a Query that is
// unwillingly to use loopback addresses (the default in pion/ice).
for i := 0; i < 100; i++ {
_, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local")
check(err, t)
if addr.String() == localAddress {
t.Fatalf("unexpected local address: %v", addr)
}
if addr.String() == "127.0.0.1" {
t.Fatal("unexpected loopback")
}
}

check(aServer.Close(), t)
check(bServer.Close(), t)
}
Expand Down Expand Up @@ -95,6 +110,93 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) {
check(aServer.Close(), t)
}

func TestValidCommunicationWithLoopbackAddressConfig(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

aSock := createListener(t)

loopbackIP := net.ParseIP("127.0.0.1")

aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{
LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"},
LocalAddress: loopbackIP,
IncludeLoopback: true, // the test would fail if this was false
})
check(err, t)

_, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local")
check(err, t)
if addr.String() != loopbackIP.String() {
t.Fatalf("address mismatch: expected %s, but got %v\n", localAddress, addr)
}

check(aServer.Close(), t)
}

func TestValidCommunicationWithLoopbackInterface(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

aSock := createListener(t)

ifaces, err := net.Interfaces()
check(err, t)
ifacesToUse := make([]net.Interface, 0, len(ifaces))
for _, ifc := range ifaces {
if ifc.Flags&net.FlagLoopback != net.FlagLoopback {
continue
}
ifcCopy := ifc
ifacesToUse = append(ifacesToUse, ifcCopy)
}

// the following checks are unlikely to fail since most places where this code runs
// will have a loopback
if len(ifacesToUse) == 0 {
t.Skip("expected at least one loopback interface, but got none")
}

aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{
LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"},
IncludeLoopback: true, // the test would fail if this was false
Interfaces: ifacesToUse,
})
check(err, t)

_, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local")
check(err, t)
var found bool
for _, iface := range ifacesToUse {
addrs, err := iface.Addrs()
check(err, t)
for _, ifaceAddr := range addrs {
ipAddr, ok := ifaceAddr.(*net.IPNet)
if !ok {
t.Fatalf("expected *net.IPNet address for loopback but got %T", addr)
}
if addr.String() == ipAddr.IP.String() {
found = true
break
}
}
if found {
break
}
}
if !found {
t.Fatalf("address mismatch: expected loopback address, but got %v\n", addr)
}

check(aServer.Close(), t)
}

func TestMultipleClose(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()
Expand Down
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module github.com/pion/mdns

go 1.12
go 1.19

require (
github.com/pion/logging v0.2.2
github.com/pion/transport/v3 v3.0.1
golang.org/x/net v0.20.0
)

require golang.org/x/sys v0.16.0 // indirect
3 changes: 0 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
Expand Down Expand Up @@ -44,14 +43,12 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
Expand Down

0 comments on commit 1daf3b5

Please sign in to comment.