From d6c2ecb0874647d8ce7d610986ec91b5f120fe5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikk=20Margus=20M=C3=B6ll?= Date: Sat, 27 Aug 2022 20:10:05 +0300 Subject: [PATCH 1/3] move from net to netip --- brute.go | 34 ++++++++--------- cidranger.go | 28 +++++++------- net/ip.go | 101 +++++++++++++++++++++++++++++---------------------- trie.go | 30 ++++++++------- version.go | 23 ++++++------ 5 files changed, 116 insertions(+), 100 deletions(-) diff --git a/brute.go b/brute.go index 37a68be..3147f4f 100644 --- a/brute.go +++ b/brute.go @@ -1,7 +1,7 @@ package cidranger import ( - "net" + "net/netip" rnet "github.com/yl2chen/cidranger/net" ) @@ -17,24 +17,24 @@ import ( // and used as the ground truth when running a wider range of 'random' tests on // other more sophisticated implementations. type bruteRanger struct { - ipV4Entries map[string]RangerEntry - ipV6Entries map[string]RangerEntry + ipV4Entries map[netip.Prefix]RangerEntry + ipV6Entries map[netip.Prefix]RangerEntry } // newBruteRanger returns a new Ranger. func newBruteRanger() Ranger { return &bruteRanger{ - ipV4Entries: make(map[string]RangerEntry), - ipV6Entries: make(map[string]RangerEntry), + ipV4Entries: make(map[netip.Prefix]RangerEntry), + ipV6Entries: make(map[netip.Prefix]RangerEntry), } } // Insert inserts a RangerEntry into ranger. func (b *bruteRanger) Insert(entry RangerEntry) error { network := entry.Network() - key := network.String() + key := network if _, found := b.ipV4Entries[key]; !found { - entries, err := b.getEntriesByVersion(entry.Network().IP) + entries, err := b.getEntriesByVersion(entry.Network().Addr()) if err != nil { return err } @@ -44,12 +44,12 @@ func (b *bruteRanger) Insert(entry RangerEntry) error { } // Remove removes a RangerEntry identified by given network from ranger. -func (b *bruteRanger) Remove(network net.IPNet) (RangerEntry, error) { - networks, err := b.getEntriesByVersion(network.IP) +func (b *bruteRanger) Remove(network netip.Prefix) (RangerEntry, error) { + networks, err := b.getEntriesByVersion(network.Addr()) if err != nil { return nil, err } - key := network.String() + key := network if networkToDelete, found := networks[key]; found { delete(networks, key) return networkToDelete, nil @@ -59,7 +59,7 @@ func (b *bruteRanger) Remove(network net.IPNet) (RangerEntry, error) { // Contains returns bool indicating whether given ip is contained by any // network in ranger. -func (b *bruteRanger) Contains(ip net.IP) (bool, error) { +func (b *bruteRanger) Contains(ip netip.Addr) (bool, error) { entries, err := b.getEntriesByVersion(ip) if err != nil { return false, err @@ -74,7 +74,7 @@ func (b *bruteRanger) Contains(ip net.IP) (bool, error) { } // ContainingNetworks returns all RangerEntry(s) that given ip contained in. -func (b *bruteRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { +func (b *bruteRanger) ContainingNetworks(ip netip.Addr) ([]RangerEntry, error) { entries, err := b.getEntriesByVersion(ip) if err != nil { return nil, err @@ -92,8 +92,8 @@ func (b *bruteRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { // CoveredNetworks returns the list of RangerEntry(s) the given ipnet // covers. That is, the networks that are completely subsumed by the // specified network. -func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { - entries, err := b.getEntriesByVersion(network.IP) +func (b *bruteRanger) CoveredNetworks(network netip.Prefix) ([]RangerEntry, error) { + entries, err := b.getEntriesByVersion(network.Addr()) if err != nil { return nil, err } @@ -113,11 +113,11 @@ func (b *bruteRanger) Len() int { return len(b.ipV4Entries) + len(b.ipV6Entries) } -func (b *bruteRanger) getEntriesByVersion(ip net.IP) (map[string]RangerEntry, error) { - if ip.To4() != nil { +func (b *bruteRanger) getEntriesByVersion(ip netip.Addr) (map[netip.Prefix]RangerEntry, error) { + if ip.Is4() { return b.ipV4Entries, nil } - if ip.To16() != nil { + if ip.Is6() { return b.ipV6Entries, nil } return nil, ErrInvalidNetworkInput diff --git a/cidranger.go b/cidranger.go index 2e8f118..f5aa859 100644 --- a/cidranger.go +++ b/cidranger.go @@ -41,14 +41,14 @@ package cidranger import ( "fmt" - "net" + "net/netip" ) // ErrInvalidNetworkInput is returned upon invalid network input. -var ErrInvalidNetworkInput = fmt.Errorf("Invalid network input") +var ErrInvalidNetworkInput = fmt.Errorf("invalid network input") // ErrInvalidNetworkNumberInput is returned upon invalid network input. -var ErrInvalidNetworkNumberInput = fmt.Errorf("Invalid network number input") +var ErrInvalidNetworkNumberInput = fmt.Errorf("invalid network number input") // AllIPv4 is a IPv4 CIDR that contains all networks var AllIPv4 = parseCIDRUnsafe("0.0.0.0/0") @@ -56,39 +56,39 @@ var AllIPv4 = parseCIDRUnsafe("0.0.0.0/0") // AllIPv6 is a IPv6 CIDR that contains all networks var AllIPv6 = parseCIDRUnsafe("0::0/0") -func parseCIDRUnsafe(s string) *net.IPNet { - _, cidr, _ := net.ParseCIDR(s) +func parseCIDRUnsafe(s string) netip.Prefix { + cidr, _ := netip.ParsePrefix(s) return cidr } // RangerEntry is an interface for insertable entry into a Ranger. type RangerEntry interface { - Network() net.IPNet + Network() netip.Prefix } type basicRangerEntry struct { - ipNet net.IPNet + ipNet netip.Prefix } -func (b *basicRangerEntry) Network() net.IPNet { +func (b *basicRangerEntry) Network() netip.Prefix { return b.ipNet } // NewBasicRangerEntry returns a basic RangerEntry that only stores the network // itself. -func NewBasicRangerEntry(ipNet net.IPNet) RangerEntry { +func NewBasicRangerEntry(ipNet netip.Prefix) RangerEntry { return &basicRangerEntry{ - ipNet: ipNet, + ipNet: ipNet, //.Masked(), } } // Ranger is an interface for cidr block containment lookups. type Ranger interface { Insert(entry RangerEntry) error - Remove(network net.IPNet) (RangerEntry, error) - Contains(ip net.IP) (bool, error) - ContainingNetworks(ip net.IP) ([]RangerEntry, error) - CoveredNetworks(network net.IPNet) ([]RangerEntry, error) + Remove(network netip.Prefix) (RangerEntry, error) + Contains(ip netip.Addr) (bool, error) + ContainingNetworks(ip netip.Addr) ([]RangerEntry, error) + CoveredNetworks(network netip.Prefix) ([]RangerEntry, error) Len() int } diff --git a/net/ip.go b/net/ip.go index 75cb356..033c1fe 100644 --- a/net/ip.go +++ b/net/ip.go @@ -4,11 +4,10 @@ Package net provides utility functions for working with IPs (net.IP). package net import ( - "bytes" "encoding/binary" "fmt" "math" - "net" + "net/netip" ) // IPVersion is version of IP address. @@ -34,7 +33,7 @@ var ErrVersionMismatch = fmt.Errorf("Network input version mismatch") // ErrNoGreatestCommonBit is an error returned when no greatest common bit // exists for the cidr ranges. -var ErrNoGreatestCommonBit = fmt.Errorf("No greatest common bit") +var ErrNoGreatestCommonBit = fmt.Errorf("no greatest common bit") // NetworkNumber represents an IP address using uint32 as internal storage. // IPv4 usings 1 uint32, while IPv6 uses 4 uint32. @@ -42,23 +41,20 @@ type NetworkNumber []uint32 // NewNetworkNumber returns a equivalent NetworkNumber to given IP address, // return nil if ip is neither IPv4 nor IPv6. -func NewNetworkNumber(ip net.IP) NetworkNumber { - if ip == nil { - return nil - } - coercedIP := ip.To4() - parts := 1 - if coercedIP == nil { - coercedIP = ip.To16() +func NewNetworkNumber(ip netip.Addr) NetworkNumber { + var parts int + if ip.Is4() { + parts = 1 + } else if ip.Is6() { parts = 4 - } - if coercedIP == nil { + } else { return nil } + nn := make(NetworkNumber, parts) + sl := ip.AsSlice() for i := 0; i < parts; i++ { - idx := i * net.IPv4len - nn[i] = binary.BigEndian.Uint32(coercedIP[idx : idx+net.IPv4len]) + nn[i] = binary.BigEndian.Uint32(sl[i*4 : (i+1)*4]) } return nn } @@ -80,15 +76,12 @@ func (n NetworkNumber) ToV6() NetworkNumber { } // ToIP returns equivalent net.IP. -func (n NetworkNumber) ToIP() net.IP { - ip := make(net.IP, len(n)*BytePerUint32) +func (n NetworkNumber) ToIP() netip.Addr { + sl := make([]byte, len(n)*BytePerUint32) for i := 0; i < len(n); i++ { - idx := i * net.IPv4len - binary.BigEndian.PutUint32(ip[idx:idx+net.IPv4len], n[i]) - } - if len(ip) == net.IPv4len { - ip = net.IPv4(ip[0], ip[1], ip[2], ip[3]) + binary.BigEndian.PutUint32(sl[i*4:(i+1)*4], n[i]) } + ip, _ := netip.AddrFromSlice(sl) return ip } @@ -171,27 +164,44 @@ func (n NetworkNumber) LeastCommonBitPosition(n1 NetworkNumber) (uint, error) { // Network represents a block of network numbers, also known as CIDR. type Network struct { - net.IPNet + IPNet netip.Prefix Number NetworkNumber Mask NetworkNumberMask } // NewNetwork returns Network built using given net.IPNet. -func NewNetwork(ipNet net.IPNet) Network { +func NewNetwork(ipNet netip.Prefix) Network { return Network{ - IPNet: ipNet, - Number: NewNetworkNumber(ipNet.IP), - Mask: NetworkNumberMask(NewNetworkNumber(net.IP(ipNet.Mask))), + IPNet: ipNet, //.Masked(), + Number: NewNetworkNumber(ipNet.Addr()), + Mask: bitsToMask(ipNet.Bits(), ipNet.Addr().BitLen()), + } +} + +func bitsToMask(ones, bits int) NetworkNumberMask { + parts := bits / BitsPerUint32 + sl := make([]uint32, parts) + for i := 0; i < parts; i++ { + if ones == 0 { + break + } + var maskBits uint32 + if ones >= 32 { + maskBits = 0xffff_ffff + ones -= 32 + } else { + maskBits = ((1 << ones) - 1) << (32 - ones) + ones = 0 + } + sl[i] = maskBits } + + return NetworkNumberMask(sl) } // Masked returns a new network conforming to new mask. func (n Network) Masked(ones int) Network { - mask := net.CIDRMask(ones, len(n.Number)*BitsPerUint32) - return NewNetwork(net.IPNet{ - IP: n.IP.Mask(mask), - Mask: mask, - }) + return NewNetwork(netip.PrefixFrom(n.IPNet.Addr(), ones).Masked()) } // Contains returns true if NetworkNumber is in range of Network, false @@ -214,30 +224,33 @@ func (n Network) Covers(o Network) bool { if len(n.Number) != len(o.Number) { return false } - nMaskSize, _ := n.IPNet.Mask.Size() - oMaskSize, _ := o.IPNet.Mask.Size() + nMaskSize := n.IPNet.Bits() + oMaskSize := o.IPNet.Bits() return n.Contains(o.Number) && nMaskSize <= oMaskSize } // LeastCommonBitPosition returns the smallest position of the preceding common // bits of the 2 networks, and returns an error ErrNoGreatestCommonBit // if the two network number diverges from the first bit. -func (n Network) LeastCommonBitPosition(n1 Network) (uint, error) { - maskSize, _ := n.IPNet.Mask.Size() - if maskSize1, _ := n1.IPNet.Mask.Size(); maskSize1 < maskSize { +func (n Network) LeastCommonBitPosition(n1 Network) (max uint, err error) { + maskSize := n.IPNet.Bits() + if maskSize1 := n1.IPNet.Bits(); maskSize1 < maskSize { maskSize = maskSize1 } - maskPosition := len(n1.Number)*BitsPerUint32 - maskSize - lcb, err := n.Number.LeastCommonBitPosition(n1.Number) - if err != nil { + + if max, err = n.Number.LeastCommonBitPosition(n1.Number); err != nil { return 0, err } - return uint(math.Max(float64(maskPosition), float64(lcb))), nil + if maskPosition := uint(len(n1.Number)*BitsPerUint32 - maskSize); maskPosition > max { + max = maskPosition + } + + return max, nil } // Equal is the equality test for 2 networks. func (n Network) Equal(n1 Network) bool { - return bytes.Equal(n.IPNet.IP, n1.IPNet.IP) && bytes.Equal(n.IPNet.Mask, n1.IPNet.Mask) + return n.IPNet == n1.IPNet } func (n Network) String() string { @@ -263,11 +276,11 @@ func (m NetworkNumberMask) Mask(n NetworkNumber) (NetworkNumber, error) { } // NextIP returns the next sequential ip. -func NextIP(ip net.IP) net.IP { +func NextIP(ip netip.Addr) netip.Addr { return NewNetworkNumber(ip).Next().ToIP() } // PreviousIP returns the previous sequential ip. -func PreviousIP(ip net.IP) net.IP { +func PreviousIP(ip netip.Addr) netip.Addr { return NewNetworkNumber(ip).Previous().ToIP() } diff --git a/trie.go b/trie.go index 31976bd..cbf556e 100644 --- a/trie.go +++ b/trie.go @@ -2,7 +2,7 @@ package cidranger import ( "fmt" - "net" + "net/netip" "strings" rnet "github.com/yl2chen/cidranger/net" @@ -48,21 +48,22 @@ type prefixTrie struct { // newPrefixTree creates a new prefixTrie. func newPrefixTree(version rnet.IPVersion) Ranger { - _, rootNet, _ := net.ParseCIDR("0.0.0.0/0") + rootStr := "0.0.0.0/0" if version == rnet.IPv6 { - _, rootNet, _ = net.ParseCIDR("0::0/0") + rootStr = "0::0/0" } + rootNet := netip.MustParsePrefix(rootStr) return &prefixTrie{ - children: make([]*prefixTrie, 2, 2), + children: make([]*prefixTrie, 2), numBitsSkipped: 0, numBitsHandled: 1, - network: rnet.NewNetwork(*rootNet), + network: rnet.NewNetwork(rootNet), } } func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie { path := &prefixTrie{ - children: make([]*prefixTrie, 2, 2), + children: make([]*prefixTrie, 2), numBitsSkipped: numBitsSkipped, numBitsHandled: 1, network: network.Masked(int(numBitsSkipped)), @@ -71,7 +72,7 @@ func newPathprefixTrie(network rnet.Network, numBitsSkipped uint) *prefixTrie { } func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie { - ones, _ := network.IPNet.Mask.Size() + ones := network.IPNet.Bits() leaf := newPathprefixTrie(network, uint(ones)) leaf.entry = entry return leaf @@ -80,7 +81,7 @@ func newEntryTrie(network rnet.Network, entry RangerEntry) *prefixTrie { // Insert inserts a RangerEntry into prefix trie. func (p *prefixTrie) Insert(entry RangerEntry) error { network := entry.Network() - sizeIncreased, err := p.insert(rnet.NewNetwork(network), entry) + sizeIncreased, err := p.insert(rnet.NewNetwork(network.Masked()), entry) if sizeIncreased { p.size++ } @@ -88,8 +89,8 @@ func (p *prefixTrie) Insert(entry RangerEntry) error { } // Remove removes RangerEntry identified by given network from trie. -func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) { - entry, err := p.remove(rnet.NewNetwork(network)) +func (p *prefixTrie) Remove(network netip.Prefix) (RangerEntry, error) { + entry, err := p.remove(rnet.NewNetwork(network.Masked())) if entry != nil { p.size-- } @@ -98,7 +99,7 @@ func (p *prefixTrie) Remove(network net.IPNet) (RangerEntry, error) { // Contains returns boolean indicating whether given ip is contained in any // of the inserted networks. -func (p *prefixTrie) Contains(ip net.IP) (bool, error) { +func (p *prefixTrie) Contains(ip netip.Addr) (bool, error) { nn := rnet.NewNetworkNumber(ip) if nn == nil { return false, ErrInvalidNetworkNumberInput @@ -108,7 +109,7 @@ func (p *prefixTrie) Contains(ip net.IP) (bool, error) { // ContainingNetworks returns the list of RangerEntry(s) the given ip is // contained in in ascending prefix order. -func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { +func (p *prefixTrie) ContainingNetworks(ip netip.Addr) ([]RangerEntry, error) { nn := rnet.NewNetworkNumber(ip) if nn == nil { return nil, ErrInvalidNetworkNumberInput @@ -119,7 +120,7 @@ func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { // CoveredNetworks returns the list of RangerEntry(s) the given ipnet // covers. That is, the networks that are completely subsumed by the // specified network. -func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { +func (p *prefixTrie) CoveredNetworks(network netip.Prefix) ([]RangerEntry, error) { net := rnet.NewNetwork(network) return p.coveredNetworks(net) } @@ -239,6 +240,9 @@ func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) (bool, erro // Check whether it is necessary to insert additional path prefix between current trie and existing child, // in the case that inserted network diverges on its path to existing child. lcb, err := network.LeastCommonBitPosition(existingChild.network) + if err != nil { + return false, err + } divergingBitPos := int(lcb) - 1 if divergingBitPos > existingChild.targetBitPosition() { pathPrefix := newPathprefixTrie(network, p.totalNumberOfBits()-lcb) diff --git a/version.go b/version.go index 2c3fe2b..76106cb 100644 --- a/version.go +++ b/version.go @@ -1,7 +1,7 @@ package cidranger import ( - "net" + "net/netip" rnet "github.com/yl2chen/cidranger/net" ) @@ -22,22 +22,22 @@ func newVersionedRanger(factory rangerFactory) Ranger { func (v *versionedRanger) Insert(entry RangerEntry) error { network := entry.Network() - ranger, err := v.getRangerForIP(network.IP) + ranger, err := v.getRangerForIP(network.Addr()) if err != nil { return err } return ranger.Insert(entry) } -func (v *versionedRanger) Remove(network net.IPNet) (RangerEntry, error) { - ranger, err := v.getRangerForIP(network.IP) +func (v *versionedRanger) Remove(network netip.Prefix) (RangerEntry, error) { + ranger, err := v.getRangerForIP(network.Addr()) if err != nil { return nil, err } return ranger.Remove(network) } -func (v *versionedRanger) Contains(ip net.IP) (bool, error) { +func (v *versionedRanger) Contains(ip netip.Addr) (bool, error) { ranger, err := v.getRangerForIP(ip) if err != nil { return false, err @@ -45,7 +45,7 @@ func (v *versionedRanger) Contains(ip net.IP) (bool, error) { return ranger.Contains(ip) } -func (v *versionedRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { +func (v *versionedRanger) ContainingNetworks(ip netip.Addr) ([]RangerEntry, error) { ranger, err := v.getRangerForIP(ip) if err != nil { return nil, err @@ -53,8 +53,8 @@ func (v *versionedRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) { return ranger.ContainingNetworks(ip) } -func (v *versionedRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) { - ranger, err := v.getRangerForIP(network.IP) +func (v *versionedRanger) CoveredNetworks(network netip.Prefix) ([]RangerEntry, error) { + ranger, err := v.getRangerForIP(network.Addr()) if err != nil { return nil, err } @@ -66,11 +66,10 @@ func (v *versionedRanger) Len() int { return v.ipV4Ranger.Len() + v.ipV6Ranger.Len() } -func (v *versionedRanger) getRangerForIP(ip net.IP) (Ranger, error) { - if ip.To4() != nil { +func (v *versionedRanger) getRangerForIP(ip netip.Addr) (Ranger, error) { + if ip.Is4() { return v.ipV4Ranger, nil - } - if ip.To16() != nil { + } else if ip.Is6() { return v.ipV6Ranger, nil } return nil, ErrInvalidNetworkNumberInput From 56effe9d6d076c922ef5940eb6e4016426759681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikk=20Margus=20M=C3=B6ll?= Date: Sat, 27 Aug 2022 20:10:23 +0300 Subject: [PATCH 2/3] update tests and example binary --- brute_test.go | 110 +++++++++++++------------------ cidranger_test.go | 89 +++++++++++++------------ example/custom-ranger-asn.go | 20 +++--- net/ip_test.go | 122 +++++++++++++++++------------------ trie_test.go | 116 ++++++++++++++++----------------- 5 files changed, 217 insertions(+), 240 deletions(-) diff --git a/brute_test.go b/brute_test.go index 71ee637..6da5a92 100644 --- a/brute_test.go +++ b/brute_test.go @@ -1,7 +1,7 @@ package cidranger import ( - "net" + "net/netip" "sort" "testing" @@ -10,46 +10,38 @@ import ( func TestInsert(t *testing.T) { ranger := newBruteRanger().(*bruteRanger) - _, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24") - _, networkIPv6, _ := net.ParseCIDR("8000::/96") - entryIPv4 := NewBasicRangerEntry(*networkIPv4) - entryIPv6 := NewBasicRangerEntry(*networkIPv6) + networkIPv4 := netip.MustParsePrefix("0.0.1.0/24") + networkIPv6 := netip.MustParsePrefix("8000::/96") + entryIPv4 := NewBasicRangerEntry(networkIPv4) + entryIPv6 := NewBasicRangerEntry(networkIPv6) ranger.Insert(entryIPv4) ranger.Insert(entryIPv6) assert.Equal(t, 1, len(ranger.ipV4Entries)) - assert.Equal(t, entryIPv4, ranger.ipV4Entries["0.0.1.0/24"]) + assert.Equal(t, entryIPv4, ranger.ipV4Entries[networkIPv4]) assert.Equal(t, 1, len(ranger.ipV6Entries)) - assert.Equal(t, entryIPv6, ranger.ipV6Entries["8000::/96"]) -} - -func TestInsertError(t *testing.T) { - bRanger := newBruteRanger().(*bruteRanger) - _, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24") - networkIPv4.IP = append(networkIPv4.IP, byte(4)) - err := bRanger.Insert(NewBasicRangerEntry(*networkIPv4)) - assert.Equal(t, ErrInvalidNetworkInput, err) + assert.Equal(t, entryIPv6, ranger.ipV6Entries[networkIPv6]) } func TestRemove(t *testing.T) { ranger := newBruteRanger().(*bruteRanger) - _, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24") - _, networkIPv6, _ := net.ParseCIDR("8000::/96") - _, notInserted, _ := net.ParseCIDR("8000::/96") + networkIPv4 := netip.MustParsePrefix("0.0.1.0/24") + networkIPv6 := netip.MustParsePrefix("8000::/96") + notInserted := netip.MustParsePrefix("8000::/96") - insertIPv4 := NewBasicRangerEntry(*networkIPv4) - insertIPv6 := NewBasicRangerEntry(*networkIPv6) + insertIPv4 := NewBasicRangerEntry(networkIPv4) + insertIPv6 := NewBasicRangerEntry(networkIPv6) ranger.Insert(insertIPv4) - deletedIPv4, err := ranger.Remove(*networkIPv4) + deletedIPv4, err := ranger.Remove(networkIPv4) assert.NoError(t, err) ranger.Insert(insertIPv6) - deletedIPv6, err := ranger.Remove(*networkIPv6) + deletedIPv6, err := ranger.Remove(networkIPv6) assert.NoError(t, err) - entry, err := ranger.Remove(*notInserted) + entry, err := ranger.Remove(notInserted) assert.NoError(t, err) assert.Nil(t, entry) @@ -59,33 +51,23 @@ func TestRemove(t *testing.T) { assert.Equal(t, 0, len(ranger.ipV6Entries)) } -func TestRemoveError(t *testing.T) { - r := newBruteRanger().(*bruteRanger) - _, invalidNetwork, _ := net.ParseCIDR("0.0.1.0/24") - invalidNetwork.IP = append(invalidNetwork.IP, byte(4)) - - _, err := r.Remove(*invalidNetwork) - assert.Equal(t, ErrInvalidNetworkInput, err) -} - func TestContains(t *testing.T) { r := newBruteRanger().(*bruteRanger) - _, network, _ := net.ParseCIDR("0.0.1.0/24") - _, network1, _ := net.ParseCIDR("8000::/112") - r.Insert(NewBasicRangerEntry(*network)) - r.Insert(NewBasicRangerEntry(*network1)) + network := netip.MustParsePrefix("0.0.1.0/24") + network1 := netip.MustParsePrefix("8000::/112") + r.Insert(NewBasicRangerEntry(network)) + r.Insert(NewBasicRangerEntry(network1)) cases := []struct { - ip net.IP + ip netip.Addr contains bool err error name string }{ - {net.ParseIP("0.0.1.255"), true, nil, "IPv4 should contain"}, - {net.ParseIP("0.0.0.255"), false, nil, "IPv4 houldn't contain"}, - {net.ParseIP("8000::ffff"), true, nil, "IPv6 shouldn't contain"}, - {net.ParseIP("8000::1:ffff"), false, nil, "IPv6 shouldn't contain"}, - {append(net.ParseIP("8000::1:ffff"), byte(0)), false, ErrInvalidNetworkInput, "Invalid IP"}, + {netip.MustParseAddr("0.0.1.255"), true, nil, "IPv4 should contain"}, + {netip.MustParseAddr("0.0.0.255"), false, nil, "IPv4 shouldn't contain"}, + {netip.MustParseAddr("8000::ffff"), true, nil, "IPv6 shouldn't contain"}, + {netip.MustParseAddr("8000::1:ffff"), false, nil, "IPv6 shouldn't contain"}, } for _, tc := range cases { @@ -103,31 +85,30 @@ func TestContains(t *testing.T) { func TestContainingNetworks(t *testing.T) { r := newBruteRanger().(*bruteRanger) - _, network1, _ := net.ParseCIDR("0.0.1.0/24") - _, network2, _ := net.ParseCIDR("0.0.1.0/25") - _, network3, _ := net.ParseCIDR("8000::/112") - _, network4, _ := net.ParseCIDR("8000::/113") - entry1 := NewBasicRangerEntry(*network1) - entry2 := NewBasicRangerEntry(*network2) - entry3 := NewBasicRangerEntry(*network3) - entry4 := NewBasicRangerEntry(*network4) + network1 := netip.MustParsePrefix("0.0.1.0/24") + network2 := netip.MustParsePrefix("0.0.1.0/25") + network3 := netip.MustParsePrefix("8000::/112") + network4 := netip.MustParsePrefix("8000::/113") + entry1 := NewBasicRangerEntry(network1) + entry2 := NewBasicRangerEntry(network2) + entry3 := NewBasicRangerEntry(network3) + entry4 := NewBasicRangerEntry(network4) r.Insert(entry1) r.Insert(entry2) r.Insert(entry3) r.Insert(entry4) cases := []struct { - ip net.IP + ip netip.Addr containingNetworks []RangerEntry err error name string }{ - {net.ParseIP("0.0.1.255"), []RangerEntry{entry1}, nil, "IPv4 should contain"}, - {net.ParseIP("0.0.1.127"), []RangerEntry{entry1, entry2}, nil, "IPv4 should contain both"}, - {net.ParseIP("0.0.0.127"), []RangerEntry{}, nil, "IPv4 should contain none"}, - {net.ParseIP("8000::ffff"), []RangerEntry{entry3}, nil, "IPv6 should constain"}, - {net.ParseIP("8000::7fff"), []RangerEntry{entry3, entry4}, nil, "IPv6 should contain both"}, - {net.ParseIP("8000::1:7fff"), []RangerEntry{}, nil, "IPv6 should contain none"}, - {append(net.ParseIP("8000::1:7fff"), byte(0)), nil, ErrInvalidNetworkInput, "Invalid IP"}, + {netip.MustParseAddr("0.0.1.255"), []RangerEntry{entry1}, nil, "IPv4 should contain"}, + {netip.MustParseAddr("0.0.1.127"), []RangerEntry{entry1, entry2}, nil, "IPv4 should contain both"}, + {netip.MustParseAddr("0.0.0.127"), []RangerEntry{}, nil, "IPv4 should contain none"}, + {netip.MustParseAddr("8000::ffff"), []RangerEntry{entry3}, nil, "IPv6 should constain"}, + {netip.MustParseAddr("8000::7fff"), []RangerEntry{entry3, entry4}, nil, "IPv6 should contain both"}, + {netip.MustParseAddr("8000::1:7fff"), []RangerEntry{}, nil, "IPv6 should contain none"}, } for _, tc := range cases { @@ -151,17 +132,16 @@ func TestCoveredNetworks(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ranger := newBruteRanger() for _, insert := range tc.inserts { - _, network, _ := net.ParseCIDR(insert) - err := ranger.Insert(NewBasicRangerEntry(*network)) + network := netip.MustParsePrefix(insert) + err := ranger.Insert(NewBasicRangerEntry(network)) assert.NoError(t, err) } + var expectedEntries []string - for _, network := range tc.networks { - expectedEntries = append(expectedEntries, network) - } + expectedEntries = append(expectedEntries, tc.networks...) sort.Strings(expectedEntries) - _, snet, _ := net.ParseCIDR(tc.search) - networks, err := ranger.CoveredNetworks(*snet) + snet := netip.MustParsePrefix(tc.search) + networks, err := ranger.CoveredNetworks(snet) assert.NoError(t, err) var results []string diff --git a/cidranger_test.go b/cidranger_test.go index c1c741e..a47b38d 100644 --- a/cidranger_test.go +++ b/cidranger_test.go @@ -2,9 +2,9 @@ package cidranger import ( "encoding/json" - "io/ioutil" "math/rand" - "net" + "net/netip" + "os" "testing" "time" @@ -58,10 +58,11 @@ func testContainsAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) { for i := 0; i < iterations; i++ { nn := ipGen() - expected, err := baseRanger.Contains(nn.ToIP()) + ip := nn.ToIP() + expected, err := baseRanger.Contains(ip) assert.NoError(t, err) for _, ranger := range rangers { - actual, err := ranger.Contains(nn.ToIP()) + actual, err := ranger.Contains(ip) assert.NoError(t, err) assert.Equal(t, expected, actual) } @@ -127,59 +128,59 @@ func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkG */ func BenchmarkPCTrieHitIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("52.95.110.1"), NewPCTrieRanger()) + benchmarkContainsUsingAWSRanges(b, netip.MustParseAddr("52.95.110.1"), NewPCTrieRanger()) } func BenchmarkBruteRangerHitIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("52.95.110.1"), newBruteRanger()) + benchmarkContainsUsingAWSRanges(b, netip.MustParseAddr("52.95.110.1"), newBruteRanger()) } func BenchmarkPCTrieHitIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), NewPCTrieRanger()) + benchmarkContainsUsingAWSRanges(b, netip.MustParseAddr("2620:107:300f::36b7:ff81"), NewPCTrieRanger()) } func BenchmarkBruteRangerHitIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), newBruteRanger()) + benchmarkContainsUsingAWSRanges(b, netip.MustParseAddr("2620:107:300f::36b7:ff81"), newBruteRanger()) } func BenchmarkPCTrieMissIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("123.123.123.123"), NewPCTrieRanger()) + benchmarkContainsUsingAWSRanges(b, netip.MustParseAddr("123.123.123.123"), NewPCTrieRanger()) } func BenchmarkBruteRangerMissIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("123.123.123.123"), newBruteRanger()) + benchmarkContainsUsingAWSRanges(b, netip.MustParseAddr("123.123.123.123"), newBruteRanger()) } func BenchmarkPCTrieHMissIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620::ffff"), NewPCTrieRanger()) + benchmarkContainsUsingAWSRanges(b, netip.MustParseAddr("2620::ffff"), NewPCTrieRanger()) } func BenchmarkBruteRangerMissIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620::ffff"), newBruteRanger()) + benchmarkContainsUsingAWSRanges(b, netip.MustParseAddr("2620::ffff"), newBruteRanger()) } func BenchmarkPCTrieHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("52.95.110.1"), NewPCTrieRanger()) + benchmarkContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("52.95.110.1"), NewPCTrieRanger()) } func BenchmarkBruteRangerHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("52.95.110.1"), newBruteRanger()) + benchmarkContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("52.95.110.1"), newBruteRanger()) } func BenchmarkPCTrieHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), NewPCTrieRanger()) + benchmarkContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("2620:107:300f::36b7:ff81"), NewPCTrieRanger()) } func BenchmarkBruteRangerHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), newBruteRanger()) + benchmarkContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("2620:107:300f::36b7:ff81"), newBruteRanger()) } func BenchmarkPCTrieMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("123.123.123.123"), NewPCTrieRanger()) + benchmarkContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("123.123.123.123"), NewPCTrieRanger()) } func BenchmarkBruteRangerMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("123.123.123.123"), newBruteRanger()) + benchmarkContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("123.123.123.123"), newBruteRanger()) } func BenchmarkPCTrieHMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620::ffff"), NewPCTrieRanger()) + benchmarkContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("2620::ffff"), NewPCTrieRanger()) } func BenchmarkBruteRangerMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620::ffff"), newBruteRanger()) + benchmarkContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("2620::ffff"), newBruteRanger()) } func BenchmarkNewPathprefixTriev4(b *testing.B) { @@ -190,14 +191,14 @@ func BenchmarkNewPathprefixTriev6(b *testing.B) { benchmarkNewPathprefixTrie(b, "8000::/24") } -func benchmarkContainsUsingAWSRanges(tb testing.TB, nn net.IP, ranger Ranger) { +func benchmarkContainsUsingAWSRanges(tb testing.TB, nn netip.Addr, ranger Ranger) { configureRangerWithAWSRanges(tb, ranger) for n := 0; n < tb.(*testing.B).N; n++ { ranger.Contains(nn) } } -func benchmarkContainingNetworksUsingAWSRanges(tb testing.TB, nn net.IP, ranger Ranger) { +func benchmarkContainingNetworksUsingAWSRanges(tb testing.TB, nn netip.Addr, ranger Ranger) { configureRangerWithAWSRanges(tb, ranger) for n := 0; n < tb.(*testing.B).N; n++ { ranger.ContainingNetworks(nn) @@ -205,10 +206,10 @@ func benchmarkContainingNetworksUsingAWSRanges(tb testing.TB, nn net.IP, ranger } func benchmarkNewPathprefixTrie(b *testing.B, net1 string) { - _, ipNet1, _ := net.ParseCIDR(net1) - ones, _ := ipNet1.Mask.Size() + ipNet1 := netip.MustParsePrefix(net1) + ones := ipNet1.Bits() - n1 := rnet.NewNetwork(*ipNet1) + n1 := rnet.NewNetwork(ipNet1) uOnes := uint(ones) b.ResetTimer() @@ -228,16 +229,20 @@ type ipGenerator func() rnet.NetworkNumber func randIPv4Gen() rnet.NetworkNumber { return rnet.NetworkNumber{rand.Uint32()} } -func randIPv6Gen() rnet.NetworkNumber { - return rnet.NetworkNumber{rand.Uint32(), rand.Uint32(), rand.Uint32(), rand.Uint32()} -} + func curatedAWSIPv6Gen() rnet.NetworkNumber { randIdx := rand.Intn(len(ipV6AWSRangesIPNets)) // Randomly generate an IP somewhat near the range. network := ipV6AWSRangesIPNets[randIdx] - nn := rnet.NewNetworkNumber(network.IP) - ones, bits := network.Mask.Size() + nn := rnet.NewNetworkNumber(network.Addr()) + + bits := 32 + addr := network.Addr() + if addr.Is6() { + bits = 128 + } + ones := network.Bits() zeros := bits - ones nnPartIdx := zeros / rnet.BitsPerUint32 nn[nnPartIdx] = rand.Uint32() @@ -246,9 +251,9 @@ func curatedAWSIPv6Gen() rnet.NetworkNumber { type networkGenerator func() rnet.Network -func randomIPNetGenFactory(pool []*net.IPNet) networkGenerator { +func randomIPNetGenFactory(pool []netip.Prefix) networkGenerator { return func() rnet.Network { - return rnet.NewNetwork(*pool[rand.Intn(len(pool))]) + return rnet.NewNetwork(pool[rand.Intn(len(pool))]) } } @@ -270,11 +275,11 @@ type IPv6Prefix struct { } var awsRanges *AWSRanges -var ipV4AWSRangesIPNets []*net.IPNet -var ipV6AWSRangesIPNets []*net.IPNet +var ipV4AWSRangesIPNets []netip.Prefix +var ipV6AWSRangesIPNets []netip.Prefix func loadAWSRanges() *AWSRanges { - file, err := ioutil.ReadFile("./testdata/aws_ip_ranges.json") + file, err := os.ReadFile("./testdata/aws_ip_ranges.json") if err != nil { panic(err) } @@ -288,25 +293,23 @@ func loadAWSRanges() *AWSRanges { func configureRangerWithAWSRanges(tb testing.TB, ranger Ranger) { for _, prefix := range awsRanges.Prefixes { - _, network, err := net.ParseCIDR(prefix.IPPrefix) - assert.NoError(tb, err) - ranger.Insert(NewBasicRangerEntry(*network)) + network := netip.MustParsePrefix(prefix.IPPrefix) + ranger.Insert(NewBasicRangerEntry(network)) } for _, prefix := range awsRanges.IPv6Prefixes { - _, network, err := net.ParseCIDR(prefix.IPPrefix) - assert.NoError(tb, err) - ranger.Insert(NewBasicRangerEntry(*network)) + network := netip.MustParsePrefix(prefix.IPPrefix) + ranger.Insert(NewBasicRangerEntry(network)) } } func init() { awsRanges = loadAWSRanges() for _, prefix := range awsRanges.IPv6Prefixes { - _, network, _ := net.ParseCIDR(prefix.IPPrefix) + network := netip.MustParsePrefix(prefix.IPPrefix) ipV6AWSRangesIPNets = append(ipV6AWSRangesIPNets, network) } for _, prefix := range awsRanges.Prefixes { - _, network, _ := net.ParseCIDR(prefix.IPPrefix) + network := netip.MustParsePrefix(prefix.IPPrefix) ipV4AWSRangesIPNets = append(ipV4AWSRangesIPNets, network) } rand.Seed(time.Now().Unix()) diff --git a/example/custom-ranger-asn.go b/example/custom-ranger-asn.go index 7b5f858..46eab50 100644 --- a/example/custom-ranger-asn.go +++ b/example/custom-ranger-asn.go @@ -9,7 +9,7 @@ package main import ( "fmt" - "net" + "net/netip" "os" "github.com/yl2chen/cidranger" @@ -17,12 +17,12 @@ import ( // custom structure that conforms to RangerEntry interface type customRangerEntry struct { - ipNet net.IPNet + ipNet netip.Prefix asn string } // get function for network -func (b *customRangerEntry) Network() net.IPNet { +func (b *customRangerEntry) Network() netip.Prefix { return b.ipNet } @@ -37,7 +37,7 @@ func (b *customRangerEntry) Asn() string { } // create customRangerEntry object using net and asn -func newCustomRangerEntry(ipNet net.IPNet, asn string) cidranger.RangerEntry { +func newCustomRangerEntry(ipNet netip.Prefix, asn string) cidranger.RangerEntry { return &customRangerEntry{ ipNet: ipNet, asn: asn, @@ -51,14 +51,14 @@ func main() { ranger := cidranger.NewPCTrieRanger() // Load sample data using our custom function - _, network, _ := net.ParseCIDR("192.168.1.0/24") - ranger.Insert(newCustomRangerEntry(*network, "0001")) + network := netip.MustParsePrefix("192.168.1.0/24") + ranger.Insert(newCustomRangerEntry(network, "0001")) - _, network, _ = net.ParseCIDR("128.168.1.0/24") - ranger.Insert(newCustomRangerEntry(*network, "0002")) + network = netip.MustParsePrefix("128.168.1.0/24") + ranger.Insert(newCustomRangerEntry(network, "0002")) // Check if IP is contained within ranger - contains, err := ranger.Contains(net.ParseIP("128.168.1.7")) + contains, err := ranger.Contains(netip.MustParseAddr("128.168.1.7")) if err != nil { fmt.Println("ranger.Contains()", err.Error()) os.Exit(1) @@ -67,7 +67,7 @@ func main() { // request networks containing this IP ip := "192.168.1.42" - entries, err := ranger.ContainingNetworks(net.ParseIP(ip)) + entries, err := ranger.ContainingNetworks(netip.MustParseAddr(ip)) if err != nil { fmt.Println("ranger.ContainingNetworks()", err.Error()) os.Exit(1) diff --git a/net/ip_test.go b/net/ip_test.go index 1e915df..a1eb636 100644 --- a/net/ip_test.go +++ b/net/ip_test.go @@ -2,7 +2,7 @@ package net import ( "math" - "net" + "net/netip" "testing" "github.com/stretchr/testify/assert" @@ -10,16 +10,14 @@ import ( func TestNewNetworkNumber(t *testing.T) { cases := []struct { - ip net.IP + ip netip.Addr nn NetworkNumber name string }{ - {nil, nil, "nil input"}, - {net.IP([]byte{1, 1, 1, 1, 1}), nil, "bad input"}, - {net.ParseIP("128.0.0.0"), NetworkNumber([]uint32{2147483648}), "IPv4"}, + {netip.MustParseAddr("128.0.0.0"), NetworkNumber([]uint32{0x80_00_00_00}), "IPv4"}, { - net.ParseIP("2001:0db8::ff00:0042:8329"), - NetworkNumber([]uint32{536939960, 0, 65280, 4358953}), + netip.MustParseAddr("2001:0db8::ff00:0042:8329"), + NetworkNumber([]uint32{0x2001_0db8, 0x0000_0000, 0x0000_ff00, 0x0042_8329}), "IPv6", }, } @@ -55,10 +53,10 @@ func TestNetworkNumberBit(t *testing.T) { ones map[uint]bool name string }{ - {NewNetworkNumber(net.ParseIP("128.0.0.0")), map[uint]bool{31: true}, "128.0.0.0"}, - {NewNetworkNumber(net.ParseIP("1.1.1.1")), map[uint]bool{0: true, 8: true, 16: true, 24: true}, "1.1.1.1"}, - {NewNetworkNumber(net.ParseIP("8000::")), map[uint]bool{127: true}, "8000::"}, - {NewNetworkNumber(net.ParseIP("8000::8000")), map[uint]bool{127: true, 15: true}, "8000::8000"}, + {NewNetworkNumber(netip.MustParseAddr("128.0.0.0")), map[uint]bool{31: true}, "128.0.0.0"}, + {NewNetworkNumber(netip.MustParseAddr("1.1.1.1")), map[uint]bool{0: true, 8: true, 16: true, 24: true}, "1.1.1.1"}, + {NewNetworkNumber(netip.MustParseAddr("8000::")), map[uint]bool{127: true}, "8000::"}, + {NewNetworkNumber(netip.MustParseAddr("8000::8000")), map[uint]bool{127: true, 15: true}, "8000::8000"}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { @@ -82,12 +80,12 @@ func TestNetworkNumberBitError(t *testing.T) { err error name string }{ - {NewNetworkNumber(net.ParseIP("128.0.0.0")), 0, nil, "IPv4 index in bound"}, - {NewNetworkNumber(net.ParseIP("128.0.0.0")), 31, nil, "IPv4 index in bound"}, - {NewNetworkNumber(net.ParseIP("128.0.0.0")), 32, ErrInvalidBitPosition, "IPv4 index out of bounds"}, - {NewNetworkNumber(net.ParseIP("8000::")), 0, nil, "IPv6 index in bound"}, - {NewNetworkNumber(net.ParseIP("8000::")), 127, nil, "IPv6 index in bound"}, - {NewNetworkNumber(net.ParseIP("8000::")), 128, ErrInvalidBitPosition, "IPv6 index out of bounds"}, + {NewNetworkNumber(netip.MustParseAddr("128.0.0.0")), 0, nil, "IPv4 index in bound"}, + {NewNetworkNumber(netip.MustParseAddr("128.0.0.0")), 31, nil, "IPv4 index in bound"}, + {NewNetworkNumber(netip.MustParseAddr("128.0.0.0")), 32, ErrInvalidBitPosition, "IPv4 index out of bounds"}, + {NewNetworkNumber(netip.MustParseAddr("8000::")), 0, nil, "IPv6 index in bound"}, + {NewNetworkNumber(netip.MustParseAddr("8000::")), 127, nil, "IPv6 index in bound"}, + {NewNetworkNumber(netip.MustParseAddr("8000::")), 128, ErrInvalidBitPosition, "IPv6 index out of bounds"}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { @@ -133,8 +131,8 @@ func TestNetworkNumberNext(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ip := NewNetworkNumber(net.ParseIP(tc.ip)) - expected := NewNetworkNumber(net.ParseIP(tc.next)) + ip := NewNetworkNumber(netip.MustParseAddr(tc.ip)) + expected := NewNetworkNumber(netip.MustParseAddr(tc.next)) assert.Equal(t, expected, ip.Next()) }) } @@ -156,8 +154,8 @@ func TestNeworkNumberPrevious(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ip := NewNetworkNumber(net.ParseIP(tc.ip)) - expected := NewNetworkNumber(net.ParseIP(tc.previous)) + ip := NewNetworkNumber(netip.MustParseAddr(tc.ip)) + expected := NewNetworkNumber(netip.MustParseAddr(tc.previous)) assert.Equal(t, expected, ip.Previous()) }) } @@ -217,10 +215,10 @@ func TestLeastCommonBitPositionForNetworks(t *testing.T) { } func TestNewNetwork(t *testing.T) { - _, ipNet, _ := net.ParseCIDR("192.128.0.0/24") - n := NewNetwork(*ipNet) + ipNet := netip.MustParsePrefix("192.128.0.0/24") + n := NewNetwork(ipNet) - assert.Equal(t, *ipNet, n.IPNet) + assert.Equal(t, ipNet, n.IPNet) assert.Equal(t, NetworkNumber{3229614080}, n.Number) assert.Equal(t, NetworkNumberMask{math.MaxUint32 - uint32(math.MaxUint8)}, n.Mask) } @@ -241,10 +239,10 @@ func TestNetworkMasked(t *testing.T) { {"8000:ffff::/96", 16, "8000::/16"}, } for _, testcase := range cases { - _, network, _ := net.ParseCIDR(testcase.network) - _, expected, _ := net.ParseCIDR(testcase.maskedNetwork) - n1 := NewNetwork(*network) - e1 := NewNetwork(*expected) + network := netip.MustParsePrefix(testcase.network) + expected := netip.MustParsePrefix(testcase.maskedNetwork) + n1 := NewNetwork(network) + e1 := NewNetwork(expected) assert.True(t, e1.String() == n1.Masked(testcase.mask).String()) } } @@ -263,9 +261,9 @@ func TestNetworkEqual(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - _, ipNet1, _ := net.ParseCIDR(tc.n1) - _, ipNet2, _ := net.ParseCIDR(tc.n2) - assert.Equal(t, tc.equal, NewNetwork(*ipNet1).Equal(NewNetwork(*ipNet2))) + ipNet1 := netip.MustParsePrefix(tc.n1) + ipNet2 := netip.MustParsePrefix(tc.n2) + assert.Equal(t, tc.equal, NewNetwork(ipNet1).Equal(NewNetwork(ipNet2))) }) } } @@ -282,10 +280,10 @@ func TestNetworkContains(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - _, net1, _ := net.ParseCIDR(tc.network) - network := NewNetwork(*net1) - ip := NewNetworkNumber(net.ParseIP(tc.firstIP)) - lastIP := NewNetworkNumber(net.ParseIP(tc.lastIP)) + net1 := netip.MustParsePrefix(tc.network) + network := NewNetwork(net1) + ip := NewNetworkNumber(netip.MustParseAddr(tc.firstIP)) + lastIP := NewNetworkNumber(netip.MustParseAddr(tc.lastIP)) assert.False(t, network.Contains(ip.Previous())) assert.False(t, network.Contains(lastIP.Next())) for ; !ip.Equal(lastIP.Next()); ip = ip.Next() { @@ -306,9 +304,9 @@ func TestNetworkContainsVersionMismatch(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - _, net1, _ := net.ParseCIDR(tc.network) - network := NewNetwork(*net1) - assert.False(t, network.Contains(NewNetworkNumber(net.ParseIP(tc.ip)))) + net1 := netip.MustParsePrefix(tc.network) + network := NewNetwork(net1) + assert.False(t, network.Contains(NewNetworkNumber(netip.MustParseAddr(tc.ip)))) }) } } @@ -331,10 +329,10 @@ func TestNetworkCovers(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - _, n, _ := net.ParseCIDR(tc.network) - network := NewNetwork(*n) - _, n, _ = net.ParseCIDR(tc.covers) - covers := NewNetwork(*n) + n := netip.MustParsePrefix(tc.network) + network := NewNetwork(n) + n = netip.MustParsePrefix(tc.covers) + covers := NewNetwork(n) assert.Equal(t, tc.result, network.Covers(covers)) }) } @@ -358,12 +356,10 @@ func TestNetworkLeastCommonBitPosition(t *testing.T) { {"ffff::0/24", "0::1/24", 0, ErrNoGreatestCommonBit, "IPv6 diverge at 1st pos"}, } for _, c := range cases { - _, cidr1, err := net.ParseCIDR(c.cidr1) - assert.NoError(t, err) - _, cidr2, err := net.ParseCIDR(c.cidr2) - assert.NoError(t, err) - n1 := NewNetwork(*cidr1) - pos, err := n1.LeastCommonBitPosition(NewNetwork(*cidr2)) + cidr1 := netip.MustParsePrefix(c.cidr1) + cidr2 := netip.MustParsePrefix(c.cidr2) + n1 := NewNetwork(cidr1) + pos, err := n1.LeastCommonBitPosition(NewNetwork(cidr2)) if c.expectedErr != nil { assert.Equal(t, c.expectedErr, err) } else { @@ -413,7 +409,7 @@ func TestNextIP(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, net.ParseIP(tc.next), NextIP(net.ParseIP(tc.ip))) + assert.Equal(t, netip.MustParseAddr(tc.next), NextIP(netip.MustParseAddr(tc.ip))) }) } } @@ -434,15 +430,15 @@ func TestPreviousIP(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, net.ParseIP(tc.next), PreviousIP(net.ParseIP(tc.ip))) + assert.Equal(t, netip.MustParseAddr(tc.next), PreviousIP(netip.MustParseAddr(tc.ip))) }) } } /* - ********************************* - Benchmarking ip manipulations. - ********************************* +********************************* +Benchmarking ip manipulations. +********************************* */ func BenchmarkNetworkNumberBitIPv4(b *testing.B) { benchmarkNetworkNumberBit(b, "52.95.110.1", 6) @@ -476,34 +472,34 @@ func BenchmarkNetworkEqualIPv6(b *testing.B) { } func benchmarkNetworkNumberBit(b *testing.B, ip string, pos uint) { - nn := NewNetworkNumber(net.ParseIP(ip)) + nn := NewNetworkNumber(netip.MustParseAddr(ip)) for n := 0; n < b.N; n++ { nn.Bit(pos) } } func benchmarkNetworkNumberEqual(b *testing.B, ip1 string, ip2 string) { - nn1 := NewNetworkNumber(net.ParseIP(ip1)) - nn2 := NewNetworkNumber(net.ParseIP(ip2)) + nn1 := NewNetworkNumber(netip.MustParseAddr(ip1)) + nn2 := NewNetworkNumber(netip.MustParseAddr(ip2)) for n := 0; n < b.N; n++ { nn1.Equal(nn2) } } func benchmarkNetworkContains(b *testing.B, cidr string, ip string) { - nn := NewNetworkNumber(net.ParseIP(ip)) - _, ipNet, _ := net.ParseCIDR(cidr) - network := NewNetwork(*ipNet) + nn := NewNetworkNumber(netip.MustParseAddr(ip)) + ipNet := netip.MustParsePrefix(cidr) + network := NewNetwork(ipNet) for n := 0; n < b.N; n++ { network.Contains(nn) } } func benchmarkNetworkEqual(b *testing.B, net1 string, net2 string) { - _, ipNet1, _ := net.ParseCIDR(net1) - _, ipNet2, _ := net.ParseCIDR(net2) - n1 := NewNetwork(*ipNet1) - n2 := NewNetwork(*ipNet2) + ipNet1 := netip.MustParsePrefix(net1) + ipNet2 := netip.MustParsePrefix(net2) + n1 := NewNetwork(ipNet1) + n2 := NewNetwork(ipNet2) for n := 0; n < b.N; n++ { n1.Equal(n2) } diff --git a/trie_test.go b/trie_test.go index 04f2900..10e5807 100644 --- a/trie_test.go +++ b/trie_test.go @@ -2,8 +2,9 @@ package cidranger import ( "encoding/binary" + "math" "math/rand" - "net" + "net/netip" "runtime" "testing" "time" @@ -12,7 +13,7 @@ import ( rnet "github.com/yl2chen/cidranger/net" ) -func getAllByVersion(version rnet.IPVersion) *net.IPNet { +func getAllByVersion(version rnet.IPVersion) netip.Prefix { if version == rnet.IPv6 { return AllIPv6 } @@ -74,21 +75,21 @@ func TestPrefixTrieInsert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { trie := newPrefixTree(tc.version).(*prefixTrie) for _, insert := range tc.inserts { - _, network, _ := net.ParseCIDR(insert) - err := trie.Insert(NewBasicRangerEntry(*network)) + network := netip.MustParsePrefix(insert) + err := trie.Insert(NewBasicRangerEntry(network)) assert.NoError(t, err) } assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match") - allNetworks, err := trie.CoveredNetworks(*getAllByVersion(tc.version)) + allNetworks, err := trie.CoveredNetworks(getAllByVersion(tc.version)) assert.Nil(t, err) assert.Equal(t, len(allNetworks), trie.Len(), "trie size should match") walk := trie.walkDepth() for _, network := range tc.expectedNetworksInDepthOrder { - _, ipnet, _ := net.ParseCIDR(network) - expected := NewBasicRangerEntry(*ipnet) + ipnet := netip.MustParsePrefix(network) + expected := NewBasicRangerEntry(ipnet) actual := <-walk assert.Equal(t, expected, actual) } @@ -105,8 +106,8 @@ func TestPrefixTrieString(t *testing.T) { inserts := []string{"192.168.0.1/24", "192.168.1.1/24", "192.168.1.1/30"} trie := newPrefixTree(rnet.IPv4).(*prefixTrie) for _, insert := range inserts { - _, network, _ := net.ParseCIDR(insert) - trie.Insert(NewBasicRangerEntry(*network)) + network := netip.MustParsePrefix(insert) + trie.Insert(NewBasicRangerEntry(network)) } expected := `0.0.0.0/0 (target_pos:31:has_entry:false) | 1--> 192.168.0.0/23 (target_pos:8:has_entry:false) @@ -206,17 +207,17 @@ func TestPrefixTrieRemove(t *testing.T) { t.Run(tc.name, func(t *testing.T) { trie := newPrefixTree(tc.version).(*prefixTrie) for _, insert := range tc.inserts { - _, network, _ := net.ParseCIDR(insert) - err := trie.Insert(NewBasicRangerEntry(*network)) + network := netip.MustParsePrefix(insert) + err := trie.Insert(NewBasicRangerEntry(network)) assert.NoError(t, err) } for i, remove := range tc.removes { - _, network, _ := net.ParseCIDR(remove) - removed, err := trie.Remove(*network) + network := netip.MustParsePrefix(remove) + removed, err := trie.Remove(network) assert.NoError(t, err) if str := tc.expectedRemoves[i]; str != "" { - _, ipnet, _ := net.ParseCIDR(str) - expected := NewBasicRangerEntry(*ipnet) + ipnet := netip.MustParsePrefix(str) + expected := NewBasicRangerEntry(ipnet) assert.Equal(t, expected, removed) } else { assert.Nil(t, removed) @@ -225,14 +226,14 @@ func TestPrefixTrieRemove(t *testing.T) { assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match after revmoval") - allNetworks, err := trie.CoveredNetworks(*getAllByVersion(tc.version)) + allNetworks, err := trie.CoveredNetworks(getAllByVersion(tc.version)) assert.Nil(t, err) assert.Equal(t, len(allNetworks), trie.Len(), "trie size should match") walk := trie.walkDepth() for _, network := range tc.expectedNetworksInDepthOrder { - _, ipnet, _ := net.ParseCIDR(network) - expected := NewBasicRangerEntry(*ipnet) + ipnet := netip.MustParsePrefix(network) + expected := NewBasicRangerEntry(ipnet) actual := <-walk assert.Equal(t, expected, actual) } @@ -251,21 +252,21 @@ func TestToReplicateIssue(t *testing.T) { cases := []struct { version rnet.IPVersion inserts []string - ip net.IP + ip netip.Addr networks []string name string }{ { rnet.IPv4, []string{"192.168.0.1/32"}, - net.ParseIP("192.168.0.1"), + netip.MustParseAddr("192.168.0.1"), []string{"192.168.0.1/32"}, "basic containing network for /32 mask", }, { rnet.IPv6, []string{"a::1/128"}, - net.ParseIP("a::1"), + netip.MustParseAddr("a::1"), []string{"a::1/128"}, "basic containing network for /128 mask", }, @@ -274,14 +275,14 @@ func TestToReplicateIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { trie := newPrefixTree(tc.version) for _, insert := range tc.inserts { - _, network, _ := net.ParseCIDR(insert) - err := trie.Insert(NewBasicRangerEntry(*network)) + network := netip.MustParsePrefix(insert) + err := trie.Insert(NewBasicRangerEntry(network)) assert.NoError(t, err) } expectedEntries := []RangerEntry{} for _, network := range tc.networks { - _, net, _ := net.ParseCIDR(network) - expectedEntries = append(expectedEntries, NewBasicRangerEntry(*net)) + net := netip.MustParsePrefix(network) + expectedEntries = append(expectedEntries, NewBasicRangerEntry(net)) } contains, err := trie.Contains(tc.ip) assert.NoError(t, err) @@ -294,8 +295,8 @@ func TestToReplicateIssue(t *testing.T) { } type expectedIPRange struct { - start net.IP - end net.IP + start netip.Addr + end netip.Addr } func TestPrefixTrieContains(t *testing.T) { @@ -309,7 +310,7 @@ func TestPrefixTrieContains(t *testing.T) { rnet.IPv4, []string{"192.168.0.0/24"}, []expectedIPRange{ - {net.ParseIP("192.168.0.0"), net.ParseIP("192.168.1.0")}, + {netip.MustParseAddr("192.168.0.0"), netip.MustParseAddr("192.168.1.0")}, }, "basic contains", }, @@ -317,8 +318,8 @@ func TestPrefixTrieContains(t *testing.T) { rnet.IPv4, []string{"192.168.0.0/24", "128.168.0.0/24"}, []expectedIPRange{ - {net.ParseIP("192.168.0.0"), net.ParseIP("192.168.1.0")}, - {net.ParseIP("128.168.0.0"), net.ParseIP("128.168.1.0")}, + {netip.MustParseAddr("192.168.0.0"), netip.MustParseAddr("192.168.1.0")}, + {netip.MustParseAddr("128.168.0.0"), netip.MustParseAddr("128.168.1.0")}, }, "multiple ranges contains", }, @@ -328,15 +329,14 @@ func TestPrefixTrieContains(t *testing.T) { t.Run(tc.name, func(t *testing.T) { trie := newPrefixTree(tc.version) for _, insert := range tc.inserts { - _, network, _ := net.ParseCIDR(insert) - err := trie.Insert(NewBasicRangerEntry(*network)) + network := netip.MustParsePrefix(insert) + err := trie.Insert(NewBasicRangerEntry(network)) assert.NoError(t, err) } for _, expectedIPRange := range tc.expectedIPs { var contains bool var err error - start := expectedIPRange.start - for ; !expectedIPRange.end.Equal(start); start = rnet.NextIP(start) { + for start := expectedIPRange.start; expectedIPRange.end != start; start = rnet.NextIP(start) { contains, err = trie.Contains(start) assert.NoError(t, err) assert.True(t, contains) @@ -358,21 +358,21 @@ func TestPrefixTrieContainingNetworks(t *testing.T) { cases := []struct { version rnet.IPVersion inserts []string - ip net.IP + ip netip.Addr networks []string name string }{ { rnet.IPv4, []string{"192.168.0.0/24"}, - net.ParseIP("192.168.0.1"), + netip.MustParseAddr("192.168.0.1"), []string{"192.168.0.0/24"}, "basic containing networks", }, { rnet.IPv4, []string{"192.168.0.0/24", "192.168.0.0/25"}, - net.ParseIP("192.168.0.1"), + netip.MustParseAddr("192.168.0.1"), []string{"192.168.0.0/24", "192.168.0.0/25"}, "inclusive networks", }, @@ -381,14 +381,14 @@ func TestPrefixTrieContainingNetworks(t *testing.T) { t.Run(tc.name, func(t *testing.T) { trie := newPrefixTree(tc.version) for _, insert := range tc.inserts { - _, network, _ := net.ParseCIDR(insert) - err := trie.Insert(NewBasicRangerEntry(*network)) + network := netip.MustParsePrefix(insert) + err := trie.Insert(NewBasicRangerEntry(network)) assert.NoError(t, err) } expectedEntries := []RangerEntry{} for _, network := range tc.networks { - _, net, _ := net.ParseCIDR(network) - expectedEntries = append(expectedEntries, NewBasicRangerEntry(*net)) + net := netip.MustParsePrefix(network) + expectedEntries = append(expectedEntries, NewBasicRangerEntry(net)) } networks, err := trie.ContainingNetworks(tc.ip) assert.NoError(t, err) @@ -474,18 +474,18 @@ func TestPrefixTrieCoveredNetworks(t *testing.T) { t.Run(tc.name, func(t *testing.T) { trie := newPrefixTree(tc.version) for _, insert := range tc.inserts { - _, network, _ := net.ParseCIDR(insert) - err := trie.Insert(NewBasicRangerEntry(*network)) + network := netip.MustParsePrefix(insert) + err := trie.Insert(NewBasicRangerEntry(network)) assert.NoError(t, err) } var expectedEntries []RangerEntry for _, network := range tc.networks { - _, net, _ := net.ParseCIDR(network) + net := netip.MustParsePrefix(network) expectedEntries = append(expectedEntries, - NewBasicRangerEntry(*net)) + NewBasicRangerEntry(net)) } - _, snet, _ := net.ParseCIDR(tc.search) - networks, err := trie.CoveredNetworks(*snet) + snet := netip.MustParsePrefix(tc.search) + networks, err := trie.CoveredNetworks(snet) assert.NoError(t, err) assert.Equal(t, expectedEntries, networks) }) @@ -517,13 +517,13 @@ func TestTrieMemUsage(t *testing.T) { assert.Less(t, 0, trie.Len(), "Len should > 0") assert.LessOrEqualf(t, trie.Len(), numIPs, "Len should <= %d", numIPs) - allNetworks, err := trie.CoveredNetworks(*getAllByVersion(rnet.IPv4)) + allNetworks, err := trie.CoveredNetworks(getAllByVersion(rnet.IPv4)) assert.Nil(t, err) assert.Equal(t, len(allNetworks), trie.Len(), "trie size should match") // Remove networks. - _, all, _ := net.ParseCIDR("0.0.0.0/0") - ll, _ := trie.CoveredNetworks(*all) + all := netip.MustParsePrefix("0.0.0.0/0") + ll, _ := trie.CoveredNetworks(all) for i := 0; i < len(ll); i++ { trie.Remove(ll[i].Network()) } @@ -546,22 +546,20 @@ func TestTrieMemUsage(t *testing.T) { assert.LessOrEqual(t, float64(baseLineHeap), float64(totalHeapAllocOverRuns/uint64(runs))*thresh) } -func GenLeafIPNet(ip net.IP) net.IPNet { - return net.IPNet{ - IP: ip, - Mask: net.CIDRMask(32, 32), - } +func GenLeafIPNet(ip netip.Addr) netip.Prefix { + return netip.PrefixFrom(ip, 32) } // GenIPV4 generates an IPV4 address -func GenIPV4() net.IP { +func GenIPV4() netip.Addr { rand.Seed(time.Now().UnixNano()) nn := rand.Uint32() - if nn < 4294967295 { + if nn < math.MaxUint32 { nn++ } - ip := make(net.IP, 4) - binary.BigEndian.PutUint32(ip, uint32(nn)) + sl := make([]byte, 4) + binary.BigEndian.PutUint32(sl, uint32(nn)) + ip, _ := netip.AddrFromSlice(sl) return ip } From 93e64309d02687cc6f2fe58b2c2f1d9b6cbad6ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikk=20Margus=20M=C3=B6ll?= Date: Sat, 27 Aug 2022 20:10:33 +0300 Subject: [PATCH 3/3] update module file --- go.mod | 9 ++++++--- go.sum | 4 ---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index a35ea91..3f1987d 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,11 @@ module github.com/yl2chen/cidranger -go 1.13 +go 1.18 + +require github.com/stretchr/testify v1.6.1 require ( - github.com/stretchr/testify v1.6.1 - gopkg.in/yaml.v2 v2.2.2 // indirect + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/go.sum b/go.sum index d063842..afe7890 100644 --- a/go.sum +++ b/go.sum @@ -3,13 +3,9 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=