From dc19cae10e1b7fb034cfa68b090c07013ef165e9 Mon Sep 17 00:00:00 2001 From: jsvisa Date: Tue, 9 Sep 2025 09:29:29 +0000 Subject: [PATCH] in parallel Signed-off-by: jsvisa --- cmd/geth/snapshot.go | 217 +++++++++++++++++++++++++++++++++---------- 1 file changed, 166 insertions(+), 51 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index e24bba082f..1a8573e118 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -18,13 +18,17 @@ package main import ( "bytes" + "context" "encoding/json" "errors" "fmt" "os" "slices" + "sync/atomic" "time" + "golang.org/x/sync/errgroup" + "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" @@ -33,9 +37,11 @@ import ( "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/triedb" "github.com/urfave/cli/v2" ) @@ -327,9 +333,9 @@ func traverseState(ctx *cli.Context) error { } var ( - accounts int - slots int - codes int + accounts atomic.Uint64 + slots atomic.Uint64 + codes atomic.Uint64 start = time.Now() ) @@ -337,7 +343,7 @@ func traverseState(ctx *cli.Context) error { timer := time.NewTicker(time.Second * 8) defer timer.Stop() for range timer.C { - log.Info("Traversing state", "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) + log.Info("Traversing state", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) } }() @@ -374,7 +380,7 @@ func traverseState(ctx *cli.Context) error { break } - slots += 1 + slots.Add(1) log.Debug("Storage slot", "key", common.Bytes2Hex(storageIter.Key), "value", common.Bytes2Hex(storageIter.Value)) } if storageIter.Err != nil { @@ -382,64 +388,173 @@ func traverseState(ctx *cli.Context) error { return storageIter.Err } - log.Info("Storage traversal complete", "slots", slots, "elapsed", common.PrettyDuration(time.Since(start))) + log.Info("Storage traversal complete", "slots", slots.Load(), "elapsed", common.PrettyDuration(time.Since(start))) return nil } else { log.Info("Start traversing state trie", "root", config.root.Hex(), "startKey", common.Bytes2Hex(config.startKey), "limitKey", common.Bytes2Hex(config.limitKey)) - acctIt, err := t.NodeIterator(config.startKey) - if err != nil { - log.Error("Failed to open iterator", "root", config.root, "err", err) - return err - } - accIter := trie.NewIterator(acctIt) - for accIter.Next() { - if config.limitKey != nil && bytes.Compare(accIter.Key, config.limitKey) >= 0 { - break + return traverseStateParallel(t, triedb, chaindb, config, &accounts, &slots, &codes, start) + } +} + +// traverseStateParallel parallelizes state traversal by dividing work across 16 trie branches +func traverseStateParallel(t *trie.StateTrie, triedb *triedb.Database, chaindb ethdb.Database, config *traverseConfig, accounts, slots, codes *atomic.Uint64, start time.Time) error { + ctx := context.Background() + g, ctx := errgroup.WithContext(ctx) + + for i := 0; i < 16; i++ { + nibble := byte(i) + g.Go(func() error { + startKey := config.startKey + limitKey := config.limitKey + + branchStartKey := make([]byte, len(startKey)+1) + branchLimitKey := make([]byte, len(startKey)+1) + + if len(startKey) > 0 { + copy(branchStartKey, startKey) + copy(branchLimitKey, startKey) } - accounts += 1 - var acc types.StateAccount - if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { - log.Error("Invalid account encountered during traversal", "err", err) + branchStartKey[len(startKey)] = nibble << 4 + branchLimitKey[len(startKey)] = (nibble + 1) << 4 + + if limitKey != nil && bytes.Compare(branchStartKey, limitKey) >= 0 { + return nil + } + if limitKey != nil && bytes.Compare(branchLimitKey, limitKey) > 0 { + branchLimitKey = make([]byte, len(limitKey)) + copy(branchLimitKey, limitKey) + } + + return traverseBranch(ctx, t, triedb, chaindb, config.root, branchStartKey, branchLimitKey, accounts, slots, codes) + }) + } + + if err := g.Wait(); err != nil { + return err + } + + log.Info("State traversal complete", "accounts", accounts.Load(), "slots", slots.Load(), "codes", codes.Load(), "elapsed", common.PrettyDuration(time.Since(start))) + return nil +} + +// traverseBranch traverses a specific branch of the state trie +func traverseBranch(ctx context.Context, t *trie.StateTrie, triedb *triedb.Database, chaindb ethdb.Database, root common.Hash, startKey, limitKey []byte, accounts, slots, codes *atomic.Uint64) error { + acctIt, err := t.NodeIterator(startKey) + if err != nil { + return err + } + + accIter := trie.NewIterator(acctIt) + for accIter.Next() { + // Check if context was cancelled + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if limitKey != nil && bytes.Compare(accIter.Key, limitKey) >= 0 { + break + } + + accounts.Add(1) + + var acc types.StateAccount + if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { + log.Error("Invalid account encountered during traversal", "err", err) + return err + } + + if acc.Root != types.EmptyRootHash { + id := trie.StorageTrieID(root, common.BytesToHash(accIter.Key), acc.Root) + storageTrie, err := trie.NewStateTrie(id, triedb) + if err != nil { + log.Error("Failed to open storage trie", "root", acc.Root, "err", err) return err } - if acc.Root != types.EmptyRootHash { - id := trie.StorageTrieID(config.root, common.BytesToHash(accIter.Key), acc.Root) - storageTrie, err := trie.NewStateTrie(id, triedb) - if err != nil { - log.Error("Failed to open storage trie", "root", acc.Root, "err", err) - return err - } - storageIt, err := storageTrie.NodeIterator(nil) - if err != nil { - log.Error("Failed to open storage iterator", "root", acc.Root, "err", err) - return err - } - storageIter := trie.NewIterator(storageIt) - for storageIter.Next() { - slots += 1 - } - if storageIter.Err != nil { - log.Error("Failed to traverse storage trie", "root", acc.Root, "err", storageIter.Err) - return storageIter.Err - } + + localSlots, err := traverseStorageParallel(ctx, storageTrie) + if err != nil { + log.Error("Failed to traverse storage trie", "root", acc.Root, "err", err) + return err } - if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { - if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { - log.Error("Code is missing", "hash", common.BytesToHash(acc.CodeHash)) - return errors.New("missing code") - } - codes += 1 + slots.Add(localSlots) + } + + if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) { + if !rawdb.HasCode(chaindb, common.BytesToHash(acc.CodeHash)) { + log.Error("Code is missing", "hash", common.BytesToHash(acc.CodeHash)) + return errors.New("missing code") } + codes.Add(1) } - if accIter.Err != nil { - log.Error("Failed to traverse state trie", "root", config.root, "err", accIter.Err) - return accIter.Err - } - log.Info("State traversal complete", "accounts", accounts, "slots", slots, "codes", codes, "elapsed", common.PrettyDuration(time.Since(start))) - return nil } + + if accIter.Err != nil { + return accIter.Err + } + + return nil +} + +// traverseStorageParallel parallelizes storage trie traversal by dividing work across 16 trie branches +func traverseStorageParallel(ctx context.Context, storageTrie *trie.StateTrie) (uint64, error) { + g, ctx := errgroup.WithContext(ctx) + totalSlots := atomic.Uint64{} + + for i := 0; i < 16; i++ { + nibble := byte(i) + g.Go(func() error { + branchStartKey := []byte{nibble << 4} + branchLimitKey := []byte{(nibble + 1) << 4} + + localSlots, err := traverseStorageBranch(ctx, storageTrie, branchStartKey, branchLimitKey) + if err != nil { + return err + } + totalSlots.Add(localSlots) + return nil + }) + } + + if err := g.Wait(); err != nil { + return 0, err + } + + return totalSlots.Load(), nil +} + +// traverseStorageBranch traverses a specific branch of the storage trie +func traverseStorageBranch(ctx context.Context, storageTrie *trie.StateTrie, startKey, limitKey []byte) (uint64, error) { + storageIt, err := storageTrie.NodeIterator(startKey) + if err != nil { + return 0, err + } + + storageIter := trie.NewIterator(storageIt) + slots := uint64(0) + + for storageIter.Next() { + // Check if context was cancelled + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + } + + if bytes.Compare(storageIter.Key, limitKey) >= 0 { + break + } + slots++ + } + + if storageIter.Err != nil { + return 0, storageIter.Err + } + + return slots, nil } type traverseConfig struct {