trie: add sub-trie iterator support (#32520)

- Adds `NodeIteratorWithPrefix()` method to support iterating only nodes
within a specific key prefix
- Adds `NodeIteratorWithRange()` method to support iterating only nodes
within a specific key range

Current `NodeIterator` always traverses the entire remaining trie from a
start position. For non-ethereum applications using the trie implementation, 
there's no way to limit iteration to just a subtree with a specific prefix.

  **Usage:**

  ```go
  // Only iterate nodes with prefix "key1"
  iter, err := trie.NodeIteratorWithPrefix([]byte("key1"))
  ```

Testing: Comprehensive test suite covering edge cases and boundary conditions.

Closes #32484

---------

Co-authored-by: gballet <guillaume.ballet@gmail.com>
Co-authored-by: Gary Rong <garyrong0905@gmail.com>
This commit is contained in:
Samuel Arogbonlo 2025-09-17 15:07:02 +01:00 committed by GitHub
parent 21769f3474
commit fda09c7b1b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 652 additions and 0 deletions

View file

@ -836,3 +836,91 @@ func (it *unionIterator) Error() error {
}
return nil
}
// subTreeIterator wraps nodeIterator to traverse a trie within a predefined
// start and limit range.
type subtreeIterator struct {
NodeIterator
stopPath []byte // Precomputed hex path for stopKey (without terminator), nil means no limit
exhausted bool // Flag whether the iterator has been exhausted
}
// newSubtreeIterator creates an iterator that only traverses nodes within a subtree
// defined by the given startKey and stopKey. This supports general range iteration
// where startKey is inclusive and stopKey is exclusive.
//
// The iterator will only visit nodes whose keys k satisfy: startKey <= k < stopKey,
// where comparisons are performed in lexicographic order of byte keys (internally
// implemented via hex-nibble path comparisons for efficiency).
//
// If startKey is nil, iteration starts from the beginning. If stopKey is nil,
// iteration continues to the end of the trie.
func newSubtreeIterator(trie *Trie, startKey, stopKey []byte) (NodeIterator, error) {
it, err := trie.NodeIterator(startKey)
if err != nil {
return nil, err
}
if startKey == nil && stopKey == nil {
return it, nil
}
// Precompute nibble paths for efficient comparison
var stopPath []byte
if stopKey != nil {
stopPath = keybytesToHex(stopKey)
if hasTerm(stopPath) {
stopPath = stopPath[:len(stopPath)-1]
}
}
return &subtreeIterator{
NodeIterator: it,
stopPath: stopPath,
}, nil
}
// nextKey returns the next possible key after the given prefix.
// For example, "abc" -> "abd", "ab\xff" -> "ac", etc.
func nextKey(prefix []byte) []byte {
if len(prefix) == 0 {
return nil
}
// Make a copy to avoid modifying the original
next := make([]byte, len(prefix))
copy(next, prefix)
// Increment the last byte that isn't 0xff
for i := len(next) - 1; i >= 0; i-- {
if next[i] < 0xff {
next[i]++
return next
}
// If it's 0xff, we need to carry over
// Trim trailing 0xff bytes
next = next[:i]
}
// If all bytes were 0xff, return nil (no upper bound)
return nil
}
// newPrefixIterator creates an iterator that only traverses nodes with the given prefix.
// This ensures that only keys starting with the prefix are visited.
func newPrefixIterator(trie *Trie, prefix []byte) (NodeIterator, error) {
return newSubtreeIterator(trie, prefix, nextKey(prefix))
}
// Next moves the iterator to the next node. If the parameter is false, any child
// nodes will be skipped.
func (it *subtreeIterator) Next(descend bool) bool {
if it.exhausted {
return false
}
if !it.NodeIterator.Next(descend) {
it.exhausted = true
return false
}
if it.stopPath != nil && reachedPath(it.NodeIterator.Path(), it.stopPath) {
it.exhausted = true
return false
}
return true
}

View file

@ -19,7 +19,9 @@ package trie
import (
"bytes"
"fmt"
"maps"
"math/rand"
"slices"
"testing"
"github.com/ethereum/go-ethereum/common"
@ -624,6 +626,540 @@ func isTrieNode(scheme string, key, val []byte) (bool, []byte, common.Hash) {
return true, path, hash
}
func TestSubtreeIterator(t *testing.T) {
var (
db = newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)
tr = NewEmpty(db)
)
vals := []struct{ k, v string }{
{"do", "verb"},
{"dog", "puppy"},
{"doge", "coin"},
{"dog\xff", "value6"},
{"dog\xff\xff", "value7"},
{"horse", "stallion"},
{"house", "building"},
{"houses", "multiple"},
{"xyz", "value"},
{"xyz\xff", "value"},
{"xyz\xff\xff", "value"},
}
all := make(map[string]string)
for _, val := range vals {
all[val.k] = val.v
tr.MustUpdate([]byte(val.k), []byte(val.v))
}
root, nodes := tr.Commit(false)
db.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes))
allNodes := make(map[string][]byte)
tr, _ = New(TrieID(root), db)
it, err := tr.NodeIterator(nil)
if err != nil {
t.Fatal(err)
}
for it.Next(true) {
allNodes[string(it.Path())] = it.NodeBlob()
}
allKeys := slices.Collect(maps.Keys(all))
suites := []struct {
start []byte
end []byte
expected []string
}{
// entire key range
{
start: nil,
end: nil,
expected: allKeys,
},
{
start: nil,
end: bytes.Repeat([]byte{0xff}, 32),
expected: allKeys,
},
{
start: bytes.Repeat([]byte{0x0}, 32),
end: bytes.Repeat([]byte{0xff}, 32),
expected: allKeys,
},
// key range with start
{
start: []byte("do"),
end: nil,
expected: allKeys,
},
{
start: []byte("doe"),
end: nil,
expected: allKeys[1:],
},
{
start: []byte("dog"),
end: nil,
expected: allKeys[1:],
},
{
start: []byte("doge"),
end: nil,
expected: allKeys[2:],
},
{
start: []byte("dog\xff"),
end: nil,
expected: allKeys[3:],
},
{
start: []byte("dog\xff\xff"),
end: nil,
expected: allKeys[4:],
},
{
start: []byte("dog\xff\xff\xff"),
end: nil,
expected: allKeys[5:],
},
// key range with limit
{
start: nil,
end: []byte("xyz"),
expected: allKeys[:len(allKeys)-3],
},
{
start: nil,
end: []byte("xyz\xff"),
expected: allKeys[:len(allKeys)-2],
},
{
start: nil,
end: []byte("xyz\xff\xff"),
expected: allKeys[:len(allKeys)-1],
},
{
start: nil,
end: []byte("xyz\xff\xff\xff"),
expected: allKeys,
},
}
for _, suite := range suites {
// We need to re-open the trie from the committed state
tr, _ = New(TrieID(root), db)
it, err := newSubtreeIterator(tr, suite.start, suite.end)
if err != nil {
t.Fatal(err)
}
found := make(map[string]string)
for it.Next(true) {
if it.Leaf() {
found[string(it.LeafKey())] = string(it.LeafBlob())
}
}
if len(found) != len(suite.expected) {
t.Errorf("wrong number of values: got %d, want %d", len(found), len(suite.expected))
}
for k, v := range found {
if all[k] != v {
t.Errorf("wrong value for %s: got %s, want %s", k, found[k], all[k])
}
}
expectedNodes := make(map[string][]byte)
for path, blob := range allNodes {
if suite.start != nil {
hexStart := keybytesToHex(suite.start)
hexStart = hexStart[:len(hexStart)-1]
if !reachedPath([]byte(path), hexStart) {
continue
}
}
if suite.end != nil {
hexEnd := keybytesToHex(suite.end)
hexEnd = hexEnd[:len(hexEnd)-1]
if reachedPath([]byte(path), hexEnd) {
continue
}
}
expectedNodes[path] = bytes.Clone(blob)
}
// Compare the result yield from the subtree iterator
var (
subCount int
subIt, _ = newSubtreeIterator(tr, suite.start, suite.end)
)
for subIt.Next(true) {
blob, ok := expectedNodes[string(subIt.Path())]
if !ok {
t.Errorf("Unexpected node iterated, path: %v", subIt.Path())
}
subCount++
if !bytes.Equal(blob, subIt.NodeBlob()) {
t.Errorf("Unexpected node blob, path: %v, want: %v, got: %v", subIt.Path(), blob, subIt.NodeBlob())
}
}
if subCount != len(expectedNodes) {
t.Errorf("Unexpected node being iterated, want: %d, got: %d", len(expectedNodes), subCount)
}
}
}
func TestPrefixIterator(t *testing.T) {
// Create a new trie
trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
// Insert test data
testData := map[string]string{
"key1": "value1",
"key2": "value2",
"key10": "value10",
"key11": "value11",
"different": "value_different",
}
for key, value := range testData {
trie.Update([]byte(key), []byte(value))
}
// Test prefix iteration for "key1" prefix
prefix := []byte("key1")
iter, err := trie.NodeIteratorWithPrefix(prefix)
if err != nil {
t.Fatalf("Failed to create prefix iterator: %v", err)
}
var foundKeys [][]byte
for iter.Next(true) {
if iter.Leaf() {
foundKeys = append(foundKeys, iter.LeafKey())
}
}
if err := iter.Error(); err != nil {
t.Fatalf("Iterator error: %v", err)
}
// Verify only keys starting with "key1" were found
expectedCount := 3 // "key1", "key10", "key11"
if len(foundKeys) != expectedCount {
t.Errorf("Expected %d keys, found %d", expectedCount, len(foundKeys))
}
for _, key := range foundKeys {
keyStr := string(key)
if !bytes.HasPrefix(key, prefix) {
t.Errorf("Found key %s doesn't have prefix %s", keyStr, string(prefix))
}
}
}
func TestPrefixIteratorVsFullIterator(t *testing.T) {
// Create a new trie with more structured data
trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
// Insert structured test data
testData := map[string]string{
"aaa": "value_aaa",
"aab": "value_aab",
"aba": "value_aba",
"bbb": "value_bbb",
}
for key, value := range testData {
trie.Update([]byte(key), []byte(value))
}
// Test that prefix iterator stops at boundary
prefix := []byte("aa")
prefixIter, err := trie.NodeIteratorWithPrefix(prefix)
if err != nil {
t.Fatalf("Failed to create prefix iterator: %v", err)
}
var prefixKeys [][]byte
for prefixIter.Next(true) {
if prefixIter.Leaf() {
prefixKeys = append(prefixKeys, prefixIter.LeafKey())
}
}
// Should only find "aaa" and "aab", not "aba" or "bbb"
if len(prefixKeys) != 2 {
t.Errorf("Expected 2 keys with prefix 'aa', found %d", len(prefixKeys))
}
// Verify no keys outside prefix were found
for _, key := range prefixKeys {
if !bytes.HasPrefix(key, prefix) {
t.Errorf("Prefix iterator returned key %s outside prefix %s", string(key), string(prefix))
}
}
}
func TestEmptyPrefixIterator(t *testing.T) {
// Test with empty trie
trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
iter, err := trie.NodeIteratorWithPrefix([]byte("nonexistent"))
if err != nil {
t.Fatalf("Failed to create iterator: %v", err)
}
if iter.Next(true) {
t.Error("Expected no results from empty trie")
}
}
// TestPrefixIteratorEdgeCases tests various edge cases for prefix iteration
func TestPrefixIteratorEdgeCases(t *testing.T) {
// Create a trie with test data
trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
testData := map[string]string{
"abc": "value1",
"abcd": "value2",
"abce": "value3",
"abd": "value4",
"dog": "value5",
"dog\xff": "value6", // Test with 0xff byte
"dog\xff\xff": "value7", // Multiple 0xff bytes
}
for key, value := range testData {
trie.Update([]byte(key), []byte(value))
}
// Test 1: Prefix not present in trie
t.Run("NonexistentPrefix", func(t *testing.T) {
iter, err := trie.NodeIteratorWithPrefix([]byte("xyz"))
if err != nil {
t.Fatalf("Failed to create iterator: %v", err)
}
count := 0
for iter.Next(true) {
if iter.Leaf() {
count++
}
}
if count != 0 {
t.Errorf("Expected 0 results for nonexistent prefix, got %d", count)
}
})
// Test 2: Prefix exactly equals an existing key
t.Run("ExactKeyPrefix", func(t *testing.T) {
iter, err := trie.NodeIteratorWithPrefix([]byte("abc"))
if err != nil {
t.Fatalf("Failed to create iterator: %v", err)
}
found := make(map[string]bool)
for iter.Next(true) {
if iter.Leaf() {
found[string(iter.LeafKey())] = true
}
}
// Should find "abc", "abcd", "abce" but not "abd"
if !found["abc"] || !found["abcd"] || !found["abce"] {
t.Errorf("Missing expected keys: got %v", found)
}
if found["abd"] {
t.Errorf("Found unexpected key 'abd' with prefix 'abc'")
}
})
// Test 3: Prefix with trailing 0xff
t.Run("TrailingFFPrefix", func(t *testing.T) {
iter, err := trie.NodeIteratorWithPrefix([]byte("dog\xff"))
if err != nil {
t.Fatalf("Failed to create iterator: %v", err)
}
found := make(map[string]bool)
for iter.Next(true) {
if iter.Leaf() {
found[string(iter.LeafKey())] = true
}
}
// Should find "dog\xff" and "dog\xff\xff"
if !found["dog\xff"] || !found["dog\xff\xff"] {
t.Errorf("Missing expected keys with 0xff: got %v", found)
}
if found["dog"] {
t.Errorf("Found unexpected key 'dog' with prefix 'dog\\xff'")
}
})
// Test 4: All 0xff case (edge case for nextKey)
t.Run("AllFFPrefix", func(t *testing.T) {
// Add a key with all 0xff bytes
allFF := []byte{0xff, 0xff}
trie.Update(allFF, []byte("all_ff_value"))
trie.Update(append(allFF, 0x00), []byte("all_ff_plus"))
iter, err := trie.NodeIteratorWithPrefix(allFF)
if err != nil {
t.Fatalf("Failed to create iterator: %v", err)
}
count := 0
for iter.Next(true) {
if iter.Leaf() {
count++
}
}
// Should find at least the allFF key itself
if count != 2 {
t.Errorf("Expected at least 1 result for all-0xff prefix, got %d", count)
}
})
// Test 5: Empty prefix (should iterate entire trie)
t.Run("EmptyPrefix", func(t *testing.T) {
iter, err := trie.NodeIteratorWithPrefix([]byte{})
if err != nil {
t.Fatalf("Failed to create iterator: %v", err)
}
count := 0
for iter.Next(true) {
if iter.Leaf() {
count++
}
}
// Should find all keys in the trie
expectedCount := len(testData) + 2 // +2 for the extra keys added in test 4
if count != expectedCount {
t.Errorf("Expected %d results for empty prefix, got %d", expectedCount, count)
}
})
}
// TestGeneralRangeIteration tests NewSubtreeIterator with arbitrary start/stop ranges
func TestGeneralRangeIteration(t *testing.T) {
// Create a trie with test data
trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
testData := map[string]string{
"apple": "fruit1",
"apricot": "fruit2",
"banana": "fruit3",
"cherry": "fruit4",
"date": "fruit5",
"fig": "fruit6",
"grape": "fruit7",
}
for key, value := range testData {
trie.Update([]byte(key), []byte(value))
}
// Test range iteration from "banana" to "fig" (exclusive)
t.Run("RangeIteration", func(t *testing.T) {
iter, _ := newSubtreeIterator(trie, []byte("banana"), []byte("fig"))
found := make(map[string]bool)
for iter.Next(true) {
if iter.Leaf() {
found[string(iter.LeafKey())] = true
}
}
// Should find "banana", "cherry", "date" but not "fig"
if !found["banana"] || !found["cherry"] || !found["date"] {
t.Errorf("Missing expected keys in range: got %v", found)
}
if found["apple"] || found["apricot"] || found["fig"] || found["grape"] {
t.Errorf("Found unexpected keys outside range: got %v", found)
}
})
// Test with nil stopKey (iterate to end)
t.Run("NilStopKey", func(t *testing.T) {
iter, _ := newSubtreeIterator(trie, []byte("date"), nil)
found := make(map[string]bool)
for iter.Next(true) {
if iter.Leaf() {
found[string(iter.LeafKey())] = true
}
}
// Should find "date", "fig", "grape"
if !found["date"] || !found["fig"] || !found["grape"] {
t.Errorf("Missing expected keys from 'date' to end: got %v", found)
}
if found["apple"] || found["banana"] || found["cherry"] {
t.Errorf("Found unexpected keys before 'date': got %v", found)
}
})
// Test with nil startKey (iterate from beginning)
t.Run("NilStartKey", func(t *testing.T) {
iter, _ := newSubtreeIterator(trie, nil, []byte("cherry"))
found := make(map[string]bool)
for iter.Next(true) {
if iter.Leaf() {
found[string(iter.LeafKey())] = true
}
}
// Should find "apple", "apricot", "banana" but not "cherry" or later
if !found["apple"] || !found["apricot"] || !found["banana"] {
t.Errorf("Missing expected keys before 'cherry': got %v", found)
}
if found["cherry"] || found["date"] || found["fig"] || found["grape"] {
t.Errorf("Found unexpected keys at or after 'cherry': got %v", found)
}
})
}
// TestPrefixIteratorWithDescend tests prefix iteration with descend=false
func TestPrefixIteratorWithDescend(t *testing.T) {
// Create a trie with nested structure
trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme))
testData := map[string]string{
"a": "value_a",
"a/b": "value_ab",
"a/b/c": "value_abc",
"a/b/d": "value_abd",
"a/e": "value_ae",
"b": "value_b",
}
for key, value := range testData {
trie.Update([]byte(key), []byte(value))
}
// Test skipping subtrees with descend=false
t.Run("SkipSubtrees", func(t *testing.T) {
iter, err := trie.NodeIteratorWithPrefix([]byte("a"))
if err != nil {
t.Fatalf("Failed to create iterator: %v", err)
}
// Count nodes at each level
nodesVisited := 0
leafsFound := make(map[string]bool)
// First call with descend=true to enter the "a" subtree
if !iter.Next(true) {
t.Fatal("Expected to find at least one node")
}
nodesVisited++
// Continue iteration, sometimes with descend=false
descendPattern := []bool{false, true, false, true, true}
for i := 0; iter.Next(descendPattern[i%len(descendPattern)]); i++ {
nodesVisited++
if iter.Leaf() {
leafsFound[string(iter.LeafKey())] = true
}
}
// We should still respect the prefix boundary even when skipping
prefix := []byte("a")
for key := range leafsFound {
if !bytes.HasPrefix([]byte(key), prefix) {
t.Errorf("Found key outside prefix when using descend=false: %s", key)
}
}
// Should not have found "b" even if we skip some subtrees
if leafsFound["b"] {
t.Error("Iterator leaked outside prefix boundary with descend=false")
}
})
}
func BenchmarkIterator(b *testing.B) {
diskDb, srcDb, tr, _ := makeTestTrie(rawdb.HashScheme)
root := tr.Hash()

View file

@ -134,6 +134,34 @@ func (t *Trie) NodeIterator(start []byte) (NodeIterator, error) {
return newNodeIterator(t, start), nil
}
// NodeIteratorWithPrefix returns an iterator that returns nodes of the trie
// whose leaf keys start with the given prefix. Iteration includes all keys
// where prefix <= k < nextKey(prefix), effectively returning only keys that
// have the prefix. The iteration stops once it would encounter a key that
// doesn't start with the prefix.
//
// For example, with prefix "dog", the iterator will return "dog", "dogcat",
// "dogfish" but not "dot" or "fog". An empty prefix iterates the entire trie.
func (t *Trie) NodeIteratorWithPrefix(prefix []byte) (NodeIterator, error) {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return nil, ErrCommitted
}
// Use the dedicated prefix iterator which handles prefix checking correctly
return newPrefixIterator(t, prefix)
}
// NodeIteratorWithRange returns an iterator over trie nodes whose leaf keys
// fall within the specified range. It includes all keys where start <= k < end.
// Iteration stops once a key beyond the end boundary is encountered.
func (t *Trie) NodeIteratorWithRange(start, end []byte) (NodeIterator, error) {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return nil, ErrCommitted
}
return newSubtreeIterator(t, start, end)
}
// MustGet is a wrapper of Get and will omit any encountered error but just
// print out an error message.
func (t *Trie) MustGet(key []byte) []byte {