Skip to content

Commit

Permalink
Enhance
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Oct 24, 2023
1 parent 611c66f commit dbcd07d
Showing 1 changed file with 154 additions and 54 deletions.
208 changes: 154 additions & 54 deletions x/examples/find-strategy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package main

import (
"context"
"crypto/tls"
"flag"
"fmt"
"io"
Expand All @@ -32,21 +33,16 @@ import (

var debugLog log.Logger = *log.New(io.Discard, "", 0)

type dialFunc func(context.Context, string) (net.Conn, error)

func makeDialFunc[C net.Conn](d func(context.Context, string) (C, error)) dialFunc {
return func(ctx context.Context, addr string) (net.Conn, error) {
return d(ctx, addr)
}
}
type connectFunc func(context.Context) (net.Conn, error)

type dnsClient struct {
dial dialFunc
client dns.Client
id string
connect connectFunc
client dns.Client
}

func (c *dnsClient) ExchangeContext(ctx context.Context, m *dns.Msg, addr string) (r *dns.Msg, rtt time.Duration, err error) {
conn, err := c.dial(ctx, addr)
func (c *dnsClient) ExchangeContext(ctx context.Context, m *dns.Msg) (r *dns.Msg, rtt time.Duration, err error) {
conn, err := c.connect(ctx)
if err != nil {
return nil, 0, err
}
Expand All @@ -56,20 +52,48 @@ func (c *dnsClient) ExchangeContext(ctx context.Context, m *dns.Msg, addr string
return c.client.ExchangeWithConnContext(ctx, m, dnsConn)
}

func makeTCPClient(sd transport.StreamDialer) *dnsClient {
func newTCPClient(sd transport.StreamDialer, resolverAddr string) *dnsClient {
return &dnsClient{
dial: makeDialFunc(sd.Dial),
id: fmt.Sprintf("tcp:%v", resolverAddr),
connect: func(ctx context.Context) (net.Conn, error) {
return sd.Dial(ctx, resolverAddr)
},
client: dns.Client{Net: "tcp"},
}
}

func makeUDPClient(pd transport.PacketDialer) *dnsClient {
func newUDPClient(pd transport.PacketDialer, resolverAddr string) *dnsClient {
return &dnsClient{
dial: makeDialFunc(pd.Dial),
id: fmt.Sprintf("udp:%v", resolverAddr),
connect: func(ctx context.Context) (net.Conn, error) {
return pd.Dial(ctx, resolverAddr)
},
client: dns.Client{Net: ""},
}
}

func newTLSClient(sd transport.StreamDialer, resolverAddr string) (*dnsClient, error) {
host, _, err := net.SplitHostPort(resolverAddr)
if err != nil {
return nil, fmt.Errorf("could not parse resolver address: %v", err)
}
return &dnsClient{
id: fmt.Sprintf("tls:%v", resolverAddr),
connect: func(ctx context.Context) (net.Conn, error) {
conn, err := sd.Dial(ctx, resolverAddr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(conn, &tls.Config{
ServerName: host,
// NextProtos: []string{"dot"},
})
return tlsConn, err
},
client: dns.Client{Net: "tcp"},
}, nil
}

func mixCase(domain string) string {
var mixed []rune
for i, r := range domain {
Expand All @@ -82,39 +106,46 @@ func mixCase(domain string) string {
return string(mixed)
}

func main() {
verboseFlag := flag.Bool("v", false, "Enable debug output")
// typeFlag := flag.String("type", "A", "The type of the query (A, AAAA, CNAME, NS or TXT).")
// resolverFlag := flag.String("resolver", "", "The address of the recursive DNS resolver to use in host:port format. If the port is missing, it's assumed to be 53")
transportFlag := flag.String("transport", "", "The transport for the connection to the recursive DNS resolver")
// tcpFlag := flag.Bool("tcp", false, "Force TCP when querying the DNS resolver")
domainFlag := flag.String("domain", "", "The test domain to find strategies")

flag.Parse()
if *verboseFlag {
debugLog = *log.New(os.Stderr, "[DEBUG] ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
}

clients := []*dnsClient{}
packetDialer, err := config.NewPacketDialer(*transportFlag)
func getARootNameserver() (string, error) {
nsList, err := net.LookupNS(".")
if err != nil {
log.Fatalf("Could not create packet dialer: %v", err)
return "", fmt.Errorf("could not get list of root nameservers: %v", err)
}
clients = append(clients, makeUDPClient(packetDialer))
streamDialer, err := config.NewStreamDialer(*transportFlag)
if err != nil {
log.Fatalf("Could not create stream dialer: %v", err)
if len(nsList) == 0 {
return "", fmt.Errorf("empty list of root nameservers")
}
clients = append(clients, makeTCPClient(streamDialer))
return nsList[0].Host, nil
}

nsList, err := net.LookupNS(".")
// type clientTester struct {
// testDomain string
// resolvedNS string
// }

// func (t *clientTester) TestClient(client *dnsClient) (int, error) {
// score := 0
// // A query to root domain should return no answer.
// var request dns.Msg
// requestDomain := mixCase(t.testDomain)
// request.SetQuestion(requestDomain, dns.TypeA)
// response, _, err := client.ExchangeContext(context.Background(), &request, t.resolvedNS)
// if err != nil {
// score -= 1
// fmt.Printf("; status=error: %v\n", err)
// }
// debugLog.Printf(";Response: %v", response)
// if len(response.Answer) > 0 {
// score -= 1
// fmt.Printf("; status=unexpected answer: %v\n", response.Answer)
// // TODO: use as blocking fingerprint.
// }
// }

func fingerprint(pd transport.PacketDialer, sd transport.StreamDialer, testDomain string) {
rootNS, err := getARootNameserver()
if err != nil {
log.Fatalf("Could not get list of root nameservers: %v", err)
}
if len(nsList) == 0 {
log.Fatalf("Empty list of root nameservers")
log.Fatalf("Failed to find root nameserver: %v", err)
}
rootNS := nsList[0].Host
debugLog.Printf("Root nameserver is %v", rootNS)

allNSIPs, err := net.LookupIP(rootNS)
Expand All @@ -135,27 +166,24 @@ func main() {
}
}

testDomain := dns.Fqdn(*domainFlag)

for _, rootNSIP := range ips {
resolvedNS := net.JoinHostPort(rootNSIP.String(), "53")
for _, client := range clients {
switch client.client.Net {
for _, proto := range []string{"udp", "tcp"} {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var client *dnsClient
switch proto {
case "tcp":
fmt.Print("proto=tcp")
client = newTCPClient(sd, resolvedNS)
default:
fmt.Print("proto=udp")
}
if rootNSIP.To4() != nil {
fmt.Print("4")
} else {
fmt.Print("6")
client = newUDPClient(pd, resolvedNS)
}

var request dns.Msg
requestDomain := mixCase(testDomain)
request.SetQuestion(requestDomain, dns.TypeA)
response, _, err := client.ExchangeContext(context.Background(), &request, resolvedNS)
response, _, err := client.ExchangeContext(ctx, &request)
fmt.Print(client.id)
if err != nil {
fmt.Printf("; status=error: %v\n", err)
continue
Expand All @@ -173,6 +201,78 @@ func main() {
fmt.Print("; status=ok\n")
}
}
}

func main() {
verboseFlag := flag.Bool("v", false, "Enable debug output")
// typeFlag := flag.String("type", "A", "The type of the query (A, AAAA, CNAME, NS or TXT).")
// resolverFlag := flag.String("resolver", "", "The address of the recursive DNS resolver to use in host:port format. If the port is missing, it's assumed to be 53")
transportFlag := flag.String("transport", "", "The transport for the connection to the recursive DNS resolver")
// tcpFlag := flag.Bool("tcp", false, "Force TCP when querying the DNS resolver")
domainFlag := flag.String("domain", "", "The test domain to find strategies")

flag.Parse()
if *verboseFlag {
debugLog = *log.New(os.Stderr, "[DEBUG] ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
}

if *domainFlag == "" {
log.Fatal("Must specify flag --domain")
}
testDomain := dns.Fqdn(*domainFlag)

packetDialer, err := config.NewPacketDialer(*transportFlag)
if err != nil {
log.Fatalf("Could not create packet dialer: %v", err)
}
streamDialer, err := config.NewStreamDialer(*transportFlag)
if err != nil {
log.Fatalf("Could not create stream dialer: %v", err)
}
fmt.Println("Fingerprinting")
fingerprint(packetDialer, streamDialer, testDomain)

clients := []*dnsClient{}
for _, resolverAddr := range []string{"8.8.8.8:53", "[2001:4860:4860::8888]:53", "1.1.1.1:53", "9.9.9.9:53", "9.9.9.9:9953", "208.67.222.222:53", "208.67.222.222:443"} {
clients = append(clients, newUDPClient(packetDialer, resolverAddr))
clients = append(clients, newTCPClient(streamDialer, resolverAddr))
}
for _, resolverAddr := range []string{"dns.google:853", "8.8.8.8:853", "[2001:4860:4860::8888]:853", "1.1.1.1:853", "9.9.9.9:853", "wikimedia-dns.org:853"} {
client, err := newTLSClient(streamDialer, resolverAddr)
if err != nil {
log.Fatalf("Failed to create TLS client: %v", err)
}
clients = append(clients, client)
}

fmt.Println("Finding strategies")
for _, client := range clients {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
fmt.Printf("client=%v", client.id)
var request dns.Msg
requestDomain := mixCase(testDomain)
request.SetQuestion(requestDomain, dns.TypeA)
response, _, err := client.ExchangeContext(ctx, &request)
if err != nil {
fmt.Printf("; status=error: %v\n", err)
continue
}
debugLog.Printf(";Response: %v", response)
if len(response.Answer) == 0 {
fmt.Printf("; status=no answers")
continue
}
if response.Answer[0].Header().Name != requestDomain {
fmt.Printf("; status=domain mismatch: %v\n", response.Answer[0])
continue
}
// if response.Answer[0].Header().Name != requestDomain {
// fmt.Print("; status=case mismatch\n")
// continue
// }
fmt.Print("; status=ok\n")
}

// TODO:
// Go over list of public resolvers, restricted to working categories.
Expand Down

0 comments on commit dbcd07d

Please sign in to comment.