diff --git a/core/state/dump.go b/core/state/dump.go index 7e7b494238..30b8dfed3b 100644 --- a/core/state/dump.go +++ b/core/state/dump.go @@ -167,7 +167,12 @@ func (s *StateDB) DumpToCollector(c DumpCollector, conf *DumpConfig) (nextKey [] } if !conf.SkipStorage { account.Storage = make(map[common.Hash]string) - storageIt := trie.NewIterator(obj.getTrie(s.db).NodeIterator(nil)) + tr, err := obj.getTrie(s.db) + if err != nil { + log.Error("Failed to load storage trie", "err", err) + continue + } + storageIt := trie.NewIterator(tr.NodeIterator(nil)) for storageIt.Next() { _, content, _, err := rlp.Split(storageIt.Value) if err != nil { diff --git a/core/state/state_object.go b/core/state/state_object.go index 8accaed408..fd8c15dbf6 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -152,16 +152,18 @@ func (s *stateObject) touch() { } } -func (s *stateObject) getTrie(db Database) Trie { +// getTrie returns the associated storage trie. The trie will be opened +// if it's not loaded previously. An error will be returned if trie can't +// be loaded. +func (s *stateObject) getTrie(db Database) (Trie, error) { if s.trie == nil { - var err error - s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) + tr, err := db.OpenStorageTrie(s.addrHash, s.data.Root) if err != nil { - s.trie, _ = db.OpenStorageTrie(s.addrHash, types.EmptyRootHash) - s.setError(fmt.Errorf("can't create storage trie: %v", err)) + return nil, err } + s.trie = tr } - return s.trie + return s.trie, nil } // GetState retrieves a value from the account storage trie. @@ -194,7 +196,12 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has // Track the amount of time wasted on reading the storage trie start := time.Now() // Otherwise load the value from the database - enc, err := s.getTrie(db).TryGet(key.Bytes()) + tr, err := s.getTrie(db) + if err != nil { + s.setError(err) + return common.Hash{} + } + enc, err := tr.TryGet(key.Bytes()) s.db.StorageReads += time.Since(start) if err != nil { s.setError(err) @@ -267,17 +274,22 @@ func (s *stateObject) finalise() { } // updateTrie writes cached storage modifications into the object's storage trie. -// It will return nil if the trie has not been loaded and no changes have been made -func (s *stateObject) updateTrie(db Database) Trie { +// It will return nil if the trie has not been loaded and no changes have been +// made. An error will be returned if the trie can't be loaded/updated correctly. +func (s *stateObject) updateTrie(db Database) (Trie, error) { // Make sure all dirty slots are finalized into the pending storage area s.finalise() if len(s.pendingStorage) == 0 { - return s.trie + return s.trie, nil } // Track the amount of time wasted on updating the storage trie defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now()) + tr, err := s.getTrie(db) + if err != nil { + s.setError(err) + return nil, err + } // Insert all the pending updates into the trie - tr := s.getTrie(db) for key, value := range s.pendingStorage { // Skip noop changes, persist actual changes if value == s.originStorage[key] { @@ -286,50 +298,66 @@ func (s *stateObject) updateTrie(db Database) Trie { s.originStorage[key] = value if (value == common.Hash{}) { - s.setError(tr.TryDelete(key[:])) + if err := tr.TryDelete(key[:]); err != nil { + s.setError(err) + return nil, err + } s.db.StorageDeleted += 1 - continue + } else { + // Encoding []byte cannot fail, ok to ignore the error. + v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) + if err := tr.TryUpdate(key[:], v); err != nil { + s.setError(err) + return nil, err + } + s.db.StorageUpdated += 1 } - // Encoding []byte cannot fail, ok to ignore the error. - v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) - s.setError(tr.TryUpdate(key[:], v)) - s.db.StorageUpdated += 1 } if len(s.pendingStorage) > 0 { s.pendingStorage = make(Storage) } - return tr + return tr, nil } -// UpdateRoot sets the trie root to the current root hash of +// UpdateRoot sets the trie root to the current root hash of. An error +// will be returned if trie root hash is not computed correctly. func (s *stateObject) updateRoot(db Database) { + tr, err := s.updateTrie(db) + if err != nil { + s.setError(fmt.Errorf("updateRoot (%x) error: %w", s.address, err)) + return + } // If nothing changed, don't bother with hashing anything - if s.updateTrie(db) == nil { + if tr == nil { return } // Track the amount of time wasted on hashing the storage trie defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now()) - - s.data.Root = s.trie.Hash() + s.data.Root = tr.Hash() } // CommitTrie the storage trie of the object to dwb. // This updates the trie root. func (s *stateObject) commitTrie(db Database) (int, error) { // If nothing changed, don't bother with hashing anything - if s.updateTrie(db) == nil { - return 0, nil + tr, err := s.updateTrie(db) + if err != nil { + return 0, err } if s.dbErr != nil { return 0, s.dbErr } + // If nothing changed, don't bother with hashing anything + if tr == nil { + return 0, nil + } // Track the amount of time wasted on committing the storage trie defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) - root, committed, err := s.trie.Commit(nil) + root, nodes, err := tr.Commit(nil) if err == nil { s.data.Root = root } - return committed, err + return nodes, err } // AddBalance adds amount to s's balance. diff --git a/core/state/statedb.go b/core/state/statedb.go index 9a86ef0e26..2179eddade 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -336,15 +336,18 @@ func (s *StateDB) Database() Database { return s.db } -// StorageTrie returns the storage trie of an account. -// The return value is a copy and is nil for non-existent accounts. -func (s *StateDB) StorageTrie(addr common.Address) Trie { +// StorageTrie returns the storage trie of an account. The return value is a copy +// and is nil for non-existent accounts. An error will be returned if storage trie +// is existent but can't be loaded correctly. +func (s *StateDB) StorageTrie(addr common.Address) (Trie, error) { stateObject := s.getStateObject(addr) if stateObject == nil { - return nil + return nil, nil } cpy := stateObject.deepCopy(s) - cpy.updateTrie(s.db) + if _, err := cpy.updateTrie(s.db); err != nil { + return nil, err + } return cpy.getTrie(s.db) } @@ -604,7 +607,11 @@ func (s *StateDB) ForEachStorage(addr common.Address, cb func(key, value common. if so == nil { return nil } - it := trie.NewIterator(so.getTrie(s.db).NodeIterator(nil)) + tr, err := so.getTrie(s.db) + if err != nil { + return err + } + it := trie.NewIterator(tr.NodeIterator(nil)) for it.Next() { key := common.BytesToHash(s.trie.GetKey(it.Key)) diff --git a/eth/api.go b/eth/api.go index 91fa756f71..ba8d7df3f8 100644 --- a/eth/api.go +++ b/eth/api.go @@ -408,7 +408,10 @@ func (api *DebugAPI) StorageRangeAt(ctx context.Context, blockHash common.Hash, if err != nil { return StorageRangeResult{}, err } - st := statedb.StorageTrie(contractAddress) + st, err := statedb.StorageTrie(contractAddress) + if err != nil { + return StorageRangeResult{}, err + } if st == nil { return StorageRangeResult{}, fmt.Errorf("account %x doesn't exist", contractAddress) } diff --git a/eth/api_test.go b/eth/api_test.go index 37e7997cec..f271ee28ba 100644 --- a/eth/api_test.go +++ b/eth/api_test.go @@ -205,7 +205,11 @@ func TestStorageRangeAt(t *testing.T) { }, } for _, test := range tests { - result, err := storageRangeAt(state.StorageTrie(addr), test.start, test.limit) + tr, err := state.StorageTrie(addr) + if err != nil { + t.Error(err) + } + result, err := storageRangeAt(tr, test.start, test.limit) if err != nil { t.Error(err) }