From 2851f7b8c75e93ba6c2e4899837013ca65640077 Mon Sep 17 00:00:00 2001 From: CPerezz Date: Tue, 7 Apr 2026 14:09:07 +0200 Subject: [PATCH] trie/bintrie: implement binaryNodeIterator.seek() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The bintrie node iterator previously discarded its `start` parameter, forcing every iteration to begin at the root. This makes resumable generators (snapshot/flat-state population) impossible — any interruption restarts from scratch. Implement seek(start []byte) by walking down the trie following start's bit path, building the iterator stack as we go. When the chosen path dead-ends (Empty, missing child, or a stem strictly less than start), backtrack through the existing stack to find the next in-order subtree and descend to its leftmost leaf. Also wire BinaryTrie.NodeIterator(startKey) to actually pass startKey through (was hardcoded to nil). Tests cover: empty start (no-op), exact key match, between-keys, into empty subtree, past end, within-stem offsets, resume simulation, and deep tree. --- trie/bintrie/iterator.go | 332 +++++++++++++++++++++++++++++++++- trie/bintrie/iterator_test.go | 236 ++++++++++++++++++++++++ trie/bintrie/trie.go | 5 +- 3 files changed, 569 insertions(+), 4 deletions(-) diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go index 048d37f766..989f49244b 100644 --- a/trie/bintrie/iterator.go +++ b/trie/bintrie/iterator.go @@ -17,7 +17,9 @@ package bintrie import ( + "bytes" "errors" + "fmt" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/trie" @@ -38,15 +40,341 @@ type binaryNodeIterator struct { stack []binaryNodeIteratorState } -func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) { +func newBinaryNodeIterator(t *BinaryTrie, start []byte) (trie.NodeIterator, error) { if t.Hash() == zero { return &binaryNodeIterator{trie: t, lastErr: errIteratorEnd}, nil } it := &binaryNodeIterator{trie: t, current: t.root} - // it.err = it.seek(start) + if len(start) > 0 { + if err := it.seek(start); err != nil { + return nil, err + } + } return it, nil } +// seek positions the iterator so that the next call to Next(true) advances to +// the first leaf with key >= start. It walks down the trie following start's +// bit path, building the iterator stack along the way. When the chosen path +// dead-ends (Empty, missing child, or a stem strictly less than start), the +// implementation backtracks through the existing stack to find the next +// in-order subtree and descends to its leftmost leaf. +// +// A nil/empty start is a no-op; iteration begins at the trie root as usual. +// +// This is required for resumable bintrie generators (snapshot generation, +// pathdb flat-state population) so that an interrupted run can pick up where +// it left off after a crash or graceful shutdown. +func (it *binaryNodeIterator) seek(start []byte) error { + if len(start) == 0 { + return nil + } + // Pad start to a 32-byte key (the trie's natural key length). + var key [32]byte + copy(key[:], start) + + // Reset state + it.stack = it.stack[:0] + it.current = nil + it.lastErr = nil + + root := it.trie.root + if root == nil { + it.lastErr = errIteratorEnd + return nil + } + if _, isEmpty := root.(Empty); isEmpty { + it.lastErr = errIteratorEnd + return nil + } + + // Resolve the root if it's a HashedNode + resolved, err := it.resolveIfHashed(root, nil, 0) + if err != nil { + return err + } + if resolved == nil { + it.lastErr = errIteratorEnd + return nil + } + if resolved != root { + it.trie.root = resolved + root = resolved + } + + return it.seekDescend(root, key[:]) +} + +// seekDescend walks down from `node` following key's bit path. For each +// InternalNode encountered, it pushes the node onto the stack with Index set +// to the bit it descended into (0 for left, 1 for right) and recurses into +// the chosen child. On a StemNode it positions at the appropriate value +// offset and returns. On a dead end (Empty, nil, stem < key), it delegates +// to seekBacktrack to find the next valid subtree. +func (it *binaryNodeIterator) seekDescend(node BinaryNode, key []byte) error { + for { + switch n := node.(type) { + case *InternalNode: + depth := n.depth + if depth >= 31*8 { + return errors.New("seek: internal node too deep") + } + bit := key[depth/8] >> (7 - uint(depth%8)) & 1 + + // Push this internal node with Index = chosen bit. The Next() + // loop interprets Index as "the side currently being explored", + // so this is consistent with normal iteration state. + it.stack = append(it.stack, binaryNodeIteratorState{Node: n, Index: int(bit)}) + it.current = n + + var child BinaryNode + if bit == 0 { + child = n.left + } else { + child = n.right + } + if child == nil { + return it.seekBacktrack() + } + if _, isEmpty := child.(Empty); isEmpty { + return it.seekBacktrack() + } + // Resolve a hashed child using the current key as the path source. + resolved, err := it.resolveIfHashed(child, key, depth+1) + if err != nil { + return err + } + if resolved == nil { + return it.seekBacktrack() + } + if resolved != child { + if bit == 0 { + n.left = resolved + } else { + n.right = resolved + } + } + node = resolved + + case *StemNode: + cmp := bytes.Compare(n.Stem, key[:StemSize]) + if cmp < 0 { + // Stem is strictly before our target. Don't push it; backtrack + // to find the next subtree to the right. + return it.seekBacktrack() + } + startOffset := 0 + if cmp == 0 { + startOffset = int(key[StemSize]) + } + it.stack = append(it.stack, binaryNodeIteratorState{Node: n, Index: startOffset}) + it.current = n + return nil + + default: + return fmt.Errorf("seek: unexpected node type %T", node) + } + } +} + +// seekBacktrack walks the existing stack backward looking for the first +// InternalNode whose right subtree hasn't been considered yet. If found, it +// flips that node's Index to 1 and descends into the leftmost leaf of the +// right subtree. If no such ancestor exists, it sets errIteratorEnd. +func (it *binaryNodeIterator) seekBacktrack() error { + for len(it.stack) > 0 { + top := &it.stack[len(it.stack)-1] + n, ok := top.Node.(*InternalNode) + if !ok { + // Not an InternalNode (e.g., a StemNode pushed elsewhere). Pop and + // continue. seekDescend never pushes non-internal nodes before + // returning, so this is a defensive fallback. + it.stack = it.stack[:len(it.stack)-1] + continue + } + if top.Index == 0 { + // We were positioned in the left subtree. Try the right sibling. + top.Index = 1 + right := n.right + if right == nil { + it.stack = it.stack[:len(it.stack)-1] + continue + } + if _, isEmpty := right.(Empty); isEmpty { + it.stack = it.stack[:len(it.stack)-1] + continue + } + // Resolve the right child if it's hashed. Use a synthetic path + // where the bit at this depth is 1 (we're descending right). + resolved, err := it.resolveRightChild(n) + if err != nil { + return err + } + if resolved == nil { + it.stack = it.stack[:len(it.stack)-1] + continue + } + if resolved != right { + n.right = resolved + right = resolved + } + it.current = right + return it.seekLeftmost(right) + } + // Index == 1: we were already in the right subtree. Both subtrees of + // this internal node have been considered. Pop and try higher. + it.stack = it.stack[:len(it.stack)-1] + } + it.lastErr = errIteratorEnd + return nil +} + +// seekLeftmost descends into the leftmost leaf of the subtree rooted at +// `node`, pushing internal nodes onto the stack with Index = 0 (left first). +// It positions the iterator at a StemNode with Index = 0, ready to scan +// values from offset 0. +func (it *binaryNodeIterator) seekLeftmost(node BinaryNode) error { + for { + switch n := node.(type) { + case *InternalNode: + it.stack = append(it.stack, binaryNodeIteratorState{Node: n, Index: 0}) + it.current = n + + child := n.left + pickedRight := false + if child == nil { + child = n.right + pickedRight = true + } + if child != nil { + if _, isEmpty := child.(Empty); isEmpty { + if !pickedRight { + child = n.right + pickedRight = true + } + if child != nil { + if _, isEmpty2 := child.(Empty); isEmpty2 { + child = nil + } + } + } + } + if child == nil { + // Both children are empty/nil — degenerate. Pop and let seek + // backtrack handle it. (This shouldn't normally happen for a + // well-formed trie because internal nodes always have at least + // two non-empty children at construction time.) + it.stack = it.stack[:len(it.stack)-1] + return it.seekBacktrack() + } + if pickedRight { + it.stack[len(it.stack)-1].Index = 1 + } + // Resolve hashed child + resolved, err := it.resolveIfHashed(child, nil, n.depth+1) + if err != nil { + return err + } + if resolved == nil { + // Resolution failed; treat as empty and try the other side. + if pickedRight { + // Already tried right; nothing left. + it.stack = it.stack[:len(it.stack)-1] + return it.seekBacktrack() + } + // Try right + right := n.right + if right == nil { + it.stack = it.stack[:len(it.stack)-1] + return it.seekBacktrack() + } + if _, isEmpty := right.(Empty); isEmpty { + it.stack = it.stack[:len(it.stack)-1] + return it.seekBacktrack() + } + it.stack[len(it.stack)-1].Index = 1 + resolved, err = it.resolveIfHashed(right, nil, n.depth+1) + if err != nil { + return err + } + if resolved == nil { + it.stack = it.stack[:len(it.stack)-1] + return it.seekBacktrack() + } + n.right = resolved + node = resolved + continue + } + if resolved != child { + if pickedRight { + n.right = resolved + } else { + n.left = resolved + } + } + node = resolved + + case *StemNode: + it.stack = append(it.stack, binaryNodeIteratorState{Node: n, Index: 0}) + it.current = n + return nil + + default: + return fmt.Errorf("seekLeftmost: unexpected node type %T", node) + } + } +} + +// resolveIfHashed checks whether the given node is a HashedNode and, if so, +// uses the trie's nodeResolver to load and deserialize the underlying node. +// Returns the resolved node or the original if no resolution was needed. +// Returns (nil, nil) if the resolver returned no data (e.g., zero hash). +// +// keyForPath supplies the bit path used to address the node; for the root +// this is unused (path is empty). depth is the depth of the node being +// resolved, used for the deserialized node's internal depth field. +func (it *binaryNodeIterator) resolveIfHashed(node BinaryNode, keyForPath []byte, depth int) (BinaryNode, error) { + hn, ok := node.(HashedNode) + if !ok { + return node, nil + } + var path []byte + if depth > 0 && keyForPath != nil { + var err error + path, err = keyToPath(depth-1, keyForPath) + if err != nil { + return nil, err + } + } + data, err := it.trie.nodeResolver(path, common.Hash(hn)) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + resolved, err := DeserializeNodeWithHash(data, depth, common.Hash(hn)) + if err != nil { + return nil, err + } + return resolved, nil +} + +// resolveRightChild resolves the right child of an InternalNode using a +// synthetic path that ends in bit=1. This is used by seekBacktrack when +// flipping from left to right exploration. +func (it *binaryNodeIterator) resolveRightChild(parent *InternalNode) (BinaryNode, error) { + right := parent.right + if _, ok := right.(HashedNode); !ok { + return right, nil + } + // Build a 32-byte key whose bit at parent.depth is 1; rest doesn't matter + // for the path computation. + var key [32]byte + key[parent.depth/8] |= 1 << (7 - uint(parent.depth%8)) + return it.resolveIfHashed(right, key[:], parent.depth+1) +} + // Next moves the iterator to the next node. If the parameter is false, any child // nodes will be skipped. func (it *binaryNodeIterator) Next(descend bool) bool { diff --git a/trie/bintrie/iterator_test.go b/trie/bintrie/iterator_test.go index 3e717c07ba..6ee03525be 100644 --- a/trie/bintrie/iterator_test.go +++ b/trie/bintrie/iterator_test.go @@ -18,6 +18,7 @@ package bintrie import ( "bytes" + "slices" "testing" "github.com/ethereum/go-ethereum/common" @@ -206,6 +207,241 @@ func TestIteratorDeepTree(t *testing.T) { } } +// collectLeaves iterates the trie and returns all (key, value) pairs visited. +func collectLeaves(t *testing.T, tr *BinaryTrie, start []byte) [][2][]byte { + t.Helper() + it, err := newBinaryNodeIterator(tr, start) + if err != nil { + t.Fatal(err) + } + var out [][2][]byte + for it.Next(true) { + if it.Leaf() { + k := slices.Clone(it.LeafKey()) + v := slices.Clone(it.LeafBlob()) + out = append(out, [2][]byte{k, v}) + } + } + if it.Error() != nil { + t.Fatalf("iterator error: %v", it.Error()) + } + return out +} + +// TestSeekEmptyStart verifies that seek with a nil/empty start behaves like +// a fresh iterator (no skipping). +func TestSeekEmptyStart(t *testing.T) { + tr := makeTrie(t, [][2]common.Hash{ + {common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey}, + {common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey}, + }) + // Both nil and empty slice should iterate everything. + if got := len(collectLeaves(t, tr, nil)); got != 2 { + t.Fatalf("nil start: expected 2 leaves, got %d", got) + } + if got := len(collectLeaves(t, tr, []byte{})); got != 2 { + t.Fatalf("empty start: expected 2 leaves, got %d", got) + } +} + +// TestSeekToExactKey verifies that seeking to an existing leaf key positions +// the iterator at that exact leaf. +func TestSeekToExactKey(t *testing.T) { + keys := [][2]common.Hash{ + {common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey}, + {common.HexToHash("0000000000000000000000000000000000000000000000000000000000000002"), twoKey}, + {common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey}, + } + tr := makeTrie(t, keys) + + // Seek to the second key. We expect to see [key2, key3]. + start := keys[1][0] + got := collectLeaves(t, tr, start[:]) + if len(got) != 2 { + t.Fatalf("expected 2 leaves after seek to %x, got %d", start, len(got)) + } + if !bytes.Equal(got[0][0], keys[1][0][:]) { + t.Fatalf("first leaf after seek: got %x, want %x", got[0][0], keys[1][0]) + } + if !bytes.Equal(got[1][0], keys[2][0][:]) { + t.Fatalf("second leaf after seek: got %x, want %x", got[1][0], keys[2][0]) + } +} + +// TestSeekToBetweenKeys verifies that seeking to a key that doesn't exist +// positions the iterator at the next existing key (in-order). +func TestSeekToBetweenKeys(t *testing.T) { + keys := [][2]common.Hash{ + {common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey}, + {common.HexToHash("0000000000000000000000000000000000000000000000000000000000000005"), twoKey}, + {common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey}, + } + tr := makeTrie(t, keys) + + // Seek to a key between key0 and key1: should land at key1. + between := common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003") + got := collectLeaves(t, tr, between[:]) + if len(got) != 2 { + t.Fatalf("expected 2 leaves after seek between, got %d", len(got)) + } + if !bytes.Equal(got[0][0], keys[1][0][:]) { + t.Fatalf("first leaf: got %x, want %x", got[0][0], keys[1][0]) + } + if !bytes.Equal(got[1][0], keys[2][0][:]) { + t.Fatalf("second leaf: got %x, want %x", got[1][0], keys[2][0]) + } +} + +// TestSeekIntoEmptySubtree verifies that seeking into a subtree where the +// chosen path is empty correctly backtracks to the next populated subtree. +func TestSeekIntoEmptySubtree(t *testing.T) { + // Build a trie with stems split across the bit-0 and bit-1 subtrees. + keys := [][2]common.Hash{ + {common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey}, + {common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), twoKey}, + } + tr := makeTrie(t, keys) + + // Seek to a key in a subtree that's entirely missing (e.g., 0x40...). + // The high bit is 0, so we'd descend left, but the left subtree only has + // keys with the FIRST bit being 0 — and the seek bit pattern would walk + // into a position that has no leaves at or after it on the left side, + // requiring backtrack to the right subtree. + missing := common.HexToHash("4000000000000000000000000000000000000000000000000000000000000001") + got := collectLeaves(t, tr, missing[:]) + // Should land at key1 (the right subtree leaf). + if len(got) != 1 { + t.Fatalf("expected 1 leaf after seek into missing subtree, got %d", len(got)) + } + if !bytes.Equal(got[0][0], keys[1][0][:]) { + t.Fatalf("leaf: got %x, want %x", got[0][0], keys[1][0]) + } +} + +// TestSeekPastEnd verifies that seeking past the last key returns no leaves. +func TestSeekPastEnd(t *testing.T) { + keys := [][2]common.Hash{ + {common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey}, + {common.HexToHash("0000000000000000000000000000000000000000000000000000000000000002"), oneKey}, + } + tr := makeTrie(t, keys) + + // Seek past the maximum key. + beyond := common.HexToHash("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + got := collectLeaves(t, tr, beyond[:]) + if len(got) != 0 { + t.Fatalf("expected 0 leaves after seek past end, got %d: %x", len(got), got) + } +} + +// TestSeekWithinSameStem verifies that seeking within a single stem (multiple +// values at different offsets) positions correctly at the requested offset. +func TestSeekWithinSameStem(t *testing.T) { + // All three keys share the same stem; only the last byte differs. + keys := [][2]common.Hash{ + {common.HexToHash("0000000000000000000000000000000000000000000000000000000000000001"), oneKey}, + {common.HexToHash("0000000000000000000000000000000000000000000000000000000000000005"), twoKey}, + {common.HexToHash("00000000000000000000000000000000000000000000000000000000000000ff"), oneKey}, + } + tr := makeTrie(t, keys) + + // Seek to offset 5: should yield keys 1 (offset 5) and 2 (offset 0xff). + start := common.HexToHash("0000000000000000000000000000000000000000000000000000000000000005") + got := collectLeaves(t, tr, start[:]) + if len(got) != 2 { + t.Fatalf("expected 2 leaves, got %d", len(got)) + } + if got[0][0][31] != 0x05 { + t.Fatalf("first leaf offset: got 0x%02x, want 0x05", got[0][0][31]) + } + if got[1][0][31] != 0xff { + t.Fatalf("second leaf offset: got 0x%02x, want 0xff", got[1][0][31]) + } + + // Seek to offset 6 (between 5 and 0xff): should yield only key 2. + start[31] = 0x06 + got = collectLeaves(t, tr, start[:]) + if len(got) != 1 { + t.Fatalf("expected 1 leaf after seek to offset 6, got %d", len(got)) + } + if got[0][0][31] != 0xff { + t.Fatalf("leaf offset: got 0x%02x, want 0xff", got[0][0][31]) + } +} + +// TestSeekResumeSimulation simulates a generator interruption: iterate halfway, +// extract the last leaf key, build a fresh iterator, seek to the next key, and +// verify that the resumed iteration produces the remaining leaves. +func TestSeekResumeSimulation(t *testing.T) { + // Construct a deterministic set of keys. + var keys [][2]common.Hash + for i := range 16 { + var k common.Hash + k[0] = byte(i << 4) // distribute across the high nibble + k[31] = 0x01 + keys = append(keys, [2]common.Hash{k, oneKey}) + } + tr := makeTrie(t, keys) + + // First pass: collect all leaves. + all := collectLeaves(t, tr, nil) + if len(all) != 16 { + t.Fatalf("first pass: expected 16 leaves, got %d", len(all)) + } + + // Stop after the 7th leaf and resume. + stopIdx := 7 + lastKey := all[stopIdx][0] + + // Resume: seek to the byte AFTER lastKey (we use lastKey + 1 in the last + // byte; for our keys this is sufficient because each key's last byte is + // 0x01 and we want to go to the NEXT stem). + resumeKey := slices.Clone(lastKey) + // Increment the last byte; if it overflows, that's fine for these keys + // because all our last bytes are 0x01. + resumeKey[31]++ + // But actually we want to start AT lastKey + 1, which for our keys means + // we want the NEXT stem. Since each stem has only one value at offset 0x01 + // and we want everything strictly after lastKey, set offset to 0x02. + got := collectLeaves(t, tr, resumeKey) + if len(got) != len(all)-stopIdx-1 { + t.Fatalf("resume: expected %d leaves, got %d", len(all)-stopIdx-1, len(got)) + } + for i, leaf := range got { + want := all[stopIdx+1+i] + if !bytes.Equal(leaf[0], want[0]) { + t.Fatalf("resume leaf %d: got %x, want %x", i, leaf[0], want[0]) + } + } +} + +// TestSeekDeepTree verifies seek works on a tree with a long shared prefix. +func TestSeekDeepTree(t *testing.T) { + keys := [][2]common.Hash{ + {common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0"), oneKey}, + {common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000"), twoKey}, + } + tr := makeTrie(t, keys) + + // Seek to the first key exactly. + got := collectLeaves(t, tr, keys[0][0][:]) + if len(got) != 2 { + t.Fatalf("seek to first: expected 2 leaves, got %d", len(got)) + } + if !bytes.Equal(got[0][0], keys[0][0][:]) { + t.Fatalf("first leaf: got %x, want %x", got[0][0], keys[0][0]) + } + + // Seek to the second key exactly. + got = collectLeaves(t, tr, keys[1][0][:]) + if len(got) != 1 { + t.Fatalf("seek to second: expected 1 leaf, got %d", len(got)) + } + if !bytes.Equal(got[0][0], keys[1][0][:]) { + t.Fatalf("leaf: got %x, want %x", got[0][0], keys[1][0]) + } +} + // TestIteratorNodeCount verifies the total number of Next(true) calls // for a known tree structure. func TestIteratorNodeCount(t *testing.T) { diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index b1e3c991c0..14c1a46c2b 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -352,9 +352,10 @@ func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { } // NodeIterator returns an iterator that returns nodes of the trie. Iteration -// starts at the key after the given start key. +// starts at the first leaf with key >= startKey. A nil/empty startKey iterates +// the whole trie. func (t *BinaryTrie) NodeIterator(startKey []byte) (trie.NodeIterator, error) { - return newBinaryNodeIterator(t, nil) + return newBinaryNodeIterator(t, startKey) } // Prove constructs a Merkle proof for key. The result contains all encoded nodes