From 58f802bc1789b86e4de5709d53a49c7d22061d49 Mon Sep 17 00:00:00 2001 From: zhangenyao Date: Mon, 5 Jun 2023 18:07:29 +0800 Subject: [PATCH] add timeout handler and max rtt option you can set MaxRtt and OnTimeout func Ontime func while be call when a request was not answered within a specified time Signed-off-by: zhangenyao --- cmd/ping/ping.go | 5 +++ ping.go | 66 ++++++++++++++++++++++++----------- ping_test.go | 13 ++++--- seq_map.go | 76 ++++++++++++++++++++++++++++++++++++++++ seq_map_test.go | 91 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 227 insertions(+), 24 deletions(-) create mode 100644 seq_map.go create mode 100644 seq_map_test.go diff --git a/cmd/ping/ping.go b/cmd/ping/ping.go index 645f2e2..b5600da 100644 --- a/cmd/ping/ping.go +++ b/cmd/ping/ping.go @@ -38,6 +38,7 @@ Examples: func main() { timeout := flag.Duration("t", time.Second*100000, "") + maxRtt := flag.Duration("mr", time.Second*3, "") interval := flag.Duration("i", time.Second, "") count := flag.Int("c", -1, "") size := flag.Int("s", 24, "") @@ -84,11 +85,15 @@ func main() { fmt.Printf("round-trip min/avg/max/stddev = %v/%v/%v/%v\n", stats.MinRtt, stats.AvgRtt, stats.MaxRtt, stats.StdDevRtt) } + pinger.OnTimeOut = func(packet *probing.Packet) { + fmt.Println("timeout", packet.Addr, packet.Rtt, packet.TTL) + } pinger.Count = *count pinger.Size = *size pinger.Interval = *interval pinger.Timeout = *timeout + pinger.MaxRtt = *maxRtt pinger.TTL = *ttl pinger.SetPrivileged(*privileged) diff --git a/ping.go b/ping.go index 2859bed..0ba378f 100644 --- a/ping.go +++ b/ping.go @@ -92,24 +92,22 @@ var ( func New(addr string) *Pinger { r := rand.New(rand.NewSource(getSeed())) firstUUID := uuid.New() - var firstSequence = map[uuid.UUID]map[int]struct{}{} - firstSequence[firstUUID] = make(map[int]struct{}) return &Pinger{ Count: -1, Interval: time.Second, RecordRtts: true, Size: timeSliceLength + trackerLength, Timeout: time.Duration(math.MaxInt64), + MaxRtt: time.Duration(math.MaxInt64), addr: addr, done: make(chan interface{}), id: r.Intn(math.MaxUint16), - trackerUUIDs: []uuid.UUID{firstUUID}, ipaddr: nil, ipv4: false, network: "ip", protocol: "udp", - awaitingSequences: firstSequence, + awaitingSequences: newSeqMap(firstUUID), TTL: 64, logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, } @@ -129,6 +127,9 @@ type Pinger struct { // Timeout specifies a timeout before ping exits, regardless of how many // packets have been received. Timeout time.Duration + // MaxRtt If no response is received after this time, OnTimeout is called + // important! This option is not guaranteed. and if we receive the packet that was timeout, the function OnDuplicateRecv will be called + MaxRtt time.Duration // Count tells pinger to stop after sending (and receiving) Count echo // packets. If this option is not specified, pinger will operate until @@ -183,6 +184,8 @@ type Pinger struct { // OnRecvError is called when an error occurs while Pinger attempts to receive a packet OnRecvError func(error) + // OnTimeOut is called when a packet don't have received after MaxRtt. + OnTimeOut func(*Packet) // Size of packet being sent Size int @@ -205,14 +208,11 @@ type Pinger struct { // df when true sets the do-not-fragment bit in the outer IP or IPv6 header df bool - // trackerUUIDs is the list of UUIDs being used for sending packets. - trackerUUIDs []uuid.UUID - ipv4 bool id int sequence int // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts - awaitingSequences map[uuid.UUID]map[int]struct{} + awaitingSequences seqMap // network is one of "ip", "ip4", or "ip6". network string // protocol is "icmp" or "udp". @@ -520,9 +520,12 @@ func (p *Pinger) runLoop( timeout := time.NewTicker(p.Timeout) interval := time.NewTicker(p.Interval) + packetTimeout := time.NewTimer(time.Duration(math.MaxInt64)) + skip := false defer func() { interval.Stop() timeout.Stop() + packetTimeout.Stop() }() if err := p.sendICMP(conn); err != nil { @@ -530,10 +533,35 @@ func (p *Pinger) runLoop( } for { + if !skip && !packetTimeout.Stop() { + <-packetTimeout.C + } + skip = false + first := p.awaitingSequences.peekFirst() + if first != nil { + packetTimeout.Reset(time.Until(first.time.Add(p.MaxRtt))) + } else { + packetTimeout.Reset(time.Duration(math.MaxInt64)) + } + select { case <-p.done: return nil + case <-packetTimeout.C: + skip = true + p.awaitingSequences.removeElem(first) + if p.OnTimeOut != nil { + inPkt := &Packet{ + IPAddr: p.ipaddr, + Addr: p.addr, + Rtt: p.MaxRtt, + Seq: first.seq, + TTL: -1, + ID: p.id, + } + p.OnTimeOut(inPkt) + } case <-timeout.C: return nil @@ -680,18 +708,15 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) { if err != nil { return nil, fmt.Errorf("error decoding tracking UUID: %w", err) } - - for _, item := range p.trackerUUIDs { - if item == packetUUID { - return &packetUUID, nil - } + if p.awaitingSequences.checkUUIDExist(packetUUID) { + return &packetUUID, nil } return nil, nil } // getCurrentTrackerUUID grabs the latest tracker UUID. func (p *Pinger) getCurrentTrackerUUID() uuid.UUID { - return p.trackerUUIDs[len(p.trackerUUIDs)-1] + return p.awaitingSequences.getCurUUID() } func (p *Pinger) processPacket(recv *packet) error { @@ -744,7 +769,8 @@ func (p *Pinger) processPacket(recv *packet) error { inPkt.Rtt = receivedAt.Sub(timestamp) inPkt.Seq = pkt.Seq // If we've already received this sequence, ignore it. - if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight { + e, inflight := p.awaitingSequences.getElem(*pktUUID, pkt.Seq) + if !inflight { p.PacketsRecvDuplicates++ if p.OnDuplicateRecv != nil { p.OnDuplicateRecv(inPkt) @@ -752,7 +778,7 @@ func (p *Pinger) processPacket(recv *packet) error { return nil } // remove it from the list of sequences we're waiting for so we don't get duplicates. - delete(p.awaitingSequences[*pktUUID], pkt.Seq) + p.awaitingSequences.removeElem(e) p.updateStatistics(inPkt) default: // Very bad, not sure how this can happen @@ -777,7 +803,8 @@ func (p *Pinger) sendICMP(conn packetConn) error { if err != nil { return fmt.Errorf("unable to marshal UUID binary: %w", err) } - t := append(timeToBytes(time.Now()), uuidEncoded...) + now := time.Now() + t := append(timeToBytes(now), uuidEncoded...) if remainSize := p.Size - timeSliceLength - trackerLength; remainSize > 0 { t = append(t, bytes.Repeat([]byte{1}, remainSize)...) } @@ -829,13 +856,12 @@ func (p *Pinger) sendICMP(conn packetConn) error { p.OnSend(outPkt) } // mark this sequence as in-flight - p.awaitingSequences[currentUUID][p.sequence] = struct{}{} + p.awaitingSequences.putElem(currentUUID, p.sequence, now) p.PacketsSent++ p.sequence++ if p.sequence > 65535 { newUUID := uuid.New() - p.trackerUUIDs = append(p.trackerUUIDs, newUUID) - p.awaitingSequences[newUUID] = make(map[int]struct{}) + p.awaitingSequences.newSeqMap(newUUID) p.sequence = 0 } break diff --git a/ping_test.go b/ping_test.go index 7f8c7e9..4e0fce6 100644 --- a/ping_test.go +++ b/ping_test.go @@ -29,7 +29,8 @@ func TestProcessPacket(t *testing.T) { if err != nil { t.Fatalf("unable to marshal UUID binary: %s", err) } - data := append(timeToBytes(time.Now()), uuidEncoded...) + now := time.Now() + data := append(timeToBytes(now), uuidEncoded...) if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 { data = append(data, bytes.Repeat([]byte{1}, remainSize)...) } @@ -39,7 +40,8 @@ func TestProcessPacket(t *testing.T) { Seq: pinger.sequence, Data: data, } - pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{} + pinger.awaitingSequences.putElem(currentUUID, pinger.sequence, now) + //pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{} msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply, @@ -598,7 +600,8 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { if err != nil { t.Fatalf("unable to marshal UUID binary: %s", err) } - data := append(timeToBytes(time.Now()), uuidEncoded...) + now := time.Now() + data := append(timeToBytes(now), uuidEncoded...) if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 { data = append(data, bytes.Repeat([]byte{1}, remainSize)...) } @@ -609,7 +612,9 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { Data: data, } // register the sequence as sent - pinger.awaitingSequences[currentUUID][0] = struct{}{} + + pinger.awaitingSequences.putElem(currentUUID, 0, now) + //pinger.awaitingSequences[currentUUID][0] = struct{}{} msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply, diff --git a/seq_map.go b/seq_map.go new file mode 100644 index 0000000..be8d8c5 --- /dev/null +++ b/seq_map.go @@ -0,0 +1,76 @@ +package probing + +import ( + "github.com/google/uuid" + "time" +) + +type seqMap struct { + curUUID uuid.UUID + head *elem + tail *elem + uuidMap map[uuid.UUID]map[int]*elem +} +type elem struct { + uuid uuid.UUID + seq int + time time.Time + prev *elem + next *elem +} + +func newSeqMap(u uuid.UUID) seqMap { + s := seqMap{ + curUUID: u, + head: &elem{}, + tail: &elem{}, + uuidMap: map[uuid.UUID]map[int]*elem{}, + } + s.uuidMap[u] = make(map[int]*elem) + s.head.next = s.tail + s.tail.prev = s.head + return s +} + +func (s seqMap) newSeqMap(u uuid.UUID) { + s.curUUID = u + s.uuidMap[u] = make(map[int]*elem) +} + +func (s seqMap) putElem(uuid uuid.UUID, seq int, now time.Time) { + e := &elem{ + uuid: uuid, + seq: seq, + time: now, + prev: s.tail.prev, + next: s.tail, + } + s.tail.prev.next = e + s.tail.prev = e + s.uuidMap[uuid][seq] = e +} +func (s seqMap) getElem(uuid uuid.UUID, seq int) (*elem, bool) { + e, ok := s.uuidMap[uuid][seq] + return e, ok +} +func (s seqMap) removeElem(e *elem) { + e.prev.next = e.next + e.next.prev = e.prev + if m, ok := s.uuidMap[e.uuid]; ok { + delete(m, e.seq) + } +} +func (s seqMap) peekFirst() *elem { + if s.head.next == s.tail { + return nil + } + return s.head.next +} +func (s seqMap) getCurUUID() uuid.UUID { + return s.curUUID +} + +func (s seqMap) checkUUIDExist(u uuid.UUID) bool { + _, ok := s.uuidMap[u] + return ok +} diff --git a/seq_map_test.go b/seq_map_test.go new file mode 100644 index 0000000..efd4c22 --- /dev/null +++ b/seq_map_test.go @@ -0,0 +1,91 @@ +package probing + +import ( + "github.com/google/uuid" + "testing" + "time" +) + +func TestSeqMap(t *testing.T) { + u := uuid.New() + s := newSeqMap(u) + t.Run("newSeqMap", func(t *testing.T) { + u2 := uuid.New() + s.newSeqMap(u2) + + if _, ok := s.uuidMap[u2]; !ok { + t.Errorf("Expected seqMap to contain UUID %s", u2.String()) + } + }) + + t.Run("putElem", func(t *testing.T) { + seq := 1 + s.putElem(u, seq, time.Now()) + + // Check that seqMap contains the correct elements + if _, ok := s.uuidMap[u][seq]; !ok { + t.Errorf("Expected seqMap[%s][%d] to exist", u.String(), seq) + } + + if s.peekFirst().seq != seq { + t.Errorf("Expected tail.prev.seq to be %d, got %d", seq, s.tail.prev.seq) + } + }) + + t.Run("getElem", func(t *testing.T) { + seq := 1 + elem, ok := s.getElem(u, seq) + + if !ok { + t.Errorf("Expected getElem to return true for existing element") + } + + if elem.seq != seq { + t.Errorf("Expected element's seq to be %d, got %d", seq, elem.seq) + } + }) + + t.Run("removeElem", func(t *testing.T) { + seq := 1 + elem, ok := s.getElem(u, seq) + if !ok { + t.Fatalf("Expected getElem to return true for existing element") + } + + s.removeElem(elem) + + if _, ok := s.uuidMap[u][seq]; ok { + t.Errorf("Expected seqMap[%s][%d] to be removed", u.String(), seq) + } + }) + + // test peekFirst + t.Run("peekFirst", func(t *testing.T) { + seq := 2 + s.putElem(u, seq, time.Now()) + + elem := s.peekFirst() + + // Check that peekFirst returns the first element of the linked list + if elem.seq != seq { + t.Errorf("Expected peekFirst to return element with seq %d, got %d", seq, elem.seq) + } + }) +} + +func TestSeqMap2(t *testing.T) { + u := uuid.New() + s := newSeqMap(u) + for i := 0; i < 100; i++ { + s.putElem(u, i, time.Now()) + } + for i := 0; i < 100; i++ { + e, ok := s.getElem(u, i) + AssertTrue(t, ok && e.seq == i) + } + for i := 0; i < 20; i++ { + first := s.peekFirst() + AssertTrue(t, first.seq == i) + s.removeElem(first) + } +}