Skip to content

Commit

Permalink
Add system resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Dec 13, 2023
1 parent d270870 commit 23753f3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
2 changes: 2 additions & 0 deletions internal/dns/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"net"
)

// TODO: Is this the right interface? We need to process the IPs as soon as we get them, so
// we need a mechanism to stream the IPs.
type Resolver interface {
LookupIP(ctx context.Context, domain string) ([]net.IP, error)
}
Expand Down
2 changes: 2 additions & 0 deletions x/examples/smart-proxy/config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
{
"dns": [
{"system": {}},

{"https": {"name": "2620:fe::fe"}, "//": "Quad9"},
{"https": {"name": "9.9.9.9"}},
{"https": {"name": "149.112.112.112"}},
Expand Down
46 changes: 24 additions & 22 deletions x/smart/finder.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ func (f *StrategyFinder) testDNSClient(ctx context.Context, dnsRT dns.RoundTripp
ctx, cancel := context.WithTimeout(ctx, f.TestTimeout)
defer cancel()

// We special case the system resolver, since we can't get a dns.RoundTripper.
if dnsRT == nil {
return evaluateNetResolver(ctx, new(net.Resolver), testDomain)
}

var request miekgdns.Msg
requestDomain := mixCase(testDomain)
request.SetQuestion(requestDomain, miekgdns.TypeA)
Expand Down Expand Up @@ -341,7 +346,9 @@ type configJSON struct {
}

func (f *StrategyFinder) newDNSRoundTriperFromEntry(entry dnsEntryJSON) (dns.RoundTripper, error) {
if cfg := entry.HTTPS; cfg != nil {
if entry.System != nil {
return nil, nil
} else if cfg := entry.HTTPS; cfg != nil {
if cfg.Name == "" {
return nil, fmt.Errorf("https entry has empty server name")
}
Expand Down Expand Up @@ -382,6 +389,7 @@ func (f *StrategyFinder) newDNSRoundTriperFromEntry(entry dnsEntryJSON) (dns.Rou
}

func (f *StrategyFinder) findDNS(testDomain string, dnsConfig []dnsEntryJSON) (dns.Resolver, error) {
var found bool
var selectedResolver dns.Resolver
var doneMu sync.Mutex
baseContext, stopFind := context.WithCancelCause(context.Background())
Expand All @@ -401,7 +409,7 @@ func (f *StrategyFinder) findDNS(testDomain string, dnsConfig []dnsEntryJSON) (d
id := string(idBytes)
dnsRT, err := f.newDNSRoundTriperFromEntry(entry)
if err != nil {
stopFind(fmt.Errorf("failed to process entry %v: %w", id, err))
stopFind(fmt.Errorf("failed to process entry %v: %w", ei, err))
return
}
f.log("starting dns %v\n", id)
Expand All @@ -419,33 +427,20 @@ func (f *StrategyFinder) findDNS(testDomain string, dnsConfig []dnsEntryJSON) (d
return
}
f.log(", status=ok ✅\n")
selectedResolver = &miekgDNSResolver{MiekgRoundTripper{dnsRT}}
found = true
if dnsRT != nil {
selectedResolver = &miekgDNSResolver{MiekgRoundTripper{dnsRT}}
}
stopFind(nil)
}
// Try system resolver
// if _, ok := streamDialer.(*transport.TCPStreamDialer); ok {
// f.log("client=system")
// sysResolver := new(net.Resolver)
// ctx, cancel := context.WithTimeout(context.Background(), f.TestTimeout)
// ips, err := evaluateNetResolver(ctx, sysResolver, testDomain)
// cancel()
// if err != nil {
// f.log("; status=%v ❌\n", err)
// } else {
// f.log("; status=ok (%v) ✅\n", ips)
// workingClients = append(workingClients, dns.NewStreamDialer(dns.FuncResolver(func(ctx context.Context, domain string) ([]net.IP, error) {
// return sysResolver.LookupIP(ctx, "ip", domain)
// }), streamDialer))
// }
// }
}(ei, entry)
continue
case <-baseContext.Done():
}
break
}
<-baseContext.Done()
if selectedResolver != nil {
if found {
return selectedResolver, nil
}
err := context.Cause(baseContext)
Expand Down Expand Up @@ -519,8 +514,15 @@ func (f *StrategyFinder) NewDialer(ctx context.Context, testDomain string, confi
if err != nil {
return nil, err
}

dnsDialer := dns.NewStreamDialer(resolver, f.StreamDialer)
var dnsDialer transport.StreamDialer
if resolver == nil {
if _, ok := f.StreamDialer.(*transport.TCPStreamDialer); !ok {
return nil, fmt.Errorf("cannot use system resolver with base dialer of type %T", f.StreamDialer)
}
dnsDialer = f.StreamDialer
} else {
dnsDialer = dns.NewStreamDialer(resolver, f.StreamDialer)
}
return f.findTLS(testDomain, dnsDialer, parsedConfig.TLS)
}

Expand Down

0 comments on commit 23753f3

Please sign in to comment.