From f65586ccd3d3c887f89eb0ce85d3c234a1886c11 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 20 Sep 2024 12:41:12 -0400 Subject: [PATCH] feat: add DNS and TCP reports to test-connectivity (#272) --- x/examples/test-connectivity/main.go | 192 ++++++++++++++++++++++----- x/go.mod | 2 + x/go.sum | 2 + 3 files changed, 165 insertions(+), 31 deletions(-) diff --git a/x/examples/test-connectivity/main.go b/x/examples/test-connectivity/main.go index ee82e558..1edd5584 100644 --- a/x/examples/test-connectivity/main.go +++ b/x/examples/test-connectivity/main.go @@ -20,27 +20,33 @@ import ( "errors" "flag" "fmt" - "io" - "log" + "log/slog" "net" "net/http" + "net/http/httptrace" "net/url" "os" "path" "strings" + "sync" "time" "github.com/Jigsaw-Code/outline-sdk/dns" + "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/x/configurl" "github.com/Jigsaw-Code/outline-sdk/x/connectivity" "github.com/Jigsaw-Code/outline-sdk/x/report" + "github.com/lmittmann/tint" + "golang.org/x/term" ) -var debugLog log.Logger = *log.New(io.Discard, "", 0) - -// var errorLog log.Logger = *log.New(os.Stderr, "[ERROR] ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) - type connectivityReport struct { + Test testReport `json:"test"` + DNSQueries []dnsReport `json:"dns_queries,omitempty"` + TCPConnections []tcpReport `json:"tcp_connections,omitempty"` +} + +type testReport struct { // Inputs Resolver string `json:"resolver"` Proto string `json:"proto"` @@ -53,6 +59,21 @@ type connectivityReport struct { Error *errorJSON `json:"error"` } +type dnsReport struct { + QueryName string `json:"query_name"` + Time time.Time `json:"time"` + DurationMs int64 `json:"duration_ms"` + AnswerIPs []string `json:"answer_ips"` + Error string `json:"error"` +} + +type tcpReport struct { + Hostname string `json:"hostname"` + IP string `json:"ip"` + Port string `json:"port"` + Error string `json:"error"` +} + type errorJSON struct { // TODO: add Shadowsocks/Transport error Op string `json:"op,omitempty"` @@ -84,7 +105,7 @@ func unwrapAll(err error) error { } func (r connectivityReport) IsSuccess() bool { - if r.Error == nil { + if r.Test.Error == nil { return true } else { return false @@ -97,6 +118,49 @@ func init() { flag.PrintDefaults() } } +func newTCPTraceDialer( + onDNS func(ctx context.Context, domain string) func(di httptrace.DNSDoneInfo), + onDial func(ctx context.Context, network, addr string, connErr error)) transport.StreamDialer { + dialer := &transport.TCPDialer{} + var onDNSDone func(di httptrace.DNSDoneInfo) + return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + DNSStart: func(di httptrace.DNSStartInfo) { + onDNSDone = onDNS(ctx, di.Host) + }, + DNSDone: func(di httptrace.DNSDoneInfo) { + if onDNSDone != nil { + onDNSDone(di) + onDNSDone = nil + } + }, + ConnectDone: func(network, addr string, connErr error) { + onDial(ctx, network, addr, connErr) + }, + }) + return dialer.DialStream(ctx, addr) + }) +} + +func newUDPTraceDialer( + onDNS func(ctx context.Context, domain string) func(di httptrace.DNSDoneInfo)) transport.PacketDialer { + dialer := &transport.UDPDialer{} + var onDNSDone func(di httptrace.DNSDoneInfo) + return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + DNSStart: func(di httptrace.DNSStartInfo) { + onDNSDone = onDNS(ctx, di.Host) + }, + DNSDone: func(di httptrace.DNSDoneInfo) { + if onDNSDone != nil { + onDNSDone(di) + onDNSDone = nil + } + }, + }) + return dialer.DialPacket(ctx, addr) + }) +} func main() { verboseFlag := flag.Bool("v", false, "Enable debug output") @@ -110,28 +174,34 @@ func main() { flag.Parse() + logLevel := slog.LevelInfo + if *verboseFlag { + logLevel = slog.LevelDebug + } + slog.SetDefault(slog.New(tint.NewHandler( + os.Stderr, + &tint.Options{NoColor: !term.IsTerminal(int(os.Stderr.Fd())), Level: logLevel}, + ))) + // Perform custom range validation for sampling rate if *reportSuccessFlag < 0.0 || *reportSuccessFlag > 1.0 { - fmt.Println("Error: report-success-rate must be between 0 and 1.") + slog.Error("Error: report-success-rate must be between 0 and 1.", "report-success-rate", *reportSuccessFlag) flag.Usage() - return + os.Exit(1) } if *reportFailureFlag < 0.0 || *reportFailureFlag > 1.0 { - fmt.Println("Error: report-failure-rate must be between 0 and 1.") + slog.Error("Error: report-failure-rate must be between 0 and 1.", "report-failure-rate", *reportFailureFlag) flag.Usage() - return - } - - if *verboseFlag { - debugLog = *log.New(os.Stderr, "[DEBUG] ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) + os.Exit(1) } var reportCollector report.Collector if *reportToFlag != "" { collectorURL, err := url.Parse(*reportToFlag) if err != nil { - debugLog.Printf("Failed to parse collector URL: %v", err) + slog.Error("Failed to parse collector URL", "url", err) + os.Exit(1) } remoteCollector := &report.RemoteCollector{ CollectorURL: collectorURL, @@ -161,56 +231,116 @@ func main() { success := false jsonEncoder := json.NewEncoder(os.Stdout) jsonEncoder.SetEscapeHTML(false) - configToDialer := configurl.NewDefaultConfigToDialer() for _, resolverHost := range strings.Split(*resolverFlag, ",") { resolverHost := strings.TrimSpace(resolverHost) resolverAddress := net.JoinHostPort(resolverHost, "53") for _, proto := range strings.Split(*protoFlag, ",") { proto = strings.TrimSpace(proto) var resolver dns.Resolver + var mu sync.Mutex + dnsReports := make([]dnsReport, 0) + tcpReports := make([]tcpReport, 0) + configToDialer := configurl.NewDefaultConfigToDialer() + onDNS := func(ctx context.Context, domain string) func(di httptrace.DNSDoneInfo) { + dnsStart := time.Now() + return func(di httptrace.DNSDoneInfo) { + report := dnsReport{ + QueryName: domain, + Time: dnsStart.UTC().Truncate(time.Second), + DurationMs: time.Since(dnsStart).Milliseconds(), + } + if di.Err != nil { + report.Error = di.Err.Error() + } + for _, ip := range di.Addrs { + report.AnswerIPs = append(report.AnswerIPs, ip.IP.String()) + } + mu.Lock() + dnsReports = append(dnsReports, report) + mu.Unlock() + } + } + configToDialer.BaseStreamDialer = transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + hostname, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + onDial := func(ctx context.Context, network, addr string, connErr error) { + ip, port, err := net.SplitHostPort(addr) + if err != nil { + return + } + report := tcpReport{ + Hostname: hostname, + IP: ip, + Port: port, + } + if connErr != nil { + report.Error = connErr.Error() + } + mu.Lock() + tcpReports = append(tcpReports, report) + mu.Unlock() + } + return newTCPTraceDialer(onDNS, onDial).DialStream(ctx, addr) + }) + configToDialer.BasePacketDialer = transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return newUDPTraceDialer(onDNS).DialPacket(ctx, addr) + }) + switch proto { case "tcp": streamDialer, err := configToDialer.NewStreamDialer(*transportFlag) if err != nil { - log.Fatalf("Failed to create StreamDialer: %v", err) + slog.Error("Failed to create StreamDialer", "error", err) + os.Exit(1) } resolver = dns.NewTCPResolver(streamDialer, resolverAddress) + case "udp": packetDialer, err := configToDialer.NewPacketDialer(*transportFlag) if err != nil { - log.Fatalf("Failed to create PacketDialer: %v", err) + slog.Error("Failed to create PacketDialer", "error", err) + os.Exit(1) } resolver = dns.NewUDPResolver(packetDialer, resolverAddress) default: - log.Fatalf(`Invalid proto %v. Must be "tcp" or "udp"`, proto) + slog.Error(`Invalid proto. Must be "tcp" or "udp"`, "proto", proto) + os.Exit(1) } startTime := time.Now() result, err := connectivity.TestConnectivityWithResolver(context.Background(), resolver, *domainFlag) if err != nil { - log.Fatalf("Connectivity test failed to run: %v", err) + slog.Error("Connectivity test failed to run", "error", err) + os.Exit(1) } testDuration := time.Since(startTime) if result == nil { success = true } - debugLog.Printf("Test %v %v result: %v", proto, resolverAddress, result) + slog.Debug("Test done", "proto", proto, "resolver", resolverAddress, "result", result) sanitizedConfig, err := configurl.SanitizeConfig(*transportFlag) if err != nil { - log.Fatalf("Failed to sanitize config: %v", err) + slog.Error("Failed to sanitize config", "error", err) + os.Exit(1) } var r report.Report = connectivityReport{ - Resolver: resolverAddress, - Proto: proto, - Time: startTime.UTC().Truncate(time.Second), - // TODO(fortuna): Add sanitized config: - Transport: sanitizedConfig, - DurationMs: testDuration.Milliseconds(), - Error: makeErrorRecord(result), + Test: testReport{ + Resolver: resolverAddress, + Proto: proto, + Time: startTime.UTC().Truncate(time.Second), + // TODO(fortuna): Add sanitized config: + Transport: sanitizedConfig, + DurationMs: testDuration.Milliseconds(), + Error: makeErrorRecord(result), + }, + DNSQueries: dnsReports, + TCPConnections: tcpReports, } if reportCollector != nil { err = reportCollector.Collect(context.Background(), r) if err != nil { - debugLog.Printf("Failed to collect report: %v\n", err) + slog.Warn("Failed to collect report", "error", err) } } } diff --git a/x/go.mod b/x/go.mod index 0b322d15..a1dfe72a 100644 --- a/x/go.mod +++ b/x/go.mod @@ -7,12 +7,14 @@ require ( // Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 + github.com/lmittmann/tint v1.0.5 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.1.0 golang.org/x/mobile v0.0.0-20240520174638-fa72addaaa1b golang.org/x/net v0.25.0 golang.org/x/sys v0.20.0 + golang.org/x/term v0.20.0 ) require ( diff --git a/x/go.sum b/x/go.sum index 7c6f5f80..033308fb 100644 --- a/x/go.sum +++ b/x/go.sum @@ -104,6 +104,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= +github.com/lmittmann/tint v1.0.5 h1:NQclAutOfYsqs2F1Lenue6OoWCajs5wJcP3DfWVpePw= +github.com/lmittmann/tint v1.0.5/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= github.com/marusama/semaphore v0.0.0-20171214154724-565ffd8e868a h1:6SRny9FLB1eWasPyDUqBQnMi9NhXU01XIlB0ao89YoI= github.com/marusama/semaphore v0.0.0-20171214154724-565ffd8e868a/go.mod h1:TmeOqAKoDinfPfSohs14CO3VcEf7o+Bem6JiNe05yrQ= github.com/mdlayher/netlink v1.4.2-0.20210930205308-a81a8c23d40a h1:yk5OmRew64lWdeNanQ3l0hDgUt1E8MfipPhh/GO9Tuw=