trie/bintrie: add mutex protection for concurrent access

This commit is contained in:
Md.Sadiq 2026-04-16 15:55:35 +05:30
parent b1baab4427
commit 8f9e3630d8
No known key found for this signature in database
GPG key ID: BDFA0A451C77AD56
2 changed files with 217 additions and 14 deletions

View file

@ -21,6 +21,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
@ -114,7 +115,16 @@ func NewBinaryNode() BinaryNode {
}
// BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864.
//
// BinaryTrie is safe for concurrent use by multiple goroutines.
type BinaryTrie struct {
mu sync.Mutex // protects root and all tree mutation/reads
// readerMu protects calls into the underlying trie.Reader. trie.Reader is
// explicitly not safe for concurrent use, and BinaryTrie.Copy/NodeIterator
// share the same reader pointer between snapshots.
readerMu *sync.Mutex
root BinaryNode
reader *trie.Reader
tracer *trie.PrevalueTracer
@ -122,6 +132,8 @@ type BinaryTrie struct {
// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
func (t *BinaryTrie) ToDot() string {
t.mu.Lock()
defer t.mu.Unlock()
t.root.Hash()
return ToDot(t.root)
}
@ -133,9 +145,10 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
return nil, err
}
t := &BinaryTrie{
root: NewBinaryNode(),
reader: reader,
tracer: trie.NewPrevalueTracer(),
root: NewBinaryNode(),
reader: reader,
tracer: trie.NewPrevalueTracer(),
readerMu: new(sync.Mutex),
}
// Parse the root node if it's not empty
if root != types.EmptyBinaryHash && root != types.EmptyRootHash {
@ -159,6 +172,13 @@ func (t *BinaryTrie) nodeResolver(path []byte, hash common.Hash) ([]byte, error)
if hash == (common.Hash{}) {
return nil, nil // empty node
}
if t.reader == nil {
return nil, fmt.Errorf("BinaryTrie nodeResolver requires a trie.Reader")
}
if t.readerMu != nil {
t.readerMu.Lock()
defer t.readerMu.Unlock()
}
blob, err := t.reader.Node(path, hash)
if err != nil {
return nil, err
@ -176,11 +196,20 @@ func (t *BinaryTrie) GetKey(key []byte) []byte {
// GetWithHashedKey returns the value, assuming that the key has already
// been hashed.
func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) {
t.mu.Lock()
defer t.mu.Unlock()
return t.root.Get(key, t.nodeResolver)
}
// GetAccount returns the account information for the given address.
func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) {
t.mu.Lock()
defer t.mu.Unlock()
return t.getAccount(addr)
}
// getAccount is the lock-free internal implementation of GetAccount.
func (t *BinaryTrie) getAccount(addr common.Address) (*types.StateAccount, error) {
var (
values [][]byte
err error
@ -238,11 +267,21 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
// not be modified by the caller. If a node was not found in the database, a
// trie.MissingNodeError is returned.
func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) {
t.mu.Lock()
defer t.mu.Unlock()
return t.getStorage(addr, key)
}
// getStorage is the lock-free internal implementation of GetStorage.
func (t *BinaryTrie) getStorage(addr common.Address, key []byte) ([]byte, error) {
return t.root.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
}
// UpdateAccount updates the account information for the given address.
func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error {
t.mu.Lock()
defer t.mu.Unlock()
var (
err error
basicData [HashSize]byte
@ -271,6 +310,13 @@ func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount,
// UpdateStem updates the values for the given stem key.
func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error {
t.mu.Lock()
defer t.mu.Unlock()
return t.updateStem(key, values)
}
// updateStem is the lock-free internal implementation of UpdateStem.
func (t *BinaryTrie) updateStem(key []byte, values [][]byte) error {
var err error
t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0)
return err
@ -281,6 +327,9 @@ func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error {
// by the caller while they are stored in the trie. If a node was not found in the
// database, a trie.MissingNodeError is returned.
func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) error {
t.mu.Lock()
defer t.mu.Unlock()
k := GetBinaryTreeKeyStorageSlot(address, key)
var v [HashSize]byte
if len(value) >= HashSize {
@ -299,6 +348,9 @@ func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) er
// DeleteAccount erases an account by overwriting the account
// descriptors with 0s.
func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
t.mu.Lock()
defer t.mu.Unlock()
var (
values = make([][]byte, StemNodeWidth)
stem = GetBinaryTreeKey(addr, zero[:])
@ -318,6 +370,9 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
// DeleteStorage removes any existing value for key from the trie. If a node was not
// found in the database, a trie.MissingNodeError is returned.
func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error {
t.mu.Lock()
defer t.mu.Unlock()
k := GetBinaryTreeKeyStorageSlot(addr, key)
var zero [HashSize]byte
root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0)
@ -331,12 +386,17 @@ func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error {
// Hash returns the root hash of the trie. It does not write to the database and
// can be used even if the trie doesn't have one.
func (t *BinaryTrie) Hash() common.Hash {
t.mu.Lock()
defer t.mu.Unlock()
return t.root.Hash()
}
// Commit writes all nodes to the trie's memory database, tracking the internal
// and external (for account tries) references.
func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
t.mu.Lock()
defer t.mu.Unlock()
nodeset := trienode.NewNodeSet(common.Hash{})
// The root can be any type of BinaryNode (InternalNode, StemNode, etc.)
@ -348,13 +408,16 @@ func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
panic(fmt.Errorf("CollectNodes failed: %v", err))
}
// Serialize root commitment form
return t.Hash(), nodeset
// We already hold t.mu, so using the lock-free root hash computation.
return t.root.Hash(), nodeset
}
// NodeIterator returns an iterator that returns nodes of the trie. Iteration
// starts at the key after the given start key.
func (t *BinaryTrie) NodeIterator(startKey []byte) (trie.NodeIterator, error) {
return newBinaryNodeIterator(t, nil)
// Iterate on a snapshot to avoid racing with concurrent writers.
snap := t.Copy()
return newBinaryNodeIterator(snap, startKey)
}
// Prove constructs a Merkle proof for key. The result contains all encoded nodes
@ -368,12 +431,16 @@ func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
panic("not implemented")
}
// Copy creates a deep copy of the trie.
// Copy creates a deep copy of the trie. The returned copy has its own mutex
// and is independent of the original.
func (t *BinaryTrie) Copy() *BinaryTrie {
t.mu.Lock()
defer t.mu.Unlock()
return &BinaryTrie{
root: t.root.Copy(),
reader: t.reader,
tracer: t.tracer.Copy(),
root: t.root.Copy(),
reader: t.reader,
readerMu: t.readerMu,
tracer: t.tracer.Copy(),
}
}
@ -389,6 +456,9 @@ func (t *BinaryTrie) IsVerkle() bool {
//
// Note: the basic data leaf needs to have been previously created for this to work
func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Hash, code []byte) error {
t.mu.Lock()
defer t.mu.Unlock()
var (
chunks = ChunkifyCode(code)
values [][]byte
@ -406,9 +476,7 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has
values[groupOffset] = chunks[i : i+HashSize]
if groupOffset == StemNodeWidth-1 || len(chunks)-i <= HashSize {
err = t.UpdateStem(key[:StemSize], values)
if err != nil {
if err = t.updateStem(key[:StemSize], values); err != nil {
return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err)
}
}
@ -419,8 +487,10 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has
// PrefetchAccount attempts to resolve specific accounts from the database
// to accelerate subsequent trie operations.
func (t *BinaryTrie) PrefetchAccount(addresses []common.Address) error {
t.mu.Lock()
defer t.mu.Unlock()
for _, addr := range addresses {
if _, err := t.GetAccount(addr); err != nil {
if _, err := t.getAccount(addr); err != nil {
return err
}
}
@ -430,8 +500,10 @@ func (t *BinaryTrie) PrefetchAccount(addresses []common.Address) error {
// PrefetchStorage attempts to resolve specific storage slots from the database
// to accelerate subsequent trie operations.
func (t *BinaryTrie) PrefetchStorage(addr common.Address, keys [][]byte) error {
t.mu.Lock()
defer t.mu.Unlock()
for _, key := range keys {
if _, err := t.GetStorage(addr, key); err != nil {
if _, err := t.getStorage(addr, key); err != nil {
return err
}
}
@ -440,5 +512,7 @@ func (t *BinaryTrie) PrefetchStorage(addr common.Address, keys [][]byte) error {
// Witness returns a set containing all trie nodes that have been accessed.
func (t *BinaryTrie) Witness() map[string][]byte {
t.mu.Lock()
defer t.mu.Unlock()
return t.tracer.Values()
}

View file

@ -19,6 +19,8 @@ package bintrie
import (
"bytes"
"encoding/binary"
"fmt"
"sync"
"testing"
"github.com/ethereum/go-ethereum/common"
@ -779,3 +781,130 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
t.Fatalf("expected nil/zero for non-existent storage, got %x", got)
}
}
// TestConcurrentStorageUpdates exercises BinaryTrie under concurrent writes
// from multiple goroutines — the scenario that triggers a data race when
// IntermediateRoot parallelizes per-account updateTrie() calls on a single
// shared binary trie. This test must pass with -race enabled.
func TestConcurrentStorageUpdates(t *testing.T) {
tr := newEmptyTestTrie(t)
// Create multiple accounts so the root becomes an InternalNode with
// subtrees that different goroutines will traverse concurrently.
addrs := []common.Address{
common.HexToAddress("0x1111111111111111111111111111111111111111"),
common.HexToAddress("0x2222222222222222222222222222222222222222"),
common.HexToAddress("0x9999999999999999999999999999999999999999"),
common.HexToAddress("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
}
for _, addr := range addrs {
acc := makeAccount(1, 100, common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"))
if err := tr.UpdateAccount(addr, acc, 0); err != nil {
t.Fatalf("UpdateAccount: %v", err)
}
}
// Spawn goroutines that concurrently call UpdateStorage and DeleteStorage
// on different addresses but the same shared trie. This mirrors the
// errgroup pattern in IntermediateRoot.
const slotsPerAddr = 20
var wg sync.WaitGroup
errs := make(chan error, len(addrs)*slotsPerAddr)
for _, addr := range addrs {
wg.Add(1)
go func(addr common.Address) {
defer wg.Done()
for i := 0; i < slotsPerAddr; i++ {
slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i))
val := common.TrimLeftZeroes(common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000%04x%04x", addr[0], i)).Bytes())
if err := tr.UpdateStorage(addr, slot[:], val); err != nil {
errs <- fmt.Errorf("UpdateStorage(%x, slot %d): %v", addr, i, err)
return
}
}
}(addr)
}
wg.Wait()
close(errs)
for err := range errs {
t.Fatal(err)
}
// Verify all values are readable and correct.
for _, addr := range addrs {
for i := 0; i < slotsPerAddr; i++ {
slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i))
got, err := tr.GetStorage(addr, slot[:])
if err != nil {
t.Fatalf("GetStorage(%x, slot %d): %v", addr, i, err)
}
if len(got) == 0 {
t.Fatalf("GetStorage(%x, slot %d): empty, expected value", addr, i)
}
}
}
// Hash must not race with prior writes — call it to exercise the
// read path after concurrent mutations.
h := tr.Hash()
if h == (common.Hash{}) {
t.Fatal("trie hash is zero after concurrent updates")
}
}
// TestConcurrentReadWrite exercises concurrent reads and writes on the same
// BinaryTrie, verifying no data race under the race detector.
func TestConcurrentReadWrite(t *testing.T) {
tr := newEmptyTestTrie(t)
addr := common.HexToAddress("0x1111111111111111111111111111111111111111")
acc := makeAccount(1, 100, common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"))
if err := tr.UpdateAccount(addr, acc, 0); err != nil {
t.Fatalf("UpdateAccount: %v", err)
}
// Pre-populate some storage so reads have data to find.
for i := 0; i < 10; i++ {
slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i))
val := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000dead%04x", i)).Bytes()
if err := tr.UpdateStorage(addr, slot[:], common.TrimLeftZeroes(val)); err != nil {
t.Fatalf("seed UpdateStorage: %v", err)
}
}
var wg sync.WaitGroup
// Writer goroutine.
wg.Add(1)
go func() {
defer wg.Done()
for i := 10; i < 30; i++ {
slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i))
val := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000beef%04x", i)).Bytes()
tr.UpdateStorage(addr, slot[:], common.TrimLeftZeroes(val))
}
}()
// Reader goroutine.
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
slot := common.HexToHash(fmt.Sprintf("00000000000000000000000000000000000000000000000000000000000000%02x", 0x80+i))
tr.GetStorage(addr, slot[:])
}
}()
// Hash goroutine.
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 5; i++ {
tr.Hash()
}
}()
wg.Wait()
}