Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom DNS server #422

Merged
merged 6 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Config struct {
ConnectionsPerHost int `long:"connections-per-host" default:"1" description:"Number of times to connect to each host (results in more output)"`
ReadLimitPerHost int `long:"read-limit-per-host" default:"96" description:"Maximum total kilobytes to read for a single host (default 96kb)"`
Prometheus string `long:"prometheus" description:"Address to use for Prometheus server (e.g. localhost:8080). If empty, Prometheus is disabled."`
CustomDNS string `long:"dns" description:"Address of a custom DNS server for lookups. Default port is 53."`
Multiple MultipleCommand `command:"multiple" description:"Multiple module actions"`
inputFile *os.File
outputFile *os.File
Expand Down Expand Up @@ -128,6 +129,14 @@ func validateFrameworkConfiguration() {
if config.ReadLimitPerHost > 0 {
DefaultBytesReadLimit = config.ReadLimitPerHost * 1024
}

// Validate custom DNS
if config.CustomDNS != "" {
var err error
if config.CustomDNS, err = addDefaultPortToDNSServerName(config.CustomDNS); err != nil {
log.Fatalf("invalid DNS server address: %s", err)
}
}
}

// GetMetaFile returns the file to which metadata should be output
Expand Down
10 changes: 10 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,16 @@ func (d *Dialer) SetDefaults() *Dialer {
KeepAlive: d.Timeout,
DualStack: true,
}

// Use custom DNS as default if set
if config.CustomDNS != "" {
d.Dialer.Resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, config.CustomDNS)
},
}
}
}
return d
}
Expand Down
36 changes: 32 additions & 4 deletions utility.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@ package zgrab2

import (
"errors"
"fmt"
"net"
"regexp"
"strconv"
"strings"

"time"

"github.com/zmap/zflags"
"github.com/sirupsen/logrus"
"runtime/debug"

"github.com/sirupsen/logrus"
flags "github.com/zmap/zflags"
)

var parser *flags.Parser

const defaultDNSPort = "53"

func init() {
parser = flags.NewParser(&config, flags.Default)
}
Expand Down Expand Up @@ -214,8 +218,9 @@ func IsTimeoutError(err error) bool {
// doing anything. Otherwise, it logs the stacktrace, the panic error, and the provided message
// before re-raising the original panic.
// Example:
// defer zgrab2.LogPanic("Error decoding body '%x'", body)
func LogPanic(format string, args...interface{}) {
//
developStorm marked this conversation as resolved.
Show resolved Hide resolved
// defer zgrab2.LogPanic("Error decoding body '%x'", body)
func LogPanic(format string, args ...interface{}) {
err := recover()
if err == nil {
return
Expand All @@ -224,3 +229,26 @@ func LogPanic(format string, args...interface{}) {
logrus.Errorf(format, args...)
panic(err)
}

// addDefaultPortToDNSServerName validates that the input DNS server address is correct and appends the default DNS port 53 if no port is specified
func addDefaultPortToDNSServerName(inAddr string) (string, error) {
// Try to split host and port to see if the port is already specified.
host, port, err := net.SplitHostPort(inAddr)
if err != nil {
// might mean there's no port specified
host = inAddr
}

// Validate the host part as an IP address.
ip := net.ParseIP(host)
if ip == nil {
return "", fmt.Errorf("invalid IP address")
}

// If the original input does not have a port, specify port 53 as the default
if port == "" {
port = defaultDNSPort
}

return net.JoinHostPort(ip.String(), port), nil
}
Loading