Skip to content

Commit

Permalink
Support dig based unicast requests
Browse files Browse the repository at this point in the history
e.g. dig -p 5353 @224.0.0.251 pion-test.local
  • Loading branch information
edaniels committed Feb 9, 2024
1 parent 23c33ce commit 70b6f53
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
44 changes: 36 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package mdns

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -265,24 +266,36 @@ func Server(
if err := multicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil {
c.log.Warnf("failed to SetControlMessage(ipv4.FlagInterface) on multicast IPv4 PacketConn %v", err)
}
if err := multicastPktConnV4.SetControlMessage(ipv4.FlagDst, true); err != nil {
c.log.Warnf("failed to SetControlMessage(ipv4.FlagDst) on multicast IPv4 PacketConn %v", err)
}

Check warning on line 271 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L270-L271

Added lines #L270 - L271 were not covered by tests
c.multicastPktConnV4 = ipPacketConn4{multicastPktConnV4, log}
}
if multicastPktConnV6 != nil {
if err := multicastPktConnV6.SetControlMessage(ipv6.FlagInterface, true); err != nil {
c.log.Warnf("failed to SetControlMessage(ipv6.FlagInterface) on multicast IPv6 PacketConn %v", err)
}
if err := multicastPktConnV6.SetControlMessage(ipv6.FlagDst, true); err != nil {
c.log.Warnf("failed to SetControlMessage(ipv6.FlagInterface) on multicast IPv6 PacketConn %v", err)
}

Check warning on line 280 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L279-L280

Added lines #L279 - L280 were not covered by tests
c.multicastPktConnV6 = ipPacketConn6{multicastPktConnV6, log}
}
if unicastPktConnV4 != nil {
if err := unicastPktConnV4.SetControlMessage(ipv4.FlagInterface, true); err != nil {
c.log.Warnf("failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", err)
}
if err := unicastPktConnV4.SetControlMessage(ipv4.FlagDst, true); err != nil {
c.log.Warnf("failed to SetControlMessage(ipv4.FlagInterface) on unicast IPv4 PacketConn %v", err)
}

Check warning on line 289 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L288-L289

Added lines #L288 - L289 were not covered by tests
c.unicastPktConnV4 = ipPacketConn4{unicastPktConnV4, log}
}
if unicastPktConnV6 != nil {
if err := unicastPktConnV6.SetControlMessage(ipv6.FlagInterface, true); err != nil {
c.log.Warnf("failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", err)
}
if err := unicastPktConnV6.SetControlMessage(ipv6.FlagDst, true); err != nil {
c.log.Warnf("failed to SetControlMessage(ipv6.FlagInterface) on unicast IPv6 PacketConn %v", err)
}

Check warning on line 298 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L297-L298

Added lines #L297 - L298 were not covered by tests
c.unicastPktConnV6 = ipPacketConn6{unicastPktConnV6, log}
}

Expand Down Expand Up @@ -624,14 +637,15 @@ func (c *Conn) writeToSocket(ifIndex int, b []byte, hasLoopbackData bool, wType
}
}

func createAnswer(name string, addr net.IP) (dnsmessage.Message, error) {
func createAnswer(id uint16, name string, addr net.IP) (dnsmessage.Message, error) {
packedName, err := dnsmessage.NewName(name)
if err != nil {
return dnsmessage.Message{}, err
}

msg := dnsmessage.Message{
Header: dnsmessage.Header{
ID: id,
Response: true,
Authoritative: true,
},
Expand Down Expand Up @@ -669,8 +683,8 @@ func createAnswer(name string, addr net.IP) (dnsmessage.Message, error) {
return msg, nil
}

func (c *Conn) sendAnswer(name string, ifIndex int, result net.IP, dst *net.UDPAddr) {
answer, err := createAnswer(name, result)
func (c *Conn) sendAnswer(queryID uint16, name string, ifIndex int, result net.IP, dst *net.UDPAddr) {
answer, err := createAnswer(queryID, name, result)
if err != nil {
c.log.Warnf("failed to create mDNS answer %v", err)
return
Expand All @@ -687,6 +701,7 @@ func (c *Conn) sendAnswer(name string, ifIndex int, result net.IP, dst *net.UDPA

type ipControlMessage struct {
IfIndex int
Dst net.IP
}

type ipPacketConn interface {
Expand All @@ -705,7 +720,7 @@ func (c ipPacketConn4) ReadFrom(b []byte) (n int, cm *ipControlMessage, src net.
if err != nil || cm4 == nil {
return n, nil, src, err
}
return n, &ipControlMessage{IfIndex: cm4.IfIndex}, src, err
return n, &ipControlMessage{IfIndex: cm4.IfIndex, Dst: cm4.Dst}, src, err
}

func (c ipPacketConn4) WriteTo(b []byte, via *net.Interface, cm *ipControlMessage, dst net.Addr) (n int, err error) {
Expand Down Expand Up @@ -736,7 +751,7 @@ func (c ipPacketConn6) ReadFrom(b []byte) (n int, cm *ipControlMessage, src net.
if err != nil || cm6 == nil {
return n, nil, src, err
}
return n, &ipControlMessage{IfIndex: cm6.IfIndex}, src, err
return n, &ipControlMessage{IfIndex: cm6.IfIndex, Dst: cm6.Dst}, src, err
}

func (c ipPacketConn6) WriteTo(b []byte, via *net.Interface, cm *ipControlMessage, dst net.Addr) (n int, err error) {
Expand Down Expand Up @@ -773,8 +788,10 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int
c.log.Debugf("got read on %s from %s", name, src)

var ifIndex int
var pktDst net.IP
if cm != nil {
ifIndex = cm.IfIndex
pktDst = cm.Dst
}
srcAddr, ok := src.(*net.UDPAddr)
if !ok {
Expand All @@ -783,7 +800,8 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int
}

func() {
if _, err := p.Start(b[:n]); err != nil {
header, err := p.Start(b[:n])
if err != nil {
c.log.Warnf("failed to parse mDNS packet %v", err)
return
}
Expand All @@ -801,7 +819,17 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int
continue
}

shouldUnicastResponse := (q.Class & (1 << 15)) != 0
// https://datatracker.ietf.org/doc/html/rfc6762#section-6
// The destination UDP port in all Multicast DNS responses MUST be 5353,
// and the destination address MUST be the mDNS IPv4 link-local
// multicast address 224.0.0.251 or its IPv6 equivalent FF02::FB, except
// when generating a reply to a query that explicitly requested a
// unicast response
shouldUnicastResponse :=

Check failure on line 828 in conn.go

View workflow job for this annotation

GitHub Actions / lint / Go

File is not `gofumpt`-ed (gofumpt)
(q.Class&(1<<15)) != 0 || // via the unicast-response bit
srcAddr.Port != 5353 || // by virtue of being a legacy query (Section 6.7), or
(len(pktDst) != 0 && !(bytes.Equal(pktDst, c.dstAddr4.IP) || // by virtue of being a direct unicast query

Check failure on line 831 in conn.go

View workflow job for this annotation

GitHub Actions / lint / Go

SA1021: use net.IP.Equal to compare net.IPs, not bytes.Equal (staticcheck)
bytes.Equal(pktDst, c.dstAddr6.IP)))

Check failure on line 832 in conn.go

View workflow job for this annotation

GitHub Actions / lint / Go

SA1021: use net.IP.Equal to compare net.IPs, not bytes.Equal (staticcheck)
var dst *net.UDPAddr
if shouldUnicastResponse {
dst = srcAddr
Expand Down Expand Up @@ -883,7 +911,7 @@ func (c *Conn) readLoop(name string, pktConn ipPacketConn, inboundBufferSize int
continue
}
}
c.sendAnswer(q.Name.String(), ifIndex, localAddress, dst)
c.sendAnswer(header.ID, q.Name.String(), ifIndex, localAddress, dst)
}
}
}
Expand Down
12 changes: 2 additions & 10 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,23 +659,15 @@ func TestResourceParsing(t *testing.T) {
name := "test-server."

t.Run("A Record", func(t *testing.T) {
answer, err := createAnswer(name, net.IP{127, 0, 0, 1})
answer, err := createAnswer(1, name, net.IP{127, 0, 0, 1})
if err != nil {
t.Fatal(err)
}
lookForIP(answer, []byte{127, 0, 0, 1}, t)
})

t.Run("AAAA Record", func(t *testing.T) {
// because it's compatible
answer, err := createAnswer(name, net.ParseIP("127.0.0.1"))
if err != nil {
t.Fatal(err)
}
// this is wrong...?
lookForIP(answer, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 127, 0, 0, 1}, t)

answer, err = createAnswer(name, net.ParseIP("::1"))
answer, err := createAnswer(1, name, net.ParseIP("::1"))
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 70b6f53

Please sign in to comment.