trie/bintrie: implement binaryNodeIterator.seek()

The bintrie node iterator previously discarded its `start` parameter,
forcing every iteration to begin at the root. This makes resumable
generators (snapshot/flat-state population) impossible — any
interruption restarts from scratch.

Implement seek(start []byte) by walking down the trie following start's
bit path, building the iterator stack as we go. When the chosen path
dead-ends (Empty, missing child, or a stem strictly less than start),
backtrack through the existing stack to find the next in-order subtree
and descend to its leftmost leaf.

Also wire BinaryTrie.NodeIterator(startKey) to actually pass startKey
through (was hardcoded to nil).

Tests cover: empty start (no-op), exact key match, between-keys,
into empty subtree, past end, within-stem offsets, resume simulation,
and deep tree.
This commit is contained in:
CPerezz 2026-04-07 14:09:07 +02:00
parent 64d185616c
commit 2851f7b8c7
No known key found for this signature in database
GPG key ID: 62045F34B97177DD
3 changed files with 569 additions and 4 deletions

View file

@ -17,7 +17,9 @@
package bintrie
import (
"bytes"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/trie"
@ -38,15 +40,341 @@ type binaryNodeIterator struct {
stack []binaryNodeIteratorState
}
func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) {
func newBinaryNodeIterator(t *BinaryTrie, start []byte) (trie.NodeIterator, error) {
if t.Hash() == zero {
return &binaryNodeIterator{trie: t, lastErr: errIteratorEnd}, nil
}
it := &binaryNodeIterator{trie: t, current: t.root}
// it.err = it.seek(start)
if len(start) > 0 {
if err := it.seek(start); err != nil {
return nil, err
}
}
return it, nil
}
// seek positions the iterator so that the next call to Next(true) advances to
// the first leaf with key >= start. It walks down the trie following start's
// bit path, building the iterator stack along the way. When the chosen path
// dead-ends (Empty, missing child, or a stem strictly less than start), the
// implementation backtracks through the existing stack to find the next
// in-order subtree and descends to its leftmost leaf.
//
// A nil/empty start is a no-op; iteration begins at the trie root as usual.
//
// This is required for resumable bintrie generators (snapshot generation,
// pathdb flat-state population) so that an interrupted run can pick up where
// it left off after a crash or graceful shutdown.
func (it *binaryNodeIterator) seek(start []byte) error {
if len(start) == 0 {
return nil
}
// Pad start to a 32-byte key (the trie's natural key length).
var key [32]byte
copy(key[:], start)
// Reset state
it.stack = it.stack[:0]
it.current = nil
it.lastErr = nil
root := it.trie.root
if root == nil {
it.lastErr = errIteratorEnd
return nil
}
if _, isEmpty := root.(Empty); isEmpty {
it.lastErr = errIteratorEnd
return nil
}
// Resolve the root if it's a HashedNode
resolved, err := it.resolveIfHashed(root, nil, 0)
if err != nil {
return err
}
if resolved == nil {
it.lastErr = errIteratorEnd
return nil
}
if resolved != root {
it.trie.root = resolved
root = resolved
}
return it.seekDescend(root, key[:])
}
// seekDescend walks down from `node` following key's bit path. For each
// InternalNode encountered, it pushes the node onto the stack with Index set
// to the bit it descended into (0 for left, 1 for right) and recurses into
// the chosen child. On a StemNode it positions at the appropriate value
// offset and returns. On a dead end (Empty, nil, stem < key), it delegates
// to seekBacktrack to find the next valid subtree.
func (it *binaryNodeIterator) seekDescend(node BinaryNode, key []byte) error {
for {
switch n := node.(type) {
case *InternalNode:
depth := n.depth
if depth >= 31*8 {
return errors.New("seek: internal node too deep")
}
bit := key[depth/8] >> (7 - uint(depth%8)) & 1
// Push this internal node with Index = chosen bit. The Next()
// loop interprets Index as "the side currently being explored",
// so this is consistent with normal iteration state.
it.stack = append(it.stack, binaryNodeIteratorState{Node: n, Index: int(bit)})
it.current = n
var child BinaryNode
if bit == 0 {
child = n.left
} else {
child = n.right
}
if child == nil {
return it.seekBacktrack()
}
if _, isEmpty := child.(Empty); isEmpty {
return it.seekBacktrack()
}
// Resolve a hashed child using the current key as the path source.
resolved, err := it.resolveIfHashed(child, key, depth+1)
if err != nil {
return err
}
if resolved == nil {
return it.seekBacktrack()
}
if resolved != child {
if bit == 0 {
n.left = resolved
} else {
n.right = resolved
}
}
node = resolved
case *StemNode:
cmp := bytes.Compare(n.Stem, key[:StemSize])
if cmp < 0 {
// Stem is strictly before our target. Don't push it; backtrack
// to find the next subtree to the right.
return it.seekBacktrack()
}
startOffset := 0
if cmp == 0 {
startOffset = int(key[StemSize])
}
it.stack = append(it.stack, binaryNodeIteratorState{Node: n, Index: startOffset})
it.current = n
return nil
default:
return fmt.Errorf("seek: unexpected node type %T", node)
}
}
}
// seekBacktrack walks the existing stack backward looking for the first
// InternalNode whose right subtree hasn't been considered yet. If found, it
// flips that node's Index to 1 and descends into the leftmost leaf of the
// right subtree. If no such ancestor exists, it sets errIteratorEnd.
func (it *binaryNodeIterator) seekBacktrack() error {
for len(it.stack) > 0 {
top := &it.stack[len(it.stack)-1]
n, ok := top.Node.(*InternalNode)
if !ok {
// Not an InternalNode (e.g., a StemNode pushed elsewhere). Pop and
// continue. seekDescend never pushes non-internal nodes before
// returning, so this is a defensive fallback.
it.stack = it.stack[:len(it.stack)-1]
continue
}
if top.Index == 0 {
// We were positioned in the left subtree. Try the right sibling.
top.Index = 1
right := n.right
if right == nil {
it.stack = it.stack[:len(it.stack)-1]
continue
}
if _, isEmpty := right.(Empty); isEmpty {
it.stack = it.stack[:len(it.stack)-1]
continue
}
// Resolve the right child if it's hashed. Use a synthetic path
// where the bit at this depth is 1 (we're descending right).
resolved, err := it.resolveRightChild(n)
if err != nil {
return err
}
if resolved == nil {
it.stack = it.stack[:len(it.stack)-1]
continue
}
if resolved != right {
n.right = resolved
right = resolved
}
it.current = right
return it.seekLeftmost(right)
}
// Index == 1: we were already in the right subtree. Both subtrees of
// this internal node have been considered. Pop and try higher.
it.stack = it.stack[:len(it.stack)-1]
}
it.lastErr = errIteratorEnd
return nil
}
// seekLeftmost descends into the leftmost leaf of the subtree rooted at
// `node`, pushing internal nodes onto the stack with Index = 0 (left first).
// It positions the iterator at a StemNode with Index = 0, ready to scan
// values from offset 0.
func (it *binaryNodeIterator) seekLeftmost(node BinaryNode) error {
for {
switch n := node.(type) {
case *InternalNode:
it.stack = append(it.stack, binaryNodeIteratorState{Node: n, Index: 0})
it.current = n
child := n.left
pickedRight := false
if child == nil {
child = n.right
pickedRight = true
}
if child != nil {
if _, isEmpty := child.(Empty); isEmpty {
if !pickedRight {
child = n.right
pickedRight = true
}
if child != nil {
if _, isEmpty2 := child.(Empty); isEmpty2 {
child = nil
}
}
}
}
if child == nil {
// Both children are empty/nil — degenerate. Pop and let seek
// backtrack handle it. (This shouldn't normally happen for a
// well-formed trie because internal nodes always have at least
// two non-empty children at construction time.)
it.stack = it.stack[:len(it.stack)-1]
return it.seekBacktrack()
}
if pickedRight {
it.stack[len(it.stack)-1].Index = 1
}
// Resolve hashed child
resolved, err := it.resolveIfHashed(child, nil, n.depth+1)
if err != nil {
return err
}
if resolved == nil {
// Resolution failed; treat as empty and try the other side.
if pickedRight {
// Already tried right; nothing left.
it.stack = it.stack[:len(it.stack)-1]
return it.seekBacktrack()
}
// Try right
right := n.right
if right == nil {
it.stack = it.stack[:len(it.stack)-1]
return it.seekBacktrack()
}
if _, isEmpty := right.(Empty); isEmpty {
it.stack = it.stack[:len(it.stack)-1]
return it.seekBacktrack()
}
it.stack[len(it.stack)-1].Index = 1
resolved, err = it.resolveIfHashed(right, nil, n.depth+1)
if err != nil {
return err
}
if resolved == nil {
it.stack = it.stack[:len(it.stack)-1]
return it.seekBacktrack()
}
n.right = resolved
node = resolved
continue
}
if resolved != child {
if pickedRight {
n.right = resolved
} else {
n.left = resolved
}
}
node = resolved
case *StemNode:
it.stack = append(it.stack, binaryNodeIteratorState{Node: n, Index: 0})
it.current = n
return nil
default:
return fmt.Errorf("seekLeftmost: unexpected node type %T", node)
}
}
}
// resolveIfHashed checks whether the given node is a HashedNode and, if so,
// uses the trie's nodeResolver to load and deserialize the underlying node.
// Returns the resolved node or the original if no resolution was needed.
// Returns (nil, nil) if the resolver returned no data (e.g., zero hash).
//
// keyForPath supplies the bit path used to address the node; for the root
// this is unused (path is empty). depth is the depth of the node being
// resolved, used for the deserialized node's internal depth field.
func (it *binaryNodeIterator) resolveIfHashed(node BinaryNode, keyForPath []byte, depth int) (BinaryNode, error) {
hn, ok := node.(HashedNode)
if !ok {
return node, nil
}
var path []byte
if depth > 0 && keyForPath != nil {
var err error
path, err = keyToPath(depth-1, keyForPath)
if err != nil {
return nil, err
}
}
data, err := it.trie.nodeResolver(path, common.Hash(hn))
if err != nil {
return nil, err
}
if data == nil {
return nil, nil
}
resolved, err := DeserializeNodeWithHash(data, depth, common.Hash(hn))
if err != nil {
return nil, err
}
return resolved, nil
}
// resolveRightChild resolves the right child of an InternalNode using a
// synthetic path that ends in bit=1. This is used by seekBacktrack when
// flipping from left to right exploration.
func (it *binaryNodeIterator) resolveRightChild(parent *InternalNode) (BinaryNode, error) {
right := parent.right
if _, ok := right.(HashedNode); !ok {
return right, nil
}
// Build a 32-byte key whose bit at parent.depth is 1; rest doesn't matter
// for the path computation.
var key [32]byte
key[parent.depth/8] |= 1 << (7 - uint(parent.depth%8))
return it.resolveIfHashed(right, key[:], parent.depth+1)
}
// Next moves the iterator to the next node. If the parameter is false, any child
// nodes will be skipped.
func (it *binaryNodeIterator) Next(descend bool) bool {

View file

@ -18,6 +18,7 @@ package bintrie
import (
"bytes"
"slices"
"testing"
"github.com/ethereum/go-ethereum/common"
@ -206,6 +207,241 @@ func TestIteratorDeepTree(t *testing.T) {
}
}
// collectLeaves iterates the trie and returns all (key, value) pairs visited.
func collectLeaves(t *testing.T, tr *BinaryTrie, start []byte) [][2][]byte {
t.Helper()
it, err := newBinaryNodeIterator(tr, start)
if err != nil {
t.Fatal(err)
}
var out [][2][]byte
for it.Next(true) {
if it.Leaf() {
k := slices.Clone(it.LeafKey())
v := slices.Clone(it.LeafBlob())
out = append(out, [2][]byte{k, v})
}
}
if it.Error() != nil {
t.Fatalf("iterator error: %v", it.Error())
}
return out
}
// TestSeekEmptyStart verifies that seek with a nil/empty start behaves like
// a fresh iterator (no skipping).
func TestSeekEmptyStart(t *testing.T) {
tr := makeTrie(t, [][2]common.Hash{
{common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey},
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
})
// Both nil and empty slice should iterate everything.
if got := len(collectLeaves(t, tr, nil)); got != 2 {
t.Fatalf("nil start: expected 2 leaves, got %d", got)
}
if got := len(collectLeaves(t, tr, []byte{})); got != 2 {
t.Fatalf("empty start: expected 2 leaves, got %d", got)
}
}
// TestSeekToExactKey verifies that seeking to an existing leaf key positions
// the iterator at that exact leaf.
func TestSeekToExactKey(t *testing.T) {
keys := [][2]common.Hash{
{common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey},
{common.HexToHash("0000000000000000000000000000000000000000000000000000000000000002"), twoKey},
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
}
tr := makeTrie(t, keys)
// Seek to the second key. We expect to see [key2, key3].
start := keys[1][0]
got := collectLeaves(t, tr, start[:])
if len(got) != 2 {
t.Fatalf("expected 2 leaves after seek to %x, got %d", start, len(got))
}
if !bytes.Equal(got[0][0], keys[1][0][:]) {
t.Fatalf("first leaf after seek: got %x, want %x", got[0][0], keys[1][0])
}
if !bytes.Equal(got[1][0], keys[2][0][:]) {
t.Fatalf("second leaf after seek: got %x, want %x", got[1][0], keys[2][0])
}
}
// TestSeekToBetweenKeys verifies that seeking to a key that doesn't exist
// positions the iterator at the next existing key (in-order).
func TestSeekToBetweenKeys(t *testing.T) {
keys := [][2]common.Hash{
{common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey},
{common.HexToHash("0000000000000000000000000000000000000000000000000000000000000005"), twoKey},
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
}
tr := makeTrie(t, keys)
// Seek to a key between key0 and key1: should land at key1.
between := common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003")
got := collectLeaves(t, tr, between[:])
if len(got) != 2 {
t.Fatalf("expected 2 leaves after seek between, got %d", len(got))
}
if !bytes.Equal(got[0][0], keys[1][0][:]) {
t.Fatalf("first leaf: got %x, want %x", got[0][0], keys[1][0])
}
if !bytes.Equal(got[1][0], keys[2][0][:]) {
t.Fatalf("second leaf: got %x, want %x", got[1][0], keys[2][0])
}
}
// TestSeekIntoEmptySubtree verifies that seeking into a subtree where the
// chosen path is empty correctly backtracks to the next populated subtree.
func TestSeekIntoEmptySubtree(t *testing.T) {
// Build a trie with stems split across the bit-0 and bit-1 subtrees.
keys := [][2]common.Hash{
{common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey},
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), twoKey},
}
tr := makeTrie(t, keys)
// Seek to a key in a subtree that's entirely missing (e.g., 0x40...).
// The high bit is 0, so we'd descend left, but the left subtree only has
// keys with the FIRST bit being 0 — and the seek bit pattern would walk
// into a position that has no leaves at or after it on the left side,
// requiring backtrack to the right subtree.
missing := common.HexToHash("4000000000000000000000000000000000000000000000000000000000000001")
got := collectLeaves(t, tr, missing[:])
// Should land at key1 (the right subtree leaf).
if len(got) != 1 {
t.Fatalf("expected 1 leaf after seek into missing subtree, got %d", len(got))
}
if !bytes.Equal(got[0][0], keys[1][0][:]) {
t.Fatalf("leaf: got %x, want %x", got[0][0], keys[1][0])
}
}
// TestSeekPastEnd verifies that seeking past the last key returns no leaves.
func TestSeekPastEnd(t *testing.T) {
keys := [][2]common.Hash{
{common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey},
{common.HexToHash("0000000000000000000000000000000000000000000000000000000000000002"), oneKey},
}
tr := makeTrie(t, keys)
// Seek past the maximum key.
beyond := common.HexToHash("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
got := collectLeaves(t, tr, beyond[:])
if len(got) != 0 {
t.Fatalf("expected 0 leaves after seek past end, got %d: %x", len(got), got)
}
}
// TestSeekWithinSameStem verifies that seeking within a single stem (multiple
// values at different offsets) positions correctly at the requested offset.
func TestSeekWithinSameStem(t *testing.T) {
// All three keys share the same stem; only the last byte differs.
keys := [][2]common.Hash{
{common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey},
{common.HexToHash("0000000000000000000000000000000000000000000000000000000000000005"), twoKey},
{common.HexToHash("00000000000000000000000000000000000000000000000000000000000000ff"), oneKey},
}
tr := makeTrie(t, keys)
// Seek to offset 5: should yield keys 1 (offset 5) and 2 (offset 0xff).
start := common.HexToHash("0000000000000000000000000000000000000000000000000000000000000005")
got := collectLeaves(t, tr, start[:])
if len(got) != 2 {
t.Fatalf("expected 2 leaves, got %d", len(got))
}
if got[0][0][31] != 0x05 {
t.Fatalf("first leaf offset: got 0x%02x, want 0x05", got[0][0][31])
}
if got[1][0][31] != 0xff {
t.Fatalf("second leaf offset: got 0x%02x, want 0xff", got[1][0][31])
}
// Seek to offset 6 (between 5 and 0xff): should yield only key 2.
start[31] = 0x06
got = collectLeaves(t, tr, start[:])
if len(got) != 1 {
t.Fatalf("expected 1 leaf after seek to offset 6, got %d", len(got))
}
if got[0][0][31] != 0xff {
t.Fatalf("leaf offset: got 0x%02x, want 0xff", got[0][0][31])
}
}
// TestSeekResumeSimulation simulates a generator interruption: iterate halfway,
// extract the last leaf key, build a fresh iterator, seek to the next key, and
// verify that the resumed iteration produces the remaining leaves.
func TestSeekResumeSimulation(t *testing.T) {
// Construct a deterministic set of keys.
var keys [][2]common.Hash
for i := range 16 {
var k common.Hash
k[0] = byte(i << 4) // distribute across the high nibble
k[31] = 0x01
keys = append(keys, [2]common.Hash{k, oneKey})
}
tr := makeTrie(t, keys)
// First pass: collect all leaves.
all := collectLeaves(t, tr, nil)
if len(all) != 16 {
t.Fatalf("first pass: expected 16 leaves, got %d", len(all))
}
// Stop after the 7th leaf and resume.
stopIdx := 7
lastKey := all[stopIdx][0]
// Resume: seek to the byte AFTER lastKey (we use lastKey + 1 in the last
// byte; for our keys this is sufficient because each key's last byte is
// 0x01 and we want to go to the NEXT stem).
resumeKey := slices.Clone(lastKey)
// Increment the last byte; if it overflows, that's fine for these keys
// because all our last bytes are 0x01.
resumeKey[31]++
// But actually we want to start AT lastKey + 1, which for our keys means
// we want the NEXT stem. Since each stem has only one value at offset 0x01
// and we want everything strictly after lastKey, set offset to 0x02.
got := collectLeaves(t, tr, resumeKey)
if len(got) != len(all)-stopIdx-1 {
t.Fatalf("resume: expected %d leaves, got %d", len(all)-stopIdx-1, len(got))
}
for i, leaf := range got {
want := all[stopIdx+1+i]
if !bytes.Equal(leaf[0], want[0]) {
t.Fatalf("resume leaf %d: got %x, want %x", i, leaf[0], want[0])
}
}
}
// TestSeekDeepTree verifies seek works on a tree with a long shared prefix.
func TestSeekDeepTree(t *testing.T) {
keys := [][2]common.Hash{
{common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0"), oneKey},
{common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000"), twoKey},
}
tr := makeTrie(t, keys)
// Seek to the first key exactly.
got := collectLeaves(t, tr, keys[0][0][:])
if len(got) != 2 {
t.Fatalf("seek to first: expected 2 leaves, got %d", len(got))
}
if !bytes.Equal(got[0][0], keys[0][0][:]) {
t.Fatalf("first leaf: got %x, want %x", got[0][0], keys[0][0])
}
// Seek to the second key exactly.
got = collectLeaves(t, tr, keys[1][0][:])
if len(got) != 1 {
t.Fatalf("seek to second: expected 1 leaf, got %d", len(got))
}
if !bytes.Equal(got[0][0], keys[1][0][:]) {
t.Fatalf("leaf: got %x, want %x", got[0][0], keys[1][0])
}
}
// TestIteratorNodeCount verifies the total number of Next(true) calls
// for a known tree structure.
func TestIteratorNodeCount(t *testing.T) {

View file

@ -352,9 +352,10 @@ func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
}
// NodeIterator returns an iterator that returns nodes of the trie. Iteration
// starts at the key after the given start key.
// starts at the first leaf with key >= startKey. A nil/empty startKey iterates
// the whole trie.
func (t *BinaryTrie) NodeIterator(startKey []byte) (trie.NodeIterator, error) {
return newBinaryNodeIterator(t, nil)
return newBinaryNodeIterator(t, startKey)
}
// Prove constructs a Merkle proof for key. The result contains all encoded nodes