mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-06-12 01:41:36 +00:00
trie/bintrie: add mutex protection for concurrent access
This commit is contained in:
parent
b1baab4427
commit
8f9e3630d8
2 changed files with 217 additions and 14 deletions
|
|
@ -21,6 +21,7 @@ import (
|
|||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/core/types"
|
||||
|
|
@ -114,7 +115,16 @@ func NewBinaryNode() BinaryNode {
|
|||
}
|
||||
|
||||
// BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864.
|
||||
//
|
||||
// BinaryTrie is safe for concurrent use by multiple goroutines.
|
||||
type BinaryTrie struct {
|
||||
mu sync.Mutex // protects root and all tree mutation/reads
|
||||
|
||||
// readerMu protects calls into the underlying trie.Reader. trie.Reader is
|
||||
// explicitly not safe for concurrent use, and BinaryTrie.Copy/NodeIterator
|
||||
// share the same reader pointer between snapshots.
|
||||
readerMu *sync.Mutex
|
||||
|
||||
root BinaryNode
|
||||
reader *trie.Reader
|
||||
tracer *trie.PrevalueTracer
|
||||
|
|
@ -122,6 +132,8 @@ type BinaryTrie struct {
|
|||
|
||||
// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
|
||||
func (t *BinaryTrie) ToDot() string {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.root.Hash()
|
||||
return ToDot(t.root)
|
||||
}
|
||||
|
|
@ -133,9 +145,10 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
|
|||
return nil, err
|
||||
}
|
||||
t := &BinaryTrie{
|
||||
root: NewBinaryNode(),
|
||||
reader: reader,
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
root: NewBinaryNode(),
|
||||
reader: reader,
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
readerMu: new(sync.Mutex),
|
||||
}
|
||||
// Parse the root node if it's not empty
|
||||
if root != types.EmptyBinaryHash && root != types.EmptyRootHash {
|
||||
|
|
@ -159,6 +172,13 @@ func (t *BinaryTrie) nodeResolver(path []byte, hash common.Hash) ([]byte, error)
|
|||
if hash == (common.Hash{}) {
|
||||
return nil, nil // empty node
|
||||
}
|
||||
if t.reader == nil {
|
||||
return nil, fmt.Errorf("BinaryTrie nodeResolver requires a trie.Reader")
|
||||
}
|
||||
if t.readerMu != nil {
|
||||
t.readerMu.Lock()
|
||||
defer t.readerMu.Unlock()
|
||||
}
|
||||
blob, err := t.reader.Node(path, hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -176,11 +196,20 @@ func (t *BinaryTrie) GetKey(key []byte) []byte {
|
|||
// GetWithHashedKey returns the value, assuming that the key has already
|
||||
// been hashed.
|
||||
func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.root.Get(key, t.nodeResolver)
|
||||
}
|
||||
|
||||
// GetAccount returns the account information for the given address.
|
||||
func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.getAccount(addr)
|
||||
}
|
||||
|
||||
// getAccount is the lock-free internal implementation of GetAccount.
|
||||
func (t *BinaryTrie) getAccount(addr common.Address) (*types.StateAccount, error) {
|
||||
var (
|
||||
values [][]byte
|
||||
err error
|
||||
|
|
@ -238,11 +267,21 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
|
|||
// not be modified by the caller. If a node was not found in the database, a
|
||||
// trie.MissingNodeError is returned.
|
||||
func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.getStorage(addr, key)
|
||||
}
|
||||
|
||||
// getStorage is the lock-free internal implementation of GetStorage.
|
||||
func (t *BinaryTrie) getStorage(addr common.Address, key []byte) ([]byte, error) {
|
||||
return t.root.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
|
||||
}
|
||||
|
||||
// UpdateAccount updates the account information for the given address.
|
||||
func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
var (
|
||||
err error
|
||||
basicData [HashSize]byte
|
||||
|
|
@ -271,6 +310,13 @@ func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount,
|
|||
|
||||
// UpdateStem updates the values for the given stem key.
|
||||
func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.updateStem(key, values)
|
||||
}
|
||||
|
||||
// updateStem is the lock-free internal implementation of UpdateStem.
|
||||
func (t *BinaryTrie) updateStem(key []byte, values [][]byte) error {
|
||||
var err error
|
||||
t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0)
|
||||
return err
|
||||
|
|
@ -281,6 +327,9 @@ func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error {
|
|||
// by the caller while they are stored in the trie. If a node was not found in the
|
||||
// database, a trie.MissingNodeError is returned.
|
||||
func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
k := GetBinaryTreeKeyStorageSlot(address, key)
|
||||
var v [HashSize]byte
|
||||
if len(value) >= HashSize {
|
||||
|
|
@ -299,6 +348,9 @@ func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) er
|
|||
// DeleteAccount erases an account by overwriting the account
|
||||
// descriptors with 0s.
|
||||
func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
var (
|
||||
values = make([][]byte, StemNodeWidth)
|
||||
stem = GetBinaryTreeKey(addr, zero[:])
|
||||
|
|
@ -318,6 +370,9 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
|
|||
// DeleteStorage removes any existing value for key from the trie. If a node was not
|
||||
// found in the database, a trie.MissingNodeError is returned.
|
||||
func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
k := GetBinaryTreeKeyStorageSlot(addr, key)
|
||||
var zero [HashSize]byte
|
||||
root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0)
|
||||
|
|
@ -331,12 +386,17 @@ func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error {
|
|||
// Hash returns the root hash of the trie. It does not write to the database and
|
||||
// can be used even if the trie doesn't have one.
|
||||
func (t *BinaryTrie) Hash() common.Hash {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.root.Hash()
|
||||
}
|
||||
|
||||
// Commit writes all nodes to the trie's memory database, tracking the internal
|
||||
// and external (for account tries) references.
|
||||
func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
nodeset := trienode.NewNodeSet(common.Hash{})
|
||||
|
||||
// The root can be any type of BinaryNode (InternalNode, StemNode, etc.)
|
||||
|
|
@ -348,13 +408,16 @@ func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
|
|||
panic(fmt.Errorf("CollectNodes failed: %v", err))
|
||||
}
|
||||
// Serialize root commitment form
|
||||
return t.Hash(), nodeset
|
||||
// We already hold t.mu, so using the lock-free root hash computation.
|
||||
return t.root.Hash(), nodeset
|
||||
}
|
||||
|
||||
// NodeIterator returns an iterator that returns nodes of the trie. Iteration
|
||||
// starts at the key after the given start key.
|
||||
func (t *BinaryTrie) NodeIterator(startKey []byte) (trie.NodeIterator, error) {
|
||||
return newBinaryNodeIterator(t, nil)
|
||||
// Iterate on a snapshot to avoid racing with concurrent writers.
|
||||
snap := t.Copy()
|
||||
return newBinaryNodeIterator(snap, startKey)
|
||||
}
|
||||
|
||||
// Prove constructs a Merkle proof for key. The result contains all encoded nodes
|
||||
|
|
@ -368,12 +431,16 @@ func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
|
|||
panic("not implemented")
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the trie.
|
||||
// Copy creates a deep copy of the trie. The returned copy has its own mutex
|
||||
// and is independent of the original.
|
||||
func (t *BinaryTrie) Copy() *BinaryTrie {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return &BinaryTrie{
|
||||
root: t.root.Copy(),
|
||||
reader: t.reader,
|
||||
tracer: t.tracer.Copy(),
|
||||
root: t.root.Copy(),
|
||||
reader: t.reader,
|
||||
readerMu: t.readerMu,
|
||||
tracer: t.tracer.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -389,6 +456,9 @@ func (t *BinaryTrie) IsVerkle() bool {
|
|||
//
|
||||
// Note: the basic data leaf needs to have been previously created for this to work
|
||||
func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Hash, code []byte) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
var (
|
||||
chunks = ChunkifyCode(code)
|
||||
values [][]byte
|
||||
|
|
@ -406,9 +476,7 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has
|
|||
values[groupOffset] = chunks[i : i+HashSize]
|
||||
|
||||
if groupOffset == StemNodeWidth-1 || len(chunks)-i <= HashSize {
|
||||
err = t.UpdateStem(key[:StemSize], values)
|
||||
|
||||
if err != nil {
|
||||
if err = t.updateStem(key[:StemSize], values); err != nil {
|
||||
return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err)
|
||||
}
|
||||
}
|
||||
|
|
@ -419,8 +487,10 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has
|
|||
// PrefetchAccount attempts to resolve specific accounts from the database
|
||||
// to accelerate subsequent trie operations.
|
||||
func (t *BinaryTrie) PrefetchAccount(addresses []common.Address) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for _, addr := range addresses {
|
||||
if _, err := t.GetAccount(addr); err != nil {
|
||||
if _, err := t.getAccount(addr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -430,8 +500,10 @@ func (t *BinaryTrie) PrefetchAccount(addresses []common.Address) error {
|
|||
// PrefetchStorage attempts to resolve specific storage slots from the database
|
||||
// to accelerate subsequent trie operations.
|
||||
func (t *BinaryTrie) PrefetchStorage(addr common.Address, keys [][]byte) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for _, key := range keys {
|
||||
if _, err := t.GetStorage(addr, key); err != nil {
|
||||
if _, err := t.getStorage(addr, key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -440,5 +512,7 @@ func (t *BinaryTrie) PrefetchStorage(addr common.Address, keys [][]byte) error {
|
|||
|
||||
// Witness returns a set containing all trie nodes that have been accessed.
|
||||
func (t *BinaryTrie) Witness() map[string][]byte {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.tracer.Values()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ package bintrie
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
|
|
@ -779,3 +781,130 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
|
|||
t.Fatalf("expected nil/zero for non-existent storage, got %x", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentStorageUpdates exercises BinaryTrie under concurrent writes
|
||||
// from multiple goroutines — the scenario that triggers a data race when
|
||||
// IntermediateRoot parallelizes per-account updateTrie() calls on a single
|
||||
// shared binary trie. This test must pass with -race enabled.
|
||||
func TestConcurrentStorageUpdates(t *testing.T) {
|
||||
tr := newEmptyTestTrie(t)
|
||||
|
||||
// Create multiple accounts so the root becomes an InternalNode with
|
||||
// subtrees that different goroutines will traverse concurrently.
|
||||
addrs := []common.Address{
|
||||
common.HexToAddress("0x1111111111111111111111111111111111111111"),
|
||||
common.HexToAddress("0x2222222222222222222222222222222222222222"),
|
||||
common.HexToAddress("0x9999999999999999999999999999999999999999"),
|
||||
common.HexToAddress("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
acc := makeAccount(1, 100, common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"))
|
||||
if err := tr.UpdateAccount(addr, acc, 0); err != nil {
|
||||
t.Fatalf("UpdateAccount: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn goroutines that concurrently call UpdateStorage and DeleteStorage
|
||||
// on different addresses but the same shared trie. This mirrors the
|
||||
// errgroup pattern in IntermediateRoot.
|
||||
const slotsPerAddr = 20
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, len(addrs)*slotsPerAddr)
|
||||
|
||||
for _, addr := range addrs {
|
||||
wg.Add(1)
|
||||
go func(addr common.Address) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < slotsPerAddr; i++ {
|
||||
slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i))
|
||||
val := common.TrimLeftZeroes(common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000%04x%04x", addr[0], i)).Bytes())
|
||||
if err := tr.UpdateStorage(addr, slot[:], val); err != nil {
|
||||
errs <- fmt.Errorf("UpdateStorage(%x, slot %d): %v", addr, i, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(addr)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify all values are readable and correct.
|
||||
for _, addr := range addrs {
|
||||
for i := 0; i < slotsPerAddr; i++ {
|
||||
slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i))
|
||||
got, err := tr.GetStorage(addr, slot[:])
|
||||
if err != nil {
|
||||
t.Fatalf("GetStorage(%x, slot %d): %v", addr, i, err)
|
||||
}
|
||||
if len(got) == 0 {
|
||||
t.Fatalf("GetStorage(%x, slot %d): empty, expected value", addr, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Hash must not race with prior writes — call it to exercise the
|
||||
// read path after concurrent mutations.
|
||||
h := tr.Hash()
|
||||
if h == (common.Hash{}) {
|
||||
t.Fatal("trie hash is zero after concurrent updates")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentReadWrite exercises concurrent reads and writes on the same
|
||||
// BinaryTrie, verifying no data race under the race detector.
|
||||
func TestConcurrentReadWrite(t *testing.T) {
|
||||
tr := newEmptyTestTrie(t)
|
||||
|
||||
addr := common.HexToAddress("0x1111111111111111111111111111111111111111")
|
||||
acc := makeAccount(1, 100, common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"))
|
||||
if err := tr.UpdateAccount(addr, acc, 0); err != nil {
|
||||
t.Fatalf("UpdateAccount: %v", err)
|
||||
}
|
||||
|
||||
// Pre-populate some storage so reads have data to find.
|
||||
for i := 0; i < 10; i++ {
|
||||
slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i))
|
||||
val := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000dead%04x", i)).Bytes()
|
||||
if err := tr.UpdateStorage(addr, slot[:], common.TrimLeftZeroes(val)); err != nil {
|
||||
t.Fatalf("seed UpdateStorage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Writer goroutine.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 10; i < 30; i++ {
|
||||
slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i))
|
||||
val := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000beef%04x", i)).Bytes()
|
||||
tr.UpdateStorage(addr, slot[:], common.TrimLeftZeroes(val))
|
||||
}
|
||||
}()
|
||||
|
||||
// Reader goroutine.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 10; i++ {
|
||||
slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i))
|
||||
tr.GetStorage(addr, slot[:])
|
||||
}
|
||||
}()
|
||||
|
||||
// Hash goroutine.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 5; i++ {
|
||||
tr.Hash()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue