diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index fc0658a59c..c177fb5ea2 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -36,6 +36,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/triedb" "github.com/urfave/cli/v2" ) @@ -105,7 +106,9 @@ information about the specified address. Usage: "Traverse the state with given root hash and perform quick verification", ArgsUsage: "", Action: traverseState, - Flags: slices.Concat(utils.NetworkFlags, utils.DatabaseFlags), + Flags: slices.Concat([]cli.Flag{ + utils.AccountFlag, + }, utils.NetworkFlags, utils.DatabaseFlags), Description: ` geth snapshot traverse-state will traverse the whole state from the given state root and will abort if any @@ -113,6 +116,8 @@ referenced trie node or contract code is missing. This command can be used for state integrity verification. The default checking target is the HEAD state. It's also usable without snapshot enabled. + +If --account is specified, only the storage trie of that account is traversed. `, }, { @@ -120,7 +125,9 @@ It's also usable without snapshot enabled. Usage: "Traverse the state with given root hash and perform detailed verification", ArgsUsage: "", Action: traverseRawState, - Flags: slices.Concat(utils.NetworkFlags, utils.DatabaseFlags), + Flags: slices.Concat([]cli.Flag{ + utils.AccountFlag, + }, utils.NetworkFlags, utils.DatabaseFlags), Description: ` geth snapshot traverse-rawstate will traverse the whole state from the given root and will abort if any referenced @@ -129,6 +136,8 @@ verification. The default checking target is the HEAD state. It's basically iden to traverse-state, but the check granularity is smaller. It's also usable without snapshot enabled. + +If --account is specified, only the storage trie of that account is traversed. `, }, { @@ -272,6 +281,120 @@ func checkDanglingStorage(ctx *cli.Context) error { return snapshot.CheckDanglingStorage(db) } +// parseAccount parses the account flag value as either an address (20 bytes) +// or an account hash (32 bytes) and returns the hashed account key. +func parseAccount(input string) (common.Hash, error) { + switch len(input) { + case 40, 42: // address + return crypto.Keccak256Hash(common.HexToAddress(input).Bytes()), nil + case 64, 66: // hash + return common.HexToHash(input), nil + default: + return common.Hash{}, errors.New("malformed account address or hash") + } +} + +// lookupAccount resolves the account from the state trie using the given +// account hash. +func lookupAccount(accountHash common.Hash, tr *trie.Trie) (*types.StateAccount, error) { + accData, err := tr.Get(accountHash.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to get account %s: %w", accountHash, err) + } + if accData == nil { + return nil, fmt.Errorf("account not found: %s", accountHash) + } + var acc types.StateAccount + if err := rlp.DecodeBytes(accData, &acc); err != nil { + return nil, fmt.Errorf("invalid account data %s: %w", accountHash, err) + } + return &acc, nil +} + +func traverseStorage(id *trie.ID, db *triedb.Database, report bool, detail bool) error { + tr, err := trie.NewStateTrie(id, db) + if err != nil { + log.Error("Failed to open storage trie", "account", id.Owner, "root", id.Root, "err", err) + return err + } + var ( + slots int + nodes int + lastReport time.Time + start = time.Now() + ) + it, err := tr.NodeIterator(nil) + if err != nil { + log.Error("Failed to open storage iterator", "account", id.Owner, "root", id.Root, "err", err) + return err + } + logger := log.Debug + if report { + logger = log.Info + } + logger("Start traversing storage trie", "account", id.Owner, "storageRoot", id.Root) + + if !detail { + iter := trie.NewIterator(it) + for iter.Next() { + slots += 1 + if time.Since(lastReport) > time.Second*8 { + logger("Traversing storage", "account", id.Owner, "slots", slots, "elapsed", common.PrettyDuration(time.Since(start))) + lastReport = time.Now() + } + } + if iter.Err != nil { + log.Error("Failed to traverse storage trie", "root", id.Root, "err", iter.Err) + return iter.Err + } + logger("Storage is complete", "account", id.Owner, "slots", slots, "elapsed", common.PrettyDuration(time.Since(start))) + } else { + reader, err := db.NodeReader(id.StateRoot) + if err != nil { + log.Error("Failed to open state reader", "err", err) + return err + } + var ( + buffer = make([]byte, 32) + hasher = crypto.NewKeccakState() + ) + for it.Next(true) { + nodes += 1 + node := it.Hash() + + // Check the presence for non-empty hash node(embedded node doesn't + // have their own hash). + if node != (common.Hash{}) { + blob, _ := reader.Node(id.Owner, it.Path(), node) + if len(blob) == 0 { + log.Error("Missing trie node(storage)", "hash", node) + return errors.New("missing storage") + } + hasher.Reset() + hasher.Write(blob) + hasher.Read(buffer) + if !bytes.Equal(buffer, node.Bytes()) { + log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) + return errors.New("invalid storage node") + } + } + if it.Leaf() { + slots += 1 + } + if time.Since(lastReport) > time.Second*8 { + logger("Traversing storage", "account", id.Owner, "nodes", nodes, "slots", slots, "elapsed", common.PrettyDuration(time.Since(start))) + lastReport = time.Now() + } + } + if err := it.Error(); err != nil { + log.Error("Failed to traverse storage trie", "root", id.Root, "err", err) + return err + } + logger("Storage is complete", "account", id.Owner, "nodes", nodes, "slots", slots, "elapsed", common.PrettyDuration(time.Since(start))) + } + return nil +} + // traverseState is a helper function used for pruning verification. // Basically it just iterates the trie, ensure all nodes and associated // contract codes are present. @@ -309,6 +432,30 @@ func traverseState(ctx *cli.Context) error { root = headBlock.Root() log.Info("Start traversing the state", "root", root, "number", headBlock.NumberU64()) } + // If --account is specified, only traverse the storage trie of that account. + if accountStr := ctx.String(utils.AccountFlag.Name); accountStr != "" { + accountHash, err := parseAccount(accountStr) + if err != nil { + log.Error("Failed to parse account", "err", err) + return err + } + // Use raw trie since the account key is already hashed. + t, err := trie.New(trie.StateTrieID(root), triedb) + if err != nil { + log.Error("Failed to open state trie", "root", root, "err", err) + return err + } + acc, err := lookupAccount(accountHash, t) + if err != nil { + log.Error("Failed to look up account", "hash", accountHash, "err", err) + return err + } + if acc.Root == types.EmptyRootHash { + log.Info("Account has no storage", "hash", accountHash) + return nil + } + return traverseStorage(trie.StorageTrieID(root, accountHash, acc.Root), triedb, true, false) + } t, err := trie.NewStateTrie(trie.StateTrieID(root), triedb) if err != nil { log.Error("Failed to open trie", "root", root, "err", err) @@ -335,30 +482,10 @@ func traverseState(ctx *cli.Context) error { return err } if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(root, common.BytesToHash(accIter.Key), acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) + err := traverseStorage(trie.StorageTrieID(root, common.BytesToHash(accIter.Key), acc.Root), triedb, false, false) if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) return err } - storageIt, err := storageTrie.NodeIterator(nil) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) - return err - } - storageIter := trie.NewIterator(storageIt) - for storageIter.Next() { - slots += 1 - - if time.Since(lastReport) > time.Second*8 { - log.Info("Traversing state", "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - lastReport = time.Now() - } - } - if storageIter.Err != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Err) - return storageIter.Err - } } if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { @@ -418,6 +545,30 @@ func traverseRawState(ctx *cli.Context) error { root = headBlock.Root() log.Info("Start traversing the state", "root", root, "number", headBlock.NumberU64()) } + // If --account is specified, only traverse the storage trie of that account. + if accountStr := ctx.String(utils.AccountFlag.Name); accountStr != "" { + accountHash, err := parseAccount(accountStr) + if err != nil { + log.Error("Failed to parse account", "err", err) + return err + } + // Use raw trie since the account key is already hashed. + t, err := trie.New(trie.StateTrieID(root), triedb) + if err != nil { + log.Error("Failed to open state trie", "root", root, "err", err) + return err + } + acc, err := lookupAccount(accountHash, t) + if err != nil { + log.Error("Failed to look up account", "hash", accountHash, "err", err) + return err + } + if acc.Root == types.EmptyRootHash { + log.Info("Account has no storage", "hash", accountHash) + return nil + } + return traverseStorage(trie.StorageTrieID(root, accountHash, acc.Root), triedb, true, true) + } t, err := trie.NewStateTrie(trie.StateTrieID(root), triedb) if err != nil { log.Error("Failed to open trie", "root", root, "err", err) @@ -473,50 +624,10 @@ func traverseRawState(ctx *cli.Context) error { return errors.New("invalid account") } if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(root, common.BytesToHash(accIter.LeafKey()), acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) + err := traverseStorage(trie.StorageTrieID(root, common.BytesToHash(accIter.LeafKey()), acc.Root), triedb, false, true) if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return errors.New("missing storage trie") - } - storageIter, err := storageTrie.NodeIterator(nil) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) return err } - for storageIter.Next(true) { - nodes += 1 - node := storageIter.Hash() - - // Check the presence for non-empty hash node(embedded node doesn't - // have their own hash). - if node != (common.Hash{}) { - blob, _ := reader.Node(common.BytesToHash(accIter.LeafKey()), storageIter.Path(), node) - if len(blob) == 0 { - log.Error("Missing trie node(storage)", "hash", node) - return errors.New("missing storage") - } - hasher.Reset() - hasher.Write(blob) - hasher.Read(got) - if !bytes.Equal(got, node.Bytes()) { - log.Error("Invalid trie node(storage)", "hash", node.Hex(), "value", blob) - return errors.New("invalid storage node") - } - } - // Bump the counter if it's leaf node. - if storageIter.Leaf() { - slots += 1 - } - if time.Since(lastReport) > time.Second*8 { - log.Info("Traversing state", "nodes", nodes, "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - lastReport = time.Now() - } - } - if storageIter.Error() != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Error()) - return storageIter.Error() - } } if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 792e0e55ab..3a0bcc6b05 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -218,6 +218,10 @@ var ( Usage: "Max number of elements (0 = no limit)", Value: 0, } + AccountFlag = &cli.StringFlag{ + Name: "account", + Usage: "Specifies the account address or hash to traverse a single storage trie", + } OutputFileFlag = &cli.StringFlag{ Name: "output", Usage: "Writes the result in json to the output",