diff --git a/trie/iterator.go b/trie/iterator.go index 80298ce48f..3d3191ffba 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -836,3 +836,91 @@ func (it *unionIterator) Error() error { } return nil } + +// subTreeIterator wraps nodeIterator to traverse a trie within a predefined +// start and limit range. +type subtreeIterator struct { + NodeIterator + + stopPath []byte // Precomputed hex path for stopKey (without terminator), nil means no limit + exhausted bool // Flag whether the iterator has been exhausted +} + +// newSubtreeIterator creates an iterator that only traverses nodes within a subtree +// defined by the given startKey and stopKey. This supports general range iteration +// where startKey is inclusive and stopKey is exclusive. +// +// The iterator will only visit nodes whose keys k satisfy: startKey <= k < stopKey, +// where comparisons are performed in lexicographic order of byte keys (internally +// implemented via hex-nibble path comparisons for efficiency). +// +// If startKey is nil, iteration starts from the beginning. If stopKey is nil, +// iteration continues to the end of the trie. +func newSubtreeIterator(trie *Trie, startKey, stopKey []byte) (NodeIterator, error) { + it, err := trie.NodeIterator(startKey) + if err != nil { + return nil, err + } + if startKey == nil && stopKey == nil { + return it, nil + } + // Precompute nibble paths for efficient comparison + var stopPath []byte + if stopKey != nil { + stopPath = keybytesToHex(stopKey) + if hasTerm(stopPath) { + stopPath = stopPath[:len(stopPath)-1] + } + } + return &subtreeIterator{ + NodeIterator: it, + stopPath: stopPath, + }, nil +} + +// nextKey returns the next possible key after the given prefix. +// For example, "abc" -> "abd", "ab\xff" -> "ac", etc. +func nextKey(prefix []byte) []byte { + if len(prefix) == 0 { + return nil + } + // Make a copy to avoid modifying the original + next := make([]byte, len(prefix)) + copy(next, prefix) + + // Increment the last byte that isn't 0xff + for i := len(next) - 1; i >= 0; i-- { + if next[i] < 0xff { + next[i]++ + return next + } + // If it's 0xff, we need to carry over + // Trim trailing 0xff bytes + next = next[:i] + } + // If all bytes were 0xff, return nil (no upper bound) + return nil +} + +// newPrefixIterator creates an iterator that only traverses nodes with the given prefix. +// This ensures that only keys starting with the prefix are visited. +func newPrefixIterator(trie *Trie, prefix []byte) (NodeIterator, error) { + return newSubtreeIterator(trie, prefix, nextKey(prefix)) +} + +// Next moves the iterator to the next node. If the parameter is false, any child +// nodes will be skipped. +func (it *subtreeIterator) Next(descend bool) bool { + if it.exhausted { + return false + } + if !it.NodeIterator.Next(descend) { + it.exhausted = true + return false + } + if it.stopPath != nil && reachedPath(it.NodeIterator.Path(), it.stopPath) { + it.exhausted = true + return false + } + return true +} diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 74a1aa378c..f1451cef90 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -19,7 +19,9 @@ package trie import ( "bytes" "fmt" + "maps" "math/rand" + "slices" "testing" "github.com/ethereum/go-ethereum/common" @@ -624,6 +626,540 @@ func isTrieNode(scheme string, key, val []byte) (bool, []byte, common.Hash) { return true, path, hash } +func TestSubtreeIterator(t *testing.T) { + var ( + db = newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme) + tr = NewEmpty(db) + ) + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"dog", "puppy"}, + {"doge", "coin"}, + {"dog\xff", "value6"}, + {"dog\xff\xff", "value7"}, + {"horse", "stallion"}, + {"house", "building"}, + {"houses", "multiple"}, + {"xyz", "value"}, + {"xyz\xff", "value"}, + {"xyz\xff\xff", "value"}, + } + all := make(map[string]string) + for _, val := range vals { + all[val.k] = val.v + tr.MustUpdate([]byte(val.k), []byte(val.v)) + } + root, nodes := tr.Commit(false) + db.Update(root, types.EmptyRootHash, trienode.NewWithNodeSet(nodes)) + + allNodes := make(map[string][]byte) + tr, _ = New(TrieID(root), db) + it, err := tr.NodeIterator(nil) + if err != nil { + t.Fatal(err) + } + for it.Next(true) { + allNodes[string(it.Path())] = it.NodeBlob() + } + allKeys := slices.Collect(maps.Keys(all)) + + suites := []struct { + start []byte + end []byte + expected []string + }{ + // entire key range + { + start: nil, + end: nil, + expected: allKeys, + }, + { + start: nil, + end: bytes.Repeat([]byte{0xff}, 32), + expected: allKeys, + }, + { + start: bytes.Repeat([]byte{0x0}, 32), + end: bytes.Repeat([]byte{0xff}, 32), + expected: allKeys, + }, + // key range with start + { + start: []byte("do"), + end: nil, + expected: allKeys, + }, + { + start: []byte("doe"), + end: nil, + expected: allKeys[1:], + }, + { + start: []byte("dog"), + end: nil, + expected: allKeys[1:], + }, + { + start: []byte("doge"), + end: nil, + expected: allKeys[2:], + }, + { + start: []byte("dog\xff"), + end: nil, + expected: allKeys[3:], + }, + { + start: []byte("dog\xff\xff"), + end: nil, + expected: allKeys[4:], + }, + { + start: []byte("dog\xff\xff\xff"), + end: nil, + expected: allKeys[5:], + }, + // key range with limit + { + start: nil, + end: []byte("xyz"), + expected: allKeys[:len(allKeys)-3], + }, + { + start: nil, + end: []byte("xyz\xff"), + expected: allKeys[:len(allKeys)-2], + }, + { + start: nil, + end: []byte("xyz\xff\xff"), + expected: allKeys[:len(allKeys)-1], + }, + { + start: nil, + end: []byte("xyz\xff\xff\xff"), + expected: allKeys, + }, + } + for _, suite := range suites { + // We need to re-open the trie from the committed state + tr, _ = New(TrieID(root), db) + it, err := newSubtreeIterator(tr, suite.start, suite.end) + if err != nil { + t.Fatal(err) + } + + found := make(map[string]string) + for it.Next(true) { + if it.Leaf() { + found[string(it.LeafKey())] = string(it.LeafBlob()) + } + } + if len(found) != len(suite.expected) { + t.Errorf("wrong number of values: got %d, want %d", len(found), len(suite.expected)) + } + for k, v := range found { + if all[k] != v { + t.Errorf("wrong value for %s: got %s, want %s", k, found[k], all[k]) + } + } + + expectedNodes := make(map[string][]byte) + for path, blob := range allNodes { + if suite.start != nil { + hexStart := keybytesToHex(suite.start) + hexStart = hexStart[:len(hexStart)-1] + if !reachedPath([]byte(path), hexStart) { + continue + } + } + if suite.end != nil { + hexEnd := keybytesToHex(suite.end) + hexEnd = hexEnd[:len(hexEnd)-1] + if reachedPath([]byte(path), hexEnd) { + continue + } + } + expectedNodes[path] = bytes.Clone(blob) + } + + // Compare the result yield from the subtree iterator + var ( + subCount int + subIt, _ = newSubtreeIterator(tr, suite.start, suite.end) + ) + for subIt.Next(true) { + blob, ok := expectedNodes[string(subIt.Path())] + if !ok { + t.Errorf("Unexpected node iterated, path: %v", subIt.Path()) + } + subCount++ + + if !bytes.Equal(blob, subIt.NodeBlob()) { + t.Errorf("Unexpected node blob, path: %v, want: %v, got: %v", subIt.Path(), blob, subIt.NodeBlob()) + } + } + if subCount != len(expectedNodes) { + t.Errorf("Unexpected node being iterated, want: %d, got: %d", len(expectedNodes), subCount) + } + } +} + +func TestPrefixIterator(t *testing.T) { + // Create a new trie + trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)) + + // Insert test data + testData := map[string]string{ + "key1": "value1", + "key2": "value2", + "key10": "value10", + "key11": "value11", + "different": "value_different", + } + + for key, value := range testData { + trie.Update([]byte(key), []byte(value)) + } + + // Test prefix iteration for "key1" prefix + prefix := []byte("key1") + iter, err := trie.NodeIteratorWithPrefix(prefix) + if err != nil { + t.Fatalf("Failed to create prefix iterator: %v", err) + } + + var foundKeys [][]byte + for iter.Next(true) { + if iter.Leaf() { + foundKeys = append(foundKeys, iter.LeafKey()) + } + } + + if err := iter.Error(); err != nil { + t.Fatalf("Iterator error: %v", err) + } + + // Verify only keys starting with "key1" were found + expectedCount := 3 // "key1", "key10", "key11" + if len(foundKeys) != expectedCount { + t.Errorf("Expected %d keys, found %d", expectedCount, len(foundKeys)) + } + + for _, key := range foundKeys { + keyStr := string(key) + if !bytes.HasPrefix(key, prefix) { + t.Errorf("Found key %s doesn't have prefix %s", keyStr, string(prefix)) + } + } +} + +func TestPrefixIteratorVsFullIterator(t *testing.T) { + // Create a new trie with more structured data + trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)) + + // Insert structured test data + testData := map[string]string{ + "aaa": "value_aaa", + "aab": "value_aab", + "aba": "value_aba", + "bbb": "value_bbb", + } + + for key, value := range testData { + trie.Update([]byte(key), []byte(value)) + } + + // Test that prefix iterator stops at boundary + prefix := []byte("aa") + prefixIter, err := trie.NodeIteratorWithPrefix(prefix) + if err != nil { + t.Fatalf("Failed to create prefix iterator: %v", err) + } + + var prefixKeys [][]byte + for prefixIter.Next(true) { + if prefixIter.Leaf() { + prefixKeys = append(prefixKeys, prefixIter.LeafKey()) + } + } + + // Should only find "aaa" and "aab", not "aba" or "bbb" + if len(prefixKeys) != 2 { + t.Errorf("Expected 2 keys with prefix 'aa', found %d", len(prefixKeys)) + } + + // Verify no keys outside prefix were found + for _, key := range prefixKeys { + if !bytes.HasPrefix(key, prefix) { + t.Errorf("Prefix iterator returned key %s outside prefix %s", string(key), string(prefix)) + } + } +} + +func TestEmptyPrefixIterator(t *testing.T) { + // Test with empty trie + trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)) + + iter, err := trie.NodeIteratorWithPrefix([]byte("nonexistent")) + if err != nil { + t.Fatalf("Failed to create iterator: %v", err) + } + + if iter.Next(true) { + t.Error("Expected no results from empty trie") + } +} + +// TestPrefixIteratorEdgeCases tests various edge cases for prefix iteration +func TestPrefixIteratorEdgeCases(t *testing.T) { + // Create a trie with test data + trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)) + testData := map[string]string{ + "abc": "value1", + "abcd": "value2", + "abce": "value3", + "abd": "value4", + "dog": "value5", + "dog\xff": "value6", // Test with 0xff byte + "dog\xff\xff": "value7", // Multiple 0xff bytes + } + for key, value := range testData { + trie.Update([]byte(key), []byte(value)) + } + + // Test 1: Prefix not present in trie + t.Run("NonexistentPrefix", func(t *testing.T) { + iter, err := trie.NodeIteratorWithPrefix([]byte("xyz")) + if err != nil { + t.Fatalf("Failed to create iterator: %v", err) + } + count := 0 + for iter.Next(true) { + if iter.Leaf() { + count++ + } + } + if count != 0 { + t.Errorf("Expected 0 results for nonexistent prefix, got %d", count) + } + }) + + // Test 2: Prefix exactly equals an existing key + t.Run("ExactKeyPrefix", func(t *testing.T) { + iter, err := trie.NodeIteratorWithPrefix([]byte("abc")) + if err != nil { + t.Fatalf("Failed to create iterator: %v", err) + } + found := make(map[string]bool) + for iter.Next(true) { + if iter.Leaf() { + found[string(iter.LeafKey())] = true + } + } + // Should find "abc", "abcd", "abce" but not "abd" + if !found["abc"] || !found["abcd"] || !found["abce"] { + t.Errorf("Missing expected keys: got %v", found) + } + if found["abd"] { + t.Errorf("Found unexpected key 'abd' with prefix 'abc'") + } + }) + + // Test 3: Prefix with trailing 0xff + t.Run("TrailingFFPrefix", func(t *testing.T) { + iter, err := trie.NodeIteratorWithPrefix([]byte("dog\xff")) + if err != nil { + t.Fatalf("Failed to create iterator: %v", err) + } + found := make(map[string]bool) + for iter.Next(true) { + if iter.Leaf() { + found[string(iter.LeafKey())] = true + } + } + // Should find "dog\xff" and "dog\xff\xff" + if !found["dog\xff"] || !found["dog\xff\xff"] { + t.Errorf("Missing expected keys with 0xff: got %v", found) + } + if found["dog"] { + t.Errorf("Found unexpected key 'dog' with prefix 'dog\\xff'") + } + }) + + // Test 4: All 0xff case (edge case for nextKey) + t.Run("AllFFPrefix", func(t *testing.T) { + // Add a key with all 0xff bytes + allFF := []byte{0xff, 0xff} + trie.Update(allFF, []byte("all_ff_value")) + trie.Update(append(allFF, 0x00), []byte("all_ff_plus")) + + iter, err := trie.NodeIteratorWithPrefix(allFF) + if err != nil { + t.Fatalf("Failed to create iterator: %v", err) + } + count := 0 + for iter.Next(true) { + if iter.Leaf() { + count++ + } + } + // Should find at least the allFF key itself + if count != 2 { + t.Errorf("Expected at least 1 result for all-0xff prefix, got %d", count) + } + }) + + // Test 5: Empty prefix (should iterate entire trie) + t.Run("EmptyPrefix", func(t *testing.T) { + iter, err := trie.NodeIteratorWithPrefix([]byte{}) + if err != nil { + t.Fatalf("Failed to create iterator: %v", err) + } + count := 0 + for iter.Next(true) { + if iter.Leaf() { + count++ + } + } + // Should find all keys in the trie + expectedCount := len(testData) + 2 // +2 for the extra keys added in test 4 + if count != expectedCount { + t.Errorf("Expected %d results for empty prefix, got %d", expectedCount, count) + } + }) +} + +// TestGeneralRangeIteration tests NewSubtreeIterator with arbitrary start/stop ranges +func TestGeneralRangeIteration(t *testing.T) { + // Create a trie with test data + trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)) + testData := map[string]string{ + "apple": "fruit1", + "apricot": "fruit2", + "banana": "fruit3", + "cherry": "fruit4", + "date": "fruit5", + "fig": "fruit6", + "grape": "fruit7", + } + for key, value := range testData { + trie.Update([]byte(key), []byte(value)) + } + + // Test range iteration from "banana" to "fig" (exclusive) + t.Run("RangeIteration", func(t *testing.T) { + iter, _ := newSubtreeIterator(trie, []byte("banana"), []byte("fig")) + found := make(map[string]bool) + for iter.Next(true) { + if iter.Leaf() { + found[string(iter.LeafKey())] = true + } + } + // Should find "banana", "cherry", "date" but not "fig" + if !found["banana"] || !found["cherry"] || !found["date"] { + t.Errorf("Missing expected keys in range: got %v", found) + } + if found["apple"] || found["apricot"] || found["fig"] || found["grape"] { + t.Errorf("Found unexpected keys outside range: got %v", found) + } + }) + + // Test with nil stopKey (iterate to end) + t.Run("NilStopKey", func(t *testing.T) { + iter, _ := newSubtreeIterator(trie, []byte("date"), nil) + found := make(map[string]bool) + for iter.Next(true) { + if iter.Leaf() { + found[string(iter.LeafKey())] = true + } + } + // Should find "date", "fig", "grape" + if !found["date"] || !found["fig"] || !found["grape"] { + t.Errorf("Missing expected keys from 'date' to end: got %v", found) + } + if found["apple"] || found["banana"] || found["cherry"] { + t.Errorf("Found unexpected keys before 'date': got %v", found) + } + }) + + // Test with nil startKey (iterate from beginning) + t.Run("NilStartKey", func(t *testing.T) { + iter, _ := newSubtreeIterator(trie, nil, []byte("cherry")) + found := make(map[string]bool) + for iter.Next(true) { + if iter.Leaf() { + found[string(iter.LeafKey())] = true + } + } + // Should find "apple", "apricot", "banana" but not "cherry" or later + if !found["apple"] || !found["apricot"] || !found["banana"] { + t.Errorf("Missing expected keys before 'cherry': got %v", found) + } + if found["cherry"] || found["date"] || found["fig"] || found["grape"] { + t.Errorf("Found unexpected keys at or after 'cherry': got %v", found) + } + }) +} + +// TestPrefixIteratorWithDescend tests prefix iteration with descend=false +func TestPrefixIteratorWithDescend(t *testing.T) { + // Create a trie with nested structure + trie := NewEmpty(newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.HashScheme)) + testData := map[string]string{ + "a": "value_a", + "a/b": "value_ab", + "a/b/c": "value_abc", + "a/b/d": "value_abd", + "a/e": "value_ae", + "b": "value_b", + } + for key, value := range testData { + trie.Update([]byte(key), []byte(value)) + } + + // Test skipping subtrees with descend=false + t.Run("SkipSubtrees", func(t *testing.T) { + iter, err := trie.NodeIteratorWithPrefix([]byte("a")) + if err != nil { + t.Fatalf("Failed to create iterator: %v", err) + } + + // Count nodes at each level + nodesVisited := 0 + leafsFound := make(map[string]bool) + + // First call with descend=true to enter the "a" subtree + if !iter.Next(true) { + t.Fatal("Expected to find at least one node") + } + nodesVisited++ + + // Continue iteration, sometimes with descend=false + descendPattern := []bool{false, true, false, true, true} + for i := 0; iter.Next(descendPattern[i%len(descendPattern)]); i++ { + nodesVisited++ + if iter.Leaf() { + leafsFound[string(iter.LeafKey())] = true + } + } + + // We should still respect the prefix boundary even when skipping + prefix := []byte("a") + for key := range leafsFound { + if !bytes.HasPrefix([]byte(key), prefix) { + t.Errorf("Found key outside prefix when using descend=false: %s", key) + } + } + + // Should not have found "b" even if we skip some subtrees + if leafsFound["b"] { + t.Error("Iterator leaked outside prefix boundary with descend=false") + } + }) +} + func BenchmarkIterator(b *testing.B) { diskDb, srcDb, tr, _ := makeTestTrie(rawdb.HashScheme) root := tr.Hash() diff --git a/trie/trie.go b/trie/trie.go index 630462f8ca..36cc732ee8 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -134,6 +134,34 @@ func (t *Trie) NodeIterator(start []byte) (NodeIterator, error) { return newNodeIterator(t, start), nil } +// NodeIteratorWithPrefix returns an iterator that returns nodes of the trie +// whose leaf keys start with the given prefix. Iteration includes all keys +// where prefix <= k < nextKey(prefix), effectively returning only keys that +// have the prefix. The iteration stops once it would encounter a key that +// doesn't start with the prefix. +// +// For example, with prefix "dog", the iterator will return "dog", "dogcat", +// "dogfish" but not "dot" or "fog". An empty prefix iterates the entire trie. +func (t *Trie) NodeIteratorWithPrefix(prefix []byte) (NodeIterator, error) { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return nil, ErrCommitted + } + // Use the dedicated prefix iterator which handles prefix checking correctly + return newPrefixIterator(t, prefix) +} + +// NodeIteratorWithRange returns an iterator over trie nodes whose leaf keys +// fall within the specified range. It includes all keys where start <= k < end. +// Iteration stops once a key beyond the end boundary is encountered. +func (t *Trie) NodeIteratorWithRange(start, end []byte) (NodeIterator, error) { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return nil, ErrCommitted + } + return newSubtreeIterator(t, start, end) +} + // MustGet is a wrapper of Get and will omit any encountered error but just // print out an error message. func (t *Trie) MustGet(key []byte) []byte {