diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 45d5fd63e7..7852e16619 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -19,6 +19,7 @@ package trie import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie/trienode" "github.com/ethereum/go-ethereum/triedb/database" @@ -63,8 +64,7 @@ type StateTrie struct { trie Trie db database.NodeDatabase preimages preimageStore - hashKeyBuf [common.HashLength]byte - secKeyCache map[string][]byte + secKeyCache map[common.Hash][]byte secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch } @@ -97,7 +97,7 @@ func NewStateTrie(id *ID, db database.NodeDatabase) (*StateTrie, error) { // This function will omit any encountered error but just // print out an error message. func (t *StateTrie) MustGet(key []byte) []byte { - return t.trie.MustGet(t.hashKey(key)) + return t.trie.MustGet(crypto.Keccak256(key)) } // GetStorage attempts to retrieve a storage slot with provided account address @@ -105,7 +105,7 @@ func (t *StateTrie) MustGet(key []byte) []byte { // If the specified storage slot is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { - enc, err := t.trie.Get(t.hashKey(key)) + enc, err := t.trie.Get(crypto.Keccak256(key)) if err != nil || len(enc) == 0 { return nil, err } @@ -117,7 +117,7 @@ func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { // If the specified account is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) GetAccount(address common.Address) (*types.StateAccount, error) { - res, err := t.trie.Get(t.hashKey(address.Bytes())) + res, err := t.trie.Get(crypto.Keccak256(address.Bytes())) if res == nil || err != nil { return nil, err } @@ -157,9 +157,9 @@ func (t *StateTrie) GetNode(path []byte) ([]byte, int, error) { // This function will omit any encountered error but just print out an // error message. func (t *StateTrie) MustUpdate(key, value []byte) { - hk := t.hashKey(key) + hk := crypto.Keccak256(key) t.trie.MustUpdate(hk, value) - t.getSecKeyCache()[string(hk)] = common.CopyBytes(key) + t.getSecKeyCache()[common.Hash(hk)] = common.CopyBytes(key) } // UpdateStorage associates key with value in the trie. Subsequent calls to @@ -171,19 +171,19 @@ func (t *StateTrie) MustUpdate(key, value []byte) { // // If a node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) UpdateStorage(_ common.Address, key, value []byte) error { - hk := t.hashKey(key) + hk := crypto.Keccak256(key) v, _ := rlp.EncodeToBytes(value) err := t.trie.Update(hk, v) if err != nil { return err } - t.getSecKeyCache()[string(hk)] = common.CopyBytes(key) + t.getSecKeyCache()[common.Hash(hk)] = common.CopyBytes(key) return nil } // UpdateAccount will abstract the write of an account to the secure trie. func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccount, _ int) error { - hk := t.hashKey(address.Bytes()) + hk := crypto.Keccak256(address.Bytes()) data, err := rlp.EncodeToBytes(acc) if err != nil { return err @@ -191,7 +191,7 @@ func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccoun if err := t.trie.Update(hk, data); err != nil { return err } - t.getSecKeyCache()[string(hk)] = address.Bytes() + t.getSecKeyCache()[common.Hash(hk)] = address.Bytes() return nil } @@ -202,8 +202,8 @@ func (t *StateTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte // MustDelete removes any existing value for key from the trie. This function // will omit any encountered error but just print out an error message. func (t *StateTrie) MustDelete(key []byte) { - hk := t.hashKey(key) - delete(t.getSecKeyCache(), string(hk)) + hk := crypto.Keccak256(key) + delete(t.getSecKeyCache(), common.Hash(hk)) t.trie.MustDelete(hk) } @@ -211,22 +211,22 @@ func (t *StateTrie) MustDelete(key []byte) { // If the specified trie node is not in the trie, nothing will be changed. // If a node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) DeleteStorage(_ common.Address, key []byte) error { - hk := t.hashKey(key) - delete(t.getSecKeyCache(), string(hk)) + hk := crypto.Keccak256(key) + delete(t.getSecKeyCache(), common.Hash(hk)) return t.trie.Delete(hk) } // DeleteAccount abstracts an account deletion from the trie. func (t *StateTrie) DeleteAccount(address common.Address) error { - hk := t.hashKey(address.Bytes()) - delete(t.getSecKeyCache(), string(hk)) + hk := crypto.Keccak256(address.Bytes()) + delete(t.getSecKeyCache(), common.Hash(hk)) return t.trie.Delete(hk) } // GetKey returns the sha3 preimage of a hashed key that was // previously used to store a value. func (t *StateTrie) GetKey(shaKey []byte) []byte { - if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { + if key, ok := t.getSecKeyCache()[common.BytesToHash(shaKey)]; ok { return key } if t.preimages == nil { @@ -251,13 +251,9 @@ func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) { // Write all the pre-images to the actual disk database if len(t.getSecKeyCache()) > 0 { if t.preimages != nil { - preimages := make(map[common.Hash][]byte, len(t.secKeyCache)) - for hk, key := range t.secKeyCache { - preimages[common.BytesToHash([]byte(hk))] = key - } - t.preimages.InsertPreimage(preimages) + t.preimages.InsertPreimage(t.secKeyCache) } - t.secKeyCache = make(map[string][]byte) + t.secKeyCache = make(map[common.Hash][]byte) } // Commit the trie and return its modified nodeset. return t.trie.Commit(collectLeaf) @@ -291,25 +287,13 @@ func (t *StateTrie) MustNodeIterator(start []byte) NodeIterator { return t.trie.MustNodeIterator(start) } -// hashKey returns the hash of key as an ephemeral buffer. -// The caller must not hold onto the return value because it will become -// invalid on the next call to hashKey or secKey. -func (t *StateTrie) hashKey(key []byte) []byte { - h := newHasher(false) - h.sha.Reset() - h.sha.Write(key) - h.sha.Read(t.hashKeyBuf[:]) - returnHasherToPool(h) - return t.hashKeyBuf[:] -} - // getSecKeyCache returns the current secure key cache, creating a new one if // ownership changed (i.e. the current secure trie is a copy of another owning // the actual cache). -func (t *StateTrie) getSecKeyCache() map[string][]byte { +func (t *StateTrie) getSecKeyCache() map[common.Hash][]byte { if t != t.secKeyCacheOwner { t.secKeyCacheOwner = t - t.secKeyCache = make(map[string][]byte) + t.secKeyCache = make(map[common.Hash][]byte) } return t.secKeyCache }