diff --git a/triedb/generate.go b/triedb/generate.go
index 259e139848..b61c4d2096 100644
--- a/triedb/generate.go
+++ b/triedb/generate.go
@@ -17,14 +17,38 @@
package triedb
import (
+ "bytes"
+ "context"
"fmt"
+ "math/big"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/ethdb/memorydb"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/triedb/internal"
+ "golang.org/x/sync/errgroup"
)
+// ErrCancelled is returned when GenerateTrie is aborted via its cancel
+// channel before completing.
+var ErrCancelled = internal.ErrCancelled
+
+// updateStorageRootsProgressPrefix is the key prefix used to persist a
+// per-partition progress marker during updateStorageRoots.
+var updateStorageRootsProgressPrefix = []byte("triedb-updsr-")
+
+func updateStorageRootsProgressKey(partition int) []byte {
+ return append(updateStorageRootsProgressPrefix, byte(partition))
+}
+
// kvAccountIterator wraps an ethdb.Iterator to iterate over account snapshot
// entries in the database, implementing internal.AccountIterator.
type kvAccountIterator struct {
@@ -80,24 +104,233 @@ func (it *kvStorageIterator) Slot() []byte { return it.it.Value() }
func (it *kvStorageIterator) Error() error { return it.it.Error() }
func (it *kvStorageIterator) Release() { it.it.Release() }
-// GenerateTrie rebuilds all tries (storage + account) from flat snapshot data
-// in the database. It reads account and storage snapshots from the KV store,
-// builds tries using StackTrie with streaming node writes, and verifies the
-// computed state root matches the expected root.
-func GenerateTrie(db ethdb.Database, scheme string, root common.Hash) error {
+// rangeIterators bundles the per-partition account and storage iterators.
+type rangeIterators struct {
+ db ethdb.Database
+ acct *internal.HoldableIterator
+ stor *internal.HoldableIterator
+}
+
+func openRangeIterators(db ethdb.Database, start common.Hash) *rangeIterators {
+ return &rangeIterators{
+ db: db,
+ acct: openFlatIterator(db, rawdb.SnapshotAccountPrefix, start[:], common.HashLength),
+ stor: openFlatIterator(db, rawdb.SnapshotStoragePrefix, start[:], 2*common.HashLength),
+ }
+}
+
+// reopen releases both iterators and reopens them at their current
+// positions. Invoked after each batch flush so pebble compactions aren't
+// blocked by long-lived iterator snapshots. Follows the same pattern as
+// triedb/pathdb/context.go.
+func (r *rangeIterators) reopen() {
+ r.acct = reopenFlatIterator(r.db, r.acct, rawdb.SnapshotAccountPrefix, common.HashLength)
+ r.stor = reopenFlatIterator(r.db, r.stor, rawdb.SnapshotStoragePrefix, 2*common.HashLength)
+}
+
+func (r *rangeIterators) release() {
+ r.acct.Release()
+ r.stor.Release()
+}
+
+// openFlatIterator opens a length-filtered HoldableIterator over a snapshot
+// prefix, seeked to the given start key (relative to the prefix).
+func openFlatIterator(db ethdb.Database, prefix, start []byte, suffixLen int) *internal.HoldableIterator {
+ it := db.NewIterator(prefix, start)
+ return internal.NewHoldableIterator(rawdb.NewKeyLengthIterator(it, len(prefix)+suffixLen))
+}
+
+// reopenFlatIterator releases `old` and returns a new HoldableIterator
+// positioned at the same key, or an empty iterator if `old` is exhausted.
+func reopenFlatIterator(db ethdb.Database, old *internal.HoldableIterator, prefix []byte, suffixLen int) *internal.HoldableIterator {
+ if !old.Next() {
+ old.Release()
+ return internal.NewHoldableIterator(memorydb.New().NewIterator(nil, nil))
+ }
+ next := old.Key()
+ old.Release()
+ return openFlatIterator(db, prefix, next[len(prefix):], suffixLen)
+}
+
+// updateStorageRoots walks flat-state accounts and updates each account's
+// Root to match the storage root computed from its flat storage slots.
+func updateStorageRoots(db ethdb.Database, cancel <-chan struct{}) error {
+ start := time.Now()
+ threads := runtime.NumCPU()
+ var (
+ batchMu sync.Mutex
+ batch = db.NewBatch()
+ scanned atomic.Int64
+ updated atomic.Int64
+ )
+ eg, ctx := errgroup.WithContext(context.Background())
+
+ // Spawn one worker per hash-space partition. Each walker handles its
+ // [rangeStart, rangeEnd] slice independently. errgroup cancels ctx
+ // on the first error so peers exit.
+ for i, r := range hashRanges(threads) {
+ partition := i
+ rangeStart, rangeEnd := r[0], r[1]
+ eg.Go(func() error {
+ return updateStorageRootsInRange(ctx, cancel, db, partition, rangeStart, rangeEnd, &batchMu, batch, &scanned, &updated)
+ })
+ }
+ if err := eg.Wait(); err != nil {
+ return err
+ }
+
+ // Clean up the progress markers now that every partition has finished
+ // successfully.
+ for i := 0; i < threads; i++ {
+ batch.Delete(updateStorageRootsProgressKey(i))
+ }
+ if err := batch.Write(); err != nil {
+ return fmt.Errorf("final batch write: %w", err)
+ }
+ log.Info("Updated stale storage roots", "scanned", scanned.Load(), "updated", updated.Load(), "elapsed", common.PrettyDuration(time.Since(start)))
+ return nil
+}
+
+// updateStorageRootsInRange walks accounts whose hashes fall inside
+// [rangeStart, rangeEnd] and fixes each account's Root to match its flat
+// storage.
+func updateStorageRootsInRange(ctx context.Context, cancel <-chan struct{}, db ethdb.Database, partition int, rangeStart, rangeEnd common.Hash, batchMu *sync.Mutex, batch ethdb.Batch, scanned, updated *atomic.Int64) error {
+ iters := openRangeIterators(db, rangeStart)
+ defer iters.release()
+
+ // Iterate through all the accounts.
+ for iters.acct.Next() {
+ select {
+ case <-cancel:
+ return ErrCancelled
+ case <-ctx.Done():
+ return nil
+ default:
+ }
+ key := iters.acct.Key()
+ var accountHash common.Hash
+ copy(accountHash[:], key[len(rawdb.SnapshotAccountPrefix):])
+ if bytes.Compare(accountHash[:], rangeEnd[:]) > 0 {
+ return nil
+ }
+ scanned.Add(1)
+ account, err := types.FullAccount(iters.acct.Value())
+ if err != nil {
+ return fmt.Errorf("decode account %x: %w", accountHash, err)
+ }
+
+ // Compute the storage root by consuming matching slots from the
+ // shared storage iterator. The inner loop terminates on Hold()
+ // (slot belongs to a later account) or exhaustion.
+ t := trie.NewStackTrie(nil)
+ for iters.stor.Next() {
+ sk := iters.stor.Key()
+ storAcc := sk[len(rawdb.SnapshotStoragePrefix) : len(rawdb.SnapshotStoragePrefix)+common.HashLength]
+ cmp := bytes.Compare(storAcc, accountHash[:])
+
+ // The slot belongs to an account whose hash is before the one we're
+ // processing. This only happens if an account was deleted but its flat
+ // storage wasn't cleaned up. Skip the orphaned slot and advance.
+ if cmp < 0 {
+ continue
+ }
+
+ // The slot belongs to a later account. We're done with the current
+ // account's slots, but we don't want to lose this slot. The slot might
+ // belong to the next iteration of the account for-loop (or a later one).
+ // Hold() the iterator so the next Next() call will re-serve this same
+ // entry instead of advancing past it.
+ if cmp > 0 {
+ iters.stor.Hold()
+ break
+ }
+
+ // The slot belongs to this account so we add it to the StackTrie.
+ slotHash := sk[len(rawdb.SnapshotStoragePrefix)+common.HashLength:]
+ if err := t.Update(slotHash, iters.stor.Value()); err != nil {
+ return fmt.Errorf("stack trie update for %x: %w", accountHash, err)
+ }
+ }
+ if err := iters.stor.Error(); err != nil {
+ return fmt.Errorf("storage iterator: %w", err)
+ }
+ computed := t.Hash()
+
+ // Update the account, progress marker, and (possibly) the batch.
+ var (
+ flushed bool
+ flushErr error
+ )
+ batchMu.Lock()
+ if computed != account.Root {
+ account.Root = computed
+ rawdb.WriteAccountSnapshot(batch, accountHash, types.SlimAccountRLP(*account))
+ updated.Add(1)
+ }
+ batch.Put(updateStorageRootsProgressKey(partition), accountHash[:])
+ if batch.ValueSize() > ethdb.IdealBatchSize {
+ flushErr = batch.Write()
+ if flushErr == nil {
+ batch.Reset()
+ flushed = true
+ }
+ }
+ batchMu.Unlock()
+ if flushErr != nil {
+ return fmt.Errorf("flush batch: %w", flushErr)
+ }
+ if flushed {
+ iters.reopen()
+ }
+ }
+ return iters.acct.Error()
+}
+
+// hashRanges returns hash pairs [start, end] that evenly partition the
+// 256-bit hash space. The last partition absorbs the remainder so rounding
+// doesn't leave hashes uncovered.
+func hashRanges(total int) [][2]common.Hash {
+ step := new(big.Int).Sub(
+ new(big.Int).Div(
+ new(big.Int).Exp(common.Big2, common.Big256, nil),
+ big.NewInt(int64(total)),
+ ),
+ common.Big1,
+ )
+ ranges := make([][2]common.Hash, total)
+ var next common.Hash
+ for i := 0; i < total; i++ {
+ last := common.BigToHash(new(big.Int).Add(next.Big(), step))
+ if i == total-1 {
+ last = common.MaxHash
+ }
+ ranges[i] = [2]common.Hash{next, last}
+ next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
+ }
+ return ranges
+}
+
+// GenerateTrie rebuilds all tries (storage + account) from flat snapshot
+// data in the database. It first brings every account's Root into
+// agreement with its flat storage, then builds tries using StackTrie with
+// streaming node writes, and verifies that the computed state root matches
+// the expected root.
+func GenerateTrie(db ethdb.Database, scheme string, root common.Hash, cancel <-chan struct{}) error {
+ if err := updateStorageRoots(db, cancel); err != nil {
+ return err
+ }
acctIt := newKVAccountIterator(db)
defer acctIt.Release()
-
got, err := internal.GenerateTrieRoot(db, scheme, acctIt, common.Hash{}, internal.StackTrieGenerate, func(dst ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *internal.GenerateStats) (common.Hash, error) {
storageIt := newKVStorageIterator(db, accountHash)
defer storageIt.Release()
- hash, err := internal.GenerateTrieRoot(dst, scheme, storageIt, accountHash, internal.StackTrieGenerate, nil, stat, false)
+ hash, err := internal.GenerateTrieRoot(dst, scheme, storageIt, accountHash, internal.StackTrieGenerate, nil, stat, false, cancel)
if err != nil {
return common.Hash{}, err
}
return hash, nil
- }, internal.NewGenerateStats(), true)
+ }, internal.NewGenerateStats(), true, cancel)
if err != nil {
return err
}
diff --git a/triedb/generate_test.go b/triedb/generate_test.go
index 42bccd9aa3..71562752a7 100644
--- a/triedb/generate_test.go
+++ b/triedb/generate_test.go
@@ -24,6 +24,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"github.com/holiman/uint256"
@@ -60,8 +61,8 @@ func buildExpectedRoot(t *testing.T, accounts []testAccount) common.Hash {
return acctTrie.Hash()
}
-// computeStorageRoot computes the storage trie root from sorted slots.
-func computeStorageRoot(slots []testSlot) common.Hash {
+// computeStorageRootFromSlots computes the storage trie root from sorted slots.
+func computeStorageRootFromSlots(slots []testSlot) common.Hash {
sort.Slice(slots, func(i, j int) bool {
return bytes.Compare(slots[i].hash[:], slots[j].hash[:]) < 0
})
@@ -74,7 +75,7 @@ func computeStorageRoot(slots []testSlot) common.Hash {
func TestGenerateTrieEmpty(t *testing.T) {
db := rawdb.NewMemoryDatabase()
- if err := GenerateTrie(db, rawdb.HashScheme, types.EmptyRootHash); err != nil {
+ if err := GenerateTrie(db, rawdb.HashScheme, types.EmptyRootHash, nil); err != nil {
t.Fatalf("GenerateTrie on empty state failed: %v", err)
}
}
@@ -107,7 +108,7 @@ func TestGenerateTrieAccountsOnly(t *testing.T) {
}
root := buildExpectedRoot(t, accounts)
- if err := GenerateTrie(db, rawdb.HashScheme, root); err != nil {
+ if err := GenerateTrie(db, rawdb.HashScheme, root, nil); err != nil {
t.Fatalf("GenerateTrie failed: %v", err)
}
}
@@ -119,7 +120,7 @@ func TestGenerateTrieWithStorage(t *testing.T) {
{hash: common.HexToHash("0xaa"), value: []byte{0x01, 0x02, 0x03}},
{hash: common.HexToHash("0xbb"), value: []byte{0x04, 0x05, 0x06}},
}
- storageRoot := computeStorageRoot(slots)
+ storageRoot := computeStorageRootFromSlots(slots)
accounts := []testAccount{
{
@@ -154,7 +155,7 @@ func TestGenerateTrieWithStorage(t *testing.T) {
}
root := buildExpectedRoot(t, accounts)
- if err := GenerateTrie(db, rawdb.HashScheme, root); err != nil {
+ if err := GenerateTrie(db, rawdb.HashScheme, root, nil); err != nil {
t.Fatalf("GenerateTrie failed: %v", err)
}
}
@@ -171,8 +172,133 @@ func TestGenerateTrieRootMismatch(t *testing.T) {
rawdb.WriteAccountSnapshot(db, common.HexToHash("0x01"), types.SlimAccountRLP(acct))
wrongRoot := common.HexToHash("0xdeadbeef")
- err := GenerateTrie(db, rawdb.HashScheme, wrongRoot)
+ err := GenerateTrie(db, rawdb.HashScheme, wrongRoot, nil)
if err == nil {
t.Fatal("expected error for root mismatch, got nil")
}
}
+
+// TestGenerateTrieFixesStaleRoots writes flat state with a mix of stale,
+// empty, and correct account roots, then checks that GenerateTrie produces
+// the expected state root.
+func TestGenerateTrieFixesStaleRoots(t *testing.T) {
+ db := rawdb.NewMemoryDatabase()
+
+ const n = 300
+ accounts := make([]testAccount, 0, n)
+ for i := 0; i < n; i++ {
+ addr := common.BytesToAddress([]byte{byte(i >> 8), byte(i)})
+ hash := crypto.Keccak256Hash(addr[:])
+
+ acc := testAccount{
+ hash: hash,
+ account: types.StateAccount{
+ Nonce: uint64(i),
+ Balance: uint256.NewInt(uint64(i + 1)),
+ Root: types.EmptyRootHash,
+ CodeHash: types.EmptyCodeHash.Bytes(),
+ },
+ }
+ // Every third account has no storage; the rest get slots.
+ if i%3 != 0 {
+ acc.storage = []testSlot{
+ {hash: common.BytesToHash([]byte{byte(i), 0xaa}), value: []byte{byte(i), 0x01}},
+ {hash: common.BytesToHash([]byte{byte(i), 0xbb}), value: []byte{byte(i), 0x02}},
+ }
+ acc.account.Root = computeStorageRootFromSlots(acc.storage)
+ }
+ accounts = append(accounts, acc)
+ }
+ // Expected state root with all Roots correct.
+ expectedRoot := buildExpectedRoot(t, accounts)
+
+ // Write flat state. Storage-bearing accounts rotate through three on-disk
+ // Root states that GenerateTrie's pre-pass must all bring into alignment:
+ // - stale non-empty Root
+ // - stale empty Root
+ // - correct Root
+ for i, a := range accounts {
+ for _, s := range a.storage {
+ rawdb.WriteStorageSnapshot(db, a.hash, s.hash, s.value)
+ }
+ onDisk := a.account
+ if len(a.storage) > 0 {
+ switch i % 3 {
+ case 0:
+ onDisk.Root = common.BytesToHash([]byte{byte(i), 0xde, 0xad})
+ case 1:
+ onDisk.Root = types.EmptyRootHash
+ }
+ }
+ rawdb.WriteAccountSnapshot(db, a.hash, types.SlimAccountRLP(onDisk))
+ }
+
+ if err := GenerateTrie(db, rawdb.HashScheme, expectedRoot, nil); err != nil {
+ t.Fatalf("GenerateTrie failed: %v", err)
+ }
+}
+
+// TestUpdateStorageRootsCancel verifies updateStorageRoots respects the
+// cancel channel.
+func TestUpdateStorageRootsCancel(t *testing.T) {
+ t.Parallel()
+ db := rawdb.NewMemoryDatabase()
+
+ for i := 0; i < 100; i++ {
+ addr := common.BytesToAddress([]byte{byte(i)})
+ hash := crypto.Keccak256Hash(addr[:])
+ rawdb.WriteAccountSnapshot(db, hash, types.SlimAccountRLP(types.StateAccount{
+ Balance: uint256.NewInt(1),
+ Root: types.EmptyRootHash,
+ CodeHash: types.EmptyCodeHash[:],
+ }))
+ }
+
+ cancel := make(chan struct{})
+ close(cancel)
+ if err := updateStorageRoots(db, cancel); err != ErrCancelled {
+ t.Fatalf("expected ErrCancelled, got %v", err)
+ }
+}
+
+// TestGenerateTrieOrphanStorage exercises the orphan-slot skip path: flat
+// storage entries for an accountHash that has no corresponding account
+// snapshot. updateStorageRoots must skip these without including them in
+// any account's storage root.
+func TestGenerateTrieOrphanStorage(t *testing.T) {
+ db := rawdb.NewMemoryDatabase()
+
+ // One legitimate account with storage.
+ liveAccountHash := crypto.Keccak256Hash(common.HexToAddress("0x01").Bytes())
+ slots := []testSlot{
+ {hash: common.HexToHash("0xaa"), value: []byte{0x01}},
+ }
+ for _, s := range slots {
+ rawdb.WriteStorageSnapshot(db, liveAccountHash, s.hash, s.value)
+ }
+ acc := testAccount{
+ hash: liveAccountHash,
+ account: types.StateAccount{
+ Nonce: 1,
+ Balance: uint256.NewInt(1),
+ Root: computeStorageRootFromSlots(slots),
+ CodeHash: types.EmptyCodeHash.Bytes(),
+ },
+ storage: slots,
+ }
+ rawdb.WriteAccountSnapshot(db, acc.hash, types.SlimAccountRLP(acc.account))
+
+ // Orphan storage: entries for an accountHash smaller than liveAccountHash,
+ // with no account snapshot behind them. Must be ordered before liveAccountHash
+ // so the storage iterator encounters them first.
+ var orphanAccountHash common.Hash
+ copy(orphanAccountHash[:], liveAccountHash[:])
+ orphanAccountHash[0] = 0x00 // guarantees cmp < 0 against liveAccountHash
+ rawdb.WriteStorageSnapshot(db, orphanAccountHash, common.HexToHash("0xbb"), []byte{0x02})
+
+ expectedRoot := buildExpectedRoot(t, []testAccount{acc})
+
+ if err := GenerateTrie(db, rawdb.HashScheme, expectedRoot, nil); err != nil {
+ t.Fatalf("GenerateTrie with orphan storage failed: %v", err)
+ }
+}
diff --git a/triedb/internal/conversion.go b/triedb/internal/conversion.go
index b331b63e21..8ab6c74268 100644
--- a/triedb/internal/conversion.go
+++ b/triedb/internal/conversion.go
@@ -21,6 +21,7 @@ package internal
import (
"encoding/binary"
+ "errors"
"fmt"
"math"
"runtime"
@@ -36,6 +37,10 @@ import (
"github.com/ethereum/go-ethereum/trie"
)
+// ErrCancelled is returned by GenerateTrieRoot when the cancel channel is
+// closed mid-run.
+var ErrCancelled = errors.New("cancelled")
+
// Iterator is an iterator to step over all the accounts or the specific
// storage in a snapshot which may or may not be composed of multiple layers.
type Iterator interface {
@@ -228,7 +233,7 @@ func RunReport(stats *GenerateStats, stop chan bool) {
// GenerateTrieRoot generates the trie hash based on the snapshot iterator.
// It can be used for generating account trie, storage trie or even the
// whole state which connects the accounts and the corresponding storages.
-func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, account common.Hash, generatorFn TrieGeneratorFn, leafCallback LeafCallbackFn, stats *GenerateStats, report bool) (common.Hash, error) {
+func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, account common.Hash, generatorFn TrieGeneratorFn, leafCallback LeafCallbackFn, stats *GenerateStats, report bool, cancel <-chan struct{}) (common.Hash, error) {
var (
in = make(chan TrieKV) // chan to pass leaves
out = make(chan common.Hash, 1) // chan to collect result
@@ -279,6 +284,14 @@ func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, accou
)
// Start to feed leaves
for it.Next() {
+ // Top-of-loop cancel check. Cheap non-blocking peek so a closed
+ // cancel channel is observed without waiting for the blocking
+ // operations below.
+ select {
+ case <-cancel:
+ return stop(ErrCancelled)
+ default:
+ }
if account == (common.Hash{}) {
var (
err error
@@ -291,8 +304,14 @@ func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, accou
}
} else {
// Wait until the semaphore allows us to continue, aborting if
- // a sub-task failed
- if err := <-results; err != nil {
+ // a sub-task failed or the caller cancelled.
+ var err error
+ select {
+ case err = <-results:
+ case <-cancel:
+ return stop(ErrCancelled)
+ }
+ if err != nil {
results <- nil // stop will drain the results, add a noop back for this error we just consumed
return stop(err)
}
@@ -322,7 +341,13 @@ func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, accou
} else {
leaf = TrieKV{it.Hash(), common.CopyBytes(it.(StorageIterator).Slot())}
}
- in <- leaf
+ // Escape on cancel so we don't deadlock if the generator goroutine is slow
+ // and the caller gave up.
+ select {
+ case in <- leaf:
+ case <-cancel:
+ return stop(ErrCancelled)
+ }
// Accumulate the generation statistic if it's required.
processed++
diff --git a/triedb/internal/conversion_test.go b/triedb/internal/conversion_test.go
new file mode 100644
index 0000000000..0651ea1d91
--- /dev/null
+++ b/triedb/internal/conversion_test.go
@@ -0,0 +1,55 @@
+// Copyright 2026 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package internal
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// fakeStorageIterator is a StorageIterator over a fixed list of slots.
+type fakeStorageIterator struct {
+ count int
+ idx int
+}
+
+func (it *fakeStorageIterator) Next() bool {
+ if it.idx >= it.count {
+ return false
+ }
+ it.idx++
+ return true
+}
+func (it *fakeStorageIterator) Error() error { return nil }
+func (it *fakeStorageIterator) Hash() common.Hash { return common.BytesToHash([]byte{byte(it.idx)}) }
+func (it *fakeStorageIterator) Slot() []byte { return []byte{byte(it.idx)} }
+func (it *fakeStorageIterator) Release() {}
+
+// TestGenerateTrieRootCancel verifies that GenerateTrieRoot aborts with
+// ErrCancelled when the cancel channel is closed.
+func TestGenerateTrieRootCancel(t *testing.T) {
+ t.Parallel()
+ it := &fakeStorageIterator{count: 10_000}
+ cancel := make(chan struct{})
+ close(cancel)
+ _, err := GenerateTrieRoot(nil, "", it, common.HexToHash("0xaa"), StackTrieGenerate, nil, nil, false, cancel)
+ if !errors.Is(err, ErrCancelled) {
+ t.Fatalf("expected ErrCancelled, got %v", err)
+ }
+}
diff --git a/triedb/pathdb/holdable_iterator.go b/triedb/internal/holdable_iterator.go
similarity index 82%
rename from triedb/pathdb/holdable_iterator.go
rename to triedb/internal/holdable_iterator.go
index 1f8e6b7068..7b0535e461 100644
--- a/triedb/pathdb/holdable_iterator.go
+++ b/triedb/internal/holdable_iterator.go
@@ -14,31 +14,31 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-package pathdb
+package internal
import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
)
-// holdableIterator is a wrapper of underlying database iterator. It extends
+// HoldableIterator is a wrapper of underlying database iterator. It extends
// the basic iterator interface by adding Hold which can hold the element
// locally where the iterator is currently located and serve it up next time.
-type holdableIterator struct {
+type HoldableIterator struct {
it ethdb.Iterator
key []byte
val []byte
atHeld bool
}
-// newHoldableIterator initializes the holdableIterator with the given iterator.
-func newHoldableIterator(it ethdb.Iterator) *holdableIterator {
- return &holdableIterator{it: it}
+// NewHoldableIterator initializes the HoldableIterator with the given iterator.
+func NewHoldableIterator(it ethdb.Iterator) *HoldableIterator {
+ return &HoldableIterator{it: it}
}
// Hold holds the element locally where the iterator is currently located which
// can be served up next time.
-func (it *holdableIterator) Hold() {
+func (it *HoldableIterator) Hold() {
if it.it.Key() == nil {
return // nothing to hold
}
@@ -49,7 +49,7 @@ func (it *holdableIterator) Hold() {
// Next moves the iterator to the next key/value pair. It returns whether the
// iterator is exhausted.
-func (it *holdableIterator) Next() bool {
+func (it *HoldableIterator) Next() bool {
if !it.atHeld && it.key != nil {
it.atHeld = true
} else if it.atHeld {
@@ -65,11 +65,11 @@ func (it *holdableIterator) Next() bool {
// Error returns any accumulated error. Exhausting all the key/value pairs
// is not considered to be an error.
-func (it *holdableIterator) Error() error { return it.it.Error() }
+func (it *HoldableIterator) Error() error { return it.it.Error() }
// Release releases associated resources. Release should always succeed and can
// be called multiple times without causing error.
-func (it *holdableIterator) Release() {
+func (it *HoldableIterator) Release() {
it.atHeld = false
it.key = nil
it.val = nil
@@ -79,7 +79,7 @@ func (it *holdableIterator) Release() {
// Key returns the key of the current key/value pair, or nil if done. The caller
// should not modify the contents of the returned slice, and its contents may
// change on the next call to Next.
-func (it *holdableIterator) Key() []byte {
+func (it *HoldableIterator) Key() []byte {
if it.key != nil {
return it.key
}
@@ -89,7 +89,7 @@ func (it *holdableIterator) Key() []byte {
// Value returns the value of the current key/value pair, or nil if done. The
// caller should not modify the contents of the returned slice, and its contents
// may change on the next call to Next.
-func (it *holdableIterator) Value() []byte {
+func (it *HoldableIterator) Value() []byte {
if it.val != nil {
return it.val
}
diff --git a/triedb/pathdb/holdable_iterator_test.go b/triedb/internal/holdable_iterator_test.go
similarity index 92%
rename from triedb/pathdb/holdable_iterator_test.go
rename to triedb/internal/holdable_iterator_test.go
index 2abc92e154..af6d7a34d6 100644
--- a/triedb/pathdb/holdable_iterator_test.go
+++ b/triedb/internal/holdable_iterator_test.go
@@ -14,7 +14,7 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-package pathdb
+package internal
import (
"bytes"
@@ -39,7 +39,7 @@ func TestIteratorHold(t *testing.T) {
}
}
// Iterate over the database with the given configs and verify the results
- it, idx := newHoldableIterator(db.NewIterator(nil, nil)), 0
+ it, idx := NewHoldableIterator(db.NewIterator(nil, nil)), 0
// Nothing should be affected for calling Discard on non-initialized iterator
it.Hold()
@@ -108,20 +108,20 @@ func TestReopenIterator(t *testing.T) {
}
db = rawdb.NewMemoryDatabase()
- reopen = func(db ethdb.KeyValueStore, iter *holdableIterator) *holdableIterator {
+ reopen = func(db ethdb.KeyValueStore, iter *HoldableIterator) *HoldableIterator {
if !iter.Next() {
iter.Release()
- return newHoldableIterator(memorydb.New().NewIterator(nil, nil))
+ return NewHoldableIterator(memorydb.New().NewIterator(nil, nil))
}
next := iter.Key()
iter.Release()
- return newHoldableIterator(db.NewIterator(rawdb.SnapshotAccountPrefix, next[1:]))
+ return NewHoldableIterator(db.NewIterator(rawdb.SnapshotAccountPrefix, next[1:]))
}
)
for key, val := range content {
rawdb.WriteAccountSnapshot(db, key, []byte(val))
}
- checkVal := func(it *holdableIterator, index int) {
+ checkVal := func(it *HoldableIterator, index int) {
if !bytes.Equal(it.Key(), append(rawdb.SnapshotAccountPrefix, order[index].Bytes()...)) {
t.Fatalf("Unexpected data entry key, want %v got %v", order[index], it.Key())
}
@@ -131,7 +131,7 @@ func TestReopenIterator(t *testing.T) {
}
// Iterate over the database with the given configs and verify the results
dbIter := db.NewIterator(rawdb.SnapshotAccountPrefix, nil)
- iter, idx := newHoldableIterator(rawdb.NewKeyLengthIterator(dbIter, 1+common.HashLength)), -1
+ iter, idx := NewHoldableIterator(rawdb.NewKeyLengthIterator(dbIter, 1+common.HashLength)), -1
idx++
iter.Next()
diff --git a/triedb/pathdb/context.go b/triedb/pathdb/context.go
index a5704de81a..0554ee91bf 100644
--- a/triedb/pathdb/context.go
+++ b/triedb/pathdb/context.go
@@ -28,6 +28,7 @@ import (
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/triedb/internal"
)
const (
@@ -91,12 +92,12 @@ func (gs *generatorStats) log(msg string, root common.Hash, marker []byte) {
// current generation cycle. It must be recreated if the generation cycle is
// restarted.
type generatorContext struct {
- root common.Hash // State root of the generation target
- account *holdableIterator // Iterator of account snapshot data
- storage *holdableIterator // Iterator of storage snapshot data
- db ethdb.KeyValueStore // Key-value store containing the snapshot data
- batch ethdb.Batch // Database batch for writing data atomically
- logged time.Time // The timestamp when last generation progress was displayed
+ root common.Hash // State root of the generation target
+ account *internal.HoldableIterator // Iterator of account snapshot data
+ storage *internal.HoldableIterator // Iterator of storage snapshot data
+ db ethdb.KeyValueStore // Key-value store containing the snapshot data
+ batch ethdb.Batch // Database batch for writing data atomically
+ logged time.Time // The timestamp when last generation progress was displayed
}
// newGeneratorContext initializes the context for generation.
@@ -119,11 +120,11 @@ func newGeneratorContext(root common.Hash, marker []byte, db ethdb.KeyValueStore
func (ctx *generatorContext) openIterator(kind string, start []byte) {
if kind == snapAccount {
iter := ctx.db.NewIterator(rawdb.SnapshotAccountPrefix, start)
- ctx.account = newHoldableIterator(rawdb.NewKeyLengthIterator(iter, 1+common.HashLength))
+ ctx.account = internal.NewHoldableIterator(rawdb.NewKeyLengthIterator(iter, 1+common.HashLength))
return
}
iter := ctx.db.NewIterator(rawdb.SnapshotStoragePrefix, start)
- ctx.storage = newHoldableIterator(rawdb.NewKeyLengthIterator(iter, 1+2*common.HashLength))
+ ctx.storage = internal.NewHoldableIterator(rawdb.NewKeyLengthIterator(iter, 1+2*common.HashLength))
}
// reopenIterator releases the specified snapshot iterator and re-open it
@@ -140,10 +141,10 @@ func (ctx *generatorContext) reopenIterator(kind string) {
// Iterator exhausted, release forever and create an already exhausted virtual iterator
iter.Release()
if kind == snapAccount {
- ctx.account = newHoldableIterator(memorydb.New().NewIterator(nil, nil))
+ ctx.account = internal.NewHoldableIterator(memorydb.New().NewIterator(nil, nil))
return
}
- ctx.storage = newHoldableIterator(memorydb.New().NewIterator(nil, nil))
+ ctx.storage = internal.NewHoldableIterator(memorydb.New().NewIterator(nil, nil))
return
}
next := iter.Key()
@@ -158,7 +159,7 @@ func (ctx *generatorContext) close() {
}
// iterator returns the corresponding iterator specified by the kind.
-func (ctx *generatorContext) iterator(kind string) *holdableIterator {
+func (ctx *generatorContext) iterator(kind string) *internal.HoldableIterator {
if kind == snapAccount {
return ctx.account
}
diff --git a/triedb/pathdb/verifier.go b/triedb/pathdb/verifier.go
index c53590f2fd..4284432c1e 100644
--- a/triedb/pathdb/verifier.go
+++ b/triedb/pathdb/verifier.go
@@ -52,12 +52,12 @@ func (db *Database) VerifyState(root common.Hash) error {
}
defer storageIt.Release()
- hash, err := internal.GenerateTrieRoot(nil, "", storageIt, accountHash, stackTrieHasher, nil, stat, false)
+ hash, err := internal.GenerateTrieRoot(nil, "", storageIt, accountHash, stackTrieHasher, nil, stat, false, nil)
if err != nil {
return common.Hash{}, err
}
return hash, nil
- }, internal.NewGenerateStats(), true)
+ }, internal.NewGenerateStats(), true, nil)
if err != nil {
return err