From e9dc64b2d70487faea95c139a9fa768c68a2d90d Mon Sep 17 00:00:00 2001
From: Yevhen Vydolob <yvydolob@redhat.com>
Date: Fri, 15 Nov 2024 15:10:54 +0200
Subject: [PATCH] Use all configured ipv4 dns services

It brings all ipv4 dns and make dns resolve request one by one until it receives first answer

Signed-off-by: Yevhen Vydolob <yvydolob@redhat.com>
---
 pkg/services/dns/dns.go                | 42 ++++++++++++++------------
 pkg/services/dns/dns_config_unix.go    | 19 ++++++++----
 pkg/services/dns/dns_config_windows.go | 13 ++++----
 3 files changed, 42 insertions(+), 32 deletions(-)

diff --git a/pkg/services/dns/dns.go b/pkg/services/dns/dns.go
index aa9b862d..e50bbc52 100644
--- a/pkg/services/dns/dns.go
+++ b/pkg/services/dns/dns.go
@@ -16,17 +16,17 @@ import (
 )
 
 type dnsHandler struct {
-	zones      []types.Zone
-	zonesLock  sync.RWMutex
-	udpClient  *dns.Client
-	tcpClient  *dns.Client
-	hostsFile  *HostsFile
-	nameserver string
+	zones       []types.Zone
+	zonesLock   sync.RWMutex
+	udpClient   *dns.Client
+	tcpClient   *dns.Client
+	hostsFile   *HostsFile
+	nameservers []string
 }
 
 func newDNSHandler(zones []types.Zone) (*dnsHandler, error) {
 
-	nameserver, port, err := getDNSHostAndPort()
+	nameservers, err := getDNSHostAndPort()
 	if err != nil {
 		return nil, err
 	}
@@ -37,11 +37,11 @@ func newDNSHandler(zones []types.Zone) (*dnsHandler, error) {
 	}
 
 	return &dnsHandler{
-		zones:      zones,
-		tcpClient:  &dns.Client{Net: "tcp"},
-		udpClient:  &dns.Client{Net: "udp"},
-		nameserver: net.JoinHostPort(nameserver, port),
-		hostsFile:  hostsFile,
+		zones:       zones,
+		tcpClient:   &dns.Client{Net: "tcp"},
+		udpClient:   &dns.Client{Net: "udp"},
+		nameservers: nameservers,
+		hostsFile:   hostsFile,
 	}, nil
 
 }
@@ -145,15 +145,19 @@ func (h *dnsHandler) addAnswers(dnsClient *dns.Client, r *dns.Msg) *dns.Msg {
 			return m
 		}
 	}
-
-	r, _, err := dnsClient.Exchange(r, h.nameserver)
-	if err != nil {
-		log.Errorf("Error during DNS Exchange: %s", err)
-		m.Rcode = dns.RcodeNameError
-		return m
+	for _, nameserver := range h.nameservers {
+		msg := r.Copy()
+		r, _, err := dnsClient.Exchange(msg, nameserver)
+		// return first good answer
+		if err == nil {
+			return r
+		}
+		log.Debugf("Error during DNS Exchange: %s", err)
 	}
 
-	return r
+	// return the error if none of configured nameservers has right answer
+	m.Rcode = dns.RcodeNameError
+	return m
 }
 
 type Server struct {
diff --git a/pkg/services/dns/dns_config_unix.go b/pkg/services/dns/dns_config_unix.go
index 32716240..fb806a65 100644
--- a/pkg/services/dns/dns_config_unix.go
+++ b/pkg/services/dns/dns_config_unix.go
@@ -3,16 +3,23 @@
 package dns
 
 import (
+	"net"
+
 	"github.com/miekg/dns"
 )
 
-func getDNSHostAndPort() (string, string, error) {
+func getDNSHostAndPort() ([]string, error) {
 	conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
 	if err != nil {
-		return "", "", err
+		return []string{}, err
 	}
-	// TODO: use all configured nameservers, instead just first one
-	nameserver := conf.Servers[0]
-
-	return nameserver, conf.Port, nil
+	var hosts = make([]string, len(conf.Servers))
+	for _, server := range conf.Servers {
+		dnsIP := net.ParseIP(server)
+		// add only ipv4 dns addresses
+		if dnsIP != nil && dnsIP.To4() != nil {
+			hosts = append(hosts, net.JoinHostPort(server, conf.Port))
+		}
+	}
+	return hosts, nil
 }
diff --git a/pkg/services/dns/dns_config_windows.go b/pkg/services/dns/dns_config_windows.go
index 2644e065..f041326d 100644
--- a/pkg/services/dns/dns_config_windows.go
+++ b/pkg/services/dns/dns_config_windows.go
@@ -3,24 +3,23 @@
 package dns
 
 import (
-	"net/netip"
+	"net"
 	"strconv"
 
 	qdmDns "github.com/qdm12/dns/v2/pkg/nameserver"
 )
 
-func getDNSHostAndPort() (string, string, error) {
+func getDNSHostAndPort() ([]string, error) {
 	nameservers := qdmDns.GetDNSServers()
 
-	var nameserver netip.AddrPort
+	var dnsServers = make([]string, 5)
 	for _, n := range nameservers {
-		// return first non ipv6 nameserver
+		// return only ipv4 nameservers
 		if n.Addr().Is4() {
-			nameserver = n
-			break
+			dnsServers = append(dnsServers, net.JoinHostPort(n.Addr().String(), strconv.Itoa(int(n.Port()))))
 		}
 	}
 
-	return nameserver.Addr().String(), strconv.Itoa(int(nameserver.Port())), nil
+	return dnsServers, nil
 
 }