diff --git a/config.go b/config.go index 8b22ba24..f59bb3f8 100644 --- a/config.go +++ b/config.go @@ -24,6 +24,7 @@ type Config struct { ConnectionsPerHost int `long:"connections-per-host" default:"1" description:"Number of times to connect to each host (results in more output)"` ReadLimitPerHost int `long:"read-limit-per-host" default:"96" description:"Maximum total kilobytes to read for a single host (default 96kb)"` Prometheus string `long:"prometheus" description:"Address to use for Prometheus server (e.g. localhost:8080). If empty, Prometheus is disabled."` + CustomDNS string `long:"dns" description:"Address of a custom DNS server for lookups. Default port is 53."` Multiple MultipleCommand `command:"multiple" description:"Multiple module actions"` inputFile *os.File outputFile *os.File @@ -128,6 +129,14 @@ func validateFrameworkConfiguration() { if config.ReadLimitPerHost > 0 { DefaultBytesReadLimit = config.ReadLimitPerHost * 1024 } + + // Validate custom DNS + if config.CustomDNS != "" { + var err error + if config.CustomDNS, err = addDefaultPortToDNSServerName(config.CustomDNS); err != nil { + log.Fatalf("invalid DNS server address: %s", err) + } + } } // GetMetaFile returns the file to which metadata should be output diff --git a/conn.go b/conn.go index ef701c81..e4dc43e2 100644 --- a/conn.go +++ b/conn.go @@ -341,6 +341,16 @@ func (d *Dialer) SetDefaults() *Dialer { KeepAlive: d.Timeout, DualStack: true, } + + // Use custom DNS as default if set + if config.CustomDNS != "" { + d.Dialer.Resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial(network, config.CustomDNS) + }, + } + } } return d } diff --git a/utility.go b/utility.go index c2d88a7f..4d500858 100644 --- a/utility.go +++ b/utility.go @@ -2,6 +2,7 @@ package zgrab2 import ( "errors" + "fmt" "net" "regexp" "strconv" @@ -9,13 +10,16 @@ import ( "time" - "github.com/zmap/zflags" - "github.com/sirupsen/logrus" "runtime/debug" + + "github.com/sirupsen/logrus" + flags "github.com/zmap/zflags" ) var parser *flags.Parser +const defaultDNSPort = "53" + func init() { parser = flags.NewParser(&config, flags.Default) } @@ -214,8 +218,9 @@ func IsTimeoutError(err error) bool { // doing anything. Otherwise, it logs the stacktrace, the panic error, and the provided message // before re-raising the original panic. // Example: -// defer zgrab2.LogPanic("Error decoding body '%x'", body) -func LogPanic(format string, args...interface{}) { +// +// defer zgrab2.LogPanic("Error decoding body '%x'", body) +func LogPanic(format string, args ...interface{}) { err := recover() if err == nil { return @@ -224,3 +229,26 @@ func LogPanic(format string, args...interface{}) { logrus.Errorf(format, args...) panic(err) } + +// addDefaultPortToDNSServerName validates that the input DNS server address is correct and appends the default DNS port 53 if no port is specified +func addDefaultPortToDNSServerName(inAddr string) (string, error) { + // Try to split host and port to see if the port is already specified. + host, port, err := net.SplitHostPort(inAddr) + if err != nil { + // might mean there's no port specified + host = inAddr + } + + // Validate the host part as an IP address. + ip := net.ParseIP(host) + if ip == nil { + return "", fmt.Errorf("invalid IP address") + } + + // If the original input does not have a port, specify port 53 as the default + if port == "" { + port = defaultDNSPort + } + + return net.JoinHostPort(ip.String(), port), nil +}