diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 0481be5..b45f167 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -8,7 +8,7 @@ jobs: name: Build strategy: matrix: - go-version: [1.21.x, 1.22.x] + go-version: [1.23.0-rc.1] platform: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.platform }} steps: diff --git a/decoder.go b/decoder.go index db941fc..5864873 100644 --- a/decoder.go +++ b/decoder.go @@ -30,8 +30,8 @@ const ( _Slice // We don't use the next two. They are placeholders. See the spec // for more details. - _Container //nolint: deadcode, varcheck // above - _Marker //nolint: deadcode, varcheck // above + _Container //nolint:deadcode,varcheck // above + _Marker //nolint:deadcode,varcheck // above _Bool _Float32 ) diff --git a/example_test.go b/example_test.go index c3aee66..d61e69d 100644 --- a/example_test.go +++ b/example_test.go @@ -67,16 +67,12 @@ func ExampleReader_Networks() { Domain string `maxminddb:"connection_type"` }{} - networks := db.Networks() - for networks.Next() { - subnet, err := networks.Network(&record) + for result := range db.Networks() { + err := result.Decode(&record) if err != nil { log.Panic(err) } - fmt.Printf("%s: %s\n", subnet.String(), record.Domain) - } - if networks.Err() != nil { - log.Panic(networks.Err()) + fmt.Printf("%s: %s\n", result.Network(), record.Domain) } // Output: // 1.0.0.0/24: Cable/DSL @@ -123,16 +119,12 @@ func ExampleReader_NetworksWithin() { log.Panic(err) } - networks := db.NetworksWithin(prefix) - for networks.Next() { - subnet, err := networks.Network(&record) + for result := range db.NetworksWithin(prefix) { + err := result.Decode(&record) if err != nil { log.Panic(err) } - fmt.Printf("%s: %s\n", subnet.String(), record.Domain) - } - if networks.Err() != nil { - log.Panic(networks.Err()) + fmt.Printf("%s: %s\n", result.Network(), record.Domain) } // Output: diff --git a/go.mod b/go.mod index 11b1b10..fc7d1c9 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/oschwald/maxminddb-golang/v2 -go 1.21 +go 1.23 require ( github.com/stretchr/testify v1.9.0 diff --git a/reader.go b/reader.go index c26f28d..e9f76c3 100644 --- a/reader.go +++ b/reader.go @@ -221,14 +221,6 @@ func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int) return node, i } -func (r *Reader) retrieveData(pointer uint, result any) error { - offset, err := r.resolveDataPointer(pointer) - if err != nil { - return err - } - return Result{decoder: r.decoder, offset: uint(offset)}.Decode(result) -} - func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) { resolved := uintptr(pointer - r.Metadata.NodeCount - dataSectionSeparatorSize) diff --git a/traverse.go b/traverse.go index 67a5148..9e4dc3d 100644 --- a/traverse.go +++ b/traverse.go @@ -3,6 +3,9 @@ package maxminddb import ( "fmt" "net/netip" + + // comment to prevent gofumpt from randomly moving iter. + "iter" ) // Internal structure used to keep track of nodes we still need to visit. @@ -12,12 +15,7 @@ type netNode struct { pointer uint } -// Networks represents a set of subnets that we are iterating over. -type Networks struct { - err error - reader *Reader - nodes []netNode - lastNode netNode +type networkOptions struct { includeAliasedNetworks bool } @@ -27,12 +25,12 @@ var ( ) // NetworksOption are options for Networks and NetworksWithin. -type NetworksOption func(*Networks) +type NetworksOption func(*networkOptions) // IncludeAliasedNetworks is an option for Networks and NetworksWithin // that makes them iterate over aliases of the IPv4 subtree in an IPv6 // database, e.g., ::ffff:0:0/96, 2001::/32, and 2002::/16. -func IncludeAliasedNetworks(networks *Networks) { +func IncludeAliasedNetworks(networks *networkOptions) { networks.includeAliasedNetworks = true } @@ -43,15 +41,11 @@ func IncludeAliasedNetworks(networks *Networks) { // in an IPv6 database. This iterator will only iterate over these once by // default. To iterate over all the IPv4 network locations, use the // IncludeAliasedNetworks option. -func (r *Reader) Networks(options ...NetworksOption) *Networks { - var networks *Networks +func (r *Reader) Networks(options ...NetworksOption) iter.Seq[Result] { if r.Metadata.IPVersion == 6 { - networks = r.NetworksWithin(allIPv6, options...) - } else { - networks = r.NetworksWithin(allIPv4, options...) + return r.NetworksWithin(allIPv6, options...) } - - return networks + return r.NetworksWithin(allIPv4, options...) } // NetworksWithin returns an iterator that can be used to traverse all networks @@ -64,126 +58,116 @@ func (r *Reader) Networks(options ...NetworksOption) *Networks { // // If the provided prefix is contained within a network in the database, the // iterator will iterate over exactly one network, the containing network. -func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) *Networks { - if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() { - return &Networks{ - err: fmt.Errorf( - "error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database", - prefix, - ), +func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) iter.Seq[Result] { + return func(yield func(Result) bool) { + if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() { + yield(Result{ + err: fmt.Errorf( + "error getting networks with '%s': you attempted to use an IPv6 network in an IPv4-only database", + prefix, + ), + }) + return } - } - - networks := &Networks{reader: r} - for _, option := range options { - option(networks) - } - - ip := prefix.Addr() - netIP := ip - stopBit := prefix.Bits() - if ip.Is4() { - netIP = v4ToV16(ip) - stopBit += 96 - } - - pointer, bit := r.traverseTree(ip, 0, stopBit) - prefix, err := netIP.Prefix(bit) - if err != nil { - networks.err = fmt.Errorf("prefixing %s with %d", netIP, bit) - } + n := &networkOptions{} + for _, option := range options { + option(n) + } - networks.nodes = []netNode{ - { - ip: prefix.Addr(), - bit: uint(bit), - pointer: pointer, - }, - } + ip := prefix.Addr() + netIP := ip + stopBit := prefix.Bits() + if ip.Is4() { + netIP = v4ToV16(ip) + stopBit += 96 + } - return networks -} + pointer, bit := r.traverseTree(ip, 0, stopBit) -// Next prepares the next network for reading with the Network method. It -// returns true if there is another network to be processed and false if there -// are no more networks or if there is an error. -func (n *Networks) Next() bool { - if n.err != nil { - return false - } - for len(n.nodes) > 0 { - node := n.nodes[len(n.nodes)-1] - n.nodes = n.nodes[:len(n.nodes)-1] - - for node.pointer != n.reader.Metadata.NodeCount { - // This skips IPv4 aliases without hardcoding the networks that the writer - // currently aliases. - if !n.includeAliasedNetworks && n.reader.ipv4Start != 0 && - node.pointer == n.reader.ipv4Start && !isInIPv4Subtree(node.ip) { - break - } + prefix, err := netIP.Prefix(bit) + if err != nil { + yield(Result{ + ip: ip, + prefixLen: uint8(bit), + err: fmt.Errorf("prefixing %s with %d", netIP, bit), + }) + } - if node.pointer > n.reader.Metadata.NodeCount { - n.lastNode = node - return true - } - ipRight := node.ip.As16() - if len(ipRight) <= int(node.bit>>3) { - displayAddr := node.ip - displayBits := node.bit - if isInIPv4Subtree(node.ip) { - displayAddr = v6ToV4(displayAddr) - displayBits -= 96 + nodes := make([]netNode, 0, 64) + nodes = append(nodes, + netNode{ + ip: prefix.Addr(), + bit: uint(bit), + pointer: pointer, + }, + ) + + for len(nodes) > 0 { + node := nodes[len(nodes)-1] + nodes = nodes[:len(nodes)-1] + + for node.pointer != r.Metadata.NodeCount { + // This skips IPv4 aliases without hardcoding the networks that the writer + // currently aliases. + if !n.includeAliasedNetworks && r.ipv4Start != 0 && + node.pointer == r.ipv4Start && !isInIPv4Subtree(node.ip) { + break } - n.err = newInvalidDatabaseError( - "invalid search tree at %s/%d", displayAddr, displayBits) - return false - } - ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) + if node.pointer > r.Metadata.NodeCount { + ip := node.ip + if isInIPv4Subtree(ip) { + ip = v6ToV4(ip) + } + + offset, err := r.resolveDataPointer(node.pointer) + ok := yield(Result{ + decoder: r.decoder, + ip: ip, + offset: uint(offset), + prefixLen: uint8(node.bit), + err: err, + }) + if !ok { + return + } + break + } + ipRight := node.ip.As16() + if len(ipRight) <= int(node.bit>>3) { + displayAddr := node.ip + if isInIPv4Subtree(node.ip) { + displayAddr = v6ToV4(displayAddr) + } + + res := Result{ + ip: displayAddr, + prefixLen: uint8(node.bit), + } + res.err = newInvalidDatabaseError( + "invalid search tree at %s", res.Network()) + + yield(res) + + return + } + ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) - offset := node.pointer * n.reader.nodeOffsetMult - rightPointer := n.reader.nodeReader.readRight(offset) + offset := node.pointer * r.nodeOffsetMult + rightPointer := r.nodeReader.readRight(offset) - node.bit++ - n.nodes = append(n.nodes, netNode{ - pointer: rightPointer, - ip: netip.AddrFrom16(ipRight), - bit: node.bit, - }) + node.bit++ + nodes = append(nodes, netNode{ + pointer: rightPointer, + ip: netip.AddrFrom16(ipRight), + bit: node.bit, + }) - node.pointer = n.reader.nodeReader.readLeft(offset) + node.pointer = r.nodeReader.readLeft(offset) + } } } - - return false -} - -// Network returns the current network or an error if there is a problem -// decoding the data for the network. It takes a pointer to a result value to -// decode the network's data into. -func (n *Networks) Network(result any) (netip.Prefix, error) { - if n.err != nil { - return netip.Prefix{}, n.err - } - if err := n.reader.retrieveData(n.lastNode.pointer, result); err != nil { - return netip.Prefix{}, err - } - - ip := n.lastNode.ip - prefixLength := int(n.lastNode.bit) - if isInIPv4Subtree(ip) { - ip = v6ToV4(ip) - prefixLength -= 96 - } - - return netip.PrefixFrom(ip, prefixLength), nil -} - -// Err returns an error, if any, that was encountered during iteration. -func (n *Networks) Err() error { - return n.err } var ipv4SubtreeBoundary = netip.MustParseAddr("::255.255.255.255").Next() diff --git a/traverse_test.go b/traverse_test.go index 382659a..554c2b2 100644 --- a/traverse_test.go +++ b/traverse_test.go @@ -20,18 +20,18 @@ func TestNetworks(t *testing.T) { reader, err := Open(fileName) require.NoError(t, err, "unexpected error while opening database: %v", err) - n := reader.Networks() - for n.Next() { + for result := range reader.Networks() { record := struct { IP string `maxminddb:"ip"` }{} - network, err := n.Network(&record) + err := result.Decode(&record) require.NoError(t, err) + + network := result.Network() assert.Equal(t, record.IP, network.Addr().String(), "expected %s got %s", record.IP, network.Addr().String(), ) } - require.NoError(t, n.Err()) require.NoError(t, reader.Close()) } } @@ -41,13 +41,14 @@ func TestNetworksWithInvalidSearchTree(t *testing.T) { reader, err := Open(testFile("MaxMind-DB-test-broken-search-tree-24.mmdb")) require.NoError(t, err, "unexpected error while opening database: %v", err) - n := reader.Networks() - for n.Next() { + for result := range reader.Networks() { var record any - _, err := n.Network(&record) - require.NoError(t, err) + err = result.Decode(&record) + if err != nil { + break + } } - require.EqualError(t, n.Err(), "invalid search tree at 128.128.128.128/32") + require.EqualError(t, err, "invalid search tree at 128.128.128.128/32") require.NoError(t, reader.Close()) } @@ -285,20 +286,18 @@ func TestNetworksWithin(t *testing.T) { require.NoError(t, err) require.NoError(t, err) - n := reader.NetworksWithin(network, v.Options...) var innerIPs []string - for n.Next() { + for result := range reader.NetworksWithin(network, v.Options...) { record := struct { IP string `maxminddb:"ip"` }{} - network, err := n.Network(&record) + err := result.Decode(&record) require.NoError(t, err) - innerIPs = append(innerIPs, network.String()) + innerIPs = append(innerIPs, result.Network().String()) } assert.Equal(t, v.Expected, innerIPs) - require.NoError(t, n.Err()) require.NoError(t, reader.Close()) }) @@ -326,21 +325,35 @@ func TestGeoIPNetworksWithin(t *testing.T) { prefix, err := netip.ParsePrefix(v.Network) require.NoError(t, err) - n := reader.NetworksWithin(prefix) var innerIPs []string - for n.Next() { + for result := range reader.NetworksWithin(prefix) { record := struct { IP string `maxminddb:"ip"` }{} - network, err := n.Network(&record) + err := result.Decode(&record) require.NoError(t, err) - innerIPs = append(innerIPs, network.String()) + innerIPs = append(innerIPs, result.Network().String()) } assert.Equal(t, v.Expected, innerIPs) - require.NoError(t, n.Err()) require.NoError(t, reader.Close()) } } + +func BenchmarkNetworks(b *testing.B) { + db, err := Open(testFile("GeoIP2-Country-Test.mmdb")) + require.NoError(b, err) + + for i := 0; i < b.N; i++ { + for r := range db.Networks() { + var rec struct{} + err = r.Decode(&rec) + if err != nil { + b.Error(err) + } + } + } + require.NoError(b, db.Close(), "error on close") +} diff --git a/verifier.go b/verifier.go index b14b3e4..335cb1b 100644 --- a/verifier.go +++ b/verifier.go @@ -102,16 +102,11 @@ func (v *verifier) verifyDatabase() error { func (v *verifier) verifySearchTree() (map[uint]bool, error) { offsets := make(map[uint]bool) - it := v.reader.Networks() - for it.Next() { - offset, err := v.reader.resolveDataPointer(it.lastNode.pointer) - if err != nil { + for result := range v.reader.Networks() { + if err := result.Err(); err != nil { return nil, err } - offsets[uint(offset)] = true - } - if err := it.Err(); err != nil { - return nil, err + offsets[result.offset] = true } return offsets, nil }