diff --git a/config.go b/config.go index 936b0b6..34fc8fe 100644 --- a/config.go +++ b/config.go @@ -32,4 +32,10 @@ type Config struct { LocalAddress net.IP LoggerFactory logging.LoggerFactory + + // IncludeLoopback will include loopback interfaces to be eligble for queries and answers. + IncludeLoopback bool + + // Interfaces will override the interfaces used for queries and answers. + Interfaces []net.Interface } diff --git a/conn.go b/conn.go index 6ca2b68..1f14607 100644 --- a/conn.go +++ b/conn.go @@ -6,7 +6,7 @@ package mdns import ( "context" "errors" - "math/big" + "fmt" "net" "sync" "time" @@ -57,22 +57,32 @@ const ( var errNoPositiveMTUFound = errors.New("no positive MTU found") -// Server establishes a mDNS connection over an existing conn +// Server establishes a mDNS connection over an existing conn. +// +// Currently, the server only supports listening on an IPv4 connection, but internally +// it supports answering with IPv6 AAAA records if this were ever to change. func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { if config == nil { return nil, errNilConfig } - ifaces, err := net.Interfaces() - if err != nil { - return nil, err + ifaces := config.Interfaces + if ifaces == nil { + var err error + ifaces, err = net.Interfaces() + if err != nil { + return nil, err + } } inboundBufferSize := 0 joinErrCount := 0 ifacesToUse := make([]net.Interface, 0, len(ifaces)) for i, ifc := range ifaces { - if err = conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil { + if !config.IncludeLoopback && ifc.Flags&net.FlagLoopback == net.FlagLoopback { + continue + } + if err := conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil { joinErrCount++ continue } @@ -127,6 +137,14 @@ func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) { c.log.Warnf("Failed to SetControlMessage on PacketConn %v", err) } + if config.IncludeLoopback { + // this is an efficient way for us to send ourselves a message faster instead of it going + // further out into the network stack. + if err := conn.SetMulticastLoopback(true); err != nil { + c.log.Warnf("Failed to SetMulticastLoopback(true) on PacketConn %v; this may cause inefficient network path communications", err) + } + } + // https://www.rfc-editor.org/rfc/rfc6762.html#section-17 // Multicast DNS messages carried by UDP may be up to the IP MTU of the // physical interface, less the space required for the IP header (20 @@ -189,6 +207,11 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade case <-c.closed: return dnsmessage.ResourceHeader{}, nil, errConnectionClosed case res := <-queryChan: + // Given https://datatracker.ietf.org/doc/html/draft-ietf-mmusic-mdns-ice-candidates#section-3.2.2-2 + // An ICE agent SHOULD ignore candidates where the hostname resolution returns more than one IP address. + // + // We will take the first we receive which could result in a race between two suitable addresses where + // one is better than the other (e.g. localhost vs LAN). return res.answer, res.addr, nil case <-ctx.Done(): return dnsmessage.ResourceHeader{}, nil, errContextElapsed @@ -196,16 +219,37 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade } } -func ipToBytes(ip net.IP) (out [4]byte) { +type ipToBytesError struct { + ip net.IP + expectedType string +} + +func (err ipToBytesError) Error() string { + return fmt.Sprintf("ip (%s) is not %s", err.ip, err.expectedType) +} + +func ipv4ToBytes(ip net.IP) ([4]byte, error) { rawIP := ip.To4() if rawIP == nil { - return + return [4]byte{}, ipToBytesError{ip, "IPv4"} } - ipInt := big.NewInt(0) - ipInt.SetBytes(rawIP) - copy(out[:], ipInt.Bytes()) - return + // net.IPs are stored in big endian / network byte order + var out [4]byte + copy(out[:], rawIP[:]) + return out, nil +} + +func ipv6ToBytes(ip net.IP) ([16]byte, error) { + rawIP := ip.To16() + if rawIP == nil { + return [16]byte{}, ipToBytesError{ip, "IPv6"} + } + + // net.IPs are stored in big endian / network byte order + var out [16]byte + copy(out[:], rawIP[:]) + return out, nil } func interfaceForRemote(remote string) (net.IP, error) { @@ -253,14 +297,14 @@ func (c *Conn) sendQuestion(name string) { c.writeToSocket(0, rawQuery, false) } -func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) { +func (c *Conn) writeToSocket(ifIndex int, b []byte, srcIfcIsLoopback bool) { if ifIndex != 0 { ifc, err := net.InterfaceByIndex(ifIndex) if err != nil { - c.log.Warnf("Failed to get interface interface for %d: %v", ifIndex, err) + c.log.Warnf("Failed to get interface for %d: %v", ifIndex, err) return } - if onlyLooback && ifc.Flags&net.FlagLoopback == 0 { + if srcIfcIsLoopback && ifc.Flags&net.FlagLoopback == 0 { // avoid accidentally tricking the destination that itself is the same as us c.log.Warnf("Interface is not loopback %d", ifIndex) return @@ -275,7 +319,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) { return } for ifcIdx := range c.ifaces { - if onlyLooback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 { + if srcIfcIsLoopback && c.ifaces[ifcIdx].Flags&net.FlagLoopback == 0 { // avoid accidentally tricking the destination that itself is the same as us continue } @@ -289,11 +333,10 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, onlyLooback bool) { } } -func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) { +func createAnswer(name string, addr net.IP) (dnsmessage.Message, error) { packedName, err := dnsmessage.NewName(name) if err != nil { - c.log.Warnf("Failed to construct mDNS packet %v", err) - return + return dnsmessage.Message{}, err } msg := dnsmessage.Message{ @@ -309,20 +352,45 @@ func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) { Name: packedName, TTL: responseTTL, }, - Body: &dnsmessage.AResource{ - A: ipToBytes(dst), - }, }, }, } - rawAnswer, err := msg.Pack() + if ip4 := addr.To4(); ip4 != nil { + ipBuf, err := ipv4ToBytes(addr) + if err != nil { + return dnsmessage.Message{}, err + } + msg.Answers[0].Body = &dnsmessage.AResource{ + A: ipBuf, + } + } else { + ipBuf, err := ipv6ToBytes(addr) + if err != nil { + return dnsmessage.Message{}, err + } + msg.Answers[0].Body = &dnsmessage.AAAAResource{ + AAAA: ipBuf, + } + } + + return msg, nil +} + +func (c *Conn) sendAnswer(name string, ifIndex int, addr net.IP) { + answer, err := createAnswer(name, addr) + if err != nil { + c.log.Warnf("Failed to create mDNS answer %v", err) + return + } + + rawAnswer, err := answer.Pack() if err != nil { c.log.Warnf("Failed to construct mDNS packet %v", err) return } - c.writeToSocket(ifIndex, rawAnswer, dst.IsLoopback()) + c.writeToSocket(ifIndex, rawAnswer, addr.IsLoopback()) } func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit @@ -348,6 +416,17 @@ func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit if cm != nil { ifIndex = cm.IfIndex } + var srcIP net.IP + switch addr := src.(type) { + case *net.UDPAddr: + srcIP = addr.IP + case *net.TCPAddr: + srcIP = addr.IP + default: + c.log.Warnf("Failed to determine address type %T for source address %s", src, src) + continue + } + srcIsIPv4 := srcIP.To4() != nil func() { c.mu.RLock() @@ -372,10 +451,70 @@ func (c *Conn) start(inboundBufferSize int, config *Config) { //nolint gocognit if config.LocalAddress != nil { c.sendAnswer(q.Name.String(), ifIndex, config.LocalAddress) } else { - localAddress, err := interfaceForRemote(src.String()) - if err != nil { - c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err) - continue + var localAddress net.IP + + // prefer the address of the interface if we know its index, but otherwise + // derive it from the address we read from. We do this because even if + // multicast loopback is in use or we send from a loopback interface, + // there are still cases where the IP packet will contain the wrong + // source IP (e.g. a LAN interface). + // For example, we can have a packet that has: + // Source: 192.168.65.3 + // Destination: 224.0.0.251 + // Interface Index: 1 + // Interface Addresses @ 1: [127.0.0.1/8 ::1/128] + if ifIndex != 0 { + ifc, netErr := net.InterfaceByIndex(ifIndex) + if netErr != nil { + c.log.Warnf("Failed to get interface for %d: %v", ifIndex, netErr) + continue + } + addrs, addrsErr := ifc.Addrs() + if addrsErr != nil { + c.log.Warnf("Failed to get addresses for interface %d: %v", ifIndex, addrsErr) + continue + } + if len(addrs) == 0 { + c.log.Warnf("Expected more than one address for interface %d", ifIndex) + continue + } + var selectedIP net.IP + for _, addr := range addrs { + var ip net.IP + switch addr := addr.(type) { + case *net.IPNet: + ip = addr.IP + case *net.IPAddr: + ip = addr.IP + default: + c.log.Warnf("Failed to determine address type %T from interface %d", addr, ifIndex) + continue + } + + // match up respective IP types + if ipv4 := ip.To4(); ipv4 == nil { + if srcIsIPv4 { + continue + } else if !isSupportedIPv6(ip) { + continue + } + } else if !srcIsIPv4 { + continue + } + selectedIP = ip + break + } + if selectedIP == nil { + c.log.Warnf("Failed to find suitable IP for interface %d; deriving address from source address instead", ifIndex) + } else { + localAddress = selectedIP + } + } else if ifIndex == 0 || localAddress == nil { + localAddress, err = interfaceForRemote(src.String()) + if err != nil { + c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err) + continue + } } c.sendAnswer(q.Name.String(), ifIndex, localAddress) @@ -423,7 +562,7 @@ func ipFromAnswerHeader(a dnsmessage.ResourceHeader, p dnsmessage.Parser) (ip [] if err != nil { return nil, err } - ip = net.IP(resource.A[:]) + ip = resource.A[:] } else { resource, err := p.AAAAResource() if err != nil { @@ -434,3 +573,25 @@ func ipFromAnswerHeader(a dnsmessage.ResourceHeader, p dnsmessage.Parser) (ip [] return } + +// The conditions of invalidation written below are defined in +// https://tools.ietf.org/html/rfc8445#section-5.1.1.1 +func isSupportedIPv6(ip net.IP) bool { + if len(ip) != net.IPv6len || + isZeros(ip[0:12]) || // !(IPv4-compatible IPv6) + ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast) + ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() { + return false + } + return true +} + +func isZeros(ip net.IP) bool { + for i := 0; i < len(ip); i++ { + if ip[i] != 0 { + return false + } + } + return true +} diff --git a/conn_test.go b/conn_test.go index d480361..faa6ad4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -27,6 +27,22 @@ func check(err error, t *testing.T) { } } +func checkIPv4(addr net.Addr, t *testing.T) { + var ip net.IP + switch addr := addr.(type) { + case *net.IPNet: + ip = addr.IP + case *net.IPAddr: + ip = addr.IP + default: + t.Fatalf("Failed to determine address type %T", addr) + } + + if ip.To4() == nil { + t.Fatalf("expected IPv4 for answer but got %s", ip) + } +} + func createListener(t *testing.T) *net.UDPConn { addr, err := net.ResolveUDPAddr("udp", DefaultAddress) check(err, t) @@ -60,12 +76,30 @@ func TestValidCommunication(t *testing.T) { if addr.String() == localAddress { t.Fatalf("unexpected local address: %v", addr) } + checkIPv4(addr, t) _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") check(err, t) if addr.String() == localAddress { t.Fatalf("unexpected local address: %v", addr) } + checkIPv4(addr, t) + + // test against regression from https://github.com/pion/mdns/commit/608f20b + // where by properly sending mDNS responses to all interfaces, we significantly + // increased the chance that we send a loopback response to a Query that is + // unwillingly to use loopback addresses (the default in pion/ice). + for i := 0; i < 100; i++ { + _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + check(err, t) + if addr.String() == localAddress { + t.Fatalf("unexpected local address: %v", addr) + } + if addr.String() == "127.0.0.1" { + t.Fatal("unexpected loopback") + } + checkIPv4(addr, t) + } check(aServer.Close(), t) check(bServer.Close(), t) @@ -105,6 +139,93 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) { } } +func TestValidCommunicationWithLoopbackAddressConfig(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock := createListener(t) + + loopbackIP := net.ParseIP("127.0.0.1") + + aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + LocalAddress: loopbackIP, + IncludeLoopback: true, // the test would fail if this was false + }) + check(err, t) + + _, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local") + check(err, t) + if addr.String() != loopbackIP.String() { + t.Fatalf("address mismatch: expected %s, but got %v\n", localAddress, addr) + } + + check(aServer.Close(), t) +} + +func TestValidCommunicationWithLoopbackInterface(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + aSock := createListener(t) + + ifaces, err := net.Interfaces() + check(err, t) + ifacesToUse := make([]net.Interface, 0, len(ifaces)) + for _, ifc := range ifaces { + if ifc.Flags&net.FlagLoopback != net.FlagLoopback { + continue + } + ifcCopy := ifc + ifacesToUse = append(ifacesToUse, ifcCopy) + } + + // the following checks are unlikely to fail since most places where this code runs + // will have a loopback + if len(ifacesToUse) == 0 { + t.Skip("expected at least one loopback interface, but got none") + } + + aServer, err := Server(ipv4.NewPacketConn(aSock), &Config{ + LocalNames: []string{"pion-mdns-1.local", "pion-mdns-2.local"}, + IncludeLoopback: true, // the test would fail if this was false + Interfaces: ifacesToUse, + }) + check(err, t) + + _, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local") + check(err, t) + var found bool + for _, iface := range ifacesToUse { + addrs, err := iface.Addrs() + check(err, t) + for _, ifaceAddr := range addrs { + ipAddr, ok := ifaceAddr.(*net.IPNet) + if !ok { + t.Fatalf("expected *net.IPNet address for loopback but got %T", addr) + } + if addr.String() == ipAddr.IP.String() { + found = true + break + } + } + if found { + break + } + } + if !found { + t.Fatalf("address mismatch: expected loopback address, but got %v\n", addr) + } + + check(aServer.Close(), t) +} + func TestMultipleClose(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() @@ -214,40 +335,64 @@ func TestResourceParsing(t *testing.T) { } } - name, err := dnsmessage.NewName("test-server.") - if err != nil { - t.Fatal(err) - } + name := "test-server." t.Run("A Record", func(t *testing.T) { - lookForIP(dnsmessage.Message{ - Header: dnsmessage.Header{Response: true, Authoritative: true}, - Answers: []dnsmessage.Resource{ - { - Header: dnsmessage.ResourceHeader{ - Name: name, - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }, - Body: &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}, - }, - }, - }, []byte{127, 0, 0, 1}) + answer, err := createAnswer(name, net.ParseIP("127.0.0.1")) + if err != nil { + t.Fatal(err) + } + lookForIP(answer, []byte{127, 0, 0, 1}) }) t.Run("AAAA Record", func(t *testing.T) { - lookForIP(dnsmessage.Message{ - Header: dnsmessage.Header{Response: true, Authoritative: true}, - Answers: []dnsmessage.Resource{ - { - Header: dnsmessage.ResourceHeader{ - Name: name, - Type: dnsmessage.TypeAAAA, - Class: dnsmessage.ClassINET, - }, - Body: &dnsmessage.AAAAResource{AAAA: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}, - }, - }, - }, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) + answer, err := createAnswer(name, net.ParseIP("::1")) + if err != nil { + t.Fatal(err) + } + lookForIP(answer, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) }) } + +func TestIPToBytes(t *testing.T) { + expectedIP := []byte{127, 0, 0, 1} + actualIP4, err := ipv4ToBytes(net.ParseIP("127.0.0.1")) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(actualIP4[:], expectedIP) { + t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualIP4) + } + + expectedIP = []byte{0, 0, 0, 1} + actualIP4, err = ipv4ToBytes(net.ParseIP("0.0.0.1")) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(actualIP4[:], expectedIP) { + t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualIP4) + } + + expectedIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + actualIP6, err := ipv6ToBytes(net.ParseIP("::1")) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(actualIP6[:], expectedIP) { + t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualIP6) + } + + _, err = ipv4ToBytes(net.ParseIP("::1")) + if err == nil { + t.Fatal("expected ::1 to not be output to IPv4 bytes") + } + + expectedIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1} + actualIP6, err = ipv6ToBytes(net.ParseIP("127.0.0.1")) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(actualIP6[:], expectedIP) { + t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualIP6) + } +} diff --git a/go.mod b/go.mod index d8b2977..9bf6a4f 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,11 @@ module github.com/pion/mdns -go 1.12 +go 1.19 require ( github.com/pion/logging v0.2.2 github.com/pion/transport/v3 v3.0.1 golang.org/x/net v0.20.0 ) + +require golang.org/x/sys v0.16.0 // indirect diff --git a/go.sum b/go.sum index 1e22f4c..ba6294d 100644 --- a/go.sum +++ b/go.sum @@ -15,7 +15,6 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -44,14 +43,12 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=