From bcea6b1c709aa3f009e7dcacc2fd009b1637eeab Mon Sep 17 00:00:00 2001 From: Yevhen Vydolob Date: Mon, 5 Aug 2024 11:14:44 +0300 Subject: [PATCH] Replace multiple "resolver.*" fn cals with single "dns.Exchange()" fn. This highly simplify resolving DNS code. Signed-off-by: Yevhen Vydolob Co-authored-by: Christophe Fergeau --- pkg/services/dns/dns.go | 221 +++++++++---------------- pkg/services/dns/dns_config_unix.go | 22 +++ pkg/services/dns/dns_config_windows.go | 26 +++ pkg/services/dns/dns_test.go | 104 ++++++++++++ test/basic_test.go | 2 +- 5 files changed, 235 insertions(+), 140 deletions(-) create mode 100644 pkg/services/dns/dns_config_unix.go create mode 100644 pkg/services/dns/dns_config_windows.go diff --git a/pkg/services/dns/dns.go b/pkg/services/dns/dns.go index 1cfe0b6cb..ca83f524c 100644 --- a/pkg/services/dns/dns.go +++ b/pkg/services/dns/dns.go @@ -1,7 +1,6 @@ package dns import ( - "context" "encoding/json" "fmt" "net" @@ -15,15 +14,43 @@ import ( ) type dnsHandler struct { - zones []types.Zone - zonesLock sync.RWMutex + zones []types.Zone + zonesLock sync.RWMutex + dnsClient *dns.Client + nameserver string +} + +func newDNSHandler(zones []types.Zone) (*dnsHandler, error) { + + dnsClient, nameserver, err := readAndCreateClient() + if err != nil { + return nil, err + } + + return &dnsHandler{ + zones: zones, + dnsClient: dnsClient, + nameserver: nameserver, + }, nil + +} + +func readAndCreateClient() (*dns.Client, string, error) { + + nameserver, port, err := GetDNSHostAndPort() + if err != nil { + return nil, "", err + } + + nameserver = net.JoinHostPort(nameserver, port) + + client := new(dns.Client) + + return client, nameserver, nil } func (h *dnsHandler) handle(w dns.ResponseWriter, r *dns.Msg, responseMessageSize int) { - m := new(dns.Msg) - m.SetReply(r) - m.RecursionAvailable = true - h.addAnswers(m) + m := h.addAnswers(r) edns0 := r.IsEdns0() if edns0 != nil { responseMessageSize = int(edns0.UDPSize()) @@ -35,40 +62,31 @@ func (h *dnsHandler) handle(w dns.ResponseWriter, r *dns.Msg, responseMessageSiz } func (h *dnsHandler) handleTCP(w dns.ResponseWriter, r *dns.Msg) { + // needs to be handled in a better way, handleTCP/handleUDP can run concurrently so this change is racy + // h.dnsClient.Net = "tcp" h.handle(w, r, dns.MaxMsgSize) } func (h *dnsHandler) handleUDP(w dns.ResponseWriter, r *dns.Msg) { + // needs to be handled in a better way, handleTCP/handleUDP can run concurrently so this change is racy + // h.dnsClient.Net = "udp" h.handle(w, r, dns.MinMsgSize) } -func (h *dnsHandler) addAnswers(m *dns.Msg) { +func (h *dnsHandler) addLocalAnswers(m *dns.Msg, q dns.Question) bool { h.zonesLock.RLock() defer h.zonesLock.RUnlock() - for _, q := range m.Question { - for _, zone := range h.zones { - zoneSuffix := fmt.Sprintf(".%s", zone.Name) - if strings.HasSuffix(q.Name, zoneSuffix) { - if q.Qtype != dns.TypeA { - return - } - for _, record := range zone.Records { - withoutZone := strings.TrimSuffix(q.Name, zoneSuffix) - if (record.Name != "" && record.Name == withoutZone) || - (record.Regexp != nil && record.Regexp.MatchString(withoutZone)) { - m.Answer = append(m.Answer, &dns.A{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 0, - }, - A: record.IP, - }) - return - } - } - if !zone.DefaultIP.Equal(net.IP("")) { + + for _, zone := range h.zones { + zoneSuffix := fmt.Sprintf(".%s", zone.Name) + if strings.HasSuffix(q.Name, zoneSuffix) { + if q.Qtype != dns.TypeA { + return false + } + for _, record := range zone.Records { + withoutZone := strings.TrimSuffix(q.Name, zoneSuffix) + if (record.Name != "" && record.Name == withoutZone) || + (record.Regexp != nil && record.Regexp.MatchString(withoutZone)) { m.Answer = append(m.Answer, &dns.A{ Hdr: dns.RR_Header{ Name: q.Name, @@ -76,29 +94,12 @@ func (h *dnsHandler) addAnswers(m *dns.Msg) { Class: dns.ClassINET, Ttl: 0, }, - A: zone.DefaultIP, + A: record.IP, }) - return + return true } - m.Rcode = dns.RcodeNameError - return - } - } - - resolver := net.Resolver{ - PreferGo: false, - } - switch q.Qtype { - case dns.TypeA: - ips, err := resolver.LookupIPAddr(context.TODO(), q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return } - for _, ip := range ips { - if len(ip.IP.To4()) != net.IPv4len { - continue - } + if !zone.DefaultIP.Equal(net.IP("")) { m.Answer = append(m.Answer, &dns.A{ Hdr: dns.RR_Header{ Name: q.Name, @@ -106,96 +107,35 @@ func (h *dnsHandler) addAnswers(m *dns.Msg) { Class: dns.ClassINET, Ttl: 0, }, - A: ip.IP.To4(), + A: zone.DefaultIP, }) + return true } - case dns.TypeCNAME: - cname, err := resolver.LookupCNAME(context.TODO(), q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - m.Answer = append(m.Answer, &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 0, - }, - Target: cname, - }) - case dns.TypeMX: - records, err := resolver.LookupMX(context.TODO(), q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - for _, mx := range records { - m.Answer = append(m.Answer, &dns.MX{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeMX, - Class: dns.ClassINET, - Ttl: 0, - }, - Mx: mx.Host, - Preference: mx.Pref, - }) - } - case dns.TypeNS: - records, err := resolver.LookupNS(context.TODO(), q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - for _, ns := range records { - m.Answer = append(m.Answer, &dns.NS{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - Ttl: 0, - }, - Ns: ns.Host, - }) - } - case dns.TypeSRV: - _, records, err := resolver.LookupSRV(context.TODO(), "", "", q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - for _, srv := range records { - m.Answer = append(m.Answer, &dns.SRV{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - Ttl: 0, - }, - Port: srv.Port, - Priority: srv.Priority, - Target: srv.Target, - Weight: srv.Weight, - }) - } - case dns.TypeTXT: - records, err := resolver.LookupTXT(context.TODO(), q.Name) - if err != nil { - m.Rcode = dns.RcodeNameError - return - } - m.Answer = append(m.Answer, &dns.TXT{ - Hdr: dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - Ttl: 0, - }, - Txt: records, - }) + m.Rcode = dns.RcodeNameError + return true + } + } + return false +} + +func (h *dnsHandler) addAnswers(r *dns.Msg) *dns.Msg { + m := new(dns.Msg) + m.SetReply(r) + m.RecursionAvailable = true + for _, q := range m.Question { + if done := h.addLocalAnswers(m, q); done { + return m } + } + + r, _, err := h.dnsClient.Exchange(r, h.nameserver) + if err != nil { + m.Rcode = dns.RcodeNameError + return m + } + + return r } type Server struct { @@ -205,7 +145,10 @@ type Server struct { } func New(udpConn net.PacketConn, tcpLn net.Listener, zones []types.Zone) (*Server, error) { - handler := &dnsHandler{zones: zones} + handler, err := newDNSHandler(zones) + if err != nil { + return nil, err + } return &Server{udpConn: udpConn, tcpLn: tcpLn, handler: handler}, nil } diff --git a/pkg/services/dns/dns_config_unix.go b/pkg/services/dns/dns_config_unix.go new file mode 100644 index 000000000..9d4cc1b7c --- /dev/null +++ b/pkg/services/dns/dns_config_unix.go @@ -0,0 +1,22 @@ +//go:build !windows + +package dns + +import ( + "fmt" + "os" + + "github.com/miekg/dns" +) + +func GetDNSHostAndPort() (string, string, error) { + conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") + if err != nil { + fmt.Fprintln(os.Stderr, err) + return "", "", err + } + // TODO: use all configured nameservers, instead just first one + nameserver := conf.Servers[0] + + return nameserver, conf.Port, nil +} diff --git a/pkg/services/dns/dns_config_windows.go b/pkg/services/dns/dns_config_windows.go new file mode 100644 index 000000000..d30f7b451 --- /dev/null +++ b/pkg/services/dns/dns_config_windows.go @@ -0,0 +1,26 @@ +//go:build windows + +package dns + +import ( + "net/netip" + "strconv" + + qdmDns "github.com/qdm12/dns/v2/pkg/nameserver" +) + +func GetDNSHostAndPort() (string, string, error) { + nameservers := qdmDns.GetDNSServers() + + var nameserver netip.AddrPort + for _, n := range nameservers { + // return first non ipv6 nameserver + if n.Addr().Is4() { + nameserver = n + break + } + } + + return nameserver.Addr().String(), strconv.Itoa(int(nameserver.Port())), nil + +} diff --git a/pkg/services/dns/dns_test.go b/pkg/services/dns/dns_test.go index f01488d23..e342288b4 100644 --- a/pkg/services/dns/dns_test.go +++ b/pkg/services/dns/dns_test.go @@ -1,12 +1,17 @@ package dns import ( + "context" "net" "testing" + "time" "github.com/containers/gvisor-tap-vsock/pkg/types" + "github.com/miekg/dns" "github.com/onsi/ginkgo" "github.com/onsi/gomega" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" ) func TestSuite(t *testing.T) { @@ -191,4 +196,103 @@ var _ = ginkgo.Describe("dns add test", func() { }, })) }) + + ginkgo.It("Should pass DNS requests to default system DNS server", func() { + m := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Authoritative: false, + AuthenticatedData: false, + CheckingDisabled: false, + RecursionDesired: true, + Opcode: 0, + }, + Question: make([]dns.Question, 1), + } + + m.Question[0] = dns.Question{ + Name: "redhat.com.", + Qtype: 1, + Qclass: 1, + } + + r := server.handler.addAnswers(m) + + gomega.Expect(r.Answer[0].Header().Name).To(gomega.Equal("redhat.com.")) + gomega.Expect(r.Answer[0].String()).To(gomega.SatisfyAny(gomega.ContainSubstring("34.235.198.240"), gomega.ContainSubstring("52.200.142.250"))) + }) }) + +type ARecord struct { + name string + expectedIPs []string +} + +func TestDNS(t *testing.T) { + log.Infof("starting test DNS servers") + nameserver, cleanup, err := startDNSServer() + require.NoError(t, err) + defer cleanup() + time.Sleep(100 * time.Millisecond) + log.Infof("test DNS servers started") + + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, _ string) (net.Conn, error) { + d := net.Dialer{ + Timeout: time.Millisecond * time.Duration(10000), + } + log.Infof("dialing %s %s", network, nameserver) + + return d.DialContext(ctx, network, nameserver) + }, + } + redhatdotcom := ARecord{ + name: "redhat.com", + expectedIPs: []string{"52.200.142.250"}, + } + record := redhatdotcom + { + log.Infof("looking up %s", record.name) + ipGvisor, err := r.LookupHost(context.Background(), record.name) + require.NoError(t, err) + require.Subset(t, ipGvisor, record.expectedIPs) + log.Infof("ip gvisor: %+v", ipGvisor) + + ipGo, err := net.LookupHost(record.name) + require.NoError(t, err) + log.Infof("ip go: %+v", ipGo) + require.Subset(t, ipGvisor, ipGo) + } +} + +func startDNSServer() (string, func(), error) { + udpConn, err := net.ListenPacket("udp", "127.0.0.1:5354") + if err != nil { + return "", nil, err + } + + tcpLn, err := net.Listen("tcp", "127.0.0.1:5354") + if err != nil { + return "", nil, err + } + + server, err := New(udpConn, tcpLn, nil) + if err != nil { + return "", nil, err + } + + go func() { + if err := server.Serve(); err != nil { + log.Errorf("serve UDP error: %T %s", err, err) + } + }() + go func() { + if err := server.ServeTCP(); err != nil { + log.Errorf("serve TCP error: %T %s", err, err) + } + }() + return "127.0.0.1:5354", func() { + udpConn.Close() + tcpLn.Close() + }, nil +} diff --git a/test/basic_test.go b/test/basic_test.go index a3db89e7e..514f66599 100644 --- a/test/basic_test.go +++ b/test/basic_test.go @@ -61,7 +61,7 @@ var _ = ginkgo.Describe("dns", func() { ginkgo.It("should resolve MX record for wikipedia.org", func() { out, err := sshExec("nslookup -query=mx wikipedia.org") gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) - gomega.Expect(string(out)).To(gomega.ContainSubstring("wikipedia.org mail exchanger = 10 mx1001.wikimedia.org.")) + gomega.Expect(string(out)).To(gomega.ContainSubstring("wikipedia.org mail exchanger = 10 mx-1001.wikimedia.org.")) }) ginkgo.It("should resolve NS record for wikipedia.org", func() {