Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch Networks methods to iterators #146

Merged
merged 6 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
name: Build
strategy:
matrix:
go-version: [1.21.x, 1.22.x]
go-version: [1.23.0-rc.1]
platform: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.platform }}
steps:
Expand Down
4 changes: 2 additions & 2 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ const (
_Slice
// We don't use the next two. They are placeholders. See the spec
// for more details.
_Container //nolint: deadcode, varcheck // above
_Marker //nolint: deadcode, varcheck // above
_Container //nolint:deadcode,varcheck // above
_Marker //nolint:deadcode,varcheck // above
_Bool
_Float32
)
Expand Down
20 changes: 6 additions & 14 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,12 @@ func ExampleReader_Networks() {
Domain string `maxminddb:"connection_type"`
}{}

networks := db.Networks()
for networks.Next() {
subnet, err := networks.Network(&record)
for result := range db.Networks() {
err := result.Decode(&record)
if err != nil {
log.Panic(err)
}
fmt.Printf("%s: %s\n", subnet.String(), record.Domain)
}
if networks.Err() != nil {
log.Panic(networks.Err())
fmt.Printf("%s: %s\n", result.Network(), record.Domain)
}
// Output:
// 1.0.0.0/24: Cable/DSL
Expand Down Expand Up @@ -123,16 +119,12 @@ func ExampleReader_NetworksWithin() {
log.Panic(err)
}

networks := db.NetworksWithin(prefix)
for networks.Next() {
subnet, err := networks.Network(&record)
for result := range db.NetworksWithin(prefix) {
err := result.Decode(&record)
if err != nil {
log.Panic(err)
}
fmt.Printf("%s: %s\n", subnet.String(), record.Domain)
}
if networks.Err() != nil {
log.Panic(networks.Err())
fmt.Printf("%s: %s\n", result.Network(), record.Domain)
}

// Output:
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/oschwald/maxminddb-golang/v2

go 1.21
go 1.23

require (
github.com/stretchr/testify v1.9.0
Expand Down
8 changes: 0 additions & 8 deletions reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,6 @@ func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int)
return node, i
}

func (r *Reader) retrieveData(pointer uint, result any) error {
offset, err := r.resolveDataPointer(pointer)
if err != nil {
return err
}
return Result{decoder: r.decoder, offset: uint(offset)}.Decode(result)
}

func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) {
resolved := uintptr(pointer - r.Metadata.NodeCount - dataSectionSeparatorSize)

Expand Down
228 changes: 106 additions & 122 deletions traverse.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package maxminddb
import (
"fmt"
"net/netip"

// comment to prevent gofumpt from randomly moving iter.
"iter"
)

// Internal structure used to keep track of nodes we still need to visit.
Expand All @@ -12,12 +15,7 @@ type netNode struct {
pointer uint
}

// Networks represents a set of subnets that we are iterating over.
type Networks struct {
err error
reader *Reader
nodes []netNode
lastNode netNode
type networkOptions struct {
includeAliasedNetworks bool
}

Expand All @@ -27,12 +25,12 @@ var (
)

// NetworksOption are options for Networks and NetworksWithin.
type NetworksOption func(*Networks)
type NetworksOption func(*networkOptions)

// IncludeAliasedNetworks is an option for Networks and NetworksWithin
// that makes them iterate over aliases of the IPv4 subtree in an IPv6
// database, e.g., ::ffff:0:0/96, 2001::/32, and 2002::/16.
func IncludeAliasedNetworks(networks *Networks) {
func IncludeAliasedNetworks(networks *networkOptions) {
networks.includeAliasedNetworks = true
}

Expand All @@ -43,15 +41,11 @@ func IncludeAliasedNetworks(networks *Networks) {
// in an IPv6 database. This iterator will only iterate over these once by
// default. To iterate over all the IPv4 network locations, use the
// IncludeAliasedNetworks option.
func (r *Reader) Networks(options ...NetworksOption) *Networks {
var networks *Networks
func (r *Reader) Networks(options ...NetworksOption) iter.Seq[Result] {
if r.Metadata.IPVersion == 6 {
networks = r.NetworksWithin(allIPv6, options...)
} else {
networks = r.NetworksWithin(allIPv4, options...)
return r.NetworksWithin(allIPv6, options...)
}

return networks
return r.NetworksWithin(allIPv4, options...)
}

// NetworksWithin returns an iterator that can be used to traverse all networks
Expand All @@ -64,126 +58,116 @@ func (r *Reader) Networks(options ...NetworksOption) *Networks {
//
// If the provided prefix is contained within a network in the database, the
// iterator will iterate over exactly one network, the containing network.
func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) *Networks {
if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() {
return &Networks{
err: fmt.Errorf(
"error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database",
prefix,
),
func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) iter.Seq[Result] {
return func(yield func(Result) bool) {
if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() {
yield(Result{
err: fmt.Errorf(
"error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database",
prefix,
),
})
return
}
}

networks := &Networks{reader: r}
for _, option := range options {
option(networks)
}

ip := prefix.Addr()
netIP := ip
stopBit := prefix.Bits()
if ip.Is4() {
netIP = v4ToV16(ip)
stopBit += 96
}

pointer, bit := r.traverseTree(ip, 0, stopBit)

prefix, err := netIP.Prefix(bit)
if err != nil {
networks.err = fmt.Errorf("prefixing %s with %d", netIP, bit)
}
n := &networkOptions{}
for _, option := range options {
option(n)
}

networks.nodes = []netNode{
{
ip: prefix.Addr(),
bit: uint(bit),
pointer: pointer,
},
}
ip := prefix.Addr()
netIP := ip
stopBit := prefix.Bits()
if ip.Is4() {
netIP = v4ToV16(ip)
stopBit += 96
}

return networks
}
pointer, bit := r.traverseTree(ip, 0, stopBit)

// Next prepares the next network for reading with the Network method. It
// returns true if there is another network to be processed and false if there
// are no more networks or if there is an error.
func (n *Networks) Next() bool {
if n.err != nil {
return false
}
for len(n.nodes) > 0 {
node := n.nodes[len(n.nodes)-1]
n.nodes = n.nodes[:len(n.nodes)-1]

for node.pointer != n.reader.Metadata.NodeCount {
// This skips IPv4 aliases without hardcoding the networks that the writer
// currently aliases.
if !n.includeAliasedNetworks && n.reader.ipv4Start != 0 &&
node.pointer == n.reader.ipv4Start && !isInIPv4Subtree(node.ip) {
break
}
prefix, err := netIP.Prefix(bit)
if err != nil {
yield(Result{
ip: ip,
prefixLen: uint8(bit),
err: fmt.Errorf("prefixing %s with %d", netIP, bit),
})
}

if node.pointer > n.reader.Metadata.NodeCount {
n.lastNode = node
return true
}
ipRight := node.ip.As16()
if len(ipRight) <= int(node.bit>>3) {
displayAddr := node.ip
displayBits := node.bit
if isInIPv4Subtree(node.ip) {
displayAddr = v6ToV4(displayAddr)
displayBits -= 96
nodes := make([]netNode, 0, 64)
nodes = append(nodes,
netNode{
ip: prefix.Addr(),
bit: uint(bit),
pointer: pointer,
},
)

for len(nodes) > 0 {
node := nodes[len(nodes)-1]
nodes = nodes[:len(nodes)-1]

for node.pointer != r.Metadata.NodeCount {
// This skips IPv4 aliases without hardcoding the networks that the writer
// currently aliases.
if !n.includeAliasedNetworks && r.ipv4Start != 0 &&
node.pointer == r.ipv4Start && !isInIPv4Subtree(node.ip) {
break
}

n.err = newInvalidDatabaseError(
"invalid search tree at %s/%d", displayAddr, displayBits)
return false
}
ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8))
if node.pointer > r.Metadata.NodeCount {
ip := node.ip
if isInIPv4Subtree(ip) {
ip = v6ToV4(ip)
}

offset, err := r.resolveDataPointer(node.pointer)
ok := yield(Result{
decoder: r.decoder,
ip: ip,
offset: uint(offset),
prefixLen: uint8(node.bit),
err: err,
})
if !ok {
return
}
break
}
ipRight := node.ip.As16()
if len(ipRight) <= int(node.bit>>3) {
displayAddr := node.ip
if isInIPv4Subtree(node.ip) {
displayAddr = v6ToV4(displayAddr)
}

res := Result{
ip: displayAddr,
prefixLen: uint8(node.bit),
}
res.err = newInvalidDatabaseError(
"invalid search tree at %s", res.Network())

yield(res)

return
}
ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8))

offset := node.pointer * n.reader.nodeOffsetMult
rightPointer := n.reader.nodeReader.readRight(offset)
offset := node.pointer * r.nodeOffsetMult
rightPointer := r.nodeReader.readRight(offset)

node.bit++
n.nodes = append(n.nodes, netNode{
pointer: rightPointer,
ip: netip.AddrFrom16(ipRight),
bit: node.bit,
})
node.bit++
nodes = append(nodes, netNode{
pointer: rightPointer,
ip: netip.AddrFrom16(ipRight),
bit: node.bit,
})

node.pointer = n.reader.nodeReader.readLeft(offset)
node.pointer = r.nodeReader.readLeft(offset)
}
}
}

return false
}

// Network returns the current network or an error if there is a problem
// decoding the data for the network. It takes a pointer to a result value to
// decode the network's data into.
func (n *Networks) Network(result any) (netip.Prefix, error) {
if n.err != nil {
return netip.Prefix{}, n.err
}
if err := n.reader.retrieveData(n.lastNode.pointer, result); err != nil {
return netip.Prefix{}, err
}

ip := n.lastNode.ip
prefixLength := int(n.lastNode.bit)
if isInIPv4Subtree(ip) {
ip = v6ToV4(ip)
prefixLength -= 96
}

return netip.PrefixFrom(ip, prefixLength), nil
}

// Err returns an error, if any, that was encountered during iteration.
func (n *Networks) Err() error {
return n.err
}

var ipv4SubtreeBoundary = netip.MustParseAddr("::255.255.255.255").Next()
Expand Down
Loading
Loading