diff --git a/http-proxy/main.go b/http-proxy/main.go index 4cba0694..f63a1ba3 100644 --- a/http-proxy/main.go +++ b/http-proxy/main.go @@ -185,7 +185,7 @@ var ( track = flag.String("track", "", "The track this proxy is running on") - dnsServer = flag.String("dns-server", "172.16.0.53:53", "Optional DNS server to use for DNS lookups (in place of system resolver)") + dnsServers = flag.String("dns-servers", "", "Optional DNS servers (comma separated) to use for DNS lookups (in place of system resolver)") ) const ( @@ -470,7 +470,7 @@ func main() { BroflakeAddr: *broflakeAddr, BroflakeCert: os.Getenv("BROFLAKE_CERT"), BroflakeKey: os.Getenv("BROFLAKE_KEY"), - DNSServer: *dnsServer, + DNSServers: strings.Split(*dnsServers, ","), } if *maxmindLicenseKey != "" { log.Debug("Will use Maxmind for geolocating clients") diff --git a/http_proxy.go b/http_proxy.go index 0d90eaf1..d57efbfc 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -184,7 +184,7 @@ type Proxy struct { BroflakeCert string BroflakeKey string - DNSServer string + DNSServers []string throttleConfig throttle.Config instrument instrument.Instrument } @@ -279,35 +279,6 @@ func (p *Proxy) ListenAndServe(ctx context.Context) error { // Throttle connections when signaled srv.AddListenerWrappers(lanternlisteners.NewBitrateListener, bwReporting.wrapper) - if p.DNSServer != "" { - log.Debugf("Will resolve DNS using %v", p.DNSServer) - host, port, err := net.SplitHostPort(p.DNSServer) - if err != nil { - log.Fatalf("invalid dns-server address %v: %v", p.DNSServer, err) - } - r := &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - log.Debug("Dialing custom resolver") - return netx.DialContext(ctx, host, port) - }, - } - netx.OverrideResolveIPs(func(host string) ([]net.IP, error) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - addrs, err := r.LookupIPAddr(ctx, host) - if err != nil { - return nil, err - } - ips := make([]net.IP, 0, len(addrs)) - for _, addr := range addrs { - ips = append(ips, addr.IP) - } - return ips, nil - }) - } - allListeners := make([]net.Listener, 0) listenerProtocols := make([]string, 0) addListenerIfNecessary := func(proto, addr string, fn listenerBuilderFN) error { @@ -639,17 +610,60 @@ func (p *Proxy) createFilterChain(bl *blacklist.Blacklist) (filters.Chain, proxy } filterChain = filterChain.Append(instrumentedProxyPingFilter) - // Google anomaly detection can be triggered very often over IPv6. - // Prefer IPv4 to mitigate, see issue #97 - _dialer := preferIPV4Dialer(timeoutToDialOriginSite) + var resolvers []*net.Resolver + if len(p.DNSServers) == 0 { + resolvers = append(resolvers, &net.Resolver{}) + } else { + log.Debugf("Will resolve DNS using %v", p.DNSServers) + for _, _dnsServer := range p.DNSServers { + dnsServer := _dnsServer + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return netx.DialContext(ctx, "udp", dnsServer) + }, + } + resolvers = append(resolvers, r) + } + } + dialer := func(ctx context.Context, network, addr string) (net.Conn, error) { // resolve separately so that we can track the DNS resolution time - resolvedAddr, resolveErr := netx.Resolve(network, addr) - if resolveErr != nil { - return nil, resolveErr + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, errors.New("invalid address %v: %v", addr, err) + } + ip := net.ParseIP(host) + var resolveErr error + if ip == nil { + resolveLoop: + for _, r := range resolvers { + // Note - 5 seconds is the default Linux DNS timeout + rctx, cancel := context.WithTimeout(ctx, 5*time.Second) + var ips []net.IPAddr + ips, resolveErr = r.LookupIPAddr(rctx, host) + cancel() + if resolveErr == nil { + // Google anomaly detection can be triggered very often over IPv6. + // Prefer IPv4 to mitigate, see issue #97 + for _, candidate := range ips { + if candidate.IP.To4() != nil { + ip = candidate.IP + break resolveLoop + } + } + } + } + } + if ip == nil { + return nil, errors.New("unable to resolve host %v, last resolution error: %v", host, resolveErr) } - conn, dialErr := _dialer(ctx, network, resolvedAddr.String()) + resolvedAddr := fmt.Sprintf("%s:%s", ip, port) + d := &net.Dialer{ + Deadline: time.Now().Add(timeoutToDialOriginSite), + } + conn, dialErr := d.DialContext(ctx, network, resolvedAddr) if dialErr != nil { return nil, dialErr } diff --git a/throttle_integration_test.go b/throttle_integration_test.go index f3d909d4..3dc9308d 100644 --- a/throttle_integration_test.go +++ b/throttle_integration_test.go @@ -93,6 +93,7 @@ func doTestThrottling(t *testing.T, pro bool, serverAddr string, redisIsUp bool, TestingLocal: true, GoogleSearchRegex: "bequiet", GoogleCaptchaRegex: "bequiet", + DNSServers: []string{"127.0.0.1:2435", "8.8.8.8:53"}, // first one is a bogus DNS server } go func() { assert.NoError(t, proxy.ListenAndServe(context.Background()))