Skip to content

Commit

Permalink
refactor: seperate context logic into a new func
Browse files Browse the repository at this point in the history
  • Loading branch information
amircybersec committed Sep 6, 2024
1 parent 01d1495 commit 7c69351
Showing 1 changed file with 72 additions and 76 deletions.
148 changes: 72 additions & 76 deletions x/examples/test-connectivity/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ type dnsReport struct {
}

type tcpReport struct {
Hostname string `json:"hostname"`
IP string `json:"ip"`
Port string `json:"port"`
Error string `json:"error"`
Hostname string `json:"hostname"`
IP string `json:"ip"`
Port string `json:"port"`
Error string `json:"error"`
Time time.Time `json:"time"`
DurationMs int64 `json:"duration_ms"`
}

type errorJSON struct {
Expand Down Expand Up @@ -121,6 +123,55 @@ func init() {
}
}

func getReportFromTrace(ctx context.Context, r *connectivityReport, hostname string) context.Context {
var dnsStart, connectStart time.Time
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
DNSStart: func(di httptrace.DNSStartInfo) {
dnsStart = time.Now()
},
DNSDone: func(di httptrace.DNSDoneInfo) {
report := dnsReport{
QueryName: hostname,
Time: dnsStart.UTC().Truncate(time.Second),
DurationMs: time.Since(dnsStart).Milliseconds(),
}
if di.Err != nil {
report.Error = di.Err.Error()
}
for _, ip := range di.Addrs {
report.AnswerIPs = append(report.AnswerIPs, ip.IP.String())
}
// TODO(fortuna): Use a Mutex.
r.DNSQueries = append(r.DNSQueries, report)
},
ConnectStart: func(network, addr string) {
connectStart = time.Now()
},
ConnectDone: func(network, addr string, connErr error) {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return
}
if network == "tcp" {
report := tcpReport{
Hostname: hostname,
IP: ip,
Port: port,
Time: connectStart.UTC().Truncate(time.Second),
DurationMs: time.Since(connectStart).Milliseconds(),
}
if connErr != nil {
report.Error = connErr.Error()
}
// TODO(fortuna): Use a Mutex.
r.TCPConnections = append(r.TCPConnections, report)
}
},
})

return ctx
}

func main() {
verboseFlag := flag.Bool("v", false, "Enable debug output")
transportFlag := flag.String("transport", "", "Transport config")
Expand Down Expand Up @@ -188,10 +239,13 @@ func main() {
resolverHost := strings.TrimSpace(resolverHost)
resolverAddress := net.JoinHostPort(resolverHost, "53")
for _, proto := range strings.Split(*protoFlag, ",") {
r := &connectivityReport{
Test: testReport{},
DNSQueries: []dnsReport{},
TCPConnections: []tcpReport{},
}
proto = strings.TrimSpace(proto)
var resolver dns.Resolver
dnsReports := make([]dnsReport, 0)
tcpReports := make([]tcpReport, 0)
switch proto {
case "tcp":
configToDialer := config.NewDefaultConfigToDialer()
Expand All @@ -200,43 +254,7 @@ func main() {
if err != nil {
return nil, err
}
var dnsStart time.Time
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
DNSStart: func(di httptrace.DNSStartInfo) {
dnsStart = time.Now()
},
DNSDone: func(di httptrace.DNSDoneInfo) {
report := dnsReport{
QueryName: hostname,
Time: dnsStart.UTC().Truncate(time.Second),
DurationMs: time.Since(dnsStart).Milliseconds(),
}
if di.Err != nil {
report.Error = di.Err.Error()
}
for _, ip := range di.Addrs {
report.AnswerIPs = append(report.AnswerIPs, ip.IP.String())
}
// TODO(fortuna): Use a Mutex.
dnsReports = append(dnsReports, report)
},
ConnectDone: func(network, addr string, connErr error) {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return
}
report := tcpReport{
Hostname: hostname,
IP: ip,
Port: port,
}
if connErr != nil {
report.Error = connErr.Error()
}
// TODO(fortuna): Use a Mutex.
tcpReports = append(tcpReports, report)
},
})
ctx = setupTraceContext(ctx, r, hostname)
return (&transport.TCPDialer{}).DialStream(ctx, addr)
})
streamDialer, err := configToDialer.NewStreamDialer(*transportFlag)
Expand All @@ -252,27 +270,7 @@ func main() {
if err != nil {
return nil, err
}
var dnsStart time.Time
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
DNSStart: func(di httptrace.DNSStartInfo) {
dnsStart = time.Now()
},
DNSDone: func(di httptrace.DNSDoneInfo) {
report := dnsReport{
QueryName: hostname,
Time: dnsStart.UTC().Truncate(time.Second),
DurationMs: time.Since(dnsStart).Milliseconds(),
}
if di.Err != nil {
report.Error = di.Err.Error()
}
for _, ip := range di.Addrs {
report.AnswerIPs = append(report.AnswerIPs, ip.IP.String())
}
// TODO(fortuna): Use a Mutex.
dnsReports = append(dnsReports, report)
},
})
ctx = setupTraceContext(ctx, r, hostname)
return (&transport.UDPDialer{}).DialPacket(ctx, addr)
})
packetDialer, err := configToDialer.NewPacketDialer(*transportFlag)
Expand All @@ -297,19 +295,17 @@ func main() {
if err != nil {
log.Fatalf("Failed to sanitize config: %v", err)
}
var r report.Report = connectivityReport{
Test: testReport{
Resolver: resolverAddress,
Proto: proto,
Time: startTime.UTC().Truncate(time.Second),
// TODO(fortuna): Add sanitized config:
Transport: sanitizedConfig,
DurationMs: testDuration.Milliseconds(),
Error: makeErrorRecord(result),
},
DNSQueries: dnsReports,
TCPConnections: tcpReports,

r.Test = testReport{
Resolver: resolverAddress,
Proto: proto,
Time: startTime.UTC().Truncate(time.Second),
// TODO(fortuna): Add sanitized config:
Transport: sanitizedConfig,
DurationMs: testDuration.Milliseconds(),
Error: makeErrorRecord(result),
}

if reportCollector != nil {
err = reportCollector.Collect(context.Background(), r)
if err != nil {
Expand Down

0 comments on commit 7c69351

Please sign in to comment.