From 7c6935181e6ede88c16ead6745ed6fb742c800f0 Mon Sep 17 00:00:00 2001 From: amir gh Date: Thu, 5 Sep 2024 22:04:17 -0700 Subject: [PATCH] refactor: seperate context logic into a new func --- x/examples/test-connectivity/main.go | 148 +++++++++++++-------------- 1 file changed, 72 insertions(+), 76 deletions(-) diff --git a/x/examples/test-connectivity/main.go b/x/examples/test-connectivity/main.go index a841268a..981c84f2 100644 --- a/x/examples/test-connectivity/main.go +++ b/x/examples/test-connectivity/main.go @@ -70,10 +70,12 @@ type dnsReport struct { } type tcpReport struct { - Hostname string `json:"hostname"` - IP string `json:"ip"` - Port string `json:"port"` - Error string `json:"error"` + Hostname string `json:"hostname"` + IP string `json:"ip"` + Port string `json:"port"` + Error string `json:"error"` + Time time.Time `json:"time"` + DurationMs int64 `json:"duration_ms"` } type errorJSON struct { @@ -121,6 +123,55 @@ func init() { } } +func getReportFromTrace(ctx context.Context, r *connectivityReport, hostname string) context.Context { + var dnsStart, connectStart time.Time + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + DNSStart: func(di httptrace.DNSStartInfo) { + dnsStart = time.Now() + }, + DNSDone: func(di httptrace.DNSDoneInfo) { + report := dnsReport{ + QueryName: hostname, + Time: dnsStart.UTC().Truncate(time.Second), + DurationMs: time.Since(dnsStart).Milliseconds(), + } + if di.Err != nil { + report.Error = di.Err.Error() + } + for _, ip := range di.Addrs { + report.AnswerIPs = append(report.AnswerIPs, ip.IP.String()) + } + // TODO(fortuna): Use a Mutex. + r.DNSQueries = append(r.DNSQueries, report) + }, + ConnectStart: func(network, addr string) { + connectStart = time.Now() + }, + ConnectDone: func(network, addr string, connErr error) { + ip, port, err := net.SplitHostPort(addr) + if err != nil { + return + } + if network == "tcp" { + report := tcpReport{ + Hostname: hostname, + IP: ip, + Port: port, + Time: connectStart.UTC().Truncate(time.Second), + DurationMs: time.Since(connectStart).Milliseconds(), + } + if connErr != nil { + report.Error = connErr.Error() + } + // TODO(fortuna): Use a Mutex. + r.TCPConnections = append(r.TCPConnections, report) + } + }, + }) + + return ctx +} + func main() { verboseFlag := flag.Bool("v", false, "Enable debug output") transportFlag := flag.String("transport", "", "Transport config") @@ -188,10 +239,13 @@ func main() { resolverHost := strings.TrimSpace(resolverHost) resolverAddress := net.JoinHostPort(resolverHost, "53") for _, proto := range strings.Split(*protoFlag, ",") { + r := &connectivityReport{ + Test: testReport{}, + DNSQueries: []dnsReport{}, + TCPConnections: []tcpReport{}, + } proto = strings.TrimSpace(proto) var resolver dns.Resolver - dnsReports := make([]dnsReport, 0) - tcpReports := make([]tcpReport, 0) switch proto { case "tcp": configToDialer := config.NewDefaultConfigToDialer() @@ -200,43 +254,7 @@ func main() { if err != nil { return nil, err } - var dnsStart time.Time - ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ - DNSStart: func(di httptrace.DNSStartInfo) { - dnsStart = time.Now() - }, - DNSDone: func(di httptrace.DNSDoneInfo) { - report := dnsReport{ - QueryName: hostname, - Time: dnsStart.UTC().Truncate(time.Second), - DurationMs: time.Since(dnsStart).Milliseconds(), - } - if di.Err != nil { - report.Error = di.Err.Error() - } - for _, ip := range di.Addrs { - report.AnswerIPs = append(report.AnswerIPs, ip.IP.String()) - } - // TODO(fortuna): Use a Mutex. - dnsReports = append(dnsReports, report) - }, - ConnectDone: func(network, addr string, connErr error) { - ip, port, err := net.SplitHostPort(addr) - if err != nil { - return - } - report := tcpReport{ - Hostname: hostname, - IP: ip, - Port: port, - } - if connErr != nil { - report.Error = connErr.Error() - } - // TODO(fortuna): Use a Mutex. - tcpReports = append(tcpReports, report) - }, - }) + ctx = setupTraceContext(ctx, r, hostname) return (&transport.TCPDialer{}).DialStream(ctx, addr) }) streamDialer, err := configToDialer.NewStreamDialer(*transportFlag) @@ -252,27 +270,7 @@ func main() { if err != nil { return nil, err } - var dnsStart time.Time - ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ - DNSStart: func(di httptrace.DNSStartInfo) { - dnsStart = time.Now() - }, - DNSDone: func(di httptrace.DNSDoneInfo) { - report := dnsReport{ - QueryName: hostname, - Time: dnsStart.UTC().Truncate(time.Second), - DurationMs: time.Since(dnsStart).Milliseconds(), - } - if di.Err != nil { - report.Error = di.Err.Error() - } - for _, ip := range di.Addrs { - report.AnswerIPs = append(report.AnswerIPs, ip.IP.String()) - } - // TODO(fortuna): Use a Mutex. - dnsReports = append(dnsReports, report) - }, - }) + ctx = setupTraceContext(ctx, r, hostname) return (&transport.UDPDialer{}).DialPacket(ctx, addr) }) packetDialer, err := configToDialer.NewPacketDialer(*transportFlag) @@ -297,19 +295,17 @@ func main() { if err != nil { log.Fatalf("Failed to sanitize config: %v", err) } - var r report.Report = connectivityReport{ - Test: testReport{ - Resolver: resolverAddress, - Proto: proto, - Time: startTime.UTC().Truncate(time.Second), - // TODO(fortuna): Add sanitized config: - Transport: sanitizedConfig, - DurationMs: testDuration.Milliseconds(), - Error: makeErrorRecord(result), - }, - DNSQueries: dnsReports, - TCPConnections: tcpReports, + + r.Test = testReport{ + Resolver: resolverAddress, + Proto: proto, + Time: startTime.UTC().Truncate(time.Second), + // TODO(fortuna): Add sanitized config: + Transport: sanitizedConfig, + DurationMs: testDuration.Milliseconds(), + Error: makeErrorRecord(result), } + if reportCollector != nil { err = reportCollector.Collect(context.Background(), r) if err != nil {