From 83330c144a9e1878de945b73ad67c1794097881c Mon Sep 17 00:00:00 2001 From: jsvisa Date: Tue, 2 Sep 2025 15:33:05 +0800 Subject: [PATCH 1/9] cmd/geth: snapshot traverse-state by a start account Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 994cb149ce..032afb6d84 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -103,15 +103,18 @@ information about the specified address. { Name: "traverse-state", Usage: "Traverse the state with given root hash and perform quick verification", - ArgsUsage: "", + ArgsUsage: " [accountHash|accountAddress]", Action: traverseState, Flags: slices.Concat(utils.NetworkFlags, utils.DatabaseFlags), Description: ` -geth snapshot traverse-state +geth snapshot traverse-state [accountHash|accountAddress] will traverse the whole state from the given state root and will abort if any referenced trie node or contract code is missing. This command can be used for state integrity verification. The default checking target is the HEAD state. +If accountHash or accountAddress is provided, traversal will start from that specific account. +The format is auto-detected: 40/42 chars for address, 64/66 chars for hash. + It's also usable without snapshot enabled. `, }, @@ -290,25 +293,38 @@ func traverseState(ctx *cli.Context) error { log.Error("Failed to load head block") return errors.New("no head block") } - if ctx.NArg() > 1 { + if ctx.NArg() > 2 { log.Error("Too many arguments given") return errors.New("too many arguments") } var ( - root common.Hash - err error + root common.Hash + startKey []byte + err error ) - if ctx.NArg() == 1 { + if ctx.NArg() >= 1 { root, err = parseRoot(ctx.Args().First()) if err != nil { log.Error("Failed to resolve state root", "err", err) return err } - log.Info("Start traversing the state", "root", root) } else { root = headBlock.Root() - log.Info("Start traversing the state", "root", root, "number", headBlock.NumberU64()) } + + if ctx.NArg() == 2 { + arg := ctx.Args().Get(1) + switch len(arg) { + case 40, 42: + startKey = crypto.Keccak256Hash(common.HexToAddress(arg).Bytes()).Bytes() + case 64, 66: + startKey = common.HexToHash(arg).Bytes() + default: + return errors.New("invalid account format: must be 40/42 chars for address or 64/66 chars for hash") + } + } + + log.Info("Start traversing the state", "root", root.Hex(), "startKey", common.Bytes2Hex(startKey)) t, err := trie.NewStateTrie(trie.StateTrieID(root), triedb) if err != nil { log.Error("Failed to open trie", "root", root, "err", err) @@ -321,7 +337,8 @@ func traverseState(ctx *cli.Context) error { lastReport time.Time start = time.Now() ) - acctIt, err := t.NodeIterator(nil) + + acctIt, err := t.NodeIterator(startKey) if err != nil { log.Error("Failed to open iterator", "root", root, "err", err) return err From cd54a41336c7969e2f04afc9563d4ab70f536d6a Mon Sep 17 00:00:00 2001 From: jsvisa Date: Tue, 2 Sep 2025 15:52:00 +0800 Subject: [PATCH 2/9] same for traverse-rawstate Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 96 ++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 49 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 032afb6d84..8dd51355f6 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -112,7 +112,7 @@ will traverse the whole state from the given state root and will abort if any referenced trie node or contract code is missing. This command can be used for state integrity verification. The default checking target is the HEAD state. -If accountHash or accountAddress is provided, traversal will start from that specific account. +If accountHash or accountAddress is provided, traversal will start from that specific account and continue through all subsequent accounts. The format is auto-detected: 40/42 chars for address, 64/66 chars for hash. It's also usable without snapshot enabled. @@ -121,7 +121,7 @@ It's also usable without snapshot enabled. { Name: "traverse-rawstate", Usage: "Traverse the state with given root hash and perform detailed verification", - ArgsUsage: "", + ArgsUsage: " [accountHash|accountAddress]", Action: traverseRawState, Flags: slices.Concat(utils.NetworkFlags, utils.DatabaseFlags), Description: ` @@ -131,6 +131,9 @@ trie node or contract code is missing. This command can be used for state integr verification. The default checking target is the HEAD state. It's basically identical to traverse-state, but the check granularity is smaller. +If accountHash or accountAddress is provided, traversal will start from that specific account and continue through all subsequent accounts. +The format is auto-detected: 40/42 chars for address, 64/66 chars for hash. + It's also usable without snapshot enabled. `, }, @@ -293,35 +296,13 @@ func traverseState(ctx *cli.Context) error { log.Error("Failed to load head block") return errors.New("no head block") } - if ctx.NArg() > 2 { - log.Error("Too many arguments given") - return errors.New("too many arguments") - } - var ( - root common.Hash - startKey []byte - err error - ) - if ctx.NArg() >= 1 { - root, err = parseRoot(ctx.Args().First()) - if err != nil { - log.Error("Failed to resolve state root", "err", err) - return err - } - } else { - root = headBlock.Root() - } - if ctx.NArg() == 2 { - arg := ctx.Args().Get(1) - switch len(arg) { - case 40, 42: - startKey = crypto.Keccak256Hash(common.HexToAddress(arg).Bytes()).Bytes() - case 64, 66: - startKey = common.HexToHash(arg).Bytes() - default: - return errors.New("invalid account format: must be 40/42 chars for address or 64/66 chars for hash") - } + root, startKey, err := parseTraverseArgs(ctx) + if err != nil { + return err + } + if root == (common.Hash{}) { + root = headBlock.Root() } log.Info("Start traversing the state", "root", root.Hex(), "startKey", common.Bytes2Hex(startKey)) @@ -397,6 +378,34 @@ func traverseState(ctx *cli.Context) error { return nil } +func parseTraverseArgs(ctx *cli.Context) (root common.Hash, startKey []byte, err error) { + if ctx.NArg() > 2 { + err = errors.New("too many arguments") + return + } + + if ctx.NArg() >= 1 { + root, err = parseRoot(ctx.Args().First()) + if err != nil { + return + } + } + + if ctx.NArg() == 2 { + arg := ctx.Args().Get(1) + switch len(arg) { + case 40, 42: + startKey = crypto.Keccak256Hash(common.HexToAddress(arg).Bytes()).Bytes() + case 64, 66: + startKey = common.HexToHash(arg).Bytes() + default: + err = errors.New("invalid account format: must be 40/42 chars for address or 64/66 chars for hash") + return + } + } + return root, startKey, nil +} + // traverseRawState is a helper function used for pruning verification. // Basically it just iterates the trie, ensure all nodes and associated // contract codes are present. It's basically identical to traverseState @@ -416,25 +425,14 @@ func traverseRawState(ctx *cli.Context) error { log.Error("Failed to load head block") return errors.New("no head block") } - if ctx.NArg() > 1 { - log.Error("Too many arguments given") - return errors.New("too many arguments") - } - var ( - root common.Hash - err error - ) - if ctx.NArg() == 1 { - root, err = parseRoot(ctx.Args().First()) - if err != nil { - log.Error("Failed to resolve state root", "err", err) - return err - } - log.Info("Start traversing the state", "root", root) - } else { - root = headBlock.Root() - log.Info("Start traversing the state", "root", root, "number", headBlock.NumberU64()) + + root, startKey, err := parseTraverseArgs(ctx) + if err != nil { + log.Error("Failed to parse arguments", "err", err) + return err } + + log.Info("Start traversing the state", "root", root.Hex(), "startKey", common.Bytes2Hex(startKey)) t, err := trie.NewStateTrie(trie.StateTrieID(root), triedb) if err != nil { log.Error("Failed to open trie", "root", root, "err", err) @@ -450,7 +448,7 @@ func traverseRawState(ctx *cli.Context) error { hasher = crypto.NewKeccakState() got = make([]byte, 32) ) - accIter, err := t.NodeIterator(nil) + accIter, err := t.NodeIterator(startKey) if err != nil { log.Error("Failed to open iterator", "root", root, "err", err) return err From f7074e170ce8d11b78f74d3acca180b2a43e8f64 Mon Sep 17 00:00:00 2001 From: jsvisa Date: Mon, 8 Sep 2025 04:01:06 +0000 Subject: [PATCH 3/9] traverse with --account Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 590 ++++++++++++++++++++++++++++--------------- cmd/utils/flags.go | 16 ++ 2 files changed, 398 insertions(+), 208 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 8dd51355f6..e24bba082f 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -103,17 +103,24 @@ information about the specified address. { Name: "traverse-state", Usage: "Traverse the state with given root hash and perform quick verification", - ArgsUsage: " [accountHash|accountAddress]", + ArgsUsage: "", Action: traverseState, - Flags: slices.Concat(utils.NetworkFlags, utils.DatabaseFlags), + Flags: slices.Concat(utils.TraverseStateFlags, utils.NetworkFlags, utils.DatabaseFlags), Description: ` -geth snapshot traverse-state [accountHash|accountAddress] -will traverse the whole state from the given state root and will abort if any -referenced trie node or contract code is missing. This command can be used for -state integrity verification. The default checking target is the HEAD state. +geth snapshot traverse-state [--account ] [--start ] [--limit ] -If accountHash or accountAddress is provided, traversal will start from that specific account and continue through all subsequent accounts. -The format is auto-detected: 40/42 chars for address, 64/66 chars for hash. +1. Traverse the whole state from the given state root: +- --start: starting account key (64/66 chars hex) [optional] +- --limit: ending account key (64/66 chars hex) [optional] + +2. Traverse a specific account's storage: +- --account: account address (40/42 chars) or hash (64/66 chars) [required] +- --start: starting storage key (64/66 chars hex) [optional] +- --limit: ending storage key (64/66 chars hex) [optional] + +The default checking state root is the HEAD state if not specified. +The command will abort if any referenced trie node or contract code is missing. +This can be used for state integrity verification. The default target is HEAD state. It's also usable without snapshot enabled. `, @@ -121,18 +128,26 @@ It's also usable without snapshot enabled. { Name: "traverse-rawstate", Usage: "Traverse the state with given root hash and perform detailed verification", - ArgsUsage: " [accountHash|accountAddress]", + ArgsUsage: "", Action: traverseRawState, - Flags: slices.Concat(utils.NetworkFlags, utils.DatabaseFlags), + Flags: slices.Concat(utils.TraverseStateFlags, utils.NetworkFlags, utils.DatabaseFlags), Description: ` -geth snapshot traverse-rawstate -will traverse the whole state from the given root and will abort if any referenced -trie node or contract code is missing. This command can be used for state integrity -verification. The default checking target is the HEAD state. It's basically identical -to traverse-state, but the check granularity is smaller. +geth snapshot traverse-rawstate [--account ] [--start ] [--limit ] -If accountHash or accountAddress is provided, traversal will start from that specific account and continue through all subsequent accounts. -The format is auto-detected: 40/42 chars for address, 64/66 chars for hash. +Similar to traverse-state but with more detailed verification at the trie node level. + +1. Traverse the whole state from the given state root: +- --start: starting account key (64/66 chars hex) [optional] +- --limit: ending account key (64/66 chars hex) [optional] + +2. Traverse a specific account's storage: +- --account: account address (40/42 chars) or hash (64/66 chars) [required] +- --start: starting storage key (64/66 chars hex) [optional] +- --limit: ending storage key (64/66 chars hex) [optional] + +The default checking state root is the HEAD state if not specified. +The command will abort if any referenced trie node or contract code is missing. +This can be used for state integrity verification. The default target is HEAD state. It's also usable without snapshot enabled. `, @@ -297,113 +312,188 @@ func traverseState(ctx *cli.Context) error { return errors.New("no head block") } - root, startKey, err := parseTraverseArgs(ctx) + config, err := parseTraverseArgs(ctx) if err != nil { return err } - if root == (common.Hash{}) { - root = headBlock.Root() + if config.root == (common.Hash{}) { + config.root = headBlock.Root() } - log.Info("Start traversing the state", "root", root.Hex(), "startKey", common.Bytes2Hex(startKey)) - t, err := trie.NewStateTrie(trie.StateTrieID(root), triedb) + t, err := trie.NewStateTrie(trie.StateTrieID(config.root), triedb) if err != nil { - log.Error("Failed to open trie", "root", root, "err", err) + log.Error("Failed to open trie", "root", config.root, "err", err) return err } + var ( - accounts int - slots int - codes int - lastReport time.Time - start = time.Now() + accounts int + slots int + codes int + start = time.Now() ) - acctIt, err := t.NodeIterator(startKey) - if err != nil { - log.Error("Failed to open iterator", "root", root, "err", err) - return err - } - accIter := trie.NewIterator(acctIt) - for accIter.Next() { - accounts += 1 - var acc types.StateAccount - if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { - log.Error("Invalid account encountered during traversal", "err", err) + go func() { + timer := time.NewTicker(time.Second * 8) + defer timer.Stop() + for range timer.C { + log.Info("Traversing state", "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) + } + }() + + if config.isAccount { + log.Info("Start traversing storage trie", "root", config.root.Hex(), "account", config.account.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) + + acc, err := t.GetAccountByHash(config.account) + if err != nil { + log.Error("Get account failed", "account", config.account.Hex(), "err", err) return err } - if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(root, common.BytesToHash(accIter.Key), acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - storageIt, err := storageTrie.NodeIterator(nil) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) - return err - } - storageIter := trie.NewIterator(storageIt) - for storageIter.Next() { - slots += 1 - if time.Since(lastReport) > time.Second*8 { - log.Info("Traversing state", "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - lastReport = time.Now() + if acc.Root == types.EmptyRootHash { + log.Info("Account has no storage") + return nil + } + + id := trie.StorageTrieID(config.root, config.account, acc.Root) + storageTrie, err := trie.NewStateTrie(id, triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return err + } + + storageIt, err := storageTrie.NodeIterator(config.startKey) + if err != nil { + log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) + return err + } + + storageIter := trie.NewIterator(storageIt) + for storageIter.Next() { + if config.limitKey != nil && bytes.Compare(storageIter.Key, config.limitKey) >= 0 { + break + } + + slots += 1 + log.Debug("Storage slot", "key", common.Bytes2Hex(storageIter.Key), "value", common.Bytes2Hex(storageIter.Value)) + } + if storageIter.Err != nil { + log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Err) + return storageIter.Err + } + + log.Info("Storage traversal complete", "slots", slots, "elapsed", common.PrettyDuration(time.Since(start))) + return nil + } else { + log.Info("Start traversing state trie", "root", config.root.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) + + acctIt, err := t.NodeIterator(config.startKey) + if err != nil { + log.Error("Failed to open iterator", "root", config.root, "err", err) + return err + } + accIter := trie.NewIterator(acctIt) + for accIter.Next() { + if config.limitKey != nil && bytes.Compare(accIter.Key, config.limitKey) >= 0 { + break + } + + accounts += 1 + var acc types.StateAccount + if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { + log.Error("Invalid account encountered during traversal", "err", err) + return err + } + if acc.Root != types.EmptyRootHash { + id := trie.StorageTrieID(config.root, common.BytesToHash(accIter.Key), acc.Root) + storageTrie, err := trie.NewStateTrie(id, triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return err + } + storageIt, err := storageTrie.NodeIterator(nil) + if err != nil { + log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) + return err + } + storageIter := trie.NewIterator(storageIt) + for storageIter.Next() { + slots += 1 + } + if storageIter.Err != nil { + log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Err) + return storageIter.Err } } - if storageIter.Err != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Err) - return storageIter.Err + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { + if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { + log.Error("Code is missing", "hash", common.BytesToHash(acc.CodeHash)) + return errors.New("missing code") + } + codes += 1 } } - if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { - log.Error("Code is missing", "hash", common.BytesToHash(acc.CodeHash)) - return errors.New("missing code") - } - codes += 1 - } - if time.Since(lastReport) > time.Second*8 { - log.Info("Traversing state", "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - lastReport = time.Now() + if accIter.Err != nil { + log.Error("Failed to traverse state trie", "root", config.root, "err", accIter.Err) + return accIter.Err } + log.Info("State traversal complete", "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) + return nil } - if accIter.Err != nil { - log.Error("Failed to traverse state trie", "root", root, "err", accIter.Err) - return accIter.Err - } - log.Info("State is complete", "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - return nil } -func parseTraverseArgs(ctx *cli.Context) (root common.Hash, startKey []byte, err error) { - if ctx.NArg() > 2 { - err = errors.New("too many arguments") - return +type traverseConfig struct { + root common.Hash + startKey []byte + limitKey []byte + account common.Hash + isAccount bool +} + +func parseTraverseArgs(ctx *cli.Context) (*traverseConfig, error) { + if ctx.NArg() > 1 { + return nil, errors.New("too many arguments, only is required") } - if ctx.NArg() >= 1 { - root, err = parseRoot(ctx.Args().First()) + config := &traverseConfig{} + var err error + + if ctx.NArg() == 1 { + config.root, err = parseRoot(ctx.Args().First()) if err != nil { - return + return nil, err } } - if ctx.NArg() == 2 { - arg := ctx.Args().Get(1) - switch len(arg) { + if accountFlag := ctx.String("account"); accountFlag != "" { + config.isAccount = true + switch len(accountFlag) { case 40, 42: - startKey = crypto.Keccak256Hash(common.HexToAddress(arg).Bytes()).Bytes() + config.account = crypto.Keccak256Hash(common.HexToAddress(accountFlag).Bytes()) case 64, 66: - startKey = common.HexToHash(arg).Bytes() + config.account = common.HexToHash(accountFlag) default: - err = errors.New("invalid account format: must be 40/42 chars for address or 64/66 chars for hash") - return + return nil, errors.New("account must be 40/42 chars for address or 64/66 chars for hash") } } - return root, startKey, nil + + if startFlag := ctx.String("start"); startFlag != "" { + if len(startFlag) == 64 || len(startFlag) == 66 { + config.startKey = common.HexToHash(startFlag).Bytes() + } else { + return nil, errors.New("start key must be 64/66 chars hex") + } + } + + if limitFlag := ctx.String("limit"); limitFlag != "" { + if len(limitFlag) == 64 || len(limitFlag) == 66 { + config.limitKey = common.HexToHash(limitFlag).Bytes() + } else { + return nil, errors.New("limit key must be 64/66 chars hex") + } + } + + return config, nil } // traverseRawState is a helper function used for pruning verification. @@ -426,132 +516,216 @@ func traverseRawState(ctx *cli.Context) error { return errors.New("no head block") } - root, startKey, err := parseTraverseArgs(ctx) + config, err := parseTraverseArgs(ctx) if err != nil { log.Error("Failed to parse arguments", "err", err) return err } + if config.root == (common.Hash{}) { + config.root = headBlock.Root() + } + t, err := trie.NewStateTrie(trie.StateTrieID(config.root), triedb) + if err != nil { + log.Error("Failed to open trie", "root", config.root, "err", err) + return err + } - log.Info("Start traversing the state", "root", root.Hex(), "startKey", common.Bytes2Hex(startKey)) - t, err := trie.NewStateTrie(trie.StateTrieID(root), triedb) - if err != nil { - log.Error("Failed to open trie", "root", root, "err", err) - return err - } var ( - nodes int - accounts int - slots int - codes int - lastReport time.Time - start = time.Now() - hasher = crypto.NewKeccakState() - got = make([]byte, 32) + accounts int + nodes int + slots int + codes int + start = time.Now() + hasher = crypto.NewKeccakState() + got = make([]byte, 32) ) - accIter, err := t.NodeIterator(startKey) - if err != nil { - log.Error("Failed to open iterator", "root", root, "err", err) - return err - } - reader, err := triedb.NodeReader(root) - if err != nil { - log.Error("State is non-existent", "root", root) + + go func() { + timer := time.NewTicker(time.Second * 8) + defer timer.Stop() + for range timer.C { + log.Info("Traversing rawstate", "nodes", nodes, "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) + } + }() + + if config.isAccount { + log.Info("Start traversing storage trie (raw)", "root", config.root.Hex(), "account", config.account.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) + + acc, err := t.GetAccountByHash(config.account) + if err != nil { + log.Error("Get account failed", "account", config.account.Hex(), "err", err) + return err + } + + if acc.Root == types.EmptyRootHash { + log.Info("Account has no storage") + return nil + } + + // Traverse the storage trie with detailed verification + id := trie.StorageTrieID(config.root, config.account, acc.Root) + storageTrie, err := trie.NewStateTrie(id, triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return err + } + + storageIter, err := storageTrie.NodeIterator(config.startKey) + if err != nil { + log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) + return err + } + + reader, err := triedb.NodeReader(config.root) + if err != nil { + log.Error("State is non-existent", "root", config.root) + return nil + } + + for storageIter.Next(true) { + nodes += 1 + node := storageIter.Hash() + + // Check the presence for non-empty hash node(embedded node doesn't + // have their own hash). + if node != (common.Hash{}) { + blob, _ := reader.Node(config.account, storageIter.Path(), node) + if len(blob) == 0 { + log.Error("Missing trie node(storage)", "hash", node) + return errors.New("missing storage") + } + hasher.Reset() + hasher.Write(blob) + hasher.Read(got) + if !bytes.Equal(got, node.Bytes()) { + log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) + return errors.New("invalid storage node") + } + } + + // Bump the counter if it's leaf node. + if storageIter.Leaf() { + // Check if we've exceeded the limit key for storage + if config.limitKey != nil && bytes.Compare(storageIter.LeafKey(), config.limitKey) >= 0 { + break + } + + slots += 1 + log.Debug("Storage slot", "key", common.Bytes2Hex(storageIter.LeafKey()), "value", common.Bytes2Hex(storageIter.LeafBlob())) + } + } + if storageIter.Error() != nil { + log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Error()) + return storageIter.Error() + } + + log.Info("Storage traversal complete (raw)", "nodes", nodes, "slots", slots, "elapsed", common.PrettyDuration(time.Since(start))) + return nil + } else { + log.Info("Start traversing the state trie (raw)", "root", config.root.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) + + accIter, err := t.NodeIterator(config.startKey) + if err != nil { + log.Error("Failed to open iterator", "root", config.root, "err", err) + return err + } + reader, err := triedb.NodeReader(config.root) + if err != nil { + log.Error("State is non-existent", "root", config.root) + return nil + } + for accIter.Next(true) { + nodes += 1 + node := accIter.Hash() + + // Check the present for non-empty hash node(embedded node doesn't + // have their own hash). + if node != (common.Hash{}) { + blob, _ := reader.Node(common.Hash{}, accIter.Path(), node) + if len(blob) == 0 { + log.Error("Missing trie node(account)", "hash", node) + return errors.New("missing account") + } + hasher.Reset() + hasher.Write(blob) + hasher.Read(got) + if !bytes.Equal(got, node.Bytes()) { + log.Error("Invalid trie node(account)", "hash", node.Hex(), "value", blob) + return errors.New("invalid account node") + } + } + // If it's a leaf node, yes we are touching an account, + // dig into the storage trie further. + if accIter.Leaf() { + // Check if we've exceeded the limit key for accounts + if config.limitKey != nil && bytes.Compare(accIter.LeafKey(), config.limitKey) >= 0 { + break + } + + accounts += 1 + var acc types.StateAccount + if err := rlp.DecodeBytes(accIter.LeafBlob(), &acc); err != nil { + log.Error("Invalid account encountered during traversal", "err", err) + return errors.New("invalid account") + } + if acc.Root != types.EmptyRootHash { + id := trie.StorageTrieID(config.root, common.BytesToHash(accIter.LeafKey()), acc.Root) + storageTrie, err := trie.NewStateTrie(id, triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return errors.New("missing storage trie") + } + storageIter, err := storageTrie.NodeIterator(nil) + if err != nil { + log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) + return err + } + for storageIter.Next(true) { + nodes += 1 + node := storageIter.Hash() + + // Check the presence for non-empty hash node(embedded node doesn't + // have their own hash). + if node != (common.Hash{}) { + blob, _ := reader.Node(common.BytesToHash(accIter.LeafKey()), storageIter.Path(), node) + if len(blob) == 0 { + log.Error("Missing trie node(storage)", "hash", node) + return errors.New("missing storage") + } + hasher.Reset() + hasher.Write(blob) + hasher.Read(got) + if !bytes.Equal(got, node.Bytes()) { + log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) + return errors.New("invalid storage node") + } + } + // Bump the counter if it's leaf node. + if storageIter.Leaf() { + slots += 1 + } + } + if storageIter.Error() != nil { + log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Error()) + return storageIter.Error() + } + } + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { + if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { + log.Error("Code is missing", "account", common.BytesToHash(accIter.LeafKey())) + return errors.New("missing code") + } + codes += 1 + } + } + } + if accIter.Error() != nil { + log.Error("Failed to traverse state trie", "root", config.root, "err", accIter.Error()) + return accIter.Error() + } + log.Info("State traversal complete (raw)", "nodes", nodes, "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) return nil } - for accIter.Next(true) { - nodes += 1 - node := accIter.Hash() - - // Check the present for non-empty hash node(embedded node doesn't - // have their own hash). - if node != (common.Hash{}) { - blob, _ := reader.Node(common.Hash{}, accIter.Path(), node) - if len(blob) == 0 { - log.Error("Missing trie node(account)", "hash", node) - return errors.New("missing account") - } - hasher.Reset() - hasher.Write(blob) - hasher.Read(got) - if !bytes.Equal(got, node.Bytes()) { - log.Error("Invalid trie node(account)", "hash", node.Hex(), "value", blob) - return errors.New("invalid account node") - } - } - // If it's a leaf node, yes we are touching an account, - // dig into the storage trie further. - if accIter.Leaf() { - accounts += 1 - var acc types.StateAccount - if err := rlp.DecodeBytes(accIter.LeafBlob(), &acc); err != nil { - log.Error("Invalid account encountered during traversal", "err", err) - return errors.New("invalid account") - } - if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(root, common.BytesToHash(accIter.LeafKey()), acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return errors.New("missing storage trie") - } - storageIter, err := storageTrie.NodeIterator(nil) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) - return err - } - for storageIter.Next(true) { - nodes += 1 - node := storageIter.Hash() - - // Check the presence for non-empty hash node(embedded node doesn't - // have their own hash). - if node != (common.Hash{}) { - blob, _ := reader.Node(common.BytesToHash(accIter.LeafKey()), storageIter.Path(), node) - if len(blob) == 0 { - log.Error("Missing trie node(storage)", "hash", node) - return errors.New("missing storage") - } - hasher.Reset() - hasher.Write(blob) - hasher.Read(got) - if !bytes.Equal(got, node.Bytes()) { - log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) - return errors.New("invalid storage node") - } - } - // Bump the counter if it's leaf node. - if storageIter.Leaf() { - slots += 1 - } - if time.Since(lastReport) > time.Second*8 { - log.Info("Traversing state", "nodes", nodes, "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - lastReport = time.Now() - } - } - if storageIter.Error() != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Error()) - return storageIter.Error() - } - } - if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { - log.Error("Code is missing", "account", common.BytesToHash(accIter.LeafKey())) - return errors.New("missing code") - } - codes += 1 - } - if time.Since(lastReport) > time.Second*8 { - log.Info("Traversing state", "nodes", nodes, "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - lastReport = time.Now() - } - } - } - if accIter.Error() != nil { - log.Error("Failed to traverse state trie", "root", root, "err", accIter.Error()) - return accIter.Error() - } - log.Info("State is complete", "nodes", nodes, "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - return nil } func parseRoot(input string) (common.Hash, error) { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index a134ea4308..5a5fdf88f9 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -994,6 +994,22 @@ var ( StateSchemeFlag, HttpHeaderFlag, } + + // TraverseStateFlags is the flag group of all state traversal flags. + TraverseStateFlags = []cli.Flag{ + &cli.StringFlag{ + Name: "account", + Usage: "Account address or hash to traverse storage for (enables account mode)", + }, + &cli.StringFlag{ + Name: "start", + Usage: "Starting key (account key for state mode, storage key for account mode)", + }, + &cli.StringFlag{ + Name: "limit", + Usage: "Ending key (account key for state mode, storage key for account mode)", + }, + } ) // default account to prefund when running Geth in dev mode From dc19cae10e1b7fb034cfa68b090c07013ef165e9 Mon Sep 17 00:00:00 2001 From: jsvisa Date: Tue, 9 Sep 2025 09:29:29 +0000 Subject: [PATCH 4/9] in parallel Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 217 +++++++++++++++++++++++++++++++++---------- 1 file changed, 166 insertions(+), 51 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index e24bba082f..1a8573e118 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -18,13 +18,17 @@ package main import ( "bytes" + "context" "encoding/json" "errors" "fmt" "os" "slices" + "sync/atomic" "time" + "golang.org/x/sync/errgroup" + "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" @@ -33,9 +37,11 @@ import ( "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/triedb" "github.com/urfave/cli/v2" ) @@ -327,9 +333,9 @@ func traverseState(ctx *cli.Context) error { } var ( - accounts int - slots int - codes int + accounts atomic.Uint64 + slots atomic.Uint64 + codes atomic.Uint64 start = time.Now() ) @@ -337,7 +343,7 @@ func traverseState(ctx *cli.Context) error { timer := time.NewTicker(time.Second * 8) defer timer.Stop() for range timer.C { - log.Info("Traversing state", "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) + log.Info("Traversing state", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) } }() @@ -374,7 +380,7 @@ func traverseState(ctx *cli.Context) error { break } - slots += 1 + slots.Add(1) log.Debug("Storage slot", "key", common.Bytes2Hex(storageIter.Key), "value", common.Bytes2Hex(storageIter.Value)) } if storageIter.Err != nil { @@ -382,64 +388,173 @@ func traverseState(ctx *cli.Context) error { return storageIter.Err } - log.Info("Storage traversal complete", "slots", slots, "elapsed", common.PrettyDuration(time.Since(start))) + log.Info("Storage traversal complete", "slots", slots.Load(), "elapsed", common.PrettyDuration(time.Since(start))) return nil } else { log.Info("Start traversing state trie", "root", config.root.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - acctIt, err := t.NodeIterator(config.startKey) - if err != nil { - log.Error("Failed to open iterator", "root", config.root, "err", err) - return err - } - accIter := trie.NewIterator(acctIt) - for accIter.Next() { - if config.limitKey != nil && bytes.Compare(accIter.Key, config.limitKey) >= 0 { - break + return traverseStateParallel(t, triedb, chaindb, config, &accounts, &slots, &codes, start) + } +} + +// traverseStateParallel parallelizes state traversal by dividing work across 16 trie branches +func traverseStateParallel(t *trie.StateTrie, triedb *triedb.Database, chaindb ethdb.Database, config *traverseConfig, accounts, slots, codes *atomic.Uint64, start time.Time) error { + ctx := context.Background() + g, ctx := errgroup.WithContext(ctx) + + for i := 0; i < 16; i++ { + nibble := byte(i) + g.Go(func() error { + startKey := config.startKey + limitKey := config.limitKey + + branchStartKey := make([]byte, len(startKey)+1) + branchLimitKey := make([]byte, len(startKey)+1) + + if len(startKey) > 0 { + copy(branchStartKey, startKey) + copy(branchLimitKey, startKey) } - accounts += 1 - var acc types.StateAccount - if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { - log.Error("Invalid account encountered during traversal", "err", err) + branchStartKey[len(startKey)] = nibble << 4 + branchLimitKey[len(startKey)] = (nibble + 1) << 4 + + if limitKey != nil && bytes.Compare(branchStartKey, limitKey) >= 0 { + return nil + } + if limitKey != nil && bytes.Compare(branchLimitKey, limitKey) > 0 { + branchLimitKey = make([]byte, len(limitKey)) + copy(branchLimitKey, limitKey) + } + + return traverseBranch(ctx, t, triedb, chaindb, config.root, branchStartKey, branchLimitKey, accounts, slots, codes) + }) + } + + if err := g.Wait(); err != nil { + return err + } + + log.Info("State traversal complete", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) + return nil +} + +// traverseBranch traverses a specific branch of the state trie +func traverseBranch(ctx context.Context, t *trie.StateTrie, triedb *triedb.Database, chaindb ethdb.Database, root common.Hash, startKey, limitKey []byte, accounts, slots, codes *atomic.Uint64) error { + acctIt, err := t.NodeIterator(startKey) + if err != nil { + return err + } + + accIter := trie.NewIterator(acctIt) + for accIter.Next() { + // Check if context was cancelled + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if limitKey != nil && bytes.Compare(accIter.Key, limitKey) >= 0 { + break + } + + accounts.Add(1) + + var acc types.StateAccount + if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { + log.Error("Invalid account encountered during traversal", "err", err) + return err + } + + if acc.Root != types.EmptyRootHash { + id := trie.StorageTrieID(root, common.BytesToHash(accIter.Key), acc.Root) + storageTrie, err := trie.NewStateTrie(id, triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) return err } - if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(config.root, common.BytesToHash(accIter.Key), acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - storageIt, err := storageTrie.NodeIterator(nil) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) - return err - } - storageIter := trie.NewIterator(storageIt) - for storageIter.Next() { - slots += 1 - } - if storageIter.Err != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Err) - return storageIter.Err - } + + localSlots, err := traverseStorageParallel(ctx, storageTrie) + if err != nil { + log.Error("Failed to traverse storage trie", "root", acc.Root, "err", err) + return err } - if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { - log.Error("Code is missing", "hash", common.BytesToHash(acc.CodeHash)) - return errors.New("missing code") - } - codes += 1 + slots.Add(localSlots) + } + + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { + if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { + log.Error("Code is missing", "hash", common.BytesToHash(acc.CodeHash)) + return errors.New("missing code") } + codes.Add(1) } - if accIter.Err != nil { - log.Error("Failed to traverse state trie", "root", config.root, "err", accIter.Err) - return accIter.Err - } - log.Info("State traversal complete", "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - return nil } + + if accIter.Err != nil { + return accIter.Err + } + + return nil +} + +// traverseStorageParallel parallelizes storage trie traversal by dividing work across 16 trie branches +func traverseStorageParallel(ctx context.Context, storageTrie *trie.StateTrie) (uint64, error) { + g, ctx := errgroup.WithContext(ctx) + totalSlots := atomic.Uint64{} + + for i := 0; i < 16; i++ { + nibble := byte(i) + g.Go(func() error { + branchStartKey := []byte{nibble << 4} + branchLimitKey := []byte{(nibble + 1) << 4} + + localSlots, err := traverseStorageBranch(ctx, storageTrie, branchStartKey, branchLimitKey) + if err != nil { + return err + } + totalSlots.Add(localSlots) + return nil + }) + } + + if err := g.Wait(); err != nil { + return 0, err + } + + return totalSlots.Load(), nil +} + +// traverseStorageBranch traverses a specific branch of the storage trie +func traverseStorageBranch(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte) (uint64, error) { + storageIt, err := storageTrie.NodeIterator(startKey) + if err != nil { + return 0, err + } + + storageIter := trie.NewIterator(storageIt) + slots := uint64(0) + + for storageIter.Next() { + // Check if context was cancelled + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + } + + if bytes.Compare(storageIter.Key, limitKey) >= 0 { + break + } + slots++ + } + + if storageIter.Err != nil { + return 0, storageIter.Err + } + + return slots, nil } type traverseConfig struct { From d22a7d833adca3d4163592f1e11fa1311d2f50fc Mon Sep 17 00:00:00 2001 From: jsvisa Date: Wed, 10 Sep 2025 04:51:47 +0000 Subject: [PATCH 5/9] in parallel Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 122 +++++++++++++++++++++++++------------------ 1 file changed, 70 insertions(+), 52 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 1a8573e118..edcb67cc9a 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -339,11 +339,19 @@ func traverseState(ctx *cli.Context) error { start = time.Now() ) + cctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { timer := time.NewTicker(time.Second * 8) defer timer.Stop() - for range timer.C { - log.Info("Traversing state", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) + for { + select { + case <-timer.C: + log.Info("Traversing state", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) + case <-cctx.Done(): + return + } } }() @@ -389,58 +397,57 @@ func traverseState(ctx *cli.Context) error { } log.Info("Storage traversal complete", "slots", slots.Load(), "elapsed", common.PrettyDuration(time.Since(start))) - return nil } else { log.Info("Start traversing state trie", "root", config.root.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - return traverseStateParallel(t, triedb, chaindb, config, &accounts, &slots, &codes, start) + eg, ctx := errgroup.WithContext(cctx) + + // Parallel processing with boundary checks + for i := 0; i < 16; i++ { + nibble := byte(i) + eg.Go(func() error { + // Skip branches that are entirely before startKey + if len(config.startKey) > 0 { + startNibble := config.startKey[0] >> 4 + if nibble < startNibble { + return nil // Skip this branch + } + } + + // Skip branches that are entirely after limitKey + if len(config.limitKey) > 0 { + limitNibble := config.limitKey[0] >> 4 + if nibble > limitNibble { + return nil // Skip this branch + } + } + + var ( + startKey = []byte{nibble << 4} + limitKey []byte + ) + if nibble < 15 { + limitKey = []byte{(nibble + 1) << 4} + } else { + // Last branch (0xf*) has no limit + limitKey = nil + } + + return traverseBranch(ctx, t, triedb, chaindb, config, startKey, limitKey, &accounts, &slots, &codes) + }) + } + + if err := eg.Wait(); err != nil { + return err + } + + log.Info("State traversal complete", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) } -} - -// traverseStateParallel parallelizes state traversal by dividing work across 16 trie branches -func traverseStateParallel(t *trie.StateTrie, triedb *triedb.Database, chaindb ethdb.Database, config *traverseConfig, accounts, slots, codes *atomic.Uint64, start time.Time) error { - ctx := context.Background() - g, ctx := errgroup.WithContext(ctx) - - for i := 0; i < 16; i++ { - nibble := byte(i) - g.Go(func() error { - startKey := config.startKey - limitKey := config.limitKey - - branchStartKey := make([]byte, len(startKey)+1) - branchLimitKey := make([]byte, len(startKey)+1) - - if len(startKey) > 0 { - copy(branchStartKey, startKey) - copy(branchLimitKey, startKey) - } - - branchStartKey[len(startKey)] = nibble << 4 - branchLimitKey[len(startKey)] = (nibble + 1) << 4 - - if limitKey != nil && bytes.Compare(branchStartKey, limitKey) >= 0 { - return nil - } - if limitKey != nil && bytes.Compare(branchLimitKey, limitKey) > 0 { - branchLimitKey = make([]byte, len(limitKey)) - copy(branchLimitKey, limitKey) - } - - return traverseBranch(ctx, t, triedb, chaindb, config.root, branchStartKey, branchLimitKey, accounts, slots, codes) - }) - } - - if err := g.Wait(); err != nil { - return err - } - - log.Info("State traversal complete", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) return nil } // traverseBranch traverses a specific branch of the state trie -func traverseBranch(ctx context.Context, t *trie.StateTrie, triedb *triedb.Database, chaindb ethdb.Database, root common.Hash, startKey, limitKey []byte, accounts, slots, codes *atomic.Uint64) error { +func traverseBranch(ctx context.Context, t *trie.StateTrie, triedb *triedb.Database, chaindb ethdb.Database, config *traverseConfig, startKey, limitKey []byte, accounts, slots, codes *atomic.Uint64) error { acctIt, err := t.NodeIterator(startKey) if err != nil { return err @@ -448,13 +455,16 @@ func traverseBranch(ctx context.Context, t *trie.StateTrie, triedb *triedb.Datab accIter := trie.NewIterator(acctIt) for accIter.Next() { - // Check if context was cancelled select { case <-ctx.Done(): return ctx.Err() default: } + if config.startKey != nil && bytes.Compare(accIter.Key, config.startKey) < 0 { + continue + } + if limitKey != nil && bytes.Compare(accIter.Key, limitKey) >= 0 { break } @@ -468,7 +478,7 @@ func traverseBranch(ctx context.Context, t *trie.StateTrie, triedb *triedb.Datab } if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(root, common.BytesToHash(accIter.Key), acc.Root) + id := trie.StorageTrieID(config.root, common.BytesToHash(accIter.Key), acc.Root) storageTrie, err := trie.NewStateTrie(id, triedb) if err != nil { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) @@ -507,10 +517,18 @@ func traverseStorageParallel(ctx context.Context, storageTrie *trie.StateTrie) ( for i := 0; i < 16; i++ { nibble := byte(i) g.Go(func() error { - branchStartKey := []byte{nibble << 4} - branchLimitKey := []byte{(nibble + 1) << 4} + var ( + startKey = []byte{nibble << 4} + limitKey []byte + ) + if nibble < 15 { + limitKey = []byte{(nibble + 1) << 4} + } else { + // For the last branch (0xf*), no limit key (traverse to end) + limitKey = nil + } - localSlots, err := traverseStorageBranch(ctx, storageTrie, branchStartKey, branchLimitKey) + localSlots, err := traverseStorageBranch(ctx, storageTrie, startKey, limitKey) if err != nil { return err } @@ -544,7 +562,7 @@ func traverseStorageBranch(ctx context.Context, storageTrie *trie.StateTrie, sta default: } - if bytes.Compare(storageIter.Key, limitKey) >= 0 { + if limitKey != nil && bytes.Compare(storageIter.Key, limitKey) >= 0 { break } slots++ From bd840a1c18eb0852377bfea064c5b950c7002a15 Mon Sep 17 00:00:00 2001 From: jsvisa Date: Thu, 11 Sep 2025 02:34:04 +0000 Subject: [PATCH 6/9] use callbacks Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 919 ++++++++++++++++++++++++------------------- 1 file changed, 510 insertions(+), 409 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index edcb67cc9a..0f4f952b7d 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -39,9 +39,11 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/triedb" + "github.com/ethereum/go-ethereum/triedb/database" "github.com/urfave/cli/v2" ) @@ -299,47 +301,17 @@ func checkDanglingStorage(ctx *cli.Context) error { return snapshot.CheckDanglingStorage(db) } -// traverseState is a helper function used for pruning verification. -// Basically it just iterates the trie, ensure all nodes and associated -// contract codes are present. func traverseState(ctx *cli.Context) error { - stack, _ := makeConfigNode(ctx) - defer stack.Close() - - chaindb := utils.MakeChainDatabase(ctx, stack, true) - defer chaindb.Close() - - triedb := utils.MakeTrieDatabase(ctx, stack, chaindb, false, true, false) - defer triedb.Close() - - headBlock := rawdb.ReadHeadBlock(chaindb) - if headBlock == nil { - log.Error("Failed to load head block") - return errors.New("no head block") - } - - config, err := parseTraverseArgs(ctx) + ts, err := setupTraversal(ctx) if err != nil { return err } - if config.root == (common.Hash{}) { - config.root = headBlock.Root() - } - - t, err := trie.NewStateTrie(trie.StateTrieID(config.root), triedb) - if err != nil { - log.Error("Failed to open trie", "root", config.root, "err", err) - return err - } + defer ts.Close() var ( - accounts atomic.Uint64 - slots atomic.Uint64 - codes atomic.Uint64 - start = time.Now() + counters = &traverseCounters{start: time.Now()} + cctx, cancel = context.WithCancel(context.Background()) ) - - cctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { @@ -348,231 +320,170 @@ func traverseState(ctx *cli.Context) error { for { select { case <-timer.C: - log.Info("Traversing state", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) - case <-cctx.Done(): + log.Info("Traversing state", "accounts", counters.accounts.Load(), "slots", counters.slots.Load(), "codes", counters.codes.Load(), "elapsed", common.PrettyDuration(time.Since(counters.start))) + case <-ctx.Done(): return } } }() - if config.isAccount { - log.Info("Start traversing storage trie", "root", config.root.Hex(), "account", config.account.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - - acc, err := t.GetAccountByHash(config.account) - if err != nil { - log.Error("Get account failed", "account", config.account.Hex(), "err", err) - return err - } - - if acc.Root == types.EmptyRootHash { - log.Info("Account has no storage") - return nil - } - - id := trie.StorageTrieID(config.root, config.account, acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - - storageIt, err := storageTrie.NodeIterator(config.startKey) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) - return err - } - - storageIter := trie.NewIterator(storageIt) - for storageIter.Next() { - if config.limitKey != nil && bytes.Compare(storageIter.Key, config.limitKey) >= 0 { - break - } - - slots.Add(1) - log.Debug("Storage slot", "key", common.Bytes2Hex(storageIter.Key), "value", common.Bytes2Hex(storageIter.Value)) - } - if storageIter.Err != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Err) - return storageIter.Err - } - - log.Info("Storage traversal complete", "slots", slots.Load(), "elapsed", common.PrettyDuration(time.Since(start))) + if ts.config.isAccount { + return ts.traverseAccount(cctx, counters, false) } else { - log.Info("Start traversing state trie", "root", config.root.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - - eg, ctx := errgroup.WithContext(cctx) - - // Parallel processing with boundary checks - for i := 0; i < 16; i++ { - nibble := byte(i) - eg.Go(func() error { - // Skip branches that are entirely before startKey - if len(config.startKey) > 0 { - startNibble := config.startKey[0] >> 4 - if nibble < startNibble { - return nil // Skip this branch - } - } - - // Skip branches that are entirely after limitKey - if len(config.limitKey) > 0 { - limitNibble := config.limitKey[0] >> 4 - if nibble > limitNibble { - return nil // Skip this branch - } - } - - var ( - startKey = []byte{nibble << 4} - limitKey []byte - ) - if nibble < 15 { - limitKey = []byte{(nibble + 1) << 4} - } else { - // Last branch (0xf*) has no limit - limitKey = nil - } - - return traverseBranch(ctx, t, triedb, chaindb, config, startKey, limitKey, &accounts, &slots, &codes) - }) - } - - if err := eg.Wait(); err != nil { - return err - } - - log.Info("State traversal complete", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) + return ts.traverseState(cctx, counters, false) } - return nil } -// traverseBranch traverses a specific branch of the state trie -func traverseBranch(ctx context.Context, t *trie.StateTrie, triedb *triedb.Database, chaindb ethdb.Database, config *traverseConfig, startKey, limitKey []byte, accounts, slots, codes *atomic.Uint64) error { - acctIt, err := t.NodeIterator(startKey) - if err != nil { - return err - } - - accIter := trie.NewIterator(acctIt) - for accIter.Next() { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - if config.startKey != nil && bytes.Compare(accIter.Key, config.startKey) < 0 { - continue - } - - if limitKey != nil && bytes.Compare(accIter.Key, limitKey) >= 0 { - break - } - - accounts.Add(1) - - var acc types.StateAccount - if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { - log.Error("Invalid account encountered during traversal", "err", err) - return err - } - - if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(config.root, common.BytesToHash(accIter.Key), acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - - localSlots, err := traverseStorageParallel(ctx, storageTrie) - if err != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", err) - return err - } - slots.Add(localSlots) - } - - if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { - log.Error("Code is missing", "hash", common.BytesToHash(acc.CodeHash)) - return errors.New("missing code") - } - codes.Add(1) - } - } - - if accIter.Err != nil { - return accIter.Err - } - - return nil +// StorageCallbacks defines the callbacks for different storage traversal modes +type StorageCallbacks struct { + // Called for each storage node (only for raw mode) + OnStorageNode func(node common.Hash, path []byte) error } -// traverseStorageParallel parallelizes storage trie traversal by dividing work across 16 trie branches -func traverseStorageParallel(ctx context.Context, storageTrie *trie.StateTrie) (uint64, error) { - g, ctx := errgroup.WithContext(ctx) - totalSlots := atomic.Uint64{} +// createSimpleStorageCallbacks creates callbacks for simple storage traversal +func createSimpleStorageCallbacks() *StorageCallbacks { + return &StorageCallbacks{ + OnStorageNode: nil, // Simple mode doesn't process nodes + } +} + +// createRawStorageCallbacks creates callbacks for raw storage traversal with verification +func createRawStorageCallbacks(reader database.NodeReader, accountHash common.Hash) *StorageCallbacks { + return &StorageCallbacks{ + OnStorageNode: func(node common.Hash, path []byte) error { + if node != (common.Hash{}) { + blob, _ := reader.Node(accountHash, path, node) + if len(blob) == 0 { + log.Error("Missing trie node(storage)", "hash", node) + return errors.New("missing storage") + } + if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { + log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) + return errors.New("invalid storage node") + } + } + return nil + }, + } +} + +// traverseStorageParallelWithCallbacks parallelizes storage trie traversal using callbacks +func traverseStorageParallelWithCallbacks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, callbacks *StorageCallbacks, raw bool) (slots, nodes uint64, err error) { + var ( + eg, cctx = errgroup.WithContext(ctx) + slotsAtomic atomic.Uint64 + nodesAtomic atomic.Uint64 + ) for i := 0; i < 16; i++ { nibble := byte(i) - g.Go(func() error { + eg.Go(func() error { + // Calculate this branch's natural boundaries var ( - startKey = []byte{nibble << 4} - limitKey []byte + branchStart = []byte{nibble << 4} + branchLimit []byte ) if nibble < 15 { - limitKey = []byte{(nibble + 1) << 4} - } else { - // For the last branch (0xf*), no limit key (traverse to end) - limitKey = nil + branchLimit = []byte{(nibble + 1) << 4} } - localSlots, err := traverseStorageBranch(ctx, storageTrie, startKey, limitKey) + // Skip branches that are entirely before startKey + if startKey != nil && branchLimit != nil && bytes.Compare(branchLimit, startKey) <= 0 { + return nil + } + + // Skip branches that are entirely after limitKey + if limitKey != nil && bytes.Compare(branchStart, limitKey) >= 0 { + return nil + } + + // Use the more restrictive start boundary + if startKey != nil && bytes.Compare(branchStart, startKey) < 0 { + branchStart = startKey + } + if limitKey != nil && (branchLimit == nil || bytes.Compare(branchLimit, limitKey) > 0) { + branchLimit = limitKey + } + + // Skip if branch range is empty + if branchLimit != nil && bytes.Compare(branchStart, branchLimit) >= 0 { + return nil + } + + localSlots, localNodes, err := traverseStorageBranchWithCallbacks(cctx, storageTrie, branchStart, branchLimit, callbacks, raw) if err != nil { return err } - totalSlots.Add(localSlots) + slotsAtomic.Add(localSlots) + nodesAtomic.Add(localNodes) return nil }) } - if err := g.Wait(); err != nil { - return 0, err + if err := eg.Wait(); err != nil { + return 0, 0, err } - return totalSlots.Load(), nil + return slotsAtomic.Load(), nodesAtomic.Load(), nil } -// traverseStorageBranch traverses a specific branch of the storage trie -func traverseStorageBranch(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte) (uint64, error) { - storageIt, err := storageTrie.NodeIterator(startKey) +// traverseStorageBranchWithCallbacks traverses a specific range of the storage trie using callbacks +func traverseStorageBranchWithCallbacks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, callbacks *StorageCallbacks, raw bool) (slots, nodes uint64, err error) { + nodeIter, err := storageTrie.NodeIterator(startKey) if err != nil { - return 0, err + return 0, 0, err } - storageIter := trie.NewIterator(storageIt) - slots := uint64(0) + if raw { + // Raw traversal with detailed node checking + for nodeIter.Next(true) { + select { + case <-ctx.Done(): + return 0, 0, ctx.Err() + default: + } - for storageIter.Next() { - // Check if context was cancelled - select { - case <-ctx.Done(): - return 0, ctx.Err() - default: + nodes++ + if callbacks != nil && callbacks.OnStorageNode != nil { + if err := callbacks.OnStorageNode(nodeIter.Hash(), nodeIter.Path()); err != nil { + return 0, 0, err + } + } + + if nodeIter.Leaf() { + if limitKey != nil && bytes.Compare(nodeIter.LeafKey(), limitKey) >= 0 { + break + } + slots++ + } + } + } else { + // Simple traversal - just iterate through leaf nodes + storageIter := trie.NewIterator(nodeIter) + for storageIter.Next() { + select { + case <-ctx.Done(): + return 0, 0, ctx.Err() + default: + } + + if limitKey != nil && bytes.Compare(storageIter.Key, limitKey) >= 0 { + break + } + + slots++ } - if limitKey != nil && bytes.Compare(storageIter.Key, limitKey) >= 0 { - break + if storageIter.Err != nil { + return 0, 0, storageIter.Err } - slots++ } - if storageIter.Err != nil { - return 0, storageIter.Err + if err := nodeIter.Error(); err != nil { + return 0, 0, err } - return slots, nil + return slots, nodes, nil } type traverseConfig struct { @@ -583,6 +494,181 @@ type traverseConfig struct { isAccount bool } +type traverseSetup struct { + stack *node.Node + chaindb ethdb.Database + triedb *triedb.Database + trie *trie.StateTrie + config *traverseConfig +} + +type traverseCounters struct { + accounts atomic.Uint64 + slots atomic.Uint64 + codes atomic.Uint64 + nodes atomic.Uint64 + start time.Time +} + +func setupTraversal(ctx *cli.Context) (*traverseSetup, error) { + stack, _ := makeConfigNode(ctx) + + chaindb := utils.MakeChainDatabase(ctx, stack, true) + triedb := utils.MakeTrieDatabase(ctx, stack, chaindb, false, true, false) + + headBlock := rawdb.ReadHeadBlock(chaindb) + if headBlock == nil { + log.Error("Failed to load head block") + return nil, errors.New("no head block") + } + + config, err := parseTraverseArgs(ctx) + if err != nil { + return nil, err + } + if config.root == (common.Hash{}) { + config.root = headBlock.Root() + } + + t, err := trie.NewStateTrie(trie.StateTrieID(config.root), triedb) + if err != nil { + log.Error("Failed to open trie", "root", config.root, "err", err) + return nil, err + } + + return &traverseSetup{ + stack: stack, + chaindb: chaindb, + triedb: triedb, + trie: t, + config: config, + }, nil +} + +func (ts *traverseSetup) Close() { + ts.triedb.Close() + ts.chaindb.Close() + ts.stack.Close() +} + +func (ts *traverseSetup) traverseAccount(ctx context.Context, counters *traverseCounters, raw bool) error { + log.Info("Start traversing storage trie", "root", ts.config.root.Hex(), "account", ts.config.account.Hex(), "startKey", common.Bytes2Hex(ts.config.startKey), "limitKey", common.Bytes2Hex(ts.config.limitKey)) + + acc, err := ts.trie.GetAccountByHash(ts.config.account) + if err != nil { + log.Error("Get account failed", "account", ts.config.account.Hex(), "err", err) + return err + } + + if acc.Root == types.EmptyRootHash { + log.Info("Account has no storage") + return nil + } + + id := trie.StorageTrieID(ts.config.root, ts.config.account, acc.Root) + storageTrie, err := trie.NewStateTrie(id, ts.triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return err + } + + var callbacks *StorageCallbacks + if raw { + reader, err := ts.triedb.NodeReader(ts.config.root) + if err != nil { + log.Error("State is non-existent", "root", ts.config.root) + return nil + } + callbacks = createRawStorageCallbacks(reader, ts.config.account) + } else { + callbacks = createSimpleStorageCallbacks() + } + + slots, nodes, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, ts.config.startKey, ts.config.limitKey, callbacks, raw) + if err != nil { + log.Error("Failed to traverse storage trie", "root", acc.Root, "err", err) + return err + } + + counters.slots.Add(slots) + counters.nodes.Add(nodes) + + if raw { + log.Info("Storage traversal complete (raw)", "nodes", counters.nodes.Load(), "slots", counters.slots.Load(), "elapsed", common.PrettyDuration(time.Since(counters.start))) + } else { + log.Info("Storage traversal complete", "slots", counters.slots.Load(), "elapsed", common.PrettyDuration(time.Since(counters.start))) + } + return nil +} + +func (ts *traverseSetup) traverseState(ctx context.Context, counters *traverseCounters, raw bool) error { + log.Info("Start traversing state trie", "root", ts.config.root.Hex(), "startKey", common.Bytes2Hex(ts.config.startKey), "limitKey", common.Bytes2Hex(ts.config.limitKey)) + + eg, ctx := errgroup.WithContext(ctx) + var reader database.NodeReader + if raw { + var err error + reader, err = ts.triedb.NodeReader(ts.config.root) + if err != nil { + log.Error("State is non-existent", "root", ts.config.root) + return nil + } + } + + for i := 0; i < 16; i++ { + nibble := byte(i) + eg.Go(func() error { + var ( + startKey = []byte{nibble << 4} + limitKey []byte + ) + if nibble < 15 { + limitKey = []byte{(nibble + 1) << 4} + } + + if ts.config != nil { + // Skip branches that are entirely before startKey + if limitKey != nil && bytes.Compare(limitKey, ts.config.startKey) <= 0 { + return nil + } + + // Skip branches that are entirely after limitKey + if ts.config.limitKey != nil && bytes.Compare(startKey, ts.config.limitKey) >= 0 { + return nil + } + + if ts.config.startKey != nil && bytes.Compare(startKey, ts.config.startKey) < 0 { + startKey = ts.config.startKey + } + if ts.config.limitKey != nil && (limitKey == nil || bytes.Compare(limitKey, ts.config.limitKey) > 0) { + limitKey = ts.config.limitKey + } + } + + if limitKey != nil && bytes.Compare(startKey, limitKey) >= 0 { + return nil + } + + // Create appropriate callbacks based on traversal mode + var callbacks *TraverseCallbacks + if raw { + callbacks = ts.createRawCallbacks(counters, reader) + } else { + callbacks = ts.createSimpleCallbacks(counters) + } + + return ts.traverseStateBranchWithCallbacks(ctx, startKey, limitKey, callbacks, raw) + }) + } + + if err := eg.Wait(); err != nil { + return err + } + + log.Info("State traversal complete", "accounts", counters.accounts.Load(), "slots", counters.slots.Load(), "codes", counters.codes.Load(), "elapsed", common.PrettyDuration(time.Since(counters.start))) + return nil +} + func parseTraverseArgs(ctx *cli.Context) (*traverseConfig, error) { if ctx.NArg() > 1 { return nil, errors.New("too many arguments, only is required") @@ -629,235 +715,250 @@ func parseTraverseArgs(ctx *cli.Context) (*traverseConfig, error) { return config, nil } -// traverseRawState is a helper function used for pruning verification. -// Basically it just iterates the trie, ensure all nodes and associated -// contract codes are present. It's basically identical to traverseState -// but it will check each trie node. -func traverseRawState(ctx *cli.Context) error { - stack, _ := makeConfigNode(ctx) - defer stack.Close() +// TraverseCallbacks defines the callbacks for different traversal modes +type TraverseCallbacks struct { + // Called for each trie account node (only for raw mode) + OnNode func(node common.Hash, path []byte) error + // Called for each account + OnAccount func(acc types.StateAccount, accountHash common.Hash) error + // Called for each storage slot + OnStorage func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) + // Called for each code + OnCode func(codeHash []byte, accountHash common.Hash) error +} - chaindb := utils.MakeChainDatabase(ctx, stack, true) - defer chaindb.Close() - - triedb := utils.MakeTrieDatabase(ctx, stack, chaindb, false, true, false) - defer triedb.Close() - - headBlock := rawdb.ReadHeadBlock(chaindb) - if headBlock == nil { - log.Error("Failed to load head block") - return errors.New("no head block") - } - - config, err := parseTraverseArgs(ctx) - if err != nil { - log.Error("Failed to parse arguments", "err", err) - return err - } - if config.root == (common.Hash{}) { - config.root = headBlock.Root() - } - t, err := trie.NewStateTrie(trie.StateTrieID(config.root), triedb) - if err != nil { - log.Error("Failed to open trie", "root", config.root, "err", err) - return err - } - - var ( - accounts int - nodes int - slots int - codes int - start = time.Now() - hasher = crypto.NewKeccakState() - got = make([]byte, 32) - ) - - go func() { - timer := time.NewTicker(time.Second * 8) - defer timer.Stop() - for range timer.C { - log.Info("Traversing rawstate", "nodes", nodes, "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - } - }() - - if config.isAccount { - log.Info("Start traversing storage trie (raw)", "root", config.root.Hex(), "account", config.account.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - - acc, err := t.GetAccountByHash(config.account) - if err != nil { - log.Error("Get account failed", "account", config.account.Hex(), "err", err) - return err - } - - if acc.Root == types.EmptyRootHash { - log.Info("Account has no storage") +// createSimpleCallbacks creates callbacks for simple traversal mode +func (ts *traverseSetup) createSimpleCallbacks(counters *traverseCounters) *TraverseCallbacks { + return &TraverseCallbacks{ + OnAccount: func(acc types.StateAccount, accountHash common.Hash) error { + counters.accounts.Add(1) return nil - } - - // Traverse the storage trie with detailed verification - id := trie.StorageTrieID(config.root, config.account, acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - - storageIter, err := storageTrie.NodeIterator(config.startKey) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) - return err - } - - reader, err := triedb.NodeReader(config.root) - if err != nil { - log.Error("State is non-existent", "root", config.root) - return nil - } - - for storageIter.Next(true) { - nodes += 1 - node := storageIter.Hash() - - // Check the presence for non-empty hash node(embedded node doesn't - // have their own hash). - if node != (common.Hash{}) { - blob, _ := reader.Node(config.account, storageIter.Path(), node) - if len(blob) == 0 { - log.Error("Missing trie node(storage)", "hash", node) - return errors.New("missing storage") - } - hasher.Reset() - hasher.Write(blob) - hasher.Read(got) - if !bytes.Equal(got, node.Bytes()) { - log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) - return errors.New("invalid storage node") - } + }, + OnStorage: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { + callbacks := createSimpleStorageCallbacks() + s, n, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, nil, nil, callbacks, false) + if err != nil { + return 0, 0, err } - - // Bump the counter if it's leaf node. - if storageIter.Leaf() { - // Check if we've exceeded the limit key for storage - if config.limitKey != nil && bytes.Compare(storageIter.LeafKey(), config.limitKey) >= 0 { - break - } - - slots += 1 - log.Debug("Storage slot", "key", common.Bytes2Hex(storageIter.LeafKey()), "value", common.Bytes2Hex(storageIter.LeafBlob())) + counters.slots.Add(s) + return s, n, nil + }, + OnCode: func(codeHash []byte, accountHash common.Hash) error { + if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { + log.Error("Code is missing", "hash", common.BytesToHash(codeHash)) + return errors.New("missing code") } - } - if storageIter.Error() != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Error()) - return storageIter.Error() - } - - log.Info("Storage traversal complete (raw)", "nodes", nodes, "slots", slots, "elapsed", common.PrettyDuration(time.Since(start))) - return nil - } else { - log.Info("Start traversing the state trie (raw)", "root", config.root.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - - accIter, err := t.NodeIterator(config.startKey) - if err != nil { - log.Error("Failed to open iterator", "root", config.root, "err", err) - return err - } - reader, err := triedb.NodeReader(config.root) - if err != nil { - log.Error("State is non-existent", "root", config.root) + counters.codes.Add(1) return nil - } - for accIter.Next(true) { - nodes += 1 - node := accIter.Hash() + }, + } +} - // Check the present for non-empty hash node(embedded node doesn't - // have their own hash). +// createRawCallbacks creates callbacks for raw traversal mode with detailed verification +func (ts *traverseSetup) createRawCallbacks(counters *traverseCounters, reader database.NodeReader) *TraverseCallbacks { + return &TraverseCallbacks{ + OnNode: func(node common.Hash, path []byte) error { + counters.nodes.Add(1) if node != (common.Hash{}) { - blob, _ := reader.Node(common.Hash{}, accIter.Path(), node) + blob, _ := reader.Node(common.Hash{}, path, node) if len(blob) == 0 { log.Error("Missing trie node(account)", "hash", node) return errors.New("missing account") } - hasher.Reset() - hasher.Write(blob) - hasher.Read(got) - if !bytes.Equal(got, node.Bytes()) { + if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { log.Error("Invalid trie node(account)", "hash", node.Hex(), "value", blob) return errors.New("invalid account node") } } - // If it's a leaf node, yes we are touching an account, - // dig into the storage trie further. + return nil + }, + OnAccount: func(acc types.StateAccount, accountHash common.Hash) error { + counters.accounts.Add(1) + return nil + }, + OnStorage: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { + callbacks := createRawStorageCallbacks(reader, accountHash) + s, n, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, nil, nil, callbacks, true) + if err != nil { + return 0, 0, err + } + counters.slots.Add(s) + counters.nodes.Add(n) + return s, n, nil + }, + OnCode: func(codeHash []byte, accountHash common.Hash) error { + if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { + log.Error("Code is missing", "account", accountHash) + return errors.New("missing code") + } + counters.codes.Add(1) + return nil + }, + } +} + +// traverseStateBranchWithCallbacks provides common branch traversal logic using callbacks +func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, startKey, limitKey []byte, callbacks *TraverseCallbacks, raw bool) error { + accIter, err := ts.trie.NodeIterator(startKey) + if err != nil { + return err + } + + if raw { + // Raw traversal with detailed node checking + for accIter.Next(true) { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if callbacks.OnNode != nil { + if err := callbacks.OnNode(accIter.Hash(), accIter.Path()); err != nil { + return err + } + } + + // If it's a leaf node, process the account if accIter.Leaf() { - // Check if we've exceeded the limit key for accounts - if config.limitKey != nil && bytes.Compare(accIter.LeafKey(), config.limitKey) >= 0 { + if limitKey != nil && bytes.Compare(accIter.LeafKey(), limitKey) >= 0 { break } - accounts += 1 var acc types.StateAccount if err := rlp.DecodeBytes(accIter.LeafBlob(), &acc); err != nil { log.Error("Invalid account encountered during traversal", "err", err) return errors.New("invalid account") } + + accountHash := common.BytesToHash(accIter.LeafKey()) + + if err := callbacks.OnAccount(acc, accountHash); err != nil { + return err + } + + // Process storage if present if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(config.root, common.BytesToHash(accIter.LeafKey()), acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) + id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) + storageTrie, err := trie.NewStateTrie(id, ts.triedb) if err != nil { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return errors.New("missing storage trie") - } - storageIter, err := storageTrie.NodeIterator(nil) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) return err } - for storageIter.Next(true) { - nodes += 1 - node := storageIter.Hash() - // Check the presence for non-empty hash node(embedded node doesn't - // have their own hash). - if node != (common.Hash{}) { - blob, _ := reader.Node(common.BytesToHash(accIter.LeafKey()), storageIter.Path(), node) - if len(blob) == 0 { - log.Error("Missing trie node(storage)", "hash", node) - return errors.New("missing storage") - } - hasher.Reset() - hasher.Write(blob) - hasher.Read(got) - if !bytes.Equal(got, node.Bytes()) { - log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) - return errors.New("invalid storage node") - } - } - // Bump the counter if it's leaf node. - if storageIter.Leaf() { - slots += 1 - } - } - if storageIter.Error() != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Error()) - return storageIter.Error() + if _, _, err := callbacks.OnStorage(ctx, storageTrie, accountHash); err != nil { + return err } } + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { - log.Error("Code is missing", "account", common.BytesToHash(accIter.LeafKey())) - return errors.New("missing code") + if err := callbacks.OnCode(acc.CodeHash, accountHash); err != nil { + return err } - codes += 1 } } } - if accIter.Error() != nil { - log.Error("Failed to traverse state trie", "root", config.root, "err", accIter.Error()) - return accIter.Error() + } else { + // Simple traversal - just iterate through leaf nodes + acctIt, err := ts.trie.NodeIterator(startKey) + if err != nil { + return err } - log.Info("State traversal complete (raw)", "nodes", nodes, "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - return nil + + accIter := trie.NewIterator(acctIt) + for accIter.Next() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // Check if we've reached the limit for this branch + if limitKey != nil && bytes.Compare(accIter.Key, limitKey) >= 0 { + break + } + + var acc types.StateAccount + if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { + log.Error("Invalid account encountered during traversal", "err", err) + return err + } + + accountHash := common.BytesToHash(accIter.Key) + + // Process account + if err := callbacks.OnAccount(acc, accountHash); err != nil { + return err + } + + // Process storage if present + if acc.Root != types.EmptyRootHash { + id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) + storageTrie, err := trie.NewStateTrie(id, ts.triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return err + } + + if _, _, err := callbacks.OnStorage(ctx, storageTrie, accountHash); err != nil { + return err + } + } + + // Process code if present + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { + if err := callbacks.OnCode(acc.CodeHash, accountHash); err != nil { + return err + } + } + } + + if accIter.Err != nil { + return accIter.Err + } + } + + if accIter.Error() != nil { + return accIter.Error() + } + + return nil +} + +// traverseRawState is a helper function used for pruning verification. +// Basically it just iterates the trie, ensure all nodes and associated +// contract codes are present. It's basically identical to traverseState +// but it will check each trie node. +func traverseRawState(ctx *cli.Context) error { + ts, err := setupTraversal(ctx) + if err != nil { + return err + } + defer ts.Close() + + var ( + counters = &traverseCounters{start: time.Now()} + cctx, cancel = context.WithCancel(context.Background()) + ) + defer cancel() + + go func() { + timer := time.NewTicker(time.Second * 8) + defer timer.Stop() + for { + select { + case <-timer.C: + log.Info("Traversing rawstate", "nodes", counters.nodes.Load(), "accounts", counters.accounts.Load(), "slots", counters.slots.Load(), "codes", counters.codes.Load(), "elapsed", common.PrettyDuration(time.Since(counters.start))) + case <-cctx.Done(): + return + } + } + }() + + if ts.config.isAccount { + return ts.traverseAccount(cctx, counters, true) + } else { + return ts.traverseState(cctx, counters, true) } } From 20a3a7ded2439929993bbeaa4e5cfdde97b0eb56 Mon Sep 17 00:00:00 2001 From: jsvisa Date: Thu, 11 Sep 2025 03:28:16 +0000 Subject: [PATCH 7/9] refine Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 189 +++++++++++++++++-------------------------- 1 file changed, 75 insertions(+), 114 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 0f4f952b7d..a0222dda22 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -334,45 +334,32 @@ func traverseState(ctx *cli.Context) error { } } -// StorageCallbacks defines the callbacks for different storage traversal modes -type StorageCallbacks struct { - // Called for each storage node (only for raw mode) - OnStorageNode func(node common.Hash, path []byte) error -} +type OnStorageNodeHook func(node common.Hash, path []byte) error -// createSimpleStorageCallbacks creates callbacks for simple storage traversal -func createSimpleStorageCallbacks() *StorageCallbacks { - return &StorageCallbacks{ - OnStorageNode: nil, // Simple mode doesn't process nodes - } -} - -// createRawStorageCallbacks creates callbacks for raw storage traversal with verification -func createRawStorageCallbacks(reader database.NodeReader, accountHash common.Hash) *StorageCallbacks { - return &StorageCallbacks{ - OnStorageNode: func(node common.Hash, path []byte) error { - if node != (common.Hash{}) { - blob, _ := reader.Node(accountHash, path, node) - if len(blob) == 0 { - log.Error("Missing trie node(storage)", "hash", node) - return errors.New("missing storage") - } - if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { - log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) - return errors.New("invalid storage node") - } +// createRawStorageHook creates hooks for raw storage traversal with verification +func createRawStorageHook(reader database.NodeReader, accountHash common.Hash) OnStorageNodeHook { + return func(node common.Hash, path []byte) error { + if node != (common.Hash{}) { + blob, _ := reader.Node(accountHash, path, node) + if len(blob) == 0 { + log.Error("Missing trie node(storage)", "hash", node) + return errors.New("missing storage") } - return nil - }, + if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { + log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) + return errors.New("invalid storage node") + } + } + return nil } } -// traverseStorageParallelWithCallbacks parallelizes storage trie traversal using callbacks -func traverseStorageParallelWithCallbacks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, callbacks *StorageCallbacks, raw bool) (slots, nodes uint64, err error) { +// traverseStorage parallelizes storage trie traversal +func traverseStorage(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, raw bool, hook OnStorageNodeHook) (uint64, uint64, error) { var ( - eg, cctx = errgroup.WithContext(ctx) - slotsAtomic atomic.Uint64 - nodesAtomic atomic.Uint64 + eg, cctx = errgroup.WithContext(ctx) + slots atomic.Uint64 + nodes atomic.Uint64 ) for i := 0; i < 16; i++ { @@ -410,12 +397,12 @@ func traverseStorageParallelWithCallbacks(ctx context.Context, storageTrie *trie return nil } - localSlots, localNodes, err := traverseStorageBranchWithCallbacks(cctx, storageTrie, branchStart, branchLimit, callbacks, raw) + localSlots, localNodes, err := traverseStorageBranchWithHooks(cctx, storageTrie, branchStart, branchLimit, raw, hook) if err != nil { return err } - slotsAtomic.Add(localSlots) - nodesAtomic.Add(localNodes) + slots.Add(localSlots) + nodes.Add(localNodes) return nil }) } @@ -424,11 +411,11 @@ func traverseStorageParallelWithCallbacks(ctx context.Context, storageTrie *trie return 0, 0, err } - return slotsAtomic.Load(), nodesAtomic.Load(), nil + return slots.Load(), nodes.Load(), nil } -// traverseStorageBranchWithCallbacks traverses a specific range of the storage trie using callbacks -func traverseStorageBranchWithCallbacks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, callbacks *StorageCallbacks, raw bool) (slots, nodes uint64, err error) { +// traverseStorageBranchWithHooks traverses a specific range of the storage trie using hooks +func traverseStorageBranchWithHooks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, raw bool, hook OnStorageNodeHook) (slots, nodes uint64, err error) { nodeIter, err := storageTrie.NodeIterator(startKey) if err != nil { return 0, 0, err @@ -444,8 +431,8 @@ func traverseStorageBranchWithCallbacks(ctx context.Context, storageTrie *trie.S } nodes++ - if callbacks != nil && callbacks.OnStorageNode != nil { - if err := callbacks.OnStorageNode(nodeIter.Hash(), nodeIter.Path()); err != nil { + if hook != nil { + if err := hook(nodeIter.Hash(), nodeIter.Path()); err != nil { return 0, 0, err } } @@ -470,7 +457,6 @@ func traverseStorageBranchWithCallbacks(ctx context.Context, storageTrie *trie.S if limitKey != nil && bytes.Compare(storageIter.Key, limitKey) >= 0 { break } - slots++ } @@ -572,19 +558,17 @@ func (ts *traverseSetup) traverseAccount(ctx context.Context, counters *traverse return err } - var callbacks *StorageCallbacks + var hook OnStorageNodeHook if raw { reader, err := ts.triedb.NodeReader(ts.config.root) if err != nil { log.Error("State is non-existent", "root", ts.config.root) return nil } - callbacks = createRawStorageCallbacks(reader, ts.config.account) - } else { - callbacks = createSimpleStorageCallbacks() + hook = createRawStorageHook(reader, ts.config.account) } - slots, nodes, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, ts.config.startKey, ts.config.limitKey, callbacks, raw) + slots, nodes, err := traverseStorage(ctx, storageTrie, ts.config.startKey, ts.config.limitKey, raw, hook) if err != nil { log.Error("Failed to traverse storage trie", "root", acc.Root, "err", err) return err @@ -649,15 +633,14 @@ func (ts *traverseSetup) traverseState(ctx context.Context, counters *traverseCo return nil } - // Create appropriate callbacks based on traversal mode - var callbacks *TraverseCallbacks + var hooks *TraverseHooks if raw { - callbacks = ts.createRawCallbacks(counters, reader) + hooks = ts.createRawHooks(reader) } else { - callbacks = ts.createSimpleCallbacks(counters) + hooks = ts.createSimpleHooks() } - return ts.traverseStateBranchWithCallbacks(ctx, startKey, limitKey, callbacks, raw) + return ts.traverseStateBranchWithHooks(ctx, startKey, limitKey, raw, counters, hooks) }) } @@ -715,50 +698,36 @@ func parseTraverseArgs(ctx *cli.Context) (*traverseConfig, error) { return config, nil } -// TraverseCallbacks defines the callbacks for different traversal modes -type TraverseCallbacks struct { +// TraverseHooks defines the hooks for different traversal modes +type TraverseHooks struct { // Called for each trie account node (only for raw mode) - OnNode func(node common.Hash, path []byte) error - // Called for each account - OnAccount func(acc types.StateAccount, accountHash common.Hash) error - // Called for each storage slot - OnStorage func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) + OnAccountNode func(node common.Hash, path []byte) error + // Called for each storage trie + OnStorageTrie func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) // Called for each code OnCode func(codeHash []byte, accountHash common.Hash) error } -// createSimpleCallbacks creates callbacks for simple traversal mode -func (ts *traverseSetup) createSimpleCallbacks(counters *traverseCounters) *TraverseCallbacks { - return &TraverseCallbacks{ - OnAccount: func(acc types.StateAccount, accountHash common.Hash) error { - counters.accounts.Add(1) - return nil - }, - OnStorage: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { - callbacks := createSimpleStorageCallbacks() - s, n, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, nil, nil, callbacks, false) - if err != nil { - return 0, 0, err - } - counters.slots.Add(s) - return s, n, nil +// createSimpleHooks creates hooks for simple traversal mode +func (ts *traverseSetup) createSimpleHooks() *TraverseHooks { + return &TraverseHooks{ + OnStorageTrie: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { + return traverseStorage(ctx, storageTrie, nil, nil, false, nil) }, OnCode: func(codeHash []byte, accountHash common.Hash) error { if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { log.Error("Code is missing", "hash", common.BytesToHash(codeHash)) return errors.New("missing code") } - counters.codes.Add(1) return nil }, } } -// createRawCallbacks creates callbacks for raw traversal mode with detailed verification -func (ts *traverseSetup) createRawCallbacks(counters *traverseCounters, reader database.NodeReader) *TraverseCallbacks { - return &TraverseCallbacks{ - OnNode: func(node common.Hash, path []byte) error { - counters.nodes.Add(1) +// createRawHooks creates hooks for raw traversal mode with detailed verification +func (ts *traverseSetup) createRawHooks(reader database.NodeReader) *TraverseHooks { + return &TraverseHooks{ + OnAccountNode: func(node common.Hash, path []byte) error { if node != (common.Hash{}) { blob, _ := reader.Node(common.Hash{}, path, node) if len(blob) == 0 { @@ -772,33 +741,22 @@ func (ts *traverseSetup) createRawCallbacks(counters *traverseCounters, reader d } return nil }, - OnAccount: func(acc types.StateAccount, accountHash common.Hash) error { - counters.accounts.Add(1) - return nil - }, - OnStorage: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { - callbacks := createRawStorageCallbacks(reader, accountHash) - s, n, err := traverseStorageParallelWithCallbacks(ctx, storageTrie, nil, nil, callbacks, true) - if err != nil { - return 0, 0, err - } - counters.slots.Add(s) - counters.nodes.Add(n) - return s, n, nil + OnStorageTrie: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { + hook := createRawStorageHook(reader, accountHash) + return traverseStorage(ctx, storageTrie, nil, nil, true, hook) }, OnCode: func(codeHash []byte, accountHash common.Hash) error { if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { log.Error("Code is missing", "account", accountHash) return errors.New("missing code") } - counters.codes.Add(1) return nil }, } } -// traverseStateBranchWithCallbacks provides common branch traversal logic using callbacks -func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, startKey, limitKey []byte, callbacks *TraverseCallbacks, raw bool) error { +// traverseStateBranchWithHooks provides common branch traversal logic using hooks +func (ts *traverseSetup) traverseStateBranchWithHooks(ctx context.Context, startKey, limitKey []byte, raw bool, counters *traverseCounters, hooks *TraverseHooks) error { accIter, err := ts.trie.NodeIterator(startKey) if err != nil { return err @@ -813,8 +771,9 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s default: } - if callbacks.OnNode != nil { - if err := callbacks.OnNode(accIter.Hash(), accIter.Path()); err != nil { + counters.nodes.Add(1) + if hooks != nil && hooks.OnAccountNode != nil { + if err := hooks.OnAccountNode(accIter.Hash(), accIter.Path()); err != nil { return err } } @@ -824,6 +783,7 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s if limitKey != nil && bytes.Compare(accIter.LeafKey(), limitKey) >= 0 { break } + counters.accounts.Add(1) var acc types.StateAccount if err := rlp.DecodeBytes(accIter.LeafBlob(), &acc); err != nil { @@ -833,11 +793,6 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s accountHash := common.BytesToHash(accIter.LeafKey()) - if err := callbacks.OnAccount(acc, accountHash); err != nil { - return err - } - - // Process storage if present if acc.Root != types.EmptyRootHash { id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) storageTrie, err := trie.NewStateTrie(id, ts.triedb) @@ -846,15 +801,21 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s return err } - if _, _, err := callbacks.OnStorage(ctx, storageTrie, accountHash); err != nil { - return err + if hooks != nil && hooks.OnStorageTrie != nil { + slots, nodes, err := hooks.OnStorageTrie(ctx, storageTrie, accountHash) + if err != nil { + return err + } + counters.slots.Add(slots) + counters.nodes.Add(nodes) } } if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if err := callbacks.OnCode(acc.CodeHash, accountHash); err != nil { + if err := hooks.OnCode(acc.CodeHash, accountHash); err != nil { return err } + counters.codes.Add(1) } } } @@ -873,11 +834,12 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s default: } - // Check if we've reached the limit for this branch if limitKey != nil && bytes.Compare(accIter.Key, limitKey) >= 0 { break } + counters.accounts.Add(1) + var acc types.StateAccount if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { log.Error("Invalid account encountered during traversal", "err", err) @@ -886,11 +848,6 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s accountHash := common.BytesToHash(accIter.Key) - // Process account - if err := callbacks.OnAccount(acc, accountHash); err != nil { - return err - } - // Process storage if present if acc.Root != types.EmptyRootHash { id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) @@ -900,16 +857,20 @@ func (ts *traverseSetup) traverseStateBranchWithCallbacks(ctx context.Context, s return err } - if _, _, err := callbacks.OnStorage(ctx, storageTrie, accountHash); err != nil { + slots, nodes, err := hooks.OnStorageTrie(ctx, storageTrie, accountHash) + if err != nil { return err } + counters.slots.Add(slots) + counters.nodes.Add(nodes) } // Process code if present if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if err := callbacks.OnCode(acc.CodeHash, accountHash); err != nil { + if err := hooks.OnCode(acc.CodeHash, accountHash); err != nil { return err } + counters.codes.Add(1) } } From 43f74243e55e17159510a9bb2296d9c4b36db871 Mon Sep 17 00:00:00 2001 From: jsvisa Date: Thu, 11 Sep 2025 04:10:33 +0000 Subject: [PATCH 8/9] check config.start key Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index a0222dda22..53151f215e 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -612,7 +612,7 @@ func (ts *traverseSetup) traverseState(ctx context.Context, counters *traverseCo if ts.config != nil { // Skip branches that are entirely before startKey - if limitKey != nil && bytes.Compare(limitKey, ts.config.startKey) <= 0 { + if ts.config.startKey != nil && limitKey != nil && bytes.Compare(limitKey, ts.config.startKey) <= 0 { return nil } From d477a72666558c55c59195a7d9745f4f732e819b Mon Sep 17 00:00:00 2001 From: jsvisa Date: Thu, 11 Sep 2025 04:34:17 +0000 Subject: [PATCH 9/9] simplify Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 173 +++++++++++++++++-------------------------- 1 file changed, 67 insertions(+), 106 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 53151f215e..d162535d0a 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -336,24 +336,6 @@ func traverseState(ctx *cli.Context) error { type OnStorageNodeHook func(node common.Hash, path []byte) error -// createRawStorageHook creates hooks for raw storage traversal with verification -func createRawStorageHook(reader database.NodeReader, accountHash common.Hash) OnStorageNodeHook { - return func(node common.Hash, path []byte) error { - if node != (common.Hash{}) { - blob, _ := reader.Node(accountHash, path, node) - if len(blob) == 0 { - log.Error("Missing trie node(storage)", "hash", node) - return errors.New("missing storage") - } - if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { - log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) - return errors.New("invalid storage node") - } - } - return nil - } -} - // traverseStorage parallelizes storage trie traversal func traverseStorage(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, raw bool, hook OnStorageNodeHook) (uint64, uint64, error) { var ( @@ -397,7 +379,7 @@ func traverseStorage(ctx context.Context, storageTrie *trie.StateTrie, startKey, return nil } - localSlots, localNodes, err := traverseStorageBranchWithHooks(cctx, storageTrie, branchStart, branchLimit, raw, hook) + localSlots, localNodes, err := traverseStorageBranch(cctx, storageTrie, branchStart, branchLimit, raw, hook) if err != nil { return err } @@ -414,8 +396,8 @@ func traverseStorage(ctx context.Context, storageTrie *trie.StateTrie, startKey, return slots.Load(), nodes.Load(), nil } -// traverseStorageBranchWithHooks traverses a specific range of the storage trie using hooks -func traverseStorageBranchWithHooks(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, raw bool, hook OnStorageNodeHook) (slots, nodes uint64, err error) { +// traverseStorageBranch traverses a specific range of the storage trie using hooks +func traverseStorageBranch(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte, raw bool, hook OnStorageNodeHook) (slots, nodes uint64, err error) { nodeIter, err := storageTrie.NodeIterator(startKey) if err != nil { return 0, 0, err @@ -565,7 +547,11 @@ func (ts *traverseSetup) traverseAccount(ctx context.Context, counters *traverse log.Error("State is non-existent", "root", ts.config.root) return nil } - hook = createRawStorageHook(reader, ts.config.account) + // Create storage hook inline - same pattern as in TraverseHooks + accountHash := ts.config.account + hook = func(node common.Hash, path []byte) error { + return verifyTrieNode(reader, accountHash, node, path) + } } slots, nodes, err := traverseStorage(ctx, storageTrie, ts.config.startKey, ts.config.limitKey, raw, hook) @@ -640,7 +626,7 @@ func (ts *traverseSetup) traverseState(ctx context.Context, counters *traverseCo hooks = ts.createSimpleHooks() } - return ts.traverseStateBranchWithHooks(ctx, startKey, limitKey, raw, counters, hooks) + return ts.traverseStateBranch(ctx, startKey, limitKey, raw, counters, hooks) }) } @@ -702,10 +688,8 @@ func parseTraverseArgs(ctx *cli.Context) (*traverseConfig, error) { type TraverseHooks struct { // Called for each trie account node (only for raw mode) OnAccountNode func(node common.Hash, path []byte) error - // Called for each storage trie + // Called for each storage trie - now self-contained with all needed context OnStorageTrie func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) - // Called for each code - OnCode func(codeHash []byte, accountHash common.Hash) error } // createSimpleHooks creates hooks for simple traversal mode @@ -714,54 +698,78 @@ func (ts *traverseSetup) createSimpleHooks() *TraverseHooks { OnStorageTrie: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { return traverseStorage(ctx, storageTrie, nil, nil, false, nil) }, - OnCode: func(codeHash []byte, accountHash common.Hash) error { - if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { - log.Error("Code is missing", "hash", common.BytesToHash(codeHash)) - return errors.New("missing code") - } - return nil - }, } } +func verifyTrieNode(reader database.NodeReader, accountHash common.Hash, node common.Hash, path []byte) error { + if node != (common.Hash{}) { + blob, _ := reader.Node(accountHash, path, node) + if len(blob) == 0 { + log.Error("Missing trie node", "hash", node) + return errors.New("missing node") + } + if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { + log.Error("Invalid trie node", "hash", node.Hex(), "value", blob) + return errors.New("invalid node") + } + } + return nil +} + // createRawHooks creates hooks for raw traversal mode with detailed verification func (ts *traverseSetup) createRawHooks(reader database.NodeReader) *TraverseHooks { return &TraverseHooks{ OnAccountNode: func(node common.Hash, path []byte) error { - if node != (common.Hash{}) { - blob, _ := reader.Node(common.Hash{}, path, node) - if len(blob) == 0 { - log.Error("Missing trie node(account)", "hash", node) - return errors.New("missing account") - } - if !bytes.Equal(crypto.Keccak256(blob), node.Bytes()) { - log.Error("Invalid trie node(account)", "hash", node.Hex(), "value", blob) - return errors.New("invalid account node") - } - } - return nil + return verifyTrieNode(reader, common.Hash{}, node, path) }, + OnStorageTrie: func(ctx context.Context, storageTrie *trie.StateTrie, accountHash common.Hash) (slots, nodes uint64, err error) { - hook := createRawStorageHook(reader, accountHash) - return traverseStorage(ctx, storageTrie, nil, nil, true, hook) - }, - OnCode: func(codeHash []byte, accountHash common.Hash) error { - if !rawdb.HasCode(ts.chaindb, common.BytesToHash(codeHash)) { - log.Error("Code is missing", "account", accountHash) - return errors.New("missing code") + hook := func(node common.Hash, path []byte) error { + return verifyTrieNode(reader, accountHash, node, path) } - return nil + return traverseStorage(ctx, storageTrie, nil, nil, true, hook) }, } } -// traverseStateBranchWithHooks provides common branch traversal logic using hooks -func (ts *traverseSetup) traverseStateBranchWithHooks(ctx context.Context, startKey, limitKey []byte, raw bool, counters *traverseCounters, hooks *TraverseHooks) error { +// traverseStateBranch provides common branch traversal logic using hooks +func (ts *traverseSetup) traverseStateBranch(ctx context.Context, startKey, limitKey []byte, raw bool, counters *traverseCounters, hooks *TraverseHooks) error { accIter, err := ts.trie.NodeIterator(startKey) if err != nil { return err } + processAccount := func(accKey []byte, acc *types.StateAccount) error { + accountHash := common.BytesToHash(accKey) + + // Process storage if present + if acc.Root != types.EmptyRootHash { + id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) + storageTrie, err := trie.NewStateTrie(id, ts.triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) + return err + } + + slots, nodes, err := hooks.OnStorageTrie(ctx, storageTrie, accountHash) + if err != nil { + return err + } + counters.slots.Add(slots) + counters.nodes.Add(nodes) + } + + // Process code if present + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { + if codeHash := common.BytesToHash(acc.CodeHash); !rawdb.HasCode(ts.chaindb, codeHash) { + log.Error("Code is missing", "hash", codeHash) + return errors.New("missing code") + } + counters.codes.Add(1) + } + return nil + } + if raw { // Raw traversal with detailed node checking for accIter.Next(true) { @@ -791,31 +799,8 @@ func (ts *traverseSetup) traverseStateBranchWithHooks(ctx context.Context, start return errors.New("invalid account") } - accountHash := common.BytesToHash(accIter.LeafKey()) - - if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) - storageTrie, err := trie.NewStateTrie(id, ts.triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - - if hooks != nil && hooks.OnStorageTrie != nil { - slots, nodes, err := hooks.OnStorageTrie(ctx, storageTrie, accountHash) - if err != nil { - return err - } - counters.slots.Add(slots) - counters.nodes.Add(nodes) - } - } - - if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if err := hooks.OnCode(acc.CodeHash, accountHash); err != nil { - return err - } - counters.codes.Add(1) + if err := processAccount(accIter.LeafKey(), &acc); err != nil { + return err } } } @@ -845,32 +830,8 @@ func (ts *traverseSetup) traverseStateBranchWithHooks(ctx context.Context, start log.Error("Invalid account encountered during traversal", "err", err) return err } - - accountHash := common.BytesToHash(accIter.Key) - - // Process storage if present - if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(ts.config.root, accountHash, acc.Root) - storageTrie, err := trie.NewStateTrie(id, ts.triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - - slots, nodes, err := hooks.OnStorageTrie(ctx, storageTrie, accountHash) - if err != nil { - return err - } - counters.slots.Add(slots) - counters.nodes.Add(nodes) - } - - // Process code if present - if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if err := hooks.OnCode(acc.CodeHash, accountHash); err != nil { - return err - } - counters.codes.Add(1) + if err := processAccount(accIter.Key, &acc); err != nil { + return err } }