Skip to content

Commit

Permalink
Switch to net/netip
Browse files Browse the repository at this point in the history
The preformance of this is approximately the same as the net.IP version,
except for the methods that return a network. For those, there is a
slight improvement.
  • Loading branch information
oschwald committed Jun 30, 2024
1 parent 616cde2 commit 1e6c614
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 206 deletions.
4 changes: 2 additions & 2 deletions deserializer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package maxminddb

import (
"math/big"
"net"
"net/netip"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -13,7 +13,7 @@ func TestDecodingToDeserializer(t *testing.T) {
require.NoError(t, err, "unexpected error while opening database: %v", err)

dser := testDeserializer{}
err = reader.Lookup(net.ParseIP("::1.1.1.0"), &dser)
err = reader.Lookup(netip.MustParseAddr("::1.1.1.0"), &dser)
require.NoError(t, err, "unexpected error while doing lookup: %v", err)

checkDecodingToInterface(t, dser.rv)
Expand Down
16 changes: 8 additions & 8 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package maxminddb_test
import (
"fmt"
"log"
"net"
"net/netip"

"github.com/oschwald/maxminddb-golang"
"github.com/oschwald/maxminddb-golang/v2"
)

// This example shows how to decode to a struct.
Expand All @@ -16,15 +16,15 @@ func ExampleReader_Lookup_struct() {
}
defer db.Close()

ip := net.ParseIP("81.2.69.142")
addr := netip.MustParseAddr("81.2.69.142")

var record struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
} // Or any appropriate struct

err = db.Lookup(ip, &record)
err = db.Lookup(addr, &record)
if err != nil {
log.Panic(err)
}
Expand All @@ -41,10 +41,10 @@ func ExampleReader_Lookup_interface() {
}
defer db.Close()

ip := net.ParseIP("81.2.69.142")
addr := netip.MustParseAddr("81.2.69.142")

var record any
err = db.Lookup(ip, &record)
err = db.Lookup(addr, &record)
if err != nil {
log.Panic(err)
}
Expand Down Expand Up @@ -118,12 +118,12 @@ func ExampleReader_NetworksWithin() {
Domain string `maxminddb:"connection_type"`
}{}

_, network, err := net.ParseCIDR("1.0.0.0/8")
prefix, err := netip.ParsePrefix("1.0.0.0/8")
if err != nil {
log.Panic(err)
}

networks := db.NetworksWithin(network, maxminddb.SkipAliasedNetworks)
networks := db.NetworksWithin(prefix, maxminddb.SkipAliasedNetworks)
for networks.Next() {
subnet, err := networks.Network(&record)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module github.com/oschwald/maxminddb-golang
module github.com/oschwald/maxminddb-golang/v2

go 1.21

Expand Down
94 changes: 47 additions & 47 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"bytes"
"errors"
"fmt"
"net"
"net/netip"
"reflect"
)

Expand Down Expand Up @@ -110,6 +110,7 @@ func FromBytes(buffer []byte) (*Reader, error) {

func (r *Reader) setIPv4Start() {
if r.Metadata.IPVersion != 6 {
r.ipv4StartBitDepth = 96
return
}

Expand All @@ -130,7 +131,7 @@ func (r *Reader) setIPv4Start() {
// because of type differences, an UnmarshalTypeError is returned. If the
// database is invalid or otherwise cannot be read, an InvalidDatabaseError
// is returned.
func (r *Reader) Lookup(ip net.IP, result any) error {
func (r *Reader) Lookup(ip netip.Addr, result any) error {
if r.buffer == nil {
return errors.New("cannot call Lookup on a closed database")
}
Expand All @@ -142,7 +143,7 @@ func (r *Reader) Lookup(ip net.IP, result any) error {
}

// LookupNetwork retrieves the database record for ip and stores it in the
// value pointed to by result. The network returned is the network associated
// value pointed to by result. The prefix returned is the network associated
// with the data record in the database. The ok return value indicates whether
// the database contained a record for the ip.
//
Expand All @@ -151,28 +152,29 @@ func (r *Reader) Lookup(ip net.IP, result any) error {
// UnmarshalTypeError is returned. If the database is invalid or otherwise
// cannot be read, an InvalidDatabaseError is returned.
func (r *Reader) LookupNetwork(
ip net.IP,
ip netip.Addr,
result any,
) (network *net.IPNet, ok bool, err error) {
) (prefix netip.Prefix, ok bool, err error) {
if r.buffer == nil {
return nil, false, errors.New("cannot call Lookup on a closed database")
return netip.Prefix{}, false, errors.New("cannot call Lookup on a closed database")
}
pointer, prefixLength, ip, err := r.lookupPointer(ip)
// We return this error below as we want to return the prefix it is for

network = r.cidr(ip, prefixLength)
if pointer == 0 || err != nil {
return network, false, err
prefix, errP := r.cidr(ip, prefixLength)
if pointer == 0 || err != nil || errP != nil {
return prefix, false, errors.Join(err, errP)
}

return network, true, r.retrieveData(pointer, result)
return prefix, true, r.retrieveData(pointer, result)
}

// LookupOffset maps an argument net.IP to a corresponding record offset in the
// database. NotFound is returned if no such record is found, and a record may
// otherwise be extracted by passing the returned offset to Decode. LookupOffset
// is an advanced API, which exists to provide clients with a means to cache
// previously-decoded records.
func (r *Reader) LookupOffset(ip net.IP) (uintptr, error) {
func (r *Reader) LookupOffset(ip netip.Addr) (uintptr, error) {
if r.buffer == nil {
return 0, errors.New("cannot call LookupOffset on a closed database")
}
Expand All @@ -183,22 +185,28 @@ func (r *Reader) LookupOffset(ip net.IP) (uintptr, error) {
return r.resolveDataPointer(pointer)
}

func (r *Reader) cidr(ip net.IP, prefixLength int) *net.IPNet {
// This is necessary as the node that the IPv4 start is at may
// be at a bit depth that is less that 96, i.e., ipv4Start points
// to a leaf node. For instance, if a record was inserted at ::/8,
// the ipv4Start would point directly at the leaf node for the
// record and would have a bit depth of 8. This would not happen
// with databases currently distributed by MaxMind as all of them
// have an IPv4 subtree that is greater than a single node.
if r.Metadata.IPVersion == 6 &&
len(ip) == net.IPv4len &&
r.ipv4StartBitDepth != 96 {
return &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(r.ipv4StartBitDepth, 128)}
var zeroIP = netip.MustParseAddr("::")

func (r *Reader) cidr(ip netip.Addr, prefixLength int) (netip.Prefix, error) {
if ip.Is4() {
// This is necessary as the node that the IPv4 start is at may
// be at a bit depth that is less that 96, i.e., ipv4Start points
// to a leaf node. For instance, if a record was inserted at ::/8,
// the ipv4Start would point directly at the leaf node for the
// record and would have a bit depth of 8. This would not happen
// with databases currently distributed by MaxMind as all of them
// have an IPv4 subtree that is greater than a single node.
if r.Metadata.IPVersion == 6 && r.ipv4StartBitDepth != 96 {
return netip.PrefixFrom(zeroIP, r.ipv4StartBitDepth), nil
}
prefixLength -= 96
}

mask := net.CIDRMask(prefixLength, len(ip)*8)
return &net.IPNet{IP: ip.Mask(mask), Mask: mask}
prefix, err := ip.Prefix(prefixLength)
if err != nil {
return netip.Prefix{}, fmt.Errorf("creating prefix from %s/%d: %w", ip, prefixLength, err)
}
return prefix, nil
}

// Decode the record at |offset| into |result|. The result value pointed to
Expand Down Expand Up @@ -239,29 +247,15 @@ func (r *Reader) decode(offset uintptr, result any) error {
return err
}

func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) {
if ip == nil {
return 0, 0, nil, errors.New("IP passed to Lookup cannot be nil")
}

ipV4Address := ip.To4()
if ipV4Address != nil {
ip = ipV4Address
}
if len(ip) == 16 && r.Metadata.IPVersion == 4 {
func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, netip.Addr, error) {
if r.Metadata.IPVersion == 4 && ip.Is6() {
return 0, 0, ip, fmt.Errorf(
"error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database",
ip.String(),
)
}

bitCount := uint(len(ip) * 8)

var node uint
if bitCount == 32 {
node = r.ipv4Start
}
node, prefixLength := r.traverseTree(ip, node, bitCount)
node, prefixLength := r.traverseTree(ip, 0, 128)

nodeCount := r.Metadata.NodeCount
if node == nodeCount {
Expand All @@ -274,12 +268,18 @@ func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) {
return 0, prefixLength, ip, newInvalidDatabaseError("invalid node in search tree")
}

func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) {
func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int) {
i := 0
if ip.Is4() {
i = r.ipv4StartBitDepth
node = r.ipv4Start
}
nodeCount := r.Metadata.NodeCount

i := uint(0)
for ; i < bitCount && node < nodeCount; i++ {
bit := uint(1) & (uint(ip[i>>3]) >> (7 - (i % 8)))
ip16 := ip.As16()

for ; i < stopBit && node < nodeCount; i++ {
bit := uint(1) & (uint(ip16[i>>3]) >> (7 - (i % 8)))

offset := node * r.nodeOffsetMult
if bit == 0 {
Expand All @@ -289,7 +289,7 @@ func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) {
}
}

return node, int(i)
return node, i
}

func (r *Reader) retrieveData(pointer uint, result any) error {
Expand Down
Loading

0 comments on commit 1e6c614

Please sign in to comment.