diff --git a/config.go b/config.go index 50a7a57..4659e06 100644 --- a/config.go +++ b/config.go @@ -24,6 +24,9 @@ const ( // Config is used to configure a mDNS client or server. type Config struct { + // Name is the name of the client/server used for logging purposes. + Name string + // QueryInterval controls how often we sends Queries until we // get a response for the requested name QueryInterval time.Duration diff --git a/conn.go b/conn.go index f25ac55..c736814 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "net" + "net/netip" "sync" "time" @@ -19,8 +20,9 @@ import ( // Conn represents a mDNS Server type Conn struct { - mu sync.RWMutex - log logging.LeveledLogger + mu sync.RWMutex + name string + log logging.LeveledLogger multicastPktConnV4 ipPacketConn multicastPktConnV6 ipPacketConn @@ -45,7 +47,7 @@ type query struct { type queryResult struct { answer dnsmessage.ResourceHeader - addr net.Addr + addr netip.Addr } const ( @@ -71,13 +73,18 @@ var ( type netInterface struct { net.Interface - ips []net.IP + ipAddrs []netip.Addr supportsV4 bool supportsV6 bool } // Server establishes a mDNS connection over an existing conn. -// Either one or both of the multicast packet conns should be provided, +// Either one or both of the multicast packet conns should be provided. +// The presence of each IP type of PacketConn will dictate what kinds +// of questions are sent for queries. That is, if an ipv6.PacketConn is +// provided, then AAAA questions will be sent. A questions will only be +// sent if an ipv4.PacketConn is also provided. In the future, we may +// add a QueryAddr method that allows specifying this more clearly. // //nolint:gocognit func Server( @@ -94,6 +101,16 @@ func Server( } log := loggerFactory.NewLogger("mdns") + c := &Conn{ + queryInterval: defaultQueryInterval, + log: log, + closed: make(chan interface{}), + } + c.name = config.Name + if c.name == "" { + c.name = fmt.Sprintf("%p", &c) + } + if multicastPktConnV4 == nil && multicastPktConnV6 == nil { return nil, errNoPacketConn } @@ -116,7 +133,7 @@ func Server( unicastConnV4, err := net.ListenUDP("udp4", addr4) if err != nil { - log.Warnf("failed to listen on unicast IPv4 %s: %s; will not be able to receive unicast responses on IPv4", addr4, err) + log.Warnf("[%s] failed to listen on unicast IPv4 %s: %s; will not be able to receive unicast responses on IPv4", c.name, addr4, err) } else { unicastPktConnV4 = ipv4.NewPacketConn(unicastConnV4) } @@ -131,7 +148,7 @@ func Server( unicastConnV6, err := net.ListenUDP("udp6", addr6) if err != nil { - log.Warnf("failed to listen on unicast IPv6 %s: %s; will not be able to receive unicast responses on IPv6", addr6, err) + log.Warnf("[%s] failed to listen on unicast IPv6 %s: %s; will not be able to receive unicast responses on IPv6", c.name, addr6, err) } else { unicastPktConnV6 = ipv6.NewPacketConn(unicastConnV6) } @@ -161,23 +178,36 @@ func Server( continue } var supportsV4, supportsV6 bool - ifcIPs := make([]net.IP, 0, len(addrs)) + ifcIPAddrs := make([]netip.Addr, 0, len(addrs)) for _, addr := range addrs { - var ip net.IP + var ipToConv net.IP switch addr := addr.(type) { case *net.IPNet: - ip = addr.IP + ipToConv = addr.IP case *net.IPAddr: - ip = addr.IP + ipToConv = addr.IP default: continue } - if ip.To4() == nil { + + ipAddr, ok := netip.AddrFromSlice(ipToConv) + if !ok { + continue + } + if multicastPktConnV4 != nil { + // don't want mapping since we also support IPv4/A + ipAddr = ipAddr.Unmap() + } + ipAddr = addrWithOptionalZone(ipAddr, ifc.Name) + + if ipAddr.Is6() && !ipAddr.Is4In6() { supportsV6 = true } else { + // we'll claim we support v4 but defer if we send it or not + // based on IPv4-to-IPv6 mapping rules later (search for Is4In6 below) supportsV4 = true } - ifcIPs = append(ifcIPs, ip) + ifcIPAddrs = append(ifcIPAddrs, ipAddr) } if !(supportsV4 || supportsV6) { continue @@ -201,23 +231,13 @@ func Server( ifacesToUse[ifc.Index] = netInterface{ Interface: ifc, - ips: ifcIPs, + ipAddrs: ifcIPAddrs, supportsV4: supportsV4, supportsV6: supportsV6, } if ifc.MTU > inboundBufferSize { inboundBufferSize = ifc.MTU } - if supportsV4 && unicastPktConnV4 != nil { - if err := unicastPktConnV4.JoinGroup(&ifc, multicastGroupAddr4); err != nil { - log.Debugf("failed to JoinGroup on unicast IPv4 connection for interface %d: %v", ifc.Index, err) - } - } - if supportsV6 && unicastPktConnV6 != nil { - if err := unicastPktConnV6.JoinGroup(&ifc, multicastGroupAddr6); err != nil { - log.Debugf("failed to JoinGroup on unicast IPv6 connection for interface %d: %v", ifc.Index, err) - } - } } if len(ifacesToUse) == 0 { @@ -248,54 +268,50 @@ func Server( localNames = append(localNames, l+".") } - c := &Conn{ - queryInterval: defaultQueryInterval, - dstAddr4: dstAddr4, - dstAddr6: dstAddr6, - localNames: localNames, - ifaces: ifacesToUse, - log: log, - closed: make(chan interface{}), - } + c.dstAddr4 = dstAddr4 + c.dstAddr6 = dstAddr6 + c.localNames = localNames + c.ifaces = ifacesToUse + if config.QueryInterval != 0 { c.queryInterval = config.QueryInterval } if multicastPktConnV4 != nil { if err := multicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil { - c.log.Warnf("failed to SetControlMessage(ipv4.FlagInterface) on multicast IPv4 PacketConn %v", err) + c.log.Warnf("[%s] failed to SetControlMessage(ipv4.FlagInterface) on multicast IPv4 PacketConn %v", c.name, err) } if err := multicastPktConnV4.SetControlMessage(ipv4.FlagDst, true); err != nil { - c.log.Warnf("failed to SetControlMessage(ipv4.FlagDst) on multicast IPv4 PacketConn %v", err) + c.log.Warnf("[%s] failed to SetControlMessage(ipv4.FlagDst) on multicast IPv4 PacketConn %v", c.name, err) } - c.multicastPktConnV4 = ipPacketConn4{multicastPktConnV4, log} + c.multicastPktConnV4 = ipPacketConn4{c.name, multicastPktConnV4, log} } if multicastPktConnV6 != nil { if err := multicastPktConnV6.SetControlMessage(ipv6.FlagInterface, true); err != nil { - c.log.Warnf("failed to SetControlMessage(ipv6.FlagInterface) on multicast IPv6 PacketConn %v", err) + c.log.Warnf("[%s] failed to SetControlMessage(ipv6.FlagInterface) on multicast IPv6 PacketConn %v", c.name, err) } if err := multicastPktConnV6.SetControlMessage(ipv6.FlagDst, true); err != nil { - c.log.Warnf("failed to SetControlMessage(ipv6.FlagInterface) on multicast IPv6 PacketConn %v", err) + c.log.Warnf("[%s] failed to SetControlMessage(ipv6.FlagInterface) on multicast IPv6 PacketConn %v", c.name, err) } - c.multicastPktConnV6 = ipPacketConn6{multicastPktConnV6, log} + c.multicastPktConnV6 = ipPacketConn6{c.name, multicastPktConnV6, log} } if unicastPktConnV4 != nil { if err := unicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil { - c.log.Warnf("failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", err) + c.log.Warnf("[%s] failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", c.name, err) } if err := unicastPktConnV4.SetControlMessage(ipv4.FlagDst, true); err != nil { - c.log.Warnf("failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", err) + c.log.Warnf("[%s] failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", c.name, err) } - c.unicastPktConnV4 = ipPacketConn4{unicastPktConnV4, log} + c.unicastPktConnV4 = ipPacketConn4{c.name, unicastPktConnV4, log} } if unicastPktConnV6 != nil { if err := unicastPktConnV6.SetControlMessage(ipv6.FlagInterface, true); err != nil { - c.log.Warnf("failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", err) + c.log.Warnf("[%s] failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", c.name, err) } if err := unicastPktConnV6.SetControlMessage(ipv6.FlagDst, true); err != nil { - c.log.Warnf("failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", err) + c.log.Warnf("[%s] failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", c.name, err) } - c.unicastPktConnV6 = ipPacketConn6{unicastPktConnV6, log} + c.unicastPktConnV6 = ipPacketConn6{c.name, unicastPktConnV6, log} } if config.IncludeLoopback { @@ -303,22 +319,22 @@ func Server( // further out into the network stack. if multicastPktConnV4 != nil { if err := multicastPktConnV4.SetMulticastLoopback(true); err != nil { - c.log.Warnf("failed to SetMulticastLoopback(true) on multicast IPv4 PacketConn %v; this may cause inefficient network path communications", err) + c.log.Warnf("[%s] failed to SetMulticastLoopback(true) on multicast IPv4 PacketConn %v; this may cause inefficient network path c.name,communications", err) } } if multicastPktConnV6 != nil { if err := multicastPktConnV6.SetMulticastLoopback(true); err != nil { - c.log.Warnf("failed to SetMulticastLoopback(true) on multicast IPv6 PacketConn %v; this may cause inefficient network path communications", err) + c.log.Warnf("[%s] failed to SetMulticastLoopback(true) on multicast IPv6 PacketConn %v; this may cause inefficient network path c.name,communications", err) } } if unicastPktConnV4 != nil { if err := unicastPktConnV4.SetMulticastLoopback(true); err != nil { - c.log.Warnf("failed to SetMulticastLoopback(true) on unicast IPv4 PacketConn %v; this may cause inefficient network path communications", err) + c.log.Warnf("[%s] failed to SetMulticastLoopback(true) on unicast IPv4 PacketConn %v; this may cause inefficient network path c.name,communications", err) } } if unicastPktConnV6 != nil { if err := unicastPktConnV6.SetMulticastLoopback(true); err != nil { - c.log.Warnf("failed to SetMulticastLoopback(true) on unicast IPv6 PacketConn %v; this may cause inefficient network path communications", err) + c.log.Warnf("[%s] failed to SetMulticastLoopback(true) on unicast IPv6 PacketConn %v; this may cause inefficient network path c.name,communications", err) } } } @@ -374,17 +390,32 @@ func (c *Conn) Close() error { rtrn := errFailedToClose for _, err := range errs { - rtrn = fmt.Errorf("%w\n%s", err, rtrn.Error()) + rtrn = fmt.Errorf("%w\n%w", err, rtrn) } return rtrn } // Query sends mDNS Queries for the following name until // either the Context is canceled/expires or we get a result +// +// Deprecated: Use QueryAddr instead as it supports the easier to use netip.Addr. func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeader, net.Addr, error) { + header, addr, err := c.QueryAddr(ctx, name) + if err != nil { + return header, nil, err + } + return header, &net.IPAddr{ + IP: addr.AsSlice(), + Zone: addr.Zone(), + }, nil +} + +// QueryAddr sends mDNS Queries for the following name until +// either the Context is canceled/expires or we get a result +func (c *Conn) QueryAddr(ctx context.Context, name string) (dnsmessage.ResourceHeader, netip.Addr, error) { select { case <-c.closed: - return dnsmessage.ResourceHeader{}, nil, errConnectionClosed + return dnsmessage.ResourceHeader{}, netip.Addr{}, errConnectionClosed default: } @@ -415,7 +446,7 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade case <-ticker.C: c.sendQuestion(nameWithSuffix) case <-c.closed: - return dnsmessage.ResourceHeader{}, nil, errConnectionClosed + return dnsmessage.ResourceHeader{}, netip.Addr{}, errConnectionClosed case res := <-queryChan: // Given https://datatracker.ietf.org/doc/html/draft-ietf-mmusic-mdns-ice-candidates#section-3.2.2-2 // An ICE agent SHOULD ignore candidates where the hostname resolution returns more than one IP address. @@ -424,45 +455,62 @@ func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeade // one is better than the other (e.g. localhost vs LAN). return res.answer, res.addr, nil case <-ctx.Done(): - return dnsmessage.ResourceHeader{}, nil, errContextElapsed + return dnsmessage.ResourceHeader{}, netip.Addr{}, errContextElapsed } } } type ipToBytesError struct { - ip net.IP + addr netip.Addr expectedType string } func (err ipToBytesError) Error() string { - return fmt.Sprintf("ip (%s) is not %s", err.ip, err.expectedType) + return fmt.Sprintf("ip (%s) is not %s", err.addr, err.expectedType) } -func ipv4ToBytes(ip net.IP) ([4]byte, error) { - rawIP := ip.To4() - if rawIP == nil { - return [4]byte{}, ipToBytesError{ip, "IPv4"} +// assumes ipv4-to-ipv6 mapping has been checked +func ipv4ToBytes(ipAddr netip.Addr) ([4]byte, error) { + if !ipAddr.Is4() { + return [4]byte{}, ipToBytesError{ipAddr, "IPv4"} + } + + md, err := ipAddr.MarshalBinary() + if err != nil { + return [4]byte{}, err } // net.IPs are stored in big endian / network byte order var out [4]byte - copy(out[:], rawIP[:]) + copy(out[:], md) return out, nil } -func ipv6ToBytes(ip net.IP) ([16]byte, error) { - rawIP := ip.To16() - if rawIP == nil { - return [16]byte{}, ipToBytesError{ip, "IPv6"} +// assumes ipv4-to-ipv6 mapping has been checked +func ipv6ToBytes(ipAddr netip.Addr) ([16]byte, error) { + if !ipAddr.Is6() { + return [16]byte{}, ipToBytesError{ipAddr, "IPv6"} + } + md, err := ipAddr.MarshalBinary() + if err != nil { + return [16]byte{}, err } // net.IPs are stored in big endian / network byte order var out [16]byte - copy(out[:], rawIP[:]) + copy(out[:], md) return out, nil } -func interfaceForRemote(remote string) (net.IP, error) { +type ipToAddrError struct { + ip []byte +} + +func (err ipToAddrError) Error() string { + return fmt.Sprintf("failed to convert ip address '%s' to netip.Addr", err.ip) +} + +func interfaceForRemote(remote string) (*netip.Addr, error) { conn, err := net.Dial("udp", remote) if err != nil { return nil, err @@ -477,7 +525,12 @@ func interfaceForRemote(remote string) (net.IP, error) { return nil, err } - return localAddr.IP, nil + ipAddr, ok := netip.AddrFromSlice(localAddr.IP) + if !ok { + return nil, ipToAddrError{localAddr.IP} + } + ipAddr = addrWithOptionalZone(ipAddr, localAddr.Zone) + return &ipAddr, nil } type writeType byte @@ -490,7 +543,7 @@ const ( func (c *Conn) sendQuestion(name string) { packedName, err := dnsmessage.NewName(name) if err != nil { - c.log.Warnf("failed to construct mDNS packet %v", err) + c.log.Warnf("[%s] failed to construct mDNS packet %v", c.name, err) return } @@ -533,14 +586,22 @@ func (c *Conn) sendQuestion(name string) { rawQuery, err := msg.Pack() if err != nil { - c.log.Warnf("failed to construct mDNS packet %v", err) + c.log.Warnf("[%s] failed to construct mDNS packet %v", c.name, err) return } - c.writeToSocket(0, rawQuery, false, writeTypeQuestion, nil) + c.writeToSocket(0, rawQuery, false, false, writeTypeQuestion, nil) } -func (c *Conn) writeToSocket(ifIndex int, b []byte, hasLoopbackData bool, wType writeType, unicastDst *net.UDPAddr) { //nolint:gocognit +//nolint:gocognit +func (c *Conn) writeToSocket( + ifIndex int, + b []byte, + hasLoopbackData bool, + hasIPv6Zone bool, + wType writeType, + unicastDst *net.UDPAddr, +) { var dst4, dst6 net.Addr if wType == writeTypeAnswer { if unicastDst == nil { @@ -557,31 +618,35 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, hasLoopbackData bool, wType if ifIndex != 0 { if wType == writeTypeQuestion { - c.log.Errorf("Unexpected question using specific interface index %d; dropping question", ifIndex) + c.log.Errorf("[%s] Unexpected question using specific interface index %d; dropping question", c.name, ifIndex) return } ifc, ok := c.ifaces[ifIndex] if !ok { - c.log.Warnf("no interface for %d", ifIndex) + c.log.Warnf("[%s] no interface for %d", c.name, ifIndex) return } if hasLoopbackData && ifc.Flags&net.FlagLoopback == 0 { // avoid accidentally tricking the destination that itself is the same as us - c.log.Debugf("interface is not loopback %d", ifIndex) + c.log.Debugf("[%s] interface is not loopback %d", c.name, ifIndex) return } - c.log.Debugf("writing answer to IPv4: %v, IPv6: %v", dst4, dst6) + c.log.Debugf("[%s] writing answer to IPv4: %v, IPv6: %v", c.name, dst4, dst6) if ifc.supportsV4 && c.multicastPktConnV4 != nil && dst4 != nil { - if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, dst4); err != nil { - c.log.Warnf("failed to send mDNS packet on IPv4 interface %d: %v", ifIndex, err) + if !hasIPv6Zone { + if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, dst4); err != nil { + c.log.Warnf("[%s] failed to send mDNS packet on IPv4 interface %d: %v", c.name, ifIndex, err) + } + } else { + c.log.Debugf("[%s] refusing to send mDNS packet with IPv6 zone over IPv4", c.name) } } if ifc.supportsV6 && c.multicastPktConnV6 != nil && dst6 != nil { if _, err := c.multicastPktConnV6.WriteTo(b, &ifc.Interface, nil, dst6); err != nil { - c.log.Warnf("failed to send mDNS packet on IPv6 interface %d: %v", ifIndex, err) + c.log.Warnf("[%s] failed to send mDNS packet on IPv6 interface %d: %v", c.name, ifIndex, err) } } @@ -590,7 +655,7 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, hasLoopbackData bool, wType for ifcIdx := range c.ifaces { ifc := c.ifaces[ifcIdx] if hasLoopbackData { - c.log.Debug("Refusing to send loopback data with non-specific interface") + c.log.Debugf("[%s] Refusing to send loopback data with non-specific interface", c.name) continue } @@ -600,48 +665,52 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, hasLoopbackData bool, wType // conn here, we'd be writing from a specific multicast address which won't be able to receive unicast // traffic (it only works when listening on 0.0.0.0/[::]). if c.unicastPktConnV4 == nil && c.unicastPktConnV6 == nil { - c.log.Debugf("writing question to multicast IPv4/6 %s", c.dstAddr4) + c.log.Debugf("[%s] writing question to multicast IPv4/6 %s", c.name, c.dstAddr4) if ifc.supportsV4 && c.multicastPktConnV4 != nil { if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, c.dstAddr4); err != nil { - c.log.Warnf("failed to send mDNS packet (multicast) on IPv4 interface %d: %v", ifc.Index, err) + c.log.Warnf("[%s] failed to send mDNS packet (multicast) on IPv4 interface %d: %v", c.name, ifc.Index, err) } } if ifc.supportsV6 && c.multicastPktConnV6 != nil { if _, err := c.multicastPktConnV6.WriteTo(b, &ifc.Interface, nil, c.dstAddr6); err != nil { - c.log.Warnf("failed to send mDNS packet (multicast) on IPv6 interface %d: %v", ifc.Index, err) + c.log.Warnf("[%s] failed to send mDNS packet (multicast) on IPv6 interface %d: %v", c.name, ifc.Index, err) } } } if ifc.supportsV4 && c.unicastPktConnV4 != nil { - c.log.Debugf("writing question to unicast IPv4 %s", c.dstAddr4) + c.log.Debugf("[%s] writing question to unicast IPv4 %s", c.name, c.dstAddr4) if _, err := c.unicastPktConnV4.WriteTo(b, &ifc.Interface, nil, c.dstAddr4); err != nil { - c.log.Warnf("failed to send mDNS packet (unicast) on interface %d: %v", ifc.Index, err) + c.log.Warnf("[%s] failed to send mDNS packet (unicast) on interface %d: %v", c.name, ifc.Index, err) } } if ifc.supportsV6 && c.unicastPktConnV6 != nil { - c.log.Debugf("writing question to unicast IPv6 %s", c.dstAddr6) + c.log.Debugf("[%s] writing question to unicast IPv6 %s", c.name, c.dstAddr6) if _, err := c.unicastPktConnV6.WriteTo(b, &ifc.Interface, nil, c.dstAddr6); err != nil { - c.log.Warnf("failed to send mDNS packet (unicast) on interface %d: %v", ifc.Index, err) + c.log.Warnf("[%s] failed to send mDNS packet (unicast) on interface %d: %v", c.name, ifc.Index, err) } } } else { - c.log.Debugf("writing answer to IPv4: %s, IPv6: %s", dst4, dst6) + c.log.Debugf("[%s] writing answer to IPv4: %s, IPv6: %s", c.name, dst4, dst6) if ifc.supportsV4 && c.multicastPktConnV4 != nil && dst4 != nil { - if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, dst4); err != nil { - c.log.Warnf("failed to send mDNS packet (multicast) on IPv4 interface %d: %v", ifIndex, err) + if !hasIPv6Zone { + if _, err := c.multicastPktConnV4.WriteTo(b, &ifc.Interface, nil, dst4); err != nil { + c.log.Warnf("[%s] failed to send mDNS packet (multicast) on IPv4 interface %d: %v", c.name, ifIndex, err) + } + } else { + c.log.Debugf("[%s] refusing to send mDNS packet with IPv6 zone over IPv4", c.name) } } if ifc.supportsV6 && c.multicastPktConnV6 != nil && dst6 != nil { if _, err := c.multicastPktConnV6.WriteTo(b, &ifc.Interface, nil, dst6); err != nil { - c.log.Warnf("failed to send mDNS packet (multicast) on IPv6 interface %d: %v", ifIndex, err) + c.log.Warnf("[%s] failed to send mDNS packet (multicast) on IPv6 interface %d: %v", c.name, ifIndex, err) } } } } } -func createAnswer(id uint16, name string, addr net.IP) (dnsmessage.Message, error) { +func createAnswer(id uint16, name string, addr netip.Addr) (dnsmessage.Message, error) { packedName, err := dnsmessage.NewName(name) if err != nil { return dnsmessage.Message{}, err @@ -664,7 +733,7 @@ func createAnswer(id uint16, name string, addr net.IP) (dnsmessage.Message, erro }, } - if len(addr) == net.IPv4len { + if addr.Is4() { ipBuf, err := ipv4ToBytes(addr) if err != nil { return dnsmessage.Message{}, err @@ -673,7 +742,8 @@ func createAnswer(id uint16, name string, addr net.IP) (dnsmessage.Message, erro msg.Answers[0].Body = &dnsmessage.AResource{ A: ipBuf, } - } else if len(addr) == net.IPv6len { + } else if addr.Is6() { + // we will lose the zone here, but the receiver can reconstruct it ipBuf, err := ipv6ToBytes(addr) if err != nil { return dnsmessage.Message{}, err @@ -687,20 +757,27 @@ func createAnswer(id uint16, name string, addr net.IP) (dnsmessage.Message, erro return msg, nil } -func (c *Conn) sendAnswer(queryID uint16, name string, ifIndex int, result net.IP, dst *net.UDPAddr) { +func (c *Conn) sendAnswer(queryID uint16, name string, ifIndex int, result netip.Addr, dst *net.UDPAddr) { answer, err := createAnswer(queryID, name, result) if err != nil { - c.log.Warnf("failed to create mDNS answer %v", err) + c.log.Warnf("[%s] failed to create mDNS answer %v", c.name, err) return } rawAnswer, err := answer.Pack() if err != nil { - c.log.Warnf("failed to construct mDNS packet %v", err) + c.log.Warnf("[%s] failed to construct mDNS packet %v", c.name, err) return } - c.writeToSocket(ifIndex, rawAnswer, result.IsLoopback(), writeTypeAnswer, dst) + c.writeToSocket( + ifIndex, + rawAnswer, + result.IsLoopback(), + result.Is6() && result.Zone() != "", + writeTypeAnswer, + dst, + ) } type ipControlMessage struct { @@ -715,6 +792,7 @@ type ipPacketConn interface { } type ipPacketConn4 struct { + name string conn *ipv4.PacketConn log logging.LeveledLogger } @@ -735,7 +813,7 @@ func (c ipPacketConn4) WriteTo(b []byte, via *net.Interface, cm *ipControlMessag } } if err := c.conn.SetMulticastInterface(via); err != nil { - c.log.Warnf("failed to set multicast interface for %d: %v", via.Index, err) + c.log.Warnf("[%s] failed to set multicast interface for %d: %v", c.name, via.Index, err) return 0, err } return c.conn.WriteTo(b, cm4, dst) @@ -746,6 +824,7 @@ func (c ipPacketConn4) Close() error { } type ipPacketConn6 struct { + name string conn *ipv6.PacketConn log logging.LeveledLogger } @@ -766,7 +845,7 @@ func (c ipPacketConn6) WriteTo(b []byte, via *net.Interface, cm *ipControlMessag } } if err := c.conn.SetMulticastInterface(via); err != nil { - c.log.Warnf("failed to set multicast interface for %d: %v", via.Index, err) + c.log.Warnf("[%s] failed to set multicast interface for %d: %v", c.name, via.Index, err) return 0, err } return c.conn.WriteTo(b, cm6, dst) @@ -786,27 +865,29 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int if errors.Is(err, net.ErrClosed) { return } - c.log.Warnf("failed to ReadFrom %q %v", src, err) + c.log.Warnf("[%s] failed to ReadFrom %q %v", c.name, src, err) continue } - c.log.Debugf("got read on %s from %s", name, src) + c.log.Debugf("[%s] got read on %s from %s", c.name, name, src) var ifIndex int var pktDst net.IP if cm != nil { ifIndex = cm.IfIndex pktDst = cm.Dst + } else { + ifIndex = -1 } srcAddr, ok := src.(*net.UDPAddr) if !ok { - c.log.Warnf("expected source address %s to be UDP but got %", src, src) + c.log.Warnf("[%s] expected source address %s to be UDP but got %", c.name, src, src) continue } func() { header, err := p.Start(b[:n]) if err != nil { - c.log.Warnf("failed to parse mDNS packet %v", err) + c.log.Warnf("[%s] failed to parse mDNS packet %v", c.name, err) return } @@ -815,7 +896,7 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int if errors.Is(err, dnsmessage.ErrSectionDone) { break } else if err != nil { - c.log.Warnf("failed to parse mDNS packet %v", err) + c.log.Warnf("[%s] failed to parse mDNS packet %v", c.name, err) return } @@ -842,9 +923,20 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int for _, localName := range c.localNames { if localName == q.Name.String() { - var localAddress net.IP + var localAddress *netip.Addr if config.LocalAddress != nil { - localAddress = config.LocalAddress + // this means the LocalAddress does not support link-local since + // we have no zone to set here. + ipAddr, ok := netip.AddrFromSlice(config.LocalAddress) + if !ok { + c.log.Warnf("[%s] failed to convert config.LocalAddress '%s' to netip.Addr", c.name, config.LocalAddress) + continue + } + if c.multicastPktConnV4 != nil { + // don't want mapping since we also support IPv4/A + ipAddr = ipAddr.Unmap() + } + localAddress = &ipAddr } else { // prefer the address of the interface if we know its index, but otherwise // derive it from the address we read from. We do this because even if @@ -856,65 +948,80 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int // Destination: 224.0.0.251 // Interface Index: 1 // Interface Addresses @ 1: [127.0.0.1/8 ::1/128] - if ifIndex != 0 { + if ifIndex != -1 { ifc, ok := c.ifaces[ifIndex] if !ok { - c.log.Warnf("no interface for %d", ifIndex) + c.log.Warnf("[%s] no interface for %d", c.name, ifIndex) return } - var selectedIP net.IP - for _, ip := range ifc.ips { - ipCopy := ip + var selectedAddr *netip.Addr + for _, addr := range ifc.ipAddrs { + addrCopy := addr // match up respective IP types based on question if queryWantsV4 { - if ipv4 := ipCopy.To4(); ipv4 == nil { + if addrCopy.Is4In6() { + // we may allow 4-in-6, but the question wants an A record + addrCopy = addrCopy.Unmap() + } + if !addrCopy.Is4() { continue } } else { // queryWantsV6 - if ipv6 := ipCopy.To16(); ipv6 == nil { + if !addrCopy.Is6() { continue - } else if !isSupportedIPv6(ipv6, c.multicastPktConnV4 == nil) { - c.log.Debugf("interface %d address not a supported IPv6 address %s", ifIndex, ipCopy) + } + if !isSupportedIPv6(addrCopy, c.multicastPktConnV4 == nil) { + c.log.Debugf("[%s] interface %d address not a supported IPv6 address %s", ifIndex, c.name, &addrCopy) continue } } - selectedIP = ipCopy + selectedAddr = &addrCopy break } - if selectedIP == nil { - c.log.Debugf("failed to find suitable IP for interface %d; deriving address from source address instead", ifIndex) + if selectedAddr == nil { + c.log.Debugf("[%s] failed to find suitable IP for interface %d; deriving address from source address c.name,instead", ifIndex) } else { - localAddress = selectedIP + localAddress = selectedAddr } } - if ifIndex == 0 || localAddress == nil { + if ifIndex == -1 || localAddress == nil { localAddress, err = interfaceForRemote(src.String()) if err != nil { - c.log.Warnf("failed to get local interface to communicate with %s: %v", src.String(), err) + c.log.Warnf("[%s] failed to get local interface to communicate with %s: %v", c.name, src.String(), err) continue } } } if queryWantsV4 { - localAddress = localAddress.To4() - if localAddress == nil { - c.log.Debugf("have IPv6 address %s to respond with but not question is for A not AAAA", localAddress) + if !localAddress.Is4() { + c.log.Debugf("[%s] have IPv6 address %s to respond with but not question is for A not c.name,AAAA", localAddress) continue } } else { - localAddress = localAddress.To16() - if localAddress == nil { - c.log.Debugf("have IPv4 address %s to respond with but not question is for AAAA not A", localAddress) + if !localAddress.Is6() { + c.log.Debugf("[%s] have IPv4 address %s to respond with but not question is for AAAA not c.name,A", localAddress) continue } - if !isSupportedIPv6(localAddress, c.multicastPktConnV4 == nil) { - c.log.Debugf("got local interface address but not a supported IPv6 address %s", localAddress) + if !isSupportedIPv6(*localAddress, c.multicastPktConnV4 == nil) { + c.log.Debugf("[%s] got local interface address but not a supported IPv6 address %v", c.name, localAddress) continue } } - c.sendAnswer(header.ID, q.Name.String(), ifIndex, localAddress, dst) + + if dst != nil && len(dst.IP) == net.IPv4len && + localAddress.Is6() && + localAddress.Zone() != "" && + (localAddress.IsLinkLocalUnicast() || localAddress.IsLinkLocalMulticast()) { + // This case happens when multicast v4 picks up an AAAA question that has a zone + // in the address. Since we cannot send this zone over DNS (it's meaningless), + // the other side can only infer this via the response interface on the other + // side (some IPv6 interface). + c.log.Debugf("[%s] refusing to send link-local address %s to an IPv4 destination %s", c.name, localAddress, dst) + continue + } + c.sendAnswer(header.ID, q.Name.String(), ifIndex, *localAddress, dst) } } } @@ -925,7 +1032,7 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int return } if err != nil { - c.log.Warnf("failed to parse mDNS packet %v", err) + c.log.Warnf("[%s] failed to parse mDNS packet %v", c.name, err) return } @@ -942,16 +1049,22 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int for _, query := range queries { queryCopy := query if queryCopy.nameWithSuffix == a.Name.String() { - ip, err := ipFromAnswerHeader(a, p) + addr, err := addrFromAnswerHeader(a, p) if err != nil { - c.log.Warnf("failed to parse mDNS answer %v", err) + c.log.Warnf("[%s] failed to parse mDNS answer %v", c.name, err) return } + resultAddr := *addr + // DNS records don't contain IPv6 zones. + // We're trusting that since we're on the same link, that we will only + // be sent link-local addresses from that source's interface's address. + // If it's not present, we're out of luck since we cannot rely on the + // interface zone to be the same as the source's. + resultAddr = addrWithOptionalZone(resultAddr, srcAddr.Zone) + select { - case queryCopy.queryResultChan <- queryResult{a, &net.IPAddr{ - IP: ip, - }}: + case queryCopy.queryResultChan <- queryResult{a, resultAddr}: answered = append(answered, queryCopy) default: } @@ -1035,38 +1148,51 @@ func (c *Conn) start(started chan<- struct{}, inboundBufferSize int, config *Con } } -func ipFromAnswerHeader(a dnsmessage.ResourceHeader, p dnsmessage.Parser) (ip []byte, err error) { +func addrFromAnswerHeader(a dnsmessage.ResourceHeader, p dnsmessage.Parser) (addr *netip.Addr, err error) { if a.Type == dnsmessage.TypeA { resource, err := p.AResource() if err != nil { return nil, err } - ip = resource.A[:] + ipAddr, ok := netip.AddrFromSlice(resource.A[:]) + if !ok { + return nil, fmt.Errorf("failed to convert A record: %w", ipToAddrError{resource.A[:]}) + } + ipAddr = ipAddr.Unmap() // do not want 4-in-6 + addr = &ipAddr } else { resource, err := p.AAAAResource() if err != nil { return nil, err } - ip = resource.AAAA[:] + ipAddr, ok := netip.AddrFromSlice(resource.AAAA[:]) + if !ok { + return nil, fmt.Errorf("failed to convert AAAA record: %w", ipToAddrError{resource.AAAA[:]}) + } + addr = &ipAddr } return } -func isSupportedIPv6(ip net.IP, ipv6Only bool) bool { - if len(ip) != net.IPv6len || - // IPv4-mapped IPv6 addresses cannot be connected to - (!ipv6Only && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff) { +func isSupportedIPv6(addr netip.Addr, ipv6Only bool) bool { + if !addr.Is6() { + return false + } + // IPv4-mapped-IPv6 addresses cannot be connected to unless + // unmapped. + if !ipv6Only && addr.Is4In6() { return false } return true } -func isZeros(ip net.IP) bool { - for i := 0; i < len(ip); i++ { - if ip[i] != 0 { - return false - } +func addrWithOptionalZone(addr netip.Addr, zone string) netip.Addr { + if zone == "" { + return addr } - return true + if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) { + return addr.WithZone(zone) + } + return addr } diff --git a/conn_test.go b/conn_test.go index c01a568..a282be6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -11,6 +11,7 @@ import ( "context" "errors" "net" + "net/netip" "runtime" "testing" "time" @@ -30,37 +31,17 @@ func check(err error, t *testing.T) { } } -func checkIPv4(addr net.Addr, t *testing.T) { +func checkIPv4(addr netip.Addr, t *testing.T) { t.Helper() - var ip net.IP - switch addr := addr.(type) { - case *net.IPNet: - ip = addr.IP - case *net.IPAddr: - ip = addr.IP - default: - t.Fatalf("Failed to determine address type %T", addr) - } - - if ip.To4() == nil { - t.Fatalf("expected IPv4 for answer but got %s", ip) + if !addr.Is4() { + t.Fatalf("expected IPv4 for answer but got %s", addr) } } -func checkIPv6(addr net.Addr, t *testing.T) { +func checkIPv6(addr netip.Addr, t *testing.T) { t.Helper() - var ip net.IP - switch addr := addr.(type) { - case *net.IPNet: - ip = addr.IP - case *net.IPAddr: - ip = addr.IP - default: - t.Fatalf("Failed to determine address type %T", addr) - } - - if ip.To16() == nil { - t.Fatalf("expected IPv6 for answer but got %s", ip) + if !addr.Is6() { + t.Fatalf("expected IPv6 for answer but got %s", addr) } } @@ -102,14 +83,14 @@ func TestValidCommunication(t *testing.T) { bServer, err := Server(ipv4.NewPacketConn(bSock), nil, &Config{}) check(err, t) - _, addr, err := bServer.Query(context.TODO(), "pion-mdns-1.local") + _, addr, err := bServer.QueryAddr(context.TODO(), "pion-mdns-1.local") check(err, t) if addr.String() == localAddress { t.Fatalf("unexpected local address: %v", addr) } checkIPv4(addr, t) - _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + _, addr, err = bServer.QueryAddr(context.TODO(), "pion-mdns-2.local") check(err, t) if addr.String() == localAddress { t.Fatalf("unexpected local address: %v", addr) @@ -121,7 +102,7 @@ func TestValidCommunication(t *testing.T) { // increased the chance that we send a loopback response to a Query that is // unwillingly to use loopback addresses (the default in pion/ice). for i := 0; i < 100; i++ { - _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + _, addr, err = bServer.QueryAddr(context.TODO(), "pion-mdns-2.local") check(err, t) if addr.String() == localAddress { t.Fatalf("unexpected local address: %v", addr) @@ -158,7 +139,7 @@ func TestValidCommunicationWithAddressConfig(t *testing.T) { }) check(err, t) - _, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local") + _, addr, err := aServer.QueryAddr(context.TODO(), "pion-mdns-1.local") check(err, t) if addr.String() != localAddress { t.Fatalf("address mismatch: expected %s, but got %v\n", localAddress, addr) @@ -188,7 +169,7 @@ func TestValidCommunicationWithLoopbackAddressConfig(t *testing.T) { }) check(err, t) - _, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local") + _, addr, err := aServer.QueryAddr(context.TODO(), "pion-mdns-1.local") check(err, t) if addr.String() != loopbackIP.String() { t.Fatalf("address mismatch: expected %s, but got %v\n", localAddress, addr) @@ -230,7 +211,7 @@ func TestValidCommunicationWithLoopbackInterface(t *testing.T) { }) check(err, t) - _, addr, err := aServer.Query(context.TODO(), "pion-mdns-1.local") + _, addr, err := aServer.QueryAddr(context.TODO(), "pion-mdns-1.local") check(err, t) var found bool for _, iface := range ifacesToUse { @@ -285,7 +266,7 @@ func TestValidCommunicationIPv6(t *testing.T) { bServer, err := Server(nil, ipv6.NewPacketConn(bSock), &Config{}) check(err, t) - header, addr, err := bServer.Query(context.TODO(), "pion-mdns-1.local") + header, addr, err := bServer.QueryAddr(context.TODO(), "pion-mdns-1.local") check(err, t) if header.Type != dnsmessage.TypeAAAA { t.Fatalf("expected AAAA but got %s", header.Type) @@ -295,8 +276,15 @@ func TestValidCommunicationIPv6(t *testing.T) { t.Fatalf("unexpected local address: %v", addr) } checkIPv6(addr, t) + if addr.Is4In6() { + // probably within docker + t.Logf("address %s is an IPv4-to-IPv6 mapped address even though the stack is IPv6", addr) + } + if !addr.Is4In6() && addr.Zone() == "" { + t.Fatalf("expected IPv6 to have zone but got %s", addr) + } - header, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + header, addr, err = bServer.QueryAddr(context.TODO(), "pion-mdns-2.local") check(err, t) if header.Type != dnsmessage.TypeAAAA { t.Fatalf("expected AAAA but got %s", header.Type) @@ -306,6 +294,9 @@ func TestValidCommunicationIPv6(t *testing.T) { t.Fatalf("unexpected local address: %v", addr) } checkIPv6(addr, t) + if !addr.Is4In6() && addr.Zone() == "" { + t.Fatalf("expected IPv6 to have zone but got %s", addr) + } check(aServer.Close(), t) check(bServer.Close(), t) @@ -342,14 +333,14 @@ func TestValidCommunicationIPv46(t *testing.T) { bServer, err := Server(ipv4.NewPacketConn(bSock4), ipv6.NewPacketConn(bSock6), &Config{}) check(err, t) - _, addr, err := bServer.Query(context.TODO(), "pion-mdns-1.local") + _, addr, err := bServer.QueryAddr(context.TODO(), "pion-mdns-1.local") check(err, t) if addr.String() == localAddress { t.Fatalf("unexpected local address: %v", addr) } - _, addr, err = bServer.Query(context.TODO(), "pion-mdns-2.local") + _, addr, err = bServer.QueryAddr(context.TODO(), "pion-mdns-2.local") check(err, t) if addr.String() == localAddress { t.Fatalf("unexpected local address: %v", addr) @@ -380,20 +371,31 @@ func TestValidCommunicationIPv46Mixed(t *testing.T) { aSock4 := createListener4(t) bSock6 := createListener6(t) + // we can always send from a 6-only server to a 4-only server but not always + // the other way around because the IPv4-only server will only listen + // on multicast for IPv4 questions, so it will never see an IPv6 originated + // question that contains required information to respond (the zone, if link-local). + // Therefore, the IPv4 server will refuse answering AAAA responses over + // unicast/multicast IPv4 if the answer is an IPv6 link-local address. This is basically + // the majority of cases unless a LocalAddress is set on the Config. aServer, err := Server(ipv4.NewPacketConn(aSock4), nil, &Config{ - LocalNames: []string{"pion-mdns-1.local"}, + Name: "aServer", }) check(err, t) - bServer, err := Server(nil, ipv6.NewPacketConn(bSock6), &Config{}) + bServer, err := Server(nil, ipv6.NewPacketConn(bSock6), &Config{ + Name: "bServer", + LocalNames: []string{"pion-mdns-1.local"}, + }) check(err, t) - header, addr, err := bServer.Query(context.TODO(), "pion-mdns-1.local") + header, addr, err := aServer.QueryAddr(context.TODO(), "pion-mdns-1.local") + check(err, t) - if header.Type != dnsmessage.TypeAAAA { - t.Fatalf("expected AAAA but got %s", header.Type) + if header.Type != dnsmessage.TypeA { + t.Fatalf("expected A but got %s", header.Type) } - checkIPv6(addr, t) + checkIPv4(addr, t) check(aServer.Close(), t) check(bServer.Close(), t) @@ -434,7 +436,7 @@ func TestValidCommunicationIPv46MixedLocalAddress(t *testing.T) { // we want ipv6 but all we can offer is an ipv4 mapped address, so it should fail until we support // allowing this explicitly via configuration on the aServer side - if _, _, err := bServer.Query(ctx, "pion-mdns-1.local"); !errors.Is(err, errContextElapsed) { + if _, _, err := bServer.QueryAddr(ctx, "pion-mdns-1.local"); !errors.Is(err, errContextElapsed) { t.Fatalf("Query expired but returned unexpected error %v", err) } @@ -472,12 +474,16 @@ func TestValidCommunicationIPv66MixedLocalAddress(t *testing.T) { bServer, err := Server(nil, ipv6.NewPacketConn(bSock6), &Config{}) check(err, t) - header, addr, err := bServer.Query(context.TODO(), "pion-mdns-1.local") + header, addr, err := bServer.QueryAddr(context.TODO(), "pion-mdns-1.local") check(err, t) if header.Type != dnsmessage.TypeAAAA { t.Fatalf("expected AAAA but got %s", header.Type) } - if addr.String() != localAddress { + if !addr.Is4In6() { + t.Fatalf("expected address to be ipv4-to-ipv6 mapped: %v", addr) + } + // now unmap just for this check + if addr.Unmap().String() != localAddress { t.Fatalf("unexpected local address: %v", addr) } checkIPv6(addr, t) @@ -515,14 +521,14 @@ func TestValidCommunicationIPv64Mixed(t *testing.T) { bServer, err := Server(ipv4.NewPacketConn(bSock4), nil, &Config{}) check(err, t) - _, addr, err := bServer.Query(context.TODO(), "pion-mdns-1.local") + _, addr, err := bServer.QueryAddr(context.TODO(), "pion-mdns-1.local") check(err, t) if addr.String() == localAddress { t.Fatalf("unexpected local address: %v", addr) } - header, addr, err := bServer.Query(context.TODO(), "pion-mdns-2.local") + header, addr, err := bServer.QueryAddr(context.TODO(), "pion-mdns-2.local") check(err, t) if header.Type != dnsmessage.TypeA { t.Fatalf("expected A but got %s", header.Type) @@ -577,7 +583,7 @@ func TestQueryRespectTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - if _, _, err = server.Query(ctx, "invalid-host"); !errors.Is(err, errContextElapsed) { + if _, _, err = server.QueryAddr(ctx, "invalid-host"); !errors.Is(err, errContextElapsed) { t.Fatalf("Query expired but returned unexpected error %v", err) } @@ -607,11 +613,11 @@ func TestQueryRespectClose(t *testing.T) { check(server.Close(), t) }() - if _, _, err = server.Query(context.TODO(), "invalid-host"); !errors.Is(err, errConnectionClosed) { + if _, _, err = server.QueryAddr(context.TODO(), "invalid-host"); !errors.Is(err, errConnectionClosed) { t.Fatalf("Query on closed server but returned unexpected error %v", err) } - if _, _, err = server.Query(context.TODO(), "invalid-host"); !errors.Is(err, errConnectionClosed) { + if _, _, err = server.QueryAddr(context.TODO(), "invalid-host"); !errors.Is(err, errConnectionClosed) { t.Fatalf("Query on closed server but returned unexpected error %v", err) } @@ -641,20 +647,20 @@ func TestResourceParsing(t *testing.T) { t.Fatal(err) } - actualIP, err := ipFromAnswerHeader(h, p) + actualAddr, err := addrFromAnswerHeader(h, p) if err != nil { t.Fatal(err) } - if !bytes.Equal(actualIP, expectedIP) { - t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualIP) + if !bytes.Equal(actualAddr.AsSlice(), expectedIP) { + t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualAddr) } } name := "test-server." t.Run("A Record", func(t *testing.T) { - answer, err := createAnswer(1, name, net.IP{127, 0, 0, 1}) + answer, err := createAnswer(1, name, mustAddr(net.IP{127, 0, 0, 1})) if err != nil { t.Fatal(err) } @@ -662,7 +668,7 @@ func TestResourceParsing(t *testing.T) { }) t.Run("AAAA Record", func(t *testing.T) { - answer, err := createAnswer(1, name, net.ParseIP("::1")) + answer, err := createAnswer(1, name, netip.MustParseAddr("::1")) if err != nil { t.Fatal(err) } @@ -670,45 +676,57 @@ func TestResourceParsing(t *testing.T) { }) } +func mustAddr(ip net.IP) netip.Addr { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + panic(ipToAddrError{ip}) + } + return addr +} + func TestIPToBytes(t *testing.T) { expectedIP := []byte{127, 0, 0, 1} - actualIP4, err := ipv4ToBytes(net.ParseIP("127.0.0.1")) + actualAddr4, err := ipv4ToBytes(netip.MustParseAddr("127.0.0.1")) if err != nil { t.Fatal(err) } - if !bytes.Equal(actualIP4[:], expectedIP) { - t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualIP4) + if !bytes.Equal(actualAddr4[:], expectedIP) { + t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualAddr4) } expectedIP = []byte{0, 0, 0, 1} - actualIP4, err = ipv4ToBytes(net.ParseIP("0.0.0.1")) + actualAddr4, err = ipv4ToBytes(netip.MustParseAddr("0.0.0.1")) if err != nil { t.Fatal(err) } - if !bytes.Equal(actualIP4[:], expectedIP) { - t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualIP4) + if !bytes.Equal(actualAddr4[:], expectedIP) { + t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualAddr4) } expectedIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} - actualIP6, err := ipv6ToBytes(net.ParseIP("::1")) + actualAddr6, err := ipv6ToBytes(netip.MustParseAddr("::1")) if err != nil { t.Fatal(err) } - if !bytes.Equal(actualIP6[:], expectedIP) { - t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualIP6) + if !bytes.Equal(actualAddr6[:], expectedIP) { + t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualAddr6) } - _, err = ipv4ToBytes(net.ParseIP("::1")) + _, err = ipv4ToBytes(netip.MustParseAddr("::1")) if err == nil { t.Fatal("expected ::1 to not be output to IPv4 bytes") } expectedIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1} - actualIP6, err = ipv6ToBytes(net.ParseIP("127.0.0.1")) + addr, ok := netip.AddrFromSlice(net.ParseIP("127.0.0.1")) + if !ok { + t.Fatal("expected to be able to convert IP to netip.Addr") + } + actualAddr6, err = ipv6ToBytes(addr) if err != nil { t.Fatal(err) } - if !bytes.Equal(actualIP6[:], expectedIP) { - t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualIP6) + if !bytes.Equal(actualAddr6[:], expectedIP) { + t.Fatalf("Expected(%v) and Actual(%v) IP don't match", expectedIP, actualAddr6) } } diff --git a/examples/query/main.go b/examples/query/main.go index 919b80d..3539d05 100644 --- a/examples/query/main.go +++ b/examples/query/main.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "net" + "os" "github.com/pion/mdns/v2" "golang.org/x/net/ipv4" @@ -15,31 +16,59 @@ import ( ) func main() { - addr4, err := net.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) - if err != nil { - panic(err) + var useV4, useV6 bool + if len(os.Args) > 1 { + switch os.Args[1] { + case "-v4only": + useV4 = true + useV6 = false + case "-v6only": + useV4 = false + useV6 = true + default: + useV4 = true + useV6 = true + } + } else { + useV4 = true + useV6 = true } - addr6, err := net.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6) - if err != nil { - panic(err) - } + var packetConnV4 *ipv4.PacketConn + if useV4 { + addr4, err := net.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) + if err != nil { + panic(err) + } - l4, err := net.ListenUDP("udp4", addr4) - if err != nil { - panic(err) + l4, err := net.ListenUDP("udp4", addr4) + if err != nil { + panic(err) + } + + packetConnV4 = ipv4.NewPacketConn(l4) } - l6, err := net.ListenUDP("udp6", addr6) - if err != nil { - panic(err) + var packetConnV6 *ipv6.PacketConn + if useV6 { + addr6, err := net.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6) + if err != nil { + panic(err) + } + + l6, err := net.ListenUDP("udp6", addr6) + if err != nil { + panic(err) + } + + packetConnV6 = ipv6.NewPacketConn(l6) } - server, err := mdns.Server(ipv4.NewPacketConn(l4), ipv6.NewPacketConn(l6), &mdns.Config{}) + server, err := mdns.Server(packetConnV4, packetConnV6, &mdns.Config{}) if err != nil { panic(err) } - answer, src, err := server.Query(context.TODO(), "pion-test.local") + answer, src, err := server.QueryAddr(context.TODO(), "pion-test.local") fmt.Println(answer) fmt.Println(src) fmt.Println(err)