Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add timeout handler and max rtt option #49

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmd/ping/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Examples:

func main() {
timeout := flag.Duration("t", time.Second*100000, "")
maxRtt := flag.Duration("mr", time.Second*3, "")
interval := flag.Duration("i", time.Second, "")
count := flag.Int("c", -1, "")
size := flag.Int("s", 24, "")
Expand Down Expand Up @@ -84,11 +85,15 @@ func main() {
fmt.Printf("round-trip min/avg/max/stddev = %v/%v/%v/%v\n",
stats.MinRtt, stats.AvgRtt, stats.MaxRtt, stats.StdDevRtt)
}
pinger.OnTimeOut = func(packet *probing.Packet) {
fmt.Println("timeout", packet.Addr, packet.Rtt, packet.TTL)
}

pinger.Count = *count
pinger.Size = *size
pinger.Interval = *interval
pinger.Timeout = *timeout
pinger.MaxRtt = *maxRtt
pinger.TTL = *ttl
pinger.SetPrivileged(*privileged)

Expand Down
66 changes: 46 additions & 20 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,24 +92,22 @@ var (
func New(addr string) *Pinger {
r := rand.New(rand.NewSource(getSeed()))
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})
return &Pinger{
Count: -1,
Interval: time.Second,
RecordRtts: true,
Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64),
MaxRtt: time.Duration(math.MaxInt64),

addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
trackerUUIDs: []uuid.UUID{firstUUID},
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
awaitingSequences: firstSequence,
awaitingSequences: newSeqMap(firstUUID),
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
Expand All @@ -129,6 +127,9 @@ type Pinger struct {
// Timeout specifies a timeout before ping exits, regardless of how many
// packets have been received.
Timeout time.Duration
// MaxRtt If no response is received after this time, OnTimeout is called
// important! This option is not guaranteed. and if we receive the packet that was timeout, the function OnDuplicateRecv will be called
MaxRtt time.Duration

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure either way: Should we capitalize RTT as in common style guidelines, or keep it as Rtt for consistency with RecordRtts?


// Count tells pinger to stop after sending (and receiving) Count echo
// packets. If this option is not specified, pinger will operate until
Expand Down Expand Up @@ -183,6 +184,8 @@ type Pinger struct {
// OnRecvError is called when an error occurs while Pinger attempts to receive a packet
OnRecvError func(error)

// OnTimeOut is called when a packet don't have received after MaxRtt.
OnTimeOut func(*Packet)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this naming confusing: There's a Timeout field, but OnTimeOut is for MaxRtt.

The spelling is also inconsistent with the comment on MaxRtt – I believe Timeout is more correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. There's a total runtime timeout, and this introduces a packet timeout.

Perhaps OnPacketTimeout and PacketTimeout instead of MaxRtt.

// Size of packet being sent
Size int

Expand All @@ -205,14 +208,11 @@ type Pinger struct {
// df when true sets the do-not-fragment bit in the outer IP or IPv6 header
df bool

// trackerUUIDs is the list of UUIDs being used for sending packets.
trackerUUIDs []uuid.UUID

ipv4 bool
id int
sequence int
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
awaitingSequences map[uuid.UUID]map[int]struct{}
awaitingSequences seqMap
// network is one of "ip", "ip4", or "ip6".
network string
// protocol is "icmp" or "udp".
Expand Down Expand Up @@ -520,20 +520,48 @@ func (p *Pinger) runLoop(

timeout := time.NewTicker(p.Timeout)
interval := time.NewTicker(p.Interval)
packetTimeout := time.NewTimer(time.Duration(math.MaxInt64))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I associate timeouts with contexts, would those be suitable here? Would it make a difference to how we would handle them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to be very careful about how we use stdlib functions here. We want to avoid any cases where managing the timeout would spawn a goroutine. If you had an interval of 100ms, and a timeout of 15s, that would be creating and destroying a lot of goroutines in the background.

Maybe a context call would be efficient enough? It would be useful to benchmark.

skip := false
defer func() {
interval.Stop()
timeout.Stop()
packetTimeout.Stop()
}()

if err := p.sendICMP(conn); err != nil {
return err
}

for {
if !skip && !packetTimeout.Stop() {
<-packetTimeout.C
}
skip = false
first := p.awaitingSequences.peekFirst()
if first != nil {
packetTimeout.Reset(time.Until(first.time.Add(p.MaxRtt)))
} else {
packetTimeout.Reset(time.Duration(math.MaxInt64))
}

select {
case <-p.done:
return nil

case <-packetTimeout.C:
skip = true
p.awaitingSequences.removeElem(first)
if p.OnTimeOut != nil {
inPkt := &Packet{
IPAddr: p.ipaddr,
Addr: p.addr,
Rtt: p.MaxRtt,
Seq: first.seq,
TTL: -1,
ID: p.id,
}
p.OnTimeOut(inPkt)
}
case <-timeout.C:
return nil

Expand Down Expand Up @@ -680,18 +708,15 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
if err != nil {
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
}

for _, item := range p.trackerUUIDs {
if item == packetUUID {
return &packetUUID, nil
}
if p.awaitingSequences.checkUUIDExist(packetUUID) {
return &packetUUID, nil
}
return nil, nil
}

// getCurrentTrackerUUID grabs the latest tracker UUID.
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
return p.awaitingSequences.getCurUUID()
}

func (p *Pinger) processPacket(recv *packet) error {
Expand Down Expand Up @@ -744,15 +769,16 @@ func (p *Pinger) processPacket(recv *packet) error {
inPkt.Rtt = receivedAt.Sub(timestamp)
inPkt.Seq = pkt.Seq
// If we've already received this sequence, ignore it.
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
e, inflight := p.awaitingSequences.getElem(*pktUUID, pkt.Seq)
if !inflight {
p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(inPkt)
}
return nil
}
// remove it from the list of sequences we're waiting for so we don't get duplicates.
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
p.awaitingSequences.removeElem(e)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What protects awaitingSequences from concurrent access?

p.updateStatistics(inPkt)
default:
// Very bad, not sure how this can happen
Expand All @@ -777,7 +803,8 @@ func (p *Pinger) sendICMP(conn packetConn) error {
if err != nil {
return fmt.Errorf("unable to marshal UUID binary: %w", err)
}
t := append(timeToBytes(time.Now()), uuidEncoded...)
now := time.Now()
t := append(timeToBytes(now), uuidEncoded...)
if remainSize := p.Size - timeSliceLength - trackerLength; remainSize > 0 {
t = append(t, bytes.Repeat([]byte{1}, remainSize)...)
}
Expand Down Expand Up @@ -829,13 +856,12 @@ func (p *Pinger) sendICMP(conn packetConn) error {
p.OnSend(outPkt)
}
// mark this sequence as in-flight
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
p.awaitingSequences.putElem(currentUUID, p.sequence, now)
p.PacketsSent++
p.sequence++
if p.sequence > 65535 {
newUUID := uuid.New()
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
p.awaitingSequences[newUUID] = make(map[int]struct{})
p.awaitingSequences.newSeqMap(newUUID)
p.sequence = 0
}
break
Expand Down
13 changes: 9 additions & 4 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ func TestProcessPacket(t *testing.T) {
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
data := append(timeToBytes(time.Now()), uuidEncoded...)
now := time.Now()
data := append(timeToBytes(now), uuidEncoded...)
if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 {
data = append(data, bytes.Repeat([]byte{1}, remainSize)...)
}
Expand All @@ -39,7 +40,8 @@ func TestProcessPacket(t *testing.T) {
Seq: pinger.sequence,
Data: data,
}
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}
pinger.awaitingSequences.putElem(currentUUID, pinger.sequence, now)
//pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down Expand Up @@ -598,7 +600,8 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
data := append(timeToBytes(time.Now()), uuidEncoded...)
now := time.Now()
data := append(timeToBytes(now), uuidEncoded...)
if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 {
data = append(data, bytes.Repeat([]byte{1}, remainSize)...)
}
Expand All @@ -609,7 +612,9 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
Data: data,
}
// register the sequence as sent
pinger.awaitingSequences[currentUUID][0] = struct{}{}

pinger.awaitingSequences.putElem(currentUUID, 0, now)
//pinger.awaitingSequences[currentUUID][0] = struct{}{}

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down
76 changes: 76 additions & 0 deletions seq_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package probing

import (
"github.com/google/uuid"
"time"
)

type seqMap struct {
curUUID uuid.UUID
head *elem
tail *elem
uuidMap map[uuid.UUID]map[int]*elem
}
type elem struct {
uuid uuid.UUID
seq int
time time.Time
prev *elem
next *elem
}

func newSeqMap(u uuid.UUID) seqMap {
s := seqMap{
curUUID: u,
head: &elem{},
tail: &elem{},
uuidMap: map[uuid.UUID]map[int]*elem{},
}
s.uuidMap[u] = make(map[int]*elem)
s.head.next = s.tail
s.tail.prev = s.head
return s
}

func (s seqMap) newSeqMap(u uuid.UUID) {
s.curUUID = u
s.uuidMap[u] = make(map[int]*elem)
}

func (s seqMap) putElem(uuid uuid.UUID, seq int, now time.Time) {
e := &elem{
uuid: uuid,
seq: seq,
time: now,
prev: s.tail.prev,
next: s.tail,
}
s.tail.prev.next = e
s.tail.prev = e
s.uuidMap[uuid][seq] = e
}
func (s seqMap) getElem(uuid uuid.UUID, seq int) (*elem, bool) {
e, ok := s.uuidMap[uuid][seq]
return e, ok
}
func (s seqMap) removeElem(e *elem) {
e.prev.next = e.next
e.next.prev = e.prev
if m, ok := s.uuidMap[e.uuid]; ok {
delete(m, e.seq)
}
}
func (s seqMap) peekFirst() *elem {
if s.head.next == s.tail {
return nil
}
return s.head.next
}
func (s seqMap) getCurUUID() uuid.UUID {
return s.curUUID
}

func (s seqMap) checkUUIDExist(u uuid.UUID) bool {
_, ok := s.uuidMap[u]
return ok
}
Loading