Skip to content

Commit

Permalink
Refactor packet tracking.
Browse files Browse the repository at this point in the history
Signed-off-by: SuperQ <[email protected]>
  • Loading branch information
SuperQ committed Apr 4, 2023
1 parent 20aa09d commit 5e92641
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 49 deletions.
114 changes: 114 additions & 0 deletions packet_tracking.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package probing

import (
"sync"
"time"

"github.com/google/uuid"
)

type PacketTracker struct {
currentUUID uuid.UUID
packets map[uuid.UUID]PacketSequence
sequence int
nextSequence int
timeout time.Duration
timeoutCh chan *inFlightPacket

mutex sync.RWMutex
}

type PacketSequence struct {
packets map[int]inFlightPacket
}

func (ps PacketSequence) NewInflightPacket(sequence int) {
ps.packets[sequence] = inFlightPacket{}
}

func (ps PacketSequence) GetPacket(sequence int) (inFlightPacket, bool) {
packet, ok := ps.packets[sequence]
return packet, ok
}

func (ps PacketSequence) RemovePacket(sequence int) {
delete(ps.packets, sequence)
}

type inFlightPacket struct {
timeoutTimer *time.Timer
}

func newPacketTracker(t time.Duration) *PacketTracker {
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})

return &PacketTracker{
packets: map[uuid.UUID]PacketSequence{},
sequence: 0,
timeout: t,
}
}

func (t *PacketTracker) AddPacket() int {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.nextSequence > 65535 {
newUUID := uuid.New()
t.packets[newUUID] = PacketSequence{}
t.currentUUID = newUUID
t.nextSequence = 0
}

t.sequence = t.nextSequence
t.packets[t.currentUUID].NewInflightPacket(t.sequence)
// if t.timeout > 0 {
// t.packets[t.currentUUID][t.sequence].timeoutTimer = time.Timer(t.timeout)
// }
t.nextSequence++
return t.sequence
}

// DeletePacket removes a packet from the tracker.
func (t *PacketTracker) DeletePacket(u uuid.UUID, seq int) {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.hasPacket(u, seq) {
// if _, ok := t.packets[u].GetPacket(seq) ; ok != nil {
// t.packets[u][seq].timeoutTimer.Stop()
// }
t.packets[u].RemovePacket(seq)
}
}

func (t *PacketTracker) hasPacket(u uuid.UUID, seq int) bool {
inflight, ok := t.packets[u]
if ok == false {
return ok
}
_, ok = inflight.GetPacket(seq)
return ok
}

// HasPacket checks the tracker to see if it's currently tracking a packet.
func (t *PacketTracker) HasPacket(u uuid.UUID, seq int) bool {
t.mutex.RLock()
defer t.mutex.Unlock()

return t.hasPacket(u, seq)
}

func (t *PacketTracker) HasUUID(u uuid.UUID) bool {
_, hasUUID := t.packets[u]
return hasUUID
}

func (t *PacketTracker) CurrentUUID() uuid.UUID {
// t.mutex.RLock()
// defer t.mutex.Unlock()

return t.currentUUID
}
64 changes: 24 additions & 40 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,22 @@ var (
// New returns a new Pinger struct pointer.
func New(addr string) *Pinger {
r := rand.New(rand.NewSource(getSeed()))
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})
return &Pinger{
Count: -1,
Interval: time.Second,
RecordRtts: true,
Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64),

addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
trackerUUIDs: []uuid.UUID{firstUUID},
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
awaitingSequences: firstSequence,
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
}

Expand Down Expand Up @@ -142,6 +137,9 @@ type Pinger struct {
// Number of duplicate packets received
PacketsRecvDuplicates int

// Per-packet timeout
PacketTimeout time.Duration

// Round trip time statistics
minRtt time.Duration
maxRtt time.Duration
Expand Down Expand Up @@ -188,14 +186,11 @@ type Pinger struct {
ipaddr *net.IPAddr
addr string

// trackerUUIDs is the list of UUIDs being used for sending packets.
trackerUUIDs []uuid.UUID

ipv4 bool
id int
sequence int
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
awaitingSequences map[uuid.UUID]map[int]struct{}
// tracker is a PacketTrackrer of UUIDs and sequence numbers.
tracker *PacketTracker
// network is one of "ip", "ip4", or "ip6".
network string
// protocol is "icmp" or "udp".
Expand Down Expand Up @@ -412,6 +407,9 @@ func (p *Pinger) Run() error {
if err != nil {
return err
}

p.tracker = newPacketTracker(p.PacketTimeout)

if conn, err = p.listen(); err != nil {
return err
}
Expand Down Expand Up @@ -614,19 +612,12 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
}

for _, item := range p.trackerUUIDs {
if item == packetUUID {
return &packetUUID, nil
}
if p.tracker.HasUUID(packetUUID) {
return &packetUUID, nil
}
return nil, nil
}

// getCurrentTrackerUUID grabs the latest tracker UUID.
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
}

func (p *Pinger) processPacket(recv *packet) error {
receivedAt := time.Now()
var proto int
Expand Down Expand Up @@ -675,15 +666,15 @@ func (p *Pinger) processPacket(recv *packet) error {
inPkt.Rtt = receivedAt.Sub(timestamp)
inPkt.Seq = pkt.Seq
// If we've already received this sequence, ignore it.
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
if !p.tracker.HasPacket(*pktUUID, pkt.Seq) {
p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(inPkt)
}
return nil
}
// remove it from the list of sequences we're waiting for so we don't get duplicates.
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
// Remove it from the list of sequences we're waiting for so we don't get duplicates.
p.tracker.DeletePacket(*pktUUID, pkt.Seq)
p.updateStatistics(inPkt)
default:
// Very bad, not sure how this can happen
Expand All @@ -704,7 +695,7 @@ func (p *Pinger) sendICMP(conn packetConn) error {
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
}

currentUUID := p.getCurrentTrackerUUID()
currentUUID := p.tracker.CurrentUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
return fmt.Errorf("unable to marshal UUID binary: %w", err)
Expand Down Expand Up @@ -752,15 +743,8 @@ func (p *Pinger) sendICMP(conn packetConn) error {
handler(outPkt)
}
// mark this sequence as in-flight
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
p.sequence = p.tracker.AddPacket()
p.PacketsSent++
p.sequence++
if p.sequence > 65535 {
newUUID := uuid.New()
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
p.awaitingSequences[newUUID] = make(map[int]struct{})
p.sequence = 0
}
break
}

Expand Down
19 changes: 10 additions & 9 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestProcessPacket(t *testing.T) {
shouldBe1++
}

currentUUID := pinger.getCurrentTrackerUUID()
currentUUID := pinger.tracker.CurrentUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
Expand All @@ -37,7 +37,7 @@ func TestProcessPacket(t *testing.T) {
Seq: pinger.sequence,
Data: data,
}
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}
pinger.tracker.AddPacket()

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down Expand Up @@ -66,7 +66,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) {
shouldBe0++
}

currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestProcessPacket_IDMismatch(t *testing.T) {
shouldBe0++
}

currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
Expand Down Expand Up @@ -189,7 +189,7 @@ func TestProcessPacket_LargePacket(t *testing.T) {
pinger := makeTestPinger()
pinger.Size = 4096

currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
Expand Down Expand Up @@ -484,6 +484,7 @@ func makeTestPinger() *Pinger {
pinger.protocol = "icmp"
pinger.id = 123
pinger.Size = 0
pinger.tracker = newPacketTracker(time.Second * 5)

return pinger
}
Expand Down Expand Up @@ -542,7 +543,7 @@ func BenchmarkProcessPacket(b *testing.B) {
pinger.protocol = "ip4:icmp"
pinger.id = 123

currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.tracker.CurrentUUID().MarshalBinary()
if err != nil {
b.Fatalf("unable to marshal UUID binary: %s", err)
}
Expand Down Expand Up @@ -591,7 +592,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
dups++
}

currentUUID := pinger.getCurrentTrackerUUID()
currentUUID := pinger.tracker.CurrentUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
Expand All @@ -606,8 +607,8 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
Seq: 0,
Data: data,
}
// register the sequence as sent
pinger.awaitingSequences[currentUUID][0] = struct{}{}
// Register the sequence as sent.
pinger.tracker.AddPacket()

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down

0 comments on commit 5e92641

Please sign in to comment.