diff --git a/eth/api_debug.go b/eth/api_debug.go index d5e4dda140..188dee11aa 100644 --- a/eth/api_debug.go +++ b/eth/api_debug.go @@ -271,26 +271,26 @@ func storageRangeAt(statedb *state.StateDB, root common.Hash, address common.Add // // With one parameter, returns the list of accounts modified in the specified block. func (api *DebugAPI) GetModifiedAccountsByNumber(startNum uint64, endNum *uint64) ([]common.Address, error) { - var startBlock, endBlock *types.Block + var startHeader, endHeader *types.Header - startBlock = api.eth.blockchain.GetBlockByNumber(startNum) - if startBlock == nil { + startHeader = api.eth.blockchain.GetHeaderByNumber(startNum) + if startHeader == nil { return nil, fmt.Errorf("start block %x not found", startNum) } if endNum == nil { - endBlock = startBlock - startBlock = api.eth.blockchain.GetBlockByHash(startBlock.ParentHash()) - if startBlock == nil { - return nil, fmt.Errorf("block %x has no parent", endBlock.Number()) + endHeader = startHeader + startHeader = api.eth.blockchain.GetHeaderByHash(startHeader.ParentHash) + if startHeader == nil { + return nil, fmt.Errorf("block %x has no parent", endHeader.Number) } } else { - endBlock = api.eth.blockchain.GetBlockByNumber(*endNum) - if endBlock == nil { + endHeader = api.eth.blockchain.GetHeaderByNumber(*endNum) + if endHeader == nil { return nil, fmt.Errorf("end block %d not found", *endNum) } } - return api.getModifiedAccounts(startBlock, endBlock) + return api.getModifiedAccounts(startHeader, endHeader) } // GetModifiedAccountsByHash returns all accounts that have changed between the @@ -299,38 +299,38 @@ func (api *DebugAPI) GetModifiedAccountsByNumber(startNum uint64, endNum *uint64 // // With one parameter, returns the list of accounts modified in the specified block. func (api *DebugAPI) GetModifiedAccountsByHash(startHash common.Hash, endHash *common.Hash) ([]common.Address, error) { - var startBlock, endBlock *types.Block - startBlock = api.eth.blockchain.GetBlockByHash(startHash) - if startBlock == nil { + var startHeader, endHeader *types.Header + startHeader = api.eth.blockchain.GetHeaderByHash(startHash) + if startHeader == nil { return nil, fmt.Errorf("start block %x not found", startHash) } if endHash == nil { - endBlock = startBlock - startBlock = api.eth.blockchain.GetBlockByHash(startBlock.ParentHash()) - if startBlock == nil { - return nil, fmt.Errorf("block %x has no parent", endBlock.Number()) + endHeader = startHeader + startHeader = api.eth.blockchain.GetHeaderByHash(startHeader.ParentHash) + if startHeader == nil { + return nil, fmt.Errorf("block %x has no parent", endHeader.Number) } } else { - endBlock = api.eth.blockchain.GetBlockByHash(*endHash) - if endBlock == nil { + endHeader = api.eth.blockchain.GetHeaderByHash(*endHash) + if endHeader == nil { return nil, fmt.Errorf("end block %x not found", *endHash) } } - return api.getModifiedAccounts(startBlock, endBlock) + return api.getModifiedAccounts(startHeader, endHeader) } -func (api *DebugAPI) getModifiedAccounts(startBlock, endBlock *types.Block) ([]common.Address, error) { - if startBlock.Number().Uint64() >= endBlock.Number().Uint64() { - return nil, fmt.Errorf("start block height (%d) must be less than end block height (%d)", startBlock.Number().Uint64(), endBlock.Number().Uint64()) +func (api *DebugAPI) getModifiedAccounts(startHeader, endHeader *types.Header) ([]common.Address, error) { + if startHeader.Number.Uint64() >= endHeader.Number.Uint64() { + return nil, fmt.Errorf("start block height (%d) must be less than end block height (%d)", startHeader.Number.Uint64(), endHeader.Number.Uint64()) } triedb := api.eth.BlockChain().TrieDB() - oldTrie, err := trie.NewStateTrie(trie.StateTrieID(startBlock.Root()), triedb) + oldTrie, err := trie.NewStateTrie(trie.StateTrieID(startHeader.Root), triedb) if err != nil { return nil, err } - newTrie, err := trie.NewStateTrie(trie.StateTrieID(endBlock.Root()), triedb) + newTrie, err := trie.NewStateTrie(trie.StateTrieID(endHeader.Root), triedb) if err != nil { return nil, err } diff --git a/eth/api_debug_test.go b/eth/api_debug_test.go index 02b85f69fd..90fd22498a 100644 --- a/eth/api_debug_test.go +++ b/eth/api_debug_test.go @@ -18,25 +18,74 @@ package eth import ( "bytes" + "crypto/ecdsa" "fmt" + "math/big" "reflect" "slices" "strings" "testing" + "time" "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/tracing" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/triedb" "github.com/holiman/uint256" + "github.com/stretchr/testify/assert" ) var dumper = spew.ConfigState{Indent: " "} +type Account struct { + key *ecdsa.PrivateKey + addr common.Address +} + +func newAccounts(n int) (accounts []Account) { + for i := 0; i < n; i++ { + key, _ := crypto.GenerateKey() + addr := crypto.PubkeyToAddress(key.PublicKey) + accounts = append(accounts, Account{key: key, addr: addr}) + } + slices.SortFunc(accounts, func(a, b Account) int { return a.addr.Cmp(b.addr) }) + return accounts +} + +// newTestBlockChain creates a new test blockchain. OBS: After test is done, teardown must be +// invoked in order to release associated resources. +func newTestBlockChain(t *testing.T, n int, gspec *core.Genesis, generator func(i int, b *core.BlockGen)) *core.BlockChain { + engine := ethash.NewFaker() + // Generate blocks for testing + _, blocks, _ := core.GenerateChainWithGenesis(gspec, engine, n, generator) + + // Import the canonical chain + cacheConfig := &core.CacheConfig{ + TrieCleanLimit: 256, + TrieDirtyLimit: 256, + TrieTimeLimit: 5 * time.Minute, + SnapshotLimit: 0, + Preimages: true, + TrieDirtyDisabled: true, // Archive mode + } + chain, err := core.NewBlockChain(rawdb.NewMemoryDatabase(), cacheConfig, gspec, nil, engine, vm.Config{}, nil) + if err != nil { + t.Fatalf("failed to create tester chain: %v", err) + } + if n, err := chain.InsertChain(blocks); err != nil { + t.Fatalf("block %d: failed to insert into chain: %v", n, err) + } + return chain +} + func accountRangeTest(t *testing.T, trie *state.Trie, statedb *state.StateDB, start common.Hash, requestedNum int, expectedNum int) state.Dump { result := statedb.RawDump(&state.DumpConfig{ SkipCode: true, @@ -224,3 +273,66 @@ func TestStorageRangeAt(t *testing.T) { } } } + +func TestGetModifiedAccounts(t *testing.T) { + t.Parallel() + + // Initialize test accounts + accounts := newAccounts(4) + genesis := &core.Genesis{ + Config: params.TestChainConfig, + Alloc: types.GenesisAlloc{ + accounts[0].addr: {Balance: big.NewInt(params.Ether)}, + accounts[1].addr: {Balance: big.NewInt(params.Ether)}, + accounts[2].addr: {Balance: big.NewInt(params.Ether)}, + accounts[3].addr: {Balance: big.NewInt(params.Ether)}, + }, + } + genBlocks := 1 + signer := types.HomesteadSigner{} + blockChain := newTestBlockChain(t, genBlocks, genesis, func(_ int, b *core.BlockGen) { + // Transfer from account[0] to account[1] + // value: 1000 wei + // fee: 0 wei + for _, account := range accounts[:3] { + tx, _ := types.SignTx(types.NewTx(&types.LegacyTx{ + Nonce: 0, + To: &accounts[3].addr, + Value: big.NewInt(1000), + Gas: params.TxGas, + GasPrice: b.BaseFee(), + Data: nil}), + signer, account.key) + b.AddTx(tx) + } + }) + defer blockChain.Stop() + + // Create a debug API instance. + api := NewDebugAPI(&Ethereum{blockchain: blockChain}) + + // Test GetModifiedAccountsByNumber + t.Run("GetModifiedAccountsByNumber", func(t *testing.T) { + addrs, err := api.GetModifiedAccountsByNumber(uint64(genBlocks), nil) + assert.NoError(t, err) + assert.Len(t, addrs, len(accounts)+1) // +1 for the coinbase + for _, account := range accounts { + if !slices.Contains(addrs, account.addr) { + t.Fatalf("account %s not found in modified accounts", account.addr.Hex()) + } + } + }) + + // Test GetModifiedAccountsByHash + t.Run("GetModifiedAccountsByHash", func(t *testing.T) { + header := blockChain.GetHeaderByNumber(uint64(genBlocks)) + addrs, err := api.GetModifiedAccountsByHash(header.Hash(), nil) + assert.NoError(t, err) + assert.Len(t, addrs, len(accounts)+1) // +1 for the coinbase + for _, account := range accounts { + if !slices.Contains(addrs, account.addr) { + t.Fatalf("account %s not found in modified accounts", account.addr.Hex()) + } + } + }) +}