triedb: reconcile stale storage roots in GenerateTrie, add cancel support (#34807)

Rewrites triedb.GenerateTrie as a single partitioned pass that
reconciles stale account.Root fields and rebuilds the trie at the same
time, with 16-way parallelism and crash resume baked in.

---------

Co-authored-by: Gary Rong <garyrong0905@gmail.com>
This commit is contained in:
Jonny Rhea 2026-06-03 02:08:09 -05:00 committed by GitHub
parent e514ede494
commit f4393173f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 1968 additions and 122 deletions

View file

@ -22,11 +22,18 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
_ "net/http/pprof"
"os" "os"
"os/signal"
"path/filepath"
"runtime"
"slices" "slices"
"sort" "sort"
"syscall"
"time" "time"
pebbleimpl "github.com/cockroachdb/pebble"
"github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/cmd/utils"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
@ -36,6 +43,7 @@ import (
"github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/pebble"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
@ -80,6 +88,33 @@ geth snapshot verify-state <state-root>
will traverse the whole accounts and storages set based on the specified will traverse the whole accounts and storages set based on the specified
snapshot and recalculate the root hash of state for verification. snapshot and recalculate the root hash of state for verification.
In other words, this command does the snapshot to trie conversion. In other words, this command does the snapshot to trie conversion.
`,
},
{
Name: "generate-trie",
Usage: "Benchmark triedb.GenerateTrie against a hard-linked checkpoint of the chaindata",
ArgsUsage: "[<root>]",
Action: benchGenerateTrie,
Flags: slices.Concat(utils.NetworkFlags, utils.DatabaseFlags, []cli.Flag{
&cli.StringFlag{
Name: "checkpoint",
Usage: "Directory for the pebble checkpoint (default: <chaindata-parent>/.gentrie-bench-<ts>)",
},
&cli.BoolFlag{
Name: "keep",
Usage: "Keep the checkpoint directory after the run (debugging)",
},
&cli.BoolFlag{
Name: "pprof",
Usage: "Serve pprof profiles on localhost:6060 (block + mutex profiles enabled)",
},
}),
Description: `
geth snapshot generate-trie [<root>]
Runs triedb.GenerateTrie against a hard-linked pebble checkpoint of the
chaindata. Checkpoint is removed on exit unless --keep is set. Defaults
to the snapshot root if <root> is not given.
`, `,
}, },
{ {
@ -289,6 +324,157 @@ func verifyState(ctx *cli.Context) error {
} }
} }
// benchGenerateTrie runs triedb.GenerateTrie against a hard-linked checkpoint
// of the chaindata so the source datadir is never written to.
func benchGenerateTrie(ctx *cli.Context) error {
stack, _ := makeConfigNode(ctx)
defer stack.Close()
if ctx.Bool("pprof") {
runtime.SetBlockProfileRate(1)
runtime.SetMutexProfileFraction(1)
go func() {
log.Info("pprof listening", "addr", ":6060")
if err := http.ListenAndServe(":6060", nil); err != nil {
log.Warn("pprof server stopped", "err", err)
}
}()
}
// Resolve source chaindata path (handles network-specific subdirs).
srcDir := stack.ResolvePath("chaindata")
if fi, err := os.Stat(srcDir); err != nil {
return fmt.Errorf("chaindata not found at %s: %w", srcDir, err)
} else if !fi.IsDir() {
return fmt.Errorf("%s is not a directory", srcDir)
}
// Default to snapshot root, not head: that's what GenerateTrie actually
// reconstructs from flat state. On a fully-synced node they match.
var root common.Hash
if ctx.NArg() == 1 {
r, err := parseRoot(ctx.Args().First())
if err != nil {
return fmt.Errorf("parse root: %w", err)
}
root = r
} else {
chaindb := utils.MakeChainDatabase(ctx, stack, true)
snapRoot := rawdb.ReadSnapshotRoot(chaindb)
head := rawdb.ReadHeadBlock(chaindb)
chaindb.Close()
switch {
case snapRoot != (common.Hash{}):
root = snapRoot
log.Info("using snapshot root", "root", root)
case head != nil:
root = head.Root()
log.Info("using head block root", "number", head.Number(), "root", root)
default:
return errors.New("no snapshot or head block found; pass <root> explicitly")
}
}
// Default checkpoint sits next to chaindata so hard links work.
ckpt := ctx.String("checkpoint")
if ckpt == "" {
ts := time.Now().Format("20060102-150405")
ckpt = filepath.Join(filepath.Dir(srcDir), fmt.Sprintf(".gentrie-bench-%s", ts))
}
if _, err := os.Stat(ckpt); err == nil {
return fmt.Errorf("checkpoint dir %s already exists; remove it or pass --checkpoint to a fresh path", ckpt)
}
log.Info("creating pebble checkpoint", "src", srcDir, "dst", ckpt)
checkpointStart := time.Now()
if err := makeCheckpoint(srcDir, ckpt); err != nil {
return fmt.Errorf("checkpoint failed: %w", err)
}
log.Info("checkpoint created", "elapsed", time.Since(checkpointStart))
// Clean up the checkpoint on exit, including Ctrl-C.
keep := ctx.Bool("keep")
cleanup := func() {
if keep {
log.Info("keeping checkpoint", "path", ckpt)
return
}
log.Info("removing checkpoint", "path", ckpt)
if err := os.RemoveAll(ckpt); err != nil {
log.Error("failed to remove checkpoint", "err", err)
}
}
defer cleanup()
cancelCh := make(chan struct{})
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
defer signal.Stop(sigCh)
go func() {
<-sigCh
log.Warn("interrupt received; cancelling GenerateTrie")
close(cancelCh)
}()
// Open the checkpoint writable. Reuse source ancient. Checkpoint only
// hard-links the pebble SSTs (not the freezer), and GenerateTrie never
// writes to ancient, so sharing it is safe.
srcAncient := stack.ResolveAncient("chaindata", "")
kv, err := pebble.New(ckpt, 4096, 1024, "gentrie-bench", false)
if err != nil {
return fmt.Errorf("open checkpoint: %w", err)
}
chaindb, err := rawdb.Open(kv, rawdb.OpenOptions{
Ancient: srcAncient,
MetricsNamespace: "gentrie-bench",
})
if err != nil {
kv.Close()
return fmt.Errorf("rawdb.Open checkpoint: %w", err)
}
defer chaindb.Close()
// Pick up the trie scheme already in use (path or hash).
triedbInst := utils.MakeTrieDatabase(ctx, stack, chaindb, false, true, false)
scheme := triedbInst.Scheme()
triedbInst.Close()
log.Info("running GenerateTrie", "scheme", scheme, "root", root)
runStart := time.Now()
stats, err := triedb.GenerateTrie(chaindb, scheme, root, cancelCh)
elapsed := time.Since(runStart)
status := "root matched"
if err != nil {
status = fmt.Sprintf("failed (%s)", err)
log.Error("GenerateTrie failed", "elapsed", elapsed, "err", err)
}
fmt.Printf("\n=== generate-trie benchmark ===\n")
fmt.Printf("scheme: %s\n", scheme)
fmt.Printf("root: %s\n", root.Hex())
fmt.Printf("status: %s\n", status)
fmt.Printf("accounts: %d (%d updated)\n", stats.Scanned, stats.Updated)
fmt.Printf("wall time: %s\n", elapsed)
return err
}
// makeCheckpoint opens srcDir as a pebble database and writes a hard-linked
// checkpoint to dstDir. Source is closed on return.
//
// Opens read-write so pebble can finalize its startup (WAL replay, fresh
// OPTIONS file) before checkpointing. Read-only mode skips that step, and
// Checkpoint then fails trying to hard-link the missing OPTIONS file. The
// read-write open does no more than a normal geth startup would.
func makeCheckpoint(srcDir, dstDir string) error {
db, err := pebbleimpl.Open(srcDir, &pebbleimpl.Options{})
if err != nil {
return fmt.Errorf("open source pebble: %w", err)
}
defer db.Close()
return db.Checkpoint(dstDir)
}
// checkDanglingStorage iterates the snap storage data, and verifies that all // checkDanglingStorage iterates the snap storage data, and verifies that all
// storage also has corresponding account data. // storage also has corresponding account data.
func checkDanglingStorage(ctx *cli.Context) error { func checkDanglingStorage(ctx *cli.Context) error {

View file

@ -208,3 +208,45 @@ func WriteSnapshotSyncStatus(db ethdb.KeyValueWriter, status []byte) {
log.Crit("Failed to store snapshot sync status", "err", err) log.Crit("Failed to store snapshot sync status", "err", err)
} }
} }
// ReadGenerateTriePartitionDone returns the raw subtree root blob for a
// partition that has previously completed.
func ReadGenerateTriePartitionDone(db ethdb.KeyValueReader, partition byte) ([]byte, bool) {
data, err := db.Get(generateTriePartitionDoneKey(partition))
if err != nil {
return nil, false
}
if len(data) == 0 {
return nil, false
}
switch data[0] {
case 0x00:
// Partition is done and it is empty.
return nil, true
case 0x01:
// Partition is done and the blob follows.
return data[1:], true
default:
return nil, false
}
}
// WriteGenerateTriePartitionDone records a completed partition.
func WriteGenerateTriePartitionDone(db ethdb.KeyValueWriter, partition byte, blob []byte) {
var value []byte
if blob == nil {
value = []byte{0x00}
} else {
value = append([]byte{0x01}, blob...)
}
if err := db.Put(generateTriePartitionDoneKey(partition), value); err != nil {
log.Crit("Failed to store generate-trie done marker", "err", err)
}
}
// DeleteGenerateTriePartitionDone removes a partition's done marker.
func DeleteGenerateTriePartitionDone(db ethdb.KeyValueWriter, partition byte) {
if err := db.Delete(generateTriePartitionDoneKey(partition)); err != nil {
log.Crit("Failed to remove generate-trie done marker", "err", err)
}
}

View file

@ -563,6 +563,8 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error {
} }
// Metadata keys // Metadata keys
case bytes.HasPrefix(key, generateTriePartitionDonePrefix) && len(key) == len(generateTriePartitionDonePrefix)+1:
metadata.add(size)
case slices.ContainsFunc(knownMetadataKeys, func(x []byte) bool { return bytes.Equal(x, key) }): case slices.ContainsFunc(knownMetadataKeys, func(x []byte) bool { return bytes.Equal(x, key) }):
metadata.add(size) metadata.add(size)

View file

@ -104,6 +104,10 @@ var (
// snapSyncStatusFlagKey flags that status of snap sync. // snapSyncStatusFlagKey flags that status of snap sync.
snapSyncStatusFlagKey = []byte("SnapSyncStatus") snapSyncStatusFlagKey = []byte("SnapSyncStatus")
// generateTriePartitionDonePrefix stores the subtree root hash of each
// triedb.GenerateTrie partition once it finishes.
generateTriePartitionDonePrefix = []byte("gtd") // generateTriePartitionDonePrefix + partition byte -> subtree root hash
// Data item prefixes (use single byte to avoid mixing data types, avoid `i`, used for indexes). // Data item prefixes (use single byte to avoid mixing data types, avoid `i`, used for indexes).
headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header
headerTDSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + headerTDSuffix -> td (deprecated) headerTDSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + headerTDSuffix -> td (deprecated)
@ -465,3 +469,8 @@ func trienodeHistoryIndexBlockKey(addressHash common.Hash, path []byte, blockID
func transitionStateKey(hash common.Hash) []byte { func transitionStateKey(hash common.Hash) []byte {
return append(VerkleTransitionStatePrefix, hash.Bytes()...) return append(VerkleTransitionStatePrefix, hash.Bytes()...)
} }
// generateTriePartitionDoneKey = generateTriePartitionDonePrefix + partition (single byte).
func generateTriePartitionDoneKey(partition byte) []byte {
return append(generateTriePartitionDonePrefix, partition)
}

View file

@ -23,6 +23,7 @@ import (
"strings" "strings"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -335,3 +336,79 @@ func wrapError(err error, ctx string) error {
func (err *decodeError) Error() string { func (err *decodeError) Error() string {
return fmt.Sprintf("%v (decode path: %s)", err.what, strings.Join(err.stack, "<-")) return fmt.Sprintf("%v (decode path: %s)", err.what, strings.Join(err.stack, "<-"))
} }
// MountPartitionRoot folds the leading nibble n back into the root of a
// partition subtree that was built with that nibble stripped (see
// PartialStackTrie). It returns the node, and its hash, that becomes the
// canonical trie root when partition n turns out to be the only populated
// partition, so the top-level branch collapses into the subtree itself.
//
// The subtree root blob is one of:
//
// - a branch: the canonical root is a freshly constructed extension carrying
// the single nibble n and pointing at the branch by hash. The branch stays
// at the path the partition already wrote it to ([n]).
//
// - a short node (extension or leaf): its hex key is extended from [k...] to
// [n, k...], preserving the leaf terminator if present, and the child/value
// element is reused verbatim.
//
// isOrphaned reports whether a short node was folded. When true, the node the
// caller persisted at path [n] is no longer referenced by the returned root and
// should be deleted. It is false for the branch case, where [n] stays referenced.
func MountPartitionRoot(blob []byte, n byte) (hash common.Hash, writeBlob []byte, isOrphaned bool, err error) {
elems, err := decodeNodeElements(blob)
if err != nil {
return common.Hash{}, nil, false, fmt.Errorf("decode partition root: %w", err)
}
switch len(elems) {
case 17:
// Branch root: wrap it in an extension carrying the single nibble n,
// referencing the branch by its 32-byte hash.
keyRLP, err := rlp.EncodeToBytes(hexToCompact([]byte{n}))
if err != nil {
return common.Hash{}, nil, false, fmt.Errorf("encode extension key: %w", err)
}
childRLP, err := rlp.EncodeToBytes(crypto.Keccak256(blob))
if err != nil {
return common.Hash{}, nil, false, fmt.Errorf("encode child ref: %w", err)
}
writeBlob, err = encodeNodeElements([][]byte{keyRLP, childRLP})
if err != nil {
return common.Hash{}, nil, false, fmt.Errorf("encode extension node: %w", err)
}
return crypto.Keccak256Hash(writeBlob), writeBlob, false, nil
case 2:
// Short node (extension/leaf): prepend n to its hex key. compactToHex
// retains the leaf terminator, so hexToCompact restores the right type.
compactKey, _, err := rlp.SplitString(elems[0])
if err != nil {
return common.Hash{}, nil, false, fmt.Errorf("parse compact key: %w", err)
}
hex := append([]byte{n}, compactToHex(compactKey)...)
keyRLP, err := rlp.EncodeToBytes(hexToCompact(hex))
if err != nil {
return common.Hash{}, nil, false, fmt.Errorf("encode mounted key: %w", err)
}
writeBlob, err = encodeNodeElements([][]byte{keyRLP, elems[1]})
if err != nil {
return common.Hash{}, nil, false, fmt.Errorf("encode mounted node: %w", err)
}
return crypto.Keccak256Hash(writeBlob), writeBlob, true, nil
default:
return common.Hash{}, nil, false, fmt.Errorf("unexpected partition root element count: %d", len(elems))
}
}
// AssembleBranch constructs a fullNode (17-slot branch) from the given
// children and returns its RLP encoding and 32-byte hash.
func AssembleBranch(children [17][]byte) ([]byte, common.Hash, error) {
fn := &fullnodeEncoder{Children: children}
w := rlp.NewEncoderBuffer(nil)
fn.encode(w)
blob := w.ToBytes()
w.Flush()
return blob, crypto.Keccak256Hash(blob), nil
}

View file

@ -85,6 +85,14 @@ func (t *StackTrie) Update(key, value []byte) error {
} }
t.grow(key) t.grow(key)
k := writeHexKey(t.kBuf, key) k := writeHexKey(t.kBuf, key)
return t.update(k, value)
}
// update inserts a (hex-key, value) pair into the stack trie. The key must be
// in hex (nibble) form without the terminator flag, and the value must be
// non-empty. It is shared by Update and the partition builder, which feeds a
// key with its leading nibble stripped.
func (t *StackTrie) update(k, value []byte) error {
if bytes.Compare(t.last, k) >= 0 { if bytes.Compare(t.last, k) >= 0 {
return errors.New("non-ascending key order") return errors.New("non-ascending key order")
} }

91
trie/stacktrie_partial.go Normal file
View file

@ -0,0 +1,91 @@
// Copyright 2026 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
)
// PartialStackTrie builds the subtrie of a single first-nibble partition of a
// larger trie on top of a StackTrie. It is used for parallel trie generation,
// where the key space is split into 16 partitions by the first nibble and each
// partition is built independently before being mounted under a common root.
//
// Two adjustments make the produced subtrie line up with its position in the
// full trie:
//
// - Keys are inserted with their leading nibble stripped. That nibble is
// implied by the partition's slot in the parent branch, so duplicating it
// inside the keys would corrupt every node hash below the root.
//
// - Paths reported to onTrieNode are prefixed with the partition nibble, so
// a node's path matches its absolute position in the full trie. This is
// required by the path-based storage scheme, which keys nodes by path.
//
// The hashes themselves are independent of the absolute path, so prefixing the
// path does not change any node hash.
//
// All inserted keys must share the same leading nibble equal to `nibble`; the
// caller guarantees this by construction (e.g. by partitioning a hash range).
type PartialStackTrie struct {
nibble byte
inner *StackTrie
pathBuf []byte // reusable buffer for the nibble-prefixed path
}
// NewPartialStackTrie creates a partition builder for the given leading nibble.
// The onTrieNode callback, if non-nil, is invoked for every committed node with
// its absolute path already prefixed with the partition nibble.
func NewPartialStackTrie(nibble byte, onTrieNode OnTrieNode) *PartialStackTrie {
p := &PartialStackTrie{nibble: nibble}
p.inner = NewStackTrie(func(path []byte, hash common.Hash, blob []byte) {
if onTrieNode == nil {
return
}
// Prefix the path with the partition nibble. The buffer is reused across
// calls, so the callback must consume it synchronously.
p.pathBuf = append(p.pathBuf[:0], nibble)
p.pathBuf = append(p.pathBuf, path...)
onTrieNode(p.pathBuf, hash, blob)
})
return p
}
// Update inserts a (key, value) pair, stripping the key's leading nibble, which
// is implied by the partition. The key must begin with the partition nibble.
func (p *PartialStackTrie) Update(key, value []byte) error {
if len(value) == 0 {
return errors.New("trying to insert empty (deletion)")
}
t := p.inner
t.grow(key)
k := writeHexKey(t.kBuf, key)
if k[0] != p.nibble {
return fmt.Errorf("unexpected nibble %v, expected %x", k[0], p.nibble)
}
return t.update(k[1:], value)
}
// Hash returns the root hash of the partition subtrie (built with the leading
// nibble stripped). It is the reference the parent branch mounts in slot `nibble`.
func (p *PartialStackTrie) Hash() common.Hash {
return p.inner.Hash()
}

View file

@ -0,0 +1,288 @@
// Copyright 2026 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"bytes"
"sort"
"strings"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
)
// mkKey builds a 32-byte key from a leading hex string, right-padded with zeros
// (e.g. "3a" -> 0x3a000...0). The first nibble is prefixHex[0].
func mkKey(prefixHex string) []byte {
return common.HexToHash(prefixHex + strings.Repeat("0", 64-len(prefixHex))).Bytes()
}
// sortedPairs turns key prefixes into 32-byte (key, value) slices sorted by key,
// as StackTrie requires. Values are distinct and 32 bytes long.
func sortedPairs(prefixes []string) (keys, vals [][]byte) {
type kv struct{ k, v []byte }
ps := make([]kv, len(prefixes))
for i, p := range prefixes {
ps[i] = kv{mkKey(p), bytes.Repeat([]byte{byte(i + 1)}, 32)}
}
sort.Slice(ps, func(i, j int) bool { return bytes.Compare(ps[i].k, ps[j].k) < 0 })
for _, p := range ps {
keys = append(keys, p.k)
vals = append(vals, p.v)
}
return keys, vals
}
// partitionRoot builds partition n over the given keys and returns its subtree
// root blob (the node emitted at path [n]).
func partitionRoot(t *testing.T, n byte, keys, vals [][]byte) []byte {
t.Helper()
var root []byte
pst := NewPartialStackTrie(n, func(path []byte, _ common.Hash, blob []byte) {
if len(path) == 1 {
root = common.CopyBytes(blob)
}
})
for i := range keys {
if err := pst.Update(keys[i], vals[i]); err != nil {
t.Fatalf("partition update: %v", err)
}
}
pst.Hash()
return root
}
type nodeRec struct {
hash common.Hash
blob []byte
}
// collect builds a trie via the given updater and records every committed node
// keyed by its path.
func collect(update func(onNode OnTrieNode)) map[string]nodeRec {
nodes := make(map[string]nodeRec)
update(func(path []byte, hash common.Hash, blob []byte) {
nodes[string(path)] = nodeRec{hash, common.CopyBytes(blob)}
})
return nodes
}
// nodeKind decodes a node blob into "branch", "extension" or "leaf".
func nodeKind(t *testing.T, blob []byte) string {
t.Helper()
elems, err := decodeNodeElements(blob)
if err != nil {
t.Fatalf("decode node: %v", err)
}
switch len(elems) {
case 17:
return "branch"
case 2:
key, _, err := rlp.SplitString(elems[0])
if err != nil {
t.Fatalf("split key: %v", err)
}
if hasTerm(compactToHex(key)) {
return "leaf"
}
return "extension"
default:
t.Fatalf("unexpected element count %d", len(elems))
return ""
}
}
// TestPartialStackTrieMatchesFullSubtree proves that, for every shape the
// partition subtree root can take, the nodes emitted by a PartialStackTrie for
// partition n are byte-for-byte identical (path, hash, blob) to the [n]-subtree
// of the full trie built from the same keys.
func TestPartialStackTrieMatchesFullSubtree(t *testing.T) {
const n = byte(3)
// A single key in another partition (first nibble 9 > 3, so it sorts last)
// forces the full trie's root to be a branch, giving a clean [n]-subtree.
otherKey := mkKey("9")
otherVal := bytes.Repeat([]byte{0xff}, 32)
cases := []struct {
name string
keys []string // partition-n key prefixes (first nibble must be 3)
wantRoot string // expected shape of the partition subtree root
}{
{"single-leaf", []string{"3abc"}, "leaf"},
{"branch-root", []string{"30", "37", "3a"}, "branch"},
{"extension-root", []string{"3110", "3115", "311a"}, "extension"},
{"mixed", []string{"30", "3105", "310a", "3f00", "3f0f"}, "branch"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
keys, vals := sortedPairs(tc.keys)
// Reference: full trie over the partition-n keys plus the other-partition key.
full := collect(func(onNode OnTrieNode) {
st := NewStackTrie(onNode)
for i := range keys {
if err := st.Update(keys[i], vals[i]); err != nil {
t.Fatalf("full update: %v", err)
}
}
if err := st.Update(otherKey, otherVal); err != nil {
t.Fatalf("full update (other): %v", err)
}
st.Hash()
})
// Subject: PartialStackTrie over just the partition-n keys.
var partRoot common.Hash
part := collect(func(onNode OnTrieNode) {
pst := NewPartialStackTrie(n, onNode)
for i := range keys {
if err := pst.Update(keys[i], vals[i]); err != nil {
t.Fatalf("partial update: %v", err)
}
}
partRoot = pst.Hash()
})
// The subtree root must live at path [n] in the full trie (i.e. it is
// hash-referenced, not inlined) and its hash must match Hash().
rootRec, ok := full[string([]byte{n})]
if !ok {
t.Fatalf("full trie has no node at path [%d]", n)
}
if rootRec.hash != partRoot {
t.Fatalf("partition root %x != full subtree root %x", partRoot, rootRec.hash)
}
if got := nodeKind(t, rootRec.blob); got != tc.wantRoot {
t.Fatalf("subtree root kind = %s, want %s", got, tc.wantRoot)
}
// Every full-trie node under [n] must equal the partition's node, and
// the partition must emit no node outside [n].
want := make(map[string]nodeRec)
for p, rec := range full {
if len(p) >= 1 && p[0] == n {
want[p] = rec
}
}
if len(want) != len(part) {
t.Fatalf("node count: full subtree=%d, partition=%d", len(want), len(part))
}
for p, rec := range want {
got, ok := part[p]
if !ok {
t.Fatalf("partition missing node at path %x", []byte(p))
}
if got.hash != rec.hash || !bytes.Equal(got.blob, rec.blob) {
t.Fatalf("node mismatch at path %x", []byte(p))
}
}
})
}
}
// TestPartialStackTrieWrongNibble checks the guard that rejects a key whose
// leading nibble does not belong to the partition.
func TestPartialStackTrieWrongNibble(t *testing.T) {
pst := NewPartialStackTrie(3, nil)
if err := pst.Update(mkKey("4abc"), []byte{0x01}); err == nil {
t.Fatal("expected error for key outside the partition, got nil")
}
}
// TestMountPartitionRoot checks that folding the leading nibble back into a
// single partition's subtree root reproduces the canonical trie root, for every
// root shape (leaf, extension, branch). The branch case is the one not reachable
// through the triedb single-partition tests.
func TestMountPartitionRoot(t *testing.T) {
const n = byte(3)
cases := []struct {
name string
keys []string
wantOrphaned bool
}{
{"leaf", []string{"3abc"}, true},
{"extension", []string{"3110", "3115", "311a"}, true},
{"branch", []string{"30", "37", "3a"}, false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
keys, vals := sortedPairs(tc.keys)
// Canonical root: a plain trie over the same keys. They all share
// nibble n, so there is no top-level branch to collapse.
ref := NewStackTrie(nil)
for i := range keys {
if err := ref.Update(keys[i], vals[i]); err != nil {
t.Fatalf("ref update: %v", err)
}
}
want := ref.Hash()
got, blob, isOrphaned, err := MountPartitionRoot(partitionRoot(t, n, keys, vals), n)
if err != nil {
t.Fatalf("MountPartitionRoot: %v", err)
}
if isOrphaned != tc.wantOrphaned {
t.Fatalf("isOrphaned = %v, want %v", isOrphaned, tc.wantOrphaned)
}
if got != want {
t.Fatalf("mounted root %x, want %x", got, want)
}
if crypto.Keccak256Hash(blob) != got {
t.Fatalf("returned blob does not hash to the returned root")
}
})
}
}
// TestAssembleBranch checks that packing partition subtree-root hashes into a
// top-level branch reproduces the canonical root of the union of those keys.
func TestAssembleBranch(t *testing.T) {
keys3, vals3 := sortedPairs([]string{"30", "37", "3a"})
keys7, vals7 := sortedPairs([]string{"71", "75"})
// Canonical root over both partitions (all "3..." sort before all "7...").
ref := NewStackTrie(nil)
for i := range keys3 {
if err := ref.Update(keys3[i], vals3[i]); err != nil {
t.Fatalf("ref update: %v", err)
}
}
for i := range keys7 {
if err := ref.Update(keys7[i], vals7[i]); err != nil {
t.Fatalf("ref update: %v", err)
}
}
want := ref.Hash()
var children [17][]byte
children[3] = crypto.Keccak256(partitionRoot(t, 3, keys3, vals3))
children[7] = crypto.Keccak256(partitionRoot(t, 7, keys7, vals7))
blob, got, err := AssembleBranch(children)
if err != nil {
t.Fatalf("AssembleBranch: %v", err)
}
if got != want {
t.Fatalf("assembled root %x, want %x", got, want)
}
if crypto.Keccak256Hash(blob) != got {
t.Fatalf("returned blob does not hash to the returned root")
}
}

View file

@ -17,92 +17,540 @@
package triedb package triedb
import ( import (
"bytes"
"context"
"encoding/binary"
"fmt" "fmt"
"math/big"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/triedb/internal" "github.com/ethereum/go-ethereum/triedb/internal"
"golang.org/x/sync/errgroup"
) )
// kvAccountIterator wraps an ethdb.Iterator to iterate over account snapshot // ErrCancelled is returned when GenerateTrie is aborted via its cancel
// entries in the database, implementing internal.AccountIterator. // channel before completing.
type kvAccountIterator struct { var ErrCancelled = internal.ErrCancelled
it ethdb.Iterator
hash common.Hash // GenerateStats reports per-run counters from GenerateTrie. Scanned is
// the number of accounts walked, Updated is how many had a stale Root
// field that was rewritten to match the recomputed storage root, and
// Deleted is the number of dangling storage slots removed.
type GenerateStats struct {
Scanned int64
Updated int64
Deleted int64
} }
func newKVAccountIterator(db ethdb.Iteratee) *kvAccountIterator { // numPartitions is the number of slices the account hash space is divided
it := rawdb.NewKeyLengthIterator( // into by GenerateTrie.
db.NewIterator(rawdb.SnapshotAccountPrefix, nil), const numPartitions = 16
len(rawdb.SnapshotAccountPrefix)+common.HashLength,
) // Each partition covers 1/16 of the account hash space. We track progress
return &kvAccountIterator{it: it} // by interpreting the top 8 bytes of an account hash as a uint64, so each
// partition spans 2^64 / 16 = 2^60. partitionFinished is stored in a
// partition's position when it completes.
const (
partitionRangeSize = uint64(1) << 60
partitionFinished = ^uint64(0)
)
// rangeIterators bundles the per-partition account and storage iterators.
type rangeIterators struct {
db ethdb.Database
acct *internal.HoldableIterator
stor *internal.HoldableIterator
} }
func (it *kvAccountIterator) Next() bool { func openRangeIterators(db ethdb.Database, start common.Hash) *rangeIterators {
if !it.it.Next() { return &rangeIterators{
return false db: db,
acct: openFlatIterator(db, rawdb.SnapshotAccountPrefix, start[:], common.HashLength),
stor: openFlatIterator(db, rawdb.SnapshotStoragePrefix, start[:], 2*common.HashLength),
} }
key := it.it.Key()
copy(it.hash[:], key[len(rawdb.SnapshotAccountPrefix):])
return true
} }
func (it *kvAccountIterator) Hash() common.Hash { return it.hash } // reopen releases both iterators and reopens them at their current
func (it *kvAccountIterator) Account() []byte { return it.it.Value() } // positions. Invoked after each batch flush so pebble compactions aren't
func (it *kvAccountIterator) Error() error { return it.it.Error() } // blocked by long-lived iterator snapshots. Follows the same pattern as
func (it *kvAccountIterator) Release() { it.it.Release() } // triedb/pathdb/context.go.
func (r *rangeIterators) reopen() {
// kvStorageIterator wraps an ethdb.Iterator to iterate over storage snapshot r.acct = reopenFlatIterator(r.db, r.acct, rawdb.SnapshotAccountPrefix, common.HashLength)
// entries for a specific account, implementing internal.StorageIterator. r.stor = reopenFlatIterator(r.db, r.stor, rawdb.SnapshotStoragePrefix, 2*common.HashLength)
type kvStorageIterator struct {
it ethdb.Iterator
hash common.Hash
} }
func newKVStorageIterator(db ethdb.Iteratee, accountHash common.Hash) *kvStorageIterator { func (r *rangeIterators) release() {
it := rawdb.IterateStorageSnapshots(db, accountHash) r.acct.Release()
return &kvStorageIterator{it: it} r.stor.Release()
} }
func (it *kvStorageIterator) Next() bool { // flushIfFull writes and resets the batch once it grows past IdealBatchSize,
if !it.it.Next() { // then reopens the iterators.
return false func (r *rangeIterators) flushIfFull(batch ethdb.Batch, where string) error {
if batch.ValueSize() <= ethdb.IdealBatchSize {
return nil
} }
key := it.it.Key() if err := batch.Write(); err != nil {
copy(it.hash[:], key[len(rawdb.SnapshotStoragePrefix)+common.HashLength:]) return fmt.Errorf("flush batch (%s): %w", where, err)
return true
}
func (it *kvStorageIterator) Hash() common.Hash { return it.hash }
func (it *kvStorageIterator) Slot() []byte { return it.it.Value() }
func (it *kvStorageIterator) Error() error { return it.it.Error() }
func (it *kvStorageIterator) Release() { it.it.Release() }
// GenerateTrie rebuilds all tries (storage + account) from flat snapshot data
// in the database. It reads account and storage snapshots from the KV store,
// builds tries using StackTrie with streaming node writes, and verifies the
// computed state root matches the expected root.
func GenerateTrie(db ethdb.Database, scheme string, root common.Hash) error {
acctIt := newKVAccountIterator(db)
defer acctIt.Release()
got, err := internal.GenerateTrieRoot(db, scheme, acctIt, common.Hash{}, internal.StackTrieGenerate, func(dst ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *internal.GenerateStats) (common.Hash, error) {
storageIt := newKVStorageIterator(db, accountHash)
defer storageIt.Release()
hash, err := internal.GenerateTrieRoot(dst, scheme, storageIt, accountHash, internal.StackTrieGenerate, nil, stat, false)
if err != nil {
return common.Hash{}, err
}
return hash, nil
}, internal.NewGenerateStats(), true)
if err != nil {
return err
}
if got != root {
return fmt.Errorf("state root mismatch: got %x, want %x", got, root)
} }
batch.Reset()
r.reopen()
return nil return nil
} }
// openFlatIterator opens a length-filtered HoldableIterator over a snapshot
// prefix, seeked to the given start key (relative to the prefix).
func openFlatIterator(db ethdb.Database, prefix, start []byte, suffixLen int) *internal.HoldableIterator {
it := db.NewIterator(prefix, start)
return internal.NewHoldableIterator(rawdb.NewKeyLengthIterator(it, len(prefix)+suffixLen))
}
// reopenFlatIterator releases `old` and returns a new HoldableIterator
// positioned at the same key, or an empty iterator if `old` is exhausted.
func reopenFlatIterator(db ethdb.Database, old *internal.HoldableIterator, prefix []byte, suffixLen int) *internal.HoldableIterator {
if !old.Next() {
old.Release()
return internal.NewHoldableIterator(memorydb.New().NewIterator(nil, nil))
}
// pebble's Key() slice is invalidated by Release. Copy first so the new
// iterator's lower bound isn't seeded from freed memory.
next := common.CopyBytes(old.Key())
old.Release()
return openFlatIterator(db, prefix, next[len(prefix):], suffixLen)
}
// generatePartition walks accounts whose first nibble equals `partition`,
// reconciling each account's Root with its flat storage and building
// both per-account storage subtries and the partition's slice of the
// account trie. Returns the partition's stripped subtree root blob, or
// nil if the partition had no accounts at all.
func generatePartition(ctx context.Context, cancel <-chan struct{}, db ethdb.Database, scheme string, partition byte, rangeStart, rangeEnd common.Hash, scanned, updated, deleted *atomic.Int64, pos *atomic.Uint64) ([]byte, error) {
iters := openRangeIterators(db, rangeStart)
defer iters.release()
batch := db.NewBatchWithSize(ethdb.IdealBatchSize)
// Account-trie builder for this partition. It is fed account keys with
// their leading nibble stripped and emits nodes at their absolute path
// (prefixed with the partition nibble), so they line up with the full
// trie without any post-hoc surgery.
//
// The subtree root is the only node emitted at path [partition]; we both
// persist it (so the top-level branch can reference it) and capture its
// bytes for assembleRoot, which needs them to either reference it or,
// in the single-partition case, fold the leading nibble back in.
var root []byte
acctTrie := trie.NewPartialStackTrie(partition, func(path []byte, hash common.Hash, blob []byte) {
if len(path) == 1 {
root = common.CopyBytes(blob)
}
rawdb.WriteTrieNode(batch, common.Hash{}, path, hash, blob, scheme)
})
// Iterate through all the accounts.
for iters.acct.Next() {
select {
case <-cancel:
return nil, ErrCancelled
case <-ctx.Done():
return nil, ctx.Err()
default:
}
key := iters.acct.Key()
var accountHash common.Hash
copy(accountHash[:], key[len(rawdb.SnapshotAccountPrefix):])
if bytes.Compare(accountHash[:], rangeEnd[:]) > 0 {
break
}
scanned.Add(1)
pos.Store(binary.BigEndian.Uint64(accountHash[:8]))
// Decode the account object
account, err := types.FullAccount(iters.acct.Value())
if err != nil {
return nil, fmt.Errorf("decode account %x: %w", accountHash, err)
}
// Build the account's storage trie from the flat storage snapshot.
// StackTrie's onTrieNode callback persists nodes as they finalize.
storageTrie := trie.NewStackTrie(func(path []byte, hash common.Hash, blob []byte) {
rawdb.WriteTrieNode(batch, accountHash, path, hash, blob, scheme)
})
// Compute the storage root by consuming matching slots from the
// shared storage iterator. The inner loop terminates on Hold()
// (slot belongs to a later account) or exhaustion.
lastDanglingAccount := make([]byte, common.HashLength)
for iters.stor.Next() {
// Re-check cancel.
select {
case <-cancel:
return nil, ErrCancelled
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var (
sk = iters.stor.Key()
storageAccount = sk[len(rawdb.SnapshotStoragePrefix) : len(rawdb.SnapshotStoragePrefix)+common.HashLength]
cmp = bytes.Compare(storageAccount, accountHash[:])
)
// The slot belongs to an account whose hash is smaller than the one
// currently being processed. This should be theoretically impossible,
// so log it loudly and delete the dangling entry from the flat state.
if cmp < 0 {
if !bytes.Equal(lastDanglingAccount, storageAccount) {
copy(lastDanglingAccount, storageAccount)
log.Error("Unexpected storage entries for dangling account", "expected", accountHash, "got", common.BytesToHash(storageAccount))
}
deleted.Add(1)
slotHash := sk[len(rawdb.SnapshotStoragePrefix)+common.HashLength:]
rawdb.DeleteStorageSnapshot(batch, common.BytesToHash(storageAccount), common.BytesToHash(slotHash))
if err := iters.flushIfFull(batch, "dangling"); err != nil {
return nil, err
}
continue
}
// The slot belongs to a later account. We're done with the current
// account's slots, but we don't want to lose this slot. The slot might
// belong to the next iteration of the account for-loop (or a later one).
// Hold() the iterator so the next Next() call will re-serve this same
// entry instead of advancing past it.
if cmp > 0 {
iters.stor.Hold()
break
}
// The slot belongs to this account so we add it to the StackTrie.
slotHash := sk[len(rawdb.SnapshotStoragePrefix)+common.HashLength:]
if err := storageTrie.Update(slotHash, iters.stor.Value()); err != nil {
return nil, fmt.Errorf("storage stack trie update for %x: %w", accountHash, err)
}
if err := iters.flushIfFull(batch, "storage"); err != nil {
return nil, err
}
}
if err := iters.stor.Error(); err != nil {
return nil, fmt.Errorf("storage iterator: %w", err)
}
computed := storageTrie.Hash()
// If account.Root was stale, rewrite the flat-state entry. Then feed
// the account, now with the correct Root, into this partition's
// account trie.
if computed != account.Root {
account.Root = computed
rawdb.WriteAccountSnapshot(batch, accountHash, types.SlimAccountRLP(*account))
updated.Add(1)
}
fullAccount, err := rlp.EncodeToBytes(account)
if err != nil {
return nil, fmt.Errorf("encode account %x: %w", accountHash, err)
}
if err := acctTrie.Update(accountHash[:], fullAccount); err != nil {
return nil, fmt.Errorf("account stack trie update for %x: %w", accountHash, err)
}
if err := iters.flushIfFull(batch, "account"); err != nil {
return nil, err
}
}
if err := iters.acct.Error(); err != nil {
return nil, fmt.Errorf("account iterator: %w", err)
}
// The account iterator is exhausted (or has advanced past this partition),
// but the storage iterator may still hold slots whose account hash falls
// within this partition's range. Those slots belong to no existing account
// and should be cleared.
lastDanglingTail := make([]byte, common.HashLength)
for iters.stor.Next() {
select {
case <-cancel:
return nil, ErrCancelled
case <-ctx.Done():
return nil, ctx.Err()
default:
}
sk := iters.stor.Key()
acct := sk[len(rawdb.SnapshotStoragePrefix) : len(rawdb.SnapshotStoragePrefix)+common.HashLength]
if bytes.Compare(acct, rangeEnd[:]) > 0 {
break
}
if !bytes.Equal(lastDanglingTail, acct) {
copy(lastDanglingTail, acct)
log.Error("Unexpected storage entries for dangling account", "addrhash", common.BytesToHash(acct))
}
deleted.Add(1)
slotHash := sk[len(rawdb.SnapshotStoragePrefix)+common.HashLength:]
rawdb.DeleteStorageSnapshot(batch, common.BytesToHash(acct), common.BytesToHash(slotHash))
if err := iters.flushIfFull(batch, "dangling tail"); err != nil {
return nil, err
}
}
if err := iters.stor.Error(); err != nil {
return nil, fmt.Errorf("storage iterator (dangling): %w", err)
}
// Finalize the partition's account trie. For a non-empty partition this
// emits the subtree root at path [partition], populating rootBlob. An empty
// partition never emits any node and leaves rootBlob at nil.
acctTrie.Hash()
if err := batch.Write(); err != nil {
return nil, fmt.Errorf("final partition batch write: %w", err)
}
return root, nil
}
// hashRanges returns hash pairs [start, end] that evenly partition the
// 256-bit hash space. The last partition absorbs the remainder so rounding
// doesn't leave hashes uncovered.
func hashRanges(total int) [][2]common.Hash {
step := new(big.Int).Sub(
new(big.Int).Div(
new(big.Int).Exp(common.Big2, common.Big256, nil),
big.NewInt(int64(total)),
),
common.Big1,
)
ranges := make([][2]common.Hash, total)
var next common.Hash
for i := range total {
last := common.BigToHash(new(big.Int).Add(next.Big(), step))
if i == total-1 {
last = common.MaxHash
}
ranges[i] = [2]common.Hash{next, last}
next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1))
}
return ranges
}
// GenerateTrie rebuilds all tries (storage + account) from flat snapshot
// data in the database. The account hash space is partitioned into 16
// slices aligned with the first-nibble branching of the MPT root. Each
// partition is processed by its own goroutine, which walks its slice,
// reconciles stale account.Root fields with flat storage, builds the
// per-account storage tries and the partition's slice of the account
// trie. Once every partition has produced its subtree root, the top-level
// branch is assembled and its hash verified against the expected root.
//
// Resume: on entry, any partition that has a "done" marker from a
// previous run is skipped. Its subtree blob is read from the marker
// and handed to assembleRoot directly. On a mid-run crash, only the
// in-flight partition(s) are redone.
func GenerateTrie(db ethdb.Database, scheme string, root common.Hash, cancel <-chan struct{}) (GenerateStats, error) {
var (
start = time.Now()
scanned atomic.Int64
updated atomic.Int64
deleted atomic.Int64
progress [numPartitions]atomic.Uint64
progressDone = make(chan struct{})
// partitionBlobs[i] holds the root node for partition i, or nil if
// the partition is empty.
partitionBlobs [numPartitions][]byte
)
go tickProgress(progressDone, start, &scanned, &updated, &progress)
defer close(progressDone)
// For each partition, either skip (prior done marker found) or run
// it. Prior runs can leave the partition's raw root blob in the done
// marker. We recover it here so assembleRoot has everything it needs.
var (
ranges = hashRanges(numPartitions)
eg, ctx = errgroup.WithContext(context.Background())
)
for i, r := range ranges {
partition := byte(i)
rangeStart, rangeEnd := r[0], r[1]
if blob, ok := rawdb.ReadGenerateTriePartitionDone(db, partition); ok {
partitionBlobs[partition] = blob
progress[partition].Store(partitionFinished)
continue
}
eg.Go(func() error {
start := time.Now()
blob, err := generatePartition(ctx, cancel, db, scheme, partition, rangeStart, rangeEnd, &scanned, &updated, &deleted, &progress[partition])
if err != nil {
return err
}
log.Info("Partition done", "partition", partition, "elapsed", common.PrettyDuration(time.Since(start)))
progress[partition].Store(partitionFinished)
partitionBlobs[partition] = blob
// Record completion only after the partition's batch has
// flushed inside generatePartition, so this marker appears
// on disk only when every write the partition did is durable.
rawdb.WriteGenerateTriePartitionDone(db, partition, blob)
return nil
})
}
// Wait until all the partitions are fully generated
if err := eg.Wait(); err != nil {
return GenerateStats{}, err
}
// Assemble the top-level root from the partition blobs, verify it
// matches the expected root, and clear all partition markers on
// success.
got, err := assembleRoot(db, scheme, partitionBlobs)
if err != nil {
return GenerateStats{}, fmt.Errorf("assemble root: %w", err)
}
if got != root {
return GenerateStats{}, fmt.Errorf("state root mismatch: got %x, want %x", got, root)
}
// Clear the partition progress marker, ending the generation process.
batch := db.NewBatch()
for i := range numPartitions {
rawdb.DeleteGenerateTriePartitionDone(batch, byte(i))
}
if err := batch.Write(); err != nil {
return GenerateStats{}, fmt.Errorf("clear partition markers: %w", err)
}
log.Info("Generated state trie", "scanned", scanned.Load(), "updated", updated.Load(), "dangling-slots", deleted.Load(), "elapsed", common.PrettyDuration(time.Since(start)))
return GenerateStats{
Scanned: scanned.Load(),
Updated: updated.Load(),
Deleted: deleted.Load(),
}, nil
}
// assembleRoot computes the canonical state root from the 16 partition subtree
// root blobs and persists the top-level node. Each partition was built with its
// leading nibble stripped, so its root blob is already the exact node the parent
// branch mounts in that slot, and the partition has already written it (and all
// its descendants) at their absolute paths. What's left depends on how many
// partitions ended up populated:
//
// - 0 populated: the state is empty, the root is types.EmptyRootHash and
// nothing is written.
//
// - 1 populated: there is no top-level branch; the canonical root is that
// lone partition's subtree with its leading nibble folded back in (see
// trie.MountPartitionRoot). The new root node is written. If the fold
// orphaned the old subtree root the partition left at [n], that node is
// also deleted.
//
// - 2+ populated: the canonical root is a 17-slot branch mounting each
// partition's subtree root by hash. The subtree roots are already on disk,
// so we only encode, hash, and persist the branch itself.
func assembleRoot(db ethdb.Database, scheme string, partitionBlobs [numPartitions][]byte) (common.Hash, error) {
var (
populated int
partition int // last populated index, read only when populated == 1
children [17][]byte
)
// Loop through all partitions and count how many are populated, while
// pre-filling the branch children array for the common 2+ case.
for i := range numPartitions {
if partitionBlobs[i] != nil {
populated++
partition = i
children[i] = crypto.Keccak256(partitionBlobs[i])
}
}
// No populated partitions: the state is empty.
if populated == 0 {
return types.EmptyRootHash, nil
}
// One populated partition: no top-level branch, so fold its leading nibble
// back into the subtree root.
if populated == 1 {
rootHash, rootBlob, isOrphaned, err := trie.MountPartitionRoot(partitionBlobs[partition], byte(partition))
if err != nil {
return common.Hash{}, fmt.Errorf("mount partition %d: %w", partition, err)
}
batch := db.NewBatch()
rawdb.WriteTrieNode(batch, common.Hash{}, nil, rootHash, rootBlob, scheme)
if isOrphaned {
// The folded root at nil does not reference [partition], so the copy
// generatePartition wrote there is now unreferenced. Delete it so the
// on-disk node set matches the canonical trie.
staleHash := crypto.Keccak256Hash(partitionBlobs[partition])
rawdb.DeleteTrieNode(batch, common.Hash{}, []byte{byte(partition)}, staleHash, scheme)
}
return rootHash, batch.Write()
}
// populated >= 2: mount each partition's subtree root (already persisted at
// path [i]) into a 17-slot branch by hash, using the children array filled
// above. Those hash references are valid because account-trie subtree roots
// are always >= 32 bytes.
rootBlob, rootHash, err := trie.AssembleBranch(children)
if err != nil {
return common.Hash{}, err
}
rawdb.WriteTrieNode(db, common.Hash{}, nil, rootHash, rootBlob, scheme)
return rootHash, nil
}
// tickProgress logs an aggregate progress line every 30 seconds until done
// is closed. Cheap: a handful of atomic loads and one log line per tick.
func tickProgress(done <-chan struct{}, start time.Time, scanned, updated *atomic.Int64, progress *[numPartitions]atomic.Uint64) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-done:
return
case <-ticker.C:
elapsed := time.Since(start)
fraction := progressFraction(progress)
eta := "n/a"
if fraction > 0.005 {
eta = common.PrettyDuration(time.Duration(float64(elapsed) * (1.0/fraction - 1.0))).String()
}
log.Info("Generating trie",
"progress", fmt.Sprintf("%.1f%%", fraction*100), "eta", eta,
"scanned", scanned.Load(), "updated", updated.Load(),
"elapsed", common.PrettyDuration(elapsed),
"acct/s", uint64(float64(scanned.Load())/elapsed.Seconds()))
}
}
}
// progressFraction averages each partition's iterator position (as a fraction
// of its hash range) into an overall completion estimate in [0, 1]. Keccak
// hashes are uniform, so keyspace position is a good proxy for work done.
func progressFraction(progress *[numPartitions]atomic.Uint64) float64 {
var total float64
for i := range numPartitions {
p := progress[i].Load()
switch {
case p == partitionFinished:
total += 1.0
case p == 0:
// not started yet
default:
rangeStart := uint64(i) * partitionRangeSize
if p > rangeStart {
rel := p - rangeStart
if rel > partitionRangeSize {
rel = partitionRangeSize
}
total += float64(rel) / float64(partitionRangeSize)
}
}
}
return total / float64(numPartitions)
}

View file

@ -18,12 +18,17 @@ package triedb
import ( import (
"bytes" "bytes"
"context"
"math/big"
"sort" "sort"
"sync/atomic"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"github.com/holiman/uint256" "github.com/holiman/uint256"
@ -60,8 +65,8 @@ func buildExpectedRoot(t *testing.T, accounts []testAccount) common.Hash {
return acctTrie.Hash() return acctTrie.Hash()
} }
// computeStorageRoot computes the storage trie root from sorted slots. // computeStorageRootFromSlots computes the storage trie root from sorted slots.
func computeStorageRoot(slots []testSlot) common.Hash { func computeStorageRootFromSlots(slots []testSlot) common.Hash {
sort.Slice(slots, func(i, j int) bool { sort.Slice(slots, func(i, j int) bool {
return bytes.Compare(slots[i].hash[:], slots[j].hash[:]) < 0 return bytes.Compare(slots[i].hash[:], slots[j].hash[:]) < 0
}) })
@ -74,7 +79,7 @@ func computeStorageRoot(slots []testSlot) common.Hash {
func TestGenerateTrieEmpty(t *testing.T) { func TestGenerateTrieEmpty(t *testing.T) {
db := rawdb.NewMemoryDatabase() db := rawdb.NewMemoryDatabase()
if err := GenerateTrie(db, rawdb.HashScheme, types.EmptyRootHash); err != nil { if _, err := GenerateTrie(db, rawdb.HashScheme, types.EmptyRootHash, nil); err != nil {
t.Fatalf("GenerateTrie on empty state failed: %v", err) t.Fatalf("GenerateTrie on empty state failed: %v", err)
} }
} }
@ -107,19 +112,17 @@ func TestGenerateTrieAccountsOnly(t *testing.T) {
} }
root := buildExpectedRoot(t, accounts) root := buildExpectedRoot(t, accounts)
if err := GenerateTrie(db, rawdb.HashScheme, root); err != nil { if _, err := GenerateTrie(db, rawdb.HashScheme, root, nil); err != nil {
t.Fatalf("GenerateTrie failed: %v", err) t.Fatalf("GenerateTrie failed: %v", err)
} }
} }
func TestGenerateTrieWithStorage(t *testing.T) { func TestGenerateTrieWithStorage(t *testing.T) {
db := rawdb.NewMemoryDatabase()
slots := []testSlot{ slots := []testSlot{
{hash: common.HexToHash("0xaa"), value: []byte{0x01, 0x02, 0x03}}, {hash: common.HexToHash("0xaa"), value: []byte{0x01, 0x02, 0x03}},
{hash: common.HexToHash("0xbb"), value: []byte{0x04, 0x05, 0x06}}, {hash: common.HexToHash("0xbb"), value: []byte{0x04, 0x05, 0x06}},
} }
storageRoot := computeStorageRoot(slots) storageRoot := computeStorageRootFromSlots(slots)
accounts := []testAccount{ accounts := []testAccount{
{ {
@ -142,20 +145,24 @@ func TestGenerateTrieWithStorage(t *testing.T) {
}, },
}, },
} }
// Write account snapshots
for _, a := range accounts {
rawdb.WriteAccountSnapshot(db, a.hash, types.SlimAccountRLP(a.account))
}
// Write storage snapshots
for _, a := range accounts {
for _, s := range a.storage {
rawdb.WriteStorageSnapshot(db, a.hash, s.hash, s.value)
}
}
root := buildExpectedRoot(t, accounts) root := buildExpectedRoot(t, accounts)
if err := GenerateTrie(db, rawdb.HashScheme, root); err != nil { for _, scheme := range []string{rawdb.HashScheme, rawdb.PathScheme} {
t.Fatalf("GenerateTrie failed: %v", err) t.Run(scheme, func(t *testing.T) {
db := rawdb.NewMemoryDatabase()
for _, a := range accounts {
rawdb.WriteAccountSnapshot(db, a.hash, types.SlimAccountRLP(a.account))
for _, s := range a.storage {
rawdb.WriteStorageSnapshot(db, a.hash, s.hash, s.value)
}
}
if _, err := GenerateTrie(db, scheme, root, nil); err != nil {
t.Fatalf("GenerateTrie failed: %v", err)
}
if scheme == rawdb.PathScheme {
assertCanonicalNodes(t, db, accounts)
}
})
} }
} }
@ -171,8 +178,615 @@ func TestGenerateTrieRootMismatch(t *testing.T) {
rawdb.WriteAccountSnapshot(db, common.HexToHash("0x01"), types.SlimAccountRLP(acct)) rawdb.WriteAccountSnapshot(db, common.HexToHash("0x01"), types.SlimAccountRLP(acct))
wrongRoot := common.HexToHash("0xdeadbeef") wrongRoot := common.HexToHash("0xdeadbeef")
err := GenerateTrie(db, rawdb.HashScheme, wrongRoot) _, err := GenerateTrie(db, rawdb.HashScheme, wrongRoot, nil)
if err == nil { if err == nil {
t.Fatal("expected error for root mismatch, got nil") t.Fatal("expected error for root mismatch, got nil")
} }
} }
// TestGenerateTrieFixesStaleRoots writes flat state with a mix of stale,
// empty, and correct account roots, then checks that GenerateTrie produces
// the expected state root.
func TestGenerateTrieFixesStaleRoots(t *testing.T) {
const n = 300
accounts := make([]testAccount, 0, n)
for i := 0; i < n; i++ {
addr := common.BytesToAddress([]byte{byte(i >> 8), byte(i)})
hash := crypto.Keccak256Hash(addr[:])
acc := testAccount{
hash: hash,
account: types.StateAccount{
Nonce: uint64(i),
Balance: uint256.NewInt(uint64(i + 1)),
Root: types.EmptyRootHash,
CodeHash: types.EmptyCodeHash.Bytes(),
},
}
// Every third account has no storage; the rest get slots.
if i%3 != 0 {
acc.storage = []testSlot{
{hash: common.BytesToHash([]byte{byte(i), 0xaa}), value: []byte{byte(i), 0x01}},
{hash: common.BytesToHash([]byte{byte(i), 0xbb}), value: []byte{byte(i), 0x02}},
}
acc.account.Root = computeStorageRootFromSlots(acc.storage)
}
accounts = append(accounts, acc)
}
// Expected state root with all Roots correct.
expectedRoot := buildExpectedRoot(t, accounts)
for _, scheme := range []string{rawdb.HashScheme, rawdb.PathScheme} {
t.Run(scheme, func(t *testing.T) {
db := rawdb.NewMemoryDatabase()
// Write flat state. Storage-bearing accounts rotate through three
// on-disk Root states that GenerateTrie's pre-pass must all bring
// into alignment:
// - stale non-empty Root
// - stale empty Root
// - correct Root
for i, a := range accounts {
for _, s := range a.storage {
rawdb.WriteStorageSnapshot(db, a.hash, s.hash, s.value)
}
onDisk := a.account
if len(a.storage) > 0 {
switch i % 3 {
case 0:
onDisk.Root = common.BytesToHash([]byte{byte(i), 0xde, 0xad})
case 1:
onDisk.Root = types.EmptyRootHash
}
}
rawdb.WriteAccountSnapshot(db, a.hash, types.SlimAccountRLP(onDisk))
}
if _, err := GenerateTrie(db, scheme, expectedRoot, nil); err != nil {
t.Fatalf("GenerateTrie failed: %v", err)
}
if scheme == rawdb.PathScheme {
assertCanonicalNodes(t, db, accounts)
}
})
}
}
// TestGenerateTrieCancel verifies GenerateTrie respects the cancel channel.
func TestGenerateTrieCancel(t *testing.T) {
t.Parallel()
db := rawdb.NewMemoryDatabase()
for i := 0; i < 100; i++ {
addr := common.BytesToAddress([]byte{byte(i)})
hash := crypto.Keccak256Hash(addr[:])
rawdb.WriteAccountSnapshot(db, hash, types.SlimAccountRLP(types.StateAccount{
Balance: uint256.NewInt(1),
Root: types.EmptyRootHash,
CodeHash: types.EmptyCodeHash[:],
}))
}
cancel := make(chan struct{})
close(cancel)
if _, err := GenerateTrie(db, rawdb.HashScheme, common.Hash{}, cancel); err != ErrCancelled {
t.Fatalf("expected ErrCancelled, got %v", err)
}
}
// TestGenerateTrieOrphanStorage exercises dangling-slot cleanup: flat storage
// entries for an accountHash that has no corresponding account snapshot must
// be deleted, regardless of whether they sit before, between, or after the
// live accounts within a partition. The state root must match and the
// Deleted counter must reflect every dangling entry.
func TestGenerateTrieOrphanStorage(t *testing.T) {
db := rawdb.NewMemoryDatabase()
// Two legitimate accounts in the same partition (first nibble 0x5) so
// orphans can be placed before, between, and after them in the shared
// per-partition storage iterator.
liveA := common.HexToHash("0x5300000000000000000000000000000000000000000000000000000000000000")
liveB := common.HexToHash("0x5900000000000000000000000000000000000000000000000000000000000000")
slotsA := []testSlot{{hash: common.HexToHash("0xaa"), value: []byte{0xa1}}}
slotsB := []testSlot{{hash: common.HexToHash("0xbb"), value: []byte{0xb1}}}
accounts := []testAccount{
{
hash: liveA,
account: types.StateAccount{
Nonce: 1,
Balance: uint256.NewInt(1),
Root: computeStorageRootFromSlots(slotsA),
CodeHash: types.EmptyCodeHash.Bytes(),
},
storage: slotsA,
},
{
hash: liveB,
account: types.StateAccount{
Nonce: 2,
Balance: uint256.NewInt(2),
Root: computeStorageRootFromSlots(slotsB),
CodeHash: types.EmptyCodeHash.Bytes(),
},
storage: slotsB,
},
}
for _, a := range accounts {
rawdb.WriteAccountSnapshot(db, a.hash, types.SlimAccountRLP(a.account))
for _, s := range a.storage {
rawdb.WriteStorageSnapshot(db, a.hash, s.hash, s.value)
}
}
// Dangling slots at three positions within partition 5:
// before liveA, between liveA and liveB, after liveB.
orphans := []struct {
account common.Hash
slots []testSlot
}{
{
account: common.HexToHash("0x5000000000000000000000000000000000000000000000000000000000000000"),
slots: []testSlot{
{hash: common.HexToHash("0x11"), value: []byte{0x01}},
{hash: common.HexToHash("0x22"), value: []byte{0x02}},
},
},
{
account: common.HexToHash("0x5600000000000000000000000000000000000000000000000000000000000000"),
slots: []testSlot{{hash: common.HexToHash("0x33"), value: []byte{0x03}}},
},
{
account: common.HexToHash("0x5d00000000000000000000000000000000000000000000000000000000000000"),
slots: []testSlot{
{hash: common.HexToHash("0x44"), value: []byte{0x04}},
{hash: common.HexToHash("0x55"), value: []byte{0x05}},
},
},
}
var totalOrphans int64
for _, o := range orphans {
for _, s := range o.slots {
rawdb.WriteStorageSnapshot(db, o.account, s.hash, s.value)
totalOrphans++
}
}
expectedRoot := buildExpectedRoot(t, accounts)
stats, err := GenerateTrie(db, rawdb.HashScheme, expectedRoot, nil)
if err != nil {
t.Fatalf("GenerateTrie with orphan storage failed: %v", err)
}
if stats.Deleted != totalOrphans {
t.Errorf("Deleted counter = %d, want %d", stats.Deleted, totalOrphans)
}
for _, o := range orphans {
for _, s := range o.slots {
if v := rawdb.ReadStorageSnapshot(db, o.account, s.hash); v != nil {
t.Errorf("dangling slot %x/%x not cleared, got %x", o.account, s.hash, v)
}
}
}
}
// TestGenerateTriePartialResume proves that the resume path actually
// fires when a partition's done marker is present.
func TestGenerateTriePartialResume(t *testing.T) {
// Build the account set. Empty storage keeps the test focused on the
// account-trie resume path.
const n = 200
accounts := make([]testAccount, 0, n)
for i := 0; i < n; i++ {
addr := common.BytesToAddress([]byte{byte(i >> 8), byte(i)})
hash := crypto.Keccak256Hash(addr[:])
accounts = append(accounts, testAccount{
hash: hash,
account: types.StateAccount{
Nonce: uint64(i),
Balance: uint256.NewInt(uint64(i + 1)),
Root: types.EmptyRootHash,
CodeHash: types.EmptyCodeHash.Bytes(),
},
})
}
expectedRoot := buildExpectedRoot(t, accounts)
for _, scheme := range []string{rawdb.HashScheme, rawdb.PathScheme} {
t.Run(scheme, func(t *testing.T) {
db := rawdb.NewMemoryDatabase()
// Step 1: write the account snapshots for this run.
for _, a := range accounts {
rawdb.WriteAccountSnapshot(db, a.hash, types.SlimAccountRLP(a.account))
}
// Step 2: run every partition once to populate trie nodes on disk
// and capture each partition's raw root blob.
var (
scanned atomic.Int64
updated atomic.Int64
deleted atomic.Int64
)
ranges := hashRanges(numPartitions)
blobs := make([][]byte, numPartitions)
for i, r := range ranges {
var pos atomic.Uint64
blob, err := generatePartition(context.Background(), nil, db, scheme, byte(i), r[0], r[1], &scanned, &updated, &deleted, &pos)
if err != nil {
t.Fatalf("pre-run partition %d: %v", i, err)
}
blobs[i] = blob
}
// Step 3: pre-seed done markers for even partitions only.
for i := 0; i < numPartitions; i++ {
if i%2 == 0 {
rawdb.WriteGenerateTriePartitionDone(db, byte(i), blobs[i])
}
}
// Step 4: delete flat-state account snapshots for every account that
// lives in an even partition. After this, rerunning generatePartition for
// an even partition would find no accounts and produce a nil blob,
// so a correct final root requires the resume path.
numDeleted := 0
for _, a := range accounts {
if (a.hash[0]>>4)%2 == 0 {
rawdb.DeleteAccountSnapshot(db, a.hash)
numDeleted++
}
}
if numDeleted == 0 {
t.Fatal("test setup failure: no accounts fell in even partitions")
}
// Step 5: run GenerateTrie. Success implies resume actually consulted
// the markers. Without it, even partitions would yield nil blobs and
// the root check inside GenerateTrie would fail.
if _, err := GenerateTrie(db, scheme, expectedRoot, nil); err != nil {
t.Fatalf("partial-resume GenerateTrie failed: %v", err)
}
// All markers cleared on success.
for i := 0; i < numPartitions; i++ {
if _, ok := rawdb.ReadGenerateTriePartitionDone(db, byte(i)); ok {
t.Errorf("partition %d marker not cleared after successful resume", i)
}
}
if scheme == rawdb.PathScheme {
assertCanonicalNodes(t, db, accounts)
}
})
}
}
// TestHashRanges checks that hashRanges fully and contiguously covers the
// 256-bit hash space, with the last range absorbing the rounding remainder.
func TestHashRanges(t *testing.T) {
for _, total := range []int{1, 2, 16, 256} {
ranges := hashRanges(total)
if len(ranges) != total {
t.Fatalf("total=%d: got %d ranges, want %d", total, len(ranges), total)
}
if ranges[0][0] != (common.Hash{}) {
t.Errorf("total=%d: first range starts at %x, want zero", total, ranges[0][0])
}
if ranges[total-1][1] != common.MaxHash {
t.Errorf("total=%d: last range ends at %x, want MaxHash", total, ranges[total-1][1])
}
for i, r := range ranges {
if r[0].Big().Cmp(r[1].Big()) > 0 {
t.Errorf("total=%d: range %d malformed: start %x > end %x", total, i, r[0], r[1])
}
if i == 0 {
continue
}
gap := new(big.Int).Sub(r[0].Big(), ranges[i-1][1].Big())
if gap.Cmp(common.Big1) != 0 {
t.Errorf("total=%d: range %d not contiguous with %d (gap=%s)", total, i, i-1, gap)
}
}
}
}
// TestGenerateTriePathSchemeNodeSet runs GenerateTrie on the path scheme and
// checks the persisted account-trie node set against a canonical StackTrie. A
// root-only check can't see the single-partition orphan, but a node-set diff can.
func TestGenerateTriePathSchemeNodeSet(t *testing.T) {
mkAccount := func(hashHex string) testAccount {
// Empty storage and no code, so the account trie is the only trie built
// and the canonical reference is a plain StackTrie over the accounts.
return testAccount{
hash: common.HexToHash(hashHex),
account: types.StateAccount{
Nonce: 1,
Balance: uint256.NewInt(1),
Root: types.EmptyRootHash,
CodeHash: types.EmptyCodeHash.Bytes(),
},
}
}
cases := []struct {
name string
accounts []testAccount
}{
{
// One populated partition whose subtree root is a leaf. The node the
// partition wrote at [5] is left unreferenced, so GenerateTrie has to
// delete it.
name: "single account, leaf root",
accounts: []testAccount{mkAccount("0x5a00000000000000000000000000000000000000000000000000000000000000")},
},
{
// One populated partition whose subtree root is an extension. Like the
// leaf case, the node at [5] is left unreferenced and must be deleted.
name: "two accounts sharing two nibbles, extension root",
accounts: []testAccount{
mkAccount("0x5300000000000000000000000000000000000000000000000000000000000000"),
mkAccount("0x5320000000000000000000000000000000000000000000000000000000000000"),
},
},
{
// One populated partition whose subtree root is a branch. Here [5] stays
// referenced by the new root, so nothing is orphaned.
name: "two accounts diverging at second nibble, branch root",
accounts: []testAccount{
mkAccount("0x5a00000000000000000000000000000000000000000000000000000000000000"),
mkAccount("0x5f00000000000000000000000000000000000000000000000000000000000000"),
},
},
{
// Several populated partitions. Every [i] stays referenced by the top
// branch, so nothing is orphaned.
name: "accounts across multiple partitions",
accounts: []testAccount{
mkAccount("0x1000000000000000000000000000000000000000000000000000000000000000"),
mkAccount("0x5a00000000000000000000000000000000000000000000000000000000000000"),
mkAccount("0x5f00000000000000000000000000000000000000000000000000000000000000"),
mkAccount("0xc000000000000000000000000000000000000000000000000000000000000000"),
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
db := rawdb.NewMemoryDatabase()
for _, a := range tc.accounts {
rawdb.WriteAccountSnapshot(db, a.hash, types.SlimAccountRLP(a.account))
}
root := buildExpectedRoot(t, tc.accounts)
if _, err := GenerateTrie(db, rawdb.PathScheme, root, nil); err != nil {
t.Fatalf("GenerateTrie (path scheme) failed: %v", err)
}
assertCanonicalNodes(t, db, tc.accounts)
})
}
}
// assertCanonicalNodes checks that the trie nodes persisted under the path
// scheme exactly match the canonical set: the account-trie nodes a StackTrie
// over the accounts emits, plus, per account with slots, the storage-trie nodes
// a StackTrie over those slots emits. accounts must carry their final Root
// values (post storage-root reconciliation).
func assertCanonicalNodes(t *testing.T, db ethdb.Database, accounts []testAccount) {
t.Helper()
sorted := make([]testAccount, len(accounts))
copy(sorted, accounts)
sort.Slice(sorted, func(i, j int) bool {
return bytes.Compare(sorted[i].hash[:], sorted[j].hash[:]) < 0
})
// Canonical account-trie node paths.
wantAccount := make(map[string]struct{})
acct := trie.NewStackTrie(func(path []byte, _ common.Hash, _ []byte) {
wantAccount[string(path)] = struct{}{}
})
for i := range sorted {
data, err := rlp.EncodeToBytes(&sorted[i].account)
if err != nil {
t.Fatal(err)
}
if err := acct.Update(sorted[i].hash[:], data); err != nil {
t.Fatal(err)
}
}
acct.Hash()
// Canonical storage-trie node keys (accountHash ++ path), one StackTrie per
// account that has slots.
wantStorage := make(map[string]struct{})
for _, a := range accounts {
if len(a.storage) == 0 {
continue
}
slots := make([]testSlot, len(a.storage))
copy(slots, a.storage)
sort.Slice(slots, func(i, j int) bool {
return bytes.Compare(slots[i].hash[:], slots[j].hash[:]) < 0
})
owner := a.hash
st := trie.NewStackTrie(func(path []byte, _ common.Hash, _ []byte) {
wantStorage[string(owner[:])+string(path)] = struct{}{}
})
for _, s := range slots {
if err := st.Update(s.hash[:], s.value); err != nil {
t.Fatal(err)
}
}
st.Hash()
}
assertSameNodeSet(t, "account", diskNodeKeys(db, rawdb.TrieNodeAccountPrefix), wantAccount)
assertSameNodeSet(t, "storage", diskNodeKeys(db, rawdb.TrieNodeStoragePrefix), wantStorage)
}
// diskNodeKeys returns the set of path-scheme node keys with the given prefix
// stripped (account: hexPath; storage: accountHash ++ hexPath).
func diskNodeKeys(db ethdb.Database, prefix []byte) map[string]struct{} {
keys := make(map[string]struct{})
it := db.NewIterator(prefix, nil)
defer it.Release()
for it.Next() {
keys[string(it.Key()[len(prefix):])] = struct{}{}
}
return keys
}
// assertSameNodeSet fails if got and want differ, reporting each offending key.
func assertSameNodeSet(t *testing.T, label string, got, want map[string]struct{}) {
t.Helper()
for k := range got {
if _, ok := want[k]; !ok {
t.Errorf("%s-trie: extra node on disk at %x", label, k)
}
}
for k := range want {
if _, ok := got[k]; !ok {
t.Errorf("%s-trie: missing node on disk at %x", label, k)
}
}
}
// peakBatch records the largest ValueSize the batch reaches before any flush.
type peakBatch struct {
ethdb.Batch
peak *int
}
func (b *peakBatch) Write() error {
if s := b.ValueSize(); s > *b.peak {
*b.peak = s
}
return b.Batch.Write()
}
// peakBatchDB hands out peakBatch batches so a test can observe how large the
// write batch grows between flushes.
type peakBatchDB struct {
ethdb.Database
peak *int
}
func (d peakBatchDB) NewBatch() ethdb.Batch {
return &peakBatch{Batch: d.Database.NewBatch(), peak: d.peak}
}
func (d peakBatchDB) NewBatchWithSize(size int) ethdb.Batch {
return &peakBatch{Batch: d.Database.NewBatchWithSize(size), peak: d.peak}
}
// TestGenerateTrieBatchFlush drives each of generatePartition's batch-flush
// sites past IdealBatchSize and checks the write batch stays bounded (so the
// flush fired) without dropping or skipping any entry.
func TestGenerateTrieBatchFlush(t *testing.T) {
// h builds a unique partition-0 hash (leading nibble 0) from an int, used
// for both account hashes and storage slot hashes.
h := func(i int) common.Hash {
return common.BytesToHash([]byte{0x00, byte(i >> 16), byte(i >> 8), byte(i)})
}
acct := func(root common.Hash) types.StateAccount {
return types.StateAccount{Nonce: 1, Balance: uint256.NewInt(1), Root: root, CodeHash: types.EmptyCodeHash.Bytes()}
}
// Each fixture writes this many entries into partition 0, enough that one
// flush site overflows IdealBatchSize several times over.
const n = 5000
cases := []struct {
name string
build func(db ethdb.Database)
wantScanned int64
wantDeleted int64
}{
{
// Dangling account (no snapshot) sorting before a live account, so its
// slots are deleted inline (cmp < 0) while the live account is built.
name: "inline dangling deletes",
build: func(db ethdb.Database) {
for i := 0; i < n; i++ {
rawdb.WriteStorageSnapshot(db, h(1), h(i), []byte{0x01})
}
rawdb.WriteAccountSnapshot(db, h(0xffffff), types.SlimAccountRLP(acct(types.EmptyRootHash)))
},
wantScanned: 1,
wantDeleted: n,
},
{
// Dangling account with no live account at all, so every slot is
// cleared by the tail loop after the account iterator is exhausted.
name: "tail dangling deletes",
build: func(db ethdb.Database) {
for i := 0; i < n; i++ {
rawdb.WriteStorageSnapshot(db, h(1), h(i), []byte{0x01})
}
},
wantScanned: 0,
wantDeleted: n,
},
{
// One account whose storage trie alone overflows the batch, so the
// cmp == 0 storage path flushes mid-build. updated stays 0 only if
// every slot survived the flush and iterator reopen.
name: "single account, large storage",
build: func(db ethdb.Database) {
slots := make([]testSlot, n)
for i := range slots {
slots[i] = testSlot{hash: h(i), value: bytes.Repeat([]byte{byte(i)}, 32)}
}
rawdb.WriteAccountSnapshot(db, h(7), types.SlimAccountRLP(acct(computeStorageRootFromSlots(slots))))
for _, s := range slots {
rawdb.WriteStorageSnapshot(db, h(7), s.hash, s.value)
}
},
wantScanned: 1,
wantDeleted: 0,
},
{
// Many empty-storage accounts so the account trie alone overflows the
// batch, exercising the per-account flush. A skipped account would not
// be counted in scanned.
name: "many accounts",
build: func(db ethdb.Database) {
for i := 0; i < n; i++ {
rawdb.WriteAccountSnapshot(db, h(i), types.SlimAccountRLP(acct(types.EmptyRootHash)))
}
},
wantScanned: n,
wantDeleted: 0,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
db := rawdb.NewMemoryDatabase()
tc.build(db)
peak := 0
var scanned, updated, deleted atomic.Int64
var pos atomic.Uint64
ranges := hashRanges(numPartitions)
if _, err := generatePartition(context.Background(), nil, peakBatchDB{Database: db, peak: &peak},
rawdb.HashScheme, 0, ranges[0][0], ranges[0][1], &scanned, &updated, &deleted, &pos); err != nil {
t.Fatalf("generatePartition: %v", err)
}
if scanned.Load() != tc.wantScanned {
t.Errorf("scanned = %d, want %d (an account was skipped?)", scanned.Load(), tc.wantScanned)
}
if deleted.Load() != tc.wantDeleted {
t.Errorf("deleted = %d, want %d", deleted.Load(), tc.wantDeleted)
}
if updated.Load() != 0 {
t.Errorf("updated = %d, want 0 (a storage slot was dropped across a flush?)", updated.Load())
}
// The batch must have stayed bounded. Without this site's flush its
// full write set (far larger than IdealBatchSize) buffers into one batch.
if peak > 2*ethdb.IdealBatchSize {
t.Errorf("peak batch size %d exceeded 2*IdealBatchSize (%d); flush did not fire", peak, 2*ethdb.IdealBatchSize)
}
})
}
}

View file

@ -21,6 +21,7 @@ package internal
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"math" "math"
"runtime" "runtime"
@ -36,6 +37,10 @@ import (
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
) )
// ErrCancelled is returned by GenerateTrieRoot when the cancel channel is
// closed mid-run.
var ErrCancelled = errors.New("cancelled")
// Iterator is an iterator to step over all the accounts or the specific // Iterator is an iterator to step over all the accounts or the specific
// storage in a snapshot which may or may not be composed of multiple layers. // storage in a snapshot which may or may not be composed of multiple layers.
type Iterator interface { type Iterator interface {
@ -228,7 +233,7 @@ func RunReport(stats *GenerateStats, stop chan bool) {
// GenerateTrieRoot generates the trie hash based on the snapshot iterator. // GenerateTrieRoot generates the trie hash based on the snapshot iterator.
// It can be used for generating account trie, storage trie or even the // It can be used for generating account trie, storage trie or even the
// whole state which connects the accounts and the corresponding storages. // whole state which connects the accounts and the corresponding storages.
func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, account common.Hash, generatorFn TrieGeneratorFn, leafCallback LeafCallbackFn, stats *GenerateStats, report bool) (common.Hash, error) { func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, account common.Hash, generatorFn TrieGeneratorFn, leafCallback LeafCallbackFn, stats *GenerateStats, report bool, cancel <-chan struct{}) (common.Hash, error) {
var ( var (
in = make(chan TrieKV) // chan to pass leaves in = make(chan TrieKV) // chan to pass leaves
out = make(chan common.Hash, 1) // chan to collect result out = make(chan common.Hash, 1) // chan to collect result
@ -279,6 +284,14 @@ func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, accou
) )
// Start to feed leaves // Start to feed leaves
for it.Next() { for it.Next() {
// Top-of-loop cancel check. Cheap non-blocking peek so a closed
// cancel channel is observed without waiting for the blocking
// operations below.
select {
case <-cancel:
return stop(ErrCancelled)
default:
}
if account == (common.Hash{}) { if account == (common.Hash{}) {
var ( var (
err error err error
@ -291,8 +304,14 @@ func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, accou
} }
} else { } else {
// Wait until the semaphore allows us to continue, aborting if // Wait until the semaphore allows us to continue, aborting if
// a sub-task failed // a sub-task failed or the caller cancelled.
if err := <-results; err != nil { var err error
select {
case err = <-results:
case <-cancel:
return stop(ErrCancelled)
}
if err != nil {
results <- nil // stop will drain the results, add a noop back for this error we just consumed results <- nil // stop will drain the results, add a noop back for this error we just consumed
return stop(err) return stop(err)
} }
@ -322,7 +341,13 @@ func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, accou
} else { } else {
leaf = TrieKV{it.Hash(), common.CopyBytes(it.(StorageIterator).Slot())} leaf = TrieKV{it.Hash(), common.CopyBytes(it.(StorageIterator).Slot())}
} }
in <- leaf // Escape on cancel so we don't deadlock if the generator goroutine is slow
// and the caller gave up.
select {
case in <- leaf:
case <-cancel:
return stop(ErrCancelled)
}
// Accumulate the generation statistic if it's required. // Accumulate the generation statistic if it's required.
processed++ processed++

View file

@ -0,0 +1,55 @@
// Copyright 2026 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package internal
import (
"errors"
"testing"
"github.com/ethereum/go-ethereum/common"
)
// fakeStorageIterator is a StorageIterator over a fixed list of slots.
type fakeStorageIterator struct {
count int
idx int
}
func (it *fakeStorageIterator) Next() bool {
if it.idx >= it.count {
return false
}
it.idx++
return true
}
func (it *fakeStorageIterator) Error() error { return nil }
func (it *fakeStorageIterator) Hash() common.Hash { return common.BytesToHash([]byte{byte(it.idx)}) }
func (it *fakeStorageIterator) Slot() []byte { return []byte{byte(it.idx)} }
func (it *fakeStorageIterator) Release() {}
// TestGenerateTrieRootCancel verifies that GenerateTrieRoot aborts with
// ErrCancelled when the cancel channel is closed.
func TestGenerateTrieRootCancel(t *testing.T) {
t.Parallel()
it := &fakeStorageIterator{count: 10_000}
cancel := make(chan struct{})
close(cancel)
_, err := GenerateTrieRoot(nil, "", it, common.HexToHash("0xaa"), StackTrieGenerate, nil, nil, false, cancel)
if !errors.Is(err, ErrCancelled) {
t.Fatalf("expected ErrCancelled, got %v", err)
}
}

View file

@ -14,31 +14,31 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package pathdb package internal
import ( import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
) )
// holdableIterator is a wrapper of underlying database iterator. It extends // HoldableIterator is a wrapper of underlying database iterator. It extends
// the basic iterator interface by adding Hold which can hold the element // the basic iterator interface by adding Hold which can hold the element
// locally where the iterator is currently located and serve it up next time. // locally where the iterator is currently located and serve it up next time.
type holdableIterator struct { type HoldableIterator struct {
it ethdb.Iterator it ethdb.Iterator
key []byte key []byte
val []byte val []byte
atHeld bool atHeld bool
} }
// newHoldableIterator initializes the holdableIterator with the given iterator. // NewHoldableIterator initializes the HoldableIterator with the given iterator.
func newHoldableIterator(it ethdb.Iterator) *holdableIterator { func NewHoldableIterator(it ethdb.Iterator) *HoldableIterator {
return &holdableIterator{it: it} return &HoldableIterator{it: it}
} }
// Hold holds the element locally where the iterator is currently located which // Hold holds the element locally where the iterator is currently located which
// can be served up next time. // can be served up next time.
func (it *holdableIterator) Hold() { func (it *HoldableIterator) Hold() {
if it.it.Key() == nil { if it.it.Key() == nil {
return // nothing to hold return // nothing to hold
} }
@ -49,7 +49,7 @@ func (it *holdableIterator) Hold() {
// Next moves the iterator to the next key/value pair. It returns whether the // Next moves the iterator to the next key/value pair. It returns whether the
// iterator is exhausted. // iterator is exhausted.
func (it *holdableIterator) Next() bool { func (it *HoldableIterator) Next() bool {
if !it.atHeld && it.key != nil { if !it.atHeld && it.key != nil {
it.atHeld = true it.atHeld = true
} else if it.atHeld { } else if it.atHeld {
@ -65,11 +65,11 @@ func (it *holdableIterator) Next() bool {
// Error returns any accumulated error. Exhausting all the key/value pairs // Error returns any accumulated error. Exhausting all the key/value pairs
// is not considered to be an error. // is not considered to be an error.
func (it *holdableIterator) Error() error { return it.it.Error() } func (it *HoldableIterator) Error() error { return it.it.Error() }
// Release releases associated resources. Release should always succeed and can // Release releases associated resources. Release should always succeed and can
// be called multiple times without causing error. // be called multiple times without causing error.
func (it *holdableIterator) Release() { func (it *HoldableIterator) Release() {
it.atHeld = false it.atHeld = false
it.key = nil it.key = nil
it.val = nil it.val = nil
@ -79,7 +79,7 @@ func (it *holdableIterator) Release() {
// Key returns the key of the current key/value pair, or nil if done. The caller // Key returns the key of the current key/value pair, or nil if done. The caller
// should not modify the contents of the returned slice, and its contents may // should not modify the contents of the returned slice, and its contents may
// change on the next call to Next. // change on the next call to Next.
func (it *holdableIterator) Key() []byte { func (it *HoldableIterator) Key() []byte {
if it.key != nil { if it.key != nil {
return it.key return it.key
} }
@ -89,7 +89,7 @@ func (it *holdableIterator) Key() []byte {
// Value returns the value of the current key/value pair, or nil if done. The // Value returns the value of the current key/value pair, or nil if done. The
// caller should not modify the contents of the returned slice, and its contents // caller should not modify the contents of the returned slice, and its contents
// may change on the next call to Next. // may change on the next call to Next.
func (it *holdableIterator) Value() []byte { func (it *HoldableIterator) Value() []byte {
if it.val != nil { if it.val != nil {
return it.val return it.val
} }

View file

@ -14,7 +14,7 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package pathdb package internal
import ( import (
"bytes" "bytes"
@ -39,7 +39,7 @@ func TestIteratorHold(t *testing.T) {
} }
} }
// Iterate over the database with the given configs and verify the results // Iterate over the database with the given configs and verify the results
it, idx := newHoldableIterator(db.NewIterator(nil, nil)), 0 it, idx := NewHoldableIterator(db.NewIterator(nil, nil)), 0
// Nothing should be affected for calling Discard on non-initialized iterator // Nothing should be affected for calling Discard on non-initialized iterator
it.Hold() it.Hold()
@ -108,20 +108,20 @@ func TestReopenIterator(t *testing.T) {
} }
db = rawdb.NewMemoryDatabase() db = rawdb.NewMemoryDatabase()
reopen = func(db ethdb.KeyValueStore, iter *holdableIterator) *holdableIterator { reopen = func(db ethdb.KeyValueStore, iter *HoldableIterator) *HoldableIterator {
if !iter.Next() { if !iter.Next() {
iter.Release() iter.Release()
return newHoldableIterator(memorydb.New().NewIterator(nil, nil)) return NewHoldableIterator(memorydb.New().NewIterator(nil, nil))
} }
next := iter.Key() next := iter.Key()
iter.Release() iter.Release()
return newHoldableIterator(db.NewIterator(rawdb.SnapshotAccountPrefix, next[1:])) return NewHoldableIterator(db.NewIterator(rawdb.SnapshotAccountPrefix, next[1:]))
} }
) )
for key, val := range content { for key, val := range content {
rawdb.WriteAccountSnapshot(db, key, []byte(val)) rawdb.WriteAccountSnapshot(db, key, []byte(val))
} }
checkVal := func(it *holdableIterator, index int) { checkVal := func(it *HoldableIterator, index int) {
if !bytes.Equal(it.Key(), append(rawdb.SnapshotAccountPrefix, order[index].Bytes()...)) { if !bytes.Equal(it.Key(), append(rawdb.SnapshotAccountPrefix, order[index].Bytes()...)) {
t.Fatalf("Unexpected data entry key, want %v got %v", order[index], it.Key()) t.Fatalf("Unexpected data entry key, want %v got %v", order[index], it.Key())
} }
@ -131,7 +131,7 @@ func TestReopenIterator(t *testing.T) {
} }
// Iterate over the database with the given configs and verify the results // Iterate over the database with the given configs and verify the results
dbIter := db.NewIterator(rawdb.SnapshotAccountPrefix, nil) dbIter := db.NewIterator(rawdb.SnapshotAccountPrefix, nil)
iter, idx := newHoldableIterator(rawdb.NewKeyLengthIterator(dbIter, 1+common.HashLength)), -1 iter, idx := NewHoldableIterator(rawdb.NewKeyLengthIterator(dbIter, 1+common.HashLength)), -1
idx++ idx++
iter.Next() iter.Next()

View file

@ -28,6 +28,7 @@ import (
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/triedb/internal"
) )
const ( const (
@ -91,12 +92,12 @@ func (gs *generatorStats) log(msg string, root common.Hash, marker []byte) {
// current generation cycle. It must be recreated if the generation cycle is // current generation cycle. It must be recreated if the generation cycle is
// restarted. // restarted.
type generatorContext struct { type generatorContext struct {
root common.Hash // State root of the generation target root common.Hash // State root of the generation target
account *holdableIterator // Iterator of account snapshot data account *internal.HoldableIterator // Iterator of account snapshot data
storage *holdableIterator // Iterator of storage snapshot data storage *internal.HoldableIterator // Iterator of storage snapshot data
db ethdb.KeyValueStore // Key-value store containing the snapshot data db ethdb.KeyValueStore // Key-value store containing the snapshot data
batch ethdb.Batch // Database batch for writing data atomically batch ethdb.Batch // Database batch for writing data atomically
logged time.Time // The timestamp when last generation progress was displayed logged time.Time // The timestamp when last generation progress was displayed
} }
// newGeneratorContext initializes the context for generation. // newGeneratorContext initializes the context for generation.
@ -119,11 +120,11 @@ func newGeneratorContext(root common.Hash, marker []byte, db ethdb.KeyValueStore
func (ctx *generatorContext) openIterator(kind string, start []byte) { func (ctx *generatorContext) openIterator(kind string, start []byte) {
if kind == snapAccount { if kind == snapAccount {
iter := ctx.db.NewIterator(rawdb.SnapshotAccountPrefix, start) iter := ctx.db.NewIterator(rawdb.SnapshotAccountPrefix, start)
ctx.account = newHoldableIterator(rawdb.NewKeyLengthIterator(iter, 1+common.HashLength)) ctx.account = internal.NewHoldableIterator(rawdb.NewKeyLengthIterator(iter, 1+common.HashLength))
return return
} }
iter := ctx.db.NewIterator(rawdb.SnapshotStoragePrefix, start) iter := ctx.db.NewIterator(rawdb.SnapshotStoragePrefix, start)
ctx.storage = newHoldableIterator(rawdb.NewKeyLengthIterator(iter, 1+2*common.HashLength)) ctx.storage = internal.NewHoldableIterator(rawdb.NewKeyLengthIterator(iter, 1+2*common.HashLength))
} }
// reopenIterator releases the specified snapshot iterator and re-open it // reopenIterator releases the specified snapshot iterator and re-open it
@ -140,10 +141,10 @@ func (ctx *generatorContext) reopenIterator(kind string) {
// Iterator exhausted, release forever and create an already exhausted virtual iterator // Iterator exhausted, release forever and create an already exhausted virtual iterator
iter.Release() iter.Release()
if kind == snapAccount { if kind == snapAccount {
ctx.account = newHoldableIterator(memorydb.New().NewIterator(nil, nil)) ctx.account = internal.NewHoldableIterator(memorydb.New().NewIterator(nil, nil))
return return
} }
ctx.storage = newHoldableIterator(memorydb.New().NewIterator(nil, nil)) ctx.storage = internal.NewHoldableIterator(memorydb.New().NewIterator(nil, nil))
return return
} }
next := iter.Key() next := iter.Key()
@ -158,7 +159,7 @@ func (ctx *generatorContext) close() {
} }
// iterator returns the corresponding iterator specified by the kind. // iterator returns the corresponding iterator specified by the kind.
func (ctx *generatorContext) iterator(kind string) *holdableIterator { func (ctx *generatorContext) iterator(kind string) *internal.HoldableIterator {
if kind == snapAccount { if kind == snapAccount {
return ctx.account return ctx.account
} }

View file

@ -52,12 +52,12 @@ func (db *Database) VerifyState(root common.Hash) error {
} }
defer storageIt.Release() defer storageIt.Release()
hash, err := internal.GenerateTrieRoot(nil, "", storageIt, accountHash, stackTrieHasher, nil, stat, false) hash, err := internal.GenerateTrieRoot(nil, "", storageIt, accountHash, stackTrieHasher, nil, stat, false, nil)
if err != nil { if err != nil {
return common.Hash{}, err return common.Hash{}, err
} }
return hash, nil return hash, nil
}, internal.NewGenerateStats(), true) }, internal.NewGenerateStats(), true, nil)
if err != nil { if err != nil {
return err return err