perf: improve state reader with error handling and committed flag #27428 (#1166)

- Add error returns to Database.Reader() and NodeIterator() methods
- Introduce committed flag to prevent usage of tries after commit
- Update callers to handle new error signatures
- Add MustNodeIterator() helper for backward compatibility

Co-authored-by: rjl493456442 <garyrong0905@gmail.com>
This commit is contained in:
Daniel Liu 2026-02-03 23:25:53 +08:00 committed by GitHub
parent 76a53321df
commit 9cf795c908
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 369 additions and 124 deletions

View file

@ -186,9 +186,38 @@ func (t *XDCXTrie) Copy() *XDCXTrie {
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
// starts at the key after the given start key.
func (t *XDCXTrie) NodeIterator(start []byte) trie.NodeIterator {
return t.trie.NodeIterator(start)
trieIt, err := t.trie.NodeIterator(start)
if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
return errNodeIterator{err: err}
}
return trieIt
}
// errNodeIterator is a safe, non-nil iterator that reports an error and yields no nodes.
// It prevents nil dereferences when callers don't check for a nil iterator.
type errNodeIterator struct {
err error
}
func (it errNodeIterator) Next(bool) bool { return false }
func (it errNodeIterator) Error() error { return it.err }
func (it errNodeIterator) Hash() common.Hash {
return common.Hash{}
}
func (it errNodeIterator) Parent() common.Hash {
return common.Hash{}
}
func (it errNodeIterator) Path() []byte { return nil }
func (it errNodeIterator) NodeBlob() []byte { return nil }
func (it errNodeIterator) Leaf() bool { return false }
func (it errNodeIterator) LeafKey() []byte { return nil }
func (it errNodeIterator) LeafBlob() []byte { return nil }
func (it errNodeIterator) LeafProof() [][]byte {
return nil
}
func (it errNodeIterator) AddResolver(trie.NodeResolver) {}
// hashKey returns the hash of key as an ephemeral buffer.
// The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey.

View file

@ -182,9 +182,38 @@ func (t *XDCXTrie) Copy() *XDCXTrie {
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
// starts at the key after the given start key.
func (t *XDCXTrie) NodeIterator(start []byte) trie.NodeIterator {
return t.trie.NodeIterator(start)
trieIt, err := t.trie.NodeIterator(start)
if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
return errNodeIterator{err: err}
}
return trieIt
}
// errNodeIterator is a safe, non-nil iterator that reports an error and yields no nodes.
// It prevents nil dereferences when callers don't check for a nil iterator.
type errNodeIterator struct {
err error
}
func (it errNodeIterator) Next(bool) bool { return false }
func (it errNodeIterator) Error() error { return it.err }
func (it errNodeIterator) Hash() common.Hash {
return common.Hash{}
}
func (it errNodeIterator) Parent() common.Hash {
return common.Hash{}
}
func (it errNodeIterator) Path() []byte { return nil }
func (it errNodeIterator) NodeBlob() []byte { return nil }
func (it errNodeIterator) Leaf() bool { return false }
func (it errNodeIterator) LeafKey() []byte { return nil }
func (it errNodeIterator) LeafBlob() []byte { return nil }
func (it errNodeIterator) LeafProof() [][]byte {
return nil
}
func (it errNodeIterator) AddResolver(trie.NodeResolver) {}
// hashKey returns the hash of key as an ephemeral buffer.
// The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey.

View file

@ -390,8 +390,12 @@ func dbDumpTrie(ctx *cli.Context) error {
if err != nil {
return err
}
trieIt, err := theTrie.NodeIterator(start)
if err != nil {
return err
}
var count int64
it := trie.NewIterator(theTrie.NodeIterator(start))
it := trie.NewIterator(trieIt)
for it.Next() {
if max > 0 && count == max {
fmt.Printf("Exiting after %d values\n", count)

View file

@ -113,8 +113,9 @@ type Trie interface {
Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet)
// NodeIterator returns an iterator that returns nodes of the trie. Iteration
// starts at the key after the given start key.
NodeIterator(startKey []byte) trie.NodeIterator
// starts at the key after the given start key. An error will be returned
// if fails to create node iterator.
NodeIterator(startKey []byte) (trie.NodeIterator, error)
// Prove constructs a Merkle proof for key. The result contains all encoded nodes
// on the path to the value at key. The value itself is also included in the last

View file

@ -139,7 +139,11 @@ func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey []
log.Info("Trie dumping started", "root", s.trie.Hash())
c.OnRoot(s.trie.Hash())
it := trie.NewIterator(s.trie.NodeIterator(conf.Start))
trieIt, err := s.trie.NodeIterator(conf.Start)
if err != nil {
return nil
}
it := trie.NewIterator(trieIt)
for it.Next() {
var data types.StateAccount
if err := rlp.DecodeBytes(it.Value, &data); err != nil {
@ -177,7 +181,12 @@ func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey []
log.Error("Failed to load storage trie", "err", err)
continue
}
storageIt := trie.NewIterator(tr.NodeIterator(nil))
trieIt, err := tr.NodeIterator(nil)
if err != nil {
log.Error("Failed to create trie iterator", "err", err)
continue
}
storageIt := trie.NewIterator(trieIt)
for storageIt.Next() {
_, content, _, err := rlp.Split(storageIt.Value)
if err != nil {

View file

@ -74,8 +74,12 @@ func (it *nodeIterator) step() error {
return nil
}
// Initialize the iterator if we've just started
var err error
if it.stateIt == nil {
it.stateIt = it.state.trie.NodeIterator(nil)
it.stateIt, err = it.state.trie.NodeIterator(nil)
if err != nil {
return err
}
}
// If we had data nodes previously, we surely have at least state nodes
if it.dataIt != nil {
@ -113,7 +117,10 @@ func (it *nodeIterator) step() error {
if err != nil {
return err
}
it.dataIt = dataTrie.NodeIterator(nil)
it.dataIt, err = dataTrie.NodeIterator(nil)
if err != nil {
return err
}
if !it.dataIt.Next(true) {
it.dataIt = nil
}

View file

@ -43,7 +43,8 @@ func newStateTest() *stateTest {
func TestDump(t *testing.T) {
db := rawdb.NewMemoryDatabase()
sdb, _ := New(types.EmptyRootHash, NewDatabaseWithConfig(db, &trie.Config{Preimages: true}))
tdb := NewDatabaseWithConfig(db, &trie.Config{Preimages: true})
sdb, _ := New(types.EmptyRootHash, tdb)
s := &stateTest{db: db, state: sdb}
// generate a few entries
@ -57,9 +58,10 @@ func TestDump(t *testing.T) {
// write some of them to the trie
s.state.updateStateObject(obj1)
s.state.updateStateObject(obj2)
s.state.Commit(0, false)
root, _ := s.state.Commit(0, false)
// check that DumpToCollector contains the state objects that are in trie
s.state, _ = New(root, tdb)
got := string(s.state.Dump(nil))
want := `{
"root": "71edff0130dd2385947095001c73d9e28d862fc286fca2b922ca6f6f3cddfdd2",
@ -95,7 +97,8 @@ func TestDump(t *testing.T) {
func TestIterativeDump(t *testing.T) {
db := rawdb.NewMemoryDatabase()
sdb, _ := New(types.EmptyRootHash, NewDatabaseWithConfig(db, &trie.Config{Preimages: true}))
tdb := NewDatabaseWithConfig(db, &trie.Config{Preimages: true})
sdb, _ := New(types.EmptyRootHash, tdb)
s := &stateTest{db: db, state: sdb}
// generate a few entries
@ -111,7 +114,8 @@ func TestIterativeDump(t *testing.T) {
// write some of them to the trie
s.state.updateStateObject(obj1)
s.state.updateStateObject(obj2)
s.state.Commit(0, false)
root, _ := s.state.Commit(0, false)
s.state, _ = New(root, tdb)
b := &bytes.Buffer{}
s.state.IterativeDump(nil, json.NewEncoder(b))

View file

@ -45,8 +45,14 @@ type revision struct {
// StateDB structs within the ethereum protocol are used to store anything
// within the merkle trie. StateDBs take care of caching and storing
// nested states. It's the general query interface to retrieve:
//
// * Contracts
// * Accounts
//
// Once the state is committed, tries cached in stateDB (including account
// trie, storage tries) will no longer be functional. A new state instance
// must be created with new root and updated database for accessing post-
// commit states.
type StateDB struct {
db Database
trie Trie
@ -678,7 +684,11 @@ func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.
if err != nil {
return err
}
it := trie.NewIterator(tr.NodeIterator(nil))
trieIt, err := tr.NodeIterator(nil)
if err != nil {
return err
}
it := trie.NewIterator(trieIt)
for it.Next() {
key := common.BytesToHash(s.trie.GetKey(it.Key))

View file

@ -19,6 +19,7 @@ package state
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"math"
"math/big"
@ -33,6 +34,7 @@ import (
"github.com/XinFinOrg/XDPoSChain/core/rawdb"
"github.com/XinFinOrg/XDPoSChain/core/tracing"
"github.com/XinFinOrg/XDPoSChain/core/types"
"github.com/XinFinOrg/XDPoSChain/trie"
)
// Tests that updating a state trie does not leak any database writes prior to
@ -514,6 +516,51 @@ func TestTouchDelete(t *testing.T) {
}
}
// TestCommitCopy tests the copy from a committed state is not functional.
func TestCommitCopy(t *testing.T) {
state, _ := New(types.EmptyRootHash, NewDatabase(rawdb.NewMemoryDatabase()))
// Create an account and check if the retrieved balance is correct
addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe")
skey := common.HexToHash("aaa")
sval := common.HexToHash("bbb")
state.SetBalance(addr, big.NewInt(42), tracing.BalanceChangeUnspecified) // Change the account trie
state.SetCode(addr, []byte("hello")) // Change an external metadata
state.SetState(addr, skey, sval) // Change the storage trie
if balance := state.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("initial balance mismatch: have %v, want %v", balance, 42)
}
if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := state.GetState(addr, skey); val != sval {
t.Fatalf("initial non-committed storage slot mismatch: have %x, want %x", val, sval)
}
if val := state.GetCommittedState(addr, skey); val != (common.Hash{}) {
t.Fatalf("initial committed storage slot mismatch: have %x, want %x", val, common.Hash{})
}
// Copy the committed state database, the copied one is not functional.
state.Commit(0, true)
copied := state.Copy()
if balance := copied.GetBalance(addr); balance.Cmp(big.NewInt(0)) != 0 {
t.Fatalf("unexpected balance: have %v", balance)
}
if code := copied.GetCode(addr); code != nil {
t.Fatalf("unexpected code: have %x", code)
}
if val := copied.GetState(addr, skey); val != (common.Hash{}) {
t.Fatalf("unexpected storage slot: have %x", val)
}
if val := copied.GetCommittedState(addr, skey); val != (common.Hash{}) {
t.Fatalf("unexpected storage slot: have %x", val)
}
if !errors.Is(copied.Error(), trie.ErrCommitted) {
t.Fatalf("unexpected state error, %v", copied.Error())
}
}
func TestStateDBAccessList(t *testing.T) {
// Some helpers
addr := func(a string) common.Address {
@ -777,7 +824,8 @@ func TestCopyOfCopy(t *testing.T) {
//
// See https://github.com/ethereum/go-ethereum/issues/20106.
func TestCopyCommitCopy(t *testing.T) {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
tdb := NewDatabase(rawdb.NewMemoryDatabase())
state, _ := New(types.EmptyRootHash, tdb)
// Create an account and check if the retrieved balance is correct
addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe")
@ -814,20 +862,6 @@ func TestCopyCommitCopy(t *testing.T) {
if val := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) {
t.Fatalf("first copy pre-commit committed storage slot mismatch: have %x, want %x", val, common.Hash{})
}
copyOne.Commit(0, false)
if balance := copyOne.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("first copy post-commit balance mismatch: have %v, want %v", balance, 42)
}
if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("first copy post-commit code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := copyOne.GetState(addr, skey); val != sval {
t.Fatalf("first copy post-commit non-committed storage slot mismatch: have %x, want %x", val, sval)
}
if val := copyOne.GetCommittedState(addr, skey); val != sval {
t.Fatalf("first copy post-commit committed storage slot mismatch: have %x, want %x", val, sval)
}
// Copy the copy and check the balance once more
copyTwo := copyOne.Copy()
if balance := copyTwo.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
@ -839,8 +873,23 @@ func TestCopyCommitCopy(t *testing.T) {
if val := copyTwo.GetState(addr, skey); val != sval {
t.Fatalf("second copy non-committed storage slot mismatch: have %x, want %x", val, sval)
}
if val := copyTwo.GetCommittedState(addr, skey); val != sval {
t.Fatalf("second copy post-commit committed storage slot mismatch: have %x, want %x", val, sval)
if val := copyTwo.GetCommittedState(addr, skey); val != (common.Hash{}) {
t.Fatalf("second copy committed storage slot mismatch: have %x, want %x", val, sval)
}
// Commit state, ensure states can be loaded from disk
root, _ := state.Commit(0, false)
state, _ = New(root, tdb)
if balance := state.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("state post-commit balance mismatch: have %v, want %v", balance, 42)
}
if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("state post-commit code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := state.GetState(addr, skey); val != sval {
t.Fatalf("state post-commit non-committed storage slot mismatch: have %x, want %x", val, sval)
}
if val := state.GetCommittedState(addr, skey); val != sval {
t.Fatalf("state post-commit committed storage slot mismatch: have %x, want %x", val, sval)
}
}
@ -849,7 +898,7 @@ func TestCopyCommitCopy(t *testing.T) {
//
// See https://github.com/ethereum/go-ethereum/issues/20106.
func TestCopyCopyCommitCopy(t *testing.T) {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
state, _ := New(types.EmptyRootHash, NewDatabase(rawdb.NewMemoryDatabase()))
// Create an account and check if the retrieved balance is correct
addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe")
@ -900,19 +949,6 @@ func TestCopyCopyCommitCopy(t *testing.T) {
if val := copyTwo.GetCommittedState(addr, skey); val != (common.Hash{}) {
t.Fatalf("second copy pre-commit committed storage slot mismatch: have %x, want %x", val, common.Hash{})
}
copyTwo.Commit(0, false)
if balance := copyTwo.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
t.Fatalf("second copy post-commit balance mismatch: have %v, want %v", balance, 42)
}
if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) {
t.Fatalf("second copy post-commit code mismatch: have %x, want %x", code, []byte("hello"))
}
if val := copyTwo.GetState(addr, skey); val != sval {
t.Fatalf("second copy post-commit non-committed storage slot mismatch: have %x, want %x", val, sval)
}
if val := copyTwo.GetCommittedState(addr, skey); val != sval {
t.Fatalf("second copy post-commit committed storage slot mismatch: have %x, want %x", val, sval)
}
// Copy the copy-copy and check the balance once more
copyThree := copyTwo.Copy()
if balance := copyThree.GetBalance(addr); balance.Cmp(big.NewInt(42)) != 0 {
@ -924,7 +960,7 @@ func TestCopyCopyCommitCopy(t *testing.T) {
if val := copyThree.GetState(addr, skey); val != sval {
t.Fatalf("third copy non-committed storage slot mismatch: have %x, want %x", val, sval)
}
if val := copyThree.GetCommittedState(addr, skey); val != sval {
if val := copyThree.GetCommittedState(addr, skey); val != (common.Hash{}) {
t.Fatalf("third copy committed storage slot mismatch: have %x, want %x", val, sval)
}
}

View file

@ -109,7 +109,7 @@ func checkTrieConsistency(db ethdb.Database, root common.Hash) error {
if err != nil {
return err
}
it := trie.NodeIterator(nil)
it := trie.MustNodeIterator(nil)
for it.Next(true) {
}
return it.Error()
@ -567,6 +567,10 @@ func TestIncompleteStateSync(t *testing.T) {
addedPaths []string
addedHashes []common.Hash
)
reader, err := srcDb.TrieDB().Reader(srcRoot)
if err != nil {
t.Fatalf("state is not available %x", srcRoot)
}
nodeQueue := make(map[string]stateElement)
codeQueue := make(map[common.Hash]struct{})
paths, nodes, codes := sched.Missing(1)
@ -604,7 +608,7 @@ func TestIncompleteStateSync(t *testing.T) {
results := make([]trie.NodeSyncResult, 0, len(nodeQueue))
for path, element := range nodeQueue {
owner, inner := trie.ResolvePath([]byte(element.path))
data, err := srcDb.TrieDB().Reader(srcRoot).Node(owner, inner, element.hash)
data, err := reader.Node(owner, inner, element.hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x", element.hash)
}

View file

@ -211,7 +211,11 @@ func (api *DebugAPI) StorageRangeAt(ctx context.Context, blockHash common.Hash,
}
func storageRangeAt(st state.Trie, start []byte, maxResult int) (StorageRangeResult, error) {
it := trie.NewIterator(st.NodeIterator(start))
trieIt, err := st.NodeIterator(start)
if err != nil {
return StorageRangeResult{}, err
}
it := trie.NewIterator(trieIt)
result := StorageRangeResult{Storage: storageMap{}}
for i := 0; i < maxResult && it.Next(); i++ {
_, content, _, err := rlp.Split(it.Value)
@ -303,7 +307,15 @@ func (api *DebugAPI) getModifiedAccounts(startBlock, endBlock *types.Block) ([]c
return nil, err
}
diff, _ := trie.NewDifferenceIterator(oldTrie.NodeIterator([]byte{}), newTrie.NodeIterator([]byte{}))
oldIt, err := oldTrie.NodeIterator([]byte{})
if err != nil {
return nil, err
}
newIt, err := newTrie.NodeIterator([]byte{})
if err != nil {
return nil, err
}
diff, _ := trie.NewDifferenceIterator(oldIt, newIt)
iter := trie.NewIterator(diff)
var dirty []common.Address

View file

@ -36,7 +36,7 @@ import (
var dumper = spew.ConfigState{Indent: " "}
func accountRangeTest(t *testing.T, trie *state.Trie, statedb *state.StateDB, start common.Hash, requestedNum int, expectedNum int) state.IteratorDump {
func accountRangeTest(t *testing.T, statedb *state.StateDB, start common.Hash, requestedNum int, expectedNum int) state.IteratorDump {
result := statedb.IteratorDump(&state.DumpConfig{
SkipCode: true,
SkipStorage: true,
@ -80,17 +80,17 @@ func TestAccountRange(t *testing.T) {
m[addr] = true
}
}
sdb.Commit(0, true)
root := sdb.IntermediateRoot(true)
root, err := sdb.Commit(0, true)
sdb, _ = state.New(root, statedb)
trie, err := statedb.OpenTrie(root)
_, err = statedb.OpenTrie(root)
if err != nil {
t.Fatal(err)
}
accountRangeTest(t, &trie, sdb, common.Hash{}, AccountRangeMaxResults/2, AccountRangeMaxResults/2)
accountRangeTest(t, sdb, common.Hash{}, AccountRangeMaxResults/2, AccountRangeMaxResults/2)
// test pagination
firstResult := accountRangeTest(t, &trie, sdb, common.Hash{}, AccountRangeMaxResults, AccountRangeMaxResults)
secondResult := accountRangeTest(t, &trie, sdb, common.BytesToHash(firstResult.Next), AccountRangeMaxResults, AccountRangeMaxResults)
firstResult := accountRangeTest(t, sdb, common.Hash{}, AccountRangeMaxResults, AccountRangeMaxResults)
secondResult := accountRangeTest(t, sdb, common.BytesToHash(firstResult.Next), AccountRangeMaxResults, AccountRangeMaxResults)
hList := make([]common.Hash, 0)
for addr1 := range firstResult.Accounts {
@ -108,7 +108,7 @@ func TestAccountRange(t *testing.T) {
// set and get an even split between the first and second sets.
slices.SortFunc(hList, common.Hash.Cmp)
middleH := hList[AccountRangeMaxResults/2]
middleResult := accountRangeTest(t, &trie, sdb, middleH, AccountRangeMaxResults, AccountRangeMaxResults)
middleResult := accountRangeTest(t, sdb, middleH, AccountRangeMaxResults, AccountRangeMaxResults)
missing, infirst, insecond := 0, 0, 0
for h := range middleResult.Accounts {
if _, ok := firstResult.Accounts[h]; ok {

View file

@ -202,6 +202,11 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD
if root != common.Hash(post.Root) {
return statedb, fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root)
}
// Re-init the post-state instance for further operation
statedb, err = state.New(root, statedb.Database())
if err != nil {
return nil, err
}
return statedb, nil
}

View file

@ -133,11 +133,11 @@ func (db *Database) InsertPreimage(secKeyCache map[string][]byte) {
// Reader returns a reader for accessing all trie nodes with provided state root.
// Nil is returned in case the state is not available.
func (db *Database) Reader(blockRoot common.Hash) Reader {
func (db *Database) Reader(blockRoot common.Hash) (Reader, error) {
if hdb, ok := db.backend.(*hashdb.Database); ok {
return hdb.Reader(blockRoot)
}
return nil
return nil, errors.New("no support hashdb.Database")
}
// Update performs a state transition by committing dirty nodes contained in the

View file

@ -17,11 +17,17 @@
package trie
import (
"errors"
"fmt"
"github.com/XinFinOrg/XDPoSChain/common"
)
// ErrCommitted is returned when a already committed trie is requested for usage.
// The potential usages can be `Get`, `Update`, `Delete`, `NodeIterator`, `Prove`
// and so on.
var ErrCommitted = errors.New("trie is already committed")
// MissingNodeError is returned by the trie functions (Get, Update, Delete)
// in the case where a trie node is not present in the local database. It contains
// information necessary for retrieving the missing node.

View file

@ -34,7 +34,7 @@ import (
func TestEmptyIterator(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
iter := trie.NodeIterator(nil)
iter := trie.MustNodeIterator(nil)
seen := make(map[string]struct{})
for iter.Next(true) {
@ -67,7 +67,7 @@ func TestIterator(t *testing.T) {
trie, _ = New(TrieID(root), db)
found := make(map[string]string)
it := NewIterator(trie.NodeIterator(nil))
it := NewIterator(trie.MustNodeIterator(nil))
for it.Next() {
found[string(it.Key)] = string(it.Value)
}
@ -101,7 +101,7 @@ func TestIteratorLargeData(t *testing.T) {
vals[string(value2.k)] = value2
}
it := NewIterator(trie.NodeIterator(nil))
it := NewIterator(trie.MustNodeIterator(nil))
for it.Next() {
vals[string(it.Key)].t = true
}
@ -139,7 +139,7 @@ func testNodeIteratorCoverage(t *testing.T, scheme string) {
// Gather all the node hashes found by the iterator
var elements = make(map[common.Hash]iterationElement)
for it := trie.NodeIterator(nil); it.Next(true); {
for it := trie.MustNodeIterator(nil); it.Next(true); {
if it.Hash() != (common.Hash{}) {
elements[it.Hash()] = iterationElement{
hash: it.Hash(),
@ -149,8 +149,12 @@ func testNodeIteratorCoverage(t *testing.T, scheme string) {
}
}
// Cross check the hashes and the database itself
reader, err := nodeDb.Reader(trie.Hash())
if err != nil {
t.Fatalf("state is not available %x", trie.Hash())
}
for _, element := range elements {
if blob, err := nodeDb.Reader(trie.Hash()).Node(common.Hash{}, element.path, element.hash); err != nil {
if blob, err := reader.Node(common.Hash{}, element.path, element.hash); err != nil {
t.Errorf("failed to retrieve reported node %x: %v", element.hash, err)
} else if !bytes.Equal(blob, element.blob) {
t.Errorf("node blob is different, want %v got %v", element.blob, blob)
@ -210,19 +214,19 @@ func TestIteratorSeek(t *testing.T) {
}
// Seek to the middle.
it := NewIterator(trie.NodeIterator([]byte("fab")))
it := NewIterator(trie.MustNodeIterator([]byte("fab")))
if err := checkIteratorOrder(testdata1[4:], it); err != nil {
t.Fatal(err)
}
// Seek to a non-existent key.
it = NewIterator(trie.NodeIterator([]byte("barc")))
it = NewIterator(trie.MustNodeIterator([]byte("barc")))
if err := checkIteratorOrder(testdata1[1:], it); err != nil {
t.Fatal(err)
}
// Seek beyond the end.
it = NewIterator(trie.NodeIterator([]byte("z")))
it = NewIterator(trie.MustNodeIterator([]byte("z")))
if err := checkIteratorOrder(nil, it); err != nil {
t.Fatal(err)
}
@ -264,7 +268,7 @@ func TestDifferenceIterator(t *testing.T) {
trieb, _ = New(TrieID(rootB), dbb)
found := make(map[string]string)
di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
di, _ := NewDifferenceIterator(triea.MustNodeIterator(nil), trieb.MustNodeIterator(nil))
it := NewIterator(di)
for it.Next() {
found[string(it.Key)] = string(it.Value)
@ -305,7 +309,7 @@ func TestUnionIterator(t *testing.T) {
dbb.Update(rootB, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodesB))
trieb, _ = New(TrieID(rootB), dbb)
di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
di, _ := NewUnionIterator([]NodeIterator{triea.MustNodeIterator(nil), trieb.MustNodeIterator(nil)})
it := NewIterator(di)
all := []struct{ k, v string }{
@ -344,7 +348,7 @@ func TestIteratorNoDups(t *testing.T) {
for _, val := range testdata1 {
tr.MustUpdate([]byte(val.k), []byte(val.v))
}
checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
checkIteratorNoDups(t, tr.MustNodeIterator(nil), nil)
}
// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
@ -369,7 +373,7 @@ func testIteratorContinueAfterError(t *testing.T, memonly bool, scheme string) {
tdb.Commit(root, false)
}
tr, _ = New(TrieID(root), tdb)
wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
wantNodeCount := checkIteratorNoDups(t, tr.MustNodeIterator(nil), nil)
var (
paths [][]byte
@ -428,7 +432,7 @@ func testIteratorContinueAfterError(t *testing.T, memonly bool, scheme string) {
}
// Iterate until the error is hit.
seen := make(map[string]bool)
it := tr.NodeIterator(nil)
it := tr.MustNodeIterator(nil)
checkIteratorNoDups(t, it, seen)
missing, ok := it.Error().(*MissingNodeError)
if !ok || missing.NodeHash != rhash {
@ -496,7 +500,7 @@ func testIteratorContinueAfterSeekError(t *testing.T, memonly bool, scheme strin
}
// Create a new iterator that seeks to "bars". Seeking can't proceed because
// the node is missing.
it := tr.NodeIterator([]byte("bars"))
it := tr.MustNodeIterator([]byte("bars"))
missing, ok := it.Error().(*MissingNodeError)
if !ok {
t.Fatal("want MissingNodeError, got", it.Error())
@ -598,7 +602,10 @@ func makeLargeTestTrie() (*Database, *StateTrie, *loggingDb) {
}
root, nodes := trie.Commit(false)
triedb.Update(root, types.EmptyRootHash, 0, trienode.NewWithNodeSet(nodes))
triedb.Commit(root, false)
// Return the generated trie
trie, _ = NewStateTrie(TrieID(root), triedb)
return triedb, trie, logDb
}
@ -610,8 +617,8 @@ func TestNodeIteratorLargeTrie(t *testing.T) {
// Do a seek operation
trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885"))
// master: 24 get operations
// this pr: 5 get operations
if have, want := logDb.getCount, uint64(5); have != want {
// this pr: 6 get operations
if have, want := logDb.getCount, uint64(6); have != want {
t.Fatalf("Too many lookups during seek, have %d want %d", have, want)
}
}
@ -642,7 +649,7 @@ func testIteratorNodeBlob(t *testing.T, scheme string) {
var found = make(map[common.Hash][]byte)
trie, _ = New(TrieID(root), triedb)
it := trie.NodeIterator(nil)
it := trie.MustNodeIterator(nil)
for it.Next(true) {
if it.Hash() == (common.Hash{}) {
continue

View file

@ -34,6 +34,10 @@ import (
// nodes of the longest existing prefix of the key (at least the root node), ending
// with the node that proves the absence of the key.
func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return ErrCommitted
}
// Collect all nodes on the path to key.
var (
prefix []byte

View file

@ -63,7 +63,7 @@ func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database {
// Create a leaf iterator based Merkle prover
provers = append(provers, func(key []byte) *memorydb.Database {
proof := memorydb.New()
if it := NewIterator(trie.NodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) {
if it := NewIterator(trie.MustNodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) {
for _, p := range it.Prove() {
proof.Put(crypto.Keccak256(p), p)
}

View file

@ -250,12 +250,18 @@ func (t *StateTrie) Copy() *StateTrie {
}
}
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
// starts at the key after the given start key.
func (t *StateTrie) NodeIterator(start []byte) NodeIterator {
// NodeIterator returns an iterator that returns nodes of the underlying trie.
// Iteration starts at the key after the given start key.
func (t *StateTrie) NodeIterator(start []byte) (NodeIterator, error) {
return t.trie.NodeIterator(start)
}
// MustNodeIterator is a wrapper of NodeIterator and will omit any encountered
// error but just print out an error message.
func (t *StateTrie) MustNodeIterator(start []byte) NodeIterator {
return t.trie.MustNodeIterator(start)
}
// hashKey returns the hash of key as an ephemeral buffer.
// The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey.

View file

@ -94,7 +94,7 @@ func checkTrieConsistency(db ethdb.Database, scheme string, root common.Hash) er
if err != nil {
return nil // Consider a non existent state consistent
}
it := trie.NodeIterator(nil)
it := trie.MustNodeIterator(nil)
for it.Next(true) {
}
return it.Error()
@ -159,12 +159,16 @@ func testIterativeSync(t *testing.T, count int, bypath bool, scheme string) {
syncPath: NewSyncPath([]byte(paths[i])),
})
}
reader, err := srcDb.Reader(srcTrie.Hash())
if err != nil {
t.Fatalf("State is not available %x", srcTrie.Hash())
}
for len(elements) > 0 {
results := make([]NodeSyncResult, len(elements))
if !bypath {
for i, element := range elements {
owner, inner := ResolvePath([]byte(element.path))
data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash)
data, err := reader.Node(owner, inner, element.hash)
if err != nil {
t.Fatalf("failed to retrieve node data for hash %x: %v", element.hash, err)
}
@ -230,12 +234,16 @@ func testIterativeDelayedSync(t *testing.T, scheme string) {
syncPath: NewSyncPath([]byte(paths[i])),
})
}
reader, err := srcDb.Reader(srcTrie.Hash())
if err != nil {
t.Fatalf("State is not available %x", srcTrie.Hash())
}
for len(elements) > 0 {
// Sync only half of the scheduled nodes
results := make([]NodeSyncResult, len(elements)/2+1)
for i, element := range elements[:len(results)] {
owner, inner := ResolvePath([]byte(element.path))
data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash)
data, err := reader.Node(owner, inner, element.hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err)
}
@ -295,12 +303,16 @@ func testIterativeRandomSync(t *testing.T, count int, scheme string) {
syncPath: NewSyncPath([]byte(paths[i])),
}
}
reader, err := srcDb.Reader(srcTrie.Hash())
if err != nil {
t.Fatalf("State is not available %x", srcTrie.Hash())
}
for len(queue) > 0 {
// Fetch all the queued nodes in a random order
results := make([]NodeSyncResult, 0, len(queue))
for path, element := range queue {
owner, inner := ResolvePath([]byte(element.path))
data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash)
data, err := reader.Node(owner, inner, element.hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err)
}
@ -358,12 +370,16 @@ func testIterativeRandomDelayedSync(t *testing.T, scheme string) {
syncPath: NewSyncPath([]byte(path)),
}
}
reader, err := srcDb.Reader(srcTrie.Hash())
if err != nil {
t.Fatalf("State is not available %x", srcTrie.Hash())
}
for len(queue) > 0 {
// Sync only half of the scheduled nodes, even those in random order
results := make([]NodeSyncResult, 0, len(queue)/2+1)
for path, element := range queue {
owner, inner := ResolvePath([]byte(element.path))
data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash)
data, err := reader.Node(owner, inner, element.hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err)
}
@ -426,13 +442,16 @@ func testDuplicateAvoidanceSync(t *testing.T, scheme string) {
syncPath: NewSyncPath([]byte(paths[i])),
})
}
reader, err := srcDb.Reader(srcTrie.Hash())
if err != nil {
t.Fatalf("State is not available %x", srcTrie.Hash())
}
requested := make(map[common.Hash]struct{})
for len(elements) > 0 {
results := make([]NodeSyncResult, len(elements))
for i, element := range elements {
owner, inner := ResolvePath([]byte(element.path))
data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash)
data, err := reader.Node(owner, inner, element.hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err)
}
@ -501,12 +520,16 @@ func testIncompleteSync(t *testing.T, scheme string) {
syncPath: NewSyncPath([]byte(paths[i])),
})
}
reader, err := srcDb.Reader(srcTrie.Hash())
if err != nil {
t.Fatalf("State is not available %x", srcTrie.Hash())
}
for len(elements) > 0 {
// Fetch a batch of trie nodes
results := make([]NodeSyncResult, len(elements))
for i, element := range elements {
owner, inner := ResolvePath([]byte(element.path))
data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash)
data, err := reader.Node(owner, inner, element.hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err)
}
@ -585,12 +608,15 @@ func testSyncOrdering(t *testing.T, scheme string) {
})
reqs = append(reqs, NewSyncPath([]byte(paths[i])))
}
reader, err := srcDb.Reader(srcTrie.Hash())
if err != nil {
t.Fatalf("State is not available %x", srcTrie.Hash())
}
for len(elements) > 0 {
results := make([]NodeSyncResult, len(elements))
for i, element := range elements {
owner, inner := ResolvePath([]byte(element.path))
data, err := srcDb.Reader(srcTrie.Hash()).Node(owner, inner, element.hash)
data, err := reader.Node(owner, inner, element.hash)
if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", element.hash, err)
}
@ -649,11 +675,15 @@ func syncWith(t *testing.T, root common.Hash, db ethdb.Database, srcDb *Database
syncPath: NewSyncPath([]byte(paths[i])),
})
}
reader, err := srcDb.Reader(root)
if err != nil {
t.Fatalf("State is not available %x", root)
}
for len(elements) > 0 {
results := make([]NodeSyncResult, len(elements))
for i, element := range elements {
owner, inner := ResolvePath([]byte(element.path))
data, err := srcDb.Reader(root).Node(owner, inner, element.hash)
data, err := reader.Node(owner, inner, element.hash)
if err != nil {
t.Fatalf("failed to retrieve node data for hash %x: %v", element.hash, err)
}

View file

@ -226,14 +226,14 @@ func TestAccessListLeak(t *testing.T) {
}{
{
func(tr *Trie) {
it := tr.NodeIterator(nil)
it := tr.MustNodeIterator(nil)
for it.Next(true) {
}
},
},
{
func(tr *Trie) {
it := NewIterator(tr.NodeIterator(nil))
it := NewIterator(tr.MustNodeIterator(nil))
for it.Next() {
}
},
@ -300,7 +300,7 @@ func compareSet(setA, setB map[string]struct{}) bool {
func forNodes(tr *Trie) map[string][]byte {
var (
it = tr.NodeIterator(nil)
it = tr.MustNodeIterator(nil)
nodes = make(map[string][]byte)
)
for it.Next(true) {
@ -319,7 +319,7 @@ func iterNodes(db *Database, root common.Hash) map[string][]byte {
func forHashedNodes(tr *Trie) map[string][]byte {
var (
it = tr.NodeIterator(nil)
it = tr.MustNodeIterator(nil)
nodes = make(map[string][]byte)
)
for it.Next(true) {

View file

@ -39,6 +39,10 @@ type Trie struct {
root node
owner common.Hash
// Flag whether the commit operation is already performed. If so the
// trie is not usable(latest state is invisible).
committed bool
// Keep track of the number leaves which have been inserted since the last
// hashing operation. This number will not directly map to the number of
// actually unhashed nodes.
@ -88,12 +92,13 @@ func (t *Trie) UpdateDb(root common.Hash, parent common.Hash, block uint64, node
// Copy returns a copy of Trie.
func (t *Trie) Copy() *Trie {
return &Trie{
root: t.root,
owner: t.owner,
unhashed: t.unhashed,
db: t.db,
reader: t.reader,
tracer: t.tracer.copy(),
root: t.root,
owner: t.owner,
committed: t.committed,
unhashed: t.unhashed,
db: t.db,
reader: t.reader,
tracer: t.tracer.copy(),
}
}
@ -130,10 +135,24 @@ func NewEmpty(db *Database) *Trie {
return tr
}
// MustNodeIterator is a wrapper of NodeIterator and will omit any encountered
// error but just print out an error message.
func (t *Trie) MustNodeIterator(start []byte) NodeIterator {
it, err := t.NodeIterator(start)
if err != nil {
log.Error("Unhandled trie error in Trie.NodeIterator", "err", err)
}
return it
}
// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at
// the key after the given start key.
func (t *Trie) NodeIterator(start []byte) NodeIterator {
return newNodeIterator(t, start)
func (t *Trie) NodeIterator(start []byte) (NodeIterator, error) {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return nil, ErrCommitted
}
return newNodeIterator(t, start), nil
}
// MustGet is a wrapper of Get and will omit any encountered error but just
@ -152,6 +171,10 @@ func (t *Trie) MustGet(key []byte) []byte {
// If the requested node is not present in trie, no error will be returned.
// If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) Get(key []byte) ([]byte, error) {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return nil, ErrCommitted
}
value, newroot, didResolve, err := t.get(t.root, keybytesToHex(key), 0)
if err == nil && didResolve {
t.root = newroot
@ -211,6 +234,10 @@ func (t *Trie) MustGetNode(path []byte) ([]byte, int) {
// If the requested node is not present in trie, no error will be returned.
// If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) GetNode(path []byte) ([]byte, int, error) {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return nil, 0, ErrCommitted
}
item, newroot, resolved, err := t.getNode(t.root, compactToHex(path), 0)
if err != nil {
return nil, resolved, err
@ -468,6 +495,10 @@ func (t *Trie) MustUpdate(key, value []byte) {
// If the requested node is not present in trie, no error will be returned.
// If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) Update(key, value []byte) error {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return ErrCommitted
}
return t.update(key, value)
}
@ -582,6 +613,10 @@ func (t *Trie) MustDelete(key []byte) {
// If the requested node is not present in trie, no error will be returned.
// If the trie is corrupted, a MissingNodeError is returned.
func (t *Trie) Delete(key []byte) error {
// Short circuit if the trie is already committed and not usable.
if t.committed {
return ErrCommitted
}
t.unhashed++
k := keybytesToHex(key)
_, n, err := t.delete(t.root, nil, k)
@ -769,7 +804,9 @@ func (t *Trie) Hash() common.Hash {
// be created with new root and updated trie database for following usage
func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) {
defer t.tracer.reset()
defer func() {
t.committed = true
}()
nodes := trienode.NewNodeSet(t.owner)
t.tracer.markDeletions(nodes)
@ -816,4 +853,5 @@ func (t *Trie) Reset() {
t.owner = common.Hash{}
t.unhashed = 0
t.tracer.reset()
t.committed = false
}

View file

@ -17,9 +17,9 @@
package trie
import (
"fmt"
"github.com/XinFinOrg/XDPoSChain/common"
"github.com/XinFinOrg/XDPoSChain/core/types"
"github.com/XinFinOrg/XDPoSChain/log"
)
// Reader wraps the Node method of a backing trie store.
@ -30,13 +30,6 @@ type Reader interface {
Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)
}
// NodeReader wraps all the necessary functions for accessing trie node.
type NodeReader interface {
// Reader returns a reader for accessing all trie nodes with provided
// state root. Nil is returned in case the state is not available.
Reader(root common.Hash) Reader
}
// trieReader is a wrapper of the underlying node reader. It's not safe
// for concurrent usage.
type trieReader struct {
@ -46,10 +39,16 @@ type trieReader struct {
}
// newTrieReader initializes the trie reader with the given node reader.
func newTrieReader(stateRoot, owner common.Hash, db NodeReader) (*trieReader, error) {
reader := db.Reader(stateRoot)
if reader == nil {
return nil, fmt.Errorf("state not found #%x", stateRoot)
func newTrieReader(stateRoot, owner common.Hash, db *Database) (*trieReader, error) {
if stateRoot == (common.Hash{}) || stateRoot == types.EmptyRootHash {
if stateRoot == (common.Hash{}) {
log.Error("Zero state root hash!")
}
return &trieReader{owner: owner}, nil
}
reader, err := db.Reader(stateRoot)
if err != nil {
return nil, &MissingNodeError{Owner: owner, NodeHash: stateRoot, err: err}
}
return &trieReader{owner: owner, reader: reader}, nil
}

View file

@ -540,7 +540,7 @@ func runRandTest(rt randTest) error {
origin = root
case opItercheckhash:
checktr := NewEmpty(triedb)
it := NewIterator(tr.NodeIterator(nil))
it := NewIterator(tr.MustNodeIterator(nil))
for it.Next() {
checktr.MustUpdate(it.Key, it.Value)
}
@ -549,8 +549,8 @@ func runRandTest(rt randTest) error {
}
case opNodeDiff:
var (
origIter = origTrie.NodeIterator(nil)
curIter = tr.NodeIterator(nil)
origIter = origTrie.MustNodeIterator(nil)
curIter = tr.MustNodeIterator(nil)
origSeen = make(map[string]struct{})
curSeen = make(map[string]struct{})
)
@ -729,7 +729,7 @@ func TestTinyTrie(t *testing.T) {
t.Errorf("3: got %x, exp %x", root, exp)
}
checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
it := NewIterator(trie.NodeIterator(nil))
it := NewIterator(trie.MustNodeIterator(nil))
for it.Next() {
checktr.MustUpdate(it.Key, it.Value)
}

View file

@ -18,6 +18,7 @@ package hashdb
import (
"errors"
"fmt"
"reflect"
"sync"
"time"
@ -621,8 +622,12 @@ func (db *Database) Scheme() string {
}
// Reader retrieves a node reader belonging to the given state root.
func (db *Database) Reader(root common.Hash) *reader {
return &reader{db: db}
// An error will be returned if the requested state is not available.
func (db *Database) Reader(root common.Hash) (*reader, error) {
if _, err := db.Node(root); err != nil {
return nil, fmt.Errorf("state %#x is not available, %v", root, err)
}
return &reader{db: db}, nil
}
// reader is a state reader of Database which implements the Reader interface.