diff --git a/x/smart/finder.go b/x/smart/finder.go index 9a2fdced..646e9130 100644 --- a/x/smart/finder.go +++ b/x/smart/finder.go @@ -36,7 +36,6 @@ import ( ) // TODO: -// - Fix for China // - Add DNS caching // - Parallelize TLS // - Add debug logging to proxy handler @@ -240,7 +239,7 @@ func evaluateNetResolver(ctx context.Context, resolver *net.Resolver, testDomain if err != nil { return nil, fmt.Errorf("could not get cname: %w", err) } - ips, err := net.LookupIP(requestDomain) + ips, err := resolver.LookupIP(ctx, "ip", requestDomain) if err != nil { return nil, fmt.Errorf("failed to lookup IPs: %w", err) } @@ -256,7 +255,10 @@ func evaluateNetResolver(ctx context.Context, resolver *net.Resolver, testDomain return ips, nil } -func evaluateResponse(response *miekgdns.Msg, requestDomain string) ([]net.IP, error) { +func evaluateAResponse(response *miekgdns.Msg, requestDomain string) ([]net.IP, error) { + if response.Rcode != miekgdns.RcodeSuccess { + return nil, fmt.Errorf("rcode is not success: %v", response.Rcode) + } var ips []net.IP if len(response.Answer) == 0 { return ips, errors.New("no answers") // -1 @@ -287,6 +289,37 @@ func evaluateResponse(response *miekgdns.Msg, requestDomain string) ([]net.IP, e return ips, nil } +func evaluateCNAMEResponse(response *miekgdns.Msg, requestDomain string) error { + if response.Rcode != miekgdns.RcodeSuccess { + return fmt.Errorf("rcode is not success: %v", response.Rcode) + } + if len(response.Answer) == 0 { + var numSOA int + for _, answer := range response.Ns { + if _, ok := answer.(*miekgdns.SOA); ok { + numSOA++ + } + } + if numSOA != 1 { + return fmt.Errorf("SOA records is %v, expected 1", numSOA) + } + return nil + } + var cname string + for _, answer := range response.Answer { + if rr, ok := answer.(*miekgdns.CNAME); ok { + if cname != "" { + return fmt.Errorf("found too many CNAMEs: %v %v", cname, rr.Target) + } + cname = rr.Target + } + } + if cname == "" { + return errors.New("no CNAME in answers") + } + return nil +} + type StrategyFinder struct { TestTimeout time.Duration LogWriter io.Writer @@ -312,14 +345,28 @@ func (f *StrategyFinder) testDNSClient(ctx context.Context, dnsRT dns.RoundTripp return evaluateNetResolver(ctx, new(net.Resolver), testDomain) } - var request miekgdns.Msg requestDomain := mixCase(testDomain) + var request miekgdns.Msg + + request = miekgdns.Msg{} request.SetQuestion(requestDomain, miekgdns.TypeA) response, err := (&MiekgRoundTripper{dnsRT}).RoundTripMsg(ctx, &request) if err != nil { - return nil, fmt.Errorf("request failed: %w", err) + return nil, fmt.Errorf("request for A query failed: %w", err) + } + ips, err := evaluateAResponse(response, requestDomain) + if err != nil { + return ips, err + } + + request = miekgdns.Msg{} + request.SetQuestion(requestDomain, miekgdns.TypeCNAME) + response, err = (&MiekgRoundTripper{dnsRT}).RoundTripMsg(ctx, &request) + if err != nil { + return nil, fmt.Errorf("request for CNAME query failed: %w", err) } - return evaluateResponse(response, requestDomain) + err = evaluateCNAMEResponse(response, requestDomain) + return ips, err } type httpsEntryJSON struct {