diff --git a/beacon/light/request/server.go b/beacon/light/request/server.go index a06dec99ae..d39570b8e5 100644 --- a/beacon/light/request/server.go +++ b/beacon/light/request/server.go @@ -438,14 +438,11 @@ func (s *serverWithLimits) fail(desc string) { // failLocked calculates the dynamic failure delay and applies it. func (s *serverWithLimits) failLocked(desc string) { log.Debug("Server error", "description", desc) - s.failureDelay *= 2 now := s.clock.Now() if now > s.failureDelayEnd { s.failureDelay *= math.Pow(2, -float64(now-s.failureDelayEnd)/float64(maxFailureDelay)) } - if s.failureDelay < float64(minFailureDelay) { - s.failureDelay = float64(minFailureDelay) - } + s.failureDelay = max(min(s.failureDelay*2, float64(maxFailureDelay)), float64(minFailureDelay)) s.failureDelayEnd = now + mclock.AbsTime(s.failureDelay) s.delay(time.Duration(s.failureDelay)) } diff --git a/beacon/light/sync/update_sync.go b/beacon/light/sync/update_sync.go index 9549ee5992..d84a3d64da 100644 --- a/beacon/light/sync/update_sync.go +++ b/beacon/light/sync/update_sync.go @@ -62,7 +62,6 @@ const ( ssNeedParent // cp header slot %32 != 0, need parent to check epoch boundary ssParentRequested // cp parent header requested ssPrintStatus // has all necessary info, print log message if init still not successful - ssDone // log message printed, no more action required ) type serverState struct { @@ -180,7 +179,8 @@ func (s *CheckpointInit) Process(requester request.Requester, events []request.E default: log.Error("blsync: checkpoint not available, but reported as finalized; specified checkpoint hash might be too old", "server", server.Name()) } - s.serverState[server] = serverState{state: ssDone} + s.serverState[server] = serverState{state: ssDefault} + requester.Fail(server, "checkpoint init failed") } } diff --git a/cmd/geth/bintrie_convert.go b/cmd/geth/bintrie_convert.go new file mode 100644 index 0000000000..3730768697 --- /dev/null +++ b/cmd/geth/bintrie_convert.go @@ -0,0 +1,408 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "errors" + "fmt" + "runtime" + "runtime/debug" + "slices" + "time" + + "github.com/ethereum/go-ethereum/cmd/utils" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/trie/bintrie" + "github.com/ethereum/go-ethereum/trie/trienode" + "github.com/ethereum/go-ethereum/triedb" + "github.com/ethereum/go-ethereum/triedb/pathdb" + "github.com/urfave/cli/v2" +) + +var ( + deleteSourceFlag = &cli.BoolFlag{ + Name: "delete-source", + Usage: "Delete MPT trie nodes after conversion", + } + memoryLimitFlag = &cli.Uint64Flag{ + Name: "memory-limit", + Usage: "Max heap allocation in MB before forcing a commit cycle", + Value: 16384, + } + + bintrieCommand = &cli.Command{ + Name: "bintrie", + Usage: "A set of commands for binary trie operations", + Description: "", + Subcommands: []*cli.Command{ + { + Name: "convert", + Usage: "Convert MPT state to binary trie", + ArgsUsage: "[state-root]", + Action: convertToBinaryTrie, + Flags: slices.Concat([]cli.Flag{ + deleteSourceFlag, + memoryLimitFlag, + }, utils.NetworkFlags, utils.DatabaseFlags), + Description: ` +geth bintrie convert [--delete-source] [--memory-limit MB] [state-root] + +Reads all state from the Merkle Patricia Trie and writes it into a Binary Trie, +operating offline. Memory-safe via periodic commit-and-reload cycles. + +The optional state-root argument specifies which state root to convert. +If omitted, the head block's state root is used. + +Flags: + --delete-source Delete MPT trie nodes after successful conversion + --memory-limit Max heap allocation in MB before forcing a commit (default: 16384) +`, + }, + }, + } +) + +type conversionStats struct { + accounts uint64 + slots uint64 + codes uint64 + commits uint64 + start time.Time + lastReport time.Time + lastMemChk time.Time +} + +func (s *conversionStats) report(force bool) { + if !force && time.Since(s.lastReport) < 8*time.Second { + return + } + elapsed := time.Since(s.start).Seconds() + acctRate := float64(0) + if elapsed > 0 { + acctRate = float64(s.accounts) / elapsed + } + log.Info("Conversion progress", + "accounts", s.accounts, + "slots", s.slots, + "codes", s.codes, + "commits", s.commits, + "accounts/sec", fmt.Sprintf("%.0f", acctRate), + "elapsed", common.PrettyDuration(time.Since(s.start)), + ) + s.lastReport = time.Now() +} + +func convertToBinaryTrie(ctx *cli.Context) error { + if ctx.NArg() > 1 { + return errors.New("too many arguments") + } + stack, _ := makeConfigNode(ctx) + defer stack.Close() + + chaindb := utils.MakeChainDatabase(ctx, stack, false) + defer chaindb.Close() + + headBlock := rawdb.ReadHeadBlock(chaindb) + if headBlock == nil { + return errors.New("no head block found") + } + var ( + root common.Hash + err error + ) + if ctx.NArg() == 1 { + root, err = parseRoot(ctx.Args().First()) + if err != nil { + return fmt.Errorf("invalid state root: %w", err) + } + } else { + root = headBlock.Root() + } + log.Info("Starting MPT to binary trie conversion", "root", root, "block", headBlock.NumberU64()) + + srcTriedb := utils.MakeTrieDatabase(ctx, stack, chaindb, true, true, false) + defer srcTriedb.Close() + + destTriedb := triedb.NewDatabase(chaindb, &triedb.Config{ + IsVerkle: true, + PathDB: &pathdb.Config{ + JournalDirectory: stack.ResolvePath("triedb-bintrie"), + }, + }) + defer destTriedb.Close() + + binTrie, err := bintrie.NewBinaryTrie(types.EmptyBinaryHash, destTriedb) + if err != nil { + return fmt.Errorf("failed to create binary trie: %w", err) + } + memLimit := ctx.Uint64(memoryLimitFlag.Name) * 1024 * 1024 + + currentRoot, err := runConversionLoop(chaindb, srcTriedb, destTriedb, binTrie, root, memLimit) + if err != nil { + return err + } + log.Info("Conversion complete", "binaryRoot", currentRoot) + + if ctx.Bool(deleteSourceFlag.Name) { + log.Info("Deleting source MPT data") + if err := deleteMPTData(chaindb, srcTriedb, root); err != nil { + return fmt.Errorf("MPT deletion failed: %w", err) + } + log.Info("Source MPT data deleted") + } + return nil +} + +func runConversionLoop(chaindb ethdb.Database, srcTriedb *triedb.Database, destTriedb *triedb.Database, binTrie *bintrie.BinaryTrie, root common.Hash, memLimit uint64) (common.Hash, error) { + currentRoot := types.EmptyBinaryHash + stats := &conversionStats{ + start: time.Now(), + lastReport: time.Now(), + lastMemChk: time.Now(), + } + + srcTrie, err := trie.NewStateTrie(trie.StateTrieID(root), srcTriedb) + if err != nil { + return common.Hash{}, fmt.Errorf("failed to open source trie: %w", err) + } + acctIt, err := srcTrie.NodeIterator(nil) + if err != nil { + return common.Hash{}, fmt.Errorf("failed to create account iterator: %w", err) + } + accIter := trie.NewIterator(acctIt) + + for accIter.Next() { + var acc types.StateAccount + if err := rlp.DecodeBytes(accIter.Value, &acc); err != nil { + return common.Hash{}, fmt.Errorf("invalid account RLP: %w", err) + } + addrBytes := srcTrie.GetKey(accIter.Key) + if addrBytes == nil { + return common.Hash{}, fmt.Errorf("missing preimage for account hash %x (run with --cache.preimages)", accIter.Key) + } + addr := common.BytesToAddress(addrBytes) + + var code []byte + codeHash := common.BytesToHash(acc.CodeHash) + if codeHash != types.EmptyCodeHash { + code = rawdb.ReadCode(chaindb, codeHash) + if code == nil { + return common.Hash{}, fmt.Errorf("missing code for hash %x (account %x)", codeHash, addr) + } + stats.codes++ + } + + if err := binTrie.UpdateAccount(addr, &acc, len(code)); err != nil { + return common.Hash{}, fmt.Errorf("failed to update account %x: %w", addr, err) + } + if len(code) > 0 { + if err := binTrie.UpdateContractCode(addr, codeHash, code); err != nil { + return common.Hash{}, fmt.Errorf("failed to update code for %x: %w", addr, err) + } + } + + if acc.Root != types.EmptyRootHash { + addrHash := common.BytesToHash(accIter.Key) + storageTrie, err := trie.NewStateTrie(trie.StorageTrieID(root, addrHash, acc.Root), srcTriedb) + if err != nil { + return common.Hash{}, fmt.Errorf("failed to open storage trie for %x: %w", addr, err) + } + storageNodeIt, err := storageTrie.NodeIterator(nil) + if err != nil { + return common.Hash{}, fmt.Errorf("failed to create storage iterator for %x: %w", addr, err) + } + storageIter := trie.NewIterator(storageNodeIt) + + slotCount := uint64(0) + for storageIter.Next() { + slotKey := storageTrie.GetKey(storageIter.Key) + if slotKey == nil { + return common.Hash{}, fmt.Errorf("missing preimage for storage key %x (account %x)", storageIter.Key, addr) + } + _, content, _, err := rlp.Split(storageIter.Value) + if err != nil { + return common.Hash{}, fmt.Errorf("invalid storage RLP for key %x (account %x): %w", slotKey, addr, err) + } + if err := binTrie.UpdateStorage(addr, slotKey, content); err != nil { + return common.Hash{}, fmt.Errorf("failed to update storage %x/%x: %w", addr, slotKey, err) + } + stats.slots++ + slotCount++ + + if slotCount%10000 == 0 { + binTrie, currentRoot, err = maybeCommit(binTrie, currentRoot, destTriedb, memLimit, stats) + if err != nil { + return common.Hash{}, err + } + } + } + if storageIter.Err != nil { + return common.Hash{}, fmt.Errorf("storage iteration error for %x: %w", addr, storageIter.Err) + } + } + stats.accounts++ + stats.report(false) + + if stats.accounts%1000 == 0 { + binTrie, currentRoot, err = maybeCommit(binTrie, currentRoot, destTriedb, memLimit, stats) + if err != nil { + return common.Hash{}, err + } + } + } + if accIter.Err != nil { + return common.Hash{}, fmt.Errorf("account iteration error: %w", accIter.Err) + } + + _, currentRoot, err = commitBinaryTrie(binTrie, currentRoot, destTriedb) + if err != nil { + return common.Hash{}, fmt.Errorf("final commit failed: %w", err) + } + stats.commits++ + stats.report(true) + return currentRoot, nil +} + +func maybeCommit(bt *bintrie.BinaryTrie, currentRoot common.Hash, destDB *triedb.Database, memLimit uint64, stats *conversionStats) (*bintrie.BinaryTrie, common.Hash, error) { + if time.Since(stats.lastMemChk) < 5*time.Second { + return bt, currentRoot, nil + } + stats.lastMemChk = time.Now() + + var m runtime.MemStats + runtime.ReadMemStats(&m) + if m.Alloc < memLimit { + return bt, currentRoot, nil + } + log.Info("Memory limit reached, committing", "alloc", common.StorageSize(m.Alloc), "limit", common.StorageSize(memLimit)) + + bt, currentRoot, err := commitBinaryTrie(bt, currentRoot, destDB) + if err != nil { + return nil, common.Hash{}, err + } + stats.commits++ + stats.report(true) + return bt, currentRoot, nil +} + +func commitBinaryTrie(bt *bintrie.BinaryTrie, currentRoot common.Hash, destDB *triedb.Database) (*bintrie.BinaryTrie, common.Hash, error) { + newRoot, nodeSet := bt.Commit(false) + if nodeSet != nil { + merged := trienode.NewWithNodeSet(nodeSet) + if err := destDB.Update(newRoot, currentRoot, 0, merged, triedb.NewStateSet()); err != nil { + return nil, common.Hash{}, fmt.Errorf("triedb update failed: %w", err) + } + if err := destDB.Commit(newRoot, false); err != nil { + return nil, common.Hash{}, fmt.Errorf("triedb commit failed: %w", err) + } + } + runtime.GC() + debug.FreeOSMemory() + + bt, err := bintrie.NewBinaryTrie(newRoot, destDB) + if err != nil { + return nil, common.Hash{}, fmt.Errorf("failed to reload binary trie: %w", err) + } + return bt, newRoot, nil +} + +func deleteMPTData(chaindb ethdb.Database, srcTriedb *triedb.Database, root common.Hash) error { + isPathDB := srcTriedb.Scheme() == rawdb.PathScheme + + srcTrie, err := trie.NewStateTrie(trie.StateTrieID(root), srcTriedb) + if err != nil { + return fmt.Errorf("failed to open source trie for deletion: %w", err) + } + acctIt, err := srcTrie.NodeIterator(nil) + if err != nil { + return fmt.Errorf("failed to create account iterator for deletion: %w", err) + } + batch := chaindb.NewBatch() + deleted := 0 + + for acctIt.Next(true) { + if isPathDB { + rawdb.DeleteAccountTrieNode(batch, acctIt.Path()) + } else { + node := acctIt.Hash() + if node != (common.Hash{}) { + rawdb.DeleteLegacyTrieNode(batch, node) + } + } + deleted++ + + if acctIt.Leaf() { + var acc types.StateAccount + if err := rlp.DecodeBytes(acctIt.LeafBlob(), &acc); err != nil { + return fmt.Errorf("invalid account during deletion: %w", err) + } + if acc.Root != types.EmptyRootHash { + addrHash := common.BytesToHash(acctIt.LeafKey()) + storageTrie, err := trie.NewStateTrie(trie.StorageTrieID(root, addrHash, acc.Root), srcTriedb) + if err != nil { + return fmt.Errorf("failed to open storage trie for deletion: %w", err) + } + storageIt, err := storageTrie.NodeIterator(nil) + if err != nil { + return fmt.Errorf("failed to create storage iterator for deletion: %w", err) + } + for storageIt.Next(true) { + if isPathDB { + rawdb.DeleteStorageTrieNode(batch, addrHash, storageIt.Path()) + } else { + node := storageIt.Hash() + if node != (common.Hash{}) { + rawdb.DeleteLegacyTrieNode(batch, node) + } + } + deleted++ + if batch.ValueSize() >= ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + return fmt.Errorf("batch write failed: %w", err) + } + batch.Reset() + } + } + if storageIt.Error() != nil { + return fmt.Errorf("storage deletion iterator error: %w", storageIt.Error()) + } + } + } + if batch.ValueSize() >= ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + return fmt.Errorf("batch write failed: %w", err) + } + batch.Reset() + } + } + if acctIt.Error() != nil { + return fmt.Errorf("account deletion iterator error: %w", acctIt.Error()) + } + if batch.ValueSize() > 0 { + if err := batch.Write(); err != nil { + return fmt.Errorf("final batch write failed: %w", err) + } + } + log.Info("MPT deletion complete", "nodesDeleted", deleted) + return nil +} diff --git a/cmd/geth/bintrie_convert_test.go b/cmd/geth/bintrie_convert_test.go new file mode 100644 index 0000000000..9b95f6a70f --- /dev/null +++ b/cmd/geth/bintrie_convert_test.go @@ -0,0 +1,229 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "math" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie/bintrie" + "github.com/ethereum/go-ethereum/triedb" + "github.com/ethereum/go-ethereum/triedb/pathdb" + "github.com/holiman/uint256" +) + +func TestBintrieConvert(t *testing.T) { + var ( + addr1 = common.HexToAddress("0x1111111111111111111111111111111111111111") + addr2 = common.HexToAddress("0x2222222222222222222222222222222222222222") + slotKey1 = common.HexToHash("0x01") + slotKey2 = common.HexToHash("0x02") + slotVal1 = common.HexToHash("0xdeadbeef") + slotVal2 = common.HexToHash("0xcafebabe") + code = []byte{0x60, 0x42, 0x60, 0x00, 0x52, 0x60, 0x20, 0x60, 0x00, 0xf3} + ) + + chaindb := rawdb.NewMemoryDatabase() + + srcTriedb := triedb.NewDatabase(chaindb, &triedb.Config{ + Preimages: true, + PathDB: pathdb.Defaults, + }) + + gspec := &core.Genesis{ + Config: params.TestChainConfig, + BaseFee: big.NewInt(params.InitialBaseFee), + Alloc: types.GenesisAlloc{ + addr1: { + Balance: big.NewInt(1000000), + Nonce: 5, + }, + addr2: { + Balance: big.NewInt(2000000), + Nonce: 10, + Code: code, + Storage: map[common.Hash]common.Hash{ + slotKey1: slotVal1, + slotKey2: slotVal2, + }, + }, + }, + } + + genesisBlock := gspec.MustCommit(chaindb, srcTriedb) + root := genesisBlock.Root() + t.Logf("Genesis root: %x", root) + srcTriedb.Close() + + srcTriedb2 := triedb.NewDatabase(chaindb, &triedb.Config{ + Preimages: true, + PathDB: &pathdb.Config{ReadOnly: true}, + }) + defer srcTriedb2.Close() + + destTriedb := triedb.NewDatabase(chaindb, &triedb.Config{ + IsVerkle: true, + PathDB: pathdb.Defaults, + }) + defer destTriedb.Close() + + bt, err := bintrie.NewBinaryTrie(types.EmptyBinaryHash, destTriedb) + if err != nil { + t.Fatalf("failed to create binary trie: %v", err) + } + + currentRoot, err := runConversionLoop(chaindb, srcTriedb2, destTriedb, bt, root, math.MaxUint64) + if err != nil { + t.Fatalf("conversion failed: %v", err) + } + t.Logf("Binary trie root: %x", currentRoot) + + bt2, err := bintrie.NewBinaryTrie(currentRoot, destTriedb) + if err != nil { + t.Fatalf("failed to reload binary trie: %v", err) + } + + acc1, err := bt2.GetAccount(addr1) + if err != nil { + t.Fatalf("failed to get account1: %v", err) + } + if acc1 == nil { + t.Fatal("account1 not found in binary trie") + } + if acc1.Nonce != 5 { + t.Errorf("account1 nonce: got %d, want 5", acc1.Nonce) + } + wantBal1 := uint256.NewInt(1000000) + if acc1.Balance.Cmp(wantBal1) != 0 { + t.Errorf("account1 balance: got %s, want %s", acc1.Balance, wantBal1) + } + + acc2, err := bt2.GetAccount(addr2) + if err != nil { + t.Fatalf("failed to get account2: %v", err) + } + if acc2 == nil { + t.Fatal("account2 not found in binary trie") + } + if acc2.Nonce != 10 { + t.Errorf("account2 nonce: got %d, want 10", acc2.Nonce) + } + wantBal2 := uint256.NewInt(2000000) + if acc2.Balance.Cmp(wantBal2) != 0 { + t.Errorf("account2 balance: got %s, want %s", acc2.Balance, wantBal2) + } + + treeKey1 := bintrie.GetBinaryTreeKeyStorageSlot(addr2, slotKey1[:]) + val1, err := bt2.GetWithHashedKey(treeKey1) + if err != nil { + t.Fatalf("failed to get storage slot1: %v", err) + } + if len(val1) == 0 { + t.Fatal("storage slot1 not found") + } + got1 := common.BytesToHash(val1) + if got1 != slotVal1 { + t.Errorf("storage slot1: got %x, want %x", got1, slotVal1) + } + + treeKey2 := bintrie.GetBinaryTreeKeyStorageSlot(addr2, slotKey2[:]) + val2, err := bt2.GetWithHashedKey(treeKey2) + if err != nil { + t.Fatalf("failed to get storage slot2: %v", err) + } + if len(val2) == 0 { + t.Fatal("storage slot2 not found") + } + got2 := common.BytesToHash(val2) + if got2 != slotVal2 { + t.Errorf("storage slot2: got %x, want %x", got2, slotVal2) + } +} + +func TestBintrieConvertDeleteSource(t *testing.T) { + addr1 := common.HexToAddress("0x3333333333333333333333333333333333333333") + + chaindb := rawdb.NewMemoryDatabase() + + srcTriedb := triedb.NewDatabase(chaindb, &triedb.Config{ + Preimages: true, + PathDB: pathdb.Defaults, + }) + + gspec := &core.Genesis{ + Config: params.TestChainConfig, + BaseFee: big.NewInt(params.InitialBaseFee), + Alloc: types.GenesisAlloc{ + addr1: { + Balance: big.NewInt(1000000), + }, + }, + } + + genesisBlock := gspec.MustCommit(chaindb, srcTriedb) + root := genesisBlock.Root() + srcTriedb.Close() + + srcTriedb2 := triedb.NewDatabase(chaindb, &triedb.Config{ + Preimages: true, + PathDB: &pathdb.Config{ReadOnly: true}, + }) + + destTriedb := triedb.NewDatabase(chaindb, &triedb.Config{ + IsVerkle: true, + PathDB: pathdb.Defaults, + }) + + bt, err := bintrie.NewBinaryTrie(types.EmptyBinaryHash, destTriedb) + if err != nil { + t.Fatalf("failed to create binary trie: %v", err) + } + + newRoot, err := runConversionLoop(chaindb, srcTriedb2, destTriedb, bt, root, math.MaxUint64) + if err != nil { + t.Fatalf("conversion failed: %v", err) + } + + if err := deleteMPTData(chaindb, srcTriedb2, root); err != nil { + t.Fatalf("deletion failed: %v", err) + } + srcTriedb2.Close() + + bt2, err := bintrie.NewBinaryTrie(newRoot, destTriedb) + if err != nil { + t.Fatalf("failed to reload binary trie after deletion: %v", err) + } + + acc, err := bt2.GetAccount(addr1) + if err != nil { + t.Fatalf("failed to get account after deletion: %v", err) + } + if acc == nil { + t.Fatal("account not found after MPT deletion") + } + wantBal := uint256.NewInt(1000000) + if acc.Balance.Cmp(wantBal) != 0 { + t.Errorf("balance after deletion: got %s, want %s", acc.Balance, wantBal) + } + destTriedb.Close() +} diff --git a/cmd/geth/main.go b/cmd/geth/main.go index b72cbb9885..e196ac8688 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -260,6 +260,8 @@ func init() { utils.ShowDeprecated, // See snapshot.go snapshotCommand, + // See bintrie_convert.go + bintrieCommand, } if logTestCommand != nil { app.Commands = append(app.Commands, logTestCommand) diff --git a/consensus/beacon/consensus.go b/consensus/beacon/consensus.go index 25f4f9d2b2..c4a284d485 100644 --- a/consensus/beacon/consensus.go +++ b/consensus/beacon/consensus.go @@ -275,12 +275,22 @@ func (beacon *Beacon) verifyHeader(chain consensus.ChainHeaderReader, header, pa } } + // Verify the existence / non-existence of Amsterdam-specific header fields amsterdam := chain.Config().IsAmsterdam(header.Number, header.Time) - if amsterdam && header.SlotNumber == nil { - return errors.New("header is missing slotNumber") - } - if !amsterdam && header.SlotNumber != nil { - return fmt.Errorf("invalid slotNumber: have %d, expected nil", *header.SlotNumber) + if amsterdam { + if header.BlockAccessListHash == nil { + return errors.New("header is missing block access list hash") + } + if header.SlotNumber == nil { + return errors.New("header is missing slotNumber") + } + } else { + if header.BlockAccessListHash != nil { + return fmt.Errorf("invalid block access list hash: have %x, expected nil", *header.BlockAccessListHash) + } + if header.SlotNumber != nil { + return fmt.Errorf("invalid slotNumber: have %d, expected nil", *header.SlotNumber) + } } return nil } diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go index f1b40d0d0c..3614702d1a 100644 --- a/core/blockchain_reader.go +++ b/core/blockchain_reader.go @@ -296,6 +296,14 @@ func (bc *BlockChain) GetReceiptsRLP(hash common.Hash) rlp.RawValue { return rawdb.ReadReceiptsRLP(bc.db, hash, number) } +func (bc *BlockChain) GetAccessListRLP(hash common.Hash) rlp.RawValue { + number, ok := rawdb.ReadHeaderNumber(bc.db, hash) + if !ok { + return nil + } + return rawdb.ReadAccessListRLP(bc.db, hash, number) +} + // GetUnclesInChain retrieves all the uncles from a given block backwards until // a specific distance is reached. func (bc *BlockChain) GetUnclesInChain(block *types.Block, length int) []*types.Header { @@ -468,7 +476,7 @@ func (bc *BlockChain) TxIndexProgress() (TxIndexProgress, error) { } // StateIndexProgress returns the historical state indexing progress. -func (bc *BlockChain) StateIndexProgress() (uint64, error) { +func (bc *BlockChain) StateIndexProgress() (uint64, uint64, error) { return bc.triedb.IndexProgress() } diff --git a/core/filtermaps/math_test.go b/core/filtermaps/math_test.go index a4c1609059..0cd0046a7d 100644 --- a/core/filtermaps/math_test.go +++ b/core/filtermaps/math_test.go @@ -41,9 +41,7 @@ func TestSingleMatch(t *testing.T) { t.Fatalf("Invalid length of matches (got %d, expected 1)", len(matches)) } if matches[0] != lvIndex { - if len(matches) != 1 { - t.Fatalf("Incorrect match returned (got %d, expected %d)", matches[0], lvIndex) - } + t.Fatalf("Incorrect match returned (got %d, expected %d)", matches[0], lvIndex) } } } diff --git a/core/rawdb/accessors_chain.go b/core/rawdb/accessors_chain.go index 6ae64fb2fd..0582e842c3 100644 --- a/core/rawdb/accessors_chain.go +++ b/core/rawdb/accessors_chain.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/misc/eip4844" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/types/bal" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" @@ -605,6 +606,55 @@ func DeleteReceipts(db ethdb.KeyValueWriter, hash common.Hash, number uint64) { } } +// HasAccessList verifies the existence of a block access list for a block. +func HasAccessList(db ethdb.Reader, hash common.Hash, number uint64) bool { + has, _ := db.Has(accessListKey(number, hash)) + return has +} + +// ReadAccessListRLP retrieves the RLP-encoded block access list for a block from KV. +func ReadAccessListRLP(db ethdb.Reader, hash common.Hash, number uint64) rlp.RawValue { + data, _ := db.Get(accessListKey(number, hash)) + return data +} + +// ReadAccessList retrieves and decodes the block access list for a block. +func ReadAccessList(db ethdb.Reader, hash common.Hash, number uint64) *bal.BlockAccessList { + data := ReadAccessListRLP(db, hash, number) + if len(data) == 0 { + return nil + } + b := new(bal.BlockAccessList) + if err := rlp.DecodeBytes(data, b); err != nil { + log.Error("Invalid BAL RLP", "hash", hash, "err", err) + return nil + } + return b +} + +// WriteAccessList RLP-encodes and stores a block access list in the active KV store. +func WriteAccessList(db ethdb.KeyValueWriter, hash common.Hash, number uint64, b *bal.BlockAccessList) { + bytes, err := rlp.EncodeToBytes(b) + if err != nil { + log.Crit("Failed to encode BAL", "err", err) + } + WriteAccessListRLP(db, hash, number, bytes) +} + +// WriteAccessListRLP stores a pre-encoded block access list in the active KV store. +func WriteAccessListRLP(db ethdb.KeyValueWriter, hash common.Hash, number uint64, encoded rlp.RawValue) { + if err := db.Put(accessListKey(number, hash), encoded); err != nil { + log.Crit("Failed to store BAL", "err", err) + } +} + +// DeleteAccessList removes a block access list from the active KV store. +func DeleteAccessList(db ethdb.KeyValueWriter, hash common.Hash, number uint64) { + if err := db.Delete(accessListKey(number, hash)); err != nil { + log.Crit("Failed to delete BAL", "err", err) + } +} + // ReceiptLogs is a barebone version of ReceiptForStorage which only keeps // the list of logs. When decoding a stored receipt into this object we // avoid creating the bloom filter. @@ -659,13 +709,25 @@ func ReadBlock(db ethdb.Reader, hash common.Hash, number uint64) *types.Block { if body == nil { return nil } - return types.NewBlockWithHeader(header).WithBody(*body) + block := types.NewBlockWithHeader(header).WithBody(*body) + + // Best-effort assembly of the block access list from the database. + if header.BlockAccessListHash != nil { + al := ReadAccessList(db, hash, number) + block = block.WithAccessListUnsafe(al) + } + return block } // WriteBlock serializes a block into the database, header and body separately. func WriteBlock(db ethdb.KeyValueWriter, block *types.Block) { - WriteBody(db, block.Hash(), block.NumberU64(), block.Body()) + hash, number := block.Hash(), block.NumberU64() + WriteBody(db, hash, number, block.Body()) WriteHeader(db, block.Header()) + + if accessList := block.AccessList(); accessList != nil { + WriteAccessList(db, hash, number, accessList) + } } // WriteAncientBlocks writes entire block data into ancient store and returns the total written size. diff --git a/core/rawdb/accessors_chain_test.go b/core/rawdb/accessors_chain_test.go index 280fc21e8f..c35f56ee07 100644 --- a/core/rawdb/accessors_chain_test.go +++ b/core/rawdb/accessors_chain_test.go @@ -27,10 +27,12 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/types/bal" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/keccak" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" + "github.com/holiman/uint256" ) // Tests block header storage and retrieval operations. @@ -899,3 +901,78 @@ func TestHeadersRLPStorage(t *testing.T) { checkSequence(1, 1) // Only block 1 checkSequence(1, 2) // Genesis + block 1 } + +func makeTestBAL(t *testing.T) (rlp.RawValue, *bal.BlockAccessList) { + t.Helper() + + cb := bal.NewConstructionBlockAccessList() + addr := common.HexToAddress("0x1111111111111111111111111111111111111111") + cb.AccountRead(addr) + cb.StorageRead(addr, common.BytesToHash([]byte{0x01})) + cb.StorageWrite(0, addr, common.BytesToHash([]byte{0x02}), common.BytesToHash([]byte{0xaa})) + cb.BalanceChange(0, addr, uint256.NewInt(100)) + cb.NonceChange(addr, 0, 1) + + var buf bytes.Buffer + if err := cb.EncodeRLP(&buf); err != nil { + t.Fatalf("failed to encode BAL: %v", err) + } + encoded := buf.Bytes() + + var decoded bal.BlockAccessList + if err := rlp.DecodeBytes(encoded, &decoded); err != nil { + t.Fatalf("failed to decode BAL: %v", err) + } + return encoded, &decoded +} + +// TestBALStorage tests write/read/delete of BALs in the KV store. +func TestBALStorage(t *testing.T) { + db := NewMemoryDatabase() + + hash := common.BytesToHash([]byte{0x03, 0x14}) + number := uint64(42) + + // Check that no BAL exists in a new database. + if HasAccessList(db, hash, number) { + t.Fatal("BAL found in new database") + } + if b := ReadAccessList(db, hash, number); b != nil { + t.Fatalf("non existent BAL returned: %v", b) + } + + // Write a BAL and verify it can be read back. + encoded, testBAL := makeTestBAL(t) + WriteAccessList(db, hash, number, testBAL) + + if !HasAccessList(db, hash, number) { + t.Fatal("HasAccessList returned false after write") + } + if blob := ReadAccessListRLP(db, hash, number); len(blob) == 0 { + t.Fatal("ReadAccessListRLP returned empty after write") + } + if b := ReadAccessList(db, hash, number); b == nil { + t.Fatal("ReadAccessList returned nil after write") + } else if b.Hash() != testBAL.Hash() { + t.Fatalf("BAL hash mismatch: got %x, want %x", b.Hash(), testBAL.Hash()) + } + + // Also test WriteAccessListRLP with pre-encoded data. + hash2 := common.BytesToHash([]byte{0x03, 0x15}) + WriteAccessListRLP(db, hash2, number, encoded) + if b := ReadAccessList(db, hash2, number); b == nil { + t.Fatal("ReadAccessList returned nil after WriteAccessListRLP") + } else if b.Hash() != testBAL.Hash() { + t.Fatalf("BAL hash mismatch after WriteAccessListRLP: got %x, want %x", b.Hash(), testBAL.Hash()) + } + + // Delete the BAL and verify it's gone. + DeleteAccessList(db, hash, number) + + if HasAccessList(db, hash, number) { + t.Fatal("HasAccessList returned true after delete") + } + if b := ReadAccessList(db, hash, number); b != nil { + t.Fatalf("deleted BAL returned: %v", b) + } +} diff --git a/core/rawdb/database.go b/core/rawdb/database.go index 945fd9097d..39e1a64e5a 100644 --- a/core/rawdb/database.go +++ b/core/rawdb/database.go @@ -413,6 +413,7 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { tds stat numHashPairings stat hashNumPairings stat + blockAccessList stat legacyTries stat stateLookups stat accountTries stat @@ -484,6 +485,9 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { numHashPairings.add(size) case bytes.HasPrefix(key, headerNumberPrefix) && len(key) == (len(headerNumberPrefix)+common.HashLength): hashNumPairings.add(size) + case bytes.HasPrefix(key, accessListPrefix) && len(key) == len(accessListPrefix)+8+common.HashLength: + blockAccessList.add(size) + case IsLegacyTrieNode(key, it.Value()): legacyTries.add(size) case bytes.HasPrefix(key, stateIDPrefix) && len(key) == len(stateIDPrefix)+common.HashLength: @@ -625,6 +629,7 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { {"Key-Value store", "Difficulties (deprecated)", tds.sizeString(), tds.countString()}, {"Key-Value store", "Block number->hash", numHashPairings.sizeString(), numHashPairings.countString()}, {"Key-Value store", "Block hash->number", hashNumPairings.sizeString(), hashNumPairings.countString()}, + {"Key-Value store", "Block accessList", blockAccessList.sizeString(), blockAccessList.countString()}, {"Key-Value store", "Transaction index", txLookups.sizeString(), txLookups.countString()}, {"Key-Value store", "Log index filter-map rows", filterMapRows.sizeString(), filterMapRows.countString()}, {"Key-Value store", "Log index last-block-of-map", filterMapLastBlock.sizeString(), filterMapLastBlock.countString()}, diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index d9140c5fd6..54c76143b4 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -112,6 +112,7 @@ var ( blockBodyPrefix = []byte("b") // blockBodyPrefix + num (uint64 big endian) + hash -> block body blockReceiptsPrefix = []byte("r") // blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts + accessListPrefix = []byte("j") // accessListPrefix + num (uint64 big endian) + hash -> block access list txLookupPrefix = []byte("l") // txLookupPrefix + hash -> transaction/receipt lookup metadata bloomBitsPrefix = []byte("B") // bloomBitsPrefix + bit (uint16 big endian) + section (uint64 big endian) + hash -> bloom bits @@ -214,6 +215,11 @@ func blockReceiptsKey(number uint64, hash common.Hash) []byte { return append(append(blockReceiptsPrefix, encodeBlockNumber(number)...), hash.Bytes()...) } +// accessListKey = accessListPrefix + num (uint64 big endian) + hash +func accessListKey(number uint64, hash common.Hash) []byte { + return append(append(accessListPrefix, encodeBlockNumber(number)...), hash.Bytes()...) +} + // txLookupKey = txLookupPrefix + hash func txLookupKey(hash common.Hash) []byte { return append(txLookupPrefix, hash.Bytes()...) diff --git a/core/state/database.go b/core/state/database.go index 002ce57fbc..c603e3ad7a 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -39,6 +39,10 @@ type Database interface { // Reader returns a state reader associated with the specified state root. Reader(root common.Hash) (Reader, error) + // Iteratee returns a state iteratee associated with the specified state root, + // through which the account iterator and storage iterator can be created. + Iteratee(root common.Hash) (Iteratee, error) + // OpenTrie opens the main account trie. OpenTrie(root common.Hash) (Trie, error) @@ -48,9 +52,6 @@ type Database interface { // TrieDB returns the underlying trie database for managing trie nodes. TrieDB() *triedb.Database - // Snapshot returns the underlying state snapshot. - Snapshot() *snapshot.Tree - // Commit flushes all pending writes and finalizes the state transition, // committing the changes to the underlying storage. It returns an error // if the commit fails. @@ -310,6 +311,12 @@ func (db *CachingDB) Commit(update *stateUpdate) error { return db.triedb.Update(update.root, update.originRoot, update.blockNumber, update.nodes, update.stateSet()) } +// Iteratee returns a state iteratee associated with the specified state root, +// through which the account iterator and storage iterator can be created. +func (db *CachingDB) Iteratee(root common.Hash) (Iteratee, error) { + return newStateIteratee(!db.triedb.IsVerkle(), root, db.triedb, db.snap) +} + // mustCopyTrie returns a deep-copied trie. func mustCopyTrie(t Trie) Trie { switch t := t.(type) { diff --git a/core/state/database_history.go b/core/state/database_history.go index c25c4eae4b..0dbb8cc546 100644 --- a/core/state/database_history.go +++ b/core/state/database_history.go @@ -22,7 +22,6 @@ import ( "sync" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" @@ -289,14 +288,15 @@ func (db *HistoricDB) TrieDB() *triedb.Database { return db.triedb } -// Snapshot returns the underlying state snapshot. -func (db *HistoricDB) Snapshot() *snapshot.Tree { - return nil -} - // Commit flushes all pending writes and finalizes the state transition, // committing the changes to the underlying storage. It returns an error // if the commit fails. func (db *HistoricDB) Commit(update *stateUpdate) error { return errors.New("not implemented") } + +// Iteratee returns a state iteratee associated with the specified state root, +// through which the account iterator and storage iterator can be created. +func (db *HistoricDB) Iteratee(root common.Hash) (Iteratee, error) { + return nil, errors.New("not implemented") +} diff --git a/core/state/database_iterator.go b/core/state/database_iterator.go new file mode 100644 index 0000000000..8fad66a1e8 --- /dev/null +++ b/core/state/database_iterator.go @@ -0,0 +1,435 @@ +// Copyright 2025 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state/snapshot" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/triedb" +) + +// Iterator is an iterator to step over all the accounts or the specific +// storage in the specific state. +type Iterator interface { + // Next steps the iterator forward one element. It returns false if the iterator + // is exhausted or if an error occurs. Any error encountered is retained and + // can be retrieved via Error(). + Next() bool + + // Error returns any failure that occurred during iteration, which might have + // caused a premature iteration exit. + Error() error + + // Hash returns the hash of the account or storage slot the iterator is + // currently at. + Hash() common.Hash + + // Release releases associated resources. Release should always succeed and + // can be called multiple times without causing error. + Release() +} + +// AccountIterator is an iterator to step over all the accounts in the +// specific state. +type AccountIterator interface { + Iterator + + // Address returns the raw account address the iterator is currently at. + // An error will be returned if the preimage is not available. + Address() (common.Address, error) + + // Account returns the RLP encoded account the iterator is currently at. + // An error will be retained if the iterator becomes invalid. + Account() []byte +} + +// StorageIterator is an iterator to step over the specific storage in the +// specific state. +type StorageIterator interface { + Iterator + + // Key returns the raw storage slot key the iterator is currently at. + // An error will be returned if the preimage is not available. + Key() (common.Hash, error) + + // Slot returns the storage slot the iterator is currently at. An error will + // be retained if the iterator becomes invalid. + Slot() []byte +} + +// Iteratee wraps the NewIterator methods for traversing the accounts and +// storages of the specific state. +type Iteratee interface { + // NewAccountIterator creates an account iterator for the state specified by + // the given root. It begins at a specified starting position, corresponding + // to a particular initial key (or the next key if the specified one does + // not exist). + // + // The starting position here refers to the hash of the account address. + NewAccountIterator(start common.Hash) (AccountIterator, error) + + // NewStorageIterator creates a storage iterator for the state specified by + // the address hash. It begins at a specified starting position, corresponding + // to a particular initial key (or the next key if the specified one does + // not exist). + // + // The starting position here refers to the hash of the slot key. + NewStorageIterator(addressHash common.Hash, start common.Hash) (StorageIterator, error) +} + +// PreimageReader wraps the function Preimage for accessing the preimage of +// a given hash. +type PreimageReader interface { + // Preimage returns the preimage of associated hash. + Preimage(hash common.Hash) []byte +} + +// flatAccountIterator is a wrapper around the underlying flat state iterator. +// Before returning data from the iterator, it performs an additional conversion +// to bridge the slim encoding with the full encoding format. +type flatAccountIterator struct { + err error + it snapshot.AccountIterator + preimage PreimageReader +} + +// newFlatAccountIterator constructs the account iterator with the provided +// flat state iterator. +func newFlatAccountIterator(it snapshot.AccountIterator, preimage PreimageReader) *flatAccountIterator { + return &flatAccountIterator{it: it, preimage: preimage} +} + +// Next steps the iterator forward one element. It returns false if the iterator +// is exhausted or if an error occurs. Any error encountered is retained and +// can be retrieved via Error(). +func (ai *flatAccountIterator) Next() bool { + if ai.err != nil { + return false + } + return ai.it.Next() +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit. +func (ai *flatAccountIterator) Error() error { + if ai.err != nil { + return ai.err + } + return ai.it.Error() +} + +// Hash returns the hash of the account or storage slot the iterator is +// currently at. +func (ai *flatAccountIterator) Hash() common.Hash { + return ai.it.Hash() +} + +// Release releases associated resources. Release should always succeed and +// can be called multiple times without causing error. +func (ai *flatAccountIterator) Release() { + ai.it.Release() +} + +// Address returns the raw account address the iterator is currently at. +// An error will be returned if the preimage is not available. +func (ai *flatAccountIterator) Address() (common.Address, error) { + if ai.preimage == nil { + return common.Address{}, errors.New("account address is not available") + } + preimage := ai.preimage.Preimage(ai.Hash()) + if preimage == nil { + return common.Address{}, errors.New("account address is not available") + } + return common.BytesToAddress(preimage), nil +} + +// Account returns the account data the iterator is currently at. The account +// data is encoded as slim format from the underlying iterator, the conversion +// is required. +func (ai *flatAccountIterator) Account() []byte { + data, err := types.FullAccountRLP(ai.it.Account()) + if err != nil { + ai.err = err + return nil + } + return data +} + +// flatStorageIterator is a wrapper around the underlying flat state iterator. +type flatStorageIterator struct { + it snapshot.StorageIterator + preimage PreimageReader +} + +// newFlatStorageIterator constructs the storage iterator with the provided +// flat state iterator. +func newFlatStorageIterator(it snapshot.StorageIterator, preimage PreimageReader) *flatStorageIterator { + return &flatStorageIterator{it: it, preimage: preimage} +} + +// Next steps the iterator forward one element. It returns false if the iterator +// is exhausted or if an error occurs. Any error encountered is retained and +// can be retrieved via Error(). +func (si *flatStorageIterator) Next() bool { + return si.it.Next() +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit. +func (si *flatStorageIterator) Error() error { + return si.it.Error() +} + +// Hash returns the hash of the account or storage slot the iterator is +// currently at. +func (si *flatStorageIterator) Hash() common.Hash { + return si.it.Hash() +} + +// Release releases associated resources. Release should always succeed and +// can be called multiple times without causing error. +func (si *flatStorageIterator) Release() { + si.it.Release() +} + +// Key returns the raw storage slot key the iterator is currently at. +// An error will be returned if the preimage is not available. +func (si *flatStorageIterator) Key() (common.Hash, error) { + if si.preimage == nil { + return common.Hash{}, errors.New("slot key is not available") + } + preimage := si.preimage.Preimage(si.Hash()) + if preimage == nil { + return common.Hash{}, errors.New("slot key is not available") + } + return common.BytesToHash(preimage), nil +} + +// Slot returns the storage slot data the iterator is currently at. +func (si *flatStorageIterator) Slot() []byte { + return si.it.Slot() +} + +// merkleIterator implements the Iterator interface, providing functions to traverse +// the accounts or storages with the manner of Merkle-Patricia-Trie. +type merkleIterator struct { + tr Trie + it *trie.Iterator + account bool +} + +// newMerkleTrieIterator constructs the iterator with the given trie and starting position. +func newMerkleTrieIterator(tr Trie, start common.Hash, account bool) (*merkleIterator, error) { + it, err := tr.NodeIterator(start.Bytes()) + if err != nil { + return nil, err + } + return &merkleIterator{ + tr: tr, + it: trie.NewIterator(it), + account: account, + }, nil +} + +// Next steps the iterator forward one element. It returns false if the iterator +// is exhausted or if an error occurs. Any error encountered is retained and +// can be retrieved via Error(). +func (ti *merkleIterator) Next() bool { + return ti.it.Next() +} + +// Error returns any failure that occurred during iteration, which might have +// caused a premature iteration exit. +func (ti *merkleIterator) Error() error { + return ti.it.Err +} + +// Hash returns the hash of the account or storage slot the iterator is +// currently at. +func (ti *merkleIterator) Hash() common.Hash { + return common.BytesToHash(ti.it.Key) +} + +// Release releases associated resources. Release should always succeed and +// can be called multiple times without causing error. +func (ti *merkleIterator) Release() {} + +// Address returns the raw account address the iterator is currently at. +// An error will be returned if the preimage is not available. +func (ti *merkleIterator) Address() (common.Address, error) { + if !ti.account { + return common.Address{}, errors.New("account address is not available") + } + preimage := ti.tr.GetKey(ti.it.Key) + if preimage == nil { + return common.Address{}, errors.New("account address is not available") + } + return common.BytesToAddress(preimage), nil +} + +// Account returns the account data the iterator is currently at. +func (ti *merkleIterator) Account() []byte { + if !ti.account { + return nil + } + return ti.it.Value +} + +// Key returns the raw storage slot key the iterator is currently at. +// An error will be returned if the preimage is not available. +func (ti *merkleIterator) Key() (common.Hash, error) { + if ti.account { + return common.Hash{}, errors.New("slot key is not available") + } + preimage := ti.tr.GetKey(ti.it.Key) + if preimage == nil { + return common.Hash{}, errors.New("slot key is not available") + } + return common.BytesToHash(preimage), nil +} + +// Slot returns the storage slot the iterator is currently at. +func (ti *merkleIterator) Slot() []byte { + if ti.account { + return nil + } + return ti.it.Value +} + +// stateIteratee implements Iteratee interface, providing the state traversal +// functionalities of a specific state. +type stateIteratee struct { + merkle bool + root common.Hash + triedb *triedb.Database + snap *snapshot.Tree +} + +func newStateIteratee(merkle bool, root common.Hash, triedb *triedb.Database, snap *snapshot.Tree) (*stateIteratee, error) { + return &stateIteratee{ + merkle: merkle, + root: root, + triedb: triedb, + snap: snap, + }, nil +} + +// NewAccountIterator creates an account iterator for the state specified by +// the given root. It begins at a specified starting position, corresponding +// to a particular initial key (or the next key if the specified one does +// not exist). +// +// The starting position here refers to the hash of the account address. +func (si *stateIteratee) NewAccountIterator(start common.Hash) (AccountIterator, error) { + // If the external snapshot is available (hash scheme), try to initialize + // the account iterator from there first. + if si.snap != nil { + it, err := si.snap.AccountIterator(si.root, start) + if err == nil { + return newFlatAccountIterator(it, si.triedb), nil + } + } + // If the external snapshot is not available, try to initialize the + // account iterator from the trie database (path scheme) + it, err := si.triedb.AccountIterator(si.root, start) + if err == nil { + return newFlatAccountIterator(it, si.triedb), nil + } + if !si.merkle { + return nil, fmt.Errorf("state %x is not available for account traversal", si.root) + } + // The snapshot is not usable so far, construct the account iterator from + // the trie as the fallback. It's not as efficient as the flat state iterator. + tr, err := trie.NewStateTrie(trie.StateTrieID(si.root), si.triedb) + if err != nil { + return nil, err + } + return newMerkleTrieIterator(tr, start, true) +} + +// NewStorageIterator creates a storage iterator for the state specified by +// the address hash. It begins at a specified starting position, corresponding +// to a particular initial key (or the next key if the specified one does not exist). +// +// The starting position here refers to the hash of the slot key. +func (si *stateIteratee) NewStorageIterator(addressHash common.Hash, start common.Hash) (StorageIterator, error) { + // If the external snapshot is available (hash scheme), try to initialize + // the storage iterator from there first. + if si.snap != nil { + it, err := si.snap.StorageIterator(si.root, addressHash, start) + if err == nil { + return newFlatStorageIterator(it, si.triedb), nil + } + } + // If the external snapshot is not available, try to initialize the + // storage iterator from the trie database (path scheme) + it, err := si.triedb.StorageIterator(si.root, addressHash, start) + if err == nil { + return newFlatStorageIterator(it, si.triedb), nil + } + if !si.merkle { + return nil, fmt.Errorf("state %x is not available for storage traversal", si.root) + } + // The snapshot is not usable so far, construct the storage iterator from + // the trie as the fallback. It's not as efficient as the flat state iterator. + tr, err := trie.NewStateTrie(trie.StateTrieID(si.root), si.triedb) + if err != nil { + return nil, err + } + acct, err := tr.GetAccountByHash(addressHash) + if err != nil { + return nil, err + } + if acct == nil || acct.Root == types.EmptyRootHash { + return &exhaustedIterator{}, nil + } + storageTr, err := trie.NewStateTrie(trie.StorageTrieID(si.root, addressHash, acct.Root), si.triedb) + if err != nil { + return nil, err + } + return newMerkleTrieIterator(storageTr, start, false) +} + +type exhaustedIterator struct{} + +func (e exhaustedIterator) Next() bool { + return false +} + +func (e exhaustedIterator) Error() error { + return nil +} + +func (e exhaustedIterator) Hash() common.Hash { + return common.Hash{} +} + +func (e exhaustedIterator) Release() { +} + +func (e exhaustedIterator) Key() (common.Hash, error) { + return common.Hash{}, nil +} + +func (e exhaustedIterator) Slot() []byte { + return nil +} diff --git a/core/state/database_iterator_test.go b/core/state/database_iterator_test.go new file mode 100644 index 0000000000..87819e5526 --- /dev/null +++ b/core/state/database_iterator_test.go @@ -0,0 +1,262 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +// TestExhaustedIterator verifies the exhaustedIterator sentinel: Next is false, +// Error is nil, Hash/Key are zero, Slot is nil, and double Release is safe. +func TestExhaustedIterator(t *testing.T) { + var it exhaustedIterator + + if it.Next() { + t.Fatal("Next() returned true") + } + if err := it.Error(); err != nil { + t.Fatalf("Error() = %v, want nil", err) + } + if hash := it.Hash(); hash != (common.Hash{}) { + t.Fatalf("Hash() = %x, want zero", hash) + } + if key, err := it.Key(); key != (common.Hash{}) || err != nil { + t.Fatalf("Key() = %x, %v; want zero, nil", key, err) + } + if slot := it.Slot(); slot != nil { + t.Fatalf("Slot() = %x, want nil", slot) + } + it.Release() + it.Release() +} + +// TestAccountIterator tests the account iterator: correct count, ascending +// hash order, valid full-format RLP, data integrity, address preimage +// resolution, and seek behavior. +func TestAccountIterator(t *testing.T) { + testAccountIterator(t, rawdb.HashScheme) + testAccountIterator(t, rawdb.PathScheme) +} + +func testAccountIterator(t *testing.T, scheme string) { + _, sdb, ndb, root, accounts := makeTestState(scheme) + ndb.Commit(root, false) + + iteratee, err := sdb.Iteratee(root) + if err != nil { + t.Fatalf("(%s) failed to create iteratee: %v", scheme, err) + } + // Build lookups from address hash. + addrByHash := make(map[common.Hash]*testAccount) + for _, acc := range accounts { + addrByHash[crypto.Keccak256Hash(acc.address.Bytes())] = acc + } + + // --- Full iteration: count, ordering, RLP validity, data integrity, address resolution --- + acctIt, err := iteratee.NewAccountIterator(common.Hash{}) + if err != nil { + t.Fatalf("(%s) failed to create account iterator: %v", scheme, err) + } + var ( + hashes []common.Hash + prevHash common.Hash + ) + for acctIt.Next() { + hash := acctIt.Hash() + if hash == (common.Hash{}) { + t.Fatalf("(%s) zero hash at position %d", scheme, len(hashes)) + } + if len(hashes) > 0 && bytes.Compare(prevHash.Bytes(), hash.Bytes()) >= 0 { + t.Fatalf("(%s) hashes not ascending: %x >= %x", scheme, prevHash, hash) + } + prevHash = hash + hashes = append(hashes, hash) + + // Decode and verify account data. + blob := acctIt.Account() + if blob == nil { + t.Fatalf("(%s) nil account at %x", scheme, hash) + } + var decoded types.StateAccount + if err := rlp.DecodeBytes(blob, &decoded); err != nil { + t.Fatalf("(%s) bad RLP at %x: %v", scheme, hash, err) + } + acc := addrByHash[hash] + if decoded.Nonce != acc.nonce { + t.Fatalf("(%s) nonce %x: got %d, want %d", scheme, hash, decoded.Nonce, acc.nonce) + } + if decoded.Balance.Cmp(acc.balance) != 0 { + t.Fatalf("(%s) balance %x: got %v, want %v", scheme, hash, decoded.Balance, acc.balance) + } + // Verify address preimage resolution. + addr, err := acctIt.Address() + if err != nil { + t.Fatalf("(%s) failed to address: %v", scheme, err) + } + if addr != acc.address { + t.Fatalf("(%s) Address() = %x, want %x", scheme, addr, acc.address) + } + } + acctIt.Release() + + if err := acctIt.Error(); err != nil { + t.Fatalf("(%s) iteration error: %v", scheme, err) + } + if len(hashes) != len(accounts) { + t.Fatalf("(%s) iterated %d accounts, want %d", scheme, len(hashes), len(accounts)) + } + + // --- Seek: starting from midpoint should skip earlier entries --- + mid := hashes[len(hashes)/2] + seekIt, err := iteratee.NewAccountIterator(mid) + if err != nil { + t.Fatalf("(%s) failed to create seeked iterator: %v", scheme, err) + } + seekCount := 0 + for seekIt.Next() { + if bytes.Compare(seekIt.Hash().Bytes(), mid.Bytes()) < 0 { + t.Fatalf("(%s) seeked iterator returned hash before start", scheme) + } + seekCount++ + } + seekIt.Release() + + if seekCount != len(hashes)/2 { + t.Fatalf("(%s) unexpected seeked count, %d != %d", scheme, seekCount, len(hashes)/2) + } +} + +// TestStorageIterator tests the storage iterator: correct slot counts against +// the trie, ascending hash order, non-nil slot data, key preimage resolution, +// seek behavior, and empty-storage accounts. +func TestStorageIterator(t *testing.T) { + testStorageIterator(t, rawdb.HashScheme) + testStorageIterator(t, rawdb.PathScheme) +} + +func testStorageIterator(t *testing.T, scheme string) { + _, sdb, ndb, root, accounts := makeTestState(scheme) + ndb.Commit(root, false) + + iteratee, err := sdb.Iteratee(root) + if err != nil { + t.Fatalf("(%s) failed to create iteratee: %v", scheme, err) + } + + // --- Slot count and ordering for every account --- + var withStorage common.Hash // remember an account that has storage for seek test + for _, acc := range accounts { + addrHash := crypto.Keccak256Hash(acc.address.Bytes()) + expected := countStorageSlots(t, scheme, sdb, root, addrHash) + + storageIt, err := iteratee.NewStorageIterator(addrHash, common.Hash{}) + if err != nil { + t.Fatalf("(%s) failed to create storage iterator for %x: %v", scheme, acc.address, err) + } + count := 0 + var prevHash common.Hash + for storageIt.Next() { + hash := storageIt.Hash() + if count > 0 && bytes.Compare(prevHash.Bytes(), hash.Bytes()) >= 0 { + t.Fatalf("(%s) storage hashes not ascending for %x", scheme, acc.address) + } + prevHash = hash + if storageIt.Slot() == nil { + t.Fatalf("(%s) nil slot at %x", scheme, hash) + } + // Check key preimage resolution on first slot. + if _, err := storageIt.Key(); err != nil { + t.Fatalf("(%s) Key() failed to resolve", scheme) + } + count++ + } + if err := storageIt.Error(); err != nil { + t.Fatalf("(%s) storage iteration error for %x: %v", scheme, acc.address, err) + } + storageIt.Release() + + if count != expected { + t.Fatalf("(%s) account %x: %d slots, want %d", scheme, acc.address, count, expected) + } + if count > 0 { + withStorage = addrHash + } + } + + // --- Seek: starting from second slot should skip the first --- + if withStorage == (common.Hash{}) { + t.Fatalf("(%s) no account with storage found", scheme) + } + fullIt, err := iteratee.NewStorageIterator(withStorage, common.Hash{}) + if err != nil { + t.Fatalf("(%s) failed to create full storage iterator: %v", scheme, err) + } + var slotHashes []common.Hash + for fullIt.Next() { + slotHashes = append(slotHashes, fullIt.Hash()) + } + fullIt.Release() + + seekIt, err := iteratee.NewStorageIterator(withStorage, slotHashes[1]) + if err != nil { + t.Fatalf("(%s) failed to create seeked storage iterator: %v", scheme, err) + } + seekCount := 0 + for seekIt.Next() { + if bytes.Compare(seekIt.Hash().Bytes(), slotHashes[1].Bytes()) < 0 { + t.Fatalf("(%s) seeked storage iterator returned hash before start", scheme) + } + seekCount++ + } + seekIt.Release() + + if seekCount != len(slotHashes)-1 { + t.Fatalf("(%s) unexpected seeked storage count %d != %d", scheme, seekCount, len(slotHashes)-1) + } +} + +// countStorageSlots counts storage slots for an account by opening the +// storage trie directly. +func countStorageSlots(t *testing.T, scheme string, sdb Database, root common.Hash, addrHash common.Hash) int { + t.Helper() + accTrie, err := trie.NewStateTrie(trie.StateTrieID(root), sdb.TrieDB()) + if err != nil { + t.Fatalf("(%s) failed to open account trie: %v", scheme, err) + } + acct, err := accTrie.GetAccountByHash(addrHash) + if err != nil || acct == nil || acct.Root == types.EmptyRootHash { + return 0 + } + storageTrie, err := trie.NewStateTrie(trie.StorageTrieID(root, addrHash, acct.Root), sdb.TrieDB()) + if err != nil { + t.Fatalf("(%s) failed to open storage trie for %x: %v", scheme, addrHash, err) + } + it := trie.NewIterator(storageTrie.MustNodeIterator(nil)) + count := 0 + for it.Next() { + count++ + } + return count +} diff --git a/core/state/dump.go b/core/state/dump.go index 829d106ed3..71138143d9 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -27,7 +27,6 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie/bintrie" ) @@ -45,6 +44,7 @@ type DumpConfig struct { type DumpCollector interface { // OnRoot is called with the state root OnRoot(common.Hash) + // OnAccount is called once for each account in the trie OnAccount(*common.Address, DumpAccount) } @@ -65,9 +65,10 @@ type DumpAccount struct { type Dump struct { Root string `json:"root"` Accounts map[string]DumpAccount `json:"accounts"` + // Next can be set to represent that this dump is only partial, and Next // is where an iterator should be positioned in order to continue the dump. - Next []byte `json:"next,omitempty"` // nil if no more accounts + Next hexutil.Bytes `json:"next,omitempty"` // nil if no more accounts } // OnRoot implements DumpCollector interface @@ -114,9 +115,6 @@ func (d iterativeDump) OnRoot(root common.Hash) { // DumpToCollector iterates the state according to the given options and inserts // the items into a collector for aggregation or serialization. -// -// The state iterator is still trie-based and can be converted to snapshot-based -// once the state snapshot is fully integrated into database. TODO(rjl493456442). func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey []byte) { // Sanitize the input to allow nil configs if conf == nil { @@ -131,20 +129,23 @@ func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey [] log.Info("Trie dumping started", "root", s.originalRoot) c.OnRoot(s.originalRoot) - tr, err := s.db.OpenTrie(s.originalRoot) + iteratee, err := s.db.Iteratee(s.originalRoot) if err != nil { return nil } - trieIt, err := tr.NodeIterator(conf.Start) + var startHash common.Hash + if conf.Start != nil { + startHash = common.BytesToHash(conf.Start) + } + acctIt, err := iteratee.NewAccountIterator(startHash) if err != nil { - log.Error("Trie dumping error", "err", err) return nil } - it := trie.NewIterator(trieIt) + defer acctIt.Release() - for it.Next() { + for acctIt.Next() { var data types.StateAccount - if err := rlp.DecodeBytes(it.Value, &data); err != nil { + if err := rlp.DecodeBytes(acctIt.Account(), &data); err != nil { panic(err) } var ( @@ -153,63 +154,55 @@ func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey [] Nonce: data.Nonce, Root: data.Root[:], CodeHash: data.CodeHash, - AddressHash: it.Key, + AddressHash: acctIt.Hash().Bytes(), } - address *common.Address - addr common.Address - addrBytes = tr.GetKey(it.Key) + address *common.Address ) - if addrBytes == nil { + addrBytes, err := acctIt.Address() + if err != nil { missingPreimages++ if conf.OnlyWithAddresses { continue } } else { - addr = common.BytesToAddress(addrBytes) - address = &addr + address = &addrBytes account.Address = address } - obj := newObject(s, addr, &data) + obj := newObject(s, addrBytes, &data) if !conf.SkipCode { account.Code = obj.Code() } if !conf.SkipStorage { account.Storage = make(map[common.Hash]string) - storageTr, err := s.db.OpenStorageTrie(s.originalRoot, addr, obj.Root(), tr) + storageIt, err := iteratee.NewStorageIterator(acctIt.Hash(), common.Hash{}) if err != nil { log.Error("Failed to load storage trie", "err", err) continue } - trieIt, err := storageTr.NodeIterator(nil) - if err != nil { - log.Error("Failed to create trie iterator", "err", err) - continue - } - storageIt := trie.NewIterator(trieIt) for storageIt.Next() { - _, content, _, err := rlp.Split(storageIt.Value) + _, content, _, err := rlp.Split(storageIt.Slot()) if err != nil { log.Error("Failed to decode the value returned by iterator", "error", err) continue } - key := storageTr.GetKey(storageIt.Key) - if key == nil { + key, err := storageIt.Key() + if err != nil { continue } - account.Storage[common.BytesToHash(key)] = common.Bytes2Hex(content) + account.Storage[key] = common.Bytes2Hex(content) } + storageIt.Release() } c.OnAccount(address, account) accounts++ if time.Since(logged) > 8*time.Second { - log.Info("Trie dumping in progress", "at", common.Bytes2Hex(it.Key), "accounts", accounts, - "elapsed", common.PrettyDuration(time.Since(start))) + log.Info("Trie dumping in progress", "at", acctIt.Hash().Hex(), "accounts", accounts, "elapsed", common.PrettyDuration(time.Since(start))) logged = time.Now() } if conf.Max > 0 && accounts >= conf.Max { - if it.Next() { - nextKey = it.Key + if acctIt.Next() { + nextKey = acctIt.Hash().Bytes() } break } @@ -217,9 +210,7 @@ func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey [] if missingPreimages > 0 { log.Warn("Dump incomplete due to missing preimages", "missing", missingPreimages) } - log.Info("Trie dumping complete", "accounts", accounts, - "elapsed", common.PrettyDuration(time.Since(start))) - + log.Info("Trie dumping complete", "accounts", accounts, "elapsed", common.PrettyDuration(time.Since(start))) return nextKey } diff --git a/core/state/statedb.go b/core/state/statedb.go index 93dd7d6488..8b09ea89f6 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -28,7 +28,6 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/stateless" "github.com/ethereum/go-ethereum/core/tracing" "github.com/ethereum/go-ethereum/core/types" @@ -746,7 +745,7 @@ type removedAccountWithBalance struct { balance *uint256.Int } -// EmitLogsForBurnAccounts emits the eth burn logs for accounts scheduled for +// LogsForBurnAccounts returns the eth burn logs for accounts scheduled for // removal which still have positive balance. The purpose of this function is // to handle a corner case of EIP-7708 where a self-destructed account might // still receive funds between sending/burning its previous balance and actual @@ -756,7 +755,7 @@ type removedAccountWithBalance struct { // // This function should only be invoked at the transaction boundary, specifically // before the Finalise. -func (s *StateDB) EmitLogsForBurnAccounts() { +func (s *StateDB) LogsForBurnAccounts() []*types.Log { var list []removedAccountWithBalance for addr := range s.journal.dirties { if obj, exist := s.stateObjects[addr]; exist && obj.selfDestructed && !obj.Balance().IsZero() { @@ -766,14 +765,17 @@ func (s *StateDB) EmitLogsForBurnAccounts() { }) } } - if list != nil { - sort.Slice(list, func(i, j int) bool { - return list[i].address.Cmp(list[j].address) < 0 - }) + if list == nil { + return nil } - for _, acct := range list { - s.AddLog(types.EthBurnLog(acct.address, acct.balance)) + sort.Slice(list, func(i, j int) bool { + return list[i].address.Cmp(list[j].address) < 0 + }) + logs := make([]*types.Log, len(list)) + for i, acct := range list { + logs[i] = types.EthBurnLog(acct.address, acct.balance) } + return logs } // Finalise finalises the state by removing the destructed objects and clears @@ -879,10 +881,12 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { if err := s.trie.UpdateStorage(addr, key[:], common.TrimLeftZeroes(value[:])); err != nil { s.setError(err) } + s.StorageUpdated.Add(1) } else { if err := s.trie.DeleteStorage(addr, key[:]); err != nil { s.setError(err) } + s.StorageDeleted.Add(1) } } } @@ -1037,31 +1041,32 @@ func (s *StateDB) clearJournalAndRefund() { s.refund = 0 } -// fastDeleteStorage is the function that efficiently deletes the storage trie -// of a specific account. It leverages the associated state snapshot for fast -// storage iteration and constructs trie node deletion markers by creating -// stack trie with iterated slots. -func (s *StateDB) fastDeleteStorage(snaps *snapshot.Tree, addrHash common.Hash, root common.Hash) (map[common.Hash][]byte, map[common.Hash][]byte, *trienode.NodeSet, error) { - iter, err := snaps.StorageIterator(s.originalRoot, addrHash, common.Hash{}) - if err != nil { - return nil, nil, nil, err - } - defer iter.Release() - +// deleteStorage is designed to delete the storage trie of a designated account. +func (s *StateDB) deleteStorage(addrHash common.Hash, root common.Hash) (map[common.Hash][]byte, map[common.Hash][]byte, *trienode.NodeSet, error) { var ( nodes = trienode.NewNodeSet(addrHash) // the set for trie node mutations (value is nil) storages = make(map[common.Hash][]byte) // the set for storage mutations (value is nil) storageOrigins = make(map[common.Hash][]byte) // the set for tracking the original value of slot ) + iteratee, err := s.db.Iteratee(s.originalRoot) + if err != nil { + return nil, nil, nil, err + } + it, err := iteratee.NewStorageIterator(addrHash, common.Hash{}) + if err != nil { + return nil, nil, nil, err + } + defer it.Release() + stack := trie.NewStackTrie(func(path []byte, hash common.Hash, blob []byte) { nodes.AddNode(path, trienode.NewDeletedWithPrev(blob)) }) - for iter.Next() { - slot := common.CopyBytes(iter.Slot()) - if err := iter.Error(); err != nil { // error might occur after Slot function + for it.Next() { + slot := common.CopyBytes(it.Slot()) + if err := it.Error(); err != nil { // error might occur after Slot function return nil, nil, nil, err } - key := iter.Hash() + key := it.Hash() storages[key] = nil storageOrigins[key] = slot @@ -1069,7 +1074,7 @@ func (s *StateDB) fastDeleteStorage(snaps *snapshot.Tree, addrHash common.Hash, return nil, nil, nil, err } } - if err := iter.Error(); err != nil { // error might occur during iteration + if err := it.Error(); err != nil { // error might occur during iteration return nil, nil, nil, err } if stack.Hash() != root { @@ -1078,68 +1083,6 @@ func (s *StateDB) fastDeleteStorage(snaps *snapshot.Tree, addrHash common.Hash, return storages, storageOrigins, nodes, nil } -// slowDeleteStorage serves as a less-efficient alternative to "fastDeleteStorage," -// employed when the associated state snapshot is not available. It iterates the -// storage slots along with all internal trie nodes via trie directly. -func (s *StateDB) slowDeleteStorage(addr common.Address, addrHash common.Hash, root common.Hash) (map[common.Hash][]byte, map[common.Hash][]byte, *trienode.NodeSet, error) { - tr, err := s.db.OpenStorageTrie(s.originalRoot, addr, root, s.trie) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to open storage trie, err: %w", err) - } - it, err := tr.NodeIterator(nil) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to open storage iterator, err: %w", err) - } - var ( - nodes = trienode.NewNodeSet(addrHash) // the set for trie node mutations (value is nil) - storages = make(map[common.Hash][]byte) // the set for storage mutations (value is nil) - storageOrigins = make(map[common.Hash][]byte) // the set for tracking the original value of slot - ) - for it.Next(true) { - if it.Leaf() { - key := common.BytesToHash(it.LeafKey()) - storages[key] = nil - storageOrigins[key] = common.CopyBytes(it.LeafBlob()) - continue - } - if it.Hash() == (common.Hash{}) { - continue - } - nodes.AddNode(it.Path(), trienode.NewDeletedWithPrev(it.NodeBlob())) - } - if err := it.Error(); err != nil { - return nil, nil, nil, err - } - return storages, storageOrigins, nodes, nil -} - -// deleteStorage is designed to delete the storage trie of a designated account. -// The function will make an attempt to utilize an efficient strategy if the -// associated state snapshot is reachable; otherwise, it will resort to a less -// efficient approach. -func (s *StateDB) deleteStorage(addr common.Address, addrHash common.Hash, root common.Hash) (map[common.Hash][]byte, map[common.Hash][]byte, *trienode.NodeSet, error) { - var ( - err error - nodes *trienode.NodeSet // the set for trie node mutations (value is nil) - storages map[common.Hash][]byte // the set for storage mutations (value is nil) - storageOrigins map[common.Hash][]byte // the set for tracking the original value of slot - ) - // The fast approach can be failed if the snapshot is not fully - // generated, or it's internally corrupted. Fallback to the slow - // one just in case. - snaps := s.db.Snapshot() - if snaps != nil { - storages, storageOrigins, nodes, err = s.fastDeleteStorage(snaps, addrHash, root) - } - if snaps == nil || err != nil { - storages, storageOrigins, nodes, err = s.slowDeleteStorage(addr, addrHash, root) - } - if err != nil { - return nil, nil, nil, err - } - return storages, storageOrigins, nodes, nil -} - // handleDestruction processes all destruction markers and deletes the account // and associated storage slots if necessary. There are four potential scenarios // as following: @@ -1190,7 +1133,7 @@ func (s *StateDB) handleDestruction(noStorageWiping bool) (map[common.Hash]*acco return nil, nil, fmt.Errorf("unexpected storage wiping, %x", addr) } // Remove storage slots belonging to the account. - storages, storagesOrigin, set, err := s.deleteStorage(addr, addrHash, prev.Root) + storages, storagesOrigin, set, err := s.deleteStorage(addrHash, prev.Root) if err != nil { return nil, nil, fmt.Errorf("failed to delete storage, err: %w", err) } diff --git a/core/state/statedb_hooked.go b/core/state/statedb_hooked.go index 8c217fba48..52cf98d19b 100644 --- a/core/state/statedb_hooked.go +++ b/core/state/statedb_hooked.go @@ -229,8 +229,8 @@ func (s *hookedStateDB) AddLog(log *types.Log) { } } -func (s *hookedStateDB) EmitLogsForBurnAccounts() { - s.inner.EmitLogsForBurnAccounts() +func (s *hookedStateDB) LogsForBurnAccounts() []*types.Log { + return s.inner.LogsForBurnAccounts() } func (s *hookedStateDB) Finalise(deleteEmptyObjects bool) { diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 8d1f93ca1b..d29b262eea 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -1296,12 +1296,12 @@ func TestDeleteStorage(t *testing.T) { obj := fastState.getOrNewStateObject(addr) storageRoot := obj.data.Root - _, _, fastNodes, err := fastState.deleteStorage(addr, crypto.Keccak256Hash(addr[:]), storageRoot) + _, _, fastNodes, err := fastState.deleteStorage(crypto.Keccak256Hash(addr[:]), storageRoot) if err != nil { t.Fatal(err) } - _, _, slowNodes, err := slowState.deleteStorage(addr, crypto.Keccak256Hash(addr[:]), storageRoot) + _, _, slowNodes, err := slowState.deleteStorage(crypto.Keccak256Hash(addr[:]), storageRoot) if err != nil { t.Fatal(err) } diff --git a/core/state_processor.go b/core/state_processor.go index 85f106d58c..bbb1341299 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -260,6 +260,9 @@ func ProcessBeaconBlockRoot(beaconRoot common.Hash, evm *vm.EVM) { evm.SetTxContext(NewEVMTxContext(msg)) evm.StateDB.AddAddressToAccessList(params.BeaconRootsAddress) _, _, _ = evm.Call(msg.From, *msg.To, msg.Data, 30_000_000, common.U2560) + if evm.StateDB.AccessEvents() != nil { + evm.StateDB.AccessEvents().Merge(evm.AccessEvents) + } evm.StateDB.Finalise(true) } @@ -323,6 +326,9 @@ func processRequestsSystemCall(requests *[][]byte, evm *vm.EVM, requestType byte evm.SetTxContext(NewEVMTxContext(msg)) evm.StateDB.AddAddressToAccessList(addr) ret, _, err := evm.Call(msg.From, *msg.To, msg.Data, 30_000_000, common.U2560) + if evm.StateDB.AccessEvents() != nil { + evm.StateDB.AccessEvents().Merge(evm.AccessEvents) + } evm.StateDB.Finalise(true) if err != nil { return fmt.Errorf("system call failed to execute: %v", err) diff --git a/core/state_transition.go b/core/state_transition.go index 52375bedaa..bd7e5daeff 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -584,7 +584,9 @@ func (st *stateTransition) execute() (*ExecutionResult, error) { } } if rules.IsAmsterdam { - st.evm.StateDB.EmitLogsForBurnAccounts() + for _, log := range st.evm.StateDB.LogsForBurnAccounts() { + st.evm.StateDB.AddLog(log) + } } return &ExecutionResult{ UsedGas: st.gasUsed(), diff --git a/core/stateless/encoding.go b/core/stateless/encoding.go index d559178892..1b20c4cb2a 100644 --- a/core/stateless/encoding.go +++ b/core/stateless/encoding.go @@ -17,6 +17,7 @@ package stateless import ( + "errors" "io" "github.com/ethereum/go-ethereum/common/hexutil" @@ -42,6 +43,9 @@ func (w *Witness) ToExtWitness() *ExtWitness { // FromExtWitness converts the consensus witness format into our internal one. func (w *Witness) FromExtWitness(ext *ExtWitness) error { + if len(ext.Headers) == 0 { + return errors.New("witness must contain at least one header") + } w.Headers = ext.Headers w.Codes = make(map[string]struct{}, len(ext.Codes)) diff --git a/core/types/bal/bal_encoding.go b/core/types/bal/bal_encoding.go index 1b1406ea32..6d52c17c83 100644 --- a/core/types/bal/bal_encoding.go +++ b/core/types/bal/bal_encoding.go @@ -350,9 +350,12 @@ func (e *BlockAccessList) PrettyPrint() string { } // Copy returns a deep copy of the access list -func (e *BlockAccessList) Copy() (res BlockAccessList) { - for _, accountAccess := range e.Accesses { - res.Accesses = append(res.Accesses, accountAccess.Copy()) +func (e *BlockAccessList) Copy() *BlockAccessList { + cpy := &BlockAccessList{ + Accesses: make([]AccountAccess, 0, len(e.Accesses)), } - return + for _, accountAccess := range e.Accesses { + cpy.Accesses = append(cpy.Accesses, accountAccess.Copy()) + } + return cpy } diff --git a/core/types/bal/bal_test.go b/core/types/bal/bal_test.go index 52c0de825e..58ba639ff0 100644 --- a/core/types/bal/bal_test.go +++ b/core/types/bal/bal_test.go @@ -190,8 +190,8 @@ func makeTestAccountAccess(sort bool) AccountAccess { } } -func makeTestBAL(sort bool) BlockAccessList { - list := BlockAccessList{} +func makeTestBAL(sort bool) *BlockAccessList { + list := &BlockAccessList{} for i := 0; i < 5; i++ { list.Accesses = append(list.Accesses, makeTestAccountAccess(sort)) } diff --git a/core/types/block.go b/core/types/block.go index d092351b58..ea576ed232 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -30,6 +30,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/core/types/bal" "github.com/ethereum/go-ethereum/rlp" ) @@ -99,6 +100,9 @@ type Header struct { // RequestsHash was added by EIP-7685 and is ignored in legacy headers. RequestsHash *common.Hash `json:"requestsHash" rlp:"optional"` + // BlockAccessListHash was added by EIP-7928 and is ignored in legacy headers. + BlockAccessListHash *common.Hash `json:"balHash" rlp:"optional"` + // SlotNumber was added by EIP-7843 and is ignored in legacy headers. SlotNumber *uint64 `json:"slotNumber" rlp:"optional"` } @@ -204,6 +208,7 @@ type Block struct { uncles []*Header transactions Transactions withdrawals Withdrawals + accessList *bal.BlockAccessList // caches hash atomic.Pointer[common.Hash] @@ -320,6 +325,10 @@ func CopyHeader(h *Header) *Header { cpy.RequestsHash = new(common.Hash) *cpy.RequestsHash = *h.RequestsHash } + if h.BlockAccessListHash != nil { + cpy.BlockAccessListHash = new(common.Hash) + *cpy.BlockAccessListHash = *h.BlockAccessListHash + } if h.SlotNumber != nil { cpy.SlotNumber = new(uint64) *cpy.SlotNumber = *h.SlotNumber @@ -358,9 +367,10 @@ func (b *Block) Body() *Body { // Accessors for body data. These do not return a copy because the content // of the body slices does not affect the cached hash/size in block. -func (b *Block) Uncles() []*Header { return b.uncles } -func (b *Block) Transactions() Transactions { return b.transactions } -func (b *Block) Withdrawals() Withdrawals { return b.withdrawals } +func (b *Block) Uncles() []*Header { return b.uncles } +func (b *Block) Transactions() Transactions { return b.transactions } +func (b *Block) Withdrawals() Withdrawals { return b.withdrawals } +func (b *Block) AccessList() *bal.BlockAccessList { return b.accessList } func (b *Block) Transaction(hash common.Hash) *Transaction { for _, transaction := range b.transactions { @@ -495,6 +505,7 @@ func (b *Block) WithSeal(header *Header) *Block { transactions: b.transactions, uncles: b.uncles, withdrawals: b.withdrawals, + accessList: b.accessList, } } @@ -506,6 +517,7 @@ func (b *Block) WithBody(body Body) *Block { transactions: slices.Clone(body.Transactions), uncles: make([]*Header, len(body.Uncles)), withdrawals: slices.Clone(body.Withdrawals), + accessList: b.accessList, } for i := range body.Uncles { block.uncles[i] = CopyHeader(body.Uncles[i]) @@ -513,6 +525,24 @@ func (b *Block) WithBody(body Body) *Block { return block } +// WithAccessList returns a copy of the block with the given access list embedded. +func (b *Block) WithAccessList(accessList *bal.BlockAccessList) *Block { + return b.WithAccessListUnsafe(accessList.Copy()) +} + +// WithAccessListUnsafe returns a copy of the block with the given access list +// embedded. Note that the access list is not deep-copied; use WithAccessList +// if the provided list may be modified by other actors. +func (b *Block) WithAccessListUnsafe(accessList *bal.BlockAccessList) *Block { + return &Block{ + header: b.header, + transactions: b.transactions, + uncles: b.uncles, + withdrawals: b.withdrawals, + accessList: accessList, + } +} + // Hash returns the keccak256 hash of b's header. // The hash is computed on the first call and cached thereafter. func (b *Block) Hash() common.Hash { diff --git a/core/types/gen_header_json.go b/core/types/gen_header_json.go index 16fb03f612..2e2f1cdca5 100644 --- a/core/types/gen_header_json.go +++ b/core/types/gen_header_json.go @@ -16,29 +16,30 @@ var _ = (*headerMarshaling)(nil) // MarshalJSON marshals as JSON. func (h Header) MarshalJSON() ([]byte, error) { type Header struct { - ParentHash common.Hash `json:"parentHash" gencodec:"required"` - UncleHash common.Hash `json:"sha3Uncles" gencodec:"required"` - Coinbase common.Address `json:"miner"` - Root common.Hash `json:"stateRoot" gencodec:"required"` - TxHash common.Hash `json:"transactionsRoot" gencodec:"required"` - ReceiptHash common.Hash `json:"receiptsRoot" gencodec:"required"` - Bloom Bloom `json:"logsBloom" gencodec:"required"` - Difficulty *hexutil.Big `json:"difficulty" gencodec:"required"` - Number *hexutil.Big `json:"number" gencodec:"required"` - GasLimit hexutil.Uint64 `json:"gasLimit" gencodec:"required"` - GasUsed hexutil.Uint64 `json:"gasUsed" gencodec:"required"` - Time hexutil.Uint64 `json:"timestamp" gencodec:"required"` - Extra hexutil.Bytes `json:"extraData" gencodec:"required"` - MixDigest common.Hash `json:"mixHash"` - Nonce BlockNonce `json:"nonce"` - BaseFee *hexutil.Big `json:"baseFeePerGas" rlp:"optional"` - WithdrawalsHash *common.Hash `json:"withdrawalsRoot" rlp:"optional"` - BlobGasUsed *hexutil.Uint64 `json:"blobGasUsed" rlp:"optional"` - ExcessBlobGas *hexutil.Uint64 `json:"excessBlobGas" rlp:"optional"` - ParentBeaconRoot *common.Hash `json:"parentBeaconBlockRoot" rlp:"optional"` - RequestsHash *common.Hash `json:"requestsHash" rlp:"optional"` - SlotNumber *hexutil.Uint64 `json:"slotNumber" rlp:"optional"` - Hash common.Hash `json:"hash"` + ParentHash common.Hash `json:"parentHash" gencodec:"required"` + UncleHash common.Hash `json:"sha3Uncles" gencodec:"required"` + Coinbase common.Address `json:"miner"` + Root common.Hash `json:"stateRoot" gencodec:"required"` + TxHash common.Hash `json:"transactionsRoot" gencodec:"required"` + ReceiptHash common.Hash `json:"receiptsRoot" gencodec:"required"` + Bloom Bloom `json:"logsBloom" gencodec:"required"` + Difficulty *hexutil.Big `json:"difficulty" gencodec:"required"` + Number *hexutil.Big `json:"number" gencodec:"required"` + GasLimit hexutil.Uint64 `json:"gasLimit" gencodec:"required"` + GasUsed hexutil.Uint64 `json:"gasUsed" gencodec:"required"` + Time hexutil.Uint64 `json:"timestamp" gencodec:"required"` + Extra hexutil.Bytes `json:"extraData" gencodec:"required"` + MixDigest common.Hash `json:"mixHash"` + Nonce BlockNonce `json:"nonce"` + BaseFee *hexutil.Big `json:"baseFeePerGas" rlp:"optional"` + WithdrawalsHash *common.Hash `json:"withdrawalsRoot" rlp:"optional"` + BlobGasUsed *hexutil.Uint64 `json:"blobGasUsed" rlp:"optional"` + ExcessBlobGas *hexutil.Uint64 `json:"excessBlobGas" rlp:"optional"` + ParentBeaconRoot *common.Hash `json:"parentBeaconBlockRoot" rlp:"optional"` + RequestsHash *common.Hash `json:"requestsHash" rlp:"optional"` + BlockAccessListHash *common.Hash `json:"balHash" rlp:"optional"` + SlotNumber *hexutil.Uint64 `json:"slotNumber" rlp:"optional"` + Hash common.Hash `json:"hash"` } var enc Header enc.ParentHash = h.ParentHash @@ -62,6 +63,7 @@ func (h Header) MarshalJSON() ([]byte, error) { enc.ExcessBlobGas = (*hexutil.Uint64)(h.ExcessBlobGas) enc.ParentBeaconRoot = h.ParentBeaconRoot enc.RequestsHash = h.RequestsHash + enc.BlockAccessListHash = h.BlockAccessListHash enc.SlotNumber = (*hexutil.Uint64)(h.SlotNumber) enc.Hash = h.Hash() return json.Marshal(&enc) @@ -70,28 +72,29 @@ func (h Header) MarshalJSON() ([]byte, error) { // UnmarshalJSON unmarshals from JSON. func (h *Header) UnmarshalJSON(input []byte) error { type Header struct { - ParentHash *common.Hash `json:"parentHash" gencodec:"required"` - UncleHash *common.Hash `json:"sha3Uncles" gencodec:"required"` - Coinbase *common.Address `json:"miner"` - Root *common.Hash `json:"stateRoot" gencodec:"required"` - TxHash *common.Hash `json:"transactionsRoot" gencodec:"required"` - ReceiptHash *common.Hash `json:"receiptsRoot" gencodec:"required"` - Bloom *Bloom `json:"logsBloom" gencodec:"required"` - Difficulty *hexutil.Big `json:"difficulty" gencodec:"required"` - Number *hexutil.Big `json:"number" gencodec:"required"` - GasLimit *hexutil.Uint64 `json:"gasLimit" gencodec:"required"` - GasUsed *hexutil.Uint64 `json:"gasUsed" gencodec:"required"` - Time *hexutil.Uint64 `json:"timestamp" gencodec:"required"` - Extra *hexutil.Bytes `json:"extraData" gencodec:"required"` - MixDigest *common.Hash `json:"mixHash"` - Nonce *BlockNonce `json:"nonce"` - BaseFee *hexutil.Big `json:"baseFeePerGas" rlp:"optional"` - WithdrawalsHash *common.Hash `json:"withdrawalsRoot" rlp:"optional"` - BlobGasUsed *hexutil.Uint64 `json:"blobGasUsed" rlp:"optional"` - ExcessBlobGas *hexutil.Uint64 `json:"excessBlobGas" rlp:"optional"` - ParentBeaconRoot *common.Hash `json:"parentBeaconBlockRoot" rlp:"optional"` - RequestsHash *common.Hash `json:"requestsHash" rlp:"optional"` - SlotNumber *hexutil.Uint64 `json:"slotNumber" rlp:"optional"` + ParentHash *common.Hash `json:"parentHash" gencodec:"required"` + UncleHash *common.Hash `json:"sha3Uncles" gencodec:"required"` + Coinbase *common.Address `json:"miner"` + Root *common.Hash `json:"stateRoot" gencodec:"required"` + TxHash *common.Hash `json:"transactionsRoot" gencodec:"required"` + ReceiptHash *common.Hash `json:"receiptsRoot" gencodec:"required"` + Bloom *Bloom `json:"logsBloom" gencodec:"required"` + Difficulty *hexutil.Big `json:"difficulty" gencodec:"required"` + Number *hexutil.Big `json:"number" gencodec:"required"` + GasLimit *hexutil.Uint64 `json:"gasLimit" gencodec:"required"` + GasUsed *hexutil.Uint64 `json:"gasUsed" gencodec:"required"` + Time *hexutil.Uint64 `json:"timestamp" gencodec:"required"` + Extra *hexutil.Bytes `json:"extraData" gencodec:"required"` + MixDigest *common.Hash `json:"mixHash"` + Nonce *BlockNonce `json:"nonce"` + BaseFee *hexutil.Big `json:"baseFeePerGas" rlp:"optional"` + WithdrawalsHash *common.Hash `json:"withdrawalsRoot" rlp:"optional"` + BlobGasUsed *hexutil.Uint64 `json:"blobGasUsed" rlp:"optional"` + ExcessBlobGas *hexutil.Uint64 `json:"excessBlobGas" rlp:"optional"` + ParentBeaconRoot *common.Hash `json:"parentBeaconBlockRoot" rlp:"optional"` + RequestsHash *common.Hash `json:"requestsHash" rlp:"optional"` + BlockAccessListHash *common.Hash `json:"balHash" rlp:"optional"` + SlotNumber *hexutil.Uint64 `json:"slotNumber" rlp:"optional"` } var dec Header if err := json.Unmarshal(input, &dec); err != nil { @@ -172,6 +175,9 @@ func (h *Header) UnmarshalJSON(input []byte) error { if dec.RequestsHash != nil { h.RequestsHash = dec.RequestsHash } + if dec.BlockAccessListHash != nil { + h.BlockAccessListHash = dec.BlockAccessListHash + } if dec.SlotNumber != nil { h.SlotNumber = (*uint64)(dec.SlotNumber) } diff --git a/core/types/gen_header_rlp.go b/core/types/gen_header_rlp.go index cfbd57ab8a..3b7eb2c926 100644 --- a/core/types/gen_header_rlp.go +++ b/core/types/gen_header_rlp.go @@ -43,8 +43,9 @@ func (obj *Header) EncodeRLP(_w io.Writer) error { _tmp4 := obj.ExcessBlobGas != nil _tmp5 := obj.ParentBeaconRoot != nil _tmp6 := obj.RequestsHash != nil - _tmp7 := obj.SlotNumber != nil - if _tmp1 || _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + _tmp7 := obj.BlockAccessListHash != nil + _tmp8 := obj.SlotNumber != nil + if _tmp1 || _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 || _tmp8 { if obj.BaseFee == nil { w.Write(rlp.EmptyString) } else { @@ -54,42 +55,49 @@ func (obj *Header) EncodeRLP(_w io.Writer) error { w.WriteBigInt(obj.BaseFee) } } - if _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + if _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 || _tmp8 { if obj.WithdrawalsHash == nil { w.Write([]byte{0x80}) } else { w.WriteBytes(obj.WithdrawalsHash[:]) } } - if _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + if _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 || _tmp8 { if obj.BlobGasUsed == nil { w.Write([]byte{0x80}) } else { w.WriteUint64((*obj.BlobGasUsed)) } } - if _tmp4 || _tmp5 || _tmp6 || _tmp7 { + if _tmp4 || _tmp5 || _tmp6 || _tmp7 || _tmp8 { if obj.ExcessBlobGas == nil { w.Write([]byte{0x80}) } else { w.WriteUint64((*obj.ExcessBlobGas)) } } - if _tmp5 || _tmp6 || _tmp7 { + if _tmp5 || _tmp6 || _tmp7 || _tmp8 { if obj.ParentBeaconRoot == nil { w.Write([]byte{0x80}) } else { w.WriteBytes(obj.ParentBeaconRoot[:]) } } - if _tmp6 || _tmp7 { + if _tmp6 || _tmp7 || _tmp8 { if obj.RequestsHash == nil { w.Write([]byte{0x80}) } else { w.WriteBytes(obj.RequestsHash[:]) } } - if _tmp7 { + if _tmp7 || _tmp8 { + if obj.BlockAccessListHash == nil { + w.Write([]byte{0x80}) + } else { + w.WriteBytes(obj.BlockAccessListHash[:]) + } + } + if _tmp8 { if obj.SlotNumber == nil { w.Write([]byte{0x80}) } else { diff --git a/core/vm/contract.go b/core/vm/contract.go index 165ca833f8..3b5695e21a 100644 --- a/core/vm/contract.go +++ b/core/vm/contract.go @@ -42,7 +42,7 @@ type Contract struct { IsDeployment bool IsSystemCall bool - Gas uint64 + Gas GasCosts value *uint256.Int } @@ -56,7 +56,7 @@ func NewContract(caller common.Address, address common.Address, value *uint256.I caller: caller, address: address, jumpDests: jumpDests, - Gas: gas, + Gas: GasCosts{RegularGas: gas}, value: value, } } @@ -127,13 +127,13 @@ func (c *Contract) Caller() common.Address { // UseGas attempts the use gas and subtracts it and returns true on success func (c *Contract) UseGas(gas uint64, logger *tracing.Hooks, reason tracing.GasChangeReason) (ok bool) { - if c.Gas < gas { + if c.Gas.RegularGas < gas { return false } if logger != nil && logger.OnGasChange != nil && reason != tracing.GasChangeIgnored { - logger.OnGasChange(c.Gas, c.Gas-gas, reason) + logger.OnGasChange(c.Gas.RegularGas, c.Gas.RegularGas-gas, reason) } - c.Gas -= gas + c.Gas.RegularGas -= gas return true } @@ -143,9 +143,9 @@ func (c *Contract) RefundGas(gas uint64, logger *tracing.Hooks, reason tracing.G return } if logger != nil && logger.OnGasChange != nil && reason != tracing.GasChangeIgnored { - logger.OnGasChange(c.Gas, c.Gas+gas, reason) + logger.OnGasChange(c.Gas.RegularGas, c.Gas.RegularGas+gas, reason) } - c.Gas += gas + c.Gas.RegularGas += gas } // Address returns the contracts address diff --git a/core/vm/eips.go b/core/vm/eips.go index 3ccd9aaaf0..8f4ca3ae41 100644 --- a/core/vm/eips.go +++ b/core/vm/eips.go @@ -381,7 +381,7 @@ func opExtCodeCopyEIP4762(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, er addr := common.Address(a.Bytes20()) code := evm.StateDB.GetCode(addr) paddedCodeCopy, copyOffset, nonPaddedCopyLength := getDataAndAdjustedBounds(code, uint64CodeOffset, length.Uint64()) - consumed, wanted := evm.AccessEvents.CodeChunksRangeGas(addr, copyOffset, nonPaddedCopyLength, uint64(len(code)), false, scope.Contract.Gas) + consumed, wanted := evm.AccessEvents.CodeChunksRangeGas(addr, copyOffset, nonPaddedCopyLength, uint64(len(code)), false, scope.Contract.Gas.RegularGas) scope.Contract.UseGas(consumed, evm.Config.Tracer, tracing.GasChangeUnspecified) if consumed < wanted { return nil, ErrOutOfGas @@ -407,7 +407,7 @@ func opPush1EIP4762(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { // touch next chunk if PUSH1 is at the boundary. if so, *pc has // advanced past this boundary. contractAddr := scope.Contract.Address() - consumed, wanted := evm.AccessEvents.CodeChunksRangeGas(contractAddr, *pc+1, uint64(1), uint64(len(scope.Contract.Code)), false, scope.Contract.Gas) + consumed, wanted := evm.AccessEvents.CodeChunksRangeGas(contractAddr, *pc+1, uint64(1), uint64(len(scope.Contract.Code)), false, scope.Contract.Gas.RegularGas) scope.Contract.UseGas(wanted, evm.Config.Tracer, tracing.GasChangeUnspecified) if consumed < wanted { return nil, ErrOutOfGas @@ -435,7 +435,7 @@ func makePushEIP4762(size uint64, pushByteSize int) executionFunc { if !scope.Contract.IsDeployment && !scope.Contract.IsSystemCall { contractAddr := scope.Contract.Address() - consumed, wanted := evm.AccessEvents.CodeChunksRangeGas(contractAddr, uint64(start), uint64(pushByteSize), uint64(len(scope.Contract.Code)), false, scope.Contract.Gas) + consumed, wanted := evm.AccessEvents.CodeChunksRangeGas(contractAddr, uint64(start), uint64(pushByteSize), uint64(len(scope.Contract.Code)), false, scope.Contract.Gas.RegularGas) scope.Contract.UseGas(consumed, evm.Config.Tracer, tracing.GasChangeUnspecified) if consumed < wanted { return nil, ErrOutOfGas diff --git a/core/vm/evm.go b/core/vm/evm.go index 36494de2a8..4df2627486 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -303,7 +303,7 @@ func (evm *EVM) Call(caller common.Address, addr common.Address, input []byte, g contract.IsSystemCall = isSystemCall(caller) contract.SetCallCode(evm.resolveCodeHash(addr), code) ret, err = evm.Run(contract, input, false) - gas = contract.Gas + gas = contract.Gas.RegularGas } } // When an error was returned by the EVM or when setting the creation code @@ -365,7 +365,7 @@ func (evm *EVM) CallCode(caller common.Address, addr common.Address, input []byt contract := NewContract(caller, caller, value, gas, evm.jumpDests) contract.SetCallCode(evm.resolveCodeHash(addr), evm.resolveCode(addr)) ret, err = evm.Run(contract, input, false) - gas = contract.Gas + gas = contract.Gas.RegularGas } if err != nil { evm.StateDB.RevertToSnapshot(snapshot) @@ -413,7 +413,7 @@ func (evm *EVM) DelegateCall(originCaller common.Address, caller common.Address, contract := NewContract(originCaller, caller, value, gas, evm.jumpDests) contract.SetCallCode(evm.resolveCodeHash(addr), evm.resolveCode(addr)) ret, err = evm.Run(contract, input, false) - gas = contract.Gas + gas = contract.Gas.RegularGas } if err != nil { evm.StateDB.RevertToSnapshot(snapshot) @@ -472,7 +472,7 @@ func (evm *EVM) StaticCall(caller common.Address, addr common.Address, input []b // above we revert to the snapshot and consume any gas remaining. Additionally // when we're in Homestead this also counts for code storage gas errors. ret, err = evm.Run(contract, input, true) - gas = contract.Gas + gas = contract.Gas.RegularGas } if err != nil { evm.StateDB.RevertToSnapshot(snapshot) @@ -583,10 +583,10 @@ func (evm *EVM) create(caller common.Address, code []byte, gas uint64, value *ui if err != nil && (evm.chainRules.IsHomestead || err != ErrCodeStoreOutOfGas) { evm.StateDB.RevertToSnapshot(snapshot) if err != ErrExecutionReverted { - contract.UseGas(contract.Gas, evm.Config.Tracer, tracing.GasChangeCallFailedExecution) + contract.UseGas(contract.Gas.RegularGas, evm.Config.Tracer, tracing.GasChangeCallFailedExecution) } } - return ret, address, contract.Gas, err + return ret, address, contract.Gas.RegularGas, err } // initNewContract runs a new contract's creation code, performs checks on the @@ -613,7 +613,7 @@ func (evm *EVM) initNewContract(contract *Contract, address common.Address) ([]b return ret, ErrCodeStoreOutOfGas } } else { - consumed, wanted := evm.AccessEvents.CodeChunksRangeGas(address, 0, uint64(len(ret)), uint64(len(ret)), true, contract.Gas) + consumed, wanted := evm.AccessEvents.CodeChunksRangeGas(address, 0, uint64(len(ret)), uint64(len(ret)), true, contract.Gas.RegularGas) contract.UseGas(consumed, evm.Config.Tracer, tracing.GasChangeWitnessCodeChunk) if len(ret) > 0 && (consumed < wanted) { return ret, ErrCodeStoreOutOfGas diff --git a/core/vm/gas_table.go b/core/vm/gas_table.go index f075a99468..b3259b2ec7 100644 --- a/core/vm/gas_table.go +++ b/core/vm/gas_table.go @@ -64,26 +64,26 @@ func memoryGasCost(mem *Memory, newMemSize uint64) (uint64, error) { // EXTCODECOPY (stack position 3) // RETURNDATACOPY (stack position 2) func memoryCopierGas(stackpos int) gasFunc { - return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { // Gas for expanding the memory gas, err := memoryGasCost(mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } // And gas for copying data, charged per word at param.CopyGas words, overflow := stack.Back(stackpos).Uint64WithOverflow() if overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } if words, overflow = math.SafeMul(toWordSize(words), params.CopyGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } if gas, overflow = math.SafeAdd(gas, words); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } } @@ -95,9 +95,9 @@ var ( gasReturnDataCopy = memoryCopierGas(2) ) -func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { if evm.readOnly { - return 0, ErrWriteProtection + return GasCosts{}, ErrWriteProtection } var ( y, x = stack.Back(1), stack.Back(0) @@ -114,12 +114,12 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi // 3. From a non-zero to a non-zero (CHANGE) switch { case current == (common.Hash{}) && y.Sign() != 0: // 0 => non 0 - return params.SstoreSetGas, nil + return GasCosts{RegularGas: params.SstoreSetGas}, nil case current != (common.Hash{}) && y.Sign() == 0: // non 0 => 0 evm.StateDB.AddRefund(params.SstoreRefundGas) - return params.SstoreClearGas, nil + return GasCosts{RegularGas: params.SstoreClearGas}, nil default: // non 0 => non 0 (or 0 => 0) - return params.SstoreResetGas, nil + return GasCosts{RegularGas: params.SstoreResetGas}, nil } } @@ -139,16 +139,16 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi // (2.2.2.2.) Otherwise, add 4800 gas to refund counter. value := common.Hash(y.Bytes32()) if current == value { // noop (1) - return params.NetSstoreNoopGas, nil + return GasCosts{RegularGas: params.NetSstoreNoopGas}, nil } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) - return params.NetSstoreInitGas, nil + return GasCosts{RegularGas: params.NetSstoreInitGas}, nil } if value == (common.Hash{}) { // delete slot (2.1.2b) evm.StateDB.AddRefund(params.NetSstoreClearRefund) } - return params.NetSstoreCleanGas, nil // write existing slot (2.1.2) + return GasCosts{RegularGas: params.NetSstoreCleanGas}, nil // write existing slot (2.1.2) } if original != (common.Hash{}) { if current == (common.Hash{}) { // recreate slot (2.2.1.1) @@ -164,7 +164,7 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi evm.StateDB.AddRefund(params.NetSstoreResetRefund) } } - return params.NetSstoreDirtyGas, nil + return GasCosts{RegularGas: params.NetSstoreDirtyGas}, nil } // Here come the EIP2200 rules: @@ -182,13 +182,13 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi // (2.2.2.) If original value equals new value (this storage slot is reset): // (2.2.2.1.) If original value is 0, add SSTORE_SET_GAS - SLOAD_GAS to refund counter. // (2.2.2.2.) Otherwise, add SSTORE_RESET_GAS - SLOAD_GAS gas to refund counter. -func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { if evm.readOnly { - return 0, ErrWriteProtection + return GasCosts{}, ErrWriteProtection } // If we fail the minimum gas availability invariant, fail (0) - if contract.Gas <= params.SstoreSentryGasEIP2200 { - return 0, errors.New("not enough gas for reentrancy sentry") + if contract.Gas.RegularGas <= params.SstoreSentryGasEIP2200 { + return GasCosts{}, errors.New("not enough gas for reentrancy sentry") } // Gas sentry honoured, do the actual gas calculation based on the stored value var ( @@ -198,16 +198,16 @@ func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m value := common.Hash(y.Bytes32()) if current == value { // noop (1) - return params.SloadGasEIP2200, nil + return GasCosts{RegularGas: params.SloadGasEIP2200}, nil } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) - return params.SstoreSetGasEIP2200, nil + return GasCosts{RegularGas: params.SstoreSetGasEIP2200}, nil } if value == (common.Hash{}) { // delete slot (2.1.2b) evm.StateDB.AddRefund(params.SstoreClearsScheduleRefundEIP2200) } - return params.SstoreResetGasEIP2200, nil // write existing slot (2.1.2) + return GasCosts{RegularGas: params.SstoreResetGasEIP2200}, nil // write existing slot (2.1.2) } if original != (common.Hash{}) { if current == (common.Hash{}) { // recreate slot (2.2.1.1) @@ -223,62 +223,66 @@ func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m evm.StateDB.AddRefund(params.SstoreResetGasEIP2200 - params.SloadGasEIP2200) } } - return params.SloadGasEIP2200, nil // dirty update (2.2) + return GasCosts{RegularGas: params.SloadGasEIP2200}, nil // dirty update (2.2) } func makeGasLog(n uint64) gasFunc { - return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { requestedSize, overflow := stack.Back(1).Uint64WithOverflow() if overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } gas, err := memoryGasCost(mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } if gas, overflow = math.SafeAdd(gas, params.LogGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } if gas, overflow = math.SafeAdd(gas, n*params.LogTopicGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } var memorySizeGas uint64 if memorySizeGas, overflow = math.SafeMul(requestedSize, params.LogDataGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } if gas, overflow = math.SafeAdd(gas, memorySizeGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } } -func gasKeccak256(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasKeccak256(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { gas, err := memoryGasCost(mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } wordGas, overflow := stack.Back(1).Uint64WithOverflow() if overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } if wordGas, overflow = math.SafeMul(toWordSize(wordGas), params.Keccak256WordGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } if gas, overflow = math.SafeAdd(gas, wordGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } // pureMemoryGascost is used by several operations, which aside from their // static cost have a dynamic cost which is solely based on the memory // expansion -func pureMemoryGascost(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { - return memoryGasCost(mem, memorySize) +func pureMemoryGascost(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { + gas, err := memoryGasCost(mem, memorySize) + if err != nil { + return GasCosts{}, err + } + return GasCosts{RegularGas: gas}, nil } var ( @@ -290,64 +294,64 @@ var ( gasCreate = pureMemoryGascost ) -func gasCreate2(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasCreate2(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { gas, err := memoryGasCost(mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } wordGas, overflow := stack.Back(2).Uint64WithOverflow() if overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } if wordGas, overflow = math.SafeMul(toWordSize(wordGas), params.Keccak256WordGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } if gas, overflow = math.SafeAdd(gas, wordGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } -func gasCreateEip3860(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasCreateEip3860(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { gas, err := memoryGasCost(mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } size, overflow := stack.Back(2).Uint64WithOverflow() if overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } if err := CheckMaxInitCodeSize(&evm.chainRules, size); err != nil { - return 0, err + return GasCosts{}, err } // Since size <= the protocol-defined maximum initcode size limit, these multiplication cannot overflow moreGas := params.InitCodeWordGas * ((size + 31) / 32) if gas, overflow = math.SafeAdd(gas, moreGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } -func gasCreate2Eip3860(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasCreate2Eip3860(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { gas, err := memoryGasCost(mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } size, overflow := stack.Back(2).Uint64WithOverflow() if overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } if err := CheckMaxInitCodeSize(&evm.chainRules, size); err != nil { - return 0, err + return GasCosts{}, err } // Since size <= the protocol-defined maximum initcode size limit, these multiplication cannot overflow moreGas := (params.InitCodeWordGas + params.Keccak256WordGas) * ((size + 31) / 32) if gas, overflow = math.SafeAdd(gas, moreGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } -func gasExpFrontier(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasExpFrontier(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { expByteLen := uint64((stack.data[stack.len()-2].BitLen() + 7) / 8) var ( @@ -355,12 +359,12 @@ func gasExpFrontier(evm *EVM, contract *Contract, stack *Stack, mem *Memory, mem overflow bool ) if gas, overflow = math.SafeAdd(gas, params.ExpGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } -func gasExpEIP158(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasExpEIP158(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { expByteLen := uint64((stack.data[stack.len()-2].BitLen() + 7) / 8) var ( @@ -368,9 +372,9 @@ func gasExpEIP158(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memor overflow bool ) if gas, overflow = math.SafeAdd(gas, params.ExpGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } var ( @@ -381,36 +385,36 @@ var ( ) func makeCallVariantGasCost(intrinsicFunc gasFunc) gasFunc { - return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { intrinsic, err := intrinsicFunc(evm, contract, stack, mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } - evm.callGasTemp, err = callGas(evm.chainRules.IsEIP150, contract.Gas, intrinsic, stack.Back(0)) + evm.callGasTemp, err = callGas(evm.chainRules.IsEIP150, contract.Gas.RegularGas, intrinsic.RegularGas, stack.Back(0)) if err != nil { - return 0, err + return GasCosts{}, err } - gas, overflow := math.SafeAdd(intrinsic, evm.callGasTemp) + gas, overflow := math.SafeAdd(intrinsic.RegularGas, evm.callGasTemp) if overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } } -func gasCallIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasCallIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { var ( gas uint64 transfersValue = !stack.Back(2).IsZero() address = common.Address(stack.Back(1).Bytes20()) ) if evm.readOnly && transfersValue { - return 0, ErrWriteProtection + return GasCosts{}, ErrWriteProtection } // Stateless check memoryGas, err := memoryGasCost(mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } var transferGas uint64 if transfersValue && !evm.chainRules.IsEIP4762 { @@ -418,12 +422,12 @@ func gasCallIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m } var overflow bool if gas, overflow = math.SafeAdd(memoryGas, transferGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } // Terminate the gas measurement if the leftover gas is not sufficient, // it can effectively prevent accessing the states in the following steps. - if contract.Gas < gas { - return 0, ErrOutOfGas + if contract.Gas.RegularGas < gas { + return GasCosts{}, ErrOutOfGas } // Stateful check var stateGas uint64 @@ -435,15 +439,15 @@ func gasCallIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m stateGas += params.CallNewAccountGas } if gas, overflow = math.SafeAdd(gas, stateGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } -func gasCallCodeIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasCallCodeIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { memoryGas, err := memoryGasCost(mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } var ( gas uint64 @@ -453,22 +457,30 @@ func gasCallCodeIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memor gas += params.CallValueTransferGas } if gas, overflow = math.SafeAdd(gas, memoryGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } -func gasDelegateCallIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { - return memoryGasCost(mem, memorySize) +func gasDelegateCallIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { + gas, err := memoryGasCost(mem, memorySize) + if err != nil { + return GasCosts{}, err + } + return GasCosts{RegularGas: gas}, nil } -func gasStaticCallIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { - return memoryGasCost(mem, memorySize) +func gasStaticCallIntrinsic(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { + gas, err := memoryGasCost(mem, memorySize) + if err != nil { + return GasCosts{}, err + } + return GasCosts{RegularGas: gas}, nil } -func gasSelfdestruct(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasSelfdestruct(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { if evm.readOnly { - return 0, ErrWriteProtection + return GasCosts{}, ErrWriteProtection } var gas uint64 @@ -490,5 +502,5 @@ func gasSelfdestruct(evm *EVM, contract *Contract, stack *Stack, mem *Memory, me if !evm.StateDB.HasSelfDestructed(contract.Address()) { evm.StateDB.AddRefund(params.SelfdestructRefundGas) } - return gas, nil + return GasCosts{RegularGas: gas}, nil } diff --git a/core/vm/gascosts.go b/core/vm/gascosts.go new file mode 100644 index 0000000000..ba6746758b --- /dev/null +++ b/core/vm/gascosts.go @@ -0,0 +1,36 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package vm + +import "fmt" + +// GasCosts denotes a vector of gas costs in the +// multidimensional metering paradigm. +type GasCosts struct { + RegularGas uint64 + StateGas uint64 +} + +// Sum returns the total gas (regular + state). +func (g GasCosts) Sum() uint64 { + return g.RegularGas + g.StateGas +} + +// String returns a visual representation of the gas vector. +func (g GasCosts) String() string { + return fmt.Sprintf("<%v,%v>", g.RegularGas, g.StateGas) +} diff --git a/core/vm/instructions.go b/core/vm/instructions.go index a5fa11e307..74400732ac 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -566,7 +566,7 @@ func opMsize(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { } func opGas(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { - scope.Stack.push(new(uint256.Int).SetUint64(scope.Contract.Gas)) + scope.Stack.push(new(uint256.Int).SetUint64(scope.Contract.Gas.RegularGas)) return nil, nil } @@ -658,7 +658,7 @@ func opCreate(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { value = scope.Stack.pop() offset, size = scope.Stack.pop(), scope.Stack.pop() input = scope.Memory.GetCopy(offset.Uint64(), size.Uint64()) - gas = scope.Contract.Gas + gas = scope.Contract.Gas.RegularGas ) if evm.chainRules.IsEIP150 { gas -= gas / 64 @@ -702,7 +702,7 @@ func opCreate2(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { offset, size = scope.Stack.pop(), scope.Stack.pop() salt = scope.Stack.pop() input = scope.Memory.GetCopy(offset.Uint64(), size.Uint64()) - gas = scope.Contract.Gas + gas = scope.Contract.Gas.RegularGas ) // Apply EIP150 diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go index 4c6d093d2e..1f69eea3da 100644 --- a/core/vm/instructions_test.go +++ b/core/vm/instructions_test.go @@ -879,7 +879,7 @@ func TestOpMCopy(t *testing.T) { if dynamicCost, err := gasMcopy(evm, nil, stack, mem, memorySize); err != nil { t.Error(err) } else { - haveGas = GasFastestStep + dynamicCost + haveGas = GasFastestStep + dynamicCost.RegularGas } // Expand mem if memorySize > 0 { diff --git a/core/vm/interface.go b/core/vm/interface.go index 6a93846ac5..d7c4340e06 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -87,7 +87,7 @@ type StateDB interface { Snapshot() int AddLog(*types.Log) - EmitLogsForBurnAccounts() + LogsForBurnAccounts() []*types.Log AddPreimage(common.Hash, []byte) Witness() *stateless.Witness diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index 620c069fc8..b507595fab 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -166,14 +166,14 @@ func (evm *EVM) Run(contract *Contract, input []byte, readOnly bool) (ret []byte for { if debug { // Capture pre-execution values for tracing. - logged, pcCopy, gasCopy = false, pc, contract.Gas + logged, pcCopy, gasCopy = false, pc, contract.Gas.RegularGas } if isEIP4762 && !contract.IsDeployment && !contract.IsSystemCall { // if the PC ends up in a new "chunk" of verkleized code, charge the // associated costs. contractAddr := contract.Address() - consumed, wanted := evm.TxContext.AccessEvents.CodeChunksRangeGas(contractAddr, pc, 1, uint64(len(contract.Code)), false, contract.Gas) + consumed, wanted := evm.TxContext.AccessEvents.CodeChunksRangeGas(contractAddr, pc, 1, uint64(len(contract.Code)), false, contract.Gas.RegularGas) contract.UseGas(consumed, evm.Config.Tracer, tracing.GasChangeWitnessCodeChunk) if consumed < wanted { return nil, ErrOutOfGas @@ -192,10 +192,10 @@ func (evm *EVM) Run(contract *Contract, input []byte, readOnly bool) (ret []byte return nil, &ErrStackOverflow{stackLen: sLen, limit: operation.maxStack} } // for tracing: this gas consumption event is emitted below in the debug section. - if contract.Gas < cost { + if contract.Gas.RegularGas < cost { return nil, ErrOutOfGas } else { - contract.Gas -= cost + contract.Gas.RegularGas -= cost } // All ops with a dynamic memory usage also has a dynamic gas cost. @@ -218,17 +218,17 @@ func (evm *EVM) Run(contract *Contract, input []byte, readOnly bool) (ret []byte } // Consume the gas and return an error if not enough gas is available. // cost is explicitly set so that the capture state defer method can get the proper cost - var dynamicCost uint64 + var dynamicCost GasCosts dynamicCost, err = operation.dynamicGas(evm, contract, stack, mem, memorySize) - cost += dynamicCost // for tracing + cost += dynamicCost.RegularGas // for tracing if err != nil { return nil, fmt.Errorf("%w: %v", ErrOutOfGas, err) } // for tracing: this gas consumption event is emitted below in the debug section. - if contract.Gas < dynamicCost { + if contract.Gas.RegularGas < dynamicCost.RegularGas { return nil, ErrOutOfGas } else { - contract.Gas -= dynamicCost + contract.Gas.RegularGas -= dynamicCost.RegularGas } } diff --git a/core/vm/jump_table.go b/core/vm/jump_table.go index a2e2c91194..82fc43ec13 100644 --- a/core/vm/jump_table.go +++ b/core/vm/jump_table.go @@ -24,7 +24,7 @@ import ( type ( executionFunc func(pc *uint64, evm *EVM, callContext *ScopeContext) ([]byte, error) - gasFunc func(*EVM, *Contract, *Stack, *Memory, uint64) (uint64, error) // last parameter is the requested memory size as a uint64 + gasFunc func(*EVM, *Contract, *Stack, *Memory, uint64) (GasCosts, error) // last parameter is the requested memory size as a uint64 // memorySizeFunc returns the required size, and whether the operation overflowed a uint64 memorySizeFunc func(*Stack) (size uint64, overflow bool) ) diff --git a/core/vm/operations_acl.go b/core/vm/operations_acl.go index addd2b162f..154c261cae 100644 --- a/core/vm/operations_acl.go +++ b/core/vm/operations_acl.go @@ -27,13 +27,13 @@ import ( ) func makeGasSStoreFunc(clearingRefund uint64) gasFunc { - return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { if evm.readOnly { - return 0, ErrWriteProtection + return GasCosts{}, ErrWriteProtection } // If we fail the minimum gas availability invariant, fail (0) - if contract.Gas <= params.SstoreSentryGasEIP2200 { - return 0, errors.New("not enough gas for reentrancy sentry") + if contract.Gas.RegularGas <= params.SstoreSentryGasEIP2200 { + return GasCosts{}, errors.New("not enough gas for reentrancy sentry") } // Gas sentry honoured, do the actual gas calculation based on the stored value var ( @@ -53,18 +53,18 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { if current == value { // noop (1) // EIP 2200 original clause: // return params.SloadGasEIP2200, nil - return cost + params.WarmStorageReadCostEIP2929, nil // SLOAD_GAS + return GasCosts{RegularGas: cost + params.WarmStorageReadCostEIP2929}, nil // SLOAD_GAS } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) - return cost + params.SstoreSetGasEIP2200, nil + return GasCosts{RegularGas: cost + params.SstoreSetGasEIP2200}, nil } if value == (common.Hash{}) { // delete slot (2.1.2b) evm.StateDB.AddRefund(clearingRefund) } // EIP-2200 original clause: // return params.SstoreResetGasEIP2200, nil // write existing slot (2.1.2) - return cost + (params.SstoreResetGasEIP2200 - params.ColdSloadCostEIP2929), nil // write existing slot (2.1.2) + return GasCosts{RegularGas: cost + (params.SstoreResetGasEIP2200 - params.ColdSloadCostEIP2929)}, nil // write existing slot (2.1.2) } if original != (common.Hash{}) { if current == (common.Hash{}) { // recreate slot (2.2.1.1) @@ -89,7 +89,7 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { } // EIP-2200 original clause: //return params.SloadGasEIP2200, nil // dirty update (2.2) - return cost + params.WarmStorageReadCostEIP2929, nil // dirty update (2.2) + return GasCosts{RegularGas: cost + params.WarmStorageReadCostEIP2929}, nil // dirty update (2.2) } } @@ -98,7 +98,7 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { // whose storage is being read) is not yet in accessed_storage_keys, // charge 2100 gas and add the pair to accessed_storage_keys. // If the pair is already in accessed_storage_keys, charge 100 gas. -func gasSLoadEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasSLoadEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { loc := stack.peek() slot := common.Hash(loc.Bytes32()) // Check slot presence in the access list @@ -106,9 +106,9 @@ func gasSLoadEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, me // If the caller cannot afford the cost, this change will be rolled back // If he does afford it, we can skip checking the same thing later on, during execution evm.StateDB.AddSlotToAccessList(contract.Address(), slot) - return params.ColdSloadCostEIP2929, nil + return GasCosts{RegularGas: params.ColdSloadCostEIP2929}, nil } - return params.WarmStorageReadCostEIP2929, nil + return GasCosts{RegularGas: params.WarmStorageReadCostEIP2929}, nil } // gasExtCodeCopyEIP2929 implements extcodecopy according to EIP-2929 @@ -116,12 +116,13 @@ func gasSLoadEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, me // > If the target is not in accessed_addresses, // > charge COLD_ACCOUNT_ACCESS_COST gas, and add the address to accessed_addresses. // > Otherwise, charge WARM_STORAGE_READ_COST gas. -func gasExtCodeCopyEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasExtCodeCopyEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { // memory expansion first (dynamic part of pre-2929 implementation) - gas, err := gasExtCodeCopy(evm, contract, stack, mem, memorySize) + gasCost, err := gasExtCodeCopy(evm, contract, stack, mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } + gas := gasCost.RegularGas addr := common.Address(stack.peek().Bytes20()) // Check slot presence in the access list if !evm.StateDB.AddressInAccessList(addr) { @@ -129,11 +130,11 @@ func gasExtCodeCopyEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memo var overflow bool // We charge (cold-warm), since 'warm' is already charged as constantGas if gas, overflow = math.SafeAdd(gas, params.ColdAccountAccessCostEIP2929-params.WarmStorageReadCostEIP2929); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } - return gas, nil + return GasCosts{RegularGas: gas}, nil } // gasEip2929AccountCheck checks whether the first stack item (as address) is present in the access list. @@ -143,20 +144,20 @@ func gasExtCodeCopyEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memo // - extcodehash, // - extcodesize, // - (ext) balance -func gasEip2929AccountCheck(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasEip2929AccountCheck(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { addr := common.Address(stack.peek().Bytes20()) // Check slot presence in the access list if !evm.StateDB.AddressInAccessList(addr) { // If the caller cannot afford the cost, this change will be rolled back evm.StateDB.AddAddressToAccessList(addr) // The warm storage read cost is already charged as constantGas - return params.ColdAccountAccessCostEIP2929 - params.WarmStorageReadCostEIP2929, nil + return GasCosts{RegularGas: params.ColdAccountAccessCostEIP2929 - params.WarmStorageReadCostEIP2929}, nil } - return 0, nil + return GasCosts{}, nil } func makeCallVariantGasCallEIP2929(oldCalculator gasFunc, addressPosition int) gasFunc { - return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { addr := common.Address(stack.Back(addressPosition).Bytes20()) // Check slot presence in the access list warmAccess := evm.StateDB.AddressInAccessList(addr) @@ -168,7 +169,7 @@ func makeCallVariantGasCallEIP2929(oldCalculator gasFunc, addressPosition int) g // Charge the remaining difference here already, to correctly calculate available // gas for call if !contract.UseGas(coldCost, evm.Config.Tracer, tracing.GasChangeCallStorageColdAccess) { - return 0, ErrOutOfGas + return GasCosts{}, ErrOutOfGas } } // Now call the old calculator, which takes into account @@ -176,21 +177,22 @@ func makeCallVariantGasCallEIP2929(oldCalculator gasFunc, addressPosition int) g // - transfer value // - memory expansion // - 63/64ths rule - gas, err := oldCalculator(evm, contract, stack, mem, memorySize) + gasCost, err := oldCalculator(evm, contract, stack, mem, memorySize) if warmAccess || err != nil { - return gas, err + return gasCost, err } // In case of a cold access, we temporarily add the cold charge back, and also // add it to the returned gas. By adding it to the return, it will be charged // outside of this function, as part of the dynamic gas, and that will make it // also become correctly reported to tracers. - contract.Gas += coldCost + contract.Gas.RegularGas += coldCost + gas := gasCost.RegularGas var overflow bool if gas, overflow = math.SafeAdd(gas, coldCost); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } } @@ -224,13 +226,13 @@ var ( // makeSelfdestructGasFn can create the selfdestruct dynamic gas function for EIP-2929 and EIP-3529 func makeSelfdestructGasFn(refundsEnabled bool) gasFunc { - gasFunc := func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + gasFunc := func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { var ( gas uint64 address = common.Address(stack.peek().Bytes20()) ) if evm.readOnly { - return 0, ErrWriteProtection + return GasCosts{}, ErrWriteProtection } if !evm.StateDB.AddressInAccessList(address) { // If the caller cannot afford the cost, this change will be rolled back @@ -239,8 +241,8 @@ func makeSelfdestructGasFn(refundsEnabled bool) gasFunc { // Terminate the gas measurement if the leftover gas is not sufficient, // it can effectively prevent accessing the states in the following steps - if contract.Gas < gas { - return 0, ErrOutOfGas + if contract.Gas.RegularGas < gas { + return GasCosts{}, ErrOutOfGas } } // if empty and transfers value @@ -250,7 +252,7 @@ func makeSelfdestructGasFn(refundsEnabled bool) gasFunc { if refundsEnabled && !evm.StateDB.HasSelfDestructed(contract.Address()) { evm.StateDB.AddRefund(params.SelfdestructRefundGas) } - return gas, nil + return GasCosts{RegularGas: gas}, nil } return gasFunc } @@ -262,20 +264,20 @@ var ( gasCallCodeEIP7702 = makeCallVariantGasCallEIP7702(gasCallCodeIntrinsic) ) -func gasCallEIP7702(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasCallEIP7702(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { // Return early if this call attempts to transfer value in a static context. // Although it's checked in `gasCall`, EIP-7702 loads the target's code before // to determine if it is resolving a delegation. This could incorrectly record // the target in the block access list (BAL) if the call later fails. transfersValue := !stack.Back(2).IsZero() if evm.readOnly && transfersValue { - return 0, ErrWriteProtection + return GasCosts{}, ErrWriteProtection } return innerGasCallEIP7702(evm, contract, stack, mem, memorySize) } func makeCallVariantGasCallEIP7702(intrinsicFunc gasFunc) gasFunc { - return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { var ( eip2929Cost uint64 eip7702Cost uint64 @@ -294,7 +296,7 @@ func makeCallVariantGasCallEIP7702(intrinsicFunc gasFunc) gasFunc { // Charge the remaining difference here already, to correctly calculate // available gas for call if !contract.UseGas(eip2929Cost, evm.Config.Tracer, tracing.GasChangeCallStorageColdAccess) { - return 0, ErrOutOfGas + return GasCosts{}, ErrOutOfGas } } @@ -305,13 +307,13 @@ func makeCallVariantGasCallEIP7702(intrinsicFunc gasFunc) gasFunc { // - create new account intrinsicCost, err := intrinsicFunc(evm, contract, stack, mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } // Terminate the gas measurement if the leftover gas is not sufficient, // it can effectively prevent accessing the states in the following steps. // It's an essential safeguard before any stateful check. - if contract.Gas < intrinsicCost { - return 0, ErrOutOfGas + if contract.Gas.RegularGas < intrinsicCost.RegularGas { + return GasCosts{}, ErrOutOfGas } // Check if code is a delegation and if so, charge for resolution. @@ -323,20 +325,20 @@ func makeCallVariantGasCallEIP7702(intrinsicFunc gasFunc) gasFunc { eip7702Cost = params.ColdAccountAccessCostEIP2929 } if !contract.UseGas(eip7702Cost, evm.Config.Tracer, tracing.GasChangeCallStorageColdAccess) { - return 0, ErrOutOfGas + return GasCosts{}, ErrOutOfGas } } // Calculate the gas budget for the nested call. The costs defined by // EIP-2929 and EIP-7702 have already been applied. - evm.callGasTemp, err = callGas(evm.chainRules.IsEIP150, contract.Gas, intrinsicCost, stack.Back(0)) + evm.callGasTemp, err = callGas(evm.chainRules.IsEIP150, contract.Gas.RegularGas, intrinsicCost.RegularGas, stack.Back(0)) if err != nil { - return 0, err + return GasCosts{}, err } // Temporarily add the gas charge back to the contract and return value. By // adding it to the return, it will be charged outside of this function, as // part of the dynamic gas. This will ensure it is correctly reported to // tracers. - contract.Gas += eip2929Cost + eip7702Cost + contract.Gas.RegularGas += eip2929Cost + eip7702Cost // Aggregate the gas costs from all components, including EIP-2929, EIP-7702, // the CALL opcode itself, and the cost incurred by nested calls. @@ -345,14 +347,14 @@ func makeCallVariantGasCallEIP7702(intrinsicFunc gasFunc) gasFunc { totalCost uint64 ) if totalCost, overflow = math.SafeAdd(eip2929Cost, eip7702Cost); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - if totalCost, overflow = math.SafeAdd(totalCost, intrinsicCost); overflow { - return 0, ErrGasUintOverflow + if totalCost, overflow = math.SafeAdd(totalCost, intrinsicCost.RegularGas); overflow { + return GasCosts{}, ErrGasUintOverflow } if totalCost, overflow = math.SafeAdd(totalCost, evm.callGasTemp); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return totalCost, nil + return GasCosts{RegularGas: totalCost}, nil } } diff --git a/core/vm/operations_verkle.go b/core/vm/operations_verkle.go index 30f9957775..d57f2c4dcf 100644 --- a/core/vm/operations_verkle.go +++ b/core/vm/operations_verkle.go @@ -24,37 +24,37 @@ import ( "github.com/ethereum/go-ethereum/params" ) -func gasSStore4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { - return evm.AccessEvents.SlotGas(contract.Address(), stack.peek().Bytes32(), true, contract.Gas, true), nil +func gasSStore4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { + return GasCosts{RegularGas: evm.AccessEvents.SlotGas(contract.Address(), stack.peek().Bytes32(), true, contract.Gas.RegularGas, true)}, nil } -func gasSLoad4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { - return evm.AccessEvents.SlotGas(contract.Address(), stack.peek().Bytes32(), false, contract.Gas, true), nil +func gasSLoad4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { + return GasCosts{RegularGas: evm.AccessEvents.SlotGas(contract.Address(), stack.peek().Bytes32(), false, contract.Gas.RegularGas, true)}, nil } -func gasBalance4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasBalance4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { address := stack.peek().Bytes20() - return evm.AccessEvents.BasicDataGas(address, false, contract.Gas, true), nil + return GasCosts{RegularGas: evm.AccessEvents.BasicDataGas(address, false, contract.Gas.RegularGas, true)}, nil } -func gasExtCodeSize4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasExtCodeSize4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { address := stack.peek().Bytes20() if _, isPrecompile := evm.precompile(address); isPrecompile { - return 0, nil + return GasCosts{}, nil } - return evm.AccessEvents.BasicDataGas(address, false, contract.Gas, true), nil + return GasCosts{RegularGas: evm.AccessEvents.BasicDataGas(address, false, contract.Gas.RegularGas, true)}, nil } -func gasExtCodeHash4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasExtCodeHash4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { address := stack.peek().Bytes20() if _, isPrecompile := evm.precompile(address); isPrecompile { - return 0, nil + return GasCosts{}, nil } - return evm.AccessEvents.CodeHashGas(address, false, contract.Gas, true), nil + return GasCosts{RegularGas: evm.AccessEvents.CodeHashGas(address, false, contract.Gas.RegularGas, true)}, nil } func makeCallVariantGasEIP4762(oldCalculator gasFunc, withTransferCosts bool) gasFunc { - return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { var ( target = common.Address(stack.Back(1).Bytes20()) witnessGas uint64 @@ -65,9 +65,9 @@ func makeCallVariantGasEIP4762(oldCalculator gasFunc, withTransferCosts bool) ga // If value is transferred, it is charged before 1/64th // is subtracted from the available gas pool. if withTransferCosts && !stack.Back(2).IsZero() { - wantedValueTransferWitnessGas := evm.AccessEvents.ValueTransferGas(contract.Address(), target, contract.Gas) - if wantedValueTransferWitnessGas > contract.Gas { - return wantedValueTransferWitnessGas, nil + wantedValueTransferWitnessGas := evm.AccessEvents.ValueTransferGas(contract.Address(), target, contract.Gas.RegularGas) + if wantedValueTransferWitnessGas > contract.Gas.RegularGas { + return GasCosts{RegularGas: wantedValueTransferWitnessGas}, nil } witnessGas = wantedValueTransferWitnessGas } else if isPrecompile || isSystemContract { @@ -78,25 +78,26 @@ func makeCallVariantGasEIP4762(oldCalculator gasFunc, withTransferCosts bool) ga // (so before we get to this point) // But the message call is part of the subcall, for which only 63/64th // of the gas should be available. - wantedMessageCallWitnessGas := evm.AccessEvents.MessageCallGas(target, contract.Gas-witnessGas) + wantedMessageCallWitnessGas := evm.AccessEvents.MessageCallGas(target, contract.Gas.RegularGas-witnessGas) var overflow bool if witnessGas, overflow = math.SafeAdd(witnessGas, wantedMessageCallWitnessGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - if witnessGas > contract.Gas { - return witnessGas, nil + if witnessGas > contract.Gas.RegularGas { + return GasCosts{RegularGas: witnessGas}, nil } } - contract.Gas -= witnessGas + contract.Gas.RegularGas -= witnessGas // if the operation fails, adds witness gas to the gas before returning the error - gas, err := oldCalculator(evm, contract, stack, mem, memorySize) - contract.Gas += witnessGas // restore witness gas so that it can be charged at the callsite + gasCost, err := oldCalculator(evm, contract, stack, mem, memorySize) + contract.Gas.RegularGas += witnessGas // restore witness gas so that it can be charged at the callsite + gas := gasCost.RegularGas var overflow bool if gas, overflow = math.SafeAdd(gas, witnessGas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, err + return GasCosts{RegularGas: gas}, err } } @@ -107,18 +108,18 @@ var ( gasDelegateCallEIP4762 = makeCallVariantGasEIP4762(gasDelegateCall, false) ) -func gasSelfdestructEIP4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasSelfdestructEIP4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { beneficiaryAddr := common.Address(stack.peek().Bytes20()) if _, isPrecompile := evm.precompile(beneficiaryAddr); isPrecompile { - return 0, nil + return GasCosts{}, nil } if contract.IsSystemCall { - return 0, nil + return GasCosts{}, nil } contractAddr := contract.Address() - wanted := evm.AccessEvents.BasicDataGas(contractAddr, false, contract.Gas, false) - if wanted > contract.Gas { - return wanted, nil + wanted := evm.AccessEvents.BasicDataGas(contractAddr, false, contract.Gas.RegularGas, false) + if wanted > contract.Gas.RegularGas { + return GasCosts{RegularGas: wanted}, nil } statelessGas := wanted balanceIsZero := evm.StateDB.GetBalance(contractAddr).Sign() == 0 @@ -126,44 +127,45 @@ func gasSelfdestructEIP4762(evm *EVM, contract *Contract, stack *Stack, mem *Mem isSystemContract := beneficiaryAddr == params.HistoryStorageAddress if (isPrecompile || isSystemContract) && balanceIsZero { - return statelessGas, nil + return GasCosts{RegularGas: statelessGas}, nil } if contractAddr != beneficiaryAddr { - wanted := evm.AccessEvents.BasicDataGas(beneficiaryAddr, false, contract.Gas-statelessGas, false) - if wanted > contract.Gas-statelessGas { - return statelessGas + wanted, nil + wanted := evm.AccessEvents.BasicDataGas(beneficiaryAddr, false, contract.Gas.RegularGas-statelessGas, false) + if wanted > contract.Gas.RegularGas-statelessGas { + return GasCosts{RegularGas: statelessGas + wanted}, nil } statelessGas += wanted } // Charge write costs if it transfers value if !balanceIsZero { - wanted := evm.AccessEvents.BasicDataGas(contractAddr, true, contract.Gas-statelessGas, false) - if wanted > contract.Gas-statelessGas { - return statelessGas + wanted, nil + wanted := evm.AccessEvents.BasicDataGas(contractAddr, true, contract.Gas.RegularGas-statelessGas, false) + if wanted > contract.Gas.RegularGas-statelessGas { + return GasCosts{RegularGas: statelessGas + wanted}, nil } statelessGas += wanted if contractAddr != beneficiaryAddr { if evm.StateDB.Exist(beneficiaryAddr) { - wanted = evm.AccessEvents.BasicDataGas(beneficiaryAddr, true, contract.Gas-statelessGas, false) + wanted = evm.AccessEvents.BasicDataGas(beneficiaryAddr, true, contract.Gas.RegularGas-statelessGas, false) } else { - wanted = evm.AccessEvents.AddAccount(beneficiaryAddr, true, contract.Gas-statelessGas) + wanted = evm.AccessEvents.AddAccount(beneficiaryAddr, true, contract.Gas.RegularGas-statelessGas) } - if wanted > contract.Gas-statelessGas { - return statelessGas + wanted, nil + if wanted > contract.Gas.RegularGas-statelessGas { + return GasCosts{RegularGas: statelessGas + wanted}, nil } statelessGas += wanted } } - return statelessGas, nil + return GasCosts{RegularGas: statelessGas}, nil } -func gasCodeCopyEip4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { - gas, err := gasCodeCopy(evm, contract, stack, mem, memorySize) +func gasCodeCopyEip4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { + gasCost, err := gasCodeCopy(evm, contract, stack, mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } + gas := gasCost.RegularGas if !contract.IsDeployment && !contract.IsSystemCall { var ( codeOffset = stack.Back(1) @@ -175,31 +177,32 @@ func gasCodeCopyEip4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, } _, copyOffset, nonPaddedCopyLength := getDataAndAdjustedBounds(contract.Code, uint64CodeOffset, length.Uint64()) - _, wanted := evm.AccessEvents.CodeChunksRangeGas(contract.Address(), copyOffset, nonPaddedCopyLength, uint64(len(contract.Code)), false, contract.Gas-gas) + _, wanted := evm.AccessEvents.CodeChunksRangeGas(contract.Address(), copyOffset, nonPaddedCopyLength, uint64(len(contract.Code)), false, contract.Gas.RegularGas-gas) gas += wanted } - return gas, nil + return GasCosts{RegularGas: gas}, nil } -func gasExtCodeCopyEIP4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { +func gasExtCodeCopyEIP4762(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (GasCosts, error) { // memory expansion first (dynamic part of pre-2929 implementation) - gas, err := gasExtCodeCopy(evm, contract, stack, mem, memorySize) + gasCost, err := gasExtCodeCopy(evm, contract, stack, mem, memorySize) if err != nil { - return 0, err + return GasCosts{}, err } + gas := gasCost.RegularGas addr := common.Address(stack.peek().Bytes20()) _, isPrecompile := evm.precompile(addr) if isPrecompile || addr == params.HistoryStorageAddress { var overflow bool if gas, overflow = math.SafeAdd(gas, params.WarmStorageReadCostEIP2929); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } - wgas := evm.AccessEvents.BasicDataGas(addr, false, contract.Gas-gas, true) + wgas := evm.AccessEvents.BasicDataGas(addr, false, contract.Gas.RegularGas-gas, true) var overflow bool if gas, overflow = math.SafeAdd(gas, wgas); overflow { - return 0, ErrGasUintOverflow + return GasCosts{}, ErrGasUintOverflow } - return gas, nil + return GasCosts{RegularGas: gas}, nil } diff --git a/eth/api_backend.go b/eth/api_backend.go index 726d8316a0..a4e976b1b8 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -414,9 +414,10 @@ func (b *EthAPIBackend) SyncProgress(ctx context.Context) ethereum.SyncProgress prog.TxIndexFinishedBlocks = txProg.Indexed prog.TxIndexRemainingBlocks = txProg.Remaining } - remain, err := b.eth.blockchain.StateIndexProgress() + stateRemain, trienodeRemain, err := b.eth.blockchain.StateIndexProgress() if err == nil { - prog.StateIndexRemaining = remain + prog.StateIndexRemaining = stateRemain + prog.TrienodeIndexRemaining = trienodeRemain } return prog } @@ -484,12 +485,12 @@ func (b *EthAPIBackend) CurrentHeader() *types.Header { return b.eth.blockchain.CurrentHeader() } -func (b *EthAPIBackend) StateAtBlock(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, readOnly bool, preferDisk bool) (*state.StateDB, tracers.StateReleaseFunc, error) { - return b.eth.stateAtBlock(ctx, block, reexec, base, readOnly, preferDisk) +func (b *EthAPIBackend) StateAtBlock(ctx context.Context, block *types.Block, base *state.StateDB, readOnly bool, preferDisk bool) (*state.StateDB, tracers.StateReleaseFunc, error) { + return b.eth.stateAtBlock(ctx, block, base, readOnly, preferDisk) } -func (b *EthAPIBackend) StateAtTransaction(ctx context.Context, block *types.Block, txIndex int, reexec uint64) (*types.Transaction, vm.BlockContext, *state.StateDB, tracers.StateReleaseFunc, error) { - return b.eth.stateAtTransaction(ctx, block, txIndex, reexec) +func (b *EthAPIBackend) StateAtTransaction(ctx context.Context, block *types.Block, txIndex int) (*types.Transaction, vm.BlockContext, *state.StateDB, tracers.StateReleaseFunc, error) { + return b.eth.stateAtTransaction(ctx, block, txIndex) } func (b *EthAPIBackend) RPCTxSyncDefaultTimeout() time.Duration { diff --git a/eth/api_debug.go b/eth/api_debug.go index b8267902b2..5dd535e672 100644 --- a/eth/api_debug.go +++ b/eth/api_debug.go @@ -222,7 +222,7 @@ func (api *DebugAPI) StorageRangeAt(ctx context.Context, blockNrOrHash rpc.Block if block == nil { return StorageRangeResult{}, fmt.Errorf("block %v not found", blockNrOrHash) } - _, _, statedb, release, err := api.eth.stateAtTransaction(ctx, block, txIndex, 0) + _, _, statedb, release, err := api.eth.stateAtTransaction(ctx, block, txIndex) if err != nil { return StorageRangeResult{}, err } @@ -236,6 +236,8 @@ func storageRangeAt(statedb *state.StateDB, root common.Hash, address common.Add if storageRoot == types.EmptyRootHash || storageRoot == (common.Hash{}) { return StorageRangeResult{}, nil // empty storage } + // TODO(rjl493456442) it's problematic for traversing the state with in-memory + // state mutations, specifically txIndex != 0. id := trie.StorageTrieID(root, crypto.Keccak256Hash(address.Bytes()), storageRoot) tr, err := trie.NewStateTrie(id, statedb.Database().TrieDB()) if err != nil { diff --git a/eth/downloader/api.go b/eth/downloader/api.go index f97371de5f..1fea35775e 100644 --- a/eth/downloader/api.go +++ b/eth/downloader/api.go @@ -81,9 +81,10 @@ func (api *DownloaderAPI) eventLoop() { prog.TxIndexFinishedBlocks = txProg.Indexed prog.TxIndexRemainingBlocks = txProg.Remaining } - remain, err := api.chain.StateIndexProgress() + stateRemain, trienodeRemain, err := api.chain.StateIndexProgress() if err == nil { - prog.StateIndexRemaining = remain + prog.StateIndexRemaining = stateRemain + prog.TrienodeIndexRemaining = trienodeRemain } return prog } diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 01a994dbfd..9280d455fb 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -370,7 +370,7 @@ func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, cou Paths: encPaths, Bytes: uint64(bytes), } - nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req, time.Now()) + nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req) go dlp.dl.downloader.SnapSyncer.OnTrieNodes(dlp, id, nodes) return nil } diff --git a/eth/filters/filter.go b/eth/filters/filter.go index 04e11f0475..f31b9568cd 100644 --- a/eth/filters/filter.go +++ b/eth/filters/filter.go @@ -19,7 +19,6 @@ package filters import ( "context" "errors" - "fmt" "math" "math/big" "slices" @@ -147,7 +146,7 @@ func (f *Filter) Logs(ctx context.Context) ([]*types.Log, error) { return nil, err } if f.rangeLimit != 0 && (end-begin) > f.rangeLimit { - return nil, fmt.Errorf("exceed maximum block range: %d", f.rangeLimit) + return nil, invalidParamsErr("exceed maximum block range %d", f.rangeLimit) } return f.rangeLogs(ctx, begin, end) } diff --git a/eth/filters/filter_test.go b/eth/filters/filter_test.go index e7b1b08046..c133438c64 100644 --- a/eth/filters/filter_test.go +++ b/eth/filters/filter_test.go @@ -19,6 +19,7 @@ package filters import ( "context" "encoding/json" + "errors" "math/big" "strings" "testing" @@ -634,7 +635,19 @@ func TestRangeLimit(t *testing.T) { // Set rangeLimit to 5, but request a range of 9 (end - begin = 9, from 0 to 9) filter := sys.NewRangeFilter(0, 9, nil, nil, 5) _, err = filter.Logs(context.Background()) - if err == nil || !strings.Contains(err.Error(), "exceed maximum block range") { - t.Fatalf("expected range limit error, got %v", err) + if err == nil { + t.Fatal("expected range limit error, got nil") + } + + var re rpc.Error + if errors.As(err, &re) { + if re.ErrorCode() != -32602 { + t.Fatalf("expected error code -32602, got %d", re.ErrorCode()) + } + if re.Error() != "exceed maximum block range 5" { + t.Fatalf("expected error message 'exceed maximum block range 5', got %q", re.Error()) + } + } else { + t.Fatalf("expected rpc error, got %v", err) } } diff --git a/eth/gasestimator/gasestimator.go b/eth/gasestimator/gasestimator.go index 80aeb3d3b2..ad6491fd93 100644 --- a/eth/gasestimator/gasestimator.go +++ b/eth/gasestimator/gasestimator.go @@ -27,7 +27,6 @@ import ( "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" - "github.com/ethereum/go-ethereum/internal/ethapi/override" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" ) @@ -38,11 +37,12 @@ import ( // these together, it would be excessively hard to test. Splitting the parts out // allows testing without needing a proper live chain. type Options struct { - Config *params.ChainConfig // Chain configuration for hard fork selection - Chain core.ChainContext // Chain context to access past block hashes - Header *types.Header // Header defining the block context to execute in - State *state.StateDB // Pre-state on top of which to estimate the gas - BlockOverrides *override.BlockOverrides // Block overrides to apply during the estimation + Config *params.ChainConfig // Chain configuration for hard fork selection + Chain core.ChainContext // Chain context to access past block hashes + Header *types.Header // Header defining the block context to execute in + State *state.StateDB // Pre-state on top of which to estimate the gas + + BlobBaseFee *big.Int // BlobBaseFee optionally overrides the blob base fee in the execution context. ErrorRatio float64 // Allowed overestimation ratio for faster estimation termination } @@ -64,16 +64,7 @@ func Estimate(ctx context.Context, call *core.Message, opts *Options, gasCap uin // Cap the maximum gas allowance according to EIP-7825 if the estimation targets Osaka if hi > params.MaxTxGas { - blockNumber, blockTime := opts.Header.Number, opts.Header.Time - if opts.BlockOverrides != nil { - if opts.BlockOverrides.Number != nil { - blockNumber = opts.BlockOverrides.Number.ToInt() - } - if opts.BlockOverrides.Time != nil { - blockTime = uint64(*opts.BlockOverrides.Time) - } - } - if opts.Config.IsOsaka(blockNumber, blockTime) { + if opts.Config.IsOsaka(opts.Header.Number, opts.Header.Time) { hi = params.MaxTxGas } } @@ -241,10 +232,8 @@ func run(ctx context.Context, call *core.Message, opts *Options) (*core.Executio evmContext = core.NewEVMBlockContext(opts.Header, opts.Chain, nil) dirtyState = opts.State.Copy() ) - if opts.BlockOverrides != nil { - if err := opts.BlockOverrides.Apply(&evmContext); err != nil { - return nil, err - } + if opts.BlobBaseFee != nil { + evmContext.BlobBaseFee = new(big.Int).Set(opts.BlobBaseFee) } // Lower the basefee to 0 to avoid breaking EVM // invariants (basefee < feecap). diff --git a/eth/protocols/eth/handler.go b/eth/protocols/eth/handler.go index 59512f5be7..f7d25bd8ca 100644 --- a/eth/protocols/eth/handler.go +++ b/eth/protocols/eth/handler.go @@ -167,7 +167,6 @@ func Handle(backend Backend, peer *Peer) error { type msgHandler func(backend Backend, msg Decoder, peer *Peer) error type Decoder interface { Decode(val interface{}) error - Time() time.Time } var eth69 = map[uint64]msgHandler{ diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index 071a0419fb..26545f2960 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -17,25 +17,14 @@ package snap import ( - "bytes" "fmt" "time" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/state/snapshot" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" - "github.com/ethereum/go-ethereum/p2p/tracker" - "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" - "github.com/ethereum/go-ethereum/trie/trienode" - "github.com/ethereum/go-ethereum/triedb/database" ) const ( @@ -55,6 +44,10 @@ const ( // number is there to limit the number of disk lookups. maxTrieNodeLookups = 1024 + // maxAccessListLookups is the maximum number of BALs to server. This number + // is there to limit the number of disk lookups. + maxAccessListLookups = 1024 + // maxTrieNodeTimeSpent is the maximum time we should spend on looking up trie nodes. // If we spend too much time, then it's a fairly high chance of timing out // at the remote side, which means all the work is in vain. @@ -123,6 +116,34 @@ func Handle(backend Backend, peer *Peer) error { } } +type msgHandler func(backend Backend, msg Decoder, peer *Peer) error +type Decoder interface { + Decode(val interface{}) error +} + +var snap1 = map[uint64]msgHandler{ + GetAccountRangeMsg: handleGetAccountRange, + AccountRangeMsg: handleAccountRange, + GetStorageRangesMsg: handleGetStorageRanges, + StorageRangesMsg: handleStorageRanges, + GetByteCodesMsg: handleGetByteCodes, + ByteCodesMsg: handleByteCodes, + GetTrieNodesMsg: handleGetTrienodes, + TrieNodesMsg: handleTrieNodes, +} + +// nolint:unused +var snap2 = map[uint64]msgHandler{ + GetAccountRangeMsg: handleGetAccountRange, + AccountRangeMsg: handleAccountRange, + GetStorageRangesMsg: handleGetStorageRanges, + StorageRangesMsg: handleStorageRanges, + GetByteCodesMsg: handleGetByteCodes, + ByteCodesMsg: handleByteCodes, + GetAccessListsMsg: handleGetAccessLists, + // AccessListsMsg: TODO +} + // HandleMessage is invoked whenever an inbound message is received from a // remote peer on the `snap` protocol. The remote connection is torn down upon // returning any error. @@ -136,8 +157,19 @@ func HandleMessage(backend Backend, peer *Peer) error { return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize) } defer msg.Discard() - start := time.Now() + + var handlers map[uint64]msgHandler + switch peer.version { + case SNAP1: + handlers = snap1 + //case SNAP2: + // handlers = snap2 + default: + return fmt.Errorf("unknown eth protocol version: %v", peer.version) + } + // Track the amount of time it takes to serve the request and run the handler + start := time.Now() if metrics.Enabled() { h := fmt.Sprintf("%s/%s/%d/%#02x", p2p.HandleHistName, ProtocolName, peer.Version(), msg.Code) defer func(start time.Time) { @@ -149,520 +181,11 @@ func HandleMessage(backend Backend, peer *Peer) error { metrics.GetOrRegisterHistogramLazy(h, nil, sampler).Update(time.Since(start).Microseconds()) }(start) } - // Handle the message depending on its contents - switch { - case msg.Code == GetAccountRangeMsg: - var req GetAccountRangePacket - if err := msg.Decode(&req); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - // Service the request, potentially returning nothing in case of errors - accounts, proofs := ServiceGetAccountRangeQuery(backend.Chain(), &req) - // Send back anything accumulated (or empty in case of errors) - return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ - ID: req.ID, - Accounts: accounts, - Proof: proofs, - }) - - case msg.Code == AccountRangeMsg: - res := new(accountRangeInput) - if err := msg.Decode(res); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - - // Check response validity. - if len := res.Proof.Len(); len > 128 { - return fmt.Errorf("AccountRange: invalid proof (length %d)", len) - } - tresp := tracker.Response{ID: res.ID, MsgCode: AccountRangeMsg, Size: len(res.Accounts.Content())} - if err := peer.tracker.Fulfil(tresp); err != nil { - return err - } - - // Decode. - accounts, err := res.Accounts.Items() - if err != nil { - return fmt.Errorf("AccountRange: invalid accounts list: %v", err) - } - proof, err := res.Proof.Items() - if err != nil { - return fmt.Errorf("AccountRange: invalid proof: %v", err) - } - - // Ensure the range is monotonically increasing - for i := 1; i < len(accounts); i++ { - if bytes.Compare(accounts[i-1].Hash[:], accounts[i].Hash[:]) >= 0 { - return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, accounts[i-1].Hash[:], i, accounts[i].Hash[:]) - } - } - - return backend.Handle(peer, &AccountRangePacket{res.ID, accounts, proof}) - - case msg.Code == GetStorageRangesMsg: - var req GetStorageRangesPacket - if err := msg.Decode(&req); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - // Service the request, potentially returning nothing in case of errors - slots, proofs := ServiceGetStorageRangesQuery(backend.Chain(), &req) - - // Send back anything accumulated (or empty in case of errors) - return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ - ID: req.ID, - Slots: slots, - Proof: proofs, - }) - - case msg.Code == StorageRangesMsg: - res := new(storageRangesInput) - if err := msg.Decode(res); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - - // Check response validity. - if len := res.Proof.Len(); len > 128 { - return fmt.Errorf("StorageRangesMsg: invalid proof (length %d)", len) - } - tresp := tracker.Response{ID: res.ID, MsgCode: StorageRangesMsg, Size: len(res.Slots.Content())} - if err := peer.tracker.Fulfil(tresp); err != nil { - return fmt.Errorf("StorageRangesMsg: %w", err) - } - - // Decode. - slotLists, err := res.Slots.Items() - if err != nil { - return fmt.Errorf("AccountRange: invalid accounts list: %v", err) - } - proof, err := res.Proof.Items() - if err != nil { - return fmt.Errorf("AccountRange: invalid proof: %v", err) - } - - // Ensure the ranges are monotonically increasing - for i, slots := range slotLists { - for j := 1; j < len(slots); j++ { - if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 { - return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:]) - } - } - } - - return backend.Handle(peer, &StorageRangesPacket{res.ID, slotLists, proof}) - - case msg.Code == GetByteCodesMsg: - var req GetByteCodesPacket - if err := msg.Decode(&req); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - // Service the request, potentially returning nothing in case of errors - codes := ServiceGetByteCodesQuery(backend.Chain(), &req) - - // Send back anything accumulated (or empty in case of errors) - return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{ - ID: req.ID, - Codes: codes, - }) - - case msg.Code == ByteCodesMsg: - res := new(byteCodesInput) - if err := msg.Decode(res); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - - length := res.Codes.Len() - tresp := tracker.Response{ID: res.ID, MsgCode: ByteCodesMsg, Size: length} - if err := peer.tracker.Fulfil(tresp); err != nil { - return fmt.Errorf("ByteCodes: %w", err) - } - - codes, err := res.Codes.Items() - if err != nil { - return fmt.Errorf("ByteCodes: %w", err) - } - - return backend.Handle(peer, &ByteCodesPacket{res.ID, codes}) - - case msg.Code == GetTrieNodesMsg: - var req GetTrieNodesPacket - if err := msg.Decode(&req); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - // Service the request, potentially returning nothing in case of errors - nodes, err := ServiceGetTrieNodesQuery(backend.Chain(), &req, start) - if err != nil { - return err - } - // Send back anything accumulated (or empty in case of errors) - return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ - ID: req.ID, - Nodes: nodes, - }) - - case msg.Code == TrieNodesMsg: - res := new(trieNodesInput) - if err := msg.Decode(res); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - - tresp := tracker.Response{ID: res.ID, MsgCode: TrieNodesMsg, Size: res.Nodes.Len()} - if err := peer.tracker.Fulfil(tresp); err != nil { - return fmt.Errorf("TrieNodes: %w", err) - } - nodes, err := res.Nodes.Items() - if err != nil { - return fmt.Errorf("TrieNodes: %w", err) - } - - return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes}) - - default: - return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code) + if handler := handlers[msg.Code]; handler != nil { + return handler(backend, msg, peer) } -} - -// ServiceGetAccountRangeQuery assembles the response to an account range query. -// It is exposed to allow external packages to test protocol behavior. -func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePacket) ([]*AccountData, [][]byte) { - if req.Bytes > softResponseLimit { - req.Bytes = softResponseLimit - } - // Retrieve the requested state and bail out if non existent - tr, err := trie.New(trie.StateTrieID(req.Root), chain.TrieDB()) - if err != nil { - return nil, nil - } - // Temporary solution: using the snapshot interface for both cases. - // This can be removed once the hash scheme is deprecated. - var it snapshot.AccountIterator - if chain.TrieDB().Scheme() == rawdb.HashScheme { - // The snapshot is assumed to be available in hash mode if - // the SNAP protocol is enabled. - it, err = chain.Snapshots().AccountIterator(req.Root, req.Origin) - } else { - it, err = chain.TrieDB().AccountIterator(req.Root, req.Origin) - } - if err != nil { - return nil, nil - } - // Iterate over the requested range and pile accounts up - var ( - accounts []*AccountData - size uint64 - last common.Hash - ) - for it.Next() { - hash, account := it.Hash(), common.CopyBytes(it.Account()) - - // Track the returned interval for the Merkle proofs - last = hash - - // Assemble the reply item - size += uint64(common.HashLength + len(account)) - accounts = append(accounts, &AccountData{ - Hash: hash, - Body: account, - }) - // If we've exceeded the request threshold, abort - if bytes.Compare(hash[:], req.Limit[:]) >= 0 { - break - } - if size > req.Bytes { - break - } - } - it.Release() - - // Generate the Merkle proofs for the first and last account - proof := trienode.NewProofSet() - if err := tr.Prove(req.Origin[:], proof); err != nil { - log.Warn("Failed to prove account range", "origin", req.Origin, "err", err) - return nil, nil - } - if last != (common.Hash{}) { - if err := tr.Prove(last[:], proof); err != nil { - log.Warn("Failed to prove account range", "last", last, "err", err) - return nil, nil - } - } - return accounts, proof.List() -} - -func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesPacket) ([][]*StorageData, [][]byte) { - if req.Bytes > softResponseLimit { - req.Bytes = softResponseLimit - } - // TODO(karalabe): Do we want to enforce > 0 accounts and 1 account if origin is set? - // TODO(karalabe): - Logging locally is not ideal as remote faults annoy the local user - // TODO(karalabe): - Dropping the remote peer is less flexible wrt client bugs (slow is better than non-functional) - - // Calculate the hard limit at which to abort, even if mid storage trie - hardLimit := uint64(float64(req.Bytes) * (1 + stateLookupSlack)) - - // Retrieve storage ranges until the packet limit is reached - var ( - slots [][]*StorageData - proofs [][]byte - size uint64 - ) - for _, account := range req.Accounts { - // If we've exceeded the requested data limit, abort without opening - // a new storage range (that we'd need to prove due to exceeded size) - if size >= req.Bytes { - break - } - // The first account might start from a different origin and end sooner - var origin common.Hash - if len(req.Origin) > 0 { - origin, req.Origin = common.BytesToHash(req.Origin), nil - } - var limit = common.MaxHash - if len(req.Limit) > 0 { - limit, req.Limit = common.BytesToHash(req.Limit), nil - } - // Retrieve the requested state and bail out if non existent - var ( - err error - it snapshot.StorageIterator - ) - // Temporary solution: using the snapshot interface for both cases. - // This can be removed once the hash scheme is deprecated. - if chain.TrieDB().Scheme() == rawdb.HashScheme { - // The snapshot is assumed to be available in hash mode if - // the SNAP protocol is enabled. - it, err = chain.Snapshots().StorageIterator(req.Root, account, origin) - } else { - it, err = chain.TrieDB().StorageIterator(req.Root, account, origin) - } - if err != nil { - return nil, nil - } - // Iterate over the requested range and pile slots up - var ( - storage []*StorageData - last common.Hash - abort bool - ) - for it.Next() { - if size >= hardLimit { - abort = true - break - } - hash, slot := it.Hash(), common.CopyBytes(it.Slot()) - - // Track the returned interval for the Merkle proofs - last = hash - - // Assemble the reply item - size += uint64(common.HashLength + len(slot)) - storage = append(storage, &StorageData{ - Hash: hash, - Body: slot, - }) - // If we've exceeded the request threshold, abort - if bytes.Compare(hash[:], limit[:]) >= 0 { - break - } - } - if len(storage) > 0 { - slots = append(slots, storage) - } - it.Release() - - // Generate the Merkle proofs for the first and last storage slot, but - // only if the response was capped. If the entire storage trie included - // in the response, no need for any proofs. - if origin != (common.Hash{}) || (abort && len(storage) > 0) { - // Request started at a non-zero hash or was capped prematurely, add - // the endpoint Merkle proofs - accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), chain.TrieDB()) - if err != nil { - return nil, nil - } - acc, err := accTrie.GetAccountByHash(account) - if err != nil || acc == nil { - return nil, nil - } - id := trie.StorageTrieID(req.Root, account, acc.Root) - stTrie, err := trie.NewStateTrie(id, chain.TrieDB()) - if err != nil { - return nil, nil - } - proof := trienode.NewProofSet() - if err := stTrie.Prove(origin[:], proof); err != nil { - log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err) - return nil, nil - } - if last != (common.Hash{}) { - if err := stTrie.Prove(last[:], proof); err != nil { - log.Warn("Failed to prove storage range", "last", last, "err", err) - return nil, nil - } - } - proofs = append(proofs, proof.List()...) - // Proof terminates the reply as proofs are only added if a node - // refuses to serve more data (exception when a contract fetch is - // finishing, but that's that). - break - } - } - return slots, proofs -} - -// ServiceGetByteCodesQuery assembles the response to a byte codes query. -// It is exposed to allow external packages to test protocol behavior. -func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [][]byte { - if req.Bytes > softResponseLimit { - req.Bytes = softResponseLimit - } - if len(req.Hashes) > maxCodeLookups { - req.Hashes = req.Hashes[:maxCodeLookups] - } - // Retrieve bytecodes until the packet size limit is reached - var ( - codes [][]byte - bytes uint64 - ) - for _, hash := range req.Hashes { - if hash == types.EmptyCodeHash { - // Peers should not request the empty code, but if they do, at - // least sent them back a correct response without db lookups - codes = append(codes, []byte{}) - } else if blob := chain.ContractCodeWithPrefix(hash); len(blob) > 0 { - codes = append(codes, blob) - bytes += uint64(len(blob)) - } - if bytes > req.Bytes { - break - } - } - return codes -} - -// ServiceGetTrieNodesQuery assembles the response to a trie nodes query. -// It is exposed to allow external packages to test protocol behavior. -func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, start time.Time) ([][]byte, error) { - if req.Bytes > softResponseLimit { - req.Bytes = softResponseLimit - } - // Make sure we have the state associated with the request - triedb := chain.TrieDB() - - accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), triedb) - if err != nil { - // We don't have the requested state available, bail out - return nil, nil - } - // The 'reader' might be nil, in which case we cannot serve storage slots - // via snapshot. - var reader database.StateReader - if chain.Snapshots() != nil { - reader = chain.Snapshots().Snapshot(req.Root) - } - if reader == nil { - reader, _ = triedb.StateReader(req.Root) - } - - // Retrieve trie nodes until the packet size limit is reached - var ( - outerIt = req.Paths.ContentIterator() - nodes [][]byte - bytes uint64 - loads int // Trie hash expansions to count database reads - ) - for outerIt.Next() { - innerIt, err := rlp.NewListIterator(outerIt.Value()) - if err != nil { - return nodes, err - } - - switch innerIt.Count() { - case 0: - // Ensure we penalize invalid requests - return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest) - - case 1: - // If we're only retrieving an account trie node, fetch it directly - accKey := nextBytes(&innerIt) - if accKey == nil { - return nodes, fmt.Errorf("%w: invalid account node request", errBadRequest) - } - blob, resolved, err := accTrie.GetNode(accKey) - loads += resolved // always account database reads, even for failures - if err != nil { - break - } - nodes = append(nodes, blob) - bytes += uint64(len(blob)) - - default: - // Storage slots requested, open the storage trie and retrieve from there - accKey := nextBytes(&innerIt) - if accKey == nil { - return nodes, fmt.Errorf("%w: invalid account storage request", errBadRequest) - } - var stRoot common.Hash - if reader == nil { - // We don't have the requested state snapshotted yet (or it is stale), - // but can look up the account via the trie instead. - account, err := accTrie.GetAccountByHash(common.BytesToHash(accKey)) - loads += 8 // We don't know the exact cost of lookup, this is an estimate - if err != nil || account == nil { - break - } - stRoot = account.Root - } else { - account, err := reader.Account(common.BytesToHash(accKey)) - loads++ // always account database reads, even for failures - if err != nil || account == nil { - break - } - stRoot = common.BytesToHash(account.Root) - } - - id := trie.StorageTrieID(req.Root, common.BytesToHash(accKey), stRoot) - stTrie, err := trie.NewStateTrie(id, triedb) - loads++ // always account database reads, even for failures - if err != nil { - break - } - for innerIt.Next() { - path, _, err := rlp.SplitString(innerIt.Value()) - if err != nil { - return nil, fmt.Errorf("%w: invalid storage key: %v", errBadRequest, err) - } - blob, resolved, err := stTrie.GetNode(path) - loads += resolved // always account database reads, even for failures - if err != nil { - break - } - nodes = append(nodes, blob) - bytes += uint64(len(blob)) - - // Sanity check limits to avoid DoS on the store trie loads - if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent { - break - } - } - } - // Abort request processing if we've exceeded our limits - if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent { - break - } - } - return nodes, nil -} - -func nextBytes(it *rlp.Iterator) []byte { - if !it.Next() { - return nil - } - content, _, err := rlp.SplitString(it.Value()) - if err != nil { - return nil - } - return content + return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code) } // NodeInfo represents a short summary of the `snap` sub-protocol metadata diff --git a/eth/protocols/snap/handler_fuzzing_test.go b/eth/protocols/snap/handler_fuzzing_test.go index 4930ae9ae6..a52da0aac5 100644 --- a/eth/protocols/snap/handler_fuzzing_test.go +++ b/eth/protocols/snap/handler_fuzzing_test.go @@ -60,6 +60,12 @@ func FuzzTrieNodes(f *testing.F) { }) } +func FuzzAccessLists(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + doFuzz(data, &GetAccessListsPacket{}, GetAccessListsMsg) + }) +} + func doFuzz(input []byte, obj interface{}, code int) { bc := getChain() defer bc.Stop() diff --git a/eth/protocols/snap/handler_test.go b/eth/protocols/snap/handler_test.go new file mode 100644 index 0000000000..3f6a43a059 --- /dev/null +++ b/eth/protocols/snap/handler_test.go @@ -0,0 +1,314 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snap + +import ( + "bytes" + "encoding/binary" + "reflect" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/beacon" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types/bal" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" +) + +func makeTestBAL(minSize int) *bal.BlockAccessList { + n := minSize/33 + 1 // 33 bytes per storage read slot in RLP + access := bal.AccountAccess{ + Address: common.HexToAddress("0x01"), + StorageReads: make([][32]byte, n), + } + for i := range access.StorageReads { + binary.BigEndian.PutUint64(access.StorageReads[i][24:], uint64(i)) + } + return &bal.BlockAccessList{Accesses: []bal.AccountAccess{access}} +} + +// getChainWithBALs creates a minimal test chain with BALs stored for each block. +// It returns the chain, block hashes, and the stored BAL data. +func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash, []rlp.RawValue) { + gspec := &core.Genesis{ + Config: params.MergedTestChainConfig, + } + db := rawdb.NewMemoryDatabase() + engine := beacon.New(ethash.NewFaker()) + _, blocks, _ := core.GenerateChainWithGenesis(gspec, engine, nBlocks, func(i int, gen *core.BlockGen) {}) + options := &core.BlockChainConfig{ + StateScheme: rawdb.PathScheme, + TrieTimeLimit: 5 * time.Minute, + NoPrefetch: true, + } + bc, err := core.NewBlockChain(db, gspec, engine, options) + if err != nil { + panic(err) + } + if _, err := bc.InsertChain(blocks); err != nil { + panic(err) + } + + // Store BALs for each block + var ( + hashes []common.Hash + bals []rlp.RawValue + ) + for _, block := range blocks { + hash := block.Hash() + number := block.NumberU64() + + // Fill with data based on block number + bytes, err := rlp.EncodeToBytes(makeTestBAL(balSize)) + if err != nil { + panic(err) + } + rawdb.WriteAccessListRLP(db, hash, number, bytes) + hashes = append(hashes, hash) + bals = append(bals, bytes) + } + return bc, hashes, bals +} + +// TestServiceGetAccessListsQuery verifies that known block hashes return the +// correct BALs with positional correspondence. +func TestServiceGetAccessListsQuery(t *testing.T) { + t.Parallel() + bc, hashes, bals := getChainWithBALs(5, 100) + defer bc.Stop() + req := &GetAccessListsPacket{ + ID: 1, + Hashes: hashes, + Bytes: softResponseLimit, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Verify the results + if result.Len() != len(hashes) { + t.Fatalf("expected %d results, got %d", len(hashes), result.Len()) + } + var ( + index int + it = result.ContentIterator() + ) + for it.Next() { + if !bytes.Equal(it.Value(), bals[index]) { + t.Errorf("BAL %d mismatch: got %x, want %x", index, it.Value(), bals[index]) + } + index++ + } +} + +// TestServiceGetAccessListsQueryEmpty verifies that unknown block hashes return +// nil placeholders and that mixed known/unknown hashes preserve alignment. +func TestServiceGetAccessListsQueryEmpty(t *testing.T) { + t.Parallel() + bc, hashes, bals := getChainWithBALs(3, 100) + defer bc.Stop() + unknown := common.HexToHash("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + mixed := []common.Hash{hashes[0], unknown, hashes[1], unknown, hashes[2]} + req := &GetAccessListsPacket{ + ID: 2, + Hashes: mixed, + Bytes: softResponseLimit, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Verify length + if result.Len() != len(mixed) { + t.Fatalf("expected %d results, got %d", len(mixed), result.Len()) + } + + // Check positional correspondence + var expectVal = []rlp.RawValue{ + bals[0], rlp.EmptyString, bals[1], rlp.EmptyString, bals[2], + } + var ( + index int + it = result.ContentIterator() + ) + for it.Next() { + if !bytes.Equal(it.Value(), expectVal[index]) { + t.Errorf("BAL %d mismatch: got %x, want %x", index, it.Value(), expectVal[index]) + } + index++ + } +} + +// TestServiceGetAccessListsQueryCap verifies that requests exceeding +// maxAccessListLookups are capped. +func TestServiceGetAccessListsQueryCap(t *testing.T) { + t.Parallel() + + bc, _, _ := getChainWithBALs(2, 100) + defer bc.Stop() + + // Create a request with more hashes than the cap + hashes := make([]common.Hash, maxAccessListLookups+100) + for i := range hashes { + hashes[i] = common.BytesToHash([]byte{byte(i), byte(i >> 8)}) + } + req := &GetAccessListsPacket{ + ID: 3, + Hashes: hashes, + Bytes: softResponseLimit, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Can't get more than maxAccessListLookups results + if result.Len() > maxAccessListLookups { + t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, result.Len()) + } +} + +// TestServiceGetAccessListsQueryByteLimit verifies that the response stops +// once the byte limit is exceeded. The handler appends the entry that crosses +// the limit before breaking, so the total size will exceed the limit by at +// most one BAL. +func TestServiceGetAccessListsQueryByteLimit(t *testing.T) { + t.Parallel() + + // The handler will return 3/5 entries (3MB total) then break. + balSize := 1024 * 1024 + nBlocks := 5 + bc, hashes, _ := getChainWithBALs(nBlocks, balSize) + defer bc.Stop() + req := &GetAccessListsPacket{ + ID: 0, + Hashes: hashes, + Bytes: softResponseLimit, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Should have stopped before returning all blocks + if result.Len() >= nBlocks { + t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, result.Len()) + } + + // Should have returned at least one + if result.Len() == 0 { + t.Fatal("expected at least one result") + } + + // The total size should exceed the limit (the entry that crosses it is included) + if result.Size() <= softResponseLimit { + t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", result.Size(), softResponseLimit) + } +} + +// TestGetAccessListResponseDecoding verifies that an AccessListsPacket +// round-trips through RLP encode/decode, preserving positional +// correspondence and correctly representing absent BALs as empty strings. +func TestGetAccessListResponseDecoding(t *testing.T) { + t.Parallel() + + // Build two real BALs of different sizes. + bal1 := makeTestBAL(100) + bal2 := makeTestBAL(200) + bytes1, _ := rlp.EncodeToBytes(bal1) + bytes2, _ := rlp.EncodeToBytes(bal2) + + tests := []struct { + name string + items []rlp.RawValue // nil entry = unavailable BAL + counts int // expected decoded length + }{ + { + name: "all present", + items: []rlp.RawValue{bytes1, bytes2}, + counts: 2, + }, + { + name: "all absent", + items: []rlp.RawValue{rlp.EmptyString, rlp.EmptyString, rlp.EmptyString}, + counts: 3, + }, + { + name: "mixed present and absent", + items: []rlp.RawValue{bytes1, rlp.EmptyString, bytes2, rlp.EmptyString}, + counts: 4, + }, + { + name: "empty response", + items: []rlp.RawValue{}, + counts: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Build the packet using Append. + var orig AccessListsPacket + orig.ID = 42 + for _, item := range tt.items { + if err := orig.AccessLists.AppendRaw(item); err != nil { + t.Fatalf("AppendRaw failed: %v", err) + } + } + + // Encode -> Decode round-trip. + enc, err := rlp.EncodeToBytes(&orig) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + var dec AccessListsPacket + if err := rlp.DecodeBytes(enc, &dec); err != nil { + t.Fatalf("decode failed: %v", err) + } + + // Verify ID preserved. + if dec.ID != orig.ID { + t.Fatalf("ID mismatch: got %d, want %d", dec.ID, orig.ID) + } + + // Verify element count. + if dec.AccessLists.Len() != tt.counts { + t.Fatalf("length mismatch: got %d, want %d", dec.AccessLists.Len(), tt.counts) + } + + // Verify each element positionally. + it := dec.AccessLists.ContentIterator() + for i, want := range tt.items { + if !it.Next() { + t.Fatalf("iterator exhausted at index %d", i) + } + got := it.Value() + if !bytes.Equal(got, want) { + t.Errorf("element %d: got %x, want %x", i, got, want) + } + if !bytes.Equal(got, rlp.EmptyString) { + obj := new(bal.BlockAccessList) + if err := rlp.DecodeBytes(got, obj); err != nil { + t.Fatalf("decode failed: %v", err) + } + if bytes.Equal(got, bytes1) && !reflect.DeepEqual(obj, bal1) { + t.Fatalf("decode failed: got %x, want %x", obj, bal1) + } + if bytes.Equal(got, bytes2) && !reflect.DeepEqual(obj, bal2) { + t.Fatalf("decode failed: got %x, want %x", obj, bal2) + } + } + } + if it.Next() { + t.Error("iterator has extra elements after expected end") + } + }) + } +} diff --git a/eth/protocols/snap/handlers.go b/eth/protocols/snap/handlers.go new file mode 100644 index 0000000000..5a5733bdb4 --- /dev/null +++ b/eth/protocols/snap/handlers.go @@ -0,0 +1,600 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see + +package snap + +import ( + "bytes" + "fmt" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state/snapshot" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/tracker" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/trie/trienode" + "github.com/ethereum/go-ethereum/triedb/database" +) + +func handleGetAccountRange(backend Backend, msg Decoder, peer *Peer) error { + var req GetAccountRangePacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Service the request, potentially returning nothing in case of errors + accounts, proofs := ServiceGetAccountRangeQuery(backend.Chain(), &req) + + // Send back anything accumulated (or empty in case of errors) + return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ + ID: req.ID, + Accounts: accounts, + Proof: proofs, + }) +} + +// ServiceGetAccountRangeQuery assembles the response to an account range query. +// It is exposed to allow external packages to test protocol behavior. +func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePacket) ([]*AccountData, [][]byte) { + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // Retrieve the requested state and bail out if non existent + tr, err := trie.New(trie.StateTrieID(req.Root), chain.TrieDB()) + if err != nil { + return nil, nil + } + // Temporary solution: using the snapshot interface for both cases. + // This can be removed once the hash scheme is deprecated. + var it snapshot.AccountIterator + if chain.TrieDB().Scheme() == rawdb.HashScheme { + // The snapshot is assumed to be available in hash mode if + // the SNAP protocol is enabled. + it, err = chain.Snapshots().AccountIterator(req.Root, req.Origin) + } else { + it, err = chain.TrieDB().AccountIterator(req.Root, req.Origin) + } + if err != nil { + return nil, nil + } + // Iterate over the requested range and pile accounts up + var ( + accounts []*AccountData + size uint64 + last common.Hash + ) + for it.Next() { + hash, account := it.Hash(), common.CopyBytes(it.Account()) + + // Track the returned interval for the Merkle proofs + last = hash + + // Assemble the reply item + size += uint64(common.HashLength + len(account)) + accounts = append(accounts, &AccountData{ + Hash: hash, + Body: account, + }) + // If we've exceeded the request threshold, abort + if bytes.Compare(hash[:], req.Limit[:]) >= 0 { + break + } + if size > req.Bytes { + break + } + } + it.Release() + + // Generate the Merkle proofs for the first and last account + proof := trienode.NewProofSet() + if err := tr.Prove(req.Origin[:], proof); err != nil { + log.Warn("Failed to prove account range", "origin", req.Origin, "err", err) + return nil, nil + } + if last != (common.Hash{}) { + if err := tr.Prove(last[:], proof); err != nil { + log.Warn("Failed to prove account range", "last", last, "err", err) + return nil, nil + } + } + return accounts, proof.List() +} + +func handleAccountRange(backend Backend, msg Decoder, peer *Peer) error { + res := new(accountRangeInput) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + + // Check response validity. + if len := res.Proof.Len(); len > 128 { + return fmt.Errorf("AccountRange: invalid proof (length %d)", len) + } + tresp := tracker.Response{ID: res.ID, MsgCode: AccountRangeMsg, Size: len(res.Accounts.Content())} + if err := peer.tracker.Fulfil(tresp); err != nil { + return err + } + + // Decode. + accounts, err := res.Accounts.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid accounts list: %v", err) + } + proof, err := res.Proof.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid proof: %v", err) + } + + // Ensure the range is monotonically increasing + for i := 1; i < len(accounts); i++ { + if bytes.Compare(accounts[i-1].Hash[:], accounts[i].Hash[:]) >= 0 { + return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, accounts[i-1].Hash[:], i, accounts[i].Hash[:]) + } + } + + return backend.Handle(peer, &AccountRangePacket{res.ID, accounts, proof}) +} + +func handleGetStorageRanges(backend Backend, msg Decoder, peer *Peer) error { + var req GetStorageRangesPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Service the request, potentially returning nothing in case of errors + slots, proofs := ServiceGetStorageRangesQuery(backend.Chain(), &req) + + // Send back anything accumulated (or empty in case of errors) + return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ + ID: req.ID, + Slots: slots, + Proof: proofs, + }) +} + +func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesPacket) ([][]*StorageData, [][]byte) { + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // TODO(karalabe): Do we want to enforce > 0 accounts and 1 account if origin is set? + // TODO(karalabe): - Logging locally is not ideal as remote faults annoy the local user + // TODO(karalabe): - Dropping the remote peer is less flexible wrt client bugs (slow is better than non-functional) + + // Calculate the hard limit at which to abort, even if mid storage trie + hardLimit := uint64(float64(req.Bytes) * (1 + stateLookupSlack)) + + // Retrieve storage ranges until the packet limit is reached + var ( + slots [][]*StorageData + proofs [][]byte + size uint64 + ) + for _, account := range req.Accounts { + // If we've exceeded the requested data limit, abort without opening + // a new storage range (that we'd need to prove due to exceeded size) + if size >= req.Bytes { + break + } + // The first account might start from a different origin and end sooner + var origin common.Hash + if len(req.Origin) > 0 { + origin, req.Origin = common.BytesToHash(req.Origin), nil + } + var limit = common.MaxHash + if len(req.Limit) > 0 { + limit, req.Limit = common.BytesToHash(req.Limit), nil + } + // Retrieve the requested state and bail out if non existent + var ( + err error + it snapshot.StorageIterator + ) + // Temporary solution: using the snapshot interface for both cases. + // This can be removed once the hash scheme is deprecated. + if chain.TrieDB().Scheme() == rawdb.HashScheme { + // The snapshot is assumed to be available in hash mode if + // the SNAP protocol is enabled. + it, err = chain.Snapshots().StorageIterator(req.Root, account, origin) + } else { + it, err = chain.TrieDB().StorageIterator(req.Root, account, origin) + } + if err != nil { + return nil, nil + } + // Iterate over the requested range and pile slots up + var ( + storage []*StorageData + last common.Hash + abort bool + ) + for it.Next() { + if size >= hardLimit { + abort = true + break + } + hash, slot := it.Hash(), common.CopyBytes(it.Slot()) + + // Track the returned interval for the Merkle proofs + last = hash + + // Assemble the reply item + size += uint64(common.HashLength + len(slot)) + storage = append(storage, &StorageData{ + Hash: hash, + Body: slot, + }) + // If we've exceeded the request threshold, abort + if bytes.Compare(hash[:], limit[:]) >= 0 { + break + } + } + if len(storage) > 0 { + slots = append(slots, storage) + } + it.Release() + + // Generate the Merkle proofs for the first and last storage slot, but + // only if the response was capped. If the entire storage trie included + // in the response, no need for any proofs. + if origin != (common.Hash{}) || (abort && len(storage) > 0) { + // Request started at a non-zero hash or was capped prematurely, add + // the endpoint Merkle proofs + accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), chain.TrieDB()) + if err != nil { + return nil, nil + } + acc, err := accTrie.GetAccountByHash(account) + if err != nil || acc == nil { + return nil, nil + } + id := trie.StorageTrieID(req.Root, account, acc.Root) + stTrie, err := trie.NewStateTrie(id, chain.TrieDB()) + if err != nil { + return nil, nil + } + proof := trienode.NewProofSet() + if err := stTrie.Prove(origin[:], proof); err != nil { + log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err) + return nil, nil + } + if last != (common.Hash{}) { + if err := stTrie.Prove(last[:], proof); err != nil { + log.Warn("Failed to prove storage range", "last", last, "err", err) + return nil, nil + } + } + proofs = append(proofs, proof.List()...) + // Proof terminates the reply as proofs are only added if a node + // refuses to serve more data (exception when a contract fetch is + // finishing, but that's that). + break + } + } + return slots, proofs +} + +func handleStorageRanges(backend Backend, msg Decoder, peer *Peer) error { + res := new(storageRangesInput) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + + // Check response validity. + if len := res.Proof.Len(); len > 128 { + return fmt.Errorf("StorageRangesMsg: invalid proof (length %d)", len) + } + tresp := tracker.Response{ID: res.ID, MsgCode: StorageRangesMsg, Size: len(res.Slots.Content())} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("StorageRangesMsg: %w", err) + } + + // Decode. + slotLists, err := res.Slots.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid accounts list: %v", err) + } + proof, err := res.Proof.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid proof: %v", err) + } + + // Ensure the ranges are monotonically increasing + for i, slots := range slotLists { + for j := 1; j < len(slots); j++ { + if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 { + return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:]) + } + } + } + + return backend.Handle(peer, &StorageRangesPacket{res.ID, slotLists, proof}) +} + +func handleGetByteCodes(backend Backend, msg Decoder, peer *Peer) error { + var req GetByteCodesPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Service the request, potentially returning nothing in case of errors + codes := ServiceGetByteCodesQuery(backend.Chain(), &req) + + // Send back anything accumulated (or empty in case of errors) + return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{ + ID: req.ID, + Codes: codes, + }) +} + +// ServiceGetByteCodesQuery assembles the response to a byte codes query. +// It is exposed to allow external packages to test protocol behavior. +func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [][]byte { + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + if len(req.Hashes) > maxCodeLookups { + req.Hashes = req.Hashes[:maxCodeLookups] + } + // Retrieve bytecodes until the packet size limit is reached + var ( + codes [][]byte + bytes uint64 + ) + for _, hash := range req.Hashes { + if hash == types.EmptyCodeHash { + // Peers should not request the empty code, but if they do, at + // least sent them back a correct response without db lookups + codes = append(codes, []byte{}) + } else if blob := chain.ContractCodeWithPrefix(hash); len(blob) > 0 { + codes = append(codes, blob) + bytes += uint64(len(blob)) + } + if bytes > req.Bytes { + break + } + } + return codes +} + +func handleByteCodes(backend Backend, msg Decoder, peer *Peer) error { + res := new(byteCodesInput) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + + length := res.Codes.Len() + tresp := tracker.Response{ID: res.ID, MsgCode: ByteCodesMsg, Size: length} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("ByteCodes: %w", err) + } + + codes, err := res.Codes.Items() + if err != nil { + return fmt.Errorf("ByteCodes: %w", err) + } + + return backend.Handle(peer, &ByteCodesPacket{res.ID, codes}) +} + +func handleGetTrienodes(backend Backend, msg Decoder, peer *Peer) error { + var req GetTrieNodesPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Service the request, potentially returning nothing in case of errors + nodes, err := ServiceGetTrieNodesQuery(backend.Chain(), &req) + if err != nil { + return err + } + // Send back anything accumulated (or empty in case of errors) + return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ + ID: req.ID, + Nodes: nodes, + }) +} + +func nextBytes(it *rlp.Iterator) []byte { + if !it.Next() { + return nil + } + content, _, err := rlp.SplitString(it.Value()) + if err != nil { + return nil + } + return content +} + +// ServiceGetTrieNodesQuery assembles the response to a trie nodes query. +// It is exposed to allow external packages to test protocol behavior. +func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket) ([][]byte, error) { + start := time.Now() + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // Make sure we have the state associated with the request + triedb := chain.TrieDB() + + accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), triedb) + if err != nil { + // We don't have the requested state available, bail out + return nil, nil + } + // The 'reader' might be nil, in which case we cannot serve storage slots + // via snapshot. + var reader database.StateReader + if chain.Snapshots() != nil { + reader = chain.Snapshots().Snapshot(req.Root) + } + if reader == nil { + reader, _ = triedb.StateReader(req.Root) + } + + // Retrieve trie nodes until the packet size limit is reached + var ( + outerIt = req.Paths.ContentIterator() + nodes [][]byte + bytes uint64 + loads int // Trie hash expansions to count database reads + ) + for outerIt.Next() { + innerIt, err := rlp.NewListIterator(outerIt.Value()) + if err != nil { + return nodes, err + } + + switch innerIt.Count() { + case 0: + // Ensure we penalize invalid requests + return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest) + + case 1: + // If we're only retrieving an account trie node, fetch it directly + accKey := nextBytes(&innerIt) + if accKey == nil { + return nodes, fmt.Errorf("%w: invalid account node request", errBadRequest) + } + blob, resolved, err := accTrie.GetNode(accKey) + loads += resolved // always account database reads, even for failures + if err != nil { + break + } + nodes = append(nodes, blob) + bytes += uint64(len(blob)) + + default: + // Storage slots requested, open the storage trie and retrieve from there + accKey := nextBytes(&innerIt) + if accKey == nil { + return nodes, fmt.Errorf("%w: invalid account storage request", errBadRequest) + } + var stRoot common.Hash + if reader == nil { + // We don't have the requested state snapshotted yet (or it is stale), + // but can look up the account via the trie instead. + account, err := accTrie.GetAccountByHash(common.BytesToHash(accKey)) + loads += 8 // We don't know the exact cost of lookup, this is an estimate + if err != nil || account == nil { + break + } + stRoot = account.Root + } else { + account, err := reader.Account(common.BytesToHash(accKey)) + loads++ // always account database reads, even for failures + if err != nil || account == nil { + break + } + stRoot = common.BytesToHash(account.Root) + } + + id := trie.StorageTrieID(req.Root, common.BytesToHash(accKey), stRoot) + stTrie, err := trie.NewStateTrie(id, triedb) + loads++ // always account database reads, even for failures + if err != nil { + break + } + for innerIt.Next() { + path, _, err := rlp.SplitString(innerIt.Value()) + if err != nil { + return nil, fmt.Errorf("%w: invalid storage key: %v", errBadRequest, err) + } + blob, resolved, err := stTrie.GetNode(path) + loads += resolved // always account database reads, even for failures + if err != nil { + break + } + nodes = append(nodes, blob) + bytes += uint64(len(blob)) + + // Sanity check limits to avoid DoS on the store trie loads + if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent { + break + } + } + } + // Abort request processing if we've exceeded our limits + if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent { + break + } + } + return nodes, nil +} + +func handleTrieNodes(backend Backend, msg Decoder, peer *Peer) error { + res := new(trieNodesInput) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + + tresp := tracker.Response{ID: res.ID, MsgCode: TrieNodesMsg, Size: res.Nodes.Len()} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("TrieNodes: %w", err) + } + nodes, err := res.Nodes.Items() + if err != nil { + return fmt.Errorf("TrieNodes: %w", err) + } + + return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes}) +} + +// nolint:unused +func handleGetAccessLists(backend Backend, msg Decoder, peer *Peer) error { + var req GetAccessListsPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + return p2p.Send(peer.rw, AccessListsMsg, &AccessListsPacket{ + ID: req.ID, + AccessLists: ServiceGetAccessListsQuery(backend.Chain(), &req), + }) +} + +// ServiceGetAccessListsQuery assembles the response to an access list query. +// It is exposed to allow external packages to test protocol behavior. +func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) rlp.RawList[rlp.RawValue] { + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // Cap the number of lookups + if len(req.Hashes) > maxAccessListLookups { + req.Hashes = req.Hashes[:maxAccessListLookups] + } + var ( + err error + bytes uint64 + response = rlp.RawList[rlp.RawValue]{} + ) + for _, hash := range req.Hashes { + if bal := chain.GetAccessListRLP(hash); len(bal) > 0 { + err = response.AppendRaw(bal) + bytes += uint64(len(bal)) + } else { + // Either the block is unknown or the BAL doesn't exist + err = response.AppendRaw(rlp.EmptyString) + bytes += 1 + } + if err != nil { + break + } + if bytes > req.Bytes { + break + } + } + return response +} diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go index 25fe25822b..685f468da3 100644 --- a/eth/protocols/snap/protocol.go +++ b/eth/protocols/snap/protocol.go @@ -28,6 +28,7 @@ import ( // Constants to match up protocol versions and messages const ( SNAP1 = 1 + //SNAP2 = 2 ) // ProtocolName is the official short name of the `snap` protocol used during @@ -40,7 +41,7 @@ var ProtocolVersions = []uint{SNAP1} // protocolLengths are the number of implemented message corresponding to // different protocol versions. -var protocolLengths = map[uint]uint64{SNAP1: 8} +var protocolLengths = map[uint]uint64{ /*SNAP2: 10,*/ SNAP1: 8} // maxMessageSize is the maximum cap on the size of a protocol message. const maxMessageSize = 10 * 1024 * 1024 @@ -54,6 +55,8 @@ const ( ByteCodesMsg = 0x05 GetTrieNodesMsg = 0x06 TrieNodesMsg = 0x07 + GetAccessListsMsg = 0x08 + AccessListsMsg = 0x09 ) var ( @@ -215,6 +218,21 @@ type TrieNodesPacket struct { Nodes [][]byte // Requested state trie nodes } +// GetAccessListsPacket requests BALs for a set of block hashes. +type GetAccessListsPacket struct { + ID uint64 // Request ID to match up responses with + Hashes []common.Hash // Block hashes to retrieve BALs for + Bytes uint64 // Soft limit at which to stop returning data +} + +// AccessListsPacket is the response to GetAccessListsPacket. +// Each entry corresponds to the requested hash at the same index. +// Empty entries indicate the BAL is unavailable. +type AccessListsPacket struct { + ID uint64 // ID of the request this is a response for + AccessLists rlp.RawList[rlp.RawValue] // Requested BALs +} + func (*GetAccountRangePacket) Name() string { return "GetAccountRange" } func (*GetAccountRangePacket) Kind() byte { return GetAccountRangeMsg } @@ -238,3 +256,9 @@ func (*GetTrieNodesPacket) Kind() byte { return GetTrieNodesMsg } func (*TrieNodesPacket) Name() string { return "TrieNodes" } func (*TrieNodesPacket) Kind() byte { return TrieNodesMsg } + +func (*GetAccessListsPacket) Name() string { return "GetAccessLists" } +func (*GetAccessListsPacket) Kind() byte { return GetAccessListsMsg } + +func (*AccessListsPacket) Name() string { return "AccessLists" } +func (*AccessListsPacket) Kind() byte { return AccessListsMsg } diff --git a/eth/state_accessor.go b/eth/state_accessor.go index 871f2c9269..04aac321cb 100644 --- a/eth/state_accessor.go +++ b/eth/state_accessor.go @@ -38,7 +38,11 @@ import ( // for releasing state. var noopReleaser = tracers.StateReleaseFunc(func() {}) -func (eth *Ethereum) hashState(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, readOnly bool, preferDisk bool) (statedb *state.StateDB, release tracers.StateReleaseFunc, err error) { +// reexecLimit is the maximum number of ancestor blocks to walk back when +// attempting to reconstruct missing historical state for hash-scheme nodes. +const reexecLimit = uint64(128) + +func (eth *Ethereum) hashState(ctx context.Context, block *types.Block, base *state.StateDB, readOnly bool, preferDisk bool) (statedb *state.StateDB, release tracers.StateReleaseFunc, err error) { var ( current *types.Block database state.Database @@ -99,7 +103,7 @@ func (eth *Ethereum) hashState(ctx context.Context, block *types.Block, reexec u } } // Database does not have the state for the given block, try to regenerate - for i := uint64(0); i < reexec; i++ { + for i := uint64(0); i < reexecLimit; i++ { if err := ctx.Err(); err != nil { return nil, nil, err } @@ -120,7 +124,7 @@ func (eth *Ethereum) hashState(ctx context.Context, block *types.Block, reexec u if err != nil { switch err.(type) { case *trie.MissingNodeError: - return nil, nil, fmt.Errorf("required historical state unavailable (reexec=%d)", reexec) + return nil, nil, fmt.Errorf("required historical state unavailable (reexec=%d)", reexecLimit) default: return nil, nil, err } @@ -190,10 +194,9 @@ func (eth *Ethereum) pathState(block *types.Block) (*state.StateDB, func(), erro } // stateAtBlock retrieves the state database associated with a certain block. -// If no state is locally available for the given block, a number of blocks -// are attempted to be reexecuted to generate the desired state. The optional -// base layer statedb can be provided which is regarded as the statedb of the -// parent block. +// If no state is locally available for the given block, up to reexecLimit ancestor +// blocks are reexecuted to generate the desired state. The optional base layer +// statedb can be provided which is regarded as the statedb of the parent block. // // An additional release function will be returned if the requested state is // available. Release is expected to be invoked when the returned state is no @@ -202,7 +205,6 @@ func (eth *Ethereum) pathState(block *types.Block) (*state.StateDB, func(), erro // // Parameters: // - block: The block for which we want the state(state = block.Root) -// - reexec: The maximum number of blocks to reprocess trying to obtain the desired state // - base: If the caller is tracing multiple blocks, the caller can provide the parent // state continuously from the callsite. // - readOnly: If true, then the live 'blockchain' state database is used. No mutation should @@ -211,9 +213,9 @@ func (eth *Ethereum) pathState(block *types.Block) (*state.StateDB, func(), erro // - preferDisk: This arg can be used by the caller to signal that even though the 'base' is // provided, it would be preferable to start from a fresh state, if we have it // on disk. -func (eth *Ethereum) stateAtBlock(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, readOnly bool, preferDisk bool) (statedb *state.StateDB, release tracers.StateReleaseFunc, err error) { +func (eth *Ethereum) stateAtBlock(ctx context.Context, block *types.Block, base *state.StateDB, readOnly bool, preferDisk bool) (statedb *state.StateDB, release tracers.StateReleaseFunc, err error) { if eth.blockchain.TrieDB().Scheme() == rawdb.HashScheme { - return eth.hashState(ctx, block, reexec, base, readOnly, preferDisk) + return eth.hashState(ctx, block, base, readOnly, preferDisk) } return eth.pathState(block) } @@ -225,7 +227,7 @@ func (eth *Ethereum) stateAtBlock(ctx context.Context, block *types.Block, reexe // function will return the state of block after the pre-block operations have // been completed (e.g. updating system contracts), but before post-block // operations are completed (e.g. processing withdrawals). -func (eth *Ethereum) stateAtTransaction(ctx context.Context, block *types.Block, txIndex int, reexec uint64) (*types.Transaction, vm.BlockContext, *state.StateDB, tracers.StateReleaseFunc, error) { +func (eth *Ethereum) stateAtTransaction(ctx context.Context, block *types.Block, txIndex int) (*types.Transaction, vm.BlockContext, *state.StateDB, tracers.StateReleaseFunc, error) { // Short circuit if it's genesis block. if block.NumberU64() == 0 { return nil, vm.BlockContext{}, nil, nil, errors.New("no transaction in genesis") @@ -237,7 +239,7 @@ func (eth *Ethereum) stateAtTransaction(ctx context.Context, block *types.Block, } // Lookup the statedb of parent block from the live database, // otherwise regenerate it on the flight. - statedb, release, err := eth.stateAtBlock(ctx, parent, reexec, nil, true, false) + statedb, release, err := eth.stateAtBlock(ctx, parent, nil, true, false) if err != nil { return nil, vm.BlockContext{}, nil, nil, err } diff --git a/eth/tracers/api.go b/eth/tracers/api.go index eed404622e..53a09087e4 100644 --- a/eth/tracers/api.go +++ b/eth/tracers/api.go @@ -51,11 +51,6 @@ const ( // by default before being forcefully aborted. defaultTraceTimeout = 5 * time.Second - // defaultTraceReexec is the number of blocks the tracer is willing to go back - // and reexecute to produce missing historical state necessary to run a specific - // trace. - defaultTraceReexec = uint64(128) - // defaultTracechainMemLimit is the size of the triedb, at which traceChain // switches over and tries to use a disk-backed database instead of building // on top of memory. @@ -89,8 +84,8 @@ type Backend interface { ChainConfig() *params.ChainConfig Engine() consensus.Engine ChainDb() ethdb.Database - StateAtBlock(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, readOnly bool, preferDisk bool) (*state.StateDB, StateReleaseFunc, error) - StateAtTransaction(ctx context.Context, block *types.Block, txIndex int, reexec uint64) (*types.Transaction, vm.BlockContext, *state.StateDB, StateReleaseFunc, error) + StateAtBlock(ctx context.Context, block *types.Block, base *state.StateDB, readOnly bool, preferDisk bool) (*state.StateDB, StateReleaseFunc, error) + StateAtTransaction(ctx context.Context, block *types.Block, txIndex int) (*types.Transaction, vm.BlockContext, *state.StateDB, StateReleaseFunc, error) } // API is the collection of tracing APIs exposed over the private debugging endpoint. @@ -156,7 +151,6 @@ type TraceConfig struct { *logger.Config Tracer *string Timeout *string - Reexec *uint64 // Config specific to given tracer. Note struct logger // config are historically embedded in main object. TracerConfig json.RawMessage @@ -174,7 +168,6 @@ type TraceCallConfig struct { // StdTraceConfig holds extra parameters to standard-json trace functions. type StdTraceConfig struct { logger.Config - Reexec *uint64 TxHash common.Hash } @@ -245,10 +238,6 @@ func (api *API) TraceChain(ctx context.Context, start, end rpc.BlockNumber, conf // transaction, dependent on the requested tracer. // The tracing procedure should be aborted in case the closed signal is received. func (api *API) traceChain(start, end *types.Block, config *TraceConfig, closed <-chan error) chan *blockTraceResult { - reexec := defaultTraceReexec - if config != nil && config.Reexec != nil { - reexec = *config.Reexec - } blocks := int(end.NumberU64() - start.NumberU64()) threads := runtime.NumCPU() if threads > blocks { @@ -374,7 +363,7 @@ func (api *API) traceChain(start, end *types.Block, config *TraceConfig, closed s1, s2, s3 := statedb.Database().TrieDB().Size() preferDisk = s1+s2+s3 > defaultTracechainMemLimit } - statedb, release, err = api.backend.StateAtBlock(ctx, block, reexec, statedb, false, preferDisk) + statedb, release, err = api.backend.StateAtBlock(ctx, block, statedb, false, preferDisk) if err != nil { failed = err break @@ -522,11 +511,7 @@ func (api *API) IntermediateRoots(ctx context.Context, hash common.Hash, config if err != nil { return nil, err } - reexec := defaultTraceReexec - if config != nil && config.Reexec != nil { - reexec = *config.Reexec - } - statedb, release, err := api.backend.StateAtBlock(ctx, parent, reexec, nil, true, false) + statedb, release, err := api.backend.StateAtBlock(ctx, parent, nil, true, false) if err != nil { return nil, err } @@ -591,11 +576,7 @@ func (api *API) traceBlock(ctx context.Context, block *types.Block, config *Trac if err != nil { return nil, err } - reexec := defaultTraceReexec - if config != nil && config.Reexec != nil { - reexec = *config.Reexec - } - statedb, release, err := api.backend.StateAtBlock(ctx, parent, reexec, nil, true, false) + statedb, release, err := api.backend.StateAtBlock(ctx, parent, nil, true, false) if err != nil { return nil, err } @@ -743,11 +724,7 @@ func (api *API) standardTraceBlockToFile(ctx context.Context, block *types.Block if err != nil { return nil, err } - reexec := defaultTraceReexec - if config != nil && config.Reexec != nil { - reexec = *config.Reexec - } - statedb, release, err := api.backend.StateAtBlock(ctx, parent, reexec, nil, true, false) + statedb, release, err := api.backend.StateAtBlock(ctx, parent, nil, true, false) if err != nil { return nil, err } @@ -877,15 +854,11 @@ func (api *API) TraceTransaction(ctx context.Context, hash common.Hash, config * if blockNumber == 0 { return nil, errors.New("genesis is not traceable") } - reexec := defaultTraceReexec - if config != nil && config.Reexec != nil { - reexec = *config.Reexec - } block, err := api.blockByNumberAndHash(ctx, rpc.BlockNumber(blockNumber), blockHash) if err != nil { return nil, err } - tx, vmctx, statedb, release, err := api.backend.StateAtTransaction(ctx, block, int(index), reexec) + tx, vmctx, statedb, release, err := api.backend.StateAtTransaction(ctx, block, int(index)) if err != nil { return nil, err } @@ -939,15 +912,10 @@ func (api *API) TraceCall(ctx context.Context, args ethapi.TransactionArgs, bloc return nil, err } // try to recompute the state - reexec := defaultTraceReexec - if config != nil && config.Reexec != nil { - reexec = *config.Reexec - } - if config != nil && config.TxIndex != nil { - _, _, statedb, release, err = api.backend.StateAtTransaction(ctx, block, int(*config.TxIndex), reexec) + _, _, statedb, release, err = api.backend.StateAtTransaction(ctx, block, int(*config.TxIndex)) } else { - statedb, release, err = api.backend.StateAtBlock(ctx, block, reexec, nil, true, false) + statedb, release, err = api.backend.StateAtBlock(ctx, block, nil, true, false) } if err != nil { return nil, err diff --git a/eth/tracers/api_test.go b/eth/tracers/api_test.go index 1d5024ad08..ecf3c99c8f 100644 --- a/eth/tracers/api_test.go +++ b/eth/tracers/api_test.go @@ -151,7 +151,7 @@ func (b *testBackend) teardown() { b.chain.Stop() } -func (b *testBackend) StateAtBlock(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, readOnly bool, preferDisk bool) (*state.StateDB, StateReleaseFunc, error) { +func (b *testBackend) StateAtBlock(ctx context.Context, block *types.Block, base *state.StateDB, readOnly bool, preferDisk bool) (*state.StateDB, StateReleaseFunc, error) { statedb, err := b.chain.StateAt(block.Root()) if err != nil { return nil, nil, errStateNotFound @@ -167,12 +167,12 @@ func (b *testBackend) StateAtBlock(ctx context.Context, block *types.Block, reex return statedb, release, nil } -func (b *testBackend) StateAtTransaction(ctx context.Context, block *types.Block, txIndex int, reexec uint64) (*types.Transaction, vm.BlockContext, *state.StateDB, StateReleaseFunc, error) { +func (b *testBackend) StateAtTransaction(ctx context.Context, block *types.Block, txIndex int) (*types.Transaction, vm.BlockContext, *state.StateDB, StateReleaseFunc, error) { parent := b.chain.GetBlock(block.ParentHash(), block.NumberU64()-1) if parent == nil { return nil, vm.BlockContext{}, nil, nil, errBlockNotFound } - statedb, release, err := b.StateAtBlock(ctx, parent, reexec, nil, true, false) + statedb, release, err := b.StateAtBlock(ctx, parent, nil, true, false) if err != nil { return nil, vm.BlockContext{}, nil, nil, errStateNotFound } @@ -202,6 +202,18 @@ type stateTracer struct { Storage map[common.Address]map[common.Hash]common.Hash } +type tracedOpcodeLog struct { + Op string `json:"op"` + Refund *uint64 `json:"refund,omitempty"` + Storage map[string]string `json:"storage,omitempty"` +} + +type tracedOpcodeResult struct { + Failed bool `json:"failed"` + ReturnValue string `json:"returnValue"` + StructLogs []tracedOpcodeLog `json:"structLogs"` +} + func newStateTracer(ctx *Context, cfg json.RawMessage, chainCfg *params.ChainConfig) (*Tracer, error) { t := &stateTracer{ Balance: make(map[common.Address]*hexutil.Big), @@ -1058,6 +1070,176 @@ func TestTracingWithOverrides(t *testing.T) { } } +func TestTraceTransactionRefundAndStorageSnapshots(t *testing.T) { + t.Parallel() + + accounts := newAccounts(1) + contract := common.HexToAddress("0x00000000000000000000000000000000deadbeef") + slot0 := common.BigToHash(big.NewInt(0)) + txSigner := types.HomesteadSigner{} + genesis := &core.Genesis{ + Config: params.TestChainConfig, + Alloc: types.GenesisAlloc{ + accounts[0].addr: {Balance: big.NewInt(params.Ether)}, + contract: { + Nonce: 1, + Code: []byte{ + byte(vm.PUSH1), 0x00, + byte(vm.SLOAD), + byte(vm.POP), + byte(vm.PUSH1), 0x00, + byte(vm.PUSH1), 0x00, + byte(vm.SSTORE), + byte(vm.STOP), + }, + Storage: map[common.Hash]common.Hash{ + slot0: common.BigToHash(big.NewInt(1)), + }, + }, + }, + } + var target common.Hash + backend := newTestBackend(t, 1, genesis, func(i int, b *core.BlockGen) { + tx, _ := types.SignTx(types.NewTx(&types.LegacyTx{ + Nonce: 0, + To: &contract, + Value: big.NewInt(0), + Gas: 100000, + GasPrice: b.BaseFee(), + }), txSigner, accounts[0].key) + b.AddTx(tx) + target = tx.Hash() + }) + defer backend.teardown() + + api := NewAPI(backend) + result, err := api.TraceTransaction(context.Background(), target, nil) + if err != nil { + t.Fatalf("failed to trace refunding transaction: %v", err) + } + var traced tracedOpcodeResult + if err := json.Unmarshal(result.(json.RawMessage), &traced); err != nil { + t.Fatalf("failed to unmarshal trace result: %v", err) + } + if traced.Failed { + t.Fatal("expected refunding transaction to succeed") + } + if traced.ReturnValue != "0x" { + t.Fatalf("unexpected return value: have %s want 0x", traced.ReturnValue) + } + slotHex := slot0.Hex() + oneHex := common.BigToHash(big.NewInt(1)).Hex() + zeroHex := common.Hash{}.Hex() + var ( + foundSloadSnapshot bool + foundSstoreSnapshot bool + foundRefund bool + ) + for _, log := range traced.StructLogs { + switch log.Op { + case "SLOAD": + if got := log.Storage[slotHex]; got == oneHex { + foundSloadSnapshot = true + } + case "SSTORE": + if got := log.Storage[slotHex]; got == zeroHex { + foundSstoreSnapshot = true + } + } + if log.Refund != nil && *log.Refund > 0 { + foundRefund = true + } + } + if !foundSloadSnapshot { + t.Fatal("expected SLOAD snapshot to include the pre-existing non-zero storage value") + } + if !foundSstoreSnapshot { + t.Fatal("expected SSTORE snapshot to include the post-write zeroed storage value") + } + if !foundRefund { + t.Fatal("expected at least one structLog entry with a non-zero refund field") + } +} + +func TestTraceTransactionFailureReturnValues(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + code []byte + wantReturnValue string + }{ + { + name: "revert preserves return data", + code: []byte{ + byte(vm.PUSH1), 0x2a, + byte(vm.PUSH1), 0x00, + byte(vm.MSTORE), + byte(vm.PUSH1), 0x20, + byte(vm.PUSH1), 0x00, + byte(vm.REVERT), + }, + wantReturnValue: "0x000000000000000000000000000000000000000000000000000000000000002a", + }, + { + name: "hard failure clears return data", + code: []byte{ + byte(vm.INVALID), + }, + wantReturnValue: "0x", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + accounts := newAccounts(1) + contract := common.HexToAddress("0x00000000000000000000000000000000deadbeef") + txSigner := types.HomesteadSigner{} + genesis := &core.Genesis{ + Config: params.TestChainConfig, + Alloc: types.GenesisAlloc{ + accounts[0].addr: {Balance: big.NewInt(params.Ether)}, + contract: { + Nonce: 1, + Code: tc.code, + }, + }, + } + var target common.Hash + backend := newTestBackend(t, 1, genesis, func(i int, b *core.BlockGen) { + tx, _ := types.SignTx(types.NewTx(&types.LegacyTx{ + Nonce: 0, + To: &contract, + Value: big.NewInt(0), + Gas: 100000, + GasPrice: b.BaseFee(), + }), txSigner, accounts[0].key) + b.AddTx(tx) + target = tx.Hash() + }) + defer backend.teardown() + + api := NewAPI(backend) + result, err := api.TraceTransaction(context.Background(), target, nil) + if err != nil { + t.Fatalf("failed to trace transaction: %v", err) + } + var traced tracedOpcodeResult + if err := json.Unmarshal(result.(json.RawMessage), &traced); err != nil { + t.Fatalf("failed to unmarshal trace result: %v", err) + } + if !traced.Failed { + t.Fatal("expected traced transaction to fail") + } + if traced.ReturnValue != tc.wantReturnValue { + t.Fatalf("unexpected returnValue: have %s want %s", traced.ReturnValue, tc.wantReturnValue) + } + if len(traced.StructLogs) == 0 { + t.Fatal("expected failing trace to still include structLogs") + } + }) + } +} + type Account struct { key *ecdsa.PrivateKey addr common.Address diff --git a/eth/tracers/js/tracer_test.go b/eth/tracers/js/tracer_test.go index b21e104abc..694debcf98 100644 --- a/eth/tracers/js/tracer_test.go +++ b/eth/tracers/js/tracer_test.go @@ -66,9 +66,9 @@ func runTrace(tracer *tracers.Tracer, vmctx *vmContext, chaincfg *params.ChainCo tracer.OnTxStart(evm.GetVMContext(), types.NewTx(&types.LegacyTx{Gas: gasLimit, GasPrice: vmctx.txCtx.GasPrice.ToBig()}), contract.Caller()) tracer.OnEnter(0, byte(vm.CALL), contract.Caller(), contract.Address(), []byte{}, startGas, value.ToBig()) ret, err := evm.Run(contract, []byte{}, false) - tracer.OnExit(0, ret, startGas-contract.Gas, err, true) + tracer.OnExit(0, ret, startGas-contract.Gas.RegularGas, err, true) // Rest gas assumes no refund - tracer.OnTxEnd(&types.Receipt{GasUsed: gasLimit - contract.Gas}, nil) + tracer.OnTxEnd(&types.Receipt{GasUsed: gasLimit - contract.Gas.RegularGas}, nil) if err != nil { return nil, err } diff --git a/eth/tracers/logger/logger.go b/eth/tracers/logger/logger.go index 67e07f78d0..7f2b2aecf2 100644 --- a/eth/tracers/logger/logger.go +++ b/eth/tracers/logger/logger.go @@ -148,7 +148,7 @@ type structLogLegacy struct { Gas uint64 `json:"gas"` GasCost uint64 `json:"gasCost"` Depth int `json:"depth"` - Error string `json:"error,omitempty"` + Error string `json:"error,omitempty,omitzero"` Stack *[]string `json:"stack,omitempty"` ReturnData string `json:"returnData,omitempty"` Memory *[]string `json:"memory,omitempty"` @@ -156,6 +156,15 @@ type structLogLegacy struct { RefundCounter uint64 `json:"refund,omitempty"` } +func formatMemoryWord(chunk []byte) string { + if len(chunk) == 32 { + return hexutil.Encode(chunk) + } + var word [32]byte + copy(word[:], chunk) + return hexutil.Encode(word[:]) +} + // toLegacyJSON converts the structLog to legacy json-encoded legacy form. func (s *StructLog) toLegacyJSON() json.RawMessage { msg := structLogLegacy{ @@ -175,7 +184,7 @@ func (s *StructLog) toLegacyJSON() json.RawMessage { msg.Stack = &stack } if len(s.ReturnData) > 0 { - msg.ReturnData = hexutil.Bytes(s.ReturnData).String() + msg.ReturnData = hexutil.Encode(s.ReturnData) } if len(s.Memory) > 0 { memory := make([]string, 0, (len(s.Memory)+31)/32) @@ -184,14 +193,14 @@ func (s *StructLog) toLegacyJSON() json.RawMessage { if end > len(s.Memory) { end = len(s.Memory) } - memory = append(memory, fmt.Sprintf("%x", s.Memory[i:end])) + memory = append(memory, formatMemoryWord(s.Memory[i:end])) } msg.Memory = &memory } if len(s.Storage) > 0 { storage := make(map[string]string) for i, storageValue := range s.Storage { - storage[fmt.Sprintf("%x", i)] = fmt.Sprintf("%x", storageValue) + storage[i.Hex()] = storageValue.Hex() } msg.Storage = &storage } diff --git a/eth/tracers/logger/logger_test.go b/eth/tracers/logger/logger_test.go index acc3069e70..554a37aff1 100644 --- a/eth/tracers/logger/logger_test.go +++ b/eth/tracers/logger/logger_test.go @@ -96,3 +96,46 @@ func TestStructLogMarshalingOmitEmpty(t *testing.T) { }) } } + +func TestStructLogLegacyJSONSpecFormatting(t *testing.T) { + tests := []struct { + name string + log *StructLog + want string + }{ + { + name: "omits empty error and pads memory/storage", + log: &StructLog{ + Pc: 7, + Op: vm.SSTORE, + Gas: 100, + GasCost: 20, + Memory: []byte{0xaa, 0xbb}, + Storage: map[common.Hash]common.Hash{common.BigToHash(big.NewInt(1)): common.BigToHash(big.NewInt(2))}, + Depth: 1, + ReturnData: []byte{0x12, 0x34}, + }, + want: `{"pc":7,"op":"SSTORE","gas":100,"gasCost":20,"depth":1,"returnData":"0x1234","memory":["0xaabb000000000000000000000000000000000000000000000000000000000000"],"storage":{"0x0000000000000000000000000000000000000000000000000000000000000001":"0x0000000000000000000000000000000000000000000000000000000000000002"}}`, + }, + { + name: "includes error only when present", + log: &StructLog{ + Pc: 1, + Op: vm.STOP, + Gas: 2, + GasCost: 3, + Depth: 1, + Err: errors.New("boom"), + }, + want: `{"pc":1,"op":"STOP","gas":2,"gasCost":3,"depth":1,"error":"boom"}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + have := string(tt.log.toLegacyJSON()) + if have != tt.want { + t.Fatalf("mismatched results\n\thave: %v\n\twant: %v", have, tt.want) + } + }) + } +} diff --git a/ethclient/ethclient.go b/ethclient/ethclient.go index bc4eaad6fa..412f8955ba 100644 --- a/ethclient/ethclient.go +++ b/ethclient/ethclient.go @@ -498,7 +498,11 @@ func (ec *Client) SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuer func toFilterArg(q ethereum.FilterQuery) (interface{}, error) { arg := map[string]interface{}{} - if q.Addresses != nil { + // Only include "address" when there are actual address filters. + // An empty slice is treated the same as nil (no filter), and omitting + // the field avoids sending "address":[] to nodes that reject empty arrays + // (e.g. Hedera, some non-Geth implementations). + if len(q.Addresses) > 0 { arg["address"] = q.Addresses } if q.Topics != nil { @@ -838,6 +842,7 @@ type rpcProgress struct { TxIndexFinishedBlocks hexutil.Uint64 TxIndexRemainingBlocks hexutil.Uint64 StateIndexRemaining hexutil.Uint64 + TrienodeIndexRemaining hexutil.Uint64 } func (p *rpcProgress) toSyncProgress() *ethereum.SyncProgress { @@ -865,6 +870,7 @@ func (p *rpcProgress) toSyncProgress() *ethereum.SyncProgress { TxIndexFinishedBlocks: uint64(p.TxIndexFinishedBlocks), TxIndexRemainingBlocks: uint64(p.TxIndexRemainingBlocks), StateIndexRemaining: uint64(p.StateIndexRemaining), + TrienodeIndexRemaining: uint64(p.TrienodeIndexRemaining), } } diff --git a/ethclient/types_test.go b/ethclient/types_test.go index dcb9a579b7..8820b11162 100644 --- a/ethclient/types_test.go +++ b/ethclient/types_test.go @@ -53,6 +53,22 @@ func TestToFilterArg(t *testing.T) { }, nil, }, + { + // empty Addresses slice must be treated same as nil: + // the "address" field must be omitted so that non-Geth nodes + // (e.g. Hedera) do not reject the request with an error. + "with empty addresses slice", + ethereum.FilterQuery{ + Addresses: []common.Address{}, + FromBlock: big.NewInt(1), + ToBlock: big.NewInt(2), + }, + map[string]interface{}{ + "fromBlock": "0x1", + "toBlock": "0x2", + }, + nil, + }, { "without BlockHash", ethereum.FilterQuery{ diff --git a/graphql/graphql.go b/graphql/graphql.go index f25bfd127a..dadc91fac0 100644 --- a/graphql/graphql.go +++ b/graphql/graphql.go @@ -1531,6 +1531,9 @@ func (s *SyncState) TxIndexRemainingBlocks() hexutil.Uint64 { func (s *SyncState) StateIndexRemaining() hexutil.Uint64 { return hexutil.Uint64(s.progress.StateIndexRemaining) } +func (s *SyncState) TrienodeIndexRemaining() hexutil.Uint64 { + return hexutil.Uint64(s.progress.TrienodeIndexRemaining) +} // Syncing returns false in case the node is currently not syncing with the network. It can be up-to-date or has not // yet received the latest block headers from its peers. In case it is synchronizing: diff --git a/interfaces.go b/interfaces.go index 21d42c6d34..8b3dbe3a42 100644 --- a/interfaces.go +++ b/interfaces.go @@ -139,8 +139,9 @@ type SyncProgress struct { TxIndexFinishedBlocks uint64 // Number of blocks whose transactions are already indexed TxIndexRemainingBlocks uint64 // Number of blocks whose transactions are not indexed yet - // "historical state indexing" fields - StateIndexRemaining uint64 // Number of states remain unindexed + // "historical data indexing" fields + StateIndexRemaining uint64 // Number of states remain unindexed + TrienodeIndexRemaining uint64 // Number of trienodes remain unindexed } // Done returns the indicator if the initial sync is finished or not. @@ -148,7 +149,7 @@ func (prog SyncProgress) Done() bool { if prog.CurrentBlock < prog.HighestBlock { return false } - return prog.TxIndexRemainingBlocks == 0 && prog.StateIndexRemaining == 0 + return prog.TxIndexRemainingBlocks == 0 && prog.StateIndexRemaining == 0 && prog.TrienodeIndexRemaining == 0 } // ChainSyncReader wraps access to the node's current sync status. If there's no diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 4f217d0578..149e12c5b8 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -180,6 +180,7 @@ func (api *EthereumAPI) Syncing(ctx context.Context) (interface{}, error) { "txIndexFinishedBlocks": hexutil.Uint64(progress.TxIndexFinishedBlocks), "txIndexRemainingBlocks": hexutil.Uint64(progress.TxIndexRemainingBlocks), "stateIndexRemaining": hexutil.Uint64(progress.StateIndexRemaining), + "trienodeIndexRemaining": hexutil.Uint64(progress.TrienodeIndexRemaining), }, nil } @@ -897,6 +898,7 @@ func DoEstimateGas(ctx context.Context, b Backend, args TransactionArgs, blockNr if err := blockOverrides.Apply(&blockCtx); err != nil { return 0, err } + header = blockOverrides.MakeHeader(header) } rules := b.ChainConfig().Rules(blockCtx.BlockNumber, blockCtx.Random != nil, blockCtx.Time) precompiles := vm.ActivePrecompiledContracts(rules) @@ -904,13 +906,17 @@ func DoEstimateGas(ctx context.Context, b Backend, args TransactionArgs, blockNr return 0, err } // Construct the gas estimator option from the user input + var blobBaseFee *big.Int + if blockOverrides != nil && blockOverrides.BlobBaseFee != nil { + blobBaseFee = blockOverrides.BlobBaseFee.ToInt() + } opts := &gasestimator.Options{ - Config: b.ChainConfig(), - Chain: NewChainContext(ctx, b), - Header: header, - BlockOverrides: blockOverrides, - State: state, - ErrorRatio: estimateGasErrorRatio, + Config: b.ChainConfig(), + Chain: NewChainContext(ctx, b), + Header: header, + State: state, + BlobBaseFee: blobBaseFee, + ErrorRatio: estimateGasErrorRatio, } // Set any required transaction default, but make sure the gas cap itself is not messed with // if it was not specified in the original argument list. diff --git a/internal/ethapi/api_test.go b/internal/ethapi/api_test.go index 62e9979d3d..b010eeaa08 100644 --- a/internal/ethapi/api_test.go +++ b/internal/ethapi/api_test.go @@ -780,6 +780,17 @@ func TestEstimateGas(t *testing.T) { expectErr: core.ErrInsufficientFunds, want: 21000, }, + // block override gas limit should bound estimation search space. + { + blockNumber: rpc.LatestBlockNumber, + call: TransactionArgs{ + From: &accounts[0].addr, + Input: hex2Bytes("6080604052348015600f57600080fd5b50483a1015601c57600080fd5b60003a111560315760004811603057600080fd5b5b603f80603e6000396000f3fe6080604052600080fdfea264697066735822122060729c2cee02b10748fae5200f1c9da4661963354973d9154c13a8e9ce9dee1564736f6c63430008130033"), + Gas: func() *hexutil.Uint64 { v := hexutil.Uint64(0); return &v }(), + }, + blockOverrides: override.BlockOverrides{GasLimit: func() *hexutil.Uint64 { v := hexutil.Uint64(50000); return &v }()}, + expectErr: errors.New("gas required exceeds allowance (50000)"), + }, // empty create { blockNumber: rpc.LatestBlockNumber, @@ -861,6 +872,19 @@ func TestEstimateGas(t *testing.T) { }, want: 21000, }, + // blob base fee block override should be applied during estimation. + { + blockNumber: rpc.LatestBlockNumber, + call: TransactionArgs{ + From: &accounts[0].addr, + To: &accounts[1].addr, + Value: (*hexutil.Big)(big.NewInt(1)), + BlobHashes: []common.Hash{{0x01, 0x22}}, + BlobFeeCap: (*hexutil.Big)(big.NewInt(1)), + }, + blockOverrides: override.BlockOverrides{BlobBaseFee: (*hexutil.Big)(big.NewInt(2))}, + expectErr: core.ErrBlobFeeCapTooLow, + }, // // SPDX-License-Identifier: GPL-3.0 //pragma solidity >=0.8.2 <0.9.0; // @@ -1014,7 +1038,7 @@ func TestCall(t *testing.T) { Balance: big.NewInt(params.Ether), Nonce: 1, Storage: map[common.Hash]common.Hash{ - common.Hash{}: common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000001"), + {}: common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000001"), }, }, }, @@ -3795,7 +3819,7 @@ func TestCreateAccessListWithStateOverrides(t *testing.T) { Balance: (*hexutil.Big)(big.NewInt(1000000000000000000)), Nonce: &nonce, State: map[common.Hash]common.Hash{ - common.Hash{}: common.HexToHash("0x000000000000000000000000000000000000000000000000000000000000002a"), + {}: common.HexToHash("0x000000000000000000000000000000000000000000000000000000000000002a"), }, }, } diff --git a/tests/init.go b/tests/init.go index d10b47986c..f115e427a5 100644 --- a/tests/init.go +++ b/tests/init.go @@ -720,6 +720,43 @@ var Forks = map[string]*params.ChainConfig{ BPO4: params.DefaultBPO4BlobConfig, }, }, + "Amsterdam": { + ChainID: big.NewInt(1), + HomesteadBlock: big.NewInt(0), + EIP150Block: big.NewInt(0), + EIP155Block: big.NewInt(0), + EIP158Block: big.NewInt(0), + ByzantiumBlock: big.NewInt(0), + ConstantinopleBlock: big.NewInt(0), + PetersburgBlock: big.NewInt(0), + IstanbulBlock: big.NewInt(0), + MuirGlacierBlock: big.NewInt(0), + BerlinBlock: big.NewInt(0), + LondonBlock: big.NewInt(0), + ArrowGlacierBlock: big.NewInt(0), + MergeNetsplitBlock: big.NewInt(0), + TerminalTotalDifficulty: big.NewInt(0), + ShanghaiTime: u64(0), + CancunTime: u64(0), + PragueTime: u64(0), + OsakaTime: u64(0), + BPO1Time: u64(0), + BPO2Time: u64(0), + BPO3Time: u64(0), + BPO4Time: u64(0), + AmsterdamTime: u64(0), + DepositContractAddress: params.MainnetChainConfig.DepositContractAddress, + BlobScheduleConfig: ¶ms.BlobScheduleConfig{ + Cancun: params.DefaultCancunBlobConfig, + Prague: params.DefaultPragueBlobConfig, + Osaka: params.DefaultOsakaBlobConfig, + BPO1: bpo1BlobConfig, + BPO2: bpo2BlobConfig, + BPO3: params.DefaultBPO3BlobConfig, + BPO4: params.DefaultBPO4BlobConfig, + Amsterdam: params.DefaultBPO4BlobConfig, // TODO update when defined + }, + }, "Verkle": { ChainID: big.NewInt(1), HomesteadBlock: big.NewInt(0), diff --git a/trie/bintrie/stem_node.go b/trie/bintrie/stem_node.go index 3f69261d62..e5729e6182 100644 --- a/trie/bintrie/stem_node.go +++ b/trie/bintrie/stem_node.go @@ -37,7 +37,10 @@ type StemNode struct { // Get retrieves the value for the given key. func (bt *StemNode) Get(key []byte, _ NodeResolverFn) ([]byte, error) { - panic("this should not be called directly") + if !bytes.Equal(bt.Stem, key[:StemSize]) { + return nil, nil + } + return bt.Values[key[StemSize]], nil } // Insert inserts a new key-value pair into the node. diff --git a/trie/bintrie/stem_node_test.go b/trie/bintrie/stem_node_test.go index 92c1b49e02..310c553d39 100644 --- a/trie/bintrie/stem_node_test.go +++ b/trie/bintrie/stem_node_test.go @@ -23,6 +23,50 @@ import ( "github.com/ethereum/go-ethereum/common" ) +// TestStemNodeGet tests the Get method for matching stem, non-matching stem, +// and nil-value suffix scenarios. +func TestStemNodeGet(t *testing.T) { + stem := make([]byte, StemSize) + stem[0] = 0xAB + var values [StemNodeWidth][]byte + values[5] = common.HexToHash("0xdeadbeef").Bytes() + + node := &StemNode{Stem: stem, Values: values[:], depth: 0} + + // Matching stem, populated suffix → returns value. + key := make([]byte, HashSize) + copy(key[:StemSize], stem) + key[StemSize] = 5 + got, err := node.Get(key, nil) + if err != nil { + t.Fatalf("Get error: %v", err) + } + if !bytes.Equal(got, values[5]) { + t.Fatalf("Get = %x, want %x", got, values[5]) + } + + // Matching stem, empty suffix → returns nil (slot not set). + key[StemSize] = 99 + got, err = node.Get(key, nil) + if err != nil { + t.Fatalf("Get error: %v", err) + } + if got != nil { + t.Fatalf("Get(empty suffix) = %x, want nil", got) + } + + // Non-matching stem → returns nil, nil. + otherKey := make([]byte, HashSize) + otherKey[0] = 0xFF + got, err = node.Get(otherKey, nil) + if err != nil { + t.Fatalf("Get error: %v", err) + } + if got != nil { + t.Fatalf("Get(wrong stem) = %x, want nil", got) + } +} + // TestStemNodeInsertSameStem tests inserting values with the same stem func TestStemNodeInsertSameStem(t *testing.T) { stem := make([]byte, 31) diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index 6c29239a87..b1e3c991c0 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -191,7 +191,7 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error case *InternalNode: values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver) case *StemNode: - values = r.Values + values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver) case Empty: return nil, nil default: @@ -216,10 +216,12 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error return nil, nil } - // If the account has been deleted, then values[10] will be 0 and not nil. If it has - // been recreated after that, then its code keccak will NOT be 0. So return `nil` if - // the nonce, and values[10], and code keccak is 0. - if bytes.Equal(values[BasicDataLeafKey], zero[:]) && len(values) > 10 && len(values[10]) > 0 && bytes.Equal(values[CodeHashLeafKey], zero[:]) { + // If the account has been deleted, BasicData and CodeHash will both be + // 32-byte zero blobs (not nil). If the account is recreated afterwards, + // UpdateAccount overwrites BasicData and CodeHash with non-zero values, + // so this branch won't activate.. + if bytes.Equal(values[BasicDataLeafKey], zero[:]) && + bytes.Equal(values[CodeHashLeafKey], zero[:]) { return nil, nil } @@ -294,8 +296,22 @@ func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) er return nil } -// DeleteAccount is a no-op as it is disabled in stateless. +// DeleteAccount erases an account by overwriting the account +// descriptors with 0s. func (t *BinaryTrie) DeleteAccount(addr common.Address) error { + var ( + values = make([][]byte, StemNodeWidth) + stem = GetBinaryTreeKey(addr, zero[:]) + ) + // Clear BasicData (nonce, balance, code size) and CodeHash. + values[BasicDataLeafKey] = zero[:] + values[CodeHashLeafKey] = zero[:] + + root, err := t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0) + if err != nil { + return fmt.Errorf("DeleteAccount (%x) error: %v", addr, err) + } + t.root = root return nil } diff --git a/trie/bintrie/trie_test.go b/trie/bintrie/trie_test.go index 256fd218e2..5b104ddde4 100644 --- a/trie/bintrie/trie_test.go +++ b/trie/bintrie/trie_test.go @@ -267,6 +267,334 @@ func TestStorageRoundTrip(t *testing.T) { } } +// newEmptyTestTrie creates a fresh BinaryTrie with an empty root and a +// default prevalue tracer. Use this for tests that populate the trie +// incrementally via Update*; for tests that want a pre-populated trie with +// a fixed entry set, use makeTrie (in iterator_test.go) instead. +func newEmptyTestTrie(t *testing.T) *BinaryTrie { + t.Helper() + return &BinaryTrie{ + root: NewBinaryNode(), + tracer: trie.NewPrevalueTracer(), + } +} + +// makeAccount constructs a StateAccount with the given fields. The Root is +// zeroed out because the bintrie has no per-account storage root. +func makeAccount(nonce uint64, balance uint64, codeHash common.Hash) *types.StateAccount { + return &types.StateAccount{ + Nonce: nonce, + Balance: uint256.NewInt(balance), + CodeHash: codeHash.Bytes(), + } +} + +// TestDeleteAccountRoundTrip verifies the basic delete path: create an +// account, read it back, delete it, confirm subsequent reads return nil. +// Regression test for the no-op DeleteAccount bug where the deletion was +// silently ignored and the old values remained in the trie. +func TestDeleteAccountRoundTrip(t *testing.T) { + tr := newEmptyTestTrie(t) + addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") + codeHash := common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470") + + // Create: write account, verify round-trip. + acc := makeAccount(42, 1000, codeHash) + if err := tr.UpdateAccount(addr, acc, 0); err != nil { + t.Fatalf("UpdateAccount: %v", err) + } + got, err := tr.GetAccount(addr) + if err != nil { + t.Fatalf("GetAccount: %v", err) + } + if got == nil { + t.Fatal("GetAccount returned nil after UpdateAccount") + } + if got.Nonce != 42 { + t.Fatalf("Nonce: got %d, want 42", got.Nonce) + } + if got.Balance.Uint64() != 1000 { + t.Fatalf("Balance: got %s, want 1000", got.Balance) + } + if !bytes.Equal(got.CodeHash, codeHash[:]) { + t.Fatalf("CodeHash: got %x, want %x", got.CodeHash, codeHash) + } + + // Delete: verify GetAccount returns nil afterwards. + if err := tr.DeleteAccount(addr); err != nil { + t.Fatalf("DeleteAccount: %v", err) + } + got, err = tr.GetAccount(addr) + if err != nil { + t.Fatalf("GetAccount after delete: %v", err) + } + if got != nil { + t.Fatalf("GetAccount after delete: got %+v, want nil", got) + } +} + +// TestDeleteAccountOnMissingAccount verifies that deleting an account that +// was never created does not error and subsequent reads still return nil. +func TestDeleteAccountOnMissingAccount(t *testing.T) { + tr := newEmptyTestTrie(t) + addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") + + // Delete without any prior create. Should not panic or error on an + // empty root, and GetAccount should still return nil. + if err := tr.DeleteAccount(addr); err != nil { + t.Fatalf("DeleteAccount on empty trie: %v", err) + } + got, err := tr.GetAccount(addr) + if err != nil { + t.Fatalf("GetAccount after delete on empty trie: %v", err) + } + if got != nil { + t.Fatalf("GetAccount on deleted missing account: got %+v, want nil", got) + } +} + +// TestDeleteAccountPreservesOtherAccounts verifies that deleting one account +// does not affect accounts at different stems. +func TestDeleteAccountPreservesOtherAccounts(t *testing.T) { + tr := newEmptyTestTrie(t) + addrA := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") + addrB := common.HexToAddress("0xabcdef1234567890abcdef1234567890abcdef12") + codeHashA := common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470") + codeHashB := common.HexToHash("f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff0102030405060708090a0b0c0d0e0f10") + + // Create two distinct accounts. + if err := tr.UpdateAccount(addrA, makeAccount(1, 100, codeHashA), 0); err != nil { + t.Fatalf("UpdateAccount(A): %v", err) + } + if err := tr.UpdateAccount(addrB, makeAccount(2, 200, codeHashB), 0); err != nil { + t.Fatalf("UpdateAccount(B): %v", err) + } + + // Delete A. + if err := tr.DeleteAccount(addrA); err != nil { + t.Fatalf("DeleteAccount(A): %v", err) + } + + // A should be gone. + if got, err := tr.GetAccount(addrA); err != nil { + t.Fatalf("GetAccount(A): %v", err) + } else if got != nil { + t.Fatalf("GetAccount(A) after delete: got %+v, want nil", got) + } + + // B should still be readable with its original values. + got, err := tr.GetAccount(addrB) + if err != nil { + t.Fatalf("GetAccount(B): %v", err) + } + if got == nil { + t.Fatal("GetAccount(B) returned nil after unrelated delete") + } + if got.Nonce != 2 { + t.Fatalf("Account B Nonce: got %d, want 2", got.Nonce) + } + if got.Balance.Uint64() != 200 { + t.Fatalf("Account B Balance: got %s, want 200", got.Balance) + } + if !bytes.Equal(got.CodeHash, codeHashB[:]) { + t.Fatalf("Account B CodeHash: got %x, want %x", got.CodeHash, codeHashB) + } +} + +// TestDeleteAccountThenRecreate verifies that an account can be deleted and +// then recreated with different values; the second read must return the new +// values, not the stale ones from before deletion. +func TestDeleteAccountThenRecreate(t *testing.T) { + tr := newEmptyTestTrie(t) + addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") + codeHash1 := common.HexToHash("1111111111111111111111111111111111111111111111111111111111111111") + codeHash2 := common.HexToHash("2222222222222222222222222222222222222222222222222222222222222222") + + // Create. + if err := tr.UpdateAccount(addr, makeAccount(1, 100, codeHash1), 0); err != nil { + t.Fatalf("UpdateAccount #1: %v", err) + } + // Delete. + if err := tr.DeleteAccount(addr); err != nil { + t.Fatalf("DeleteAccount: %v", err) + } + // Recreate with new values. + if err := tr.UpdateAccount(addr, makeAccount(7, 9999, codeHash2), 0); err != nil { + t.Fatalf("UpdateAccount #2: %v", err) + } + // Read: must observe the new values, not the originals. + got, err := tr.GetAccount(addr) + if err != nil { + t.Fatalf("GetAccount: %v", err) + } + if got == nil { + t.Fatal("GetAccount returned nil after recreate") + } + if got.Nonce != 7 { + t.Fatalf("Nonce: got %d, want 7", got.Nonce) + } + if got.Balance.Uint64() != 9999 { + t.Fatalf("Balance: got %s, want 9999", got.Balance) + } + if !bytes.Equal(got.CodeHash, codeHash2[:]) { + t.Fatalf("CodeHash: got %x, want %x", got.CodeHash, codeHash2) + } +} + +// TestDeleteAccountDoesNotAffectMainStorage verifies that DeleteAccount only +// clears the account's BasicData and CodeHash, leaving main storage slots +// untouched. Main storage slots live at different stems entirely (their +// keys route through the non-header branch in GetBinaryTreeKeyStorageSlot), +// so this test exercises the inter-stem isolation. Header-range storage +// slots share the same stem and are covered separately by +// TestDeleteAccountPreservesHeaderStorage. +// +// Wiping storage on self-destruct is a separate concern handled at the +// StateDB level. +func TestDeleteAccountDoesNotAffectMainStorage(t *testing.T) { + tr := newEmptyTestTrie(t) + addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") + codeHash := common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470") + + // Create account. + if err := tr.UpdateAccount(addr, makeAccount(1, 100, codeHash), 0); err != nil { + t.Fatalf("UpdateAccount: %v", err) + } + // Write a main storage slot — i.e. key[31] >= 64 or key[:31] != 0 — so + // it lives at a different stem from the account header. + slot := common.HexToHash("0000000000000000000000000000000000000000000000000000000000000080") + value := common.TrimLeftZeroes(common.HexToHash("00000000000000000000000000000000000000000000000000000000deadbeef").Bytes()) + if err := tr.UpdateStorage(addr, slot[:], value); err != nil { + t.Fatalf("UpdateStorage: %v", err) + } + + // Delete the account. + if err := tr.DeleteAccount(addr); err != nil { + t.Fatalf("DeleteAccount: %v", err) + } + + // Account should be absent. + got, err := tr.GetAccount(addr) + if err != nil { + t.Fatalf("GetAccount after delete: %v", err) + } + if got != nil { + t.Fatalf("GetAccount after delete: got %+v, want nil", got) + } + + // Main storage slot should still be readable — DeleteAccount must not + // have touched it. + stored, err := tr.GetStorage(addr, slot[:]) + if err != nil { + t.Fatalf("GetStorage after DeleteAccount: %v", err) + } + if len(stored) == 0 { + t.Fatal("main storage slot was wiped by DeleteAccount, expected it to survive") + } + var expected [HashSize]byte + copy(expected[HashSize-len(value):], value) + if !bytes.Equal(stored, expected[:]) { + t.Fatalf("main storage slot: got %x, want %x", stored, expected) + } +} + +// TestDeleteAccountPreservesHeaderStorage verifies that DeleteAccount does +// not clobber header-range storage slots (key[31] < 64), which live at the +// SAME stem as BasicData/CodeHash but at offsets 64-127. The safety here +// relies on StemNode.InsertValuesAtStem treating nil entries in the values +// slice as "do not overwrite"; this test pins that invariant so a future +// change cannot silently corrupt slots 0-63 of any contract. +func TestDeleteAccountPreservesHeaderStorage(t *testing.T) { + tr := newEmptyTestTrie(t) + addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") + codeHash := common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470") + + // Create account. + if err := tr.UpdateAccount(addr, makeAccount(1, 100, codeHash), 0); err != nil { + t.Fatalf("UpdateAccount: %v", err) + } + + // Create a second, unrelated account so the root promotes from StemNode + // to InternalNode. BinaryTrie.GetStorage walks via root.Get, which is + // only implemented on InternalNode/Empty — calling it with a StemNode + // root panics. The existing main-storage test gets away with this because + // the main-storage slot lands on a separate stem and forces the same + // promotion implicitly; here we want a same-stem header slot, so the + // promotion has to come from a second account. + other := common.HexToAddress("0xabcdef1234567890abcdef1234567890abcdef12") + if err := tr.UpdateAccount(other, makeAccount(0, 0, common.Hash{}), 0); err != nil { + t.Fatalf("UpdateAccount(other): %v", err) + } + + // Write a header-range storage slot — key[:31] == 0 and key[31] < 64 + // — which routes through the header branch in GetBinaryTreeKeyStorageSlot + // and lands on the same stem as BasicData/CodeHash. + var slot [HashSize]byte + slot[31] = 5 + value := []byte{0xde, 0xad, 0xbe, 0xef} + if err := tr.UpdateStorage(addr, slot[:], value); err != nil { + t.Fatalf("UpdateStorage: %v", err) + } + + // Delete the account. + if err := tr.DeleteAccount(addr); err != nil { + t.Fatalf("DeleteAccount: %v", err) + } + + // Account metadata should be gone. + got, err := tr.GetAccount(addr) + if err != nil { + t.Fatalf("GetAccount after delete: %v", err) + } + if got != nil { + t.Fatalf("GetAccount after delete: got %+v, want nil", got) + } + + // Header storage slot must survive — DeleteAccount only writes offsets + // BasicDataLeafKey, CodeHashLeafKey, and accountDeletedMarkerKey, leaving + // the header-storage offsets (64-127) untouched. + stored, err := tr.GetStorage(addr, slot[:]) + if err != nil { + t.Fatalf("GetStorage after DeleteAccount: %v", err) + } + if len(stored) == 0 { + t.Fatal("header storage slot was wiped by DeleteAccount, expected it to survive") + } + var expected [HashSize]byte + copy(expected[HashSize-len(value):], value) + if !bytes.Equal(stored, expected[:]) { + t.Fatalf("header storage slot: got %x, want %x", stored, expected) + } +} + +func TestDeleteAccountHashIsDeterministic(t *testing.T) { + addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") + codeHash := common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470") + acc := makeAccount(42, 1000, codeHash) + + run := func() common.Hash { + tr := newEmptyTestTrie(t) + if err := tr.UpdateAccount(addr, acc, 0); err != nil { + t.Fatalf("UpdateAccount: %v", err) + } + if err := tr.DeleteAccount(addr); err != nil { + t.Fatalf("DeleteAccount: %v", err) + } + return tr.Hash() + } + + first := run() + second := run() + if first != second { + t.Fatalf("non-deterministic root after Update+Delete: first=%x second=%x", first, second) + } + + empty := newEmptyTestTrie(t).Hash() + if first == empty { + t.Fatalf("post-delete root unexpectedly equals empty-trie root %x", empty) + } +} + func TestBinaryTrieWitness(t *testing.T) { tracer := trie.NewPrevalueTracer() @@ -292,3 +620,162 @@ func TestBinaryTrieWitness(t *testing.T) { t.Fatal("unexpected witness value for path2") } } + +// testAccount is a helper that creates a BinaryTrie with a tracer and +// inserts a single account, returning the trie. +func testAccount(t *testing.T, addr common.Address, nonce uint64, balance uint64) *BinaryTrie { + t.Helper() + tr := &BinaryTrie{ + root: NewBinaryNode(), + tracer: trie.NewPrevalueTracer(), + } + acc := &types.StateAccount{ + Nonce: nonce, + Balance: uint256.NewInt(balance), + CodeHash: types.EmptyCodeHash[:], + } + if err := tr.UpdateAccount(addr, acc, 0); err != nil { + t.Fatalf("UpdateAccount error: %v", err) + } + return tr +} + +// TestGetAccountNonMembershipStemRoot verifies that querying a non-existent +// address returns nil when the trie root is a StemNode (single-account trie). +// This is a regression test: previously the StemNode branch in GetAccount +// returned the root's values without verifying the stem. +func TestGetAccountNonMembershipStemRoot(t *testing.T) { + addr := common.HexToAddress("0x1111111111111111111111111111111111111111") + tr := testAccount(t, addr, 42, 100) + + // Verify root is a StemNode (single stem inserted). + if _, ok := tr.root.(*StemNode); !ok { + t.Fatalf("expected StemNode root, got %T", tr.root) + } + + // Query a completely different address — must return nil. + other := common.HexToAddress("0x2222222222222222222222222222222222222222") + got, err := tr.GetAccount(other) + if err != nil { + t.Fatalf("GetAccount error: %v", err) + } + if got != nil { + t.Fatalf("expected nil for non-existent account, got nonce=%d balance=%s", got.Nonce, got.Balance) + } + + // Original account must still be retrievable. + got, err = tr.GetAccount(addr) + if err != nil { + t.Fatalf("GetAccount(original) error: %v", err) + } + if got == nil { + t.Fatal("expected original account, got nil") + } + if got.Nonce != 42 { + t.Fatalf("expected nonce=42, got %d", got.Nonce) + } +} + +// TestGetAccountNonMembershipInternalRoot verifies that querying a non-existent +// address returns nil when the trie root is an InternalNode (multi-account trie). +func TestGetAccountNonMembershipInternalRoot(t *testing.T) { + tr := &BinaryTrie{ + root: NewBinaryNode(), + tracer: trie.NewPrevalueTracer(), + } + + // Insert two accounts whose binary tree keys have different first bits + // so the root splits into an InternalNode. + addr1 := common.HexToAddress("0x1111111111111111111111111111111111111111") + addr2 := common.HexToAddress("0x9999999999999999999999999999999999999999") + for _, addr := range []common.Address{addr1, addr2} { + acc := &types.StateAccount{ + Nonce: 1, + Balance: uint256.NewInt(1), + CodeHash: types.EmptyCodeHash[:], + } + if err := tr.UpdateAccount(addr, acc, 0); err != nil { + t.Fatalf("UpdateAccount error: %v", err) + } + } + + // Verify root is an InternalNode. + if _, ok := tr.root.(*InternalNode); !ok { + t.Fatalf("expected InternalNode root, got %T", tr.root) + } + + // Query a non-existent address — must return nil. + other := common.HexToAddress("0x5555555555555555555555555555555555555555") + got, err := tr.GetAccount(other) + if err != nil { + t.Fatalf("GetAccount error: %v", err) + } + if got != nil { + t.Fatalf("expected nil for non-existent account, got nonce=%d", got.Nonce) + } +} + +// TestGetStorageNonMembershipStemRoot verifies that querying storage for a +// non-existent address returns nil when the root is a StemNode. This is a +// regression test: previously StemNode.Get panicked unconditionally. +func TestGetStorageNonMembershipStemRoot(t *testing.T) { + addr := common.HexToAddress("0x1111111111111111111111111111111111111111") + tr := testAccount(t, addr, 1, 100) + + // Verify root is a StemNode. + if _, ok := tr.root.(*StemNode); !ok { + t.Fatalf("expected StemNode root, got %T", tr.root) + } + + // Query storage for a different address — must return nil, not panic. + other := common.HexToAddress("0x2222222222222222222222222222222222222222") + slot := common.HexToHash("0x01") + got, err := tr.GetStorage(other, slot[:]) + if err != nil { + t.Fatalf("GetStorage error: %v", err) + } + if len(got) > 0 && !bytes.Equal(got, zero[:]) { + t.Fatalf("expected nil/zero for non-existent storage, got %x", got) + } +} + +// TestGetStorageNonMembershipInternalRoot verifies that querying storage for a +// non-existent address returns nil when the root is an InternalNode. +func TestGetStorageNonMembershipInternalRoot(t *testing.T) { + tr := &BinaryTrie{ + root: NewBinaryNode(), + tracer: trie.NewPrevalueTracer(), + } + + addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") + acc := &types.StateAccount{ + Nonce: 1, + Balance: uint256.NewInt(1000), + CodeHash: types.EmptyCodeHash[:], + } + if err := tr.UpdateAccount(addr, acc, 0); err != nil { + t.Fatalf("UpdateAccount error: %v", err) + } + + // Add a storage slot so the root becomes an InternalNode (storage + // slots use a different stem than account data). + slot := common.HexToHash("0xFF") + val := common.TrimLeftZeroes(common.HexToHash("0xdeadbeef").Bytes()) + if err := tr.UpdateStorage(addr, slot[:], val); err != nil { + t.Fatalf("UpdateStorage error: %v", err) + } + + if _, ok := tr.root.(*InternalNode); !ok { + t.Fatalf("expected InternalNode root, got %T", tr.root) + } + + // Query storage for a non-existent address — must return nil. + other := common.HexToAddress("0x9999999999999999999999999999999999999999") + got, err := tr.GetStorage(other, slot[:]) + if err != nil { + t.Fatalf("GetStorage error: %v", err) + } + if len(got) > 0 && !bytes.Equal(got, zero[:]) { + t.Fatalf("expected nil/zero for non-existent storage, got %x", got) + } +} diff --git a/triedb/database.go b/triedb/database.go index e7e47bb91a..c1abe93462 100644 --- a/triedb/database.go +++ b/triedb/database.go @@ -367,10 +367,10 @@ func (db *Database) StorageIterator(root common.Hash, account common.Hash, seek // IndexProgress returns the indexing progress made so far. It provides the // number of states that remain unindexed. -func (db *Database) IndexProgress() (uint64, error) { +func (db *Database) IndexProgress() (uint64, uint64, error) { pdb, ok := db.backend.(*pathdb.Database) if !ok { - return 0, errors.New("not supported") + return 0, 0, errors.New("not supported") } return pdb.IndexProgress() } diff --git a/triedb/generate.go b/triedb/generate.go new file mode 100644 index 0000000000..259e139848 --- /dev/null +++ b/triedb/generate.go @@ -0,0 +1,108 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package triedb + +import ( + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/triedb/internal" +) + +// kvAccountIterator wraps an ethdb.Iterator to iterate over account snapshot +// entries in the database, implementing internal.AccountIterator. +type kvAccountIterator struct { + it ethdb.Iterator + hash common.Hash +} + +func newKVAccountIterator(db ethdb.Iteratee) *kvAccountIterator { + it := rawdb.NewKeyLengthIterator( + db.NewIterator(rawdb.SnapshotAccountPrefix, nil), + len(rawdb.SnapshotAccountPrefix)+common.HashLength, + ) + return &kvAccountIterator{it: it} +} + +func (it *kvAccountIterator) Next() bool { + if !it.it.Next() { + return false + } + key := it.it.Key() + copy(it.hash[:], key[len(rawdb.SnapshotAccountPrefix):]) + return true +} + +func (it *kvAccountIterator) Hash() common.Hash { return it.hash } +func (it *kvAccountIterator) Account() []byte { return it.it.Value() } +func (it *kvAccountIterator) Error() error { return it.it.Error() } +func (it *kvAccountIterator) Release() { it.it.Release() } + +// kvStorageIterator wraps an ethdb.Iterator to iterate over storage snapshot +// entries for a specific account, implementing internal.StorageIterator. +type kvStorageIterator struct { + it ethdb.Iterator + hash common.Hash +} + +func newKVStorageIterator(db ethdb.Iteratee, accountHash common.Hash) *kvStorageIterator { + it := rawdb.IterateStorageSnapshots(db, accountHash) + return &kvStorageIterator{it: it} +} + +func (it *kvStorageIterator) Next() bool { + if !it.it.Next() { + return false + } + key := it.it.Key() + copy(it.hash[:], key[len(rawdb.SnapshotStoragePrefix)+common.HashLength:]) + return true +} + +func (it *kvStorageIterator) Hash() common.Hash { return it.hash } +func (it *kvStorageIterator) Slot() []byte { return it.it.Value() } +func (it *kvStorageIterator) Error() error { return it.it.Error() } +func (it *kvStorageIterator) Release() { it.it.Release() } + +// GenerateTrie rebuilds all tries (storage + account) from flat snapshot data +// in the database. It reads account and storage snapshots from the KV store, +// builds tries using StackTrie with streaming node writes, and verifies the +// computed state root matches the expected root. +func GenerateTrie(db ethdb.Database, scheme string, root common.Hash) error { + acctIt := newKVAccountIterator(db) + defer acctIt.Release() + + got, err := internal.GenerateTrieRoot(db, scheme, acctIt, common.Hash{}, internal.StackTrieGenerate, func(dst ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *internal.GenerateStats) (common.Hash, error) { + storageIt := newKVStorageIterator(db, accountHash) + defer storageIt.Release() + + hash, err := internal.GenerateTrieRoot(dst, scheme, storageIt, accountHash, internal.StackTrieGenerate, nil, stat, false) + if err != nil { + return common.Hash{}, err + } + return hash, nil + }, internal.NewGenerateStats(), true) + if err != nil { + return err + } + if got != root { + return fmt.Errorf("state root mismatch: got %x, want %x", got, root) + } + return nil +} diff --git a/triedb/generate_test.go b/triedb/generate_test.go new file mode 100644 index 0000000000..42bccd9aa3 --- /dev/null +++ b/triedb/generate_test.go @@ -0,0 +1,178 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package triedb + +import ( + "bytes" + "sort" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" + "github.com/holiman/uint256" +) + +// testAccount is a helper for building test state with deterministic ordering. +type testAccount struct { + hash common.Hash + account types.StateAccount + storage []testSlot // must be sorted by hash +} + +type testSlot struct { + hash common.Hash + value []byte +} + +// buildExpectedRoot computes the state root from sorted test accounts using +// StackTrie (which requires sorted key insertion). +func buildExpectedRoot(t *testing.T, accounts []testAccount) common.Hash { + t.Helper() + // Sort accounts by hash + sort.Slice(accounts, func(i, j int) bool { + return bytes.Compare(accounts[i].hash[:], accounts[j].hash[:]) < 0 + }) + acctTrie := trie.NewStackTrie(nil) + for i := range accounts { + data, err := rlp.EncodeToBytes(&accounts[i].account) + if err != nil { + t.Fatal(err) + } + acctTrie.Update(accounts[i].hash[:], data) + } + return acctTrie.Hash() +} + +// computeStorageRoot computes the storage trie root from sorted slots. +func computeStorageRoot(slots []testSlot) common.Hash { + sort.Slice(slots, func(i, j int) bool { + return bytes.Compare(slots[i].hash[:], slots[j].hash[:]) < 0 + }) + st := trie.NewStackTrie(nil) + for _, s := range slots { + st.Update(s.hash[:], s.value) + } + return st.Hash() +} + +func TestGenerateTrieEmpty(t *testing.T) { + db := rawdb.NewMemoryDatabase() + if err := GenerateTrie(db, rawdb.HashScheme, types.EmptyRootHash); err != nil { + t.Fatalf("GenerateTrie on empty state failed: %v", err) + } +} + +func TestGenerateTrieAccountsOnly(t *testing.T) { + db := rawdb.NewMemoryDatabase() + + accounts := []testAccount{ + { + hash: common.HexToHash("0x01"), + account: types.StateAccount{ + Nonce: 1, + Balance: uint256.NewInt(100), + Root: types.EmptyRootHash, + CodeHash: types.EmptyCodeHash.Bytes(), + }, + }, + { + hash: common.HexToHash("0x02"), + account: types.StateAccount{ + Nonce: 2, + Balance: uint256.NewInt(200), + Root: types.EmptyRootHash, + CodeHash: types.EmptyCodeHash.Bytes(), + }, + }, + } + for _, a := range accounts { + rawdb.WriteAccountSnapshot(db, a.hash, types.SlimAccountRLP(a.account)) + } + root := buildExpectedRoot(t, accounts) + + if err := GenerateTrie(db, rawdb.HashScheme, root); err != nil { + t.Fatalf("GenerateTrie failed: %v", err) + } +} + +func TestGenerateTrieWithStorage(t *testing.T) { + db := rawdb.NewMemoryDatabase() + + slots := []testSlot{ + {hash: common.HexToHash("0xaa"), value: []byte{0x01, 0x02, 0x03}}, + {hash: common.HexToHash("0xbb"), value: []byte{0x04, 0x05, 0x06}}, + } + storageRoot := computeStorageRoot(slots) + + accounts := []testAccount{ + { + hash: common.HexToHash("0x01"), + account: types.StateAccount{ + Nonce: 1, + Balance: uint256.NewInt(100), + Root: storageRoot, + CodeHash: types.EmptyCodeHash.Bytes(), + }, + storage: slots, + }, + { + hash: common.HexToHash("0x02"), + account: types.StateAccount{ + Nonce: 0, + Balance: uint256.NewInt(50), + Root: types.EmptyRootHash, + CodeHash: types.EmptyCodeHash.Bytes(), + }, + }, + } + // Write account snapshots + for _, a := range accounts { + rawdb.WriteAccountSnapshot(db, a.hash, types.SlimAccountRLP(a.account)) + } + // Write storage snapshots + for _, a := range accounts { + for _, s := range a.storage { + rawdb.WriteStorageSnapshot(db, a.hash, s.hash, s.value) + } + } + root := buildExpectedRoot(t, accounts) + + if err := GenerateTrie(db, rawdb.HashScheme, root); err != nil { + t.Fatalf("GenerateTrie failed: %v", err) + } +} + +func TestGenerateTrieRootMismatch(t *testing.T) { + db := rawdb.NewMemoryDatabase() + + acct := types.StateAccount{ + Nonce: 1, + Balance: uint256.NewInt(100), + Root: types.EmptyRootHash, + CodeHash: types.EmptyCodeHash.Bytes(), + } + rawdb.WriteAccountSnapshot(db, common.HexToHash("0x01"), types.SlimAccountRLP(acct)) + + wrongRoot := common.HexToHash("0xdeadbeef") + err := GenerateTrie(db, rawdb.HashScheme, wrongRoot) + if err == nil { + t.Fatal("expected error for root mismatch, got nil") + } +} diff --git a/triedb/internal/conversion.go b/triedb/internal/conversion.go new file mode 100644 index 0000000000..b331b63e21 --- /dev/null +++ b/triedb/internal/conversion.go @@ -0,0 +1,363 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package internal contains shared trie generation utilities used by both +// triedb and triedb/pathdb. All code is ported from +// core/state/snapshot/conversion.go (with exported names) unless noted. +package internal + +import ( + "encoding/binary" + "fmt" + "math" + "runtime" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +// Iterator is an iterator to step over all the accounts or the specific +// storage in a snapshot which may or may not be composed of multiple layers. +type Iterator interface { + // Next steps the iterator forward one element, returning false if exhausted, + // or an error if iteration failed for some reason (e.g. root being iterated + // becomes stale and garbage collected). + Next() bool + + // Error returns any failure that occurred during iteration, which might have + // caused a premature iteration exit (e.g. snapshot stack becoming stale). + Error() error + + // Hash returns the hash of the account or storage slot the iterator is + // currently at. + Hash() common.Hash + + // Release releases associated resources. Release should always succeed and + // can be called multiple times without causing error. + Release() +} + +// AccountIterator is an iterator to step over all the accounts in a snapshot, +// which may or may not be composed of multiple layers. +type AccountIterator interface { + Iterator + + // Account returns the RLP encoded slim account the iterator is currently at. + // An error will be returned if the iterator becomes invalid + Account() []byte +} + +// StorageIterator is an iterator to step over the specific storage in a snapshot, +// which may or may not be composed of multiple layers. +type StorageIterator interface { + Iterator + + // Slot returns the storage slot the iterator is currently at. An error will + // be returned if the iterator becomes invalid + Slot() []byte +} + +// TrieKV represents a trie key-value pair. +type TrieKV struct { + Key common.Hash + Value []byte +} + +type ( + // TrieGeneratorFn is the interface of trie generation which can + // be implemented by different trie algorithm. + TrieGeneratorFn func(db ethdb.KeyValueWriter, scheme string, owner common.Hash, in chan (TrieKV), out chan (common.Hash)) + + // LeafCallbackFn is the callback invoked at the leaves of the trie, + // returns the subtrie root with the specified subtrie identifier. + LeafCallbackFn func(db ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *GenerateStats) (common.Hash, error) +) + +// GenerateStats is a collection of statistics gathered by the trie generator +// for logging purposes. +type GenerateStats struct { + head common.Hash + start time.Time + + accounts uint64 // Number of accounts done (including those being crawled) + slots uint64 // Number of storage slots done (including those being crawled) + + slotsStart map[common.Hash]time.Time // Start time for account slot crawling + slotsHead map[common.Hash]common.Hash // Slot head for accounts being crawled + + lock sync.RWMutex +} + +// NewGenerateStats creates a new generator stats. +func NewGenerateStats() *GenerateStats { + return &GenerateStats{ + slotsStart: make(map[common.Hash]time.Time), + slotsHead: make(map[common.Hash]common.Hash), + start: time.Now(), + } +} + +// ProgressAccounts updates the generator stats for the account range. +func (stat *GenerateStats) ProgressAccounts(account common.Hash, done uint64) { + stat.lock.Lock() + defer stat.lock.Unlock() + + stat.accounts += done + stat.head = account +} + +// FinishAccounts updates the generator stats for the finished account range. +func (stat *GenerateStats) FinishAccounts(done uint64) { + stat.lock.Lock() + defer stat.lock.Unlock() + + stat.accounts += done +} + +// ProgressContract updates the generator stats for a specific in-progress contract. +func (stat *GenerateStats) ProgressContract(account common.Hash, slot common.Hash, done uint64) { + stat.lock.Lock() + defer stat.lock.Unlock() + + stat.slots += done + stat.slotsHead[account] = slot + if _, ok := stat.slotsStart[account]; !ok { + stat.slotsStart[account] = time.Now() + } +} + +// FinishContract updates the generator stats for a specific just-finished contract. +func (stat *GenerateStats) FinishContract(account common.Hash, done uint64) { + stat.lock.Lock() + defer stat.lock.Unlock() + + stat.slots += done + delete(stat.slotsHead, account) + delete(stat.slotsStart, account) +} + +// Report prints the cumulative progress statistic smartly. +func (stat *GenerateStats) Report() { + stat.lock.RLock() + defer stat.lock.RUnlock() + + ctx := []interface{}{ + "accounts", stat.accounts, + "slots", stat.slots, + "elapsed", common.PrettyDuration(time.Since(stat.start)), + } + if stat.accounts > 0 { + if done := binary.BigEndian.Uint64(stat.head[:8]) / stat.accounts; done > 0 { + var ( + left = (math.MaxUint64 - binary.BigEndian.Uint64(stat.head[:8])) / stat.accounts + eta = common.CalculateETA(done, left, time.Since(stat.start)) + ) + // If there are large contract crawls in progress, estimate their finish time + for acc, head := range stat.slotsHead { + start := stat.slotsStart[acc] + if done := binary.BigEndian.Uint64(head[:8]); done > 0 { + left := math.MaxUint64 - binary.BigEndian.Uint64(head[:8]) + + // Override the ETA if larger than the largest until now + if slotETA := common.CalculateETA(done, left, time.Since(start)); eta < slotETA { + eta = slotETA + } + } + } + ctx = append(ctx, []interface{}{ + "eta", common.PrettyDuration(eta), + }...) + } + } + log.Info("Iterating state snapshot", ctx...) +} + +// ReportDone prints the last log when the whole generation is finished. +func (stat *GenerateStats) ReportDone() { + stat.lock.RLock() + defer stat.lock.RUnlock() + + var ctx []interface{} + ctx = append(ctx, []interface{}{"accounts", stat.accounts}...) + if stat.slots != 0 { + ctx = append(ctx, []interface{}{"slots", stat.slots}...) + } + ctx = append(ctx, []interface{}{"elapsed", common.PrettyDuration(time.Since(stat.start))}...) + log.Info("Iterated snapshot", ctx...) +} + +// RunReport periodically prints the progress information. +func RunReport(stats *GenerateStats, stop chan bool) { + timer := time.NewTimer(0) + defer timer.Stop() + + for { + select { + case <-timer.C: + stats.Report() + timer.Reset(time.Second * 8) + case success := <-stop: + if success { + stats.ReportDone() + } + return + } + } +} + +// GenerateTrieRoot generates the trie hash based on the snapshot iterator. +// It can be used for generating account trie, storage trie or even the +// whole state which connects the accounts and the corresponding storages. +func GenerateTrieRoot(db ethdb.KeyValueWriter, scheme string, it Iterator, account common.Hash, generatorFn TrieGeneratorFn, leafCallback LeafCallbackFn, stats *GenerateStats, report bool) (common.Hash, error) { + var ( + in = make(chan TrieKV) // chan to pass leaves + out = make(chan common.Hash, 1) // chan to collect result + stoplog = make(chan bool, 1) // 1-size buffer, works when logging is not enabled + wg sync.WaitGroup + ) + // Spin up a go-routine for trie hash re-generation + wg.Add(1) + go func() { + defer wg.Done() + generatorFn(db, scheme, account, in, out) + }() + // Spin up a go-routine for progress logging + if report && stats != nil { + wg.Add(1) + go func() { + defer wg.Done() + RunReport(stats, stoplog) + }() + } + // Create a semaphore to assign tasks and collect results through. We'll pre- + // fill it with nils, thus using the same channel for both limiting concurrent + // processing and gathering results. + threads := runtime.NumCPU() + results := make(chan error, threads) + for i := 0; i < threads; i++ { + results <- nil // fill the semaphore + } + // stop is a helper function to shutdown the background threads + // and return the re-generated trie hash. + stop := func(fail error) (common.Hash, error) { + close(in) + result := <-out + for i := 0; i < threads; i++ { + if err := <-results; err != nil && fail == nil { + fail = err + } + } + stoplog <- fail == nil + + wg.Wait() + return result, fail + } + var ( + logged = time.Now() + processed = uint64(0) + leaf TrieKV + ) + // Start to feed leaves + for it.Next() { + if account == (common.Hash{}) { + var ( + err error + fullData []byte + ) + if leafCallback == nil { + fullData, err = types.FullAccountRLP(it.(AccountIterator).Account()) + if err != nil { + return stop(err) + } + } else { + // Wait until the semaphore allows us to continue, aborting if + // a sub-task failed + if err := <-results; err != nil { + results <- nil // stop will drain the results, add a noop back for this error we just consumed + return stop(err) + } + // Fetch the next account and process it concurrently + account, err := types.FullAccount(it.(AccountIterator).Account()) + if err != nil { + return stop(err) + } + go func(hash common.Hash) { + subroot, err := leafCallback(db, hash, common.BytesToHash(account.CodeHash), stats) + if err != nil { + results <- err + return + } + if account.Root != subroot { + results <- fmt.Errorf("invalid subroot(path %x), want %x, have %x", hash, account.Root, subroot) + return + } + results <- nil + }(it.Hash()) + fullData, err = rlp.EncodeToBytes(account) + if err != nil { + return stop(err) + } + } + leaf = TrieKV{it.Hash(), fullData} + } else { + leaf = TrieKV{it.Hash(), common.CopyBytes(it.(StorageIterator).Slot())} + } + in <- leaf + + // Accumulate the generation statistic if it's required. + processed++ + if time.Since(logged) > 3*time.Second && stats != nil { + if account == (common.Hash{}) { + stats.ProgressAccounts(it.Hash(), processed) + } else { + stats.ProgressContract(account, it.Hash(), processed) + } + logged, processed = time.Now(), 0 + } + } + // Commit the last part statistic. + if processed > 0 && stats != nil { + if account == (common.Hash{}) { + stats.FinishAccounts(processed) + } else { + stats.FinishContract(account, processed) + } + } + return stop(nil) +} + +// StackTrieGenerate is the trie generation function that creates a StackTrie +// and persists nodes via rawdb.WriteTrieNode. +func StackTrieGenerate(db ethdb.KeyValueWriter, scheme string, owner common.Hash, in chan TrieKV, out chan common.Hash) { + var onTrieNode trie.OnTrieNode + if db != nil { + onTrieNode = func(path []byte, hash common.Hash, blob []byte) { + rawdb.WriteTrieNode(db, owner, path, hash, blob, scheme) + } + } + t := trie.NewStackTrie(onTrieNode) + for leaf := range in { + t.Update(leaf.Key[:], leaf.Value) + } + out <- t.Hash() +} diff --git a/triedb/pathdb/database.go b/triedb/pathdb/database.go index 86a42c69f4..a61d302b1d 100644 --- a/triedb/pathdb/database.go +++ b/triedb/pathdb/database.go @@ -626,11 +626,26 @@ func (db *Database) HistoryRange() (uint64, uint64, error) { // IndexProgress returns the indexing progress made so far. It provides the // number of states that remain unindexed. -func (db *Database) IndexProgress() (uint64, error) { - if db.stateIndexer == nil { - return 0, nil +func (db *Database) IndexProgress() (uint64, uint64, error) { + var ( + stateProgress uint64 + trieProgress uint64 + ) + if db.stateIndexer != nil { + prog, err := db.stateIndexer.progress() + if err != nil { + return 0, 0, err + } + stateProgress = prog } - return db.stateIndexer.progress() + if db.trienodeIndexer != nil { + prog, err := db.trienodeIndexer.progress() + if err != nil { + return 0, 0, err + } + trieProgress = prog + } + return stateProgress, trieProgress, nil } // AccountIterator creates a new account iterator for the specified root hash and diff --git a/triedb/pathdb/database_test.go b/triedb/pathdb/database_test.go index 8ece83cad7..e70a3ec2a2 100644 --- a/triedb/pathdb/database_test.go +++ b/triedb/pathdb/database_test.go @@ -987,7 +987,7 @@ func TestDatabaseIndexRecovery(t *testing.T) { t.Fatalf("Unexpected state history found, %d", i) } } - remain, err := env.db.IndexProgress() + remain, _, err := env.db.IndexProgress() if err != nil { t.Fatalf("Failed to obtain the progress, %v", err) } @@ -1001,7 +1001,7 @@ func TestDatabaseIndexRecovery(t *testing.T) { panic(fmt.Errorf("failed to update state changes, err: %w", err)) } } - remain, err = env.db.IndexProgress() + remain, _, err = env.db.IndexProgress() if err != nil { t.Fatalf("Failed to obtain the progress, %v", err) } diff --git a/triedb/pathdb/disklayer.go b/triedb/pathdb/disklayer.go index 5bad19b4f5..50c7279d0e 100644 --- a/triedb/pathdb/disklayer.go +++ b/triedb/pathdb/disklayer.go @@ -408,6 +408,11 @@ func (dl *diskLayer) writeHistory(typ historyType, diff *diffLayer) (bool, error if err != nil { return false, err } + // Notify the index pruner about the new tail so that stale index + // blocks referencing the pruned histories can be cleaned up. + if indexer != nil && pruned > 0 { + indexer.prune(newFirst) + } log.Debug("Pruned history", "type", typ, "items", pruned, "tailid", newFirst) return false, nil } diff --git a/triedb/pathdb/history_index_pruner.go b/triedb/pathdb/history_index_pruner.go new file mode 100644 index 0000000000..c9be3618e8 --- /dev/null +++ b/triedb/pathdb/history_index_pruner.go @@ -0,0 +1,385 @@ +// Copyright 2025 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package pathdb + +import ( + "encoding/binary" + "sync" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" +) + +const ( + // indexPruningThreshold defines the number of pruned histories that must + // accumulate before triggering index pruning. This helps avoid scheduling + // index pruning too frequently. + indexPruningThreshold = 90000 + + // iteratorReopenInterval is how long the iterator is kept open before + // being released and re-opened. Long-lived iterators hold a read snapshot + // that blocks LSM compaction; periodically re-opening avoids stalling the + // compactor during a large scan. + iteratorReopenInterval = 30 * time.Second +) + +// indexPruner is responsible for pruning stale index data from the tail side +// when old history objects are removed. It runs as a background goroutine and +// processes pruning signals whenever the history tail advances. +// +// The pruning operates at the block level: for each state element's index +// metadata, leading index blocks whose maximum history ID falls below the +// new tail are removed entirely. This avoids the need to decode individual +// block contents and is efficient because index blocks store monotonically +// increasing history IDs. +type indexPruner struct { + disk ethdb.KeyValueStore + typ historyType + tail atomic.Uint64 // Tail below which index entries can be pruned + lastRun uint64 // The tail in the last pruning run + trigger chan struct{} // Non-blocking signal that tail has advanced + closed chan struct{} + wg sync.WaitGroup + log log.Logger + + pauseReq chan chan struct{} // Pause request; caller sends ack channel, pruner closes it when paused + resumeCh chan struct{} // Resume signal sent by caller after indexSingle/unindexSingle completes +} + +// newIndexPruner creates and starts a new index pruner for the given history type. +func newIndexPruner(disk ethdb.KeyValueStore, typ historyType) *indexPruner { + p := &indexPruner{ + disk: disk, + typ: typ, + trigger: make(chan struct{}, 1), + closed: make(chan struct{}), + log: log.New("type", typ.String()), + pauseReq: make(chan chan struct{}), + resumeCh: make(chan struct{}), + } + p.wg.Add(1) + go p.run() + return p +} + +// prune signals the pruner that the history tail has advanced to the given ID. +// All index entries referencing history IDs below newTail can be removed. +func (p *indexPruner) prune(newTail uint64) { + // Only update if the tail is actually advancing + for { + old := p.tail.Load() + if newTail <= old { + return + } + if p.tail.CompareAndSwap(old, newTail) { + break + } + } + // Non-blocking signal + select { + case p.trigger <- struct{}{}: + default: + } +} + +// pause requests the pruner to flush all pending writes and pause. It blocks +// until the pruner has acknowledged the pause. This must be paired with a +// subsequent call to resume. +func (p *indexPruner) pause() { + ack := make(chan struct{}) + select { + case p.pauseReq <- ack: + <-ack // wait for the pruner to flush and acknowledge + case <-p.closed: + } +} + +// resume unblocks a previously paused pruner, allowing it to continue +// processing. +func (p *indexPruner) resume() { + select { + case p.resumeCh <- struct{}{}: + case <-p.closed: + } +} + +// close shuts down the pruner and waits for it to finish. +func (p *indexPruner) close() { + select { + case <-p.closed: + return + default: + close(p.closed) + p.wg.Wait() + } +} + +// run is the main loop of the pruner. It waits for trigger signals and +// processes a small batch of entries on each trigger, advancing the cursor. +func (p *indexPruner) run() { + defer p.wg.Done() + + for { + select { + case <-p.trigger: + tail := p.tail.Load() + if tail < p.lastRun || tail-p.lastRun < indexPruningThreshold { + continue + } + if err := p.process(tail); err != nil { + p.log.Error("Failed to prune index", "tail", tail, "err", err) + } else { + p.lastRun = tail + } + + case ack := <-p.pauseReq: + // Pruner is idle, acknowledge immediately and wait for resume. + close(ack) + select { + case <-p.resumeCh: + case <-p.closed: + return + } + + case <-p.closed: + return + } + } +} + +// process iterates all index metadata entries for the history type and prunes +// leading blocks whose max history ID is below the given tail. +func (p *indexPruner) process(tail uint64) error { + var ( + err error + pruned int + start = time.Now() + ) + switch p.typ { + case typeStateHistory: + n, err := p.prunePrefix(rawdb.StateHistoryAccountMetadataPrefix, typeAccount, tail) + if err != nil { + return err + } + pruned += n + + n, err = p.prunePrefix(rawdb.StateHistoryStorageMetadataPrefix, typeStorage, tail) + if err != nil { + return err + } + pruned += n + statePruneHistoryIndexTimer.UpdateSince(start) + + case typeTrienodeHistory: + pruned, err = p.prunePrefix(rawdb.TrienodeHistoryMetadataPrefix, typeTrienode, tail) + if err != nil { + return err + } + trienodePruneHistoryIndexTimer.UpdateSince(start) + + default: + panic("unknown history type") + } + if pruned > 0 { + p.log.Info("Pruned stale index blocks", "pruned", pruned, "tail", tail, "elapsed", common.PrettyDuration(time.Since(start))) + } + return nil +} + +// prunePrefix scans all metadata entries under the given prefix and prunes +// leading index blocks below the tail. The iterator is periodically released +// and re-opened to avoid holding a read snapshot that blocks LSM compaction. +func (p *indexPruner) prunePrefix(prefix []byte, elemType elementType, tail uint64) (int, error) { + var ( + pruned int + opened = time.Now() + it = p.disk.NewIterator(prefix, nil) + batch = p.disk.NewBatchWithSize(ethdb.IdealBatchSize) + ) + for { + // Terminate if iterator is exhausted + if !it.Next() { + it.Release() + break + } + // Check termination or pause request + select { + case <-p.closed: + // Terminate the process if indexer is closed + it.Release() + if batch.ValueSize() > 0 { + return pruned, batch.Write() + } + return pruned, nil + + case ack := <-p.pauseReq: + // Save the current position so that after resume the + // iterator can be re-opened from where it left off. + start := common.CopyBytes(it.Key()[len(prefix):]) + it.Release() + + // Flush all pending writes before acknowledging the pause. + var flushErr error + if batch.ValueSize() > 0 { + if err := batch.Write(); err != nil { + flushErr = err + } + batch.Reset() + } + close(ack) + + // Block until resumed or closed. Always wait here even if + // the flush failed — returning early would cause resume() + // to deadlock since nobody would receive on resumeCh. + select { + case <-p.resumeCh: + if flushErr != nil { + return 0, flushErr + } + // Re-open the iterator from the saved position so the + // pruner sees the current database state (including any + // writes made by indexer during the pause). + it = p.disk.NewIterator(prefix, start) + opened = time.Now() + continue + case <-p.closed: + return pruned, flushErr + } + + default: + // Keep processing + } + + // Prune the index data block + key, value := it.Key(), it.Value() + ident, bsize := p.identFromKey(key, prefix, elemType) + n, err := p.pruneEntry(batch, ident, value, bsize, tail) + if err != nil { + p.log.Warn("Failed to prune index entry", "ident", ident, "err", err) + continue + } + pruned += n + + // Flush the batch if there are too many accumulated + if batch.ValueSize() >= ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + it.Release() + return 0, err + } + batch.Reset() + } + + // Periodically release the iterator so the LSM compactor + // is not blocked by the read snapshot we hold. + if time.Since(opened) >= iteratorReopenInterval { + opened = time.Now() + + start := common.CopyBytes(it.Key()[len(prefix):]) + it.Release() + it = p.disk.NewIterator(prefix, start) + } + } + if batch.ValueSize() > 0 { + if err := batch.Write(); err != nil { + return 0, err + } + } + return pruned, nil +} + +// identFromKey reconstructs the stateIdent and bitmapSize from a metadata key. +func (p *indexPruner) identFromKey(key []byte, prefix []byte, elemType elementType) (stateIdent, int) { + rest := key[len(prefix):] + + switch elemType { + case typeAccount: + // key = prefix + addressHash(32) + var addrHash common.Hash + copy(addrHash[:], rest[:32]) + return newAccountIdent(addrHash), 0 + + case typeStorage: + // key = prefix + addressHash(32) + storageHash(32) + var addrHash, storHash common.Hash + copy(addrHash[:], rest[:32]) + copy(storHash[:], rest[32:64]) + return newStorageIdent(addrHash, storHash), 0 + + case typeTrienode: + // key = prefix + addressHash(32) + path(variable) + var addrHash common.Hash + copy(addrHash[:], rest[:32]) + path := string(rest[32:]) + ident := newTrienodeIdent(addrHash, path) + return ident, ident.bloomSize() + + default: + panic("unknown element type") + } +} + +// pruneEntry checks a single metadata entry and removes leading index blocks +// whose max < tail. Returns the number of blocks pruned. +func (p *indexPruner) pruneEntry(batch ethdb.Batch, ident stateIdent, blob []byte, bsize int, tail uint64) (int, error) { + // Fast path: the first 8 bytes of the metadata encode the max history ID + // of the first index block (big-endian uint64). If it is >= tail, no + // blocks can be pruned and we skip the full parse entirely. + if len(blob) >= 8 && binary.BigEndian.Uint64(blob[:8]) >= tail { + return 0, nil + } + descList, err := parseIndex(blob, bsize) + if err != nil { + return 0, err + } + // Find the number of leading blocks that can be entirely pruned. + // A block can be pruned if its max history ID is strictly below + // the tail. + var count int + for _, desc := range descList { + if desc.max < tail { + count++ + } else { + break // blocks are ordered, no more to prune + } + } + if count == 0 { + return 0, nil + } + // Delete the pruned index blocks + for i := 0; i < count; i++ { + deleteStateIndexBlock(ident, batch, descList[i].id) + } + // Update or delete the metadata + remaining := descList[count:] + if len(remaining) == 0 { + // All blocks pruned, remove the metadata entry entirely + deleteStateIndex(ident, batch) + } else { + // Rewrite the metadata with the remaining blocks + size := indexBlockDescSize + bsize + buf := make([]byte, 0, size*len(remaining)) + for _, desc := range remaining { + buf = append(buf, desc.encode()...) + } + writeStateIndex(ident, batch, buf) + } + return count, nil +} diff --git a/triedb/pathdb/history_index_pruner_test.go b/triedb/pathdb/history_index_pruner_test.go new file mode 100644 index 0000000000..b3094de3e6 --- /dev/null +++ b/triedb/pathdb/history_index_pruner_test.go @@ -0,0 +1,355 @@ +// Copyright 2025 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package pathdb + +import ( + "math" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" +) + +func writeMultiBlockIndex(t *testing.T, db ethdb.Database, ident stateIdent, bitmapSize int, startID uint64) []*indexBlockDesc { + t.Helper() + + if startID == 0 { + startID = 1 + } + iw, _ := newIndexWriter(db, ident, 0, bitmapSize) + + for i := 0; i < 10000; i++ { + if err := iw.append(startID+uint64(i), randomExt(bitmapSize, 5)); err != nil { + t.Fatalf("Failed to append element %d: %v", i, err) + } + } + batch := db.NewBatch() + iw.finish(batch) + if err := batch.Write(); err != nil { + t.Fatalf("Failed to write batch: %v", err) + } + + blob := readStateIndex(ident, db) + descList, err := parseIndex(blob, bitmapSize) + if err != nil { + t.Fatalf("Failed to parse index: %v", err) + } + return descList +} + +// TestPruneEntryBasic verifies that pruneEntry correctly removes leading index +// blocks whose max is below the given tail. +func TestPruneEntryBasic(t *testing.T) { + db := rawdb.NewMemoryDatabase() + ident := newAccountIdent(common.Hash{0xa}) + descList := writeMultiBlockIndex(t, db, ident, 0, 1) + + // Prune with a tail that is above the first block's max but below the second + firstBlockMax := descList[0].max + + pruner := newIndexPruner(db, typeStateHistory) + defer pruner.close() + + if err := pruner.process(firstBlockMax + 1); err != nil { + t.Fatalf("Failed to process pruning: %v", err) + } + + // Verify the first block was removed + blob := readStateIndex(ident, db) + if len(blob) == 0 { + t.Fatal("Index metadata should not be empty after partial prune") + } + remaining, err := parseIndex(blob, 0) + if err != nil { + t.Fatalf("Failed to parse index after prune: %v", err) + } + if len(remaining) != len(descList)-1 { + t.Fatalf("Expected %d blocks remaining, got %d", len(descList)-1, len(remaining)) + } + // The first remaining block should be what was previously the second block + if remaining[0].id != descList[1].id { + t.Fatalf("Expected first remaining block id %d, got %d", descList[1].id, remaining[0].id) + } + + // Verify the pruned block data is actually deleted + blockData := readStateIndexBlock(ident, db, descList[0].id) + if len(blockData) != 0 { + t.Fatal("Pruned block data should have been deleted") + } + + // Remaining blocks should still have their data + for _, desc := range remaining { + blockData = readStateIndexBlock(ident, db, desc.id) + if len(blockData) == 0 { + t.Fatalf("Block %d data should still exist", desc.id) + } + } +} + +// TestPruneEntryBasicTrienode is the same as TestPruneEntryBasic but for +// trienode index entries with a non-zero bitmapSize. +func TestPruneEntryBasicTrienode(t *testing.T) { + db := rawdb.NewMemoryDatabase() + addrHash := common.Hash{0xa} + path := string([]byte{0x0, 0x0, 0x0}) + ident := newTrienodeIdent(addrHash, path) + + descList := writeMultiBlockIndex(t, db, ident, ident.bloomSize(), 1) + firstBlockMax := descList[0].max + + pruner := newIndexPruner(db, typeTrienodeHistory) + defer pruner.close() + + if err := pruner.process(firstBlockMax + 1); err != nil { + t.Fatalf("Failed to process pruning: %v", err) + } + + blob := readStateIndex(ident, db) + remaining, err := parseIndex(blob, ident.bloomSize()) + if err != nil { + t.Fatalf("Failed to parse index after prune: %v", err) + } + if len(remaining) != len(descList)-1 { + t.Fatalf("Expected %d blocks remaining, got %d", len(descList)-1, len(remaining)) + } + if remaining[0].id != descList[1].id { + t.Fatalf("Expected first remaining block id %d, got %d", descList[1].id, remaining[0].id) + } + blockData := readStateIndexBlock(ident, db, descList[0].id) + if len(blockData) != 0 { + t.Fatal("Pruned block data should have been deleted") + } +} + +// TestPruneEntryComplete verifies that when all blocks are pruned, the metadata +// entry is also deleted. +func TestPruneEntryComplete(t *testing.T) { + db := rawdb.NewMemoryDatabase() + ident := newAccountIdent(common.Hash{0xb}) + iw, _ := newIndexWriter(db, ident, 0, 0) + + for i := 1; i <= 10; i++ { + if err := iw.append(uint64(i), nil); err != nil { + t.Fatalf("Failed to append: %v", err) + } + } + batch := db.NewBatch() + iw.finish(batch) + if err := batch.Write(); err != nil { + t.Fatalf("Failed to write: %v", err) + } + + pruner := newIndexPruner(db, typeStateHistory) + defer pruner.close() + + // Prune with tail above all elements + if err := pruner.process(11); err != nil { + t.Fatalf("Failed to process: %v", err) + } + + // Metadata entry should be deleted + blob := readStateIndex(ident, db) + if len(blob) != 0 { + t.Fatal("Index metadata should be empty after full prune") + } +} + +// TestPruneNoop verifies that pruning does nothing when the tail is below all +// block maximums. +func TestPruneNoop(t *testing.T) { + db := rawdb.NewMemoryDatabase() + ident := newAccountIdent(common.Hash{0xc}) + iw, _ := newIndexWriter(db, ident, 0, 0) + + for i := 100; i <= 200; i++ { + if err := iw.append(uint64(i), nil); err != nil { + t.Fatalf("Failed to append: %v", err) + } + } + batch := db.NewBatch() + iw.finish(batch) + if err := batch.Write(); err != nil { + t.Fatalf("Failed to write: %v", err) + } + + blob := readStateIndex(ident, db) + origLen := len(blob) + + pruner := newIndexPruner(db, typeStateHistory) + defer pruner.close() + + if err := pruner.process(50); err != nil { + t.Fatalf("Failed to process: %v", err) + } + + // Nothing should have changed + blob = readStateIndex(ident, db) + if len(blob) != origLen { + t.Fatalf("Expected no change, original len %d, got %d", origLen, len(blob)) + } +} + +// TestPrunePreservesReadability verifies that after pruning, the remaining +// index data is still readable and returns correct results. +func TestPrunePreservesReadability(t *testing.T) { + db := rawdb.NewMemoryDatabase() + ident := newAccountIdent(common.Hash{0xe}) + descList := writeMultiBlockIndex(t, db, ident, 0, 1) + firstBlockMax := descList[0].max + + pruner := newIndexPruner(db, typeStateHistory) + defer pruner.close() + + if err := pruner.process(firstBlockMax + 1); err != nil { + t.Fatalf("Failed to process: %v", err) + } + + // Read the remaining index and verify lookups still work + ir, err := newIndexReader(db, ident, 0) + if err != nil { + t.Fatalf("Failed to create reader: %v", err) + } + + // Looking for something greater than firstBlockMax should still work + result, err := ir.readGreaterThan(firstBlockMax) + if err != nil { + t.Fatalf("Failed to read: %v", err) + } + if result != firstBlockMax+1 { + t.Fatalf("Expected %d, got %d", firstBlockMax+1, result) + } + + // Looking for the last element should return MaxUint64 + result, err = ir.readGreaterThan(20000) + if err != nil { + t.Fatalf("Failed to read: %v", err) + } + if result != math.MaxUint64 { + t.Fatalf("Expected MaxUint64, got %d", result) + } +} + +// TestPrunePauseResume verifies the pause/resume mechanism: +// - The pruner pauses mid-iteration and flushes its batch +// - Data written while the pruner is paused (simulating indexSingle) is +// visible after resume via a fresh iterator +// - Pruning still completes correctly after resume +func TestPrunePauseResume(t *testing.T) { + db := rawdb.NewMemoryDatabase() + + // Create many accounts with multi-block indexes so the pruner is still + // iterating when the pause request arrives. + var firstBlockMax uint64 + for i := 0; i < 200; i++ { + hash := common.Hash{byte(i)} + ident := newAccountIdent(hash) + descList := writeMultiBlockIndex(t, db, ident, 0, 1) + if i == 0 { + firstBlockMax = descList[0].max + } + } + // Target account at the end of the key space — the pruner should not + // have visited it yet when the pause is acknowledged. + targetIdent := newAccountIdent(common.Hash{0xff}) + targetDescList := writeMultiBlockIndex(t, db, targetIdent, 0, 1) + + tail := firstBlockMax + 1 + + // Construct the pruner without starting run(). We call process() + // directly to exercise the mid-iteration pause path deterministically. + pruner := &indexPruner{ + disk: db, + typ: typeStateHistory, + log: log.New("type", "account"), + closed: make(chan struct{}), + pauseReq: make(chan chan struct{}, 1), // buffered so we can pre-deposit + resumeCh: make(chan struct{}), + } + + // Pre-deposit a pause request before process() starts. Because + // pauseReq is buffered, this succeeds immediately. When prunePrefix's + // select checks the channel on an early iteration, it will find the + // pending request and pause — no scheduling race is possible. + ack := make(chan struct{}) + pruner.pauseReq <- ack + + // Run process() in the background. + errCh := make(chan error, 1) + go func() { + errCh <- pruner.process(tail) + }() + + // Block until the pruner has flushed pending writes and acknowledged. + <-ack + + // While paused, append a new element to the target account's index, + // simulating what indexSingle would do during the pause window. + lastMax := targetDescList[len(targetDescList)-1].max + newID := lastMax + 10000 + iw, err := newIndexWriter(db, targetIdent, lastMax, 0) + if err != nil { + t.Fatalf("Failed to create index writer: %v", err) + } + if err := iw.append(newID, nil); err != nil { + t.Fatalf("Failed to append: %v", err) + } + batch := db.NewBatch() + iw.finish(batch) + if err := batch.Write(); err != nil { + t.Fatalf("Failed to write batch: %v", err) + } + + // Resume the pruner. + pruner.resume() + + // Wait for process() to complete. + if err := <-errCh; err != nil { + t.Fatalf("process() failed: %v", err) + } + + // Verify: the entry written during the pause must still be accessible. + // If the pruner used a stale iterator snapshot, it would overwrite the + // target's metadata and lose the new entry. + ir, err := newIndexReader(db, targetIdent, 0) + if err != nil { + t.Fatalf("Failed to create index reader: %v", err) + } + result, err := ir.readGreaterThan(newID - 1) + if err != nil { + t.Fatalf("Failed to read: %v", err) + } + if result != newID { + t.Fatalf("Entry written during pause was lost: want %d, got %d", newID, result) + } + + // Verify: pruning actually occurred on an early account. + earlyIdent := newAccountIdent(common.Hash{0x00}) + earlyBlob := readStateIndex(earlyIdent, db) + if len(earlyBlob) == 0 { + t.Fatal("Early account index should not be completely empty") + } + earlyRemaining, err := parseIndex(earlyBlob, 0) + if err != nil { + t.Fatalf("Failed to parse early account index: %v", err) + } + // The first block (id=0) should have been pruned. + if earlyRemaining[0].id == 0 { + t.Fatal("First block of early account should have been pruned") + } +} diff --git a/triedb/pathdb/history_indexer.go b/triedb/pathdb/history_indexer.go index c9bf3e87f1..9b215b917f 100644 --- a/triedb/pathdb/history_indexer.go +++ b/triedb/pathdb/history_indexer.go @@ -719,6 +719,7 @@ func (i *indexIniter) recover() bool { // state history. type historyIndexer struct { initer *indexIniter + pruner *indexPruner typ historyType disk ethdb.KeyValueStore freezer ethdb.AncientStore @@ -774,6 +775,7 @@ func newHistoryIndexer(disk ethdb.Database, freezer ethdb.AncientStore, lastHist checkVersion(disk, typ) return &historyIndexer{ initer: newIndexIniter(disk, freezer, typ, lastHistoryID, noWait), + pruner: newIndexPruner(disk, typ), typ: typ, disk: disk, freezer: freezer, @@ -782,6 +784,7 @@ func newHistoryIndexer(disk ethdb.Database, freezer ethdb.AncientStore, lastHist func (i *historyIndexer) close() { i.initer.close() + i.pruner.close() } // inited returns a flag indicating whether the existing state histories @@ -802,6 +805,8 @@ func (i *historyIndexer) extend(historyID uint64) error { case <-i.initer.closed: return errors.New("indexer is closed") case <-i.initer.done: + i.pruner.pause() + defer i.pruner.resume() return indexSingle(historyID, i.disk, i.freezer, i.typ) case i.initer.interrupt <- signal: return <-signal.result @@ -819,12 +824,27 @@ func (i *historyIndexer) shorten(historyID uint64) error { case <-i.initer.closed: return errors.New("indexer is closed") case <-i.initer.done: + i.pruner.pause() + defer i.pruner.resume() return unindexSingle(historyID, i.disk, i.freezer, i.typ) case i.initer.interrupt <- signal: return <-signal.result } } +// prune signals the pruner that the history tail has advanced to the given ID, +// so that stale index blocks referencing pruned histories can be removed. +func (i *historyIndexer) prune(newTail uint64) { + select { + case <-i.initer.closed: + log.Debug("Ignored the pruning signal", "reason", "closed") + case <-i.initer.done: + i.pruner.prune(newTail) + default: + log.Debug("Ignored the pruning signal", "reason", "busy") + } +} + // progress returns the indexing progress made so far. It provides the number // of states that remain unindexed. func (i *historyIndexer) progress() (uint64, error) { diff --git a/triedb/pathdb/iterator.go b/triedb/pathdb/iterator.go index 8ca8247206..2d333dfa1b 100644 --- a/triedb/pathdb/iterator.go +++ b/triedb/pathdb/iterator.go @@ -24,48 +24,15 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/triedb/internal" ) -// Iterator is an iterator to step over all the accounts or the specific -// storage in a snapshot which may or may not be composed of multiple layers. -type Iterator interface { - // Next steps the iterator forward one element, returning false if exhausted, - // or an error if iteration failed for some reason (e.g. root being iterated - // becomes stale and garbage collected). - Next() bool - - // Error returns any failure that occurred during iteration, which might have - // caused a premature iteration exit (e.g. layer stack becoming stale). - Error() error - - // Hash returns the hash of the account or storage slot the iterator is - // currently at. - Hash() common.Hash - - // Release releases associated resources. Release should always succeed and - // can be called multiple times without causing error. - Release() -} - -// AccountIterator is an iterator to step over all the accounts in a snapshot, -// which may or may not be composed of multiple layers. -type AccountIterator interface { - Iterator - - // Account returns the RLP encoded slim account the iterator is currently at. - // An error will be returned if the iterator becomes invalid - Account() []byte -} - -// StorageIterator is an iterator to step over the specific storage in a snapshot, -// which may or may not be composed of multiple layers. -type StorageIterator interface { - Iterator - - // Slot returns the storage slot the iterator is currently at. An error will - // be returned if the iterator becomes invalid - Slot() []byte -} +// Type aliases for the iterator interfaces defined in triedb/internal. +type ( + Iterator = internal.Iterator + AccountIterator = internal.AccountIterator + StorageIterator = internal.StorageIterator +) type ( // loadAccount is the function to retrieve the account from the associated diff --git a/triedb/pathdb/layertree.go b/triedb/pathdb/layertree.go index ec45257db5..b20e40bd05 100644 --- a/triedb/pathdb/layertree.go +++ b/triedb/pathdb/layertree.go @@ -151,6 +151,15 @@ func (tree *layerTree) add(root common.Hash, parentRoot common.Hash, block uint6 if root == parentRoot { return errors.New("layer cycle") } + // If a layer with this root already exists, skip the insertion. Fork blocks + // can produce the same state root as the canonical block (same parent, same + // coinbase, zero txs); overwriting tree.layers[root] would corrupt the parent + // chain for any child layers already built on top of the existing one, and + // appending a duplicate root to the lookup indices causes accountTip/storageTip + // to resolve the wrong layer. + if tree.get(root) != nil { + return nil + } parent := tree.get(parentRoot) if parent == nil { return fmt.Errorf("triedb parent [%#x] layer missing", parentRoot) @@ -310,8 +319,8 @@ func (tree *layerTree) lookupAccount(accountHash common.Hash, state common.Hash) tree.lock.RLock() defer tree.lock.RUnlock() - tip := tree.lookup.accountTip(accountHash, state, tree.base.root) - if tip == (common.Hash{}) { + tip, ok := tree.lookup.accountTip(accountHash, state, tree.base.root) + if !ok { return nil, fmt.Errorf("[%#x] %w", state, errSnapshotStale) } l := tree.layers[tip] @@ -328,8 +337,8 @@ func (tree *layerTree) lookupStorage(accountHash common.Hash, slotHash common.Ha tree.lock.RLock() defer tree.lock.RUnlock() - tip := tree.lookup.storageTip(accountHash, slotHash, state, tree.base.root) - if tip == (common.Hash{}) { + tip, ok := tree.lookup.storageTip(accountHash, slotHash, state, tree.base.root) + if !ok { return nil, fmt.Errorf("[%#x] %w", state, errSnapshotStale) } l := tree.layers[tip] diff --git a/triedb/pathdb/layertree_test.go b/triedb/pathdb/layertree_test.go index a74c6eb045..82eb182990 100644 --- a/triedb/pathdb/layertree_test.go +++ b/triedb/pathdb/layertree_test.go @@ -575,6 +575,40 @@ func TestDescendant(t *testing.T) { } } +func TestDuplicateRootLookup(t *testing.T) { + // Chain: + // C1->C2->C3 (HEAD) + tr := newTestLayerTree() // base = 0x1 + tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, NewNodeSetWithOrigin(nil, nil), + NewStateSetWithOrigin(randomAccountSet("0xa"), randomStorageSet([]string{"0xa"}, [][]string{{"0x1"}}, nil), nil, nil, false)) + tr.add(common.Hash{0x3}, common.Hash{0x2}, 2, NewNodeSetWithOrigin(nil, nil), + NewStateSetWithOrigin(randomAccountSet("0xa"), randomStorageSet([]string{"0xa"}, [][]string{{"0x1"}}, nil), nil, nil, false)) + + // A fork block with the same state root as C2; inserting it must not + // pollute the lookup history for the canonical descendant C3. + tr.add(common.Hash{0x2}, common.Hash{0x1}, 1, NewNodeSetWithOrigin(nil, nil), + NewStateSetWithOrigin(randomAccountSet("0xa"), randomStorageSet([]string{"0xa"}, [][]string{{"0x1"}}, nil), nil, nil, false)) + if n := tr.len(); n != 3 { + t.Fatalf("duplicate root insert changed layer count, got %d, want 3", n) + } + + l, err := tr.lookupAccount(common.HexToHash("0xa"), common.Hash{0x3}) + if err != nil { + t.Fatalf("account lookup failed: %v", err) + } + if l.rootHash() != (common.Hash{0x3}) { + t.Errorf("unexpected account tip, want %x, got %x", common.Hash{0x3}, l.rootHash()) + } + + l, err = tr.lookupStorage(common.HexToHash("0xa"), common.HexToHash("0x1"), common.Hash{0x3}) + if err != nil { + t.Fatalf("storage lookup failed: %v", err) + } + if l.rootHash() != (common.Hash{0x3}) { + t.Errorf("unexpected storage tip, want %x, got %x", common.Hash{0x3}, l.rootHash()) + } +} + func TestAccountLookup(t *testing.T) { // Chain: // C1->C2->C3->C4 (HEAD) @@ -882,3 +916,118 @@ func TestStorageLookup(t *testing.T) { } } } + +// TestLookupZeroBaseRootFallback is a regression test for a sentinel +// collision in accountTip/storageTip: before the fix they returned +// common.Hash{} as both the "stale" marker and the disk-layer fallback +// when the disk root itself happened to be zero. lookupAccount/Storage +// then misreported a legitimate fallback as errSnapshotStale. +// +// On the merkle path the collision was invisible because the empty +// merkle trie hashes to types.EmptyRootHash (a concrete non-zero +// keccak), so the disk layer's root was never the zero hash in +// practice. The bug only surfaces once the disk layer root can +// legitimately be zero (for example a fresh verkle/bintrie database +// where the empty binary trie hashes to EmptyVerkleHash == +// common.Hash{}). +// +// The test constructs a layer tree whose base layer's root IS the zero +// hash, stacks diff layers on top, and exercises four cases: +// +// 1. Look up an account NEVER written → should fall through to the +// disk layer and return (diskLayer, nil). Before the fix this +// returned errSnapshotStale because the fallback hash collided +// with the sentinel. +// 2. Symmetric case for lookupStorage. +// 3. Look up an account written in a diff layer → should return that +// diff layer (the normal happy path is unaffected by the fix). +// 4. Look up any key at a state root that isn't part of the tree +// (neither the disk root nor a descendant of it) → MUST still +// return errSnapshotStale. This pins the "other half" of the +// contract so a future refactor that always returns ok=true would +// fail here. +func TestLookupZeroBaseRootFallback(t *testing.T) { + // Build a layer tree whose disk-layer root is common.Hash{} — + // mirrors the bintrie/verkle configuration where the empty trie + // hashes to EmptyVerkleHash. newTestLayerTree can't be reused + // because it hard-codes common.Hash{0x1}. + db := New(rawdb.NewMemoryDatabase(), nil, false) + base := newDiskLayer(common.Hash{}, 0, db, nil, nil, newBuffer(0, nil, nil, 0), nil) + tr := newLayerTree(base) + + // Stack two diff layers on the zero-rooted disk layer, each + // touching a known account and slot so we have something for the + // happy-path lookups to find later. + if err := tr.add( + common.Hash{0x2}, common.Hash{}, + 1, + NewNodeSetWithOrigin(nil, nil), + NewStateSetWithOrigin( + randomAccountSet("0xa"), + randomStorageSet([]string{"0xa"}, [][]string{{"0x1"}}, nil), + nil, nil, false), + ); err != nil { + t.Fatalf("add first diff layer: %v", err) + } + if err := tr.add( + common.Hash{0x3}, common.Hash{0x2}, + 2, + NewNodeSetWithOrigin(nil, nil), + NewStateSetWithOrigin( + randomAccountSet("0xb"), + nil, nil, nil, false), + ); err != nil { + t.Fatalf("add second diff layer: %v", err) + } + + // Case 1: unknown account queried at the head. The lookup must + // fall through the diff layers, hit the disk-layer fallback at + // base=common.Hash{}, and return the disk layer with no error — + // NOT errSnapshotStale. + l, err := tr.lookupAccount(common.HexToHash("0xdead"), common.Hash{0x3}) + if err != nil { + t.Fatalf("lookupAccount on zero-base disk layer: unexpected error %v", err) + } + if l.rootHash() != (common.Hash{}) { + t.Errorf("expected fall-through to disk layer (root=0), got %x", l.rootHash()) + } + + // Case 2: symmetric check for storage. Slot 0x99 was never written, + // so the lookup must fall through to the disk layer just like + // Case 1. + l, err = tr.lookupStorage( + common.HexToHash("0xdead"), common.HexToHash("0x99"), common.Hash{0x3}) + if err != nil { + t.Fatalf("lookupStorage on zero-base disk layer: unexpected error %v", err) + } + if l.rootHash() != (common.Hash{}) { + t.Errorf("expected fall-through to disk layer (root=0), got %x", l.rootHash()) + } + + // Case 3: happy path. Account 0xa was written at diff layer 0x2. + // The lookup must return that layer, proving the fix didn't break + // the normal resolution path. + l, err = tr.lookupAccount(common.HexToHash("0xa"), common.Hash{0x3}) + if err != nil { + t.Fatalf("lookupAccount(known): %v", err) + } + if l.rootHash() != (common.Hash{0x2}) { + t.Errorf("known account tip: want %x, got %x", + common.Hash{0x2}, l.rootHash()) + } + + // Case 4: truly stale state root. This pins the other half of the + // contract — the boolean must actually signal not-found for an + // unknown state, otherwise a refactor that always returned + // ok=true would still pass cases 1–3. + _, err = tr.lookupAccount(common.HexToHash("0xa"), common.HexToHash("0xdeadbeef")) + if !errors.Is(err, errSnapshotStale) { + t.Errorf("lookupAccount(stale state): want errSnapshotStale, got %v", err) + } + _, err = tr.lookupStorage( + common.HexToHash("0xa"), common.HexToHash("0x1"), + common.HexToHash("0xdeadbeef")) + if !errors.Is(err, errSnapshotStale) { + t.Errorf("lookupStorage(stale state): want errSnapshotStale, got %v", err) + } +} diff --git a/triedb/pathdb/lookup.go b/triedb/pathdb/lookup.go index 719546f410..9b300ec871 100644 --- a/triedb/pathdb/lookup.go +++ b/triedb/pathdb/lookup.go @@ -92,12 +92,16 @@ func newLookup(head layer, descendant func(state common.Hash, ancestor common.Ha // stateID or is a descendant of it. // // If found, the account data corresponding to the supplied stateID resides -// in that layer. Otherwise, two scenarios are possible: +// in the layer identified by the returned hash (ok=true). Otherwise, +// (common.Hash{}, false) is returned to signal that the supplied stateID is +// stale. // -// (a) the account remains unmodified from the current disk layer up to the state -// layer specified by the stateID: fallback to the disk layer for data retrieval, -// (b) or the layer specified by the stateID is stale: reject the data retrieval. -func (l *lookup) accountTip(accountHash common.Hash, stateID common.Hash, base common.Hash) common.Hash { +// Note the returned hash may itself be common.Hash{} when the disk layer's +// root is zero — as is the case for a fresh verkle/bintrie database whose +// empty trie hashes to EmptyVerkleHash. Callers must therefore consult the +// boolean rather than comparing the returned hash against common.Hash{} +// directly. +func (l *lookup) accountTip(accountHash common.Hash, stateID common.Hash, base common.Hash) (common.Hash, bool) { // Traverse the mutation history from latest to oldest one. Several // scenarios are possible: // @@ -123,31 +127,26 @@ func (l *lookup) accountTip(accountHash common.Hash, stateID common.Hash, base c // containing the modified data. Otherwise, the current state may be ahead // of the requested one or belong to a different branch. if list[i] == stateID || l.descendant(stateID, list[i]) { - return list[i] + return list[i], true } } // No layer matching the stateID or its descendants was found. Use the // current disk layer as a fallback. if base == stateID || l.descendant(stateID, base) { - return base + return base, true } // The layer associated with 'stateID' is not the descendant of the current // disk layer, it's already stale, return nothing. - return common.Hash{} + return common.Hash{}, false } // storageTip traverses the layer list associated with the given account and // slot hash in reverse order to locate the first entry that either matches // the specified stateID or is a descendant of it. // -// If found, the storage data corresponding to the supplied stateID resides -// in that layer. Otherwise, two scenarios are possible: -// -// (a) the storage slot remains unmodified from the current disk layer up to -// the state layer specified by the stateID: fallback to the disk layer for -// data retrieval, (b) or the layer specified by the stateID is stale: reject -// the data retrieval. -func (l *lookup) storageTip(accountHash common.Hash, slotHash common.Hash, stateID common.Hash, base common.Hash) common.Hash { +// See accountTip for the returned-hash / ok convention — the same +// bintrie-zero-root caveat applies here. +func (l *lookup) storageTip(accountHash common.Hash, slotHash common.Hash, stateID common.Hash, base common.Hash) (common.Hash, bool) { list := l.storages[storageKey(accountHash, slotHash)] for i := len(list) - 1; i >= 0; i-- { // If the current state matches the stateID, or the requested state is a @@ -155,17 +154,17 @@ func (l *lookup) storageTip(accountHash common.Hash, slotHash common.Hash, state // containing the modified data. Otherwise, the current state may be ahead // of the requested one or belong to a different branch. if list[i] == stateID || l.descendant(stateID, list[i]) { - return list[i] + return list[i], true } } // No layer matching the stateID or its descendants was found. Use the // current disk layer as a fallback. if base == stateID || l.descendant(stateID, base) { - return base + return base, true } // The layer associated with 'stateID' is not the descendant of the current // disk layer, it's already stale, return nothing. - return common.Hash{} + return common.Hash{}, false } // addLayer traverses the state data retained in the specified diff layer and diff --git a/triedb/pathdb/metrics.go b/triedb/pathdb/metrics.go index a0a626f9b5..e01dfdfb86 100644 --- a/triedb/pathdb/metrics.go +++ b/triedb/pathdb/metrics.go @@ -77,10 +77,12 @@ var ( trienodeHistoryDataBytesMeter = metrics.NewRegisteredMeter("pathdb/history/trienode/bytes/data", nil) trienodeHistoryIndexBytesMeter = metrics.NewRegisteredMeter("pathdb/history/trienode/bytes/index", nil) - stateIndexHistoryTimer = metrics.NewRegisteredResettingTimer("pathdb/history/state/index/time", nil) - stateUnindexHistoryTimer = metrics.NewRegisteredResettingTimer("pathdb/history/state/unindex/time", nil) - trienodeIndexHistoryTimer = metrics.NewRegisteredResettingTimer("pathdb/history/trienode/index/time", nil) - trienodeUnindexHistoryTimer = metrics.NewRegisteredResettingTimer("pathdb/history/trienode/unindex/time", nil) + stateIndexHistoryTimer = metrics.NewRegisteredResettingTimer("pathdb/history/state/index/time", nil) + stateUnindexHistoryTimer = metrics.NewRegisteredResettingTimer("pathdb/history/state/unindex/time", nil) + statePruneHistoryIndexTimer = metrics.NewRegisteredResettingTimer("pathdb/history/state/prune/time", nil) + trienodeIndexHistoryTimer = metrics.NewRegisteredResettingTimer("pathdb/history/trienode/index/time", nil) + trienodeUnindexHistoryTimer = metrics.NewRegisteredResettingTimer("pathdb/history/trienode/unindex/time", nil) + trienodePruneHistoryIndexTimer = metrics.NewRegisteredResettingTimer("pathdb/history/trienode/prune/time", nil) lookupAddLayerTimer = metrics.NewRegisteredResettingTimer("pathdb/lookup/add/time", nil) lookupRemoveLayerTimer = metrics.NewRegisteredResettingTimer("pathdb/lookup/remove/time", nil) diff --git a/triedb/pathdb/verifier.go b/triedb/pathdb/verifier.go index a69b10f4f3..c53590f2fd 100644 --- a/triedb/pathdb/verifier.go +++ b/triedb/pathdb/verifier.go @@ -17,36 +17,15 @@ package pathdb import ( - "encoding/binary" "errors" "fmt" - "math" - "runtime" - "sync" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" -) - -// trieKV represents a trie key-value pair -type trieKV struct { - key common.Hash - value []byte -} - -type ( - // trieHasherFn is the interface of trie hasher which can be implemented - // by different trie algorithm. - trieHasherFn func(in chan trieKV, out chan common.Hash) - - // leafCallbackFn is the callback invoked at the leaves of the trie, - // returns the subtrie root with the specified subtrie identifier. - leafCallbackFn func(accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) + "github.com/ethereum/go-ethereum/triedb/internal" ) // VerifyState traverses the flat states specified by the given state root and @@ -58,7 +37,7 @@ func (db *Database) VerifyState(root common.Hash) error { } defer acctIt.Release() - got, err := generateTrieRoot(acctIt, common.Hash{}, stackTrieHasher, func(accountHash, codeHash common.Hash, stat *generateStats) (common.Hash, error) { + got, err := internal.GenerateTrieRoot(nil, "", acctIt, common.Hash{}, stackTrieHasher, func(_ ethdb.KeyValueWriter, accountHash, codeHash common.Hash, stat *internal.GenerateStats) (common.Hash, error) { // Migrate the code first, commit the contract code into the tmp db. if codeHash != types.EmptyCodeHash { code := rawdb.ReadCode(db.diskdb, codeHash) @@ -73,12 +52,12 @@ func (db *Database) VerifyState(root common.Hash) error { } defer storageIt.Release() - hash, err := generateTrieRoot(storageIt, accountHash, stackTrieHasher, nil, stat, false) + hash, err := internal.GenerateTrieRoot(nil, "", storageIt, accountHash, stackTrieHasher, nil, stat, false) if err != nil { return common.Hash{}, err } return hash, nil - }, newGenerateStats(), true) + }, internal.NewGenerateStats(), true) if err != nil { return err @@ -89,264 +68,10 @@ func (db *Database) VerifyState(root common.Hash) error { return nil } -// generateStats is a collection of statistics gathered by the trie generator -// for logging purposes. -type generateStats struct { - head common.Hash - start time.Time - - accounts uint64 // Number of accounts done (including those being crawled) - slots uint64 // Number of storage slots done (including those being crawled) - - slotsStart map[common.Hash]time.Time // Start time for account slot crawling - slotsHead map[common.Hash]common.Hash // Slot head for accounts being crawled - - lock sync.RWMutex -} - -// newGenerateStats creates a new generator stats. -func newGenerateStats() *generateStats { - return &generateStats{ - slotsStart: make(map[common.Hash]time.Time), - slotsHead: make(map[common.Hash]common.Hash), - start: time.Now(), - } -} - -// progressAccounts updates the generator stats for the account range. -func (stat *generateStats) progressAccounts(account common.Hash, done uint64) { - stat.lock.Lock() - defer stat.lock.Unlock() - - stat.accounts += done - stat.head = account -} - -// finishAccounts updates the generator stats for the finished account range. -func (stat *generateStats) finishAccounts(done uint64) { - stat.lock.Lock() - defer stat.lock.Unlock() - - stat.accounts += done -} - -// progressContract updates the generator stats for a specific in-progress contract. -func (stat *generateStats) progressContract(account common.Hash, slot common.Hash, done uint64) { - stat.lock.Lock() - defer stat.lock.Unlock() - - stat.slots += done - stat.slotsHead[account] = slot - if _, ok := stat.slotsStart[account]; !ok { - stat.slotsStart[account] = time.Now() - } -} - -// finishContract updates the generator stats for a specific just-finished contract. -func (stat *generateStats) finishContract(account common.Hash, done uint64) { - stat.lock.Lock() - defer stat.lock.Unlock() - - stat.slots += done - delete(stat.slotsHead, account) - delete(stat.slotsStart, account) -} - -// report prints the cumulative progress statistic smartly. -func (stat *generateStats) report() { - stat.lock.RLock() - defer stat.lock.RUnlock() - - ctx := []interface{}{ - "accounts", stat.accounts, - "slots", stat.slots, - "elapsed", common.PrettyDuration(time.Since(stat.start)), - } - if stat.accounts > 0 { - // If there's progress on the account trie, estimate the time to finish crawling it - if done := binary.BigEndian.Uint64(stat.head[:8]) / stat.accounts; done > 0 { - var ( - left = (math.MaxUint64 - binary.BigEndian.Uint64(stat.head[:8])) / stat.accounts - eta = common.CalculateETA(done, left, time.Since(stat.start)) - ) - // If there are large contract crawls in progress, estimate their finish time - for acc, head := range stat.slotsHead { - start := stat.slotsStart[acc] - if done := binary.BigEndian.Uint64(head[:8]); done > 0 { - left := math.MaxUint64 - binary.BigEndian.Uint64(head[:8]) - - // Override the ETA if larger than the largest until now - if slotETA := common.CalculateETA(done, left, time.Since(start)); eta < slotETA { - eta = slotETA - } - } - } - ctx = append(ctx, []interface{}{ - "eta", common.PrettyDuration(eta), - }...) - } - } - log.Info("Iterating state snapshot", ctx...) -} - -// reportDone prints the last log when the whole generation is finished. -func (stat *generateStats) reportDone() { - stat.lock.RLock() - defer stat.lock.RUnlock() - - var ctx []interface{} - ctx = append(ctx, []interface{}{"accounts", stat.accounts}...) - if stat.slots != 0 { - ctx = append(ctx, []interface{}{"slots", stat.slots}...) - } - ctx = append(ctx, []interface{}{"elapsed", common.PrettyDuration(time.Since(stat.start))}...) - log.Info("Iterated snapshot", ctx...) -} - -// runReport periodically prints the progress information. -func runReport(stats *generateStats, stop chan bool) { - timer := time.NewTimer(0) - defer timer.Stop() - - for { - select { - case <-timer.C: - stats.report() - timer.Reset(time.Second * 8) - case success := <-stop: - if success { - stats.reportDone() - } - return - } - } -} - -// generateTrieRoot generates the trie hash based on the snapshot iterator. -// It can be used for generating account trie, storage trie or even the -// whole state which connects the accounts and the corresponding storages. -func generateTrieRoot(it Iterator, account common.Hash, generatorFn trieHasherFn, leafCallback leafCallbackFn, stats *generateStats, report bool) (common.Hash, error) { - var ( - in = make(chan trieKV) // chan to pass leaves - out = make(chan common.Hash, 1) // chan to collect result - stoplog = make(chan bool, 1) // 1-size buffer, works when logging is not enabled - wg sync.WaitGroup - ) - // Spin up a go-routine for trie hash re-generation - wg.Add(1) - go func() { - defer wg.Done() - generatorFn(in, out) - }() - // Spin up a go-routine for progress logging - if report && stats != nil { - wg.Add(1) - go func() { - defer wg.Done() - runReport(stats, stoplog) - }() - } - // Create a semaphore to assign tasks and collect results through. We'll pre- - // fill it with nils, thus using the same channel for both limiting concurrent - // processing and gathering results. - threads := runtime.NumCPU() - results := make(chan error, threads) - for i := 0; i < threads; i++ { - results <- nil // fill the semaphore - } - // stop is a helper function to shutdown the background threads - // and return the re-generated trie hash. - stop := func(fail error) (common.Hash, error) { - close(in) - result := <-out - for i := 0; i < threads; i++ { - if err := <-results; err != nil && fail == nil { - fail = err - } - } - stoplog <- fail == nil - - wg.Wait() - return result, fail - } - var ( - logged = time.Now() - processed = uint64(0) - leaf trieKV - ) - // Start to feed leaves - for it.Next() { - if account == (common.Hash{}) { - var ( - err error - fullData []byte - ) - if leafCallback == nil { - fullData, err = types.FullAccountRLP(it.(AccountIterator).Account()) - if err != nil { - return stop(err) - } - } else { - // Wait until the semaphore allows us to continue, aborting if - // a sub-task failed - if err := <-results; err != nil { - results <- nil // stop will drain the results, add a noop back for this error we just consumed - return stop(err) - } - // Fetch the next account and process it concurrently - account, err := types.FullAccount(it.(AccountIterator).Account()) - if err != nil { - return stop(err) - } - go func(hash common.Hash) { - subroot, err := leafCallback(hash, common.BytesToHash(account.CodeHash), stats) - if err != nil { - results <- err - return - } - if account.Root != subroot { - results <- fmt.Errorf("invalid subroot(path %x), want %x, have %x", hash, account.Root, subroot) - return - } - results <- nil - }(it.Hash()) - fullData, err = rlp.EncodeToBytes(account) - if err != nil { - return stop(err) - } - } - leaf = trieKV{it.Hash(), fullData} - } else { - leaf = trieKV{it.Hash(), common.CopyBytes(it.(StorageIterator).Slot())} - } - in <- leaf - - // Accumulate the generation statistic if it's required. - processed++ - if time.Since(logged) > 3*time.Second && stats != nil { - if account == (common.Hash{}) { - stats.progressAccounts(it.Hash(), processed) - } else { - stats.progressContract(account, it.Hash(), processed) - } - logged, processed = time.Now(), 0 - } - } - // Commit the last part statistic. - if processed > 0 && stats != nil { - if account == (common.Hash{}) { - stats.finishAccounts(processed) - } else { - stats.finishContract(account, processed) - } - } - return stop(nil) -} - -func stackTrieHasher(in chan trieKV, out chan common.Hash) { +func stackTrieHasher(_ ethdb.KeyValueWriter, _ string, _ common.Hash, in chan internal.TrieKV, out chan common.Hash) { t := trie.NewStackTrie(nil) for leaf := range in { - t.Update(leaf.key[:], leaf.value) + t.Update(leaf.Key[:], leaf.Value) } out <- t.Hash() }