From bd840a1c18eb0852377bfea064c5b950c7002a15 Mon Sep 17 00:00:00 2001 From: jsvisa Date: Thu, 11 Sep 2025 02:34:04 +0000 Subject: [PATCH] use callbacks Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 919 ++++++++++++++++++++++++------------------- 1 file changed, 510 insertions(+), 409 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index edcb67cc9a..0f4f952b7d 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -39,9 +39,11 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/triedb" + "github.com/ethereum/go-ethereum/triedb/database" "github.com/urfave/cli/v2" ) @@ -299,47 +301,17 @@ func checkDanglingStorage(ctx *cli.Context) error { return snapshot.CheckDanglingStorage(db) } -// traverseState is a helper function used for pruning verification. -// Basically it just iterates the trie, ensure all nodes and associated -// contract codes are present. func traverseState(ctx *cli.Context) error { - stack, _ := makeConfigNode(ctx) - defer stack.Close() - - chaindb := utils.MakeChainDatabase(ctx, stack, true) - defer chaindb.Close() - - triedb := utils.MakeTrieDatabase(ctx, stack, chaindb, false, true, false) - defer triedb.Close() - - headBlock := rawdb.ReadHeadBlock(chaindb) - if headBlock == nil { - log.Error("Failed to load head block") - return errors.New("no head block") - } - - config, err := parseTraverseArgs(ctx) + ts, err := setupTraversal(ctx) if err != nil { return err } - if config.root == (common.Hash{}) { - config.root = headBlock.Root() - } - - t, err := trie.NewStateTrie(trie.StateTrieID(config.root), triedb) - if err != nil { - log.Error("Failed to open trie", "root", config.root, "err", err) - return err - } + defer ts.Close() var ( - accounts atomic.Uint64 - slots atomic.Uint64 - codes atomic.Uint64 - start = time.Now() + counters = &traverseCounters{start: time.Now()} + cctx, cancel = context.WithCancel(context.Background()) ) - - cctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { @@ -348,231 +320,170 @@ func traverseState(ctx *cli.Context) error { for { select { case <-timer.C: - log.Info("Traversing state", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) - case <-cctx.Done(): + log.Info("Traversing state", "accounts", counters.accounts.Load(), "slots", counters.slots.Load(), "codes", counters.codes.Load(), "elapsed", common.PrettyDuration(time.Since(counters.start))) + case <-ctx.Done(): return } } }() - if config.isAccount { - log.Info("Start traversing storage trie", "root", config.root.Hex(), "account", config.account.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - - acc, err := t.GetAccountByHash(config.account) - if err != nil { - log.Error("Get account failed", "account", config.account.Hex(), "err", err) - return err - } - - if acc.Root == types.EmptyRootHash { - log.Info("Account has no storage") - return nil - } - - id := trie.StorageTrieID(config.root, config.account, acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - - storageIt, err := storageTrie.NodeIterator(config.startKey) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) - return err - } - - storageIter := trie.NewIterator(storageIt) - for storageIter.Next() { - if config.limitKey != nil && bytes.Compare(storageIter.Key, config.limitKey) >= 0 { - break - } - - slots.Add(1) - log.Debug("Storage slot", "key", common.Bytes2Hex(storageIter.Key), "value", common.Bytes2Hex(storageIter.Value)) - } - if storageIter.Err != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Err) - return storageIter.Err - } - - log.Info("Storage traversal complete", "slots", slots.Load(), "elapsed", common.PrettyDuration(time.Since(start))) + if ts.config.isAccount { + return ts.traverseAccount(cctx, counters, false) } else { - log.Info("Start traversing state trie", "root", config.root.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - - eg, ctx := errgroup.WithContext(cctx) - - // Parallel processing with boundary checks - for i := 0; i < 16; i++ { - nibble := byte(i) - eg.Go(func() error { - // Skip branches that are entirely before startKey - if len(config.startKey) > 0 { - startNibble := config.startKey[0] >> 4 - if nibble < startNibble { - return nil // Skip this branch - } - } - - // Skip branches that are entirely after limitKey - if len(config.limitKey) > 0 { - limitNibble := config.limitKey[0] >> 4 - if nibble > limitNibble { - return nil // Skip this branch - } - } - - var ( - startKey = []byte{nibble << 4} - limitKey []byte - ) - if nibble < 15 { - limitKey = []byte{(nibble + 1) << 4} - } else { - // Last branch (0xf*) has no limit - limitKey = nil - } - - return traverseBranch(ctx, t, triedb, chaindb, config, startKey, limitKey, &accounts, &slots, &codes) - }) - } - - if err := eg.Wait(); err != nil { - return err - } - - log.Info("State traversal complete", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) + return ts.traverseState(cctx, counters, false) } - return nil } -// traverseBranch traverses a specific branch of the state trie -func traverseBranch(ctx context.Context, t *trie.StateTrie, triedb *triedb.Database, chaindb ethdb.Database, config *traverseConfig, startKey, limitKey []byte, accounts, slots, codes *atomic.Uint64) error { - acctIt, err := t.NodeIterator(startKey) - if err != nil { - return err - } - - accIter := trie.NewIterator(acctIt) - for accIter.Next() { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - if config.startKey != nil && bytes.Compare(accIter.Key, config.startKey) < 0 { - continue - } - - if limitKey != nil && bytes.Compare(accIter.Key, limitKey) >= 0 { - break - } - - accounts.Add(1) - - var acc types.StateAccount - if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { - log.Error("Invalid account encountered during traversal", "err", err) - return err - } - - if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(config.root, common.BytesToHash(accIter.Key), acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - - localSlots, err := traverseStorageParallel(ctx, storageTrie) - if err != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", err) - return err - } - slots.Add(localSlots) - } - - if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { - log.Error("Code is missing", "hash", common.BytesToHash(acc.CodeHash)) - return errors.New("missing code") - } - codes.Add(1) - } - } - - if accIter.Err != nil { - return accIter.Err - } - - return nil +// StorageCallbacks defines the callbacks for different storage traversal modes +type StorageCallbacks struct { + // Called for each storage node (only for raw mode) + OnStorageNode func(node common.Hash, path []byte) error } -// traverseStorageParallel parallelizes storage trie traversal by dividing work across 16 trie branches -func traverseStorageParallel(ctx context.Context, storageTrie *trie.StateTrie) (uint64, error) { - g, ctx := errgroup.WithContext(ctx) - totalSlots := atomic.Uint64{} +// createSimpleStorageCallbacks creates callbacks for simple storage traversal +func createSimpleStorageCallbacks() *StorageCallbacks { + return &StorageCallbacks{ + OnStorageNode: nil, // Simple mode doesn't process nodes + } +} + +// createRawStorageCallbacks creates callbacks for raw storage traversal with verification +func createRawStorageCallbacks(reader database.NodeReader, accountHash common.Hash) *StorageCallbacks { + return &StorageCallbacks{ + OnStorageNode: func(node common.Hash, path []byte) error { + if node != (common.Hash{}) { + blob, _ := reader.Node(accountHash, path, node) + if len(blob) == 0 { + log.Error("Missing trie node(storage)", "hash", node) + return errors.New("missing storage") + } + if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { + log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) + return errors.New("invalid storage node") + } + } + return nil + }, + } +} + +// traverseStorageParallelWithCallbacks parallelizes storage trie traversal using callbacks +func traverseStorageParallelWithCallbacks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, callbacks *StorageCallbacks, raw bool) (slots, nodes uint64, err error) { + var ( + eg, cctx = errgroup.WithContext(ctx) + slotsAtomic atomic.Uint64 + nodesAtomic atomic.Uint64 + ) for i := 0; i < 16; i++ { nibble := byte(i) - g.Go(func() error { + eg.Go(func() error { + // Calculate this branch's natural boundaries var ( - startKey = []byte{nibble << 4} - limitKey []byte + branchStart = []byte{nibble << 4} + branchLimit []byte ) if nibble < 15 { - limitKey = []byte{(nibble + 1) << 4} - } else { - // For the last branch (0xf*), no limit key (traverse to end) - limitKey = nil + branchLimit = []byte{(nibble + 1) << 4} } - localSlots, err := traverseStorageBranch(ctx, storageTrie, startKey, limitKey) + // Skip branches that are entirely before startKey + if startKey != nil && branchLimit != nil && bytes.Compare(branchLimit, startKey) <= 0 { + return nil + } + + // Skip branches that are entirely after limitKey + if limitKey != nil && bytes.Compare(branchStart, limitKey) >= 0 { + return nil + } + + // Use the more restrictive start boundary + if startKey != nil && bytes.Compare(branchStart, startKey) < 0 { + branchStart = startKey + } + if limitKey != nil && (branchLimit == nil || bytes.Compare(branchLimit, limitKey) > 0) { + branchLimit = limitKey + } + + // Skip if branch range is empty + if branchLimit != nil && bytes.Compare(branchStart, branchLimit) >= 0 { + return nil + } + + localSlots, localNodes, err := traverseStorageBranchWithCallbacks(cctx, storageTrie, branchStart, branchLimit, callbacks, raw) if err != nil { return err } - totalSlots.Add(localSlots) + slotsAtomic.Add(localSlots) + nodesAtomic.Add(localNodes) return nil }) } - if err := g.Wait(); err != nil { - return 0, err + if err := eg.Wait(); err != nil { + return 0, 0, err } - return totalSlots.Load(), nil + return slotsAtomic.Load(), nodesAtomic.Load(), nil } -// traverseStorageBranch traverses a specific branch of the storage trie -func traverseStorageBranch(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte) (uint64, error) { - storageIt, err := storageTrie.NodeIterator(startKey) +// traverseStorageBranchWithCallbacks traverses a specific range of the storage trie using callbacks +func traverseStorageBranchWithCallbacks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, callbacks *StorageCallbacks, raw bool) (slots, nodes uint64, err error) { + nodeIter, err := storageTrie.NodeIterator(startKey) if err != nil { - return 0, err + return 0, 0, err } - storageIter := trie.NewIterator(storageIt) - slots := uint64(0) + if raw { + // Raw traversal with detailed node checking + for nodeIter.Next(true) { + select { + case <-ctx.Done(): + return 0, 0, ctx.Err() + default: + } - for storageIter.Next() { - // Check if context was cancelled - select { - case <-ctx.Done(): - return 0, ctx.Err() - default: + nodes++ + if callbacks != nil && callbacks.OnStorageNode != nil { + if err := callbacks.OnStorageNode(nodeIter.Hash(), nodeIter.Path()); err != nil { + return 0, 0, err + } + } + + if nodeIter.Leaf() { + if limitKey != nil && bytes.Compare(nodeIter.LeafKey(), limitKey) >= 0 { + break + } + slots++ + } + } + } else { + // Simple traversal - just iterate through leaf nodes + storageIter := trie.NewIterator(nodeIter) + for storageIter.Next() { + select { + case <-ctx.Done(): + return 0, 0, ctx.Err() + default: + } + + if limitKey != nil && bytes.Compare(storageIter.Key, limitKey) >= 0 { + break + } + + slots++ } - if limitKey != nil && bytes.Compare(storageIter.Key, limitKey) >= 0 { - break + if storageIter.Err != nil { + return 0, 0, storageIter.Err } - slots++ } - if storageIter.Err != nil { - return 0, storageIter.Err + if err := nodeIter.Error(); err != nil { + return 0, 0, err } - return slots, nil + return slots, nodes, nil } type traverseConfig struct { @@ -583,6 +494,181 @@ type traverseConfig struct { isAccount bool } +type traverseSetup struct { + stack *node.Node + chaindb ethdb.Database + triedb *triedb.Database + trie *trie.StateTrie + config *traverseConfig +} + +type traverseCounters struct { + accounts atomic.Uint64 + slots atomic.Uint64 + codes atomic.Uint64 + nodes atomic.Uint64 + start time.Time +} + +func setupTraversal(ctx *cli.Context) (*traverseSetup, error) { + stack, _ := makeConfigNode(ctx) + + chaindb := utils.MakeChainDatabase(ctx, stack, true) + triedb := utils.MakeTrieDatabase(ctx, stack, chaindb, false, true, false) + + headBlock := rawdb.ReadHeadBlock(chaindb) + if headBlock == nil { + log.Error("Failed to load head block") + return nil, errors.New("no head block") + } + + config, err := parseTraverseArgs(ctx) + if err != nil { + return nil, err + } + if config.root == (common.Hash{}) { + config.root = headBlock.Root() + } + + t, err := trie.NewStateTrie(trie.StateTrieID(config.root), triedb) + if err != nil { + log.Error("Failed to open trie", "root", config.root, "err", err) + return nil, err + } + + return &traverseSetup{ + stack: stack, + chaindb: chaindb, + triedb: triedb, + trie: t, + config: config, + }, nil +} + +func (ts *traverseSetup) Close() { + ts.triedb.Close() + ts.chaindb.Close() + ts.stack.Close() +} + +func (ts *traverseSetup) traverseAccount(ctx context.Context, counters *traverseCounters, raw bool) error { + log.Info("Start traversing storage trie", "root", ts.config.root.Hex(), "account", ts.config.account.Hex(), "startKey", common.Bytes2Hex(ts.config.startKey), "limitKey", common.Bytes2Hex(ts.config.limitKey)) + + acc, err := ts.trie.GetAccountByHash(ts.config.account) + if err != nil { + log.Error("Get account failed", "account", ts.config.account.Hex(), "err", err) + return err + } + + if acc.Root == types.EmptyRootHash { + log.Info("Account has no storage") + return nil + } + + id := trie.StorageTrieID(ts.config.root, ts.config.account, acc.Root) + storageTrie, err := trie.NewStateTrie(id, ts.triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return err + } + + var callbacks *StorageCallbacks + if raw { + reader, err := ts.triedb.NodeReader(ts.config.root) + if err != nil { + log.Error("State is non-existent", "root", ts.config.root) + return nil + } + callbacks = createRawStorageCallbacks(reader, ts.config.account) + } else { + callbacks = createSimpleStorageCallbacks() + } + + slots, nodes, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, ts.config.startKey, ts.config.limitKey, callbacks, raw) + if err != nil { + log.Error("Failed to traverse storage trie", "root", acc.Root, "err", err) + return err + } + + counters.slots.Add(slots) + counters.nodes.Add(nodes) + + if raw { + log.Info("Storage traversal complete (raw)", "nodes", counters.nodes.Load(), "slots", counters.slots.Load(), "elapsed", common.PrettyDuration(time.Since(counters.start))) + } else { + log.Info("Storage traversal complete", "slots", counters.slots.Load(), "elapsed", common.PrettyDuration(time.Since(counters.start))) + } + return nil +} + +func (ts *traverseSetup) traverseState(ctx context.Context, counters *traverseCounters, raw bool) error { + log.Info("Start traversing state trie", "root", ts.config.root.Hex(), "startKey", common.Bytes2Hex(ts.config.startKey), "limitKey", common.Bytes2Hex(ts.config.limitKey)) + + eg, ctx := errgroup.WithContext(ctx) + var reader database.NodeReader + if raw { + var err error + reader, err = ts.triedb.NodeReader(ts.config.root) + if err != nil { + log.Error("State is non-existent", "root", ts.config.root) + return nil + } + } + + for i := 0; i < 16; i++ { + nibble := byte(i) + eg.Go(func() error { + var ( + startKey = []byte{nibble << 4} + limitKey []byte + ) + if nibble < 15 { + limitKey = []byte{(nibble + 1) << 4} + } + + if ts.config != nil { + // Skip branches that are entirely before startKey + if limitKey != nil && bytes.Compare(limitKey, ts.config.startKey) <= 0 { + return nil + } + + // Skip branches that are entirely after limitKey + if ts.config.limitKey != nil && bytes.Compare(startKey, ts.config.limitKey) >= 0 { + return nil + } + + if ts.config.startKey != nil && bytes.Compare(startKey, ts.config.startKey) < 0 { + startKey = ts.config.startKey + } + if ts.config.limitKey != nil && (limitKey == nil || bytes.Compare(limitKey, ts.config.limitKey) > 0) { + limitKey = ts.config.limitKey + } + } + + if limitKey != nil && bytes.Compare(startKey, limitKey) >= 0 { + return nil + } + + // Create appropriate callbacks based on traversal mode + var callbacks *TraverseCallbacks + if raw { + callbacks = ts.createRawCallbacks(counters, reader) + } else { + callbacks = ts.createSimpleCallbacks(counters) + } + + return ts.traverseStateBranchWithCallbacks(ctx, startKey, limitKey, callbacks, raw) + }) + } + + if err := eg.Wait(); err != nil { + return err + } + + log.Info("State traversal complete", "accounts", counters.accounts.Load(), "slots", counters.slots.Load(), "codes", counters.codes.Load(), "elapsed", common.PrettyDuration(time.Since(counters.start))) + return nil +} + func parseTraverseArgs(ctx *cli.Context) (*traverseConfig, error) { if ctx.NArg() > 1 { return nil, errors.New("too many arguments, only is required") @@ -629,235 +715,250 @@ func parseTraverseArgs(ctx *cli.Context) (*traverseConfig, error) { return config, nil } -// traverseRawState is a helper function used for pruning verification. -// Basically it just iterates the trie, ensure all nodes and associated -// contract codes are present. It's basically identical to traverseState -// but it will check each trie node. -func traverseRawState(ctx *cli.Context) error { - stack, _ := makeConfigNode(ctx) - defer stack.Close() +// TraverseCallbacks defines the callbacks for different traversal modes +type TraverseCallbacks struct { + // Called for each trie account node (only for raw mode) + OnNode func(node common.Hash, path []byte) error + // Called for each account + OnAccount func(acc types.StateAccount, accountHash common.Hash) error + // Called for each storage slot + OnStorage func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) + // Called for each code + OnCode func(codeHash []byte, accountHash common.Hash) error +} - chaindb := utils.MakeChainDatabase(ctx, stack, true) - defer chaindb.Close() - - triedb := utils.MakeTrieDatabase(ctx, stack, chaindb, false, true, false) - defer triedb.Close() - - headBlock := rawdb.ReadHeadBlock(chaindb) - if headBlock == nil { - log.Error("Failed to load head block") - return errors.New("no head block") - } - - config, err := parseTraverseArgs(ctx) - if err != nil { - log.Error("Failed to parse arguments", "err", err) - return err - } - if config.root == (common.Hash{}) { - config.root = headBlock.Root() - } - t, err := trie.NewStateTrie(trie.StateTrieID(config.root), triedb) - if err != nil { - log.Error("Failed to open trie", "root", config.root, "err", err) - return err - } - - var ( - accounts int - nodes int - slots int - codes int - start = time.Now() - hasher = crypto.NewKeccakState() - got = make([]byte, 32) - ) - - go func() { - timer := time.NewTicker(time.Second * 8) - defer timer.Stop() - for range timer.C { - log.Info("Traversing rawstate", "nodes", nodes, "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - } - }() - - if config.isAccount { - log.Info("Start traversing storage trie (raw)", "root", config.root.Hex(), "account", config.account.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - - acc, err := t.GetAccountByHash(config.account) - if err != nil { - log.Error("Get account failed", "account", config.account.Hex(), "err", err) - return err - } - - if acc.Root == types.EmptyRootHash { - log.Info("Account has no storage") +// createSimpleCallbacks creates callbacks for simple traversal mode +func (ts *traverseSetup) createSimpleCallbacks(counters *traverseCounters) *TraverseCallbacks { + return &TraverseCallbacks{ + OnAccount: func(acc types.StateAccount, accountHash common.Hash) error { + counters.accounts.Add(1) return nil - } - - // Traverse the storage trie with detailed verification - id := trie.StorageTrieID(config.root, config.account, acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - - storageIter, err := storageTrie.NodeIterator(config.startKey) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) - return err - } - - reader, err := triedb.NodeReader(config.root) - if err != nil { - log.Error("State is non-existent", "root", config.root) - return nil - } - - for storageIter.Next(true) { - nodes += 1 - node := storageIter.Hash() - - // Check the presence for non-empty hash node(embedded node doesn't - // have their own hash). - if node != (common.Hash{}) { - blob, _ := reader.Node(config.account, storageIter.Path(), node) - if len(blob) == 0 { - log.Error("Missing trie node(storage)", "hash", node) - return errors.New("missing storage") - } - hasher.Reset() - hasher.Write(blob) - hasher.Read(got) - if !bytes.Equal(got, node.Bytes()) { - log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) - return errors.New("invalid storage node") - } + }, + OnStorage: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { + callbacks := createSimpleStorageCallbacks() + s, n, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, nil, nil, callbacks, false) + if err != nil { + return 0, 0, err } - - // Bump the counter if it's leaf node. - if storageIter.Leaf() { - // Check if we've exceeded the limit key for storage - if config.limitKey != nil && bytes.Compare(storageIter.LeafKey(), config.limitKey) >= 0 { - break - } - - slots += 1 - log.Debug("Storage slot", "key", common.Bytes2Hex(storageIter.LeafKey()), "value", common.Bytes2Hex(storageIter.LeafBlob())) + counters.slots.Add(s) + return s, n, nil + }, + OnCode: func(codeHash []byte, accountHash common.Hash) error { + if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { + log.Error("Code is missing", "hash", common.BytesToHash(codeHash)) + return errors.New("missing code") } - } - if storageIter.Error() != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Error()) - return storageIter.Error() - } - - log.Info("Storage traversal complete (raw)", "nodes", nodes, "slots", slots, "elapsed", common.PrettyDuration(time.Since(start))) - return nil - } else { - log.Info("Start traversing the state trie (raw)", "root", config.root.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - - accIter, err := t.NodeIterator(config.startKey) - if err != nil { - log.Error("Failed to open iterator", "root", config.root, "err", err) - return err - } - reader, err := triedb.NodeReader(config.root) - if err != nil { - log.Error("State is non-existent", "root", config.root) + counters.codes.Add(1) return nil - } - for accIter.Next(true) { - nodes += 1 - node := accIter.Hash() + }, + } +} - // Check the present for non-empty hash node(embedded node doesn't - // have their own hash). +// createRawCallbacks creates callbacks for raw traversal mode with detailed verification +func (ts *traverseSetup) createRawCallbacks(counters *traverseCounters, reader database.NodeReader) *TraverseCallbacks { + return &TraverseCallbacks{ + OnNode: func(node common.Hash, path []byte) error { + counters.nodes.Add(1) if node != (common.Hash{}) { - blob, _ := reader.Node(common.Hash{}, accIter.Path(), node) + blob, _ := reader.Node(common.Hash{}, path, node) if len(blob) == 0 { log.Error("Missing trie node(account)", "hash", node) return errors.New("missing account") } - hasher.Reset() - hasher.Write(blob) - hasher.Read(got) - if !bytes.Equal(got, node.Bytes()) { + if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { log.Error("Invalid trie node(account)", "hash", node.Hex(), "value", blob) return errors.New("invalid account node") } } - // If it's a leaf node, yes we are touching an account, - // dig into the storage trie further. + return nil + }, + OnAccount: func(acc types.StateAccount, accountHash common.Hash) error { + counters.accounts.Add(1) + return nil + }, + OnStorage: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { + callbacks := createRawStorageCallbacks(reader, accountHash) + s, n, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, nil, nil, callbacks, true) + if err != nil { + return 0, 0, err + } + counters.slots.Add(s) + counters.nodes.Add(n) + return s, n, nil + }, + OnCode: func(codeHash []byte, accountHash common.Hash) error { + if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { + log.Error("Code is missing", "account", accountHash) + return errors.New("missing code") + } + counters.codes.Add(1) + return nil + }, + } +} + +// traverseStateBranchWithCallbacks provides common branch traversal logic using callbacks +func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, startKey, limitKey []byte, callbacks *TraverseCallbacks, raw bool) error { + accIter, err := ts.trie.NodeIterator(startKey) + if err != nil { + return err + } + + if raw { + // Raw traversal with detailed node checking + for accIter.Next(true) { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if callbacks.OnNode != nil { + if err := callbacks.OnNode(accIter.Hash(), accIter.Path()); err != nil { + return err + } + } + + // If it's a leaf node, process the account if accIter.Leaf() { - // Check if we've exceeded the limit key for accounts - if config.limitKey != nil && bytes.Compare(accIter.LeafKey(), config.limitKey) >= 0 { + if limitKey != nil && bytes.Compare(accIter.LeafKey(), limitKey) >= 0 { break } - accounts += 1 var acc types.StateAccount if err := rlp.DecodeBytes(accIter.LeafBlob(), &acc); err != nil { log.Error("Invalid account encountered during traversal", "err", err) return errors.New("invalid account") } + + accountHash := common.BytesToHash(accIter.LeafKey()) + + if err := callbacks.OnAccount(acc, accountHash); err != nil { + return err + } + + // Process storage if present if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(config.root, common.BytesToHash(accIter.LeafKey()), acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) + id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) + storageTrie, err := trie.NewStateTrie(id, ts.triedb) if err != nil { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return errors.New("missing storage trie") - } - storageIter, err := storageTrie.NodeIterator(nil) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) return err } - for storageIter.Next(true) { - nodes += 1 - node := storageIter.Hash() - // Check the presence for non-empty hash node(embedded node doesn't - // have their own hash). - if node != (common.Hash{}) { - blob, _ := reader.Node(common.BytesToHash(accIter.LeafKey()), storageIter.Path(), node) - if len(blob) == 0 { - log.Error("Missing trie node(storage)", "hash", node) - return errors.New("missing storage") - } - hasher.Reset() - hasher.Write(blob) - hasher.Read(got) - if !bytes.Equal(got, node.Bytes()) { - log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) - return errors.New("invalid storage node") - } - } - // Bump the counter if it's leaf node. - if storageIter.Leaf() { - slots += 1 - } - } - if storageIter.Error() != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Error()) - return storageIter.Error() + if _, _, err := callbacks.OnStorage(ctx, storageTrie, accountHash); err != nil { + return err } } + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { - log.Error("Code is missing", "account", common.BytesToHash(accIter.LeafKey())) - return errors.New("missing code") + if err := callbacks.OnCode(acc.CodeHash, accountHash); err != nil { + return err } - codes += 1 } } } - if accIter.Error() != nil { - log.Error("Failed to traverse state trie", "root", config.root, "err", accIter.Error()) - return accIter.Error() + } else { + // Simple traversal - just iterate through leaf nodes + acctIt, err := ts.trie.NodeIterator(startKey) + if err != nil { + return err } - log.Info("State traversal complete (raw)", "nodes", nodes, "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - return nil + + accIter := trie.NewIterator(acctIt) + for accIter.Next() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // Check if we've reached the limit for this branch + if limitKey != nil && bytes.Compare(accIter.Key, limitKey) >= 0 { + break + } + + var acc types.StateAccount + if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { + log.Error("Invalid account encountered during traversal", "err", err) + return err + } + + accountHash := common.BytesToHash(accIter.Key) + + // Process account + if err := callbacks.OnAccount(acc, accountHash); err != nil { + return err + } + + // Process storage if present + if acc.Root != types.EmptyRootHash { + id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) + storageTrie, err := trie.NewStateTrie(id, ts.triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return err + } + + if _, _, err := callbacks.OnStorage(ctx, storageTrie, accountHash); err != nil { + return err + } + } + + // Process code if present + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { + if err := callbacks.OnCode(acc.CodeHash, accountHash); err != nil { + return err + } + } + } + + if accIter.Err != nil { + return accIter.Err + } + } + + if accIter.Error() != nil { + return accIter.Error() + } + + return nil +} + +// traverseRawState is a helper function used for pruning verification. +// Basically it just iterates the trie, ensure all nodes and associated +// contract codes are present. It's basically identical to traverseState +// but it will check each trie node. +func traverseRawState(ctx *cli.Context) error { + ts, err := setupTraversal(ctx) + if err != nil { + return err + } + defer ts.Close() + + var ( + counters = &traverseCounters{start: time.Now()} + cctx, cancel = context.WithCancel(context.Background()) + ) + defer cancel() + + go func() { + timer := time.NewTicker(time.Second * 8) + defer timer.Stop() + for { + select { + case <-timer.C: + log.Info("Traversing rawstate", "nodes", counters.nodes.Load(), "accounts", counters.accounts.Load(), "slots", counters.slots.Load(), "codes", counters.codes.Load(), "elapsed", common.PrettyDuration(time.Since(counters.start))) + case <-cctx.Done(): + return + } + } + }() + + if ts.config.isAccount { + return ts.traverseAccount(cctx, counters, true) + } else { + return ts.traverseState(cctx, counters, true) } }