Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move from net to netip #46

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions brute.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cidranger

import (
"net"
"net/netip"

rnet "github.com/yl2chen/cidranger/net"
)
Expand All @@ -17,24 +17,24 @@ import (
// and used as the ground truth when running a wider range of 'random' tests on
// other more sophisticated implementations.
type bruteRanger struct {
ipV4Entries map[string]RangerEntry
ipV6Entries map[string]RangerEntry
ipV4Entries map[netip.Prefix]RangerEntry
ipV6Entries map[netip.Prefix]RangerEntry
}

// newBruteRanger returns a new Ranger.
func newBruteRanger() Ranger {
return &bruteRanger{
ipV4Entries: make(map[string]RangerEntry),
ipV6Entries: make(map[string]RangerEntry),
ipV4Entries: make(map[netip.Prefix]RangerEntry),
ipV6Entries: make(map[netip.Prefix]RangerEntry),
}
}

// Insert inserts a RangerEntry into ranger.
func (b *bruteRanger) Insert(entry RangerEntry) error {
network := entry.Network()
key := network.String()
key := network
if _, found := b.ipV4Entries[key]; !found {
entries, err := b.getEntriesByVersion(entry.Network().IP)
entries, err := b.getEntriesByVersion(entry.Network().Addr())
if err != nil {
return err
}
Expand All @@ -44,12 +44,12 @@ func (b *bruteRanger) Insert(entry RangerEntry) error {
}

// Remove removes a RangerEntry identified by given network from ranger.
func (b *bruteRanger) Remove(network net.IPNet) (RangerEntry, error) {
networks, err := b.getEntriesByVersion(network.IP)
func (b *bruteRanger) Remove(network netip.Prefix) (RangerEntry, error) {
networks, err := b.getEntriesByVersion(network.Addr())
if err != nil {
return nil, err
}
key := network.String()
key := network
if networkToDelete, found := networks[key]; found {
delete(networks, key)
return networkToDelete, nil
Expand All @@ -59,7 +59,7 @@ func (b *bruteRanger) Remove(network net.IPNet) (RangerEntry, error) {

// Contains returns bool indicating whether given ip is contained by any
// network in ranger.
func (b *bruteRanger) Contains(ip net.IP) (bool, error) {
func (b *bruteRanger) Contains(ip netip.Addr) (bool, error) {
entries, err := b.getEntriesByVersion(ip)
if err != nil {
return false, err
Expand All @@ -74,7 +74,7 @@ func (b *bruteRanger) Contains(ip net.IP) (bool, error) {
}

// ContainingNetworks returns all RangerEntry(s) that given ip contained in.
func (b *bruteRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
func (b *bruteRanger) ContainingNetworks(ip netip.Addr) ([]RangerEntry, error) {
entries, err := b.getEntriesByVersion(ip)
if err != nil {
return nil, err
Expand All @@ -92,8 +92,8 @@ func (b *bruteRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers. That is, the networks that are completely subsumed by the
// specified network.
func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
entries, err := b.getEntriesByVersion(network.IP)
func (b *bruteRanger) CoveredNetworks(network netip.Prefix) ([]RangerEntry, error) {
entries, err := b.getEntriesByVersion(network.Addr())
if err != nil {
return nil, err
}
Expand All @@ -113,11 +113,11 @@ func (b *bruteRanger) Len() int {
return len(b.ipV4Entries) + len(b.ipV6Entries)
}

func (b *bruteRanger) getEntriesByVersion(ip net.IP) (map[string]RangerEntry, error) {
if ip.To4() != nil {
func (b *bruteRanger) getEntriesByVersion(ip netip.Addr) (map[netip.Prefix]RangerEntry, error) {
if ip.Is4() {
return b.ipV4Entries, nil
}
if ip.To16() != nil {
if ip.Is6() {
return b.ipV6Entries, nil
}
return nil, ErrInvalidNetworkInput
Expand Down
110 changes: 45 additions & 65 deletions brute_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cidranger

import (
"net"
"net/netip"
"sort"
"testing"

Expand All @@ -10,46 +10,38 @@ import (

func TestInsert(t *testing.T) {
ranger := newBruteRanger().(*bruteRanger)
_, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24")
_, networkIPv6, _ := net.ParseCIDR("8000::/96")
entryIPv4 := NewBasicRangerEntry(*networkIPv4)
entryIPv6 := NewBasicRangerEntry(*networkIPv6)
networkIPv4 := netip.MustParsePrefix("0.0.1.0/24")
networkIPv6 := netip.MustParsePrefix("8000::/96")
entryIPv4 := NewBasicRangerEntry(networkIPv4)
entryIPv6 := NewBasicRangerEntry(networkIPv6)

ranger.Insert(entryIPv4)
ranger.Insert(entryIPv6)

assert.Equal(t, 1, len(ranger.ipV4Entries))
assert.Equal(t, entryIPv4, ranger.ipV4Entries["0.0.1.0/24"])
assert.Equal(t, entryIPv4, ranger.ipV4Entries[networkIPv4])
assert.Equal(t, 1, len(ranger.ipV6Entries))
assert.Equal(t, entryIPv6, ranger.ipV6Entries["8000::/96"])
}

func TestInsertError(t *testing.T) {
bRanger := newBruteRanger().(*bruteRanger)
_, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24")
networkIPv4.IP = append(networkIPv4.IP, byte(4))
err := bRanger.Insert(NewBasicRangerEntry(*networkIPv4))
assert.Equal(t, ErrInvalidNetworkInput, err)
assert.Equal(t, entryIPv6, ranger.ipV6Entries[networkIPv6])
}

func TestRemove(t *testing.T) {
ranger := newBruteRanger().(*bruteRanger)
_, networkIPv4, _ := net.ParseCIDR("0.0.1.0/24")
_, networkIPv6, _ := net.ParseCIDR("8000::/96")
_, notInserted, _ := net.ParseCIDR("8000::/96")
networkIPv4 := netip.MustParsePrefix("0.0.1.0/24")
networkIPv6 := netip.MustParsePrefix("8000::/96")
notInserted := netip.MustParsePrefix("8000::/96")

insertIPv4 := NewBasicRangerEntry(*networkIPv4)
insertIPv6 := NewBasicRangerEntry(*networkIPv6)
insertIPv4 := NewBasicRangerEntry(networkIPv4)
insertIPv6 := NewBasicRangerEntry(networkIPv6)

ranger.Insert(insertIPv4)
deletedIPv4, err := ranger.Remove(*networkIPv4)
deletedIPv4, err := ranger.Remove(networkIPv4)
assert.NoError(t, err)

ranger.Insert(insertIPv6)
deletedIPv6, err := ranger.Remove(*networkIPv6)
deletedIPv6, err := ranger.Remove(networkIPv6)
assert.NoError(t, err)

entry, err := ranger.Remove(*notInserted)
entry, err := ranger.Remove(notInserted)
assert.NoError(t, err)
assert.Nil(t, entry)

Expand All @@ -59,33 +51,23 @@ func TestRemove(t *testing.T) {
assert.Equal(t, 0, len(ranger.ipV6Entries))
}

func TestRemoveError(t *testing.T) {
r := newBruteRanger().(*bruteRanger)
_, invalidNetwork, _ := net.ParseCIDR("0.0.1.0/24")
invalidNetwork.IP = append(invalidNetwork.IP, byte(4))

_, err := r.Remove(*invalidNetwork)
assert.Equal(t, ErrInvalidNetworkInput, err)
}

func TestContains(t *testing.T) {
r := newBruteRanger().(*bruteRanger)
_, network, _ := net.ParseCIDR("0.0.1.0/24")
_, network1, _ := net.ParseCIDR("8000::/112")
r.Insert(NewBasicRangerEntry(*network))
r.Insert(NewBasicRangerEntry(*network1))
network := netip.MustParsePrefix("0.0.1.0/24")
network1 := netip.MustParsePrefix("8000::/112")
r.Insert(NewBasicRangerEntry(network))
r.Insert(NewBasicRangerEntry(network1))

cases := []struct {
ip net.IP
ip netip.Addr
contains bool
err error
name string
}{
{net.ParseIP("0.0.1.255"), true, nil, "IPv4 should contain"},
{net.ParseIP("0.0.0.255"), false, nil, "IPv4 houldn't contain"},
{net.ParseIP("8000::ffff"), true, nil, "IPv6 shouldn't contain"},
{net.ParseIP("8000::1:ffff"), false, nil, "IPv6 shouldn't contain"},
{append(net.ParseIP("8000::1:ffff"), byte(0)), false, ErrInvalidNetworkInput, "Invalid IP"},
{netip.MustParseAddr("0.0.1.255"), true, nil, "IPv4 should contain"},
{netip.MustParseAddr("0.0.0.255"), false, nil, "IPv4 shouldn't contain"},
{netip.MustParseAddr("8000::ffff"), true, nil, "IPv6 shouldn't contain"},
{netip.MustParseAddr("8000::1:ffff"), false, nil, "IPv6 shouldn't contain"},
}

for _, tc := range cases {
Expand All @@ -103,31 +85,30 @@ func TestContains(t *testing.T) {

func TestContainingNetworks(t *testing.T) {
r := newBruteRanger().(*bruteRanger)
_, network1, _ := net.ParseCIDR("0.0.1.0/24")
_, network2, _ := net.ParseCIDR("0.0.1.0/25")
_, network3, _ := net.ParseCIDR("8000::/112")
_, network4, _ := net.ParseCIDR("8000::/113")
entry1 := NewBasicRangerEntry(*network1)
entry2 := NewBasicRangerEntry(*network2)
entry3 := NewBasicRangerEntry(*network3)
entry4 := NewBasicRangerEntry(*network4)
network1 := netip.MustParsePrefix("0.0.1.0/24")
network2 := netip.MustParsePrefix("0.0.1.0/25")
network3 := netip.MustParsePrefix("8000::/112")
network4 := netip.MustParsePrefix("8000::/113")
entry1 := NewBasicRangerEntry(network1)
entry2 := NewBasicRangerEntry(network2)
entry3 := NewBasicRangerEntry(network3)
entry4 := NewBasicRangerEntry(network4)
r.Insert(entry1)
r.Insert(entry2)
r.Insert(entry3)
r.Insert(entry4)
cases := []struct {
ip net.IP
ip netip.Addr
containingNetworks []RangerEntry
err error
name string
}{
{net.ParseIP("0.0.1.255"), []RangerEntry{entry1}, nil, "IPv4 should contain"},
{net.ParseIP("0.0.1.127"), []RangerEntry{entry1, entry2}, nil, "IPv4 should contain both"},
{net.ParseIP("0.0.0.127"), []RangerEntry{}, nil, "IPv4 should contain none"},
{net.ParseIP("8000::ffff"), []RangerEntry{entry3}, nil, "IPv6 should constain"},
{net.ParseIP("8000::7fff"), []RangerEntry{entry3, entry4}, nil, "IPv6 should contain both"},
{net.ParseIP("8000::1:7fff"), []RangerEntry{}, nil, "IPv6 should contain none"},
{append(net.ParseIP("8000::1:7fff"), byte(0)), nil, ErrInvalidNetworkInput, "Invalid IP"},
{netip.MustParseAddr("0.0.1.255"), []RangerEntry{entry1}, nil, "IPv4 should contain"},
{netip.MustParseAddr("0.0.1.127"), []RangerEntry{entry1, entry2}, nil, "IPv4 should contain both"},
{netip.MustParseAddr("0.0.0.127"), []RangerEntry{}, nil, "IPv4 should contain none"},
{netip.MustParseAddr("8000::ffff"), []RangerEntry{entry3}, nil, "IPv6 should constain"},
{netip.MustParseAddr("8000::7fff"), []RangerEntry{entry3, entry4}, nil, "IPv6 should contain both"},
{netip.MustParseAddr("8000::1:7fff"), []RangerEntry{}, nil, "IPv6 should contain none"},
}

for _, tc := range cases {
Expand All @@ -151,17 +132,16 @@ func TestCoveredNetworks(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
ranger := newBruteRanger()
for _, insert := range tc.inserts {
_, network, _ := net.ParseCIDR(insert)
err := ranger.Insert(NewBasicRangerEntry(*network))
network := netip.MustParsePrefix(insert)
err := ranger.Insert(NewBasicRangerEntry(network))
assert.NoError(t, err)
}

var expectedEntries []string
for _, network := range tc.networks {
expectedEntries = append(expectedEntries, network)
}
expectedEntries = append(expectedEntries, tc.networks...)
sort.Strings(expectedEntries)
_, snet, _ := net.ParseCIDR(tc.search)
networks, err := ranger.CoveredNetworks(*snet)
snet := netip.MustParsePrefix(tc.search)
networks, err := ranger.CoveredNetworks(snet)
assert.NoError(t, err)

var results []string
Expand Down
28 changes: 14 additions & 14 deletions cidranger.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,54 +41,54 @@ package cidranger

import (
"fmt"
"net"
"net/netip"
)

// ErrInvalidNetworkInput is returned upon invalid network input.
var ErrInvalidNetworkInput = fmt.Errorf("Invalid network input")
var ErrInvalidNetworkInput = fmt.Errorf("invalid network input")

// ErrInvalidNetworkNumberInput is returned upon invalid network input.
var ErrInvalidNetworkNumberInput = fmt.Errorf("Invalid network number input")
var ErrInvalidNetworkNumberInput = fmt.Errorf("invalid network number input")

// AllIPv4 is a IPv4 CIDR that contains all networks
var AllIPv4 = parseCIDRUnsafe("0.0.0.0/0")

// AllIPv6 is a IPv6 CIDR that contains all networks
var AllIPv6 = parseCIDRUnsafe("0::0/0")

func parseCIDRUnsafe(s string) *net.IPNet {
_, cidr, _ := net.ParseCIDR(s)
func parseCIDRUnsafe(s string) netip.Prefix {
cidr, _ := netip.ParsePrefix(s)
return cidr
}

// RangerEntry is an interface for insertable entry into a Ranger.
type RangerEntry interface {
Network() net.IPNet
Network() netip.Prefix
}

type basicRangerEntry struct {
ipNet net.IPNet
ipNet netip.Prefix
}

func (b *basicRangerEntry) Network() net.IPNet {
func (b *basicRangerEntry) Network() netip.Prefix {
return b.ipNet
}

// NewBasicRangerEntry returns a basic RangerEntry that only stores the network
// itself.
func NewBasicRangerEntry(ipNet net.IPNet) RangerEntry {
func NewBasicRangerEntry(ipNet netip.Prefix) RangerEntry {
return &basicRangerEntry{
ipNet: ipNet,
ipNet: ipNet, //.Masked(),
}
}

// Ranger is an interface for cidr block containment lookups.
type Ranger interface {
Insert(entry RangerEntry) error
Remove(network net.IPNet) (RangerEntry, error)
Contains(ip net.IP) (bool, error)
ContainingNetworks(ip net.IP) ([]RangerEntry, error)
CoveredNetworks(network net.IPNet) ([]RangerEntry, error)
Remove(network netip.Prefix) (RangerEntry, error)
Contains(ip netip.Addr) (bool, error)
ContainingNetworks(ip netip.Addr) ([]RangerEntry, error)
CoveredNetworks(network netip.Prefix) ([]RangerEntry, error)
Len() int
}

Expand Down
Loading