Skip to content

Commit

Permalink
moved setting up the tracer context to application
Browse files Browse the repository at this point in the history
  • Loading branch information
amircybersec committed Aug 29, 2024
1 parent 247ca52 commit d81fad4
Showing 1 changed file with 152 additions and 3 deletions.
155 changes: 152 additions & 3 deletions x/examples/test-connectivity/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package main

import (
"context"
ctls "crypto/tls"
"encoding/json"
"errors"
"flag"
Expand All @@ -24,16 +25,20 @@ import (
"log"
"net"
"net/http"
"net/http/httptrace"
"net/url"
"os"
"path"
"strings"
"time"

"github.com/Jigsaw-Code/outline-sdk/dns"
"github.com/Jigsaw-Code/outline-sdk/transport/socks5"
"github.com/Jigsaw-Code/outline-sdk/transport/tls"
"github.com/Jigsaw-Code/outline-sdk/x/config"
"github.com/Jigsaw-Code/outline-sdk/x/connectivity"
"github.com/Jigsaw-Code/outline-sdk/x/report"
"golang.org/x/net/dns/dnsmessage"
)

var debugLog log.Logger = *log.New(io.Discard, "", 0)
Expand Down Expand Up @@ -100,7 +105,7 @@ func init() {

func main() {
verboseFlag := flag.Bool("v", false, "Enable debug output")
testTypeFlag := flag.String("test-type", "do53-tcp,do53-udp,doh,dot,http", "Type of test to run")
testTypeFlag := flag.String("test-type", "do53-tcp,do53-udp,doh,dot,http,http3", "Type of test to run")
transportFlag := flag.String("transport", "", "Transport config")
domainFlag := flag.String("domain", "example.com.", "Domain name to resolve in the DNS test and to fetch in the HTTP test")
methodFlag := flag.String("method", "GET", "HTTP method to use in the HTTP test")
Expand Down Expand Up @@ -165,7 +170,7 @@ func main() {
jsonEncoder := json.NewEncoder(os.Stdout)
jsonEncoder.SetEscapeHTML(false)
configToDialer := config.NewDefaultConfigToDialer()
ctx := connectivity.SetupConnectivityTrace(context.Background())
ctx := SetupConnectivityTrace(context.Background())
for _, resolverHost := range strings.Split(*resolverFlag, ",") {
resolverHost := strings.TrimSpace(resolverHost)
var result *connectivity.ConnectivityError
Expand Down Expand Up @@ -234,8 +239,18 @@ func main() {
if err != nil {
log.Fatalf("Connectivity test failed to run: %v", err)
}
case "http3":
Protocol = "udp"
packetDialer, err := configToDialer.NewPacketDialer(*transportFlag)
if err != nil {
log.Fatalf("Failed to create PacketDialer: %v", err)
}
result, err = connectivity.TestPacketConnectivitywithHTTP3(ctx, packetDialer, *domainFlag, *timeoutFlag, *methodFlag)
if err != nil {
log.Fatalf("Connectivity test failed to run: %v", err)
}
default:
log.Fatalf(`Invalid Test Type %v. Must be "tcp" or "udp"`, testType)
log.Fatalf(`Invalid Test Type %v.`, testType)
}
testDuration := time.Since(startTime)
if result == nil {
Expand Down Expand Up @@ -267,3 +282,137 @@ func main() {
}
}
}

func SetupConnectivityTrace(ctx context.Context) context.Context {
t := &dns.DNSClientTrace{
QuestionSent: func(question dnsmessage.Question) {
fmt.Println("DNS query started for", question.Name.String())
},
ResponsDone: func(question dnsmessage.Question, msg *dnsmessage.Message, err error) {
if err != nil {
fmt.Printf("DNS query for %s failed: %v\n", question.Name.String(), err)
} else {
// Prepare to collect IP addresses
var ips []string

// Iterate over the answer section
for _, answer := range msg.Answers {
switch rr := answer.Body.(type) {
case *dnsmessage.AResource:
// Handle IPv4 addresses - convert [4]byte to IP string
ipv4 := net.IP(rr.A[:]) // Convert [4]byte to net.IP
ips = append(ips, ipv4.String())
case *dnsmessage.AAAAResource:
// Handle IPv6 addresses - convert [16]byte to IP string
ipv6 := net.IP(rr.AAAA[:]) // Convert [16]byte to net.IP
ips = append(ips, ipv6.String())
}
}

// Print all resolved IP addresses
if len(ips) > 0 {
fmt.Printf("Resolved IPs for %s: %v\n", question.Name.String(), ips)
} else {
fmt.Printf("No IPs found for %s\n", question.Name.String())
}
}
},
ConnectDone: func(network, addr string, err error) {
if err != nil {
fmt.Printf("%v Connection to %s failed: %v\n", network, addr, err)
} else {
fmt.Printf("%v Connection to %s succeeded\n", network, addr)
}
},
WroteDone: func(err error) {
if err != nil {
fmt.Printf("Write failed: %v\n", err)
} else {
fmt.Println("Write succeeded")
}
},
ReadDone: func(err error) {
if err != nil {
fmt.Printf("Read failed: %v\n", err)
} else {
fmt.Println("Read succeeded")
}
},
}

// Variables to store the timestamps
var startTLS time.Time

ht := &httptrace.ClientTrace{
DNSStart: func(info httptrace.DNSStartInfo) {
fmt.Printf("DNS start: %v\n", info)
},
DNSDone: func(info httptrace.DNSDoneInfo) {
fmt.Printf("DNS done: %v\n", info)
},
ConnectStart: func(network, addr string) {
fmt.Printf("Connect start: %v %v\n", network, addr)
},
ConnectDone: func(network, addr string, err error) {
fmt.Printf("Connect done: %v %v %v\n", network, addr, err)
},
GotFirstResponseByte: func() {
fmt.Println("Got first response byte")
},
WroteHeaderField: func(key string, value []string) {
fmt.Printf("Wrote header field: %v %v\n", key, value)
},
WroteHeaders: func() {
fmt.Println("Wrote headers")
},
WroteRequest: func(info httptrace.WroteRequestInfo) {
fmt.Printf("Wrote request: %v\n", info)
},
TLSHandshakeStart: func() {
startTLS = time.Now()
},
TLSHandshakeDone: func(state ctls.ConnectionState, err error) {
if err != nil {
fmt.Printf("TLS handshake failed: %v\n", err)
}
fmt.Printf("SNI: %v\n", state.ServerName)
fmt.Printf("TLS version: %v\n", state.Version)
fmt.Printf("ALPN: %v\n", state.NegotiatedProtocol)
fmt.Printf("TLS handshake took %v seconds.\n", time.Since(startTLS).Seconds())
},
}

tlsTrace := &tls.TLSClientTrace{
TLSHandshakeStart: func() {
fmt.Println("TLS handshake started")
startTLS = time.Now()
},
TLSHandshakeDone: func(state ctls.ConnectionState, err error) {
if err != nil {
fmt.Printf("TLS handshake failed: %v\n", err)
}
fmt.Printf("SNI: %v\n", state.ServerName)
fmt.Printf("TLS version: %v\n", state.Version)
fmt.Printf("ALPN: %v\n", state.NegotiatedProtocol)
fmt.Printf("TLS handshake took %v seconds.\n", time.Since(startTLS).Seconds())
},
}

socksTrace := &socks5.SOCKS5ClientTrace{
RequestStarted: func(cmd byte, dstAddr string) {
fmt.Printf("SOCKS5 request started: cmd: %v address: %v\n", cmd, dstAddr)
},
RequestDone: func(network string, bindAddr string, err error) {
if err != nil {
fmt.Printf("SOCKS5 request failed! network: %v, bindAddr: %v, error: %v \n", network, bindAddr, err)
}
fmt.Printf("SOCKS5 request succeeded! network: %v, bindAddr: %v \n", network, bindAddr)
},
}

ctx = httptrace.WithClientTrace(ctx, ht)
ctx = dns.WithDNSClientTrace(ctx, t)
ctx = tls.WithTLSClientTrace(ctx, tlsTrace)
ctx = socks5.WithSOCKS5ClientTrace(ctx, socksTrace)
return ctx
}

0 comments on commit d81fad4

Please sign in to comment.