Skip to content

Commit

Permalink
Code review updates
Browse files Browse the repository at this point in the history
  • Loading branch information
oxtoacart committed Oct 11, 2023
1 parent 3b9fb26 commit 705d649
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 19 deletions.
47 changes: 29 additions & 18 deletions custom_dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"sync"
"time"

"github.com/getlantern/errors"
Expand All @@ -17,7 +18,7 @@ const (

// Returns a dialer that uses custom DNS servers to resolve the host. It uses all DNS servers
// in parallel and uses the first response it gets.
func customDNSDialer(dnsServers []string, timeout time.Duration) (func(context.Context, string, string) (net.Conn, error), error) {
func customDNSDialer(dnsServers []string, timeoutToDialOrigin time.Duration) (func(context.Context, string, string) (net.Conn, error), error) {
resolvers := make([]*net.Resolver, 0, len(dnsServers))
if len(dnsServers) == 0 {
log.Debug("Will resolve DNS using system DNS servers")
Expand All @@ -29,13 +30,17 @@ func customDNSDialer(dnsServers []string, timeout time.Duration) (func(context.C
r := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return netx.DialContext(ctx, "udp", dnsServer)
return netx.DialContext(ctx, network, dnsServer)
},
}
resolvers = append(resolvers, r)
}
}

dialer := &net.Dialer{
Timeout: timeoutToDialOrigin,
}

dial := func(ctx context.Context, network, addr string) (net.Conn, error) {
// resolve separately so that we can track the DNS resolution time
host, port, err := net.SplitHostPort(addr)
Expand All @@ -49,29 +54,33 @@ func customDNSDialer(dnsServers []string, timeout time.Duration) (func(context.C
errs := make(chan error, len(resolvers))
rctx, cancel := context.WithTimeout(ctx, resolutionTimeout)
defer cancel()
var wg sync.WaitGroup
wg.Add(len(resolvers))
for _, r := range resolvers {
resolveInBackground(rctx, r, host, results, errs)
resolveInBackground(rctx, r, host, &wg, results, errs)
}
select {
case ip = <-results:
// got a result!
case <-time.After(resolutionTimeout):
var resolveErr error
errorCount := 0
deadline := time.After(resolutionTimeout)
resultLoop:
for {
select {
case resolveErr = <-errs:
// got an error
default:
// no error, we just timed out
case ip = <-results:
// got a result!
break resultLoop
case err := <-errs:
errorCount++
if errorCount == len(resolvers) {
// all resolvers failed, stop trying
return nil, errors.New("unable to resolve host %v, last resolution error: %v", host, err)
}
case <-deadline:
return nil, errors.New("unable to resolve host %v, resolution timed out", host)
}
return nil, errors.New("unable to resolve host %v, last resolution error: %v", host, resolveErr)
}
}

resolvedAddr := fmt.Sprintf("%s:%s", ip, port)
d := &net.Dialer{
Deadline: time.Now().Add(timeout),
}
conn, dialErr := d.DialContext(ctx, "tcp", resolvedAddr)
conn, dialErr := dialer.DialContext(ctx, "tcp", resolvedAddr)
if dialErr != nil {
return nil, dialErr
}
Expand All @@ -82,8 +91,10 @@ func customDNSDialer(dnsServers []string, timeout time.Duration) (func(context.C
return dial, nil
}

func resolveInBackground(ctx context.Context, r *net.Resolver, host string, results chan net.IP, errors chan error) {
func resolveInBackground(ctx context.Context, r *net.Resolver, host string, wg *sync.WaitGroup, results chan net.IP, errors chan error) {
go func() {
defer wg.Done()

ips, err := r.LookupIPAddr(ctx, host)
if err != nil {
errors <- err
Expand Down
2 changes: 1 addition & 1 deletion http-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ var (

track = flag.String("track", "", "The track this proxy is running on")

dnsServers = flag.String("dns-servers", "", "Optional DNS servers (comma separated) to use for DNS lookups (in place of system resolver)")
dnsServers = flag.String("dns-servers", "172.16.0.53:53,8.8.8.8:53", "Optional DNS servers (comma separated) to use for DNS lookups (in place of system resolver)")
)

const (
Expand Down
26 changes: 26 additions & 0 deletions http_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxy

import (
"bufio"
"context"
"crypto/tls"
"flag"
"fmt"
Expand All @@ -17,6 +18,7 @@ import (
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/getlantern/keyman"
"github.com/getlantern/measured"
Expand Down Expand Up @@ -793,3 +795,27 @@ func (mr *mockReporter) Report(ctx map[string]interface{}, stats *measured.Stats
}
mr.traffic[deviceID] = stats
}

func TestCustomDNSSuccess(t *testing.T) {
// Use a working and a non-working (4.4.4.4) DNS server to make sure that DNS resolution doesn't wait for
// the failing one to timeout
d, err := customDNSDialer([]string{"8.8.8.8:53", "4.4.4.4:53"}, 5*time.Second)
require.NoError(t, err)
start := time.Now()
conn, err := d(context.Background(), "tcp", "www.google.com:443")
require.NoError(t, err)
require.Less(t, time.Since(start), 2*time.Second)
conn.Close()
}

func TestCustomDNSFailure(t *testing.T) {
// Use a two working DNS servers to make sure that when all DNS servers fail to find
// a result, we return an error quickly.
d, err := customDNSDialer([]string{"8.8.8.8:53", "8.8.4.4:53"}, 5*time.Second)
require.NoError(t, err)
start := time.Now()
_, err = d(context.Background(), "tcp", "blubbaasdfsadfsadf.dude:443")
require.Error(t, err)
fmt.Println(err)
require.Less(t, time.Since(start), 2*time.Second)
}

0 comments on commit 705d649

Please sign in to comment.