From 36c89ad8256892398bbb7b6218daba39c8e1a1df Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 19 Dec 2023 14:02:20 -0500 Subject: [PATCH] Test and fix queryDatagram --- dns/resolver.go | 10 ++-- dns/resolver_test.go | 131 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 132 insertions(+), 9 deletions(-) diff --git a/dns/resolver.go b/dns/resolver.go index 7178f8aa..6db14989 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -161,17 +161,19 @@ func queryDatagram(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Messa } else { buf = make([]byte, maxDNSPacketSize) } + var returnErr error for { n, err := conn.Read(buf) if err != nil { - return nil, fmt.Errorf("failed to read message: %w", err) + return nil, errors.Join(returnErr, fmt.Errorf("failed to read message: %w", err)) } - buf = buf[:n] var msg dnsmessage.Message - if err = msg.Unpack(buf); err != nil { - return nil, fmt.Errorf("failed to unpack DNS response: %w", err) + if err := msg.Unpack(buf[:n]); err != nil { + returnErr = errors.Join(returnErr, err) + continue } if err := checkResponse(id, q, msg.Header, msg.Questions); err != nil { + returnErr = errors.Join(returnErr, err) continue } return &msg, nil diff --git a/dns/resolver_test.go b/dns/resolver_test.go index 14d0fe79..d5238306 100644 --- a/dns/resolver_test.go +++ b/dns/resolver_test.go @@ -16,7 +16,9 @@ package dns import ( "encoding/binary" + "io" "math/rand" + "net" "strings" "testing" @@ -73,16 +75,18 @@ func Test_appendRequest(t *testing.T) { require.NoError(t, err) id := uint16(1234) - buf, err := appendRequest(id, *q, []byte{}) + offset := 2 + buf, err := appendRequest(id, *q, make([]byte, offset)) require.NoError(t, err) + require.Equal(t, make([]byte, offset), buf[:offset]) - // 12 bytes header + 5 question + 11 EDNS(0) OPT RR - require.Equal(t, 28, len(buf)) + // offset + 12 bytes header + 5 question + 11 EDNS(0) OPT RR + require.Equal(t, offset+28, len(buf)) - require.Equal(t, id, binary.BigEndian.Uint16(buf)) + require.Equal(t, id, binary.BigEndian.Uint16(buf[offset:])) var request dnsmessage.Message - err = request.Unpack(buf) + err = request.Unpack(buf[offset:]) require.NoError(t, err) require.Equal(t, id, request.ID) require.Equal(t, 1, len(request.Questions)) @@ -173,3 +177,120 @@ func Test_checkResponse(t *testing.T) { require.Error(t, err) }) } + +type queryResult struct { + msg *dnsmessage.Message + err error +} + +func testDatagramExchange(t *testing.T, server func(request dnsmessage.Message, conn net.Conn, clientDone <-chan queryResult)) { + front, back := net.Pipe() + q, err := NewQuestion("example.com.", dnsmessage.TypeAAAA) + require.NoError(t, err) + clientDone := make(chan queryResult) + go func() { + msg, err := queryDatagram(front, *q) + clientDone <- queryResult{msg, err} + }() + // Read request. + buf := make([]byte, 512) + n, err := back.Read(buf) + require.NoError(t, err) + buf = buf[:n] + // Verify request. + var reqMsg dnsmessage.Message + reqMsg.Unpack(buf) + reqID := reqMsg.ID + expectedBuf, err := appendRequest(reqID, *q, make([]byte, 0, 512)) + require.NoError(t, err) + require.Equal(t, expectedBuf, buf) + + server(reqMsg, back, clientDone) +} + +func Test_queryDatagram(t *testing.T) { + t.Run("Success", func(t *testing.T) { + testDatagramExchange(t, func(req dnsmessage.Message, conn net.Conn, clientDone <-chan queryResult) { + // Send bogus response. + _, err := conn.Write([]byte{0, 0}) + require.NoError(t, err) + + // Prepare response message. + q := req.Questions[0] + var resp dnsmessage.Message + resp.ID = req.ID + resp.Header.Response = true + resp.Questions = []dnsmessage.Question{q} + resp.Answers = []dnsmessage.Resource{{ + Header: dnsmessage.ResourceHeader{Name: q.Name, Type: q.Type, Class: q.Class, TTL: 100}, + Body: &dnsmessage.AAAAResource{AAAA: [16]byte(net.IPv6loopback)}, + }} + resp.Authorities = []dnsmessage.Resource{} + resp.Additionals = []dnsmessage.Resource{} + + // Send message with invalid ID first. + badMsg := resp + badMsg.ID = req.ID + 1 + buf, err := (&badMsg).Pack() + require.NoError(t, err) + _, err = conn.Write(buf) + require.NoError(t, err) + + // Send valid response. + buf, err = (&resp).Pack() + require.NoError(t, err) + _, err = conn.Write(buf) + require.NoError(t, err) + + // Wait for queryDatagram. + result := <-clientDone + require.NoError(t, result.err) + require.NotNil(t, result.msg) + require.Equal(t, resp, *result.msg) + }) + }) + t.Run("BadResponse", func(t *testing.T) { + testDatagramExchange(t, func(req dnsmessage.Message, conn net.Conn, clientDone <-chan queryResult) { + // Send response. + _, err := conn.Write([]byte{0}) + require.NoError(t, err) + + // Close writer. + conn.Close() + + // Wait for queryDatagram. + result := <-clientDone + require.Equal(t, 2, len(result.err.(interface{ Unwrap() []error }).Unwrap())) + require.ErrorIs(t, result.err, io.EOF) + }) + }) + t.Run("FailedClientWrite", func(t *testing.T) { + front, back := net.Pipe() + back.Close() + q, err := NewQuestion("example.com.", dnsmessage.TypeAAAA) + require.NoError(t, err) + clientDone := make(chan queryResult) + go func() { + msg, err := queryDatagram(front, *q) + clientDone <- queryResult{msg, err} + }() + // Wait for queryDatagram. + result := <-clientDone + require.Error(t, result.err) + }) + t.Run("FailedClientRead", func(t *testing.T) { + front, back := net.Pipe() + q, err := NewQuestion("example.com.", dnsmessage.TypeAAAA) + require.NoError(t, err) + clientDone := make(chan queryResult) + go func() { + msg, err := queryDatagram(front, *q) + clientDone <- queryResult{msg, err} + }() + back.Read(make([]byte, 521)) + back.Close() + // Wait for queryDatagram. + result := <-clientDone + require.Error(t, result.err) + }) +}