Skip to content

Commit

Permalink
use a more robust approach to proxy dns
Browse files Browse the repository at this point in the history
  • Loading branch information
ferama committed Oct 17, 2024
1 parent 6966bd3 commit 87b8e36
Showing 1 changed file with 43 additions and 47 deletions.
90 changes: 43 additions & 47 deletions pkg/sshc/dns_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package sshc
import (
"encoding/binary"
"fmt"
"net"
"sync"

"github.com/miekg/dns"
)
Expand Down Expand Up @@ -31,7 +31,14 @@ func NewDnsProxy(sshConn *SshConnection, conf *DnsProxyConf) *DnsProxy {
}

// resolveDomain sends a DNS query for a domain name over TCP
func (p *DnsProxy) resolveDomain(conn net.Conn, msg *dns.Msg) ([]byte, error) {
// to the underying ssh connection
func (p *DnsProxy) resolveDomain(msg *dns.Msg) ([]byte, error) {

conn, err := p.sshConn.Client.Dial("tcp", p.remoteDnsServer)
if err != nil {
return nil, fmt.Errorf("unable to connect to remote dns server: %v", err)
}

// Pack the DNS message (with the original transaction ID)
query, err := msg.Pack()
if err != nil {
Expand Down Expand Up @@ -61,29 +68,18 @@ func (p *DnsProxy) resolveDomain(conn net.Conn, msg *dns.Msg) ([]byte, error) {
return response, nil
}

func (p *DnsProxy) handleDNSQuery(udpConn *net.UDPConn, clientAddr *net.UDPAddr, query []byte) {
// Unpack the DNS message
msg := new(dns.Msg)
if err := msg.Unpack(query); err != nil {
log.Printf("failed to unpack DNS query: %v", err)
return
}

// Extract the domain name from the query
func (p *DnsProxy) handleDNSQuery(w dns.ResponseWriter, msg *dns.Msg) {
// check if we have a valid question
if len(msg.Question) == 0 {
log.Printf("invalid DNS query: no question section")
return
}

originalID := msg.Id // Preserve the original transaction ID
// Preserve the original transaction ID
originalID := msg.Id

conn, err := p.sshConn.Client.Dial("tcp", p.remoteDnsServer)
if err != nil {
log.Printf("unable to connect to remote dns server: %v", err)
return
}
// Resolve the domain through the proxy
dnsResponse, err := p.resolveDomain(conn, msg)
dnsResponse, err := p.resolveDomain(msg)
if err != nil {
log.Printf("failed to resolve domain: %v", err)
return
Expand All @@ -99,42 +95,42 @@ func (p *DnsProxy) handleDNSQuery(udpConn *net.UDPConn, clientAddr *net.UDPAddr,
// Set the original transaction ID back into the response
reply.Id = originalID

// Pack the modified response (with correct ID)
finalResponse, err := reply.Pack()
if err != nil {
log.Printf("failed to pack final DNS response: %v", err)
return
w.WriteMsg(reply)
}

func (p *DnsProxy) run(net string) {
server := &dns.Server{
Addr: p.proxyListenAddr,
Net: net,
Handler: dns.DefaultServeMux,
}

// Send the DNS response back to the client
if _, err := udpConn.WriteToUDP(finalResponse, clientAddr); err != nil {
log.Printf("failed to send DNS response: %v", err)
return
err := server.ListenAndServe()
defer server.Shutdown()
if err != nil {
log.Fatalf("failed to start server: %s\n ", err.Error())
}
}

func (p *DnsProxy) Start() error {
p.sshConn.ReadyWait()

addr, err := net.ResolveUDPAddr("udp", p.proxyListenAddr)
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %v", err)
}
dns.HandleFunc(".", p.handleDNSQuery)

udpConn, err := net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("failed to listen on UDP port 53: %v", err)
}
defer udpConn.Close()
log.Printf("dns-proxy listening on UDP: %s. Using remote dns: %s", p.proxyListenAddr, p.remoteDnsServer)

// Handle incoming DNS queries
buf := make([]byte, 4096)
for {
n, clientAddr, err := udpConn.ReadFromUDP(buf)
if err != nil {
continue
}
go p.handleDNSQuery(udpConn, clientAddr, buf[:n])
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
p.run("udp")
wg.Done()
}()

wg.Add(1)
go func() {
p.run("tcp")
wg.Done()
}()

log.Printf("dns-proxy listening on: %s. Using remote dns: %s", p.proxyListenAddr, p.remoteDnsServer)
wg.Wait()
return nil
}

0 comments on commit 87b8e36

Please sign in to comment.