Skip to content

Commit

Permalink
IAVL Iterator (#440)
Browse files Browse the repository at this point in the history
* move traverse into CPS

* add iterator

* add ascending

* fix iter.valid

* fix test

* gofmt

* simplify logicl, add comments

* rm using iterator inside iavl

* add detailed comments

* add comments per review

* modify to closure independent

* apply review, add docs, separate iterator.go

* fix lint

* fix fix lint

* add more docs

Co-authored-by: Marko <[email protected]>
  • Loading branch information
mconcat and tac0turtle authored Nov 9, 2021
1 parent 209a946 commit 849cd28
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 57 deletions.
4 changes: 2 additions & 2 deletions immutable_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (t *ImmutableTree) IterateRange(start, end []byte, ascending bool, fn func(
if t.root == nil {
return false
}
return t.root.traverseInRange(t, start, end, ascending, false, 0, false, func(node *Node, _ uint8) bool {
return t.root.traverseInRange(t, start, end, ascending, false, false, func(node *Node) bool {
if node.height == 0 {
return fn(node.key, node.value)
}
Expand All @@ -196,7 +196,7 @@ func (t *ImmutableTree) IterateRangeInclusive(start, end []byte, ascending bool,
if t.root == nil {
return false
}
return t.root.traverseInRange(t, start, end, ascending, true, 0, false, func(node *Node, _ uint8) bool {
return t.root.traverseInRange(t, start, end, ascending, true, false, func(node *Node) bool {
if node.height == 0 {
return fn(node.key, node.value, node.version)
}
Expand Down
228 changes: 228 additions & 0 deletions iterator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
package iavl

// NOTE: This file favors int64 as opposed to int for size/counts.
// The Tree on the other hand favors int. This is intentional.

import (
"bytes"

dbm "github.com/tendermint/tm-db"
)

type traversal struct {
tree *ImmutableTree
start, end []byte // iteration domain
ascending bool // ascending traversal
inclusive bool // end key inclusiveness
post bool // postorder traversal
delayedNodes *delayedNodes // delayed nodes to be traversed
}

func (node *Node) newTraversal(tree *ImmutableTree, start, end []byte, ascending bool, inclusive bool, post bool) *traversal {
return &traversal{
tree: tree,
start: start,
end: end,
ascending: ascending,
inclusive: inclusive,
post: post,
delayedNodes: &delayedNodes{{node, true}}, // set initial traverse to the node
}
}

// delayedNode represents the delayed iteration on the nodes.
// When delayed is set to true, the delayedNode should be expanded, and their
// children should be traversed. When delayed is set to false, the delayedNode is
// already have expanded, and it could be immediately returned.
type delayedNode struct {
node *Node
delayed bool
}

type delayedNodes []delayedNode

func (nodes *delayedNodes) pop() (*Node, bool) {
node := (*nodes)[len(*nodes)-1]
*nodes = (*nodes)[:len(*nodes)-1]
return node.node, node.delayed
}

func (nodes *delayedNodes) push(node *Node, delayed bool) {
*nodes = append(*nodes, delayedNode{node, delayed})
}

func (nodes *delayedNodes) length() int {
return len(*nodes)
}

// `traversal` returns the delayed execution of recursive traversal on a tree.
//
// `traversal` will traverse the tree in a depth-first manner. To handle locating
// the next element, and to handle unwinding, the traversal maintains its future
// iteration under `delayedNodes`. At each call of `next()`, it will retrieve the
// next element from the `delayedNodes` and acts accordingly. The `next()` itself
// defines how to unwind the delayed nodes stack. The caller can either call the
// next traversal to proceed, or simply discard the `traversal` struct to stop iteration.
//
// At the each step of `next`, the `delayedNodes` can have one of the three states:
// 1. It has length of 0, meaning that their is no more traversable nodes.
// 2. It has length of 1, meaning that the traverse is being started from the initial node.
// 3. It has length of 2>=, meaning that there are delayed nodes to be traversed.
//
// When the `delayedNodes` are not empty, `next` retrieves the first `delayedNode` and initially check:
// 1. If it is not an delayed node (node.delayed == false) it immediately returns it.
//
// A. If the `node` is a branch node:
// 1. If the traversal is postorder, then append the current node to the t.delayedNodes,
// with `delayed` set to false. This makes the current node returned *after* all the children
// are traversed, without being expanded.
// 2. Append the traversable children nodes into the `delayedNodes`, with `delayed` set to true. This
// makes the children nodes to be traversed, and expanded with their respective children.
// 3. If the traversal is preorder, (with the children to be traversed already pushed to the
// `delayedNodes`), returns the current node.
// 4. Call `traversal.next()` to further traverse through the `delayedNodes`.
//
// B. If the `node` is a leaf node, it will be returned without expand, by the following process:
// 1. If the traversal is postorder, the current node will be append to the `delayedNodes` with `delayed`
// set to false, and immediately returned at the subsequent call of `traversal.next()` at the last line.
// 2. If the traversal is preorder, the current node will be returned.
func (t *traversal) next() *Node {
// End of traversal.
if t.delayedNodes.length() == 0 {
return nil
}

node, delayed := t.delayedNodes.pop()

// Already expanded, immediately return.
if !delayed || node == nil {
return node
}

afterStart := t.start == nil || bytes.Compare(t.start, node.key) < 0
startOrAfter := afterStart || bytes.Equal(t.start, node.key)
beforeEnd := t.end == nil || bytes.Compare(node.key, t.end) < 0
if t.inclusive {
beforeEnd = beforeEnd || bytes.Equal(node.key, t.end)
}

// case of postorder. A-1 and B-1
// Recursively process left sub-tree, then right-subtree, then node itself.
if t.post && (!node.isLeaf() || (startOrAfter && beforeEnd)) {
t.delayedNodes.push(node, false)
}

// case of branch node, traversing children. A-2.
if !node.isLeaf() {
// if node is a branch node and the order is ascending,
// We traverse through the left subtree, then the right subtree.
if t.ascending {
if beforeEnd {
// push the delayed traversal for the right nodes,
t.delayedNodes.push(node.getRightNode(t.tree), true)
}
if afterStart {
// push the delayed traversal for the left nodes,
t.delayedNodes.push(node.getLeftNode(t.tree), true)
}
} else {
// if node is a branch node and the order is not ascending
// We traverse through the right subtree, then the left subtree.
if afterStart {
// push the delayed traversal for the left nodes,
t.delayedNodes.push(node.getLeftNode(t.tree), true)
}
if beforeEnd {
// push the delayed traversal for the right nodes,
t.delayedNodes.push(node.getRightNode(t.tree), true)
}
}
}

// case of preorder traversal. A-3 and B-2.
// Process root then (recursively) processing left child, then process right child
if !t.post && (!node.isLeaf() || (startOrAfter && beforeEnd)) {
return node
}

// Keep traversing and expanding the remaning delayed nodes. A-4.
return t.next()
}

// Iterator is a dbm.Iterator for ImmutableTree
type Iterator struct {
start, end []byte

key, value []byte

valid bool

t *traversal
}

func (t *ImmutableTree) Iterator(start, end []byte, ascending bool) *Iterator {
iter := &Iterator{
start: start,
end: end,
valid: true,
t: t.root.newTraversal(t, start, end, ascending, false, false),
}

iter.Next()
return iter
}

var _ dbm.Iterator = &Iterator{}

// Domain implements dbm.Iterator.
func (iter *Iterator) Domain() ([]byte, []byte) {
return iter.start, iter.end
}

// Valid implements dbm.Iterator.
func (iter *Iterator) Valid() bool {
return iter.valid
}

// Key implements dbm.Iterator
func (iter *Iterator) Key() []byte {
return iter.key
}

// Value implements dbm.Iterator
func (iter *Iterator) Value() []byte {
return iter.value
}

// Next implements dbm.Iterator
func (iter *Iterator) Next() {
if iter.t == nil {
return
}

node := iter.t.next()
if node == nil {
iter.t = nil
iter.valid = false
return
}

if node.height == 0 {
iter.key, iter.value = node.key, node.value
return
}

iter.Next()
}

// Close implements dbm.Iterator
func (iter *Iterator) Close() error {
iter.t = nil
iter.valid = false
return nil
}

// Error implements dbm.Iterator
func (iter *Iterator) Error() error {
return nil
}
59 changes: 6 additions & 53 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,74 +440,27 @@ func (node *Node) calcBalance(t *ImmutableTree) int {

// traverse is a wrapper over traverseInRange when we want the whole tree
func (node *Node) traverse(t *ImmutableTree, ascending bool, cb func(*Node) bool) bool {
return node.traverseInRange(t, nil, nil, ascending, false, 0, false, func(node *Node, depth uint8) bool {
return node.traverseInRange(t, nil, nil, ascending, false, false, func(node *Node) bool {
return cb(node)
})
}

// traversePost is a wrapper over traverseInRange when we want the whole tree post-order
func (node *Node) traversePost(t *ImmutableTree, ascending bool, cb func(*Node) bool) bool {
return node.traverseInRange(t, nil, nil, ascending, false, 0, true, func(node *Node, depth uint8) bool {
return node.traverseInRange(t, nil, nil, ascending, false, true, func(node *Node) bool {
return cb(node)
})
}

func (node *Node) traverseInRange(t *ImmutableTree, start, end []byte, ascending bool, inclusive bool, depth uint8, post bool, cb func(*Node, uint8) bool) bool {
if node == nil {
return false
}
afterStart := start == nil || bytes.Compare(start, node.key) < 0
startOrAfter := start == nil || bytes.Compare(start, node.key) <= 0
beforeEnd := end == nil || bytes.Compare(node.key, end) < 0
if inclusive {
beforeEnd = end == nil || bytes.Compare(node.key, end) <= 0
}

// Run callback per inner/leaf node.
func (node *Node) traverseInRange(tree *ImmutableTree, start, end []byte, ascending bool, inclusive bool, post bool, cb func(*Node) bool) bool {
stop := false
if !post && (!node.isLeaf() || (startOrAfter && beforeEnd)) {
stop = cb(node, depth)
t := node.newTraversal(tree, start, end, ascending, inclusive, post)
for node2 := t.next(); node2 != nil; node2 = t.next() {
stop = cb(node2)
if stop {
return stop
}
}

if !node.isLeaf() {
if ascending {
// check lower nodes, then higher
if afterStart {
stop = node.getLeftNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, post, cb)
}
if stop {
return stop
}
if beforeEnd {
stop = node.getRightNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, post, cb)
}
} else {
// check the higher nodes first
if beforeEnd {
stop = node.getRightNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, post, cb)
}
if stop {
return stop
}
if afterStart {
stop = node.getLeftNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, post, cb)
}
}
}
if stop {
return stop
}

if post && (!node.isLeaf() || (startOrAfter && beforeEnd)) {
stop = cb(node, depth)
if stop {
return stop
}
}

return stop
}

Expand Down
4 changes: 2 additions & 2 deletions proof_range.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ func (t *ImmutableTree) getRangeProof(keyStart, keyEnd []byte, limit int) (proof
var leafCount = 1 // from left above.
var pathCount = 0

t.root.traverseInRange(t, afterLeft, nil, true, false, 0, false,
func(node *Node, depth uint8) (stop bool) {
t.root.traverseInRange(t, afterLeft, nil, true, false, false,
func(node *Node) (stop bool) {

// Track when we diverge from path, or when we've exhausted path,
// since the first allPathToLeafs shouldn't include it.
Expand Down

0 comments on commit 849cd28

Please sign in to comment.