diff --git a/core/state/database.go b/core/state/database.go index 3b1e627f28..918659ee69 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -107,12 +107,18 @@ type Trie interface { // in the trie with provided address. UpdateAccount(address common.Address, account *types.StateAccount, codeLen int) error + // UpdateAccountBatch attempts to update a list accounts in the batch manner. + UpdateAccountBatch(addresses []common.Address, accounts []*types.StateAccount, codeLengths []int) error + // UpdateStorage associates key with value in the trie. If value has length zero, // any existing value is deleted from the trie. The value bytes must not be modified // by the caller while they are stored in the trie. If a node was not found in the // database, a trie.MissingNodeError is returned. UpdateStorage(addr common.Address, key, value []byte) error + // UpdateStorageBatch attempts to update a list storages in the batch manner. + UpdateStorageBatch(_ common.Address, keys [][]byte, values [][]byte) error + // DeleteAccount abstracts an account deletion from the trie. DeleteAccount(address common.Address) error diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index 0d0c0e0e70..ac9b19a2e1 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -451,3 +451,11 @@ func (t *BinaryTrie) PrefetchStorage(addr common.Address, keys [][]byte) error { func (t *BinaryTrie) Witness() map[string][]byte { return t.tracer.Values() } + +func (t *BinaryTrie) UpdateStorageBatch(_ common.Address, keys [][]byte, values [][]byte) error { + panic("not implemented") +} + +func (t *BinaryTrie) UpdateAccountBatch(addresses []common.Address, accounts []*types.StateAccount, _ []int) error { + panic("not implemented") +} diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 4d03ca45f0..f2176310d0 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -210,6 +210,29 @@ func (t *StateTrie) UpdateStorage(_ common.Address, key, value []byte) error { return nil } +// UpdateStorageBatch attempts to update a list storages in the batch manner. +func (t *StateTrie) UpdateStorageBatch(_ common.Address, keys [][]byte, values [][]byte) error { + var ( + hkeys = make([][]byte, 0, len(keys)) + evals = make([][]byte, 0, len(values)) + ) + for _, key := range keys { + hk := crypto.Keccak256(key) + if t.preimages != nil { + t.secKeyCache[common.Hash(hk)] = key + } + hkeys = append(hkeys, hk) + } + for _, val := range values { + data, err := rlp.EncodeToBytes(val) + if err != nil { + return err + } + evals = append(evals, data) + } + return t.trie.UpdateBatch(hkeys, evals) +} + // UpdateAccount will abstract the write of an account to the secure trie. func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccount, _ int) error { hk := crypto.Keccak256(address.Bytes()) @@ -226,6 +249,29 @@ func (t *StateTrie) UpdateAccount(address common.Address, acc *types.StateAccoun return nil } +// UpdateAccountBatch attempts to update a list accounts in the batch manner. +func (t *StateTrie) UpdateAccountBatch(addresses []common.Address, accounts []*types.StateAccount, _ []int) error { + var ( + hkeys = make([][]byte, 0, len(addresses)) + values = make([][]byte, 0, len(accounts)) + ) + for _, addr := range addresses { + hk := crypto.Keccak256(addr.Bytes()) + if t.preimages != nil { + t.secKeyCache[common.Hash(hk)] = addr.Bytes() + } + hkeys = append(hkeys, hk) + } + for _, acc := range accounts { + data, err := rlp.EncodeToBytes(acc) + if err != nil { + return err + } + values = append(values, data) + } + return t.trie.UpdateBatch(hkeys, values) +} + func (t *StateTrie) UpdateContractCode(_ common.Address, _ common.Hash, _ []byte) error { return nil } diff --git a/trie/tracer.go b/trie/tracer.go index 04122d1384..042fa468bf 100644 --- a/trie/tracer.go +++ b/trie/tracer.go @@ -33,12 +33,10 @@ import ( // while the latter is inserted/deleted in order to follow the rule of trie. // This tool can track all of them no matter the node is embedded in its // parent or not, but valueNode is never tracked. -// -// Note opTracer is not thread-safe, callers should be responsible for handling -// the concurrency issues by themselves. type opTracer struct { inserts map[string]struct{} deletes map[string]struct{} + lock sync.RWMutex } // newOpTracer initializes the tracer for capturing trie changes. @@ -53,6 +51,9 @@ func newOpTracer() *opTracer { // in the deletion set (resurrected node), then just wipe it from // the deletion set as it's "untouched". func (t *opTracer) onInsert(path []byte) { + t.lock.Lock() + defer t.lock.Unlock() + if _, present := t.deletes[string(path)]; present { delete(t.deletes, string(path)) return @@ -64,6 +65,9 @@ func (t *opTracer) onInsert(path []byte) { // in the addition set, then just wipe it from the addition set // as it's untouched. func (t *opTracer) onDelete(path []byte) { + t.lock.Lock() + defer t.lock.Unlock() + if _, present := t.inserts[string(path)]; present { delete(t.inserts, string(path)) return @@ -73,12 +77,18 @@ func (t *opTracer) onDelete(path []byte) { // reset clears the content tracked by tracer. func (t *opTracer) reset() { + t.lock.Lock() + defer t.lock.Unlock() + clear(t.inserts) clear(t.deletes) } // copy returns a deep copied tracer instance. func (t *opTracer) copy() *opTracer { + t.lock.RLock() + defer t.lock.RUnlock() + return &opTracer{ inserts: maps.Clone(t.inserts), deletes: maps.Clone(t.deletes), @@ -87,6 +97,9 @@ func (t *opTracer) copy() *opTracer { // deletedList returns a list of node paths which are deleted from the trie. func (t *opTracer) deletedList() [][]byte { + t.lock.RLock() + defer t.lock.RUnlock() + paths := make([][]byte, 0, len(t.deletes)) for path := range t.deletes { paths = append(paths, []byte(path)) diff --git a/trie/transitiontrie/transition.go b/trie/transitiontrie/transition.go index 3e5511be9e..d939e804e3 100644 --- a/trie/transitiontrie/transition.go +++ b/trie/transitiontrie/transition.go @@ -144,6 +144,19 @@ func (t *TransitionTrie) UpdateStorage(address common.Address, key []byte, value return t.overlay.UpdateStorage(address, key, v) } +// UpdateStorageBatch attempts to update a list storages in the batch manner. +func (t *TransitionTrie) UpdateStorageBatch(address common.Address, keys [][]byte, values [][]byte) error { + if len(keys) != len(values) { + return fmt.Errorf("keys and values length mismatch: %d != %d", len(keys), len(values)) + } + for i, key := range keys { + if err := t.UpdateStorage(address, key, values[i]); err != nil { + return err + } + } + return nil +} + // UpdateAccount abstract an account write to the trie. func (t *TransitionTrie) UpdateAccount(addr common.Address, account *types.StateAccount, codeLen int) error { // NOTE: before the rebase, this was saving the state root, so that OpenStorageTrie @@ -152,6 +165,22 @@ func (t *TransitionTrie) UpdateAccount(addr common.Address, account *types.State return t.overlay.UpdateAccount(addr, account, codeLen) } +// UpdateAccountBatch attempts to update a list accounts in the batch manner. +func (t *TransitionTrie) UpdateAccountBatch(addresses []common.Address, accounts []*types.StateAccount, codeLens []int) error { + if len(addresses) != len(accounts) { + return fmt.Errorf("address and accounts length mismatch: %d != %d", len(addresses), len(accounts)) + } + if len(addresses) != len(codeLens) { + return fmt.Errorf("address and code length mismatch: %d != %d", len(addresses), len(codeLens)) + } + for i, addr := range addresses { + if err := t.UpdateAccount(addr, accounts[i], codeLens[i]); err != nil { + return err + } + } + return nil +} + // DeleteStorage removes any existing value for key from the trie. If a node was not // found in the database, a trie.MissingNodeError is returned. func (t *TransitionTrie) DeleteStorage(addr common.Address, key []byte) error { diff --git a/trie/trie.go b/trie/trie.go index 1ef2c2f1a6..7e69a90823 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -480,6 +480,69 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error } } +// UpdateBatch updates a batch of entries concurrently. +func (t *Trie) UpdateBatch(keys [][]byte, values [][]byte) error { + // Short circuit if the trie is already committed and unusable. + if t.committed { + return ErrCommitted + } + if len(keys) != len(values) { + return fmt.Errorf("keys and values length mismatch: %d != %d", len(keys), len(values)) + } + // Insert the entries sequentially if there are not too many + // trie nodes in the trie. + fn, ok := t.root.(*fullNode) + if !ok || len(keys) < 4 { // TODO(rjl493456442) the parallelism threshold should be twisted + for i, key := range keys { + err := t.Update(key, values[i]) + if err != nil { + return err + } + } + return nil + } + var ( + ikeys = make(map[byte][][]byte) + ivals = make(map[byte][][]byte) + eg errgroup.Group + ) + for i, key := range keys { + hkey := keybytesToHex(key) + ikeys[hkey[0]] = append(ikeys[hkey[0]], hkey) + ivals[hkey[0]] = append(ivals[hkey[0]], values[i]) + } + if len(keys) > 0 { + fn.flags = t.newFlag() + } + for pos, ks := range ikeys { + eg.Go(func() error { + vs := ivals[pos] + for i, k := range ks { + if len(vs[i]) != 0 { + _, n, err := t.insert(fn.Children[pos], []byte{pos}, k[1:], valueNode(vs[i])) + if err != nil { + return err + } + fn.Children[pos] = n + } else { + _, n, err := t.delete(fn.Children[pos], []byte{pos}, k[1:]) + if err != nil { + return err + } + fn.Children[pos] = n + } + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return err + } + t.unhashed += len(keys) + t.uncommitted += len(keys) + return nil +} + // MustDelete is a wrapper of Delete and will omit any encountered error but // just print out an error message. func (t *Trie) MustDelete(key []byte) { diff --git a/trie/trie_test.go b/trie/trie_test.go index 3661933e22..949f381f07 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -1580,3 +1580,57 @@ func BenchmarkTrieSeqPrefetch(b *testing.B) { } } } + +func TestUpdateBatch(t *testing.T) { + testUpdateBatch(t, []kv{ + {k: []byte("do"), v: []byte("verb")}, + {k: []byte("ether"), v: []byte("wookiedoo")}, + {k: []byte("horse"), v: []byte("stallion")}, + {k: []byte("shaman"), v: []byte("horse")}, + {k: []byte("doge"), v: []byte("coin")}, + {k: []byte("dog"), v: []byte("puppy")}, + }) + + var entries []kv + for i := 0; i < 256; i++ { + entries = append(entries, kv{k: testrand.Bytes(32), v: testrand.Bytes(32)}) + } + testUpdateBatch(t, entries) +} + +func testUpdateBatch(t *testing.T, entries []kv) { + var ( + base = NewEmpty(nil) + keys [][]byte + vals [][]byte + ) + for _, entry := range entries { + base.Update(entry.k, entry.v) + keys = append(keys, entry.k) + vals = append(vals, entry.v) + } + for i := 0; i < 10; i++ { + k, v := testrand.Bytes(32), testrand.Bytes(32) + base.Update(k, v) + keys = append(keys, k) + vals = append(vals, v) + } + + cmp := NewEmpty(nil) + if err := cmp.UpdateBatch(keys, vals); err != nil { + t.Fatalf("Failed to update batch, %v", err) + } + + // Traverse the original tree, the changes made on the copy one shouldn't + // affect the old one + for _, key := range keys { + v1, _ := base.Get(key) + v2, _ := cmp.Get(key) + if !bytes.Equal(v1, v2) { + t.Errorf("Unexpected data, key: %v, want: %v, got: %v", key, v1, v2) + } + } + if base.Hash() != cmp.Hash() { + t.Errorf("Hash mismatch: want %x, got %x", base.Hash(), cmp.Hash()) + } +}