From 20a3a7ded2439929993bbeaa4e5cfdde97b0eb56 Mon Sep 17 00:00:00 2001 From: jsvisa Date: Thu, 11 Sep 2025 03:28:16 +0000 Subject: [PATCH] refine Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 189 +++++++++++++++++-------------------------- 1 file changed, 75 insertions(+), 114 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 0f4f952b7d..a0222dda22 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -334,45 +334,32 @@ func traverseState(ctx *cli.Context) error { } } -// StorageCallbacks defines the callbacks for different storage traversal modes -type StorageCallbacks struct { - // Called for each storage node (only for raw mode) - OnStorageNode func(node common.Hash, path []byte) error -} +type OnStorageNodeHook func(node common.Hash, path []byte) error -// createSimpleStorageCallbacks creates callbacks for simple storage traversal -func createSimpleStorageCallbacks() *StorageCallbacks { - return &StorageCallbacks{ - OnStorageNode: nil, // Simple mode doesn't process nodes - } -} - -// createRawStorageCallbacks creates callbacks for raw storage traversal with verification -func createRawStorageCallbacks(reader database.NodeReader, accountHash common.Hash) *StorageCallbacks { - return &StorageCallbacks{ - OnStorageNode: func(node common.Hash, path []byte) error { - if node != (common.Hash{}) { - blob, _ := reader.Node(accountHash, path, node) - if len(blob) == 0 { - log.Error("Missing trie node(storage)", "hash", node) - return errors.New("missing storage") - } - if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { - log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) - return errors.New("invalid storage node") - } +// createRawStorageHook creates hooks for raw storage traversal with verification +func createRawStorageHook(reader database.NodeReader, accountHash common.Hash) OnStorageNodeHook { + return func(node common.Hash, path []byte) error { + if node != (common.Hash{}) { + blob, _ := reader.Node(accountHash, path, node) + if len(blob) == 0 { + log.Error("Missing trie node(storage)", "hash", node) + return errors.New("missing storage") } - return nil - }, + if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { + log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) + return errors.New("invalid storage node") + } + } + return nil } } -// traverseStorageParallelWithCallbacks parallelizes storage trie traversal using callbacks -func traverseStorageParallelWithCallbacks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, callbacks *StorageCallbacks, raw bool) (slots, nodes uint64, err error) { +// traverseStorage parallelizes storage trie traversal +func traverseStorage(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, raw bool, hook OnStorageNodeHook) (uint64, uint64, error) { var ( - eg, cctx = errgroup.WithContext(ctx) - slotsAtomic atomic.Uint64 - nodesAtomic atomic.Uint64 + eg, cctx = errgroup.WithContext(ctx) + slots atomic.Uint64 + nodes atomic.Uint64 ) for i := 0; i < 16; i++ { @@ -410,12 +397,12 @@ func traverseStorageParallelWithCallbacks(ctx context.Context, storageTrie *trie return nil } - localSlots, localNodes, err := traverseStorageBranchWithCallbacks(cctx, storageTrie, branchStart, branchLimit, callbacks, raw) + localSlots, localNodes, err := traverseStorageBranchWithHooks(cctx, storageTrie, branchStart, branchLimit, raw, hook) if err != nil { return err } - slotsAtomic.Add(localSlots) - nodesAtomic.Add(localNodes) + slots.Add(localSlots) + nodes.Add(localNodes) return nil }) } @@ -424,11 +411,11 @@ func traverseStorageParallelWithCallbacks(ctx context.Context, storageTrie *trie return 0, 0, err } - return slotsAtomic.Load(), nodesAtomic.Load(), nil + return slots.Load(), nodes.Load(), nil } -// traverseStorageBranchWithCallbacks traverses a specific range of the storage trie using callbacks -func traverseStorageBranchWithCallbacks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, callbacks *StorageCallbacks, raw bool) (slots, nodes uint64, err error) { +// traverseStorageBranchWithHooks traverses a specific range of the storage trie using hooks +func traverseStorageBranchWithHooks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, raw bool, hook OnStorageNodeHook) (slots, nodes uint64, err error) { nodeIter, err := storageTrie.NodeIterator(startKey) if err != nil { return 0, 0, err @@ -444,8 +431,8 @@ func traverseStorageBranchWithCallbacks(ctx context.Context, storageTrie *trie.S } nodes++ - if callbacks != nil && callbacks.OnStorageNode != nil { - if err := callbacks.OnStorageNode(nodeIter.Hash(), nodeIter.Path()); err != nil { + if hook != nil { + if err := hook(nodeIter.Hash(), nodeIter.Path()); err != nil { return 0, 0, err } } @@ -470,7 +457,6 @@ func traverseStorageBranchWithCallbacks(ctx context.Context, storageTrie *trie.S if limitKey != nil && bytes.Compare(storageIter.Key, limitKey) >= 0 { break } - slots++ } @@ -572,19 +558,17 @@ func (ts *traverseSetup) traverseAccount(ctx context.Context, counters *traverse return err } - var callbacks *StorageCallbacks + var hook OnStorageNodeHook if raw { reader, err := ts.triedb.NodeReader(ts.config.root) if err != nil { log.Error("State is non-existent", "root", ts.config.root) return nil } - callbacks = createRawStorageCallbacks(reader, ts.config.account) - } else { - callbacks = createSimpleStorageCallbacks() + hook = createRawStorageHook(reader, ts.config.account) } - slots, nodes, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, ts.config.startKey, ts.config.limitKey, callbacks, raw) + slots, nodes, err := traverseStorage(ctx, storageTrie, ts.config.startKey, ts.config.limitKey, raw, hook) if err != nil { log.Error("Failed to traverse storage trie", "root", acc.Root, "err", err) return err @@ -649,15 +633,14 @@ func (ts *traverseSetup) traverseState(ctx context.Context, counters *traverseCo return nil } - // Create appropriate callbacks based on traversal mode - var callbacks *TraverseCallbacks + var hooks *TraverseHooks if raw { - callbacks = ts.createRawCallbacks(counters, reader) + hooks = ts.createRawHooks(reader) } else { - callbacks = ts.createSimpleCallbacks(counters) + hooks = ts.createSimpleHooks() } - return ts.traverseStateBranchWithCallbacks(ctx, startKey, limitKey, callbacks, raw) + return ts.traverseStateBranchWithHooks(ctx, startKey, limitKey, raw, counters, hooks) }) } @@ -715,50 +698,36 @@ func parseTraverseArgs(ctx *cli.Context) (*traverseConfig, error) { return config, nil } -// TraverseCallbacks defines the callbacks for different traversal modes -type TraverseCallbacks struct { +// TraverseHooks defines the hooks for different traversal modes +type TraverseHooks struct { // Called for each trie account node (only for raw mode) - OnNode func(node common.Hash, path []byte) error - // Called for each account - OnAccount func(acc types.StateAccount, accountHash common.Hash) error - // Called for each storage slot - OnStorage func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) + OnAccountNode func(node common.Hash, path []byte) error + // Called for each storage trie + OnStorageTrie func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) // Called for each code OnCode func(codeHash []byte, accountHash common.Hash) error } -// createSimpleCallbacks creates callbacks for simple traversal mode -func (ts *traverseSetup) createSimpleCallbacks(counters *traverseCounters) *TraverseCallbacks { - return &TraverseCallbacks{ - OnAccount: func(acc types.StateAccount, accountHash common.Hash) error { - counters.accounts.Add(1) - return nil - }, - OnStorage: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { - callbacks := createSimpleStorageCallbacks() - s, n, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, nil, nil, callbacks, false) - if err != nil { - return 0, 0, err - } - counters.slots.Add(s) - return s, n, nil +// createSimpleHooks creates hooks for simple traversal mode +func (ts *traverseSetup) createSimpleHooks() *TraverseHooks { + return &TraverseHooks{ + OnStorageTrie: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { + return traverseStorage(ctx, storageTrie, nil, nil, false, nil) }, OnCode: func(codeHash []byte, accountHash common.Hash) error { if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { log.Error("Code is missing", "hash", common.BytesToHash(codeHash)) return errors.New("missing code") } - counters.codes.Add(1) return nil }, } } -// createRawCallbacks creates callbacks for raw traversal mode with detailed verification -func (ts *traverseSetup) createRawCallbacks(counters *traverseCounters, reader database.NodeReader) *TraverseCallbacks { - return &TraverseCallbacks{ - OnNode: func(node common.Hash, path []byte) error { - counters.nodes.Add(1) +// createRawHooks creates hooks for raw traversal mode with detailed verification +func (ts *traverseSetup) createRawHooks(reader database.NodeReader) *TraverseHooks { + return &TraverseHooks{ + OnAccountNode: func(node common.Hash, path []byte) error { if node != (common.Hash{}) { blob, _ := reader.Node(common.Hash{}, path, node) if len(blob) == 0 { @@ -772,33 +741,22 @@ func (ts *traverseSetup) createRawCallbacks(counters *traverseCounters, reader d } return nil }, - OnAccount: func(acc types.StateAccount, accountHash common.Hash) error { - counters.accounts.Add(1) - return nil - }, - OnStorage: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { - callbacks := createRawStorageCallbacks(reader, accountHash) - s, n, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, nil, nil, callbacks, true) - if err != nil { - return 0, 0, err - } - counters.slots.Add(s) - counters.nodes.Add(n) - return s, n, nil + OnStorageTrie: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { + hook := createRawStorageHook(reader, accountHash) + return traverseStorage(ctx, storageTrie, nil, nil, true, hook) }, OnCode: func(codeHash []byte, accountHash common.Hash) error { if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { log.Error("Code is missing", "account", accountHash) return errors.New("missing code") } - counters.codes.Add(1) return nil }, } } -// traverseStateBranchWithCallbacks provides common branch traversal logic using callbacks -func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, startKey, limitKey []byte, callbacks *TraverseCallbacks, raw bool) error { +// traverseStateBranchWithHooks provides common branch traversal logic using hooks +func (ts *traverseSetup) traverseStateBranchWithHooks(ctx context.Context, startKey, limitKey []byte, raw bool, counters *traverseCounters, hooks *TraverseHooks) error { accIter, err := ts.trie.NodeIterator(startKey) if err != nil { return err @@ -813,8 +771,9 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s default: } - if callbacks.OnNode != nil { - if err := callbacks.OnNode(accIter.Hash(), accIter.Path()); err != nil { + counters.nodes.Add(1) + if hooks != nil && hooks.OnAccountNode != nil { + if err := hooks.OnAccountNode(accIter.Hash(), accIter.Path()); err != nil { return err } } @@ -824,6 +783,7 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s if limitKey != nil && bytes.Compare(accIter.LeafKey(), limitKey) >= 0 { break } + counters.accounts.Add(1) var acc types.StateAccount if err := rlp.DecodeBytes(accIter.LeafBlob(), &acc); err != nil { @@ -833,11 +793,6 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s accountHash := common.BytesToHash(accIter.LeafKey()) - if err := callbacks.OnAccount(acc, accountHash); err != nil { - return err - } - - // Process storage if present if acc.Root != types.EmptyRootHash { id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) storageTrie, err := trie.NewStateTrie(id, ts.triedb) @@ -846,15 +801,21 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s return err } - if _, _, err := callbacks.OnStorage(ctx, storageTrie, accountHash); err != nil { - return err + if hooks != nil && hooks.OnStorageTrie != nil { + slots, nodes, err := hooks.OnStorageTrie(ctx, storageTrie, accountHash) + if err != nil { + return err + } + counters.slots.Add(slots) + counters.nodes.Add(nodes) } } if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if err := callbacks.OnCode(acc.CodeHash, accountHash); err != nil { + if err := hooks.OnCode(acc.CodeHash, accountHash); err != nil { return err } + counters.codes.Add(1) } } } @@ -873,11 +834,12 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s default: } - // Check if we've reached the limit for this branch if limitKey != nil && bytes.Compare(accIter.Key, limitKey) >= 0 { break } + counters.accounts.Add(1) + var acc types.StateAccount if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { log.Error("Invalid account encountered during traversal", "err", err) @@ -886,11 +848,6 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s accountHash := common.BytesToHash(accIter.Key) - // Process account - if err := callbacks.OnAccount(acc, accountHash); err != nil { - return err - } - // Process storage if present if acc.Root != types.EmptyRootHash { id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) @@ -900,16 +857,20 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s return err } - if _, _, err := callbacks.OnStorage(ctx, storageTrie, accountHash); err != nil { + slots, nodes, err := hooks.OnStorageTrie(ctx, storageTrie, accountHash) + if err != nil { return err } + counters.slots.Add(slots) + counters.nodes.Add(nodes) } // Process code if present if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if err := callbacks.OnCode(acc.CodeHash, accountHash); err != nil { + if err := hooks.OnCode(acc.CodeHash, accountHash); err != nil { return err } + counters.codes.Add(1) } }