diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index b1e3c991c0..c610d3c5c3 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -21,6 +21,7 @@ import ( "encoding/binary" "errors" "fmt" + "sync" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -114,7 +115,16 @@ func NewBinaryNode() BinaryNode { } // BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864. +// +// BinaryTrie is safe for concurrent use by multiple goroutines. type BinaryTrie struct { + mu sync.Mutex // protects root and all tree mutation/reads + + // readerMu protects calls into the underlying trie.Reader. trie.Reader is + // explicitly not safe for concurrent use, and BinaryTrie.Copy/NodeIterator + // share the same reader pointer between snapshots. + readerMu *sync.Mutex + root BinaryNode reader *trie.Reader tracer *trie.PrevalueTracer @@ -122,6 +132,8 @@ type BinaryTrie struct { // ToDot converts the binary trie to a DOT language representation. Useful for debugging. func (t *BinaryTrie) ToDot() string { + t.mu.Lock() + defer t.mu.Unlock() t.root.Hash() return ToDot(t.root) } @@ -133,9 +145,10 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err return nil, err } t := &BinaryTrie{ - root: NewBinaryNode(), - reader: reader, - tracer: trie.NewPrevalueTracer(), + root: NewBinaryNode(), + reader: reader, + tracer: trie.NewPrevalueTracer(), + readerMu: new(sync.Mutex), } // Parse the root node if it's not empty if root != types.EmptyBinaryHash && root != types.EmptyRootHash { @@ -159,6 +172,13 @@ func (t *BinaryTrie) nodeResolver(path []byte, hash common.Hash) ([]byte, error) if hash == (common.Hash{}) { return nil, nil // empty node } + if t.reader == nil { + return nil, fmt.Errorf("BinaryTrie nodeResolver requires a trie.Reader") + } + if t.readerMu != nil { + t.readerMu.Lock() + defer t.readerMu.Unlock() + } blob, err := t.reader.Node(path, hash) if err != nil { return nil, err @@ -176,11 +196,20 @@ func (t *BinaryTrie) GetKey(key []byte) []byte { // GetWithHashedKey returns the value, assuming that the key has already // been hashed. func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) { + t.mu.Lock() + defer t.mu.Unlock() return t.root.Get(key, t.nodeResolver) } // GetAccount returns the account information for the given address. func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) { + t.mu.Lock() + defer t.mu.Unlock() + return t.getAccount(addr) +} + +// getAccount is the lock-free internal implementation of GetAccount. +func (t *BinaryTrie) getAccount(addr common.Address) (*types.StateAccount, error) { var ( values [][]byte err error @@ -238,11 +267,21 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error // not be modified by the caller. If a node was not found in the database, a // trie.MissingNodeError is returned. func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) { + t.mu.Lock() + defer t.mu.Unlock() + return t.getStorage(addr, key) +} + +// getStorage is the lock-free internal implementation of GetStorage. +func (t *BinaryTrie) getStorage(addr common.Address, key []byte) ([]byte, error) { return t.root.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver) } // UpdateAccount updates the account information for the given address. func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error { + t.mu.Lock() + defer t.mu.Unlock() + var ( err error basicData [HashSize]byte @@ -271,6 +310,13 @@ func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, // UpdateStem updates the values for the given stem key. func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error { + t.mu.Lock() + defer t.mu.Unlock() + return t.updateStem(key, values) +} + +// updateStem is the lock-free internal implementation of UpdateStem. +func (t *BinaryTrie) updateStem(key []byte, values [][]byte) error { var err error t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0) return err @@ -281,6 +327,9 @@ func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error { // by the caller while they are stored in the trie. If a node was not found in the // database, a trie.MissingNodeError is returned. func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) error { + t.mu.Lock() + defer t.mu.Unlock() + k := GetBinaryTreeKeyStorageSlot(address, key) var v [HashSize]byte if len(value) >= HashSize { @@ -299,6 +348,9 @@ func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) er // DeleteAccount erases an account by overwriting the account // descriptors with 0s. func (t *BinaryTrie) DeleteAccount(addr common.Address) error { + t.mu.Lock() + defer t.mu.Unlock() + var ( values = make([][]byte, StemNodeWidth) stem = GetBinaryTreeKey(addr, zero[:]) @@ -318,6 +370,9 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error { // DeleteStorage removes any existing value for key from the trie. If a node was not // found in the database, a trie.MissingNodeError is returned. func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error { + t.mu.Lock() + defer t.mu.Unlock() + k := GetBinaryTreeKeyStorageSlot(addr, key) var zero [HashSize]byte root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0) @@ -331,12 +386,17 @@ func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error { // Hash returns the root hash of the trie. It does not write to the database and // can be used even if the trie doesn't have one. func (t *BinaryTrie) Hash() common.Hash { + t.mu.Lock() + defer t.mu.Unlock() return t.root.Hash() } // Commit writes all nodes to the trie's memory database, tracking the internal // and external (for account tries) references. func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { + t.mu.Lock() + defer t.mu.Unlock() + nodeset := trienode.NewNodeSet(common.Hash{}) // The root can be any type of BinaryNode (InternalNode, StemNode, etc.) @@ -348,13 +408,16 @@ func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { panic(fmt.Errorf("CollectNodes failed: %v", err)) } // Serialize root commitment form - return t.Hash(), nodeset + // We already hold t.mu, so using the lock-free root hash computation. + return t.root.Hash(), nodeset } // NodeIterator returns an iterator that returns nodes of the trie. Iteration // starts at the key after the given start key. func (t *BinaryTrie) NodeIterator(startKey []byte) (trie.NodeIterator, error) { - return newBinaryNodeIterator(t, nil) + // Iterate on a snapshot to avoid racing with concurrent writers. + snap := t.Copy() + return newBinaryNodeIterator(snap, startKey) } // Prove constructs a Merkle proof for key. The result contains all encoded nodes @@ -368,12 +431,16 @@ func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { panic("not implemented") } -// Copy creates a deep copy of the trie. +// Copy creates a deep copy of the trie. The returned copy has its own mutex +// and is independent of the original. func (t *BinaryTrie) Copy() *BinaryTrie { + t.mu.Lock() + defer t.mu.Unlock() return &BinaryTrie{ - root: t.root.Copy(), - reader: t.reader, - tracer: t.tracer.Copy(), + root: t.root.Copy(), + reader: t.reader, + readerMu: t.readerMu, + tracer: t.tracer.Copy(), } } @@ -389,6 +456,9 @@ func (t *BinaryTrie) IsVerkle() bool { // // Note: the basic data leaf needs to have been previously created for this to work func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Hash, code []byte) error { + t.mu.Lock() + defer t.mu.Unlock() + var ( chunks = ChunkifyCode(code) values [][]byte @@ -406,9 +476,7 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has values[groupOffset] = chunks[i : i+HashSize] if groupOffset == StemNodeWidth-1 || len(chunks)-i <= HashSize { - err = t.UpdateStem(key[:StemSize], values) - - if err != nil { + if err = t.updateStem(key[:StemSize], values); err != nil { return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err) } } @@ -419,8 +487,10 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has // PrefetchAccount attempts to resolve specific accounts from the database // to accelerate subsequent trie operations. func (t *BinaryTrie) PrefetchAccount(addresses []common.Address) error { + t.mu.Lock() + defer t.mu.Unlock() for _, addr := range addresses { - if _, err := t.GetAccount(addr); err != nil { + if _, err := t.getAccount(addr); err != nil { return err } } @@ -430,8 +500,10 @@ func (t *BinaryTrie) PrefetchAccount(addresses []common.Address) error { // PrefetchStorage attempts to resolve specific storage slots from the database // to accelerate subsequent trie operations. func (t *BinaryTrie) PrefetchStorage(addr common.Address, keys [][]byte) error { + t.mu.Lock() + defer t.mu.Unlock() for _, key := range keys { - if _, err := t.GetStorage(addr, key); err != nil { + if _, err := t.getStorage(addr, key); err != nil { return err } } @@ -440,5 +512,7 @@ func (t *BinaryTrie) PrefetchStorage(addr common.Address, keys [][]byte) error { // Witness returns a set containing all trie nodes that have been accessed. func (t *BinaryTrie) Witness() map[string][]byte { + t.mu.Lock() + defer t.mu.Unlock() return t.tracer.Values() } diff --git a/trie/bintrie/trie_test.go b/trie/bintrie/trie_test.go index 5b104ddde4..b0836632ba 100644 --- a/trie/bintrie/trie_test.go +++ b/trie/bintrie/trie_test.go @@ -19,6 +19,8 @@ package bintrie import ( "bytes" "encoding/binary" + "fmt" + "sync" "testing" "github.com/ethereum/go-ethereum/common" @@ -779,3 +781,130 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) { t.Fatalf("expected nil/zero for non-existent storage, got %x", got) } } + +// TestConcurrentStorageUpdates exercises BinaryTrie under concurrent writes +// from multiple goroutines — the scenario that triggers a data race when +// IntermediateRoot parallelizes per-account updateTrie() calls on a single +// shared binary trie. This test must pass with -race enabled. +func TestConcurrentStorageUpdates(t *testing.T) { + tr := newEmptyTestTrie(t) + + // Create multiple accounts so the root becomes an InternalNode with + // subtrees that different goroutines will traverse concurrently. + addrs := []common.Address{ + common.HexToAddress("0x1111111111111111111111111111111111111111"), + common.HexToAddress("0x2222222222222222222222222222222222222222"), + common.HexToAddress("0x9999999999999999999999999999999999999999"), + common.HexToAddress("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), + } + for _, addr := range addrs { + acc := makeAccount(1, 100, common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470")) + if err := tr.UpdateAccount(addr, acc, 0); err != nil { + t.Fatalf("UpdateAccount: %v", err) + } + } + + // Spawn goroutines that concurrently call UpdateStorage and DeleteStorage + // on different addresses but the same shared trie. This mirrors the + // errgroup pattern in IntermediateRoot. + const slotsPerAddr = 20 + var wg sync.WaitGroup + errs := make(chan error, len(addrs)*slotsPerAddr) + + for _, addr := range addrs { + wg.Add(1) + go func(addr common.Address) { + defer wg.Done() + for i := 0; i < slotsPerAddr; i++ { + slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i)) + val := common.TrimLeftZeroes(common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000%04x%04x", addr[0], i)).Bytes()) + if err := tr.UpdateStorage(addr, slot[:], val); err != nil { + errs <- fmt.Errorf("UpdateStorage(%x, slot %d): %v", addr, i, err) + return + } + } + }(addr) + } + wg.Wait() + close(errs) + + for err := range errs { + t.Fatal(err) + } + + // Verify all values are readable and correct. + for _, addr := range addrs { + for i := 0; i < slotsPerAddr; i++ { + slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i)) + got, err := tr.GetStorage(addr, slot[:]) + if err != nil { + t.Fatalf("GetStorage(%x, slot %d): %v", addr, i, err) + } + if len(got) == 0 { + t.Fatalf("GetStorage(%x, slot %d): empty, expected value", addr, i) + } + } + } + + // Hash must not race with prior writes — call it to exercise the + // read path after concurrent mutations. + h := tr.Hash() + if h == (common.Hash{}) { + t.Fatal("trie hash is zero after concurrent updates") + } +} + +// TestConcurrentReadWrite exercises concurrent reads and writes on the same +// BinaryTrie, verifying no data race under the race detector. +func TestConcurrentReadWrite(t *testing.T) { + tr := newEmptyTestTrie(t) + + addr := common.HexToAddress("0x1111111111111111111111111111111111111111") + acc := makeAccount(1, 100, common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470")) + if err := tr.UpdateAccount(addr, acc, 0); err != nil { + t.Fatalf("UpdateAccount: %v", err) + } + + // Pre-populate some storage so reads have data to find. + for i := 0; i < 10; i++ { + slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i)) + val := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000dead%04x", i)).Bytes() + if err := tr.UpdateStorage(addr, slot[:], common.TrimLeftZeroes(val)); err != nil { + t.Fatalf("seed UpdateStorage: %v", err) + } + } + + var wg sync.WaitGroup + + // Writer goroutine. + wg.Add(1) + go func() { + defer wg.Done() + for i := 10; i < 30; i++ { + slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i)) + val := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000beef%04x", i)).Bytes() + tr.UpdateStorage(addr, slot[:], common.TrimLeftZeroes(val)) + } + }() + + // Reader goroutine. + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i)) + tr.GetStorage(addr, slot[:]) + } + }() + + // Hash goroutine. + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 5; i++ { + tr.Hash() + } + }() + + wg.Wait() +}