Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Dec 28, 2023
1 parent 945ed28 commit 68acbf4
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 80 deletions.
37 changes: 37 additions & 0 deletions dns/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2023 Jigsaw Operations LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
Package dns provides utilities to interact with the Domain Name System (DNS).
The [Domain Name System] (DNS) is responsible for mapping domain names to IP addresses.
Because domain resolution gatekeeps connections and is predominantly done in plaintext, it is commonly used
for network-level filtering.
The main concept in this library is that of a [Resolver], which allows code to query the DNS. Different implementations are provided
to perform DNS resolution over multiple transports:
- DNS-over-UDP: the standard mechanism of querying resolvers. Communication is done in plaintext, using port 53.
- [DNS-over-TCP]: alternative to UDP when responses are large. Communication is done in plaintext, using port 53.
- [DNS-over-TLS] (DoT): uses the TCP protocol, but over a connection encrypted with TLS. Is uses port 853, which
makes it very easy to block using the port number, as no other protocol is assigned to that port.
- [DNS-over-HTTPS] (DoH): uses HTTP exchanges for querying the resolver and communicates over a connection encrypted with TLS. It uses
port 443. That makes the DoH traffic undistinguishable from web traffic, making it harder to block.
[Domain Name System]: https://datatracker.ietf.org/doc/html/rfc1034
[DNS-over-TCP]: https://datatracker.ietf.org/doc/html/rfc7766
[DNS-over-TLS]: https://datatracker.ietf.org/doc/html/rfc7858
[DNS-over-HTTPS]: https://datatracker.ietf.org/doc/html/rfc8484
*/
package dns
162 changes: 82 additions & 80 deletions dns/roundtrip.go → dns/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,17 @@ import (
"golang.org/x/net/dns/dnsmessage"
)

// RoundTripper is an interface representing the ability to execute a
// single DNS transaction, obtaining the Response for a given Request.
// Resolver can query the DNS with a question, and obtain a DNS message as response.
// This abstraction helps hide the underlying transport protocol.
type RoundTripper interface {
RoundTrip(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error)
type Resolver interface {
Query(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error)
}

// FuncRoundTripper is a [RoundTripper] that uses the given function for the round trip.
type FuncRoundTripper func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error)
// FuncResolver is a [Resolver] that uses the given function to query DNS.
type FuncResolver func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error)

// RoundTrip implements the [RoundTripper] interface.
func (f FuncRoundTripper) RoundTrip(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
// Query implements the [Resolver] interface.
func (f FuncResolver) Query(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
return f(ctx, q)
}

Expand All @@ -60,22 +59,55 @@ func NewQuestion(domain string, qtype dnsmessage.Type) (*dnsmessage.Question, er
}, nil
}

const maxMsgSize = 65535
// Maximum DNS packet size.
// Value taken from https://dnsflagday.net/2020/.
const maxDNSPacketSize = 1232

// Creates a DNS request using the id and question and appends the bytes to buf.
func appendRequest(id uint16, q dnsmessage.Question, buf []byte) ([]byte, error) {
b := dnsmessage.NewBuilder(buf, dnsmessage.Header{ID: id, RecursionDesired: true})
if err := b.StartQuestions(); err != nil {
return nil, fmt.Errorf("failed to start questions: %w", err)
}
if err := b.Question(q); err != nil {
return nil, fmt.Errorf("failed to add question: %w", err)
}
if err := b.StartAdditionals(); err != nil {
return nil, fmt.Errorf("failed to start additionals: %w", err)
}

var rh dnsmessage.ResourceHeader
// Set the maximum payload size we support, as per https://datatracker.ietf.org/doc/html/rfc6891#section-4.3
if err := rh.SetEDNS0(maxDNSPacketSize, dnsmessage.RCodeSuccess, false); err != nil {
return nil, fmt.Errorf("failed to set EDNS(0) parameters: %w", err)
}
if err := b.OPTResource(rh, dnsmessage.OPTResource{}); err != nil {
return nil, fmt.Errorf("failed to add OPT RR: %w", err)
}

buf, err := b.Finish()
if err != nil {
return nil, fmt.Errorf("failed to serialize message: %w", err)
}
return buf, nil
}

// Fold case as clarified in https://datatracker.ietf.org/doc/html/rfc4343#section-3.
func foldCase(char byte) byte {
if 'a' <= char && char <= 'z' {
return char - 'a' + 'A'
}
return char
}

// equalASCIIName compares DNS name as specified in https://datatracker.ietf.org/doc/html/rfc1035#section-3.1 and
// https://datatracker.ietf.org/doc/html/rfc4343#section-3.
func equalASCIIName(x, y dnsmessage.Name) bool {
if x.Length != y.Length {
return false
}
for i := 0; i < int(x.Length); i++ {
a := x.Data[i]
b := y.Data[i]
if 'A' <= a && a <= 'Z' {
a += 0x20
}
if 'A' <= b && b <= 'Z' {
b += 0x20
}
if a != b {
if foldCase(x.Data[i]) != foldCase(y.Data[i]) {
return false
}
}
Expand Down Expand Up @@ -104,12 +136,14 @@ func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage
return nil
}

// Implements a DNS exchange over a stream protocol. It frames the messages by prepending them with a 2-byte length prefix.
func dnsStreamRoundtrip(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message, error) {
const maxMsgSize = 65535

// queryStream implements a DNS query over a stream protocol. It frames the messages by prepending them with a 2-byte length prefix.
func queryStream(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message, error) {
id := uint16(rand.Uint32())
buf, err := appendRequest(id, q, make([]byte, 2, 514))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create request: %w", err)
}
if len(buf) > maxMsgSize {
return nil, fmt.Errorf("message too large: %v bytes", len(buf))
Expand All @@ -121,7 +155,7 @@ func dnsStreamRoundtrip(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.
}
var msgLen uint16
if err := binary.Read(conn, binary.BigEndian, &msgLen); err != nil {
return nil, fmt.Errorf("failed to read message length: %v", err)
return nil, fmt.Errorf("failed to read message length: %w", err)
}
if int(msgLen) <= cap(buf) {
buf = buf[:msgLen]
Expand All @@ -141,12 +175,12 @@ func dnsStreamRoundtrip(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.
return &msg, nil
}

// Implements a DNS exchange over a datagram protocol.
func dnsPacketRoundtrip(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message, error) {
// queryDatagram implements a DNS query over a datagram protocol.
func queryDatagram(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message, error) {
id := uint16(rand.Uint32())
buf, err := appendRequest(id, q, make([]byte, 0, 512))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create request: %w", err)
}
if len(buf) > maxMsgSize {
return nil, fmt.Errorf("message too large: %v bytes", len(buf))
Expand Down Expand Up @@ -176,109 +210,77 @@ func dnsPacketRoundtrip(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.
}
}

// Maximum DNS packet size.
// Value taken from https://dnsflagday.net/2020/.
const maxDNSPacketSize = 1232

// Creates a DNS request using the id and question and appends the bytes to buf.
func appendRequest(id uint16, q dnsmessage.Question, buf []byte) ([]byte, error) {
b := dnsmessage.NewBuilder(buf, dnsmessage.Header{ID: id, RecursionDesired: true})
if err := b.StartQuestions(); err != nil {
return nil, err
}
if err := b.Question(q); err != nil {
return nil, err
}

// Accept packets up to maxDNSPacketSize. RFC 6891.
if err := b.StartAdditionals(); err != nil {
return nil, err
}
var rh dnsmessage.ResourceHeader
if err := rh.SetEDNS0(maxDNSPacketSize, dnsmessage.RCodeSuccess, false); err != nil {
return nil, err
}
if err := b.OPTResource(rh, dnsmessage.OPTResource{}); err != nil {
return nil, err
}

buf, err := b.Finish()
if err != nil {
return nil, err
}
return buf, nil
}

// NewTCPRoundTripper creates a [RoundTripper] that implements the [DNS-over-TCP] protocol, using a [transport.StreamDialer] for transport.
// NewTCPResolver creates a [Resolver] that implements the [DNS-over-TCP] protocol, using a [transport.StreamDialer] for transport.
// It creates a new connection to the resolver for every request.
//
// [DNS-over-TCP]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2
func NewTCPRoundTripper(sd transport.StreamDialer, resolverAddr string) RoundTripper {
func NewTCPResolver(sd transport.StreamDialer, resolverAddr string) Resolver {
// See https://cs.opensource.google/go/go/+/master:src/net/dnsclient_unix.go;l=127;drc=6146a73d279d73b6138191929d2f1fad22188f51
// TODO: Consider handling Authenticated Data.
return FuncRoundTripper(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
conn, err := sd.Dial(ctx, resolverAddr)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to dial resolver: %w", err)
}
// TODO: consider keeping the connection open for performance.
// Need to think about security implications.
defer conn.Close()
if deadline, ok := ctx.Deadline(); ok {
conn.SetDeadline(deadline)
}
return dnsStreamRoundtrip(conn, q)
return queryStream(conn, q)
})
}

// NewUDPRoundTripper creates a [RoundTripper] that implements the DNS-over-UDP protocol, using a [transport.PacketDialer] for transport.
// NewUDPResolver creates a [Resolver] that implements the DNS-over-UDP protocol, using a [transport.PacketDialer] for transport.
// It creates a new connection to the resolver for every request.
//
// [DNS-over-UDP]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1
func NewUDPRoundTripper(pd transport.PacketDialer, resolverAddr string) RoundTripper {
func NewUDPResolver(pd transport.PacketDialer, resolverAddr string) Resolver {
// See https://cs.opensource.google/go/go/+/master:src/net/dnsclient_unix.go;l=100;drc=6146a73d279d73b6138191929d2f1fad22188f51
return FuncRoundTripper(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
conn, err := pd.Dial(ctx, resolverAddr)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to dial resolver: %w", err)
}
// TODO: reuse connection, as per https://datatracker.ietf.org/doc/html/rfc7766#section-6.2.1.
defer conn.Close()
if deadline, ok := ctx.Deadline(); ok {
conn.SetDeadline(deadline)
}
return dnsPacketRoundtrip(conn, q)
return queryDatagram(conn, q)
})
}

// NewTLSRoundTripper creates a [RoundTripper] that implements the [DNS-over-TLS] protocol, using a [transport.StreamDialer]
// NewTLSResolver creates a [Resolver] that implements the [DNS-over-TLS] protocol, using a [transport.StreamDialer]
// to connect to the resolverAddr the the resolverName as the TLS server name.
// It creates a new connection to the resolver for every request.
//
// [DNS-over-TLS]: https://datatracker.ietf.org/doc/html/rfc7858
func NewTLSRoundTripper(sd transport.StreamDialer, resolverAddr string, resolverName string) RoundTripper {
return FuncRoundTripper(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
func NewTLSResolver(sd transport.StreamDialer, resolverAddr string, resolverName string) Resolver {
return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
baseConn, err := sd.Dial(ctx, resolverAddr)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to dial resolver: %w", err)
}
tlsConn := tls.Client(baseConn, &tls.Config{
ServerName: resolverName,
})
// TODO: keep connection open. Need to handle concurrent requests.
// TODO: reuse connection, as per https://datatracker.ietf.org/doc/html/rfc7766#section-6.2.1.
defer tlsConn.Close()
if deadline, ok := ctx.Deadline(); ok {
tlsConn.SetDeadline(deadline)
}
return dnsStreamRoundtrip(tlsConn, q)
return queryStream(tlsConn, q)
})
}

// NewHTTPSRoundTripper creates a [RoundTripper] that implements the [DNS-over-HTTPS] protocol, using a [transport.StreamDialer]
// NewHTTPSResolver creates a [Resolver] that implements the [DNS-over-HTTPS] protocol, using a [transport.StreamDialer]
// to connect to the resolverAddr the url as the DoH template URI.
// It uses an internal HTTP client that reuses connections when possible.
//
// [DNS-over-HTTPS]: https://datatracker.ietf.org/doc/html/rfc8484
func NewHTTPSRoundTripper(sd transport.StreamDialer, resolverAddr string, url string) RoundTripper {
func NewHTTPSResolver(sd transport.StreamDialer, resolverAddr string, url string) Resolver {
dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) {
if !strings.HasPrefix(network, "tcp") {
// TODO: Support UDP for QUIC.
Expand All @@ -295,21 +297,21 @@ func NewHTTPSRoundTripper(sd transport.StreamDialer, resolverAddr string, url st
ResponseHeaderTimeout: 20 * time.Second, // Same value as Android DNS-over-TLS
},
}
return FuncRoundTripper(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
buf, err := appendRequest(0, q, make([]byte, 0, 512))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create DNS request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(buf))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
const mimetype = "application/dns-message"
httpReq.Header.Add("Accept", mimetype)
httpReq.Header.Add("Content-Type", mimetype)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get HTTP response: %w", err)
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
Expand Down

0 comments on commit 68acbf4

Please sign in to comment.