From df90c43ae2132e82a87bed3f7d00a032d463851d Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Mon, 18 Dec 2023 18:16:12 -0500 Subject: [PATCH] Finish library and add tests --- .github/workflows/test.yml | 4 +- dns/doc.go | 14 +- dns/resolver.go | 244 +++++++++------ dns/resolver_net_test.go | 76 +++++ dns/resolver_test.go | 454 ++++++++++++++++++++++++++++ transport/stream.go | 8 + transport/tls/stream_dialer.go | 45 ++- transport/tls/stream_dialer_test.go | 29 +- 8 files changed, 719 insertions(+), 155 deletions(-) create mode 100644 dns/resolver_net_test.go create mode 100644 dns/resolver_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d0a4021e..b0648654 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,7 +35,7 @@ jobs: run: go build -C x -o bin/ -v ./... - name: Test SDK - run: go test -v -race -bench '.' ./... -benchtime=100ms + run: go test -v -race -bench '.' ./... -benchtime=100ms -tags nettest - name: Test X - run: go test -C x -v -race -bench '.' ./... -benchtime=100ms + run: go test -C x -v -race -bench '.' ./... -benchtime=100ms -tags nettest diff --git a/dns/doc.go b/dns/doc.go index 8c036845..96e5726b 100644 --- a/dns/doc.go +++ b/dns/doc.go @@ -16,20 +16,24 @@ Package dns provides utilities to interact with the Domain Name System (DNS). The [Domain Name System] (DNS) is responsible for mapping domain names to IP addresses. -Because domain resolution gatekeeps connections and is predominantly done in plaintext, it is commonly used -for network-level filtering. +Because domain resolution gatekeeps connections and is predominantly done in plaintext, it is [commonly used +for network-level filtering]. + +# Transports The main concept in this library is that of a [Resolver], which allows code to query the DNS. Different implementations are provided -to perform DNS resolution over multiple transports: +to perform DNS resolution over different transports: - - DNS-over-UDP: the standard mechanism of querying resolvers. Communication is done in plaintext, using port 53. - - [DNS-over-TCP]: alternative to UDP when responses are large. Communication is done in plaintext, using port 53. + - [DNS-over-UDP]: the standard mechanism of querying resolvers. Communication is done in plaintext, using port 53. + - [DNS-over-TCP]: alternative to UDP that allows for more reliable delivery and larger responses, but requires establishing a connection. Communication is done in plaintext, using port 53. - [DNS-over-TLS] (DoT): uses the TCP protocol, but over a connection encrypted with TLS. Is uses port 853, which makes it very easy to block using the port number, as no other protocol is assigned to that port. - [DNS-over-HTTPS] (DoH): uses HTTP exchanges for querying the resolver and communicates over a connection encrypted with TLS. It uses port 443. That makes the DoH traffic undistinguishable from web traffic, making it harder to block. [Domain Name System]: https://datatracker.ietf.org/doc/html/rfc1034 +[commonly used for network-level filtering]: https://datatracker.ietf.org/doc/html/rfc9505#section-5.1.1 +[DNS-over-UDP]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1 [DNS-over-TCP]: https://datatracker.ietf.org/doc/html/rfc7766 [DNS-over-TLS]: https://datatracker.ietf.org/doc/html/rfc7858 [DNS-over-HTTPS]: https://datatracker.ietf.org/doc/html/rfc8484 diff --git a/dns/resolver.go b/dns/resolver.go index aa25100b..759fcde7 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -17,7 +17,6 @@ package dns import ( "bytes" "context" - "crypto/tls" "encoding/binary" "errors" "fmt" @@ -29,9 +28,29 @@ import ( "time" "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/transport/tls" "golang.org/x/net/dns/dnsmessage" ) +var ( + ErrBadRequest = errors.New("request input is bad") + ErrDial = errors.New("dial DNS resolver failed") + ErrSend = errors.New("send DNS message failed") + ErrReceive = errors.New("receive DNS message failed") + ErrBadResponse = errors.New("response message is invalid") +) + +type nestedError struct { + is error + wrapped error +} + +func (e *nestedError) Is(target error) bool { return target == e.is } + +func (e *nestedError) Unwrap() error { return e.wrapped } + +func (e *nestedError) Error() string { return e.is.Error() + ": " + e.wrapped.Error() } + // Resolver can query the DNS with a question, and obtain a DNS message as response. // This abstraction helps hide the underlying transport protocol. type Resolver interface { @@ -47,8 +66,13 @@ func (f FuncResolver) Query(ctx context.Context, q dnsmessage.Question) (*dnsmes } // NewQuestion is a convenience function to create a [dnsmessage.Question]. +// The input domain is interpreted as fully-qualified. If the end "." is missing, it's added. func NewQuestion(domain string, qtype dnsmessage.Type) (*dnsmessage.Question, error) { - name, err := dnsmessage.NewName(domain) + fullDomain := domain + if len(domain) == 0 || domain[len(domain)-1] != '.' { + fullDomain += "." + } + name, err := dnsmessage.NewName(fullDomain) if err != nil { return nil, fmt.Errorf("cannot parse domain name: %w", err) } @@ -67,27 +91,27 @@ const maxDNSPacketSize = 1232 func appendRequest(id uint16, q dnsmessage.Question, buf []byte) ([]byte, error) { b := dnsmessage.NewBuilder(buf, dnsmessage.Header{ID: id, RecursionDesired: true}) if err := b.StartQuestions(); err != nil { - return nil, fmt.Errorf("failed to start questions: %w", err) + return nil, fmt.Errorf("start questions failed: %w", err) } if err := b.Question(q); err != nil { - return nil, fmt.Errorf("failed to add question: %w", err) + return nil, fmt.Errorf("add question failed: %w", err) } if err := b.StartAdditionals(); err != nil { - return nil, fmt.Errorf("failed to start additionals: %w", err) + return nil, fmt.Errorf("start additionals failed: %w", err) } var rh dnsmessage.ResourceHeader // Set the maximum payload size we support, as per https://datatracker.ietf.org/doc/html/rfc6891#section-4.3 if err := rh.SetEDNS0(maxDNSPacketSize, dnsmessage.RCodeSuccess, false); err != nil { - return nil, fmt.Errorf("failed to set EDNS(0) parameters: %w", err) + return nil, fmt.Errorf("set EDNS(0) failed: %w", err) } if err := b.OPTResource(rh, dnsmessage.OPTResource{}); err != nil { - return nil, fmt.Errorf("failed to add OPT RR: %w", err) + return nil, fmt.Errorf("add OPT RR failed: %w", err) } buf, err := b.Finish() if err != nil { - return nil, fmt.Errorf("failed to serialize message: %w", err) + return nil, fmt.Errorf("message serialization failed: %w", err) } return buf, nil } @@ -126,7 +150,7 @@ func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage // https://datatracker.ietf.org/doc/html/rfc5452#section-4.2 if len(respQs) == 0 { - return errors.New("no questions in response") + return errors.New("response had no questions") } respQ := respQs[0] if reqQues.Type != respQ.Type || reqQues.Class != respQ.Class || !equalASCIIName(reqQues.Name, respQ.Name) { @@ -138,24 +162,58 @@ func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage const maxMsgSize = 65535 +// queryDatagram implements a DNS query over a datagram protocol. +func queryDatagram(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message, error) { + // Reference: https://cs.opensource.google/go/go/+/master:src/net/dnsclient_unix.go?q=func:dnsPacketRoundTrip&ss=go%2Fgo + id := uint16(rand.Uint32()) + buf, err := appendRequest(id, q, make([]byte, 0, maxDNSPacketSize)) + if err != nil { + return nil, &nestedError{ErrBadRequest, fmt.Errorf("append request failed: %w", err)} + } + if _, err := conn.Write(buf); err != nil { + return nil, &nestedError{ErrSend, err} + } + buf = buf[:cap(buf)] + var returnErr error + for { + n, err := conn.Read(buf) + if err != nil { + return nil, &nestedError{ErrReceive, errors.Join(returnErr, fmt.Errorf("read message failed: %w", err))} + } + var msg dnsmessage.Message + if err := msg.Unpack(buf[:n]); err != nil { + returnErr = errors.Join(returnErr, err) + continue + } + if err := checkResponse(id, q, msg.Header, msg.Questions); err != nil { + returnErr = errors.Join(returnErr, err) + continue + } + return &msg, nil + } +} + // queryStream implements a DNS query over a stream protocol. It frames the messages by prepending them with a 2-byte length prefix. func queryStream(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message, error) { + // Reference: https://cs.opensource.google/go/go/+/master:src/net/dnsclient_unix.go?q=func:dnsStreamRoundTrip&ss=go%2Fgo id := uint16(rand.Uint32()) buf, err := appendRequest(id, q, make([]byte, 2, 514)) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, &nestedError{ErrBadRequest, fmt.Errorf("append request failed: %w", err)} } if len(buf) > maxMsgSize { - return nil, fmt.Errorf("message too large: %v bytes", len(buf)) + return nil, &nestedError{ErrBadRequest, fmt.Errorf("message too large: %v bytes", len(buf))} } binary.BigEndian.PutUint16(buf[:2], uint16(len(buf)-2)) + // TODO: Consider writer.ReadFrom(net.Buffers) in case the writer is a TCPConn. if _, err := conn.Write(buf); err != nil { - return nil, fmt.Errorf("failed to write message: %w", err) + return nil, &nestedError{ErrSend, err} } + var msgLen uint16 if err := binary.Read(conn, binary.BigEndian, &msgLen); err != nil { - return nil, fmt.Errorf("failed to read message length: %w", err) + return nil, &nestedError{ErrReceive, fmt.Errorf("read message length failed: %w", err)} } if int(msgLen) <= cap(buf) { buf = buf[:msgLen] @@ -163,87 +221,42 @@ func queryStream(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message buf = make([]byte, msgLen) } if _, err = io.ReadFull(conn, buf); err != nil { - return nil, fmt.Errorf("failed to read message: %w", err) + return nil, &nestedError{ErrReceive, fmt.Errorf("read message failed: %w", err)} } + var msg dnsmessage.Message if err = msg.Unpack(buf); err != nil { - return nil, fmt.Errorf("failed to unpack DNS response: %w", err) + return nil, &nestedError{ErrBadResponse, fmt.Errorf("response failed to unpack: %w", err)} } if err := checkResponse(id, q, msg.Header, msg.Questions); err != nil { - return nil, fmt.Errorf("invalid response: %w", err) + return nil, &nestedError{ErrBadResponse, err} } return &msg, nil } -// queryDatagram implements a DNS query over a datagram protocol. -func queryDatagram(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message, error) { - id := uint16(rand.Uint32()) - buf, err := appendRequest(id, q, make([]byte, 0, 512)) +func ensurePort(address string, defaultPort string) string { + host, port, err := net.SplitHostPort(address) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + // Failed to parse as host:port. Assume address is a host. + return net.JoinHostPort(address, defaultPort) } - if len(buf) > maxMsgSize { - return nil, fmt.Errorf("message too large: %v bytes", len(buf)) - } - if _, err := conn.Write(buf); err != nil { - return nil, fmt.Errorf("failed to write message: %w", err) - } - if cap(buf) >= maxDNSPacketSize { - buf = buf[:maxDNSPacketSize] - } else { - buf = make([]byte, maxDNSPacketSize) - } - for { - n, err := conn.Read(buf) - if err != nil { - return nil, fmt.Errorf("failed to read message: %w", err) - } - buf = buf[:n] - var msg dnsmessage.Message - if err = msg.Unpack(buf); err != nil { - return nil, fmt.Errorf("failed to unpack DNS response: %w", err) - } - if err := checkResponse(id, q, msg.Header, msg.Questions); err != nil { - continue - } - return &msg, nil + if port == "" { + return net.JoinHostPort(host, defaultPort) } -} - -// NewTCPResolver creates a [Resolver] that implements the [DNS-over-TCP] protocol, using a [transport.StreamDialer] for transport. -// It creates a new connection to the resolver for every request. -// -// [DNS-over-TCP]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2 -func NewTCPResolver(sd transport.StreamDialer, resolverAddr string) Resolver { - // See https://cs.opensource.google/go/go/+/master:src/net/dnsclient_unix.go;l=127;drc=6146a73d279d73b6138191929d2f1fad22188f51 - // TODO: Consider handling Authenticated Data. - return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { - conn, err := sd.Dial(ctx, resolverAddr) - if err != nil { - return nil, fmt.Errorf("failed to dial resolver: %w", err) - } - // TODO: consider keeping the connection open for performance. - // Need to think about security implications. - defer conn.Close() - if deadline, ok := ctx.Deadline(); ok { - conn.SetDeadline(deadline) - } - return queryStream(conn, q) - }) + return address } // NewUDPResolver creates a [Resolver] that implements the DNS-over-UDP protocol, using a [transport.PacketDialer] for transport. -// It creates a new connection to the resolver for every request. +// It uses a different port for every request. // // [DNS-over-UDP]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1 func NewUDPResolver(pd transport.PacketDialer, resolverAddr string) Resolver { - // See https://cs.opensource.google/go/go/+/master:src/net/dnsclient_unix.go;l=100;drc=6146a73d279d73b6138191929d2f1fad22188f51 + resolverAddr = ensurePort(resolverAddr, "53") return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { conn, err := pd.Dial(ctx, resolverAddr) if err != nil { - return nil, fmt.Errorf("failed to dial resolver: %w", err) + return nil, &nestedError{ErrDial, err} } - // TODO: reuse connection, as per https://datatracker.ietf.org/doc/html/rfc7766#section-6.2.1. defer conn.Close() if deadline, ok := ctx.Deadline(); ok { conn.SetDeadline(deadline) @@ -252,42 +265,74 @@ func NewUDPResolver(pd transport.PacketDialer, resolverAddr string) Resolver { }) } +type streamResolver struct { + NewConn func(context.Context) (transport.StreamConn, error) +} + +func (r *streamResolver) Query(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { + conn, err := r.NewConn(ctx) + if err != nil { + return nil, &nestedError{ErrDial, err} + } + // TODO: reuse connection, as per https://datatracker.ietf.org/doc/html/rfc7766#section-6.2.1. + defer conn.Close() + if deadline, ok := ctx.Deadline(); ok { + conn.SetDeadline(deadline) + } + return queryStream(conn, q) +} + +// NewTCPResolver creates a [Resolver] that implements the [DNS-over-TCP] protocol, using a [transport.StreamDialer] for transport. +// It creates a new connection to the resolver for every request. +// +// [DNS-over-TCP]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2 +func NewTCPResolver(sd transport.StreamDialer, resolverAddr string) Resolver { + // TODO: Consider handling Authenticated Data. + resolverAddr = ensurePort(resolverAddr, "53") + return &streamResolver{ + NewConn: func(ctx context.Context) (transport.StreamConn, error) { + return sd.Dial(ctx, resolverAddr) + }, + } +} + // NewTLSResolver creates a [Resolver] that implements the [DNS-over-TLS] protocol, using a [transport.StreamDialer] -// to connect to the resolverAddr the the resolverName as the TLS server name. +// to connect to the resolverAddr, and the resolverName as the TLS server name. // It creates a new connection to the resolver for every request. // // [DNS-over-TLS]: https://datatracker.ietf.org/doc/html/rfc7858 func NewTLSResolver(sd transport.StreamDialer, resolverAddr string, resolverName string) Resolver { - return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { - baseConn, err := sd.Dial(ctx, resolverAddr) - if err != nil { - return nil, fmt.Errorf("failed to dial resolver: %w", err) - } - tlsConn := tls.Client(baseConn, &tls.Config{ - ServerName: resolverName, - }) - // TODO: reuse connection, as per https://datatracker.ietf.org/doc/html/rfc7766#section-6.2.1. - defer tlsConn.Close() - if deadline, ok := ctx.Deadline(); ok { - tlsConn.SetDeadline(deadline) - } - return queryStream(tlsConn, q) - }) + resolverAddr = ensurePort(resolverAddr, "853") + return &streamResolver{ + NewConn: func(ctx context.Context) (transport.StreamConn, error) { + baseConn, err := sd.Dial(ctx, resolverAddr) + if err != nil { + return nil, err + } + return tls.WrapConn(ctx, baseConn, resolverName) + }, + } } // NewHTTPSResolver creates a [Resolver] that implements the [DNS-over-HTTPS] protocol, using a [transport.StreamDialer] -// to connect to the resolverAddr the url as the DoH template URI. +// to connect to the resolverAddr, and the url as the DoH template URI. // It uses an internal HTTP client that reuses connections when possible. // // [DNS-over-HTTPS]: https://datatracker.ietf.org/doc/html/rfc8484 func NewHTTPSResolver(sd transport.StreamDialer, resolverAddr string, url string) Resolver { + resolverAddr = ensurePort(resolverAddr, "443") dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { if !strings.HasPrefix(network, "tcp") { // TODO: Support UDP for QUIC. return nil, fmt.Errorf("protocol not supported: %v", network) } - return sd.Dial(ctx, resolverAddr) + conn, err := sd.Dial(ctx, resolverAddr) + if err != nil { + return nil, &nestedError{ErrDial, err} + } + return conn, nil } + // TODO: add mechanism to close idle connections. // Copied from Intra: https://github.com/Jigsaw-Code/Intra/blob/d3554846a1146ae695e28a8ed6dd07f0cd310c5a/Android/tun2socks/intra/doh/doh.go#L213-L219 httpClient := http.Client{ Transport: &http.Transport{ @@ -298,35 +343,40 @@ func NewHTTPSResolver(sd transport.StreamDialer, resolverAddr string, url string }, } return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { + // Prepare request. buf, err := appendRequest(0, q, make([]byte, 0, 512)) if err != nil { - return nil, fmt.Errorf("failed to create DNS request: %w", err) + return nil, &nestedError{ErrBadRequest, fmt.Errorf("append request failed: %w", err)} } httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(buf)) if err != nil { - return nil, fmt.Errorf("failed to create HTTP request: %w", err) + return nil, &nestedError{ErrBadRequest, fmt.Errorf("create HTTP request failed: %w", err)} } const mimetype = "application/dns-message" httpReq.Header.Add("Accept", mimetype) httpReq.Header.Add("Content-Type", mimetype) + + // Send request and get response. httpResp, err := httpClient.Do(httpReq) if err != nil { - return nil, fmt.Errorf("failed to get HTTP response: %w", err) + return nil, &nestedError{ErrReceive, fmt.Errorf("failed to get HTTP response: %w", err)} } defer httpResp.Body.Close() if httpResp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("got HTTP status %v", httpResp.StatusCode) + return nil, &nestedError{ErrReceive, fmt.Errorf("got HTTP status %v", httpResp.StatusCode)} } response, err := io.ReadAll(httpResp.Body) if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) + return nil, &nestedError{ErrReceive, fmt.Errorf("failed to read response: %w", err)} } + + // Process response. var msg dnsmessage.Message if err = msg.Unpack(response); err != nil { - return nil, fmt.Errorf("failed to unpack DNS response: %w", err) + return nil, &nestedError{ErrBadResponse, fmt.Errorf("failed to unpack DNS response: %w", err)} } if err := checkResponse(0, q, msg.Header, msg.Questions); err != nil { - return nil, fmt.Errorf("invalid response: %w", err) + return nil, &nestedError{ErrBadResponse, err} } return &msg, nil }) diff --git a/dns/resolver_net_test.go b/dns/resolver_net_test.go new file mode 100644 index 00000000..5ef23ed2 --- /dev/null +++ b/dns/resolver_net_test.go @@ -0,0 +1,76 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// go:build nettest + +package dns + +import ( + "context" + "testing" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/stretchr/testify/require" + "golang.org/x/net/dns/dnsmessage" +) + +// TODO: Make tests not depend on the network. +func newTestContext(t *testing.T) context.Context { + if deadline, ok := t.Deadline(); ok { + ctx, cancel := context.WithDeadline(context.Background(), deadline) + t.Cleanup(cancel) + return ctx + } + return context.Background() +} + +func TestNewUDPResolver(t *testing.T) { + ctx := newTestContext(t) + resolver := NewUDPResolver(&transport.UDPPacketDialer{}, "8.8.8.8") + q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA) + require.NoError(t, err) + resp, err := resolver.Query(ctx, *q) + require.NoError(t, err) + require.GreaterOrEqual(t, len(resp.Answers), 1) +} + +func TestNewTCPResolver(t *testing.T) { + ctx := newTestContext(t) + resolver := NewTCPResolver(&transport.TCPStreamDialer{}, "8.8.8.8") + q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA) + require.NoError(t, err) + resp, err := resolver.Query(ctx, *q) + require.NoError(t, err) + require.GreaterOrEqual(t, len(resp.Answers), 1) +} + +func TestNewTLSResolver(t *testing.T) { + ctx := newTestContext(t) + resolver := NewTLSResolver(&transport.TCPStreamDialer{}, "8.8.8.8", "8.8.8.8") + q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA) + require.NoError(t, err) + resp, err := resolver.Query(ctx, *q) + require.NoError(t, err) + require.GreaterOrEqual(t, len(resp.Answers), 1) +} + +func TestNewHTTPSResolver(t *testing.T) { + ctx := newTestContext(t) + resolver := NewHTTPSResolver(&transport.TCPStreamDialer{}, "8.8.8.8", "https://8.8.8.8/dns-query") + q, err := NewQuestion("getoutline.org.", dnsmessage.TypeAAAA) + require.NoError(t, err) + resp, err := resolver.Query(ctx, *q) + require.NoError(t, err) + require.GreaterOrEqual(t, len(resp.Answers), 1) +} diff --git a/dns/resolver_test.go b/dns/resolver_test.go new file mode 100644 index 00000000..29881560 --- /dev/null +++ b/dns/resolver_test.go @@ -0,0 +1,454 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dns + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "math/rand" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/net/dns/dnsmessage" +) + +func TestNewQuestionTypes(t *testing.T) { + testDomain := "example.com." + qname, err := dnsmessage.NewName(testDomain) + require.NoError(t, err) + for _, qtype := range []dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA, dnsmessage.TypeCNAME} { + t.Run(qtype.String(), func(t *testing.T) { + q, err := NewQuestion(testDomain, qtype) + require.NoError(t, err) + require.Equal(t, qname, q.Name) + require.Equal(t, qtype, q.Type) + require.Equal(t, dnsmessage.ClassINET, q.Class) + }) + } +} + +func TestNewQuestionNotFQDN(t *testing.T) { + testDomain := "example.com" + q, err := NewQuestion(testDomain, dnsmessage.TypeAAAA) + require.NoError(t, err) + require.Equal(t, dnsmessage.MustNewName("example.com."), q.Name) +} + +func TestNewQuestionRoot(t *testing.T) { + testDomain := "." + qname, err := dnsmessage.NewName(testDomain) + require.NoError(t, err) + q, err := NewQuestion(testDomain, dnsmessage.TypeAAAA) + require.NoError(t, err) + require.Equal(t, qname, q.Name) +} + +func TestNewQuestionEmpty(t *testing.T) { + testDomain := "" + q, err := NewQuestion(testDomain, dnsmessage.TypeAAAA) + require.NoError(t, err) + require.Equal(t, dnsmessage.MustNewName("."), q.Name) +} + +func TestNewQuestionLongName(t *testing.T) { + testDomain := strings.Repeat("a.", 200) + _, err := NewQuestion(testDomain, dnsmessage.TypeAAAA) + require.Error(t, err) +} + +func Test_appendRequest(t *testing.T) { + q, err := NewQuestion(".", dnsmessage.TypeAAAA) + require.NoError(t, err) + + id := uint16(1234) + offset := 2 + buf, err := appendRequest(id, *q, make([]byte, offset)) + require.NoError(t, err) + require.Equal(t, make([]byte, offset), buf[:offset]) + + // offset + 12 bytes header + 5 question + 11 EDNS(0) OPT RR + require.Equal(t, offset+28, len(buf)) + + require.Equal(t, id, binary.BigEndian.Uint16(buf[offset:])) + + var request dnsmessage.Message + err = request.Unpack(buf[offset:]) + require.NoError(t, err) + require.Equal(t, id, request.ID) + require.Equal(t, 1, len(request.Questions)) + require.Equal(t, *q, request.Questions[0]) + require.Equal(t, 0, len(request.Answers)) + require.Equal(t, 0, len(request.Authorities)) + // ENDS(0) OPT resource record. + require.Equal(t, 1, len(request.Additionals)) + optRR := dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("."), + Type: dnsmessage.TypeOPT, + Class: maxDNSPacketSize, + TTL: 0, + Length: 0, + }, + Body: &dnsmessage.OPTResource{}, + } + require.Equal(t, optRR, request.Additionals[0]) +} + +func Test_foldCase(t *testing.T) { + require.Equal(t, byte('Y'), foldCase('Y')) + require.Equal(t, byte('Y'), foldCase('y')) + // Only fold ASCII + require.Equal(t, byte('ý'), foldCase('ý')) + require.Equal(t, byte('-'), foldCase('-')) +} + +func Test_equalASCIIName(t *testing.T) { + require.True(t, equalASCIIName(dnsmessage.MustNewName("My-Example.Com"), dnsmessage.MustNewName("mY-eXAMPLE.cOM"))) + require.False(t, equalASCIIName(dnsmessage.MustNewName("example.com"), dnsmessage.MustNewName("example.net"))) + require.False(t, equalASCIIName(dnsmessage.MustNewName("example.com"), dnsmessage.MustNewName("example.com.br"))) + require.False(t, equalASCIIName(dnsmessage.MustNewName("example.com"), dnsmessage.MustNewName("myexample.com"))) +} + +func Test_checkResponse(t *testing.T) { + reqID := uint16(rand.Uint32()) + reqQ := dnsmessage.Question{ + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + } + expectedHdr := dnsmessage.Header{ID: reqID, Response: true} + expectedQs := []dnsmessage.Question{reqQ} + t.Run("Match", func(t *testing.T) { + err := checkResponse(reqID, reqQ, expectedHdr, expectedQs) + require.NoError(t, err) + }) + t.Run("CaseInsensitive", func(t *testing.T) { + mixedQ := reqQ + mixedQ.Name = dnsmessage.MustNewName("Example.Com.") + err := checkResponse(reqID, reqQ, expectedHdr, []dnsmessage.Question{mixedQ}) + require.NoError(t, err) + }) + t.Run("NotResponse", func(t *testing.T) { + badHdr := expectedHdr + badHdr.Response = false + err := checkResponse(reqID, reqQ, badHdr, expectedQs) + require.Error(t, err) + }) + t.Run("BadID", func(t *testing.T) { + badHdr := expectedHdr + badHdr.ID = reqID + 1 + err := checkResponse(reqID, reqQ, badHdr, expectedQs) + require.Error(t, err) + }) + t.Run("NoQuestions", func(t *testing.T) { + err := checkResponse(reqID, reqQ, expectedHdr, []dnsmessage.Question{}) + require.Error(t, err) + }) + t.Run("BadQuestionType", func(t *testing.T) { + badQ := reqQ + badQ.Type = dnsmessage.TypeA + err := checkResponse(reqID, reqQ, expectedHdr, []dnsmessage.Question{badQ}) + require.Error(t, err) + }) + t.Run("BadQuestionClass", func(t *testing.T) { + badQ := reqQ + badQ.Class = dnsmessage.ClassCHAOS + err := checkResponse(reqID, reqQ, expectedHdr, []dnsmessage.Question{badQ}) + require.Error(t, err) + }) + t.Run("BadQuestionName", func(t *testing.T) { + badQ := reqQ + badQ.Name = dnsmessage.MustNewName("notexample.invalid.") + err := checkResponse(reqID, reqQ, expectedHdr, []dnsmessage.Question{badQ}) + require.Error(t, err) + }) +} + +func newMessageResponse(req dnsmessage.Message, answer dnsmessage.ResourceBody, ttl uint32) (dnsmessage.Message, error) { + var resp dnsmessage.Message + if len(req.Questions) != 1 { + return resp, fmt.Errorf("Invalid number of questions %v", len(req.Questions)) + } + q := req.Questions[0] + resp.ID = req.ID + resp.Header.Response = true + resp.Questions = []dnsmessage.Question{q} + resp.Answers = []dnsmessage.Resource{{ + Header: dnsmessage.ResourceHeader{Name: q.Name, Type: q.Type, Class: q.Class, TTL: ttl}, + Body: answer, + }} + resp.Authorities = []dnsmessage.Resource{} + resp.Additionals = []dnsmessage.Resource{} + return resp, nil +} + +type queryResult struct { + msg *dnsmessage.Message + err error +} + +func testDatagramExchange(t *testing.T, server func(request dnsmessage.Message, conn net.Conn)) (*dnsmessage.Message, error) { + front, back := net.Pipe() + q, err := NewQuestion("example.com.", dnsmessage.TypeAAAA) + require.NoError(t, err) + clientDone := make(chan queryResult) + go func() { + msg, err := queryDatagram(front, *q) + clientDone <- queryResult{msg, err} + }() + // Read request. + buf := make([]byte, 512) + n, err := back.Read(buf) + require.NoError(t, err) + buf = buf[:n] + // Verify request. + var reqMsg dnsmessage.Message + reqMsg.Unpack(buf) + reqID := reqMsg.ID + expectedBuf, err := appendRequest(reqID, *q, make([]byte, 0, 512)) + require.NoError(t, err) + require.Equal(t, expectedBuf, buf) + + server(reqMsg, back) + + result := <-clientDone + return result.msg, result.err +} + +func Test_queryDatagram(t *testing.T) { + t.Run("Success", func(t *testing.T) { + var respSent dnsmessage.Message + respRcvd, err := testDatagramExchange(t, func(req dnsmessage.Message, conn net.Conn) { + // Send bogus response. + _, err := conn.Write([]byte{0, 0}) + require.NoError(t, err) + + // Prepare response message. + respSent, err = newMessageResponse(req, &dnsmessage.AAAAResource{AAAA: [16]byte(net.IPv6loopback)}, 100) + require.NoError(t, err) + + // Send message with invalid ID first. + badMsg := respSent + badMsg.ID = req.ID + 1 + buf, err := (&badMsg).Pack() + require.NoError(t, err) + _, err = conn.Write(buf) + require.NoError(t, err) + + // Send valid response. + buf, err = (&respSent).Pack() + require.NoError(t, err) + _, err = conn.Write(buf) + require.NoError(t, err) + }) + require.NoError(t, err) + require.NotNil(t, respRcvd) + require.Equal(t, respSent, *respRcvd) + }) + t.Run("BadResponse", func(t *testing.T) { + _, err := testDatagramExchange(t, func(req dnsmessage.Message, conn net.Conn) { + // Send bad response. + _, err := conn.Write([]byte{0}) + require.NoError(t, err) + // Close writer. + conn.Close() + }) + require.ErrorIs(t, err, ErrReceive) + require.Equal(t, 2, len(errors.Unwrap(err).(interface{ Unwrap() []error }).Unwrap())) + require.ErrorIs(t, err, io.EOF) + }) + t.Run("FailedClientWrite", func(t *testing.T) { + front, back := net.Pipe() + back.Close() + q, err := NewQuestion("example.com.", dnsmessage.TypeAAAA) + require.NoError(t, err) + clientDone := make(chan queryResult) + go func() { + msg, err := queryDatagram(front, *q) + clientDone <- queryResult{msg, err} + }() + // Wait for queryDatagram. + result := <-clientDone + require.ErrorIs(t, result.err, ErrSend) + require.ErrorIs(t, result.err, io.ErrClosedPipe) + }) + t.Run("FailedClientRead", func(t *testing.T) { + front, back := net.Pipe() + q, err := NewQuestion("example.com.", dnsmessage.TypeAAAA) + require.NoError(t, err) + clientDone := make(chan queryResult) + go func() { + msg, err := queryDatagram(front, *q) + clientDone <- queryResult{msg, err} + }() + back.Read(make([]byte, 521)) + back.Close() + // Wait for queryDatagram. + result := <-clientDone + require.ErrorIs(t, result.err, ErrReceive) + require.ErrorIs(t, result.err, io.EOF) + }) +} + +func testStreamExchange(t *testing.T, server func(request dnsmessage.Message, conn net.Conn)) (*dnsmessage.Message, error) { + front, back := net.Pipe() + q, err := NewQuestion("example.com.", dnsmessage.TypeAAAA) + require.NoError(t, err) + clientDone := make(chan queryResult) + go func() { + msg, err := queryStream(front, *q) + clientDone <- queryResult{msg, err} + }() + // Read request. + var msgLen uint16 + require.NoError(t, binary.Read(back, binary.BigEndian, &msgLen)) + buf := make([]byte, msgLen) + n, err := back.Read(buf) + require.NoError(t, err) + buf = buf[:n] + // Verify request. + var reqMsg dnsmessage.Message + reqMsg.Unpack(buf) + reqID := reqMsg.ID + expectedBuf, err := appendRequest(reqID, *q, make([]byte, 0, 512)) + require.NoError(t, err) + require.Equal(t, expectedBuf, buf) + + server(reqMsg, back) + + result := <-clientDone + return result.msg, result.err +} + +func Test_queryStream(t *testing.T) { + t.Run("Success", func(t *testing.T) { + var respSent dnsmessage.Message + respRcvd, err := testStreamExchange(t, func(req dnsmessage.Message, conn net.Conn) { + var err error + // Prepare response message. + respSent, err = newMessageResponse(req, &dnsmessage.AAAAResource{AAAA: [16]byte(net.IPv6loopback)}, 100) + require.NoError(t, err) + + // Send response. + buf, err := (&respSent).Pack() + require.NoError(t, err) + require.NoError(t, binary.Write(conn, binary.BigEndian, uint16(len(buf)))) + _, err = conn.Write(buf) + require.NoError(t, err) + }) + require.NoError(t, err) + require.NotNil(t, respRcvd) + require.Equal(t, respSent, *respRcvd) + }) + t.Run("ShortRead", func(t *testing.T) { + _, err := testStreamExchange(t, func(req dnsmessage.Message, conn net.Conn) { + // Send response. + _, err := conn.Write([]byte{0}) + require.NoError(t, err) + + // Close writer. + conn.Close() + }) + require.ErrorIs(t, err, ErrReceive) + require.ErrorIs(t, err, io.ErrUnexpectedEOF) + }) + t.Run("ShortMessage", func(t *testing.T) { + _, err := testStreamExchange(t, func(req dnsmessage.Message, conn net.Conn) { + // Send response. + _, err := conn.Write([]byte{0, 100, 0}) + require.NoError(t, err) + // Close writer. + conn.Close() + }) + require.ErrorIs(t, err, ErrReceive) + require.ErrorIs(t, err, io.ErrUnexpectedEOF) + }) + t.Run("BadMessageFormat", func(t *testing.T) { + _, err := testStreamExchange(t, func(req dnsmessage.Message, conn net.Conn) { + // Send response. + _, err := conn.Write([]byte{0, 2, 0, 0}) + require.NoError(t, err) + + // Close writer. + conn.Close() + }) + require.ErrorIs(t, err, ErrBadResponse) + }) + t.Run("BadMessageContent", func(t *testing.T) { + _, err := testStreamExchange(t, func(req dnsmessage.Message, conn net.Conn) { + // Make response with no answer and invalid ID. + resp := req + resp.ID = req.ID + 1 + resp.Response = true + buf, err := resp.AppendPack(make([]byte, 2, 514)) + require.NoError(t, err) + binary.BigEndian.PutUint16(buf, uint16(len(buf)-2)) + // Send response. + _, err = conn.Write(buf) + require.NoError(t, err) + + // Close writer. + conn.Close() + }) + require.ErrorIs(t, err, ErrBadResponse) + }) + t.Run("FailedClientWrite", func(t *testing.T) { + front, back := net.Pipe() + back.Close() + q, err := NewQuestion("example.com.", dnsmessage.TypeAAAA) + require.NoError(t, err) + clientDone := make(chan queryResult) + go func() { + msg, err := queryStream(front, *q) + clientDone <- queryResult{msg, err} + }() + // Wait for client. + result := <-clientDone + require.ErrorIs(t, result.err, ErrSend) + require.ErrorIs(t, result.err, io.ErrClosedPipe) + }) + t.Run("FailedClientRead", func(t *testing.T) { + front, back := net.Pipe() + q, err := NewQuestion("example.com.", dnsmessage.TypeAAAA) + require.NoError(t, err) + clientDone := make(chan queryResult) + go func() { + msg, err := queryStream(front, *q) + clientDone <- queryResult{msg, err} + }() + back.Read(make([]byte, 521)) + back.Close() + // Wait for queryDatagram. + result := <-clientDone + require.ErrorIs(t, result.err, ErrReceive) + require.ErrorIs(t, result.err, io.EOF) + }) +} + +func Test_ensurePort(t *testing.T) { + require.Equal(t, "example.com:8080", ensurePort("example.com:8080", "80")) + require.Equal(t, "example.com:443", ensurePort("example.com", "443")) + require.Equal(t, "example.com:443", ensurePort("example.com:", "443")) + require.Equal(t, "8.8.8.8:8080", ensurePort("8.8.8.8:8080", "443")) + require.Equal(t, "8.8.8.8:443", ensurePort("8.8.8.8", "443")) + require.Equal(t, "8.8.8.8:443", ensurePort("8.8.8.8:", "443")) + require.Equal(t, "[2001:4860:4860::8888]:8080", ensurePort("[2001:4860:4860::8888]:8080", "443")) + require.Equal(t, "[2001:4860:4860::8888]:443", ensurePort("2001:4860:4860::8888", "443")) + require.Equal(t, "[2001:4860:4860::8888]:443", ensurePort("[2001:4860:4860::8888]:", "443")) +} diff --git a/transport/stream.go b/transport/stream.go index 5d279308..285275d8 100644 --- a/transport/stream.go +++ b/transport/stream.go @@ -132,3 +132,11 @@ func (d *TCPStreamDialer) Dial(ctx context.Context, addr string) (StreamConn, er } return conn.(*net.TCPConn), nil } + +// FuncStreamDialer is a [StreamDialer] that uses the given function to dial. +type FuncStreamDialer func(ctx context.Context, addr string) (StreamConn, error) + +// Query implements the [Resolver] interface. +func (f FuncStreamDialer) Dial(ctx context.Context, addr string) (StreamConn, error) { + return f(ctx, addr) +} diff --git a/transport/tls/stream_dialer.go b/transport/tls/stream_dialer.go index af677ba0..383341f2 100644 --- a/transport/tls/stream_dialer.go +++ b/transport/tls/stream_dialer.go @@ -68,7 +68,11 @@ func (d *StreamDialer) Dial(ctx context.Context, remoteAddr string) (transport.S if err != nil { return nil, err } - conn, err := WrapConn(ctx, innerConn, remoteAddr, d.options...) + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return nil, fmt.Errorf("invalid address: %w", err) + } + conn, err := WrapConn(ctx, innerConn, host, d.options...) if err != nil { innerConn.Close() return nil, err @@ -120,25 +124,17 @@ func (cfg *ClientConfig) toStdConfig() *tls.Config { } // ClientOption allows configuring the parameters to be used for a client TLS connection. -type ClientOption func(host string, port int, config *ClientConfig) +type ClientOption func(serverName string, config *ClientConfig) // WrapConn wraps a [transport.StreamConn] in a TLS connection. -func WrapConn(ctx context.Context, conn transport.StreamConn, remoteAdr string, options ...ClientOption) (transport.StreamConn, error) { - host, portStr, err := net.SplitHostPort(remoteAdr) - if err != nil { - return nil, fmt.Errorf("could not parse remote address: %w", err) - } - host = normalizeHost(host) - port, err := net.DefaultResolver.LookupPort(ctx, "tcp", portStr) - if err != nil { - return nil, fmt.Errorf("could not resolve port: %w", err) - } - cfg := ClientConfig{ServerName: host, CertificateName: host} +func WrapConn(ctx context.Context, conn transport.StreamConn, serverName string, options ...ClientOption) (transport.StreamConn, error) { + cfg := ClientConfig{ServerName: serverName, CertificateName: serverName} + normName := normalizeHost(serverName) for _, option := range options { - option(host, port, &cfg) + option(normName, &cfg) } tlsConn := tls.Client(conn, cfg.toStdConfig()) - err = tlsConn.HandshakeContext(ctx) + err := tlsConn.HandshakeContext(ctx) if err != nil { return nil, err } @@ -151,22 +147,19 @@ func WrapConn(ctx context.Context, conn transport.StreamConn, remoteAdr string, // // [Server Name Indication]: https://datatracker.ietf.org/doc/html/rfc6066#section-3 func WithSNI(hostName string) ClientOption { - return func(_ string, _ int, config *ClientConfig) { + return func(_ string, config *ClientConfig) { config.ServerName = hostName } } -// IfHostPort applies the given option if the host and port matches the dialed one. -func IfHostPort(matchHost string, matchPort int, option ClientOption) ClientOption { +// IfHost applies the given option if the host matches the dialed one. +func IfHost(matchHost string, option ClientOption) ClientOption { matchHost = normalizeHost(matchHost) - return func(host string, port int, config *ClientConfig) { + return func(host string, config *ClientConfig) { if matchHost != "" && matchHost != host { return } - if matchPort != 0 && matchPort != port { - return - } - option(host, port, config) + option(host, config) } } @@ -176,14 +169,14 @@ func IfHostPort(matchHost string, matchPort int, option ClientOption) ClientOpti // [Application-Layer Protocol Negotiation]: https://datatracker.ietf.org/doc/html/rfc7301 // [IANA's registry]: https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids func WithALPN(protocolNameList []string) ClientOption { - return func(_ string, _ int, config *ClientConfig) { + return func(_ string, config *ClientConfig) { config.NextProtos = protocolNameList } } // WithSessionCache sets the [tls.ClientSessionCache] to enable session resumption of TLS connections. func WithSessionCache(sessionCache tls.ClientSessionCache) ClientOption { - return func(_ string, _ int, config *ClientConfig) { + return func(_ string, config *ClientConfig) { config.SessionCache = sessionCache } } @@ -191,7 +184,7 @@ func WithSessionCache(sessionCache tls.ClientSessionCache) ClientOption { // WithCertificateName sets the hostname to be used for the certificate cerification. // If absent, defaults to the dialed hostname. func WithCertificateName(hostname string) ClientOption { - return func(_ string, _ int, config *ClientConfig) { + return func(_ string, config *ClientConfig) { config.CertificateName = hostname } } diff --git a/transport/tls/stream_dialer_test.go b/transport/tls/stream_dialer_test.go index d32611ce..33fc1708 100644 --- a/transport/tls/stream_dialer_test.go +++ b/transport/tls/stream_dialer_test.go @@ -95,8 +95,8 @@ func TestAllCustom(t *testing.T) { func TestHostSelector(t *testing.T) { sd, err := NewStreamDialer(&transport.TCPStreamDialer{}, - IfHostPort("dns.google", 0, WithSNI("decoy.example.com")), - IfHostPort("www.youtube.com", 0, WithSNI("notyoutube.com")), + IfHost("dns.google", WithSNI("decoy.example.com")), + IfHost("www.youtube.com", WithSNI("notyoutube.com")), ) require.NoError(t, err) @@ -113,35 +113,14 @@ func TestHostSelector(t *testing.T) { conn.Close() } -func TestPortSelector(t *testing.T) { - sd, err := NewStreamDialer(&transport.TCPStreamDialer{}, - IfHostPort("", 443, WithALPN([]string{"http/1.1"})), - IfHostPort("www.google.com", 443, WithALPN([]string{"h2"})), - IfHostPort("", 853, WithALPN([]string{"dot"})), - ) - require.NoError(t, err) - - conn, err := sd.Dial(context.Background(), "dns.google:443") - require.NoError(t, err) - tlsConn := conn.(streamConn) - require.Equal(t, "http/1.1", tlsConn.ConnectionState().NegotiatedProtocol) - conn.Close() - - conn, err = sd.Dial(context.Background(), "www.google.com:443") - require.NoError(t, err) - tlsConn = conn.(streamConn) - require.Equal(t, "h2", tlsConn.ConnectionState().NegotiatedProtocol) - conn.Close() -} - func TestWithSNI(t *testing.T) { var cfg ClientConfig - WithSNI("example.com")("", 0, &cfg) + WithSNI("example.com")("", &cfg) require.Equal(t, "example.com", cfg.ServerName) } func TestWithALPN(t *testing.T) { var cfg ClientConfig - WithALPN([]string{"h2", "http/1.1"})("", 0, &cfg) + WithALPN([]string{"h2", "http/1.1"})("", &cfg) require.Equal(t, []string{"h2", "http/1.1"}, cfg.NextProtos) }