diff --git a/x/examples/find-strategy/main.go b/x/examples/find-strategy/main.go index fd749c4f..51b958d3 100644 --- a/x/examples/find-strategy/main.go +++ b/x/examples/find-strategy/main.go @@ -16,6 +16,7 @@ package main import ( "context" + "crypto/tls" "flag" "fmt" "io" @@ -32,21 +33,16 @@ import ( var debugLog log.Logger = *log.New(io.Discard, "", 0) -type dialFunc func(context.Context, string) (net.Conn, error) - -func makeDialFunc[C net.Conn](d func(context.Context, string) (C, error)) dialFunc { - return func(ctx context.Context, addr string) (net.Conn, error) { - return d(ctx, addr) - } -} +type connectFunc func(context.Context) (net.Conn, error) type dnsClient struct { - dial dialFunc - client dns.Client + id string + connect connectFunc + client dns.Client } -func (c *dnsClient) ExchangeContext(ctx context.Context, m *dns.Msg, addr string) (r *dns.Msg, rtt time.Duration, err error) { - conn, err := c.dial(ctx, addr) +func (c *dnsClient) ExchangeContext(ctx context.Context, m *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) { + conn, err := c.connect(ctx) if err != nil { return nil, 0, err } @@ -56,20 +52,48 @@ func (c *dnsClient) ExchangeContext(ctx context.Context, m *dns.Msg, addr string return c.client.ExchangeWithConnContext(ctx, m, dnsConn) } -func makeTCPClient(sd transport.StreamDialer) *dnsClient { +func newTCPClient(sd transport.StreamDialer, resolverAddr string) *dnsClient { return &dnsClient{ - dial: makeDialFunc(sd.Dial), + id: fmt.Sprintf("tcp:%v", resolverAddr), + connect: func(ctx context.Context) (net.Conn, error) { + return sd.Dial(ctx, resolverAddr) + }, client: dns.Client{Net: "tcp"}, } } -func makeUDPClient(pd transport.PacketDialer) *dnsClient { +func newUDPClient(pd transport.PacketDialer, resolverAddr string) *dnsClient { return &dnsClient{ - dial: makeDialFunc(pd.Dial), + id: fmt.Sprintf("udp:%v", resolverAddr), + connect: func(ctx context.Context) (net.Conn, error) { + return pd.Dial(ctx, resolverAddr) + }, client: dns.Client{Net: ""}, } } +func newTLSClient(sd transport.StreamDialer, resolverAddr string) (*dnsClient, error) { + host, _, err := net.SplitHostPort(resolverAddr) + if err != nil { + return nil, fmt.Errorf("could not parse resolver address: %v", err) + } + return &dnsClient{ + id: fmt.Sprintf("tls:%v", resolverAddr), + connect: func(ctx context.Context) (net.Conn, error) { + conn, err := sd.Dial(ctx, resolverAddr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: host, + // NextProtos: []string{"dot"}, + }) + return tlsConn, err + }, + client: dns.Client{Net: "tcp"}, + }, nil +} + func mixCase(domain string) string { var mixed []rune for i, r := range domain { @@ -82,39 +106,46 @@ func mixCase(domain string) string { return string(mixed) } -func main() { - verboseFlag := flag.Bool("v", false, "Enable debug output") - // typeFlag := flag.String("type", "A", "The type of the query (A, AAAA, CNAME, NS or TXT).") - // resolverFlag := flag.String("resolver", "", "The address of the recursive DNS resolver to use in host:port format. If the port is missing, it's assumed to be 53") - transportFlag := flag.String("transport", "", "The transport for the connection to the recursive DNS resolver") - // tcpFlag := flag.Bool("tcp", false, "Force TCP when querying the DNS resolver") - domainFlag := flag.String("domain", "", "The test domain to find strategies") - - flag.Parse() - if *verboseFlag { - debugLog = *log.New(os.Stderr, "[DEBUG] ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) - } - - clients := []*dnsClient{} - packetDialer, err := config.NewPacketDialer(*transportFlag) +func getARootNameserver() (string, error) { + nsList, err := net.LookupNS(".") if err != nil { - log.Fatalf("Could not create packet dialer: %v", err) + return "", fmt.Errorf("could not get list of root nameservers: %v", err) } - clients = append(clients, makeUDPClient(packetDialer)) - streamDialer, err := config.NewStreamDialer(*transportFlag) - if err != nil { - log.Fatalf("Could not create stream dialer: %v", err) + if len(nsList) == 0 { + return "", fmt.Errorf("empty list of root nameservers") } - clients = append(clients, makeTCPClient(streamDialer)) + return nsList[0].Host, nil +} - nsList, err := net.LookupNS(".") +// type clientTester struct { +// testDomain string +// resolvedNS string +// } + +// func (t *clientTester) TestClient(client *dnsClient) (int, error) { +// score := 0 +// // A query to root domain should return no answer. +// var request dns.Msg +// requestDomain := mixCase(t.testDomain) +// request.SetQuestion(requestDomain, dns.TypeA) +// response, _, err := client.ExchangeContext(context.Background(), &request, t.resolvedNS) +// if err != nil { +// score -= 1 +// fmt.Printf("; status=error: %v\n", err) +// } +// debugLog.Printf(";Response: %v", response) +// if len(response.Answer) > 0 { +// score -= 1 +// fmt.Printf("; status=unexpected answer: %v\n", response.Answer) +// // TODO: use as blocking fingerprint. +// } +// } + +func fingerprint(pd transport.PacketDialer, sd transport.StreamDialer, testDomain string) { + rootNS, err := getARootNameserver() if err != nil { - log.Fatalf("Could not get list of root nameservers: %v", err) - } - if len(nsList) == 0 { - log.Fatalf("Empty list of root nameservers") + log.Fatalf("Failed to find root nameserver: %v", err) } - rootNS := nsList[0].Host debugLog.Printf("Root nameserver is %v", rootNS) allNSIPs, err := net.LookupIP(rootNS) @@ -135,27 +166,24 @@ func main() { } } - testDomain := dns.Fqdn(*domainFlag) - for _, rootNSIP := range ips { resolvedNS := net.JoinHostPort(rootNSIP.String(), "53") - for _, client := range clients { - switch client.client.Net { + for _, proto := range []string{"udp", "tcp"} { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var client *dnsClient + switch proto { case "tcp": - fmt.Print("proto=tcp") + client = newTCPClient(sd, resolvedNS) default: - fmt.Print("proto=udp") - } - if rootNSIP.To4() != nil { - fmt.Print("4") - } else { - fmt.Print("6") + client = newUDPClient(pd, resolvedNS) } var request dns.Msg requestDomain := mixCase(testDomain) request.SetQuestion(requestDomain, dns.TypeA) - response, _, err := client.ExchangeContext(context.Background(), &request, resolvedNS) + response, _, err := client.ExchangeContext(ctx, &request) + fmt.Print(client.id) if err != nil { fmt.Printf("; status=error: %v\n", err) continue @@ -173,6 +201,78 @@ func main() { fmt.Print("; status=ok\n") } } +} + +func main() { + verboseFlag := flag.Bool("v", false, "Enable debug output") + // typeFlag := flag.String("type", "A", "The type of the query (A, AAAA, CNAME, NS or TXT).") + // resolverFlag := flag.String("resolver", "", "The address of the recursive DNS resolver to use in host:port format. If the port is missing, it's assumed to be 53") + transportFlag := flag.String("transport", "", "The transport for the connection to the recursive DNS resolver") + // tcpFlag := flag.Bool("tcp", false, "Force TCP when querying the DNS resolver") + domainFlag := flag.String("domain", "", "The test domain to find strategies") + + flag.Parse() + if *verboseFlag { + debugLog = *log.New(os.Stderr, "[DEBUG] ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) + } + + if *domainFlag == "" { + log.Fatal("Must specify flag --domain") + } + testDomain := dns.Fqdn(*domainFlag) + + packetDialer, err := config.NewPacketDialer(*transportFlag) + if err != nil { + log.Fatalf("Could not create packet dialer: %v", err) + } + streamDialer, err := config.NewStreamDialer(*transportFlag) + if err != nil { + log.Fatalf("Could not create stream dialer: %v", err) + } + fmt.Println("Fingerprinting") + fingerprint(packetDialer, streamDialer, testDomain) + + clients := []*dnsClient{} + for _, resolverAddr := range []string{"8.8.8.8:53", "[2001:4860:4860::8888]:53", "1.1.1.1:53", "9.9.9.9:53", "9.9.9.9:9953", "208.67.222.222:53", "208.67.222.222:443"} { + clients = append(clients, newUDPClient(packetDialer, resolverAddr)) + clients = append(clients, newTCPClient(streamDialer, resolverAddr)) + } + for _, resolverAddr := range []string{"dns.google:853", "8.8.8.8:853", "[2001:4860:4860::8888]:853", "1.1.1.1:853", "9.9.9.9:853", "wikimedia-dns.org:853"} { + client, err := newTLSClient(streamDialer, resolverAddr) + if err != nil { + log.Fatalf("Failed to create TLS client: %v", err) + } + clients = append(clients, client) + } + + fmt.Println("Finding strategies") + for _, client := range clients { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + fmt.Printf("client=%v", client.id) + var request dns.Msg + requestDomain := mixCase(testDomain) + request.SetQuestion(requestDomain, dns.TypeA) + response, _, err := client.ExchangeContext(ctx, &request) + if err != nil { + fmt.Printf("; status=error: %v\n", err) + continue + } + debugLog.Printf(";Response: %v", response) + if len(response.Answer) == 0 { + fmt.Printf("; status=no answers") + continue + } + if response.Answer[0].Header().Name != requestDomain { + fmt.Printf("; status=domain mismatch: %v\n", response.Answer[0]) + continue + } + // if response.Answer[0].Header().Name != requestDomain { + // fmt.Print("; status=case mismatch\n") + // continue + // } + fmt.Print("; status=ok\n") + } // TODO: // Go over list of public resolvers, restricted to working categories.