Skip to content

Commit

Permalink
update: add udp report and duration to reports
Browse files Browse the repository at this point in the history
  • Loading branch information
amircybersec committed Nov 11, 2024
1 parent a582f04 commit d5d398b
Showing 1 changed file with 78 additions and 9 deletions.
87 changes: 78 additions & 9 deletions x/examples/test-connectivity/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type connectivityReport struct {
Test testReport `json:"test"`
DNSQueries []dnsReport `json:"dns_queries,omitempty"`
TCPConnections []tcpReport `json:"tcp_connections,omitempty"`
UDPConnections []udpReport `json:"udp_connections,omitempty"`
}

type testReport struct {
Expand All @@ -68,10 +69,21 @@ type dnsReport struct {
}

type tcpReport struct {
Hostname string `json:"hostname"`
IP string `json:"ip"`
Port string `json:"port"`
Error string `json:"error"`
Hostname string `json:"hostname"`
IP string `json:"ip"`
Port string `json:"port"`
Error string `json:"error"`
Time time.Time `json:"time"`
Duration int64 `json:"duration_ms"`
}

type udpReport struct {
Hostname string `json:"hostname"`
IP string `json:"ip"`
Port string `json:"port"`
Error string `json:"error"`
Time time.Time `json:"time"`
Duration int64 `json:"duration_ms"`
}

type errorJSON struct {
Expand All @@ -80,7 +92,8 @@ type errorJSON struct {
// Posix error, when available
PosixError string `json:"posix_error,omitempty"`
// TODO: remove IP addresses
Msg string `json:"msg,omitempty"`
Msg string `json:"msg,omitempty"`
MsgFull string `json:"msg_full,omitempty"`
}

func makeErrorRecord(result *connectivity.ConnectivityError) *errorJSON {
Expand All @@ -91,6 +104,7 @@ func makeErrorRecord(result *connectivity.ConnectivityError) *errorJSON {
record.Op = result.Op
record.PosixError = result.PosixError
record.Msg = unwrapAll(result.Err).Error()
record.MsgFull = result.Err.Error()
return record
}

Expand Down Expand Up @@ -120,7 +134,9 @@ func init() {
}
func newTCPTraceDialer(
onDNS func(ctx context.Context, domain string) func(di httptrace.DNSDoneInfo),
onDial func(ctx context.Context, network, addr string, connErr error)) transport.StreamDialer {
onDial func(ctx context.Context, network, addr string, connErr error),
onDialStart func(ctx context.Context, network, addr string),
) transport.StreamDialer {
dialer := &transport.TCPDialer{}
var onDNSDone func(di httptrace.DNSDoneInfo)
return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) {
Expand All @@ -134,6 +150,9 @@ func newTCPTraceDialer(
onDNSDone = nil
}
},
ConnectStart: func(network, addr string) {
onDialStart(ctx, network, addr)
},
ConnectDone: func(network, addr string, connErr error) {
onDial(ctx, network, addr, connErr)
},
Expand All @@ -143,7 +162,10 @@ func newTCPTraceDialer(
}

func newUDPTraceDialer(
onDNS func(ctx context.Context, domain string) func(di httptrace.DNSDoneInfo)) transport.PacketDialer {
onDNS func(ctx context.Context, domain string) func(di httptrace.DNSDoneInfo),
onDial func(ctx context.Context, network, addr string, connErr error),
onDialStart func(ctx context.Context, network, addr string),
) transport.PacketDialer {
dialer := &transport.UDPDialer{}
var onDNSDone func(di httptrace.DNSDoneInfo)
return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) {
Expand All @@ -157,6 +179,12 @@ func newUDPTraceDialer(
onDNSDone = nil
}
},
ConnectStart: func(network, addr string) {
onDialStart(ctx, network, addr)
},
ConnectDone: func(network, addr string, connErr error) {
onDial(ctx, network, addr, connErr)
},
})
return dialer.DialPacket(ctx, addr)
})
Expand Down Expand Up @@ -240,6 +268,8 @@ func main() {
var mu sync.Mutex
dnsReports := make([]dnsReport, 0)
tcpReports := make([]tcpReport, 0)
udpReports := make([]udpReport, 0)
var connectStart = make(map[string]time.Time)
configToDialer := configurl.NewDefaultConfigToDialer()
onDNS := func(ctx context.Context, domain string) func(di httptrace.DNSDoneInfo) {
dnsStart := time.Now()
Expand Down Expand Up @@ -274,6 +304,8 @@ func main() {
Hostname: hostname,
IP: ip,
Port: port,
Time: connectStart[network+"|"+addr].UTC().Truncate(time.Second),
Duration: time.Since(connectStart[network+"|"+addr]).Milliseconds(),
}
if connErr != nil {
report.Error = connErr.Error()
Expand All @@ -282,10 +314,46 @@ func main() {
tcpReports = append(tcpReports, report)
mu.Unlock()
}
return newTCPTraceDialer(onDNS, onDial).DialStream(ctx, addr)
onDialStart := func(ctx context.Context, network, addr string) {
mu.Lock()
connectStart[network+"|"+addr] = time.Now()
mu.Unlock()
}

return newTCPTraceDialer(onDNS, onDial, onDialStart).DialStream(ctx, addr)
})

configToDialer.BasePacketDialer = transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return newUDPTraceDialer(onDNS).DialPacket(ctx, addr)
hostname, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
onDialStart := func(ctx context.Context, network, addr string) {
mu.Lock()
connectStart[network+"|"+addr] = time.Now()
mu.Unlock()
}
onDial := func(ctx context.Context, network, addr string, connErr error) {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return
}
report := udpReport{
Hostname: hostname,
IP: ip,
Port: port,
Time: connectStart[network+"|"+addr].UTC().Truncate(time.Second),
Duration: time.Since(connectStart[network+"|"+addr]).Milliseconds(),
}
if connErr != nil {
report.Error = connErr.Error()
}
mu.Lock()
udpReports = append(udpReports, report)
mu.Unlock()
}

return newUDPTraceDialer(onDNS, onDial, onDialStart).DialPacket(ctx, addr)
})

switch proto {
Expand Down Expand Up @@ -336,6 +404,7 @@ func main() {
},
DNSQueries: dnsReports,
TCPConnections: tcpReports,
UDPConnections: udpReports,
}
if reportCollector != nil {
err = reportCollector.Collect(context.Background(), r)
Expand Down

0 comments on commit d5d398b

Please sign in to comment.