diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 53151f215e..d162535d0a 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -336,24 +336,6 @@ func traverseState(ctx *cli.Context) error { type OnStorageNodeHook func(node common.Hash, path []byte) error -// createRawStorageHook creates hooks for raw storage traversal with verification -func createRawStorageHook(reader database.NodeReader, accountHash common.Hash) OnStorageNodeHook { - return func(node common.Hash, path []byte) error { - if node != (common.Hash{}) { - blob, _ := reader.Node(accountHash, path, node) - if len(blob) == 0 { - log.Error("Missing trie node(storage)", "hash", node) - return errors.New("missing storage") - } - if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { - log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) - return errors.New("invalid storage node") - } - } - return nil - } -} - // traverseStorage parallelizes storage trie traversal func traverseStorage(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, raw bool, hook OnStorageNodeHook) (uint64, uint64, error) { var ( @@ -397,7 +379,7 @@ func traverseStorage(ctx context.Context, storageTrie *trie.StateTrie, startKey, return nil } - localSlots, localNodes, err := traverseStorageBranchWithHooks(cctx, storageTrie, branchStart, branchLimit, raw, hook) + localSlots, localNodes, err := traverseStorageBranch(cctx, storageTrie, branchStart, branchLimit, raw, hook) if err != nil { return err } @@ -414,8 +396,8 @@ func traverseStorage(ctx context.Context, storageTrie *trie.StateTrie, startKey, return slots.Load(), nodes.Load(), nil } -// traverseStorageBranchWithHooks traverses a specific range of the storage trie using hooks -func traverseStorageBranchWithHooks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, raw bool, hook OnStorageNodeHook) (slots, nodes uint64, err error) { +// traverseStorageBranch traverses a specific range of the storage trie using hooks +func traverseStorageBranch(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, raw bool, hook OnStorageNodeHook) (slots, nodes uint64, err error) { nodeIter, err := storageTrie.NodeIterator(startKey) if err != nil { return 0, 0, err @@ -565,7 +547,11 @@ func (ts *traverseSetup) traverseAccount(ctx context.Context, counters *traverse log.Error("State is non-existent", "root", ts.config.root) return nil } - hook = createRawStorageHook(reader, ts.config.account) + // Create storage hook inline - same pattern as in TraverseHooks + accountHash := ts.config.account + hook = func(node common.Hash, path []byte) error { + return verifyTrieNode(reader, accountHash, node, path) + } } slots, nodes, err := traverseStorage(ctx, storageTrie, ts.config.startKey, ts.config.limitKey, raw, hook) @@ -640,7 +626,7 @@ func (ts *traverseSetup) traverseState(ctx context.Context, counters *traverseCo hooks = ts.createSimpleHooks() } - return ts.traverseStateBranchWithHooks(ctx, startKey, limitKey, raw, counters, hooks) + return ts.traverseStateBranch(ctx, startKey, limitKey, raw, counters, hooks) }) } @@ -702,10 +688,8 @@ func parseTraverseArgs(ctx *cli.Context) (*traverseConfig, error) { type TraverseHooks struct { // Called for each trie account node (only for raw mode) OnAccountNode func(node common.Hash, path []byte) error - // Called for each storage trie + // Called for each storage trie - now self-contained with all needed context OnStorageTrie func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) - // Called for each code - OnCode func(codeHash []byte, accountHash common.Hash) error } // createSimpleHooks creates hooks for simple traversal mode @@ -714,54 +698,78 @@ func (ts *traverseSetup) createSimpleHooks() *TraverseHooks { OnStorageTrie: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { return traverseStorage(ctx, storageTrie, nil, nil, false, nil) }, - OnCode: func(codeHash []byte, accountHash common.Hash) error { - if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { - log.Error("Code is missing", "hash", common.BytesToHash(codeHash)) - return errors.New("missing code") - } - return nil - }, } } +func verifyTrieNode(reader database.NodeReader, accountHash common.Hash, node common.Hash, path []byte) error { + if node != (common.Hash{}) { + blob, _ := reader.Node(accountHash, path, node) + if len(blob) == 0 { + log.Error("Missing trie node", "hash", node) + return errors.New("missing node") + } + if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { + log.Error("Invalid trie node", "hash", node.Hex(), "value", blob) + return errors.New("invalid node") + } + } + return nil +} + // createRawHooks creates hooks for raw traversal mode with detailed verification func (ts *traverseSetup) createRawHooks(reader database.NodeReader) *TraverseHooks { return &TraverseHooks{ OnAccountNode: func(node common.Hash, path []byte) error { - if node != (common.Hash{}) { - blob, _ := reader.Node(common.Hash{}, path, node) - if len(blob) == 0 { - log.Error("Missing trie node(account)", "hash", node) - return errors.New("missing account") - } - if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { - log.Error("Invalid trie node(account)", "hash", node.Hex(), "value", blob) - return errors.New("invalid account node") - } - } - return nil + return verifyTrieNode(reader, common.Hash{}, node, path) }, + OnStorageTrie: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { - hook := createRawStorageHook(reader, accountHash) - return traverseStorage(ctx, storageTrie, nil, nil, true, hook) - }, - OnCode: func(codeHash []byte, accountHash common.Hash) error { - if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { - log.Error("Code is missing", "account", accountHash) - return errors.New("missing code") + hook := func(node common.Hash, path []byte) error { + return verifyTrieNode(reader, accountHash, node, path) } - return nil + return traverseStorage(ctx, storageTrie, nil, nil, true, hook) }, } } -// traverseStateBranchWithHooks provides common branch traversal logic using hooks -func (ts *traverseSetup) traverseStateBranchWithHooks(ctx context.Context, startKey, limitKey []byte, raw bool, counters *traverseCounters, hooks *TraverseHooks) error { +// traverseStateBranch provides common branch traversal logic using hooks +func (ts *traverseSetup) traverseStateBranch(ctx context.Context, startKey, limitKey []byte, raw bool, counters *traverseCounters, hooks *TraverseHooks) error { accIter, err := ts.trie.NodeIterator(startKey) if err != nil { return err } + processAccount := func(accKey []byte, acc *types.StateAccount) error { + accountHash := common.BytesToHash(accKey) + + // Process storage if present + if acc.Root != types.EmptyRootHash { + id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) + storageTrie, err := trie.NewStateTrie(id, ts.triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return err + } + + slots, nodes, err := hooks.OnStorageTrie(ctx, storageTrie, accountHash) + if err != nil { + return err + } + counters.slots.Add(slots) + counters.nodes.Add(nodes) + } + + // Process code if present + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { + if codeHash := common.BytesToHash(acc.CodeHash); !rawdb.HasCode(ts.chaindb, codeHash) { + log.Error("Code is missing", "hash", codeHash) + return errors.New("missing code") + } + counters.codes.Add(1) + } + return nil + } + if raw { // Raw traversal with detailed node checking for accIter.Next(true) { @@ -791,31 +799,8 @@ func (ts *traverseSetup) traverseStateBranchWithHooks(ctx context.Context, start return errors.New("invalid account") } - accountHash := common.BytesToHash(accIter.LeafKey()) - - if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) - storageTrie, err := trie.NewStateTrie(id, ts.triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - - if hooks != nil && hooks.OnStorageTrie != nil { - slots, nodes, err := hooks.OnStorageTrie(ctx, storageTrie, accountHash) - if err != nil { - return err - } - counters.slots.Add(slots) - counters.nodes.Add(nodes) - } - } - - if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if err := hooks.OnCode(acc.CodeHash, accountHash); err != nil { - return err - } - counters.codes.Add(1) + if err := processAccount(accIter.LeafKey(), &acc); err != nil { + return err } } } @@ -845,32 +830,8 @@ func (ts *traverseSetup) traverseStateBranchWithHooks(ctx context.Context, start log.Error("Invalid account encountered during traversal", "err", err) return err } - - accountHash := common.BytesToHash(accIter.Key) - - // Process storage if present - if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) - storageTrie, err := trie.NewStateTrie(id, ts.triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - - slots, nodes, err := hooks.OnStorageTrie(ctx, storageTrie, accountHash) - if err != nil { - return err - } - counters.slots.Add(slots) - counters.nodes.Add(nodes) - } - - // Process code if present - if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if err := hooks.OnCode(acc.CodeHash, accountHash); err != nil { - return err - } - counters.codes.Add(1) + if err := processAccount(accIter.Key, &acc); err != nil { + return err } }