nomt/merkle: add Phase 7 parallel workers for trie updates

Parallelize the PageWalker trie update across multiple goroutines by
partitioning sorted operations by the root page's 64 child subtrees
(first 6 bits of each key path).

Each worker runs an independent PageWalker constrained to child pages
below the root (using parentPage mechanism), producing ChildPageRoots.
After all workers complete, a root walker places the child roots using
AdvanceAndPlaceNode and concludes with the final trie root.

Workers operate on disjoint page subtrees so no synchronization is
needed during computation — only sync.WaitGroup for goroutine join.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
weiihann 2026-02-12 18:37:42 +08:00
parent 53fd00926f
commit cb3e13d93d
3 changed files with 527 additions and 46 deletions

View file

@ -10,6 +10,7 @@ import (
"fmt"
"os"
"path/filepath"
"runtime"
"sort"
"sync"
@ -27,6 +28,10 @@ const (
type Config struct {
// HTCapacity is the number of hash table buckets. Must be a power of 2.
HTCapacity uint64
// NumWorkers is the number of parallel goroutines for trie updates.
// Defaults to runtime.NumCPU() if zero.
NumWorkers int
}
// DefaultConfig returns a default configuration.
@ -38,11 +43,12 @@ func DefaultConfig() Config {
// DB is the NOMT trie database.
type DB struct {
dataDir string
bb *bitbox.DB
root core.Node
syncSeqn uint32
mu sync.RWMutex
dataDir string
bb *bitbox.DB
root core.Node
syncSeqn uint32
numWorkers int
mu sync.RWMutex
}
// Open opens or creates a NOMT trie database at the given directory.
@ -75,10 +81,16 @@ func Open(dataDir string, config Config) (*DB, error) {
}
}
numWorkers := config.NumWorkers
if numWorkers <= 0 {
numWorkers = runtime.NumCPU()
}
db := &DB{
dataDir: dataDir,
bb: bb,
root: core.Terminator,
dataDir: dataDir,
bb: bb,
root: core.Terminator,
numWorkers: numWorkers,
}
// Run WAL recovery.
@ -136,52 +148,21 @@ func (db *DB) Update(ops []core.LeafOp) (core.Node, error) {
return ops[i].Key != ops[j].Key && keyLess(&ops[i].Key, &ops[j].Key)
})
// Build a BitboxPageSet that loads pages from disk.
pageSet := newBitboxPageSet(db.bb)
// For a simple implementation, treat the entire trie as a single
// terminal at the root and replace it.
// Convert to KeyValue (filter out deletes).
kvs := make([]core.KeyValue, 0, len(ops))
for _, op := range ops {
if op.Value != nil {
kvs = append(kvs, core.KeyValue{Key: op.Key, Value: *op.Value})
}
}
walker := merkle.NewPageWalker(db.root, nil)
if len(kvs) > 0 || len(ops) > 0 {
// Simple approach: single advance at root with all operations.
// For an incremental update on a non-empty trie, we'd need to
// seek to terminals first. This simplified version rebuilds.
pos := core.NewTriePosition()
pos.Down(false) // advance to position [0] for left subtree
// Split ops into left (0-prefix) and right (1-prefix).
leftKVs := make([]core.KeyValue, 0, len(kvs))
rightKVs := make([]core.KeyValue, 0, len(kvs))
for _, kv := range kvs {
if kv.Key[0]&0x80 == 0 {
leftKVs = append(leftKVs, kv)
} else {
rightKVs = append(rightKVs, kv)
}
}
leftPos := core.NewTriePosition()
leftPos.Down(false)
if len(leftKVs) > 0 {
walker.AdvanceAndReplace(pageSet, leftPos, leftKVs)
}
rightPos := core.NewTriePosition()
rightPos.Down(true)
if len(rightKVs) > 0 {
walker.AdvanceAndReplace(pageSet, rightPos, rightKVs)
}
if len(kvs) == 0 {
return db.root, nil
}
out := walker.Conclude()
pageSetFactory := func() merkle.PageSet {
return newBitboxPageSet(db.bb)
}
out := merkle.ParallelUpdate(db.root, kvs, db.numWorkers, pageSetFactory)
// Persist updated pages.
walPath := filepath.Join(db.dataDir, walFileName)

242
nomt/merkle/worker.go Normal file
View file

@ -0,0 +1,242 @@
package merkle
import (
"fmt"
"runtime"
"sync"
"github.com/ethereum/go-ethereum/nomt/core"
)
// childBucket groups key-value operations for a single root page child index.
type childBucket struct {
childIndex uint8
kvs []core.KeyValue
}
// workerTask describes the work assigned to a single worker goroutine.
type workerTask struct {
children []childBucket
}
// workerResult holds the output produced by a single worker.
type workerResult struct {
childPageRoots []childPageRoot
pages []UpdatedPage
err error
}
// ParallelUpdate applies sorted key-value operations to the trie using
// multiple worker goroutines. Each worker processes a disjoint set of root
// page child subtrees (partitioned by the first 6 bits of each key path).
//
// If numWorkers <= 1 or the batch is small, falls back to single-threaded.
// The pageSetFactory is called once per worker to create independent PageSets.
func ParallelUpdate(
root core.Node,
kvs []core.KeyValue,
numWorkers int,
pageSetFactory func() PageSet,
) Output {
if numWorkers <= 0 {
numWorkers = runtime.NumCPU()
}
if len(kvs) == 0 {
return Output{Root: root}
}
if numWorkers <= 1 || len(kvs) < 64 {
return singleThreadedUpdate(root, kvs, pageSetFactory())
}
// Step 1: Partition by child index (first 6 bits).
buckets := partitionByChildIndex(kvs)
// Step 2: Assign to workers.
tasks := assignToWorkers(buckets, numWorkers)
if len(tasks) == 0 {
return Output{Root: root}
}
if len(tasks) == 1 {
return singleThreadedUpdate(root, kvs, pageSetFactory())
}
// Step 3: Launch workers.
results := make([]workerResult, len(tasks))
var wg sync.WaitGroup
wg.Add(len(tasks))
for i, task := range tasks {
go func(idx int, t workerTask) {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
results[idx] = workerResult{
err: fmt.Errorf("worker %d panicked: %v", idx, r),
}
}
}()
ps := pageSetFactory()
results[idx] = runWorker(root, t, ps)
}(i, task)
}
wg.Wait()
// Step 4: Check for errors.
for _, r := range results {
if r.err != nil {
panic(r.err)
}
}
// Step 5: Collect all child page roots and updated pages.
// ChildPageRoots are already left-to-right ordered because tasks are
// assigned in ascending child index order and each worker processes
// children left-to-right.
var allChildRoots []childPageRoot
var allPages []UpdatedPage
for _, r := range results {
allChildRoots = append(allChildRoots, r.childPageRoots...)
allPages = append(allPages, r.pages...)
}
// Step 6: Root walker places all child roots.
rootPS := pageSetFactory()
rootWalker := NewPageWalker(root, nil)
for _, cpr := range allChildRoots {
rootWalker.AdvanceAndPlaceNode(rootPS, cpr.Position, cpr.Node)
}
rootOut := rootWalker.Conclude()
// Step 7: Merge all pages.
allPages = append(allPages, rootOut.Pages...)
rootOut.Pages = allPages
return rootOut
}
// singleThreadedUpdate runs the trie update with a single PageWalker.
// This is the fallback for small batches or single-worker configurations.
func singleThreadedUpdate(
root core.Node,
kvs []core.KeyValue,
pageSet PageSet,
) Output {
walker := NewPageWalker(root, nil)
var leftKVs, rightKVs []core.KeyValue
for i := range kvs {
if kvs[i].Key[0]&0x80 == 0 {
leftKVs = append(leftKVs, kvs[i])
} else {
rightKVs = append(rightKVs, kvs[i])
}
}
if len(leftKVs) > 0 {
leftPos := core.NewTriePosition()
leftPos.Down(false)
walker.AdvanceAndReplace(pageSet, leftPos, leftKVs)
}
if len(rightKVs) > 0 {
rightPos := core.NewTriePosition()
rightPos.Down(true)
walker.AdvanceAndReplace(pageSet, rightPos, rightKVs)
}
return walker.Conclude()
}
// partitionByChildIndex buckets sorted KVs by the first 6 bits of each key
// path (the root page's child index: 0-63).
func partitionByChildIndex(kvs []core.KeyValue) [64][]core.KeyValue {
var buckets [64][]core.KeyValue
for i := range kvs {
childIdx := kvs[i].Key[0] >> 2
buckets[childIdx] = append(buckets[childIdx], kvs[i])
}
return buckets
}
// assignToWorkers distributes non-empty child buckets across numWorkers
// contiguous ranges.
func assignToWorkers(
buckets [64][]core.KeyValue,
numWorkers int,
) []workerTask {
var nonEmpty []childBucket
for i, kvs := range buckets {
if len(kvs) > 0 {
nonEmpty = append(nonEmpty, childBucket{
childIndex: uint8(i),
kvs: kvs,
})
}
}
if len(nonEmpty) == 0 {
return nil
}
if numWorkers > len(nonEmpty) {
numWorkers = len(nonEmpty)
}
tasks := make([]workerTask, numWorkers)
perWorker := len(nonEmpty) / numWorkers
remainder := len(nonEmpty) % numWorkers
idx := 0
for w := range numWorkers {
count := perWorker
if w < remainder {
count++
}
tasks[w] = workerTask{children: nonEmpty[idx : idx+count]}
idx += count
}
return tasks
}
// runWorker processes a worker's assigned child subtrees using a PageWalker
// constrained to pages below the root page.
func runWorker(
root core.Node,
task workerTask,
pageSet PageSet,
) workerResult {
rootPageID := core.RootPageID()
walker := NewPageWalker(root, &rootPageID)
for _, child := range task.children {
var leftKVs, rightKVs []core.KeyValue
for i := range child.kvs {
if (child.kvs[i].Key[0]>>1)&1 == 0 {
leftKVs = append(leftKVs, child.kvs[i])
} else {
rightKVs = append(rightKVs, child.kvs[i])
}
}
if len(leftKVs) > 0 {
walker.AdvanceAndReplace(pageSet, childPosition(child.childIndex, false), leftKVs)
}
if len(rightKVs) > 0 {
walker.AdvanceAndReplace(pageSet, childPosition(child.childIndex, true), rightKVs)
}
}
out := walker.Conclude()
return workerResult{
childPageRoots: out.ChildPageRoots,
pages: out.Pages,
}
}
// childPosition creates a TriePosition at depth 7: 6 bits encoding the root
// page's child index (MSB first) plus one additional bit for left/right within
// the child page.
func childPosition(childIndex uint8, rightBit bool) core.TriePosition {
pos := core.NewTriePosition()
for b := 5; b >= 0; b-- {
pos.Down((childIndex>>b)&1 == 1)
}
pos.Down(rightBit)
return pos
}

258
nomt/merkle/worker_test.go Normal file
View file

@ -0,0 +1,258 @@
package merkle
import (
"math/rand"
"sort"
"testing"
"github.com/ethereum/go-ethereum/nomt/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Unit tests for helpers ---
func TestPartitionByChildIndex(t *testing.T) {
// Key 0x00... → child 0, key 0xFC... → child 63.
kvs := []core.KeyValue{
{Key: makeKVKey(0x00), Value: makeKVVal(1)},
{Key: makeKVKey(0x04), Value: makeKVVal(2)}, // 0x04 >> 2 = 1
{Key: makeKVKey(0xFC), Value: makeKVVal(3)}, // 0xFC >> 2 = 63
}
buckets := partitionByChildIndex(kvs)
assert.Len(t, buckets[0], 1)
assert.Len(t, buckets[1], 1)
assert.Len(t, buckets[63], 1)
// All other buckets should be empty.
nonEmpty := 0
for _, b := range buckets {
if len(b) > 0 {
nonEmpty++
}
}
assert.Equal(t, 3, nonEmpty)
}
func TestChildPosition(t *testing.T) {
// Child 0, left: 7 bits all false → depth 7.
pos := childPosition(0, false)
assert.Equal(t, uint16(7), pos.Depth())
for i := range 7 {
assert.False(t, pos.Bit(i), "bit %d should be 0", i)
}
// Child 0, right: 6 false + 1 true → depth 7.
pos = childPosition(0, true)
assert.Equal(t, uint16(7), pos.Depth())
for i := range 6 {
assert.False(t, pos.Bit(i), "bit %d should be 0", i)
}
assert.True(t, pos.Bit(6))
// Child 63 (0b111111), left: 6 true + 1 false → depth 7.
pos = childPosition(63, false)
assert.Equal(t, uint16(7), pos.Depth())
for i := range 6 {
assert.True(t, pos.Bit(i), "bit %d should be 1", i)
}
assert.False(t, pos.Bit(6))
// Child 63 (0b111111), right: 7 true → depth 7.
pos = childPosition(63, true)
assert.Equal(t, uint16(7), pos.Depth())
for i := range 7 {
assert.True(t, pos.Bit(i), "bit %d should be 1", i)
}
}
func TestAssignToWorkers(t *testing.T) {
// 3 non-empty buckets, 2 workers.
var buckets [64][]core.KeyValue
buckets[0] = []core.KeyValue{{Key: makeKVKey(0x00), Value: makeKVVal(1)}}
buckets[10] = []core.KeyValue{{Key: makeKVKey(0x28), Value: makeKVVal(2)}} // 0x28>>2=10
buckets[63] = []core.KeyValue{{Key: makeKVKey(0xFC), Value: makeKVVal(3)}}
tasks := assignToWorkers(buckets, 2)
require.Len(t, tasks, 2)
// 3 items / 2 workers: first gets 2, second gets 1.
assert.Len(t, tasks[0].children, 2)
assert.Len(t, tasks[1].children, 1)
assert.Equal(t, uint8(0), tasks[0].children[0].childIndex)
assert.Equal(t, uint8(10), tasks[0].children[1].childIndex)
assert.Equal(t, uint8(63), tasks[1].children[0].childIndex)
}
func TestAssignToWorkersMoreWorkersThanChildren(t *testing.T) {
var buckets [64][]core.KeyValue
buckets[5] = []core.KeyValue{{Key: makeKVKey(0x14), Value: makeKVVal(1)}} // 0x14>>2=5
buckets[6] = []core.KeyValue{{Key: makeKVKey(0x18), Value: makeKVVal(2)}} // 0x18>>2=6
tasks := assignToWorkers(buckets, 8)
// Only 2 non-empty, so cap to 2 workers.
require.Len(t, tasks, 2)
assert.Len(t, tasks[0].children, 1)
assert.Len(t, tasks[1].children, 1)
}
// --- Integration tests ---
// permissivePageSet wraps MemoryPageSet to return fresh pages for missing
// entries (matching bitboxPageSet behavior). This is needed because the
// parallel workers descend into child pages that may not exist yet.
type permissivePageSet struct {
*MemoryPageSet
}
func (ps *permissivePageSet) Get(pageID core.PageID) (*core.RawPage, PageOrigin, bool) {
page, origin, ok := ps.MemoryPageSet.Get(pageID)
if !ok {
fresh := new(core.RawPage)
return fresh, PageOrigin{Kind: PageOriginFresh}, true
}
return page, origin, true
}
func memoryPageSetFactory() PageSet {
return &permissivePageSet{NewMemoryPageSet(true)}
}
func TestParallelUpdateEmpty(t *testing.T) {
out := ParallelUpdate(core.Terminator, nil, 4, memoryPageSetFactory)
assert.Equal(t, core.Terminator, out.Root)
}
func TestParallelUpdateSingleKey(t *testing.T) {
kv := core.KeyValue{Key: makeKVKey(0x50), Value: makeKVVal(1)}
kvs := []core.KeyValue{kv}
out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory)
expected := expectedRoot(kvs)
assert.Equal(t, expected, out.Root)
}
func TestParallelUpdateTwoKeysDifferentChildren(t *testing.T) {
// 0x00 → child 0, 0x80 → child 32.
kvs := []core.KeyValue{
{Key: makeKVKey(0x00), Value: makeKVVal(1)},
{Key: makeKVKey(0x80), Value: makeKVVal(2)},
}
out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory)
expected := expectedRoot(kvs)
assert.Equal(t, expected, out.Root)
}
func TestParallelUpdateSparseChildren(t *testing.T) {
// Only children 0 and 63 have ops.
kvs := []core.KeyValue{
{Key: makeKVKey(0x00), Value: makeKVVal(1)},
{Key: makeKVKey(0xFC), Value: makeKVVal(2)},
}
out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory)
expected := expectedRoot(kvs)
assert.Equal(t, expected, out.Root)
}
func TestParallelUpdateSingleChild(t *testing.T) {
// All keys land in child 0 (first 6 bits = 000000).
kvs := []core.KeyValue{
{Key: makeKVKey(0x00), Value: makeKVVal(1)},
{Key: makeKVKey(0x01), Value: makeKVVal(2)},
{Key: makeKVKey(0x02), Value: makeKVVal(3)},
{Key: makeKVKey(0x03), Value: makeKVVal(4)},
}
sort.Slice(kvs, func(i, j int) bool { return kvLess(&kvs[i], &kvs[j]) })
out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory)
expected := expectedRoot(kvs)
assert.Equal(t, expected, out.Root)
}
func TestParallelUpdateFallbackSmallBatch(t *testing.T) {
// Less than 64 ops → single-threaded fallback.
kvs := randomKVs(10, 42)
out := ParallelUpdate(core.Terminator, kvs, 8, memoryPageSetFactory)
expected := expectedRoot(kvs)
assert.Equal(t, expected, out.Root)
}
func TestParallelUpdateDeterministic(t *testing.T) {
kvs := randomKVs(200, 99)
r1 := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory).Root
r2 := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory).Root
assert.Equal(t, r1, r2, "same inputs should produce same root")
}
func TestParallelUpdateMatchesSingleThreaded(t *testing.T) {
tests := []struct {
name string
numKVs int
workers int
}{
{"1kv_2w", 1, 2},
{"10kv_2w", 10, 2},
{"100kv_2w", 100, 2},
{"100kv_4w", 100, 4},
{"100kv_8w", 100, 8},
{"500kv_4w", 500, 4},
{"1000kv_8w", 1000, 8},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
kvs := randomKVs(tc.numKVs, 12345)
single := singleThreadedUpdate(
core.Terminator, kvs, NewMemoryPageSet(true),
)
parallel := ParallelUpdate(
core.Terminator, kvs, tc.workers, memoryPageSetFactory,
)
assert.Equal(t, single.Root, parallel.Root,
"parallel root should match single-threaded root")
})
}
}
// --- helpers ---
func randomKVs(n int, seed int64) []core.KeyValue {
rng := rand.New(rand.NewSource(seed))
kvs := make([]core.KeyValue, n)
seen := make(map[core.KeyPath]bool, n)
for i := range n {
for {
var kp core.KeyPath
rng.Read(kp[:])
if seen[kp] {
continue
}
seen[kp] = true
var vh core.ValueHash
rng.Read(vh[:])
kvs[i] = core.KeyValue{Key: kp, Value: vh}
break
}
}
sort.Slice(kvs, func(i, j int) bool { return kvLess(&kvs[i], &kvs[j]) })
return kvs
}
func kvLess(a, b *core.KeyValue) bool {
for i := range a.Key {
if a.Key[i] < b.Key[i] {
return true
}
if a.Key[i] > b.Key[i] {
return false
}
}
return false
}