Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NETPATH-346] Reduce impact on NAT Gateway #30789

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 59 additions & 9 deletions pkg/networkpath/traceroute/tcp/tcpv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@ func (t *TCPv4) TracerouteSequential() (*Results, error) {
//
// TODO: do this once for the probe and hang on to the
// listener until we decide to close the probe
addr, err := localAddrForHost(t.Target, t.DestPort)
addr, conn, err := localAddrForHost(t.Target, t.DestPort)
if err != nil {
if conn != nil {
conn.Close()
}
return nil, fmt.Errorf("failed to get local address for target: %w", err)
}
defer conn.Close()
t.srcIP = addr.IP
t.srcPort = addr.AddrPort().Port()

Expand Down Expand Up @@ -103,9 +107,11 @@ func (t *TCPv4) TracerouteSequential() (*Results, error) {
// hops should be of length # of hops
hops := make([]*Hop, 0, t.MaxTTL-t.MinTTL)

var hop *Hop
var ackSeqNum, ackAckNum uint32
for i := int(t.MinTTL); i <= int(t.MaxTTL); i++ {
seqNumber := rand.Uint32()
hop, err := t.sendAndReceive(rawIcmpConn, rawTCPConn, i, seqNumber, t.Timeout)
hop, ackSeqNum, ackAckNum, err = t.sendAndReceive(rawIcmpConn, rawTCPConn, i, seqNumber, t.Timeout)
if err != nil {
return nil, fmt.Errorf("failed to run traceroute: %w", err)
}
Expand All @@ -117,6 +123,13 @@ func (t *TCPv4) TracerouteSequential() (*Results, error) {
break
}
}
// TODO: should we use maxTTL always, or if we know the exact TTL
// should we use that instead? I think we should use the max in case
// the packet takes a different route. Perhaps even higher than the max?
err = t.sendTCPReset(rawTCPConn, int(t.MaxTTL), ackSeqNum, ackAckNum)
if err != nil {
log.Errorf("failed to send TCP RST: %s", err.Error())
}

return &Results{
Source: t.srcIP,
Expand All @@ -127,24 +140,34 @@ func (t *TCPv4) TracerouteSequential() (*Results, error) {
}, nil
}

func (t *TCPv4) sendAndReceive(rawIcmpConn *ipv4.RawConn, rawTCPConn *ipv4.RawConn, ttl int, seqNum uint32, timeout time.Duration) (*Hop, error) {
tcpHeader, tcpPacket, err := createRawTCPSyn(t.srcIP, t.srcPort, t.Target, t.DestPort, seqNum, ttl)
func (t *TCPv4) sendAndReceive(rawIcmpConn *ipv4.RawConn, rawTCPConn *ipv4.RawConn, ttl int, seqNum uint32, timeout time.Duration) (*Hop, uint32, uint32, error) {
// Create SYN packet
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(t.srcPort),
DstPort: layers.TCPPort(t.DestPort),
Seq: seqNum,
Ack: 0,
SYN: true,
Window: 1024,
}

tcpHeader, tcpPacket, err := createRawTCPPkt(t.srcIP, t.Target, ttl, tcpLayer)
if err != nil {
log.Errorf("failed to create TCP packet with TTL: %d, error: %s", ttl, err.Error())
return nil, err
return nil, 0, 0, err
}

err = sendPacket(rawTCPConn, tcpHeader, tcpPacket)
if err != nil {
log.Errorf("failed to send TCP SYN: %s", err.Error())
return nil, err
return nil, 0, 0, err
}

start := time.Now() // TODO: is this the best place to start?
hopIP, hopPort, icmpType, end, err := listenPackets(rawIcmpConn, rawTCPConn, timeout, t.srcIP, t.srcPort, t.Target, t.DestPort, seqNum)
hopIP, hopPort, icmpType, ackSeqNum, ackAckNum, end, err := listenPackets(rawIcmpConn, rawTCPConn, timeout, t.srcIP, t.srcPort, t.Target, t.DestPort, seqNum)
if err != nil {
log.Errorf("failed to listen for packets: %s", err.Error())
return nil, err
return nil, 0, 0, err
}

rtt := time.Duration(0)
Expand All @@ -158,7 +181,34 @@ func (t *TCPv4) sendAndReceive(rawIcmpConn *ipv4.RawConn, rawTCPConn *ipv4.RawCo
ICMPType: icmpType,
RTT: rtt,
IsDest: hopIP.Equal(t.Target),
}, nil
}, ackSeqNum, ackAckNum, nil
}

func (t *TCPv4) sendTCPReset(rawTCPConn *ipv4.RawConn, ttl int, seqNum uint32, ackNum uint32) error {
// Create RST packet
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(t.srcPort),
DstPort: layers.TCPPort(t.DestPort),
Seq: ackNum,
Ack: seqNum + 1,
RST: true,
ACK: true,
Window: 1024,
}

tcpHeader, tcpPacket, err := createRawTCPPkt(t.srcIP, t.Target, ttl, tcpLayer)
if err != nil {
log.Errorf("failed to create TCP packet with TTL: %d, error: %s", ttl, err.Error())
return err
}

err = sendPacket(rawTCPConn, tcpHeader, tcpPacket)
if err != nil {
log.Errorf("failed to send TCP RST: %s", err.Error())
return err
}

return nil
}

// Close doesn't to anything yet, but we should
Expand Down
54 changes: 24 additions & 30 deletions pkg/networkpath/traceroute/tcp/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,28 @@ type (
}
)

func localAddrForHost(destIP net.IP, destPort uint16) (*net.UDPAddr, error) {
// localAddrForHost returns the local address and connection for the host
// the connection should be closed by the caller
func localAddrForHost(destIP net.IP, destPort uint16) (*net.UDPAddr, net.Conn, error) {
// this is a quick way to get the local address for connecting to the host
// using UDP as the network type to avoid actually creating a connection to
// the host, just get the OS to give us a local IP and local ephemeral port
conn, err := net.Dial("udp4", net.JoinHostPort(destIP.String(), strconv.Itoa(int(destPort))))
if err != nil {
return nil, err
return nil, nil, err
}
defer conn.Close()
localAddr := conn.LocalAddr()

localUDPAddr, ok := localAddr.(*net.UDPAddr)
if !ok {
return nil, fmt.Errorf("invalid address type for %s: want %T, got %T", localAddr, localUDPAddr, localAddr)
return nil, conn, fmt.Errorf("invalid address type for %s: want %T, got %T", localAddr, localUDPAddr, localAddr)
}

return localUDPAddr, nil
return localUDPAddr, conn, nil
}

// createRawTCPSyn creates a TCP packet with the specified parameters
func createRawTCPSyn(sourceIP net.IP, sourcePort uint16, destIP net.IP, destPort uint16, seqNum uint32, ttl int) (*ipv4.Header, []byte, error) {
// createRawTCPPkt creates a TCP packet with the specified parameters
func createRawTCPPkt(sourceIP net.IP, destIP net.IP, ttl int, tcpLayer *layers.TCP) (*ipv4.Header, []byte, error) {
ipLayer := &layers.IPv4{
Version: 4,
Length: 20,
Expand All @@ -98,15 +99,6 @@ func createRawTCPSyn(sourceIP net.IP, sourcePort uint16, destIP net.IP, destPort
SrcIP: sourceIP,
}

tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(sourcePort),
DstPort: layers.TCPPort(destPort),
Seq: seqNum,
Ack: 0,
SYN: true,
Window: 1024,
}

err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
if err != nil {
return nil, nil, fmt.Errorf("failed to create packet checksum: %w", err)
Expand Down Expand Up @@ -144,7 +136,7 @@ func sendPacket(rawConn rawConnWrapper, header *ipv4.Header, payload []byte) err
// receives a matching packet within the timeout, a blank response is returned.
// Once a matching packet is received by a listener, it will cause the other listener
// to be canceled, and data from the matching packet will be returned to the caller
func listenPackets(icmpConn rawConnWrapper, tcpConn rawConnWrapper, timeout time.Duration, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, seqNum uint32) (net.IP, uint16, layers.ICMPv4TypeCode, time.Time, error) {
func listenPackets(icmpConn rawConnWrapper, tcpConn rawConnWrapper, timeout time.Duration, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, seqNum uint32) (net.IP, uint16, layers.ICMPv4TypeCode, uint32, uint32, time.Time, error) {
var tcpErr error
var icmpErr error
var wg sync.WaitGroup
Expand All @@ -153,19 +145,21 @@ func listenPackets(icmpConn rawConnWrapper, tcpConn rawConnWrapper, timeout time
var icmpCode layers.ICMPv4TypeCode
var tcpFinished time.Time
var icmpFinished time.Time
var ackSeqNum uint32
var ackAckNum uint32
var port uint16
wg.Add(2)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
go func() {
defer wg.Done()
defer cancel()
tcpIP, port, _, tcpFinished, tcpErr = handlePackets(ctx, tcpConn, "tcp", localIP, localPort, remoteIP, remotePort, seqNum)
tcpIP, port, _, ackSeqNum, ackAckNum, tcpFinished, tcpErr = handlePackets(ctx, tcpConn, "tcp", localIP, localPort, remoteIP, remotePort, seqNum)
}()
go func() {
defer wg.Done()
defer cancel()
icmpIP, _, icmpCode, icmpFinished, icmpErr = handlePackets(ctx, icmpConn, "icmp", localIP, localPort, remoteIP, remotePort, seqNum)
icmpIP, _, icmpCode, _, _, icmpFinished, icmpErr = handlePackets(ctx, icmpConn, "icmp", localIP, localPort, remoteIP, remotePort, seqNum)
}()
wg.Wait()

Expand All @@ -174,7 +168,7 @@ func listenPackets(icmpConn rawConnWrapper, tcpConn rawConnWrapper, timeout time
_, icmpCanceled := icmpErr.(canceledError)
if icmpCanceled && tcpCanceled {
log.Trace("timed out waiting for responses")
return net.IP{}, 0, 0, time.Time{}, nil
return net.IP{}, 0, 0, 0, 0, time.Time{}, nil
}
if tcpErr != nil {
log.Errorf("TCP listener error: %s", tcpErr.Error())
Expand All @@ -183,35 +177,35 @@ func listenPackets(icmpConn rawConnWrapper, tcpConn rawConnWrapper, timeout time
log.Errorf("ICMP listener error: %s", icmpErr.Error())
}

return net.IP{}, 0, 0, time.Time{}, multierr.Append(fmt.Errorf("tcp error: %w", tcpErr), fmt.Errorf("icmp error: %w", icmpErr))
return net.IP{}, 0, 0, 0, 0, time.Time{}, multierr.Append(fmt.Errorf("tcp error: %w", tcpErr), fmt.Errorf("icmp error: %w", icmpErr))
}

// if there was an error for TCP, but not
// ICMP, return the ICMP response
if tcpErr != nil {
return icmpIP, port, icmpCode, icmpFinished, nil
return icmpIP, port, icmpCode, 0, 0, icmpFinished, nil
}

// return the TCP response
return tcpIP, port, 0, tcpFinished, nil
return tcpIP, port, 0, ackSeqNum, ackAckNum, tcpFinished, nil
}

// handlePackets in its current implementation should listen for the first matching
// packet on the connection and then return. If no packet is received within the
// timeout or if the listener is canceled, it should return a canceledError
func handlePackets(ctx context.Context, conn rawConnWrapper, listener string, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, seqNum uint32) (net.IP, uint16, layers.ICMPv4TypeCode, time.Time, error) {
func handlePackets(ctx context.Context, conn rawConnWrapper, listener string, localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, seqNum uint32) (net.IP, uint16, layers.ICMPv4TypeCode, uint32, uint32, time.Time, error) {
buf := make([]byte, 1024)
tp := newTCPParser()
for {
select {
case <-ctx.Done():
return net.IP{}, 0, 0, time.Time{}, canceledError("listener canceled")
return net.IP{}, 0, 0, 0, 0, time.Time{}, canceledError("listener canceled")
default:
}
now := time.Now()
err := conn.SetReadDeadline(now.Add(time.Millisecond * 100))
if err != nil {
return net.IP{}, 0, 0, time.Time{}, fmt.Errorf("failed to read: %w", err)
return net.IP{}, 0, 0, 0, 0, time.Time{}, fmt.Errorf("failed to read: %w", err)
}
header, packet, _, err := conn.ReadFrom(buf)
if err != nil {
Expand All @@ -220,7 +214,7 @@ func handlePackets(ctx context.Context, conn rawConnWrapper, listener string, lo
continue
}
}
return net.IP{}, 0, 0, time.Time{}, err
return net.IP{}, 0, 0, 0, 0, time.Time{}, err
}
// once we have a packet, take a timestamp to know when
// the response was received, if it matches, we will
Expand All @@ -235,7 +229,7 @@ func handlePackets(ctx context.Context, conn rawConnWrapper, listener string, lo
continue
}
if icmpMatch(localIP, localPort, remoteIP, remotePort, seqNum, icmpResponse) {
return icmpResponse.SrcIP, 0, icmpResponse.TypeCode, received, nil
return icmpResponse.SrcIP, 0, icmpResponse.TypeCode, 0, 0, received, nil
}
} else if listener == "tcp" {
tcpResp, err := tp.parseTCP(header, packet)
Expand All @@ -244,10 +238,10 @@ func handlePackets(ctx context.Context, conn rawConnWrapper, listener string, lo
continue
}
if tcpMatch(localIP, localPort, remoteIP, remotePort, seqNum, tcpResp) {
return tcpResp.SrcIP, uint16(tcpResp.TCPResponse.SrcPort), 0, received, nil
return tcpResp.SrcIP, uint16(tcpResp.TCPResponse.SrcPort), 0, tcpResp.TCPResponse.Seq, tcpResp.TCPResponse.Ack, received, nil
}
} else {
return net.IP{}, 0, 0, received, fmt.Errorf("unsupported listener type")
return net.IP{}, 0, 0, 0, 0, received, fmt.Errorf("unsupported listener type")
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/networkpath/traceroute/tcp/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func Test_handlePackets(t *testing.T) {
t.Run(test.description, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), test.ctxTimeout)
defer cancel()
actualIP, actualPort, actualTypeCode, _, err := handlePackets(ctx, test.conn, test.listener, test.localIP, test.localPort, test.remoteIP, test.remotePort, test.seqNum)
actualIP, actualPort, actualTypeCode, _, _, _, err := handlePackets(ctx, test.conn, test.listener, test.localIP, test.localPort, test.remoteIP, test.remotePort, test.seqNum)
if test.errMsg != "" {
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), test.errMsg))
Expand Down
Loading