Skip to content

Commit

Permalink
add timeout handler and max rtt option
Browse files Browse the repository at this point in the history
you can set MaxRtt and OnTimeout func

Ontime func while be call when a request was not answered within a specified time

Signed-off-by: zhangenyao <[email protected]>
  • Loading branch information
zhangenyao authored and zey1996 committed Jun 21, 2023
1 parent 23b417c commit c225ef2
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 20 deletions.
5 changes: 5 additions & 0 deletions cmd/ping/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Examples:

func main() {
timeout := flag.Duration("t", time.Second*100000, "")
maxRtt := flag.Duration("mr", time.Second*3, "")
interval := flag.Duration("i", time.Second, "")
count := flag.Int("c", -1, "")
size := flag.Int("s", 24, "")
Expand Down Expand Up @@ -84,11 +85,15 @@ func main() {
fmt.Printf("round-trip min/avg/max/stddev = %v/%v/%v/%v\n",
stats.MinRtt, stats.AvgRtt, stats.MaxRtt, stats.StdDevRtt)
}
pinger.OnTimeOut = func(packet *probing.Packet) {
fmt.Println("timeout", packet.Addr, packet.Rtt, packet.TTL)
}

pinger.Count = *count
pinger.Size = *size
pinger.Interval = *interval
pinger.Timeout = *timeout
pinger.MaxRtt = *maxRtt
pinger.TTL = *ttl
pinger.SetPrivileged(*privileged)

Expand Down
63 changes: 47 additions & 16 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,24 +92,22 @@ var (
func New(addr string) *Pinger {
r := rand.New(rand.NewSource(getSeed()))
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})
return &Pinger{
Count: -1,
Interval: time.Second,
RecordRtts: true,
Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64),
MaxRtt: time.Duration(math.MaxInt64),

addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
trackerUUIDs: []uuid.UUID{firstUUID},
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
awaitingSequences: firstSequence,
awaitingSequences: newSeqMap(firstUUID),
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
Expand All @@ -129,6 +127,9 @@ type Pinger struct {
// Timeout specifies a timeout before ping exits, regardless of how many
// packets have been received.
Timeout time.Duration
// MaxRtt If no response is received after this time, OnTimeout is called
// important! This option is not guaranteed. and if we receive the packet that was timeout, the function OnDuplicateRecv will be called
MaxRtt time.Duration

// Count tells pinger to stop after sending (and receiving) Count echo
// packets. If this option is not specified, pinger will operate until
Expand Down Expand Up @@ -183,6 +184,8 @@ type Pinger struct {
// OnRecvError is called when an error occurs while Pinger attempts to receive a packet
OnRecvError func(error)

// OnTimeOut is called when a packet don't have received after MaxRtt.
OnTimeOut func(*Packet)
// Size of packet being sent
Size int

Expand All @@ -205,14 +208,11 @@ type Pinger struct {
// df when true sets the do-not-fragment bit in the outer IP or IPv6 header
df bool

// trackerUUIDs is the list of UUIDs being used for sending packets.
trackerUUIDs []uuid.UUID

ipv4 bool
id int
sequence int
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
awaitingSequences map[uuid.UUID]map[int]struct{}
awaitingSequences seqMap
// network is one of "ip", "ip4", or "ip6".
network string
// protocol is "icmp" or "udp".
Expand Down Expand Up @@ -520,20 +520,50 @@ func (p *Pinger) runLoop(

timeout := time.NewTicker(p.Timeout)
interval := time.NewTicker(p.Interval)
timeoutTimer := time.NewTimer(time.Duration(math.MaxInt64))
skip := false
defer func() {
interval.Stop()
timeout.Stop()
timeoutTimer.Stop()
}()

if err := p.sendICMP(conn); err != nil {
return err
}

for {
if !skip {
if !timeoutTimer.Stop() {
<-timeoutTimer.C
}
}
skip = false
first := p.awaitingSequences.peekFirst()
if first != nil {
timeoutTimer.Reset(time.Until(first.time.Add(p.MaxRtt)))
} else {
timeoutTimer.Reset(time.Duration(math.MaxInt64))
}

select {
case <-p.done:
return nil

case <-timeoutTimer.C:
skip = true
p.awaitingSequences.removeElem(first)
if p.OnTimeOut != nil {
inPkt := &Packet{
IPAddr: p.ipaddr,
Addr: p.addr,
Rtt: p.MaxRtt,
Seq: first.seq,
TTL: -1,
ID: p.id,
}
p.OnTimeOut(inPkt)
}
case <-timeout.C:
return nil

Expand Down Expand Up @@ -681,7 +711,7 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
}

for _, item := range p.trackerUUIDs {
for _, item := range p.awaitingSequences.trackerUUIDs {
if item == packetUUID {
return &packetUUID, nil
}
Expand All @@ -691,7 +721,7 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {

// getCurrentTrackerUUID grabs the latest tracker UUID.
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
return p.awaitingSequences.trackerUUIDs[len(p.awaitingSequences.trackerUUIDs)-1]
}

func (p *Pinger) processPacket(recv *packet) error {
Expand Down Expand Up @@ -744,15 +774,16 @@ func (p *Pinger) processPacket(recv *packet) error {
inPkt.Rtt = receivedAt.Sub(timestamp)
inPkt.Seq = pkt.Seq
// If we've already received this sequence, ignore it.
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
e, inflight := p.awaitingSequences.getElem(*pktUUID, pkt.Seq)
if !inflight {
p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(inPkt)
}
return nil
}
// remove it from the list of sequences we're waiting for so we don't get duplicates.
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
p.awaitingSequences.removeElem(e)
p.updateStatistics(inPkt)
default:
// Very bad, not sure how this can happen
Expand All @@ -777,7 +808,8 @@ func (p *Pinger) sendICMP(conn packetConn) error {
if err != nil {
return fmt.Errorf("unable to marshal UUID binary: %w", err)
}
t := append(timeToBytes(time.Now()), uuidEncoded...)
now := time.Now()
t := append(timeToBytes(now), uuidEncoded...)
if remainSize := p.Size - timeSliceLength - trackerLength; remainSize > 0 {
t = append(t, bytes.Repeat([]byte{1}, remainSize)...)
}
Expand Down Expand Up @@ -829,13 +861,12 @@ func (p *Pinger) sendICMP(conn packetConn) error {
p.OnSend(outPkt)
}
// mark this sequence as in-flight
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
p.awaitingSequences.putElem(currentUUID, p.sequence, now)
p.PacketsSent++
p.sequence++
if p.sequence > 65535 {
newUUID := uuid.New()
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
p.awaitingSequences[newUUID] = make(map[int]struct{})
p.awaitingSequences.newSeqMap(newUUID)
p.sequence = 0
}
break
Expand Down
13 changes: 9 additions & 4 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ func TestProcessPacket(t *testing.T) {
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
data := append(timeToBytes(time.Now()), uuidEncoded...)
now := time.Now()
data := append(timeToBytes(now), uuidEncoded...)
if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 {
data = append(data, bytes.Repeat([]byte{1}, remainSize)...)
}
Expand All @@ -39,7 +40,8 @@ func TestProcessPacket(t *testing.T) {
Seq: pinger.sequence,
Data: data,
}
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}
pinger.awaitingSequences.putElem(currentUUID, pinger.sequence, now)
//pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down Expand Up @@ -598,7 +600,8 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
data := append(timeToBytes(time.Now()), uuidEncoded...)
now := time.Now()
data := append(timeToBytes(now), uuidEncoded...)
if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 {
data = append(data, bytes.Repeat([]byte{1}, remainSize)...)
}
Expand All @@ -609,7 +612,9 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
Data: data,
}
// register the sequence as sent
pinger.awaitingSequences[currentUUID][0] = struct{}{}

pinger.awaitingSequences.putElem(currentUUID, 0, now)
//pinger.awaitingSequences[currentUUID][0] = struct{}{}

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down
70 changes: 70 additions & 0 deletions seq_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package probing

import (
"github.com/google/uuid"
"time"
)

type seqMap struct {
trackerUUIDs []uuid.UUID
head *elem
tail *elem
seqMap map[uuid.UUID]map[int]*elem
}
type elem struct {
uuid uuid.UUID
seq int
time time.Time
prev *elem
next *elem
}

func newSeqMap(u uuid.UUID) seqMap {
s := seqMap{
head: &elem{},
tail: &elem{},
seqMap: map[uuid.UUID]map[int]*elem{},
}
s.trackerUUIDs = append(s.trackerUUIDs, u)
s.seqMap[u] = make(map[int]*elem)
s.head.next = s.tail
s.tail.prev = s.head
return s
}

func (s seqMap) newSeqMap(u uuid.UUID) {
s.trackerUUIDs = append(s.trackerUUIDs, u)
s.seqMap[u] = make(map[int]*elem)
s.head.next = s.tail
s.tail.prev = s.head
}

func (s seqMap) putElem(uuid uuid.UUID, seq int, now time.Time) {
e := &elem{
uuid: uuid,
seq: seq,
time: now,
prev: s.tail.prev,
next: s.tail,
}
s.tail.prev.next = e
s.tail.prev = e
s.seqMap[uuid][seq] = e
}
func (s seqMap) getElem(uuid uuid.UUID, seq int) (*elem, bool) {
e, ok := s.seqMap[uuid][seq]
return e, ok
}
func (s seqMap) removeElem(e *elem) {
e.prev.next = e.next
e.next.prev = e.prev
if m, ok := s.seqMap[e.uuid]; ok {
delete(m, e.seq)
}
}
func (s seqMap) peekFirst() *elem {
if s.head.next == s.tail {
return nil
}
return s.head.next
}
Loading

0 comments on commit c225ef2

Please sign in to comment.