From cb3e13d93d2e7d3374075131c68307510e5ff550 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 12 Feb 2026 18:37:42 +0800 Subject: [PATCH] nomt/merkle: add Phase 7 parallel workers for trie updates MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Parallelize the PageWalker trie update across multiple goroutines by partitioning sorted operations by the root page's 64 child subtrees (first 6 bits of each key path). Each worker runs an independent PageWalker constrained to child pages below the root (using parentPage mechanism), producing ChildPageRoots. After all workers complete, a root walker places the child roots using AdvanceAndPlaceNode and concludes with the final trie root. Workers operate on disjoint page subtrees so no synchronization is needed during computation — only sync.WaitGroup for goroutine join. Co-Authored-By: Claude Opus 4.6 --- nomt/db/db.go | 73 ++++------- nomt/merkle/worker.go | 242 ++++++++++++++++++++++++++++++++++ nomt/merkle/worker_test.go | 258 +++++++++++++++++++++++++++++++++++++ 3 files changed, 527 insertions(+), 46 deletions(-) create mode 100644 nomt/merkle/worker.go create mode 100644 nomt/merkle/worker_test.go diff --git a/nomt/db/db.go b/nomt/db/db.go index 2259b8efdf..4ead07b20f 100644 --- a/nomt/db/db.go +++ b/nomt/db/db.go @@ -10,6 +10,7 @@ import ( "fmt" "os" "path/filepath" + "runtime" "sort" "sync" @@ -27,6 +28,10 @@ const ( type Config struct { // HTCapacity is the number of hash table buckets. Must be a power of 2. HTCapacity uint64 + + // NumWorkers is the number of parallel goroutines for trie updates. + // Defaults to runtime.NumCPU() if zero. + NumWorkers int } // DefaultConfig returns a default configuration. @@ -38,11 +43,12 @@ func DefaultConfig() Config { // DB is the NOMT trie database. type DB struct { - dataDir string - bb *bitbox.DB - root core.Node - syncSeqn uint32 - mu sync.RWMutex + dataDir string + bb *bitbox.DB + root core.Node + syncSeqn uint32 + numWorkers int + mu sync.RWMutex } // Open opens or creates a NOMT trie database at the given directory. @@ -75,10 +81,16 @@ func Open(dataDir string, config Config) (*DB, error) { } } + numWorkers := config.NumWorkers + if numWorkers <= 0 { + numWorkers = runtime.NumCPU() + } + db := &DB{ - dataDir: dataDir, - bb: bb, - root: core.Terminator, + dataDir: dataDir, + bb: bb, + root: core.Terminator, + numWorkers: numWorkers, } // Run WAL recovery. @@ -136,52 +148,21 @@ func (db *DB) Update(ops []core.LeafOp) (core.Node, error) { return ops[i].Key != ops[j].Key && keyLess(&ops[i].Key, &ops[j].Key) }) - // Build a BitboxPageSet that loads pages from disk. - pageSet := newBitboxPageSet(db.bb) - - // For a simple implementation, treat the entire trie as a single - // terminal at the root and replace it. + // Convert to KeyValue (filter out deletes). kvs := make([]core.KeyValue, 0, len(ops)) for _, op := range ops { if op.Value != nil { kvs = append(kvs, core.KeyValue{Key: op.Key, Value: *op.Value}) } } - - walker := merkle.NewPageWalker(db.root, nil) - - if len(kvs) > 0 || len(ops) > 0 { - // Simple approach: single advance at root with all operations. - // For an incremental update on a non-empty trie, we'd need to - // seek to terminals first. This simplified version rebuilds. - pos := core.NewTriePosition() - pos.Down(false) // advance to position [0] for left subtree - - // Split ops into left (0-prefix) and right (1-prefix). - leftKVs := make([]core.KeyValue, 0, len(kvs)) - rightKVs := make([]core.KeyValue, 0, len(kvs)) - for _, kv := range kvs { - if kv.Key[0]&0x80 == 0 { - leftKVs = append(leftKVs, kv) - } else { - rightKVs = append(rightKVs, kv) - } - } - - leftPos := core.NewTriePosition() - leftPos.Down(false) - if len(leftKVs) > 0 { - walker.AdvanceAndReplace(pageSet, leftPos, leftKVs) - } - - rightPos := core.NewTriePosition() - rightPos.Down(true) - if len(rightKVs) > 0 { - walker.AdvanceAndReplace(pageSet, rightPos, rightKVs) - } + if len(kvs) == 0 { + return db.root, nil } - out := walker.Conclude() + pageSetFactory := func() merkle.PageSet { + return newBitboxPageSet(db.bb) + } + out := merkle.ParallelUpdate(db.root, kvs, db.numWorkers, pageSetFactory) // Persist updated pages. walPath := filepath.Join(db.dataDir, walFileName) diff --git a/nomt/merkle/worker.go b/nomt/merkle/worker.go new file mode 100644 index 0000000000..33a30841d0 --- /dev/null +++ b/nomt/merkle/worker.go @@ -0,0 +1,242 @@ +package merkle + +import ( + "fmt" + "runtime" + "sync" + + "github.com/ethereum/go-ethereum/nomt/core" +) + +// childBucket groups key-value operations for a single root page child index. +type childBucket struct { + childIndex uint8 + kvs []core.KeyValue +} + +// workerTask describes the work assigned to a single worker goroutine. +type workerTask struct { + children []childBucket +} + +// workerResult holds the output produced by a single worker. +type workerResult struct { + childPageRoots []childPageRoot + pages []UpdatedPage + err error +} + +// ParallelUpdate applies sorted key-value operations to the trie using +// multiple worker goroutines. Each worker processes a disjoint set of root +// page child subtrees (partitioned by the first 6 bits of each key path). +// +// If numWorkers <= 1 or the batch is small, falls back to single-threaded. +// The pageSetFactory is called once per worker to create independent PageSets. +func ParallelUpdate( + root core.Node, + kvs []core.KeyValue, + numWorkers int, + pageSetFactory func() PageSet, +) Output { + if numWorkers <= 0 { + numWorkers = runtime.NumCPU() + } + if len(kvs) == 0 { + return Output{Root: root} + } + if numWorkers <= 1 || len(kvs) < 64 { + return singleThreadedUpdate(root, kvs, pageSetFactory()) + } + + // Step 1: Partition by child index (first 6 bits). + buckets := partitionByChildIndex(kvs) + + // Step 2: Assign to workers. + tasks := assignToWorkers(buckets, numWorkers) + if len(tasks) == 0 { + return Output{Root: root} + } + if len(tasks) == 1 { + return singleThreadedUpdate(root, kvs, pageSetFactory()) + } + + // Step 3: Launch workers. + results := make([]workerResult, len(tasks)) + var wg sync.WaitGroup + wg.Add(len(tasks)) + + for i, task := range tasks { + go func(idx int, t workerTask) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + results[idx] = workerResult{ + err: fmt.Errorf("worker %d panicked: %v", idx, r), + } + } + }() + ps := pageSetFactory() + results[idx] = runWorker(root, t, ps) + }(i, task) + } + wg.Wait() + + // Step 4: Check for errors. + for _, r := range results { + if r.err != nil { + panic(r.err) + } + } + + // Step 5: Collect all child page roots and updated pages. + // ChildPageRoots are already left-to-right ordered because tasks are + // assigned in ascending child index order and each worker processes + // children left-to-right. + var allChildRoots []childPageRoot + var allPages []UpdatedPage + for _, r := range results { + allChildRoots = append(allChildRoots, r.childPageRoots...) + allPages = append(allPages, r.pages...) + } + + // Step 6: Root walker places all child roots. + rootPS := pageSetFactory() + rootWalker := NewPageWalker(root, nil) + for _, cpr := range allChildRoots { + rootWalker.AdvanceAndPlaceNode(rootPS, cpr.Position, cpr.Node) + } + rootOut := rootWalker.Conclude() + + // Step 7: Merge all pages. + allPages = append(allPages, rootOut.Pages...) + rootOut.Pages = allPages + return rootOut +} + +// singleThreadedUpdate runs the trie update with a single PageWalker. +// This is the fallback for small batches or single-worker configurations. +func singleThreadedUpdate( + root core.Node, + kvs []core.KeyValue, + pageSet PageSet, +) Output { + walker := NewPageWalker(root, nil) + + var leftKVs, rightKVs []core.KeyValue + for i := range kvs { + if kvs[i].Key[0]&0x80 == 0 { + leftKVs = append(leftKVs, kvs[i]) + } else { + rightKVs = append(rightKVs, kvs[i]) + } + } + + if len(leftKVs) > 0 { + leftPos := core.NewTriePosition() + leftPos.Down(false) + walker.AdvanceAndReplace(pageSet, leftPos, leftKVs) + } + if len(rightKVs) > 0 { + rightPos := core.NewTriePosition() + rightPos.Down(true) + walker.AdvanceAndReplace(pageSet, rightPos, rightKVs) + } + + return walker.Conclude() +} + +// partitionByChildIndex buckets sorted KVs by the first 6 bits of each key +// path (the root page's child index: 0-63). +func partitionByChildIndex(kvs []core.KeyValue) [64][]core.KeyValue { + var buckets [64][]core.KeyValue + for i := range kvs { + childIdx := kvs[i].Key[0] >> 2 + buckets[childIdx] = append(buckets[childIdx], kvs[i]) + } + return buckets +} + +// assignToWorkers distributes non-empty child buckets across numWorkers +// contiguous ranges. +func assignToWorkers( + buckets [64][]core.KeyValue, + numWorkers int, +) []workerTask { + var nonEmpty []childBucket + for i, kvs := range buckets { + if len(kvs) > 0 { + nonEmpty = append(nonEmpty, childBucket{ + childIndex: uint8(i), + kvs: kvs, + }) + } + } + if len(nonEmpty) == 0 { + return nil + } + if numWorkers > len(nonEmpty) { + numWorkers = len(nonEmpty) + } + + tasks := make([]workerTask, numWorkers) + perWorker := len(nonEmpty) / numWorkers + remainder := len(nonEmpty) % numWorkers + + idx := 0 + for w := range numWorkers { + count := perWorker + if w < remainder { + count++ + } + tasks[w] = workerTask{children: nonEmpty[idx : idx+count]} + idx += count + } + return tasks +} + +// runWorker processes a worker's assigned child subtrees using a PageWalker +// constrained to pages below the root page. +func runWorker( + root core.Node, + task workerTask, + pageSet PageSet, +) workerResult { + rootPageID := core.RootPageID() + walker := NewPageWalker(root, &rootPageID) + + for _, child := range task.children { + var leftKVs, rightKVs []core.KeyValue + for i := range child.kvs { + if (child.kvs[i].Key[0]>>1)&1 == 0 { + leftKVs = append(leftKVs, child.kvs[i]) + } else { + rightKVs = append(rightKVs, child.kvs[i]) + } + } + + if len(leftKVs) > 0 { + walker.AdvanceAndReplace(pageSet, childPosition(child.childIndex, false), leftKVs) + } + if len(rightKVs) > 0 { + walker.AdvanceAndReplace(pageSet, childPosition(child.childIndex, true), rightKVs) + } + } + + out := walker.Conclude() + return workerResult{ + childPageRoots: out.ChildPageRoots, + pages: out.Pages, + } +} + +// childPosition creates a TriePosition at depth 7: 6 bits encoding the root +// page's child index (MSB first) plus one additional bit for left/right within +// the child page. +func childPosition(childIndex uint8, rightBit bool) core.TriePosition { + pos := core.NewTriePosition() + for b := 5; b >= 0; b-- { + pos.Down((childIndex>>b)&1 == 1) + } + pos.Down(rightBit) + return pos +} diff --git a/nomt/merkle/worker_test.go b/nomt/merkle/worker_test.go new file mode 100644 index 0000000000..20e7adfb27 --- /dev/null +++ b/nomt/merkle/worker_test.go @@ -0,0 +1,258 @@ +package merkle + +import ( + "math/rand" + "sort" + "testing" + + "github.com/ethereum/go-ethereum/nomt/core" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Unit tests for helpers --- + +func TestPartitionByChildIndex(t *testing.T) { + // Key 0x00... → child 0, key 0xFC... → child 63. + kvs := []core.KeyValue{ + {Key: makeKVKey(0x00), Value: makeKVVal(1)}, + {Key: makeKVKey(0x04), Value: makeKVVal(2)}, // 0x04 >> 2 = 1 + {Key: makeKVKey(0xFC), Value: makeKVVal(3)}, // 0xFC >> 2 = 63 + } + buckets := partitionByChildIndex(kvs) + + assert.Len(t, buckets[0], 1) + assert.Len(t, buckets[1], 1) + assert.Len(t, buckets[63], 1) + + // All other buckets should be empty. + nonEmpty := 0 + for _, b := range buckets { + if len(b) > 0 { + nonEmpty++ + } + } + assert.Equal(t, 3, nonEmpty) +} + +func TestChildPosition(t *testing.T) { + // Child 0, left: 7 bits all false → depth 7. + pos := childPosition(0, false) + assert.Equal(t, uint16(7), pos.Depth()) + for i := range 7 { + assert.False(t, pos.Bit(i), "bit %d should be 0", i) + } + + // Child 0, right: 6 false + 1 true → depth 7. + pos = childPosition(0, true) + assert.Equal(t, uint16(7), pos.Depth()) + for i := range 6 { + assert.False(t, pos.Bit(i), "bit %d should be 0", i) + } + assert.True(t, pos.Bit(6)) + + // Child 63 (0b111111), left: 6 true + 1 false → depth 7. + pos = childPosition(63, false) + assert.Equal(t, uint16(7), pos.Depth()) + for i := range 6 { + assert.True(t, pos.Bit(i), "bit %d should be 1", i) + } + assert.False(t, pos.Bit(6)) + + // Child 63 (0b111111), right: 7 true → depth 7. + pos = childPosition(63, true) + assert.Equal(t, uint16(7), pos.Depth()) + for i := range 7 { + assert.True(t, pos.Bit(i), "bit %d should be 1", i) + } +} + +func TestAssignToWorkers(t *testing.T) { + // 3 non-empty buckets, 2 workers. + var buckets [64][]core.KeyValue + buckets[0] = []core.KeyValue{{Key: makeKVKey(0x00), Value: makeKVVal(1)}} + buckets[10] = []core.KeyValue{{Key: makeKVKey(0x28), Value: makeKVVal(2)}} // 0x28>>2=10 + buckets[63] = []core.KeyValue{{Key: makeKVKey(0xFC), Value: makeKVVal(3)}} + + tasks := assignToWorkers(buckets, 2) + require.Len(t, tasks, 2) + // 3 items / 2 workers: first gets 2, second gets 1. + assert.Len(t, tasks[0].children, 2) + assert.Len(t, tasks[1].children, 1) + assert.Equal(t, uint8(0), tasks[0].children[0].childIndex) + assert.Equal(t, uint8(10), tasks[0].children[1].childIndex) + assert.Equal(t, uint8(63), tasks[1].children[0].childIndex) +} + +func TestAssignToWorkersMoreWorkersThanChildren(t *testing.T) { + var buckets [64][]core.KeyValue + buckets[5] = []core.KeyValue{{Key: makeKVKey(0x14), Value: makeKVVal(1)}} // 0x14>>2=5 + buckets[6] = []core.KeyValue{{Key: makeKVKey(0x18), Value: makeKVVal(2)}} // 0x18>>2=6 + + tasks := assignToWorkers(buckets, 8) + // Only 2 non-empty, so cap to 2 workers. + require.Len(t, tasks, 2) + assert.Len(t, tasks[0].children, 1) + assert.Len(t, tasks[1].children, 1) +} + +// --- Integration tests --- + +// permissivePageSet wraps MemoryPageSet to return fresh pages for missing +// entries (matching bitboxPageSet behavior). This is needed because the +// parallel workers descend into child pages that may not exist yet. +type permissivePageSet struct { + *MemoryPageSet +} + +func (ps *permissivePageSet) Get(pageID core.PageID) (*core.RawPage, PageOrigin, bool) { + page, origin, ok := ps.MemoryPageSet.Get(pageID) + if !ok { + fresh := new(core.RawPage) + return fresh, PageOrigin{Kind: PageOriginFresh}, true + } + return page, origin, true +} + +func memoryPageSetFactory() PageSet { + return &permissivePageSet{NewMemoryPageSet(true)} +} + +func TestParallelUpdateEmpty(t *testing.T) { + out := ParallelUpdate(core.Terminator, nil, 4, memoryPageSetFactory) + assert.Equal(t, core.Terminator, out.Root) +} + +func TestParallelUpdateSingleKey(t *testing.T) { + kv := core.KeyValue{Key: makeKVKey(0x50), Value: makeKVVal(1)} + kvs := []core.KeyValue{kv} + + out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory) + expected := expectedRoot(kvs) + assert.Equal(t, expected, out.Root) +} + +func TestParallelUpdateTwoKeysDifferentChildren(t *testing.T) { + // 0x00 → child 0, 0x80 → child 32. + kvs := []core.KeyValue{ + {Key: makeKVKey(0x00), Value: makeKVVal(1)}, + {Key: makeKVKey(0x80), Value: makeKVVal(2)}, + } + + out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory) + expected := expectedRoot(kvs) + assert.Equal(t, expected, out.Root) +} + +func TestParallelUpdateSparseChildren(t *testing.T) { + // Only children 0 and 63 have ops. + kvs := []core.KeyValue{ + {Key: makeKVKey(0x00), Value: makeKVVal(1)}, + {Key: makeKVKey(0xFC), Value: makeKVVal(2)}, + } + + out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory) + expected := expectedRoot(kvs) + assert.Equal(t, expected, out.Root) +} + +func TestParallelUpdateSingleChild(t *testing.T) { + // All keys land in child 0 (first 6 bits = 000000). + kvs := []core.KeyValue{ + {Key: makeKVKey(0x00), Value: makeKVVal(1)}, + {Key: makeKVKey(0x01), Value: makeKVVal(2)}, + {Key: makeKVKey(0x02), Value: makeKVVal(3)}, + {Key: makeKVKey(0x03), Value: makeKVVal(4)}, + } + sort.Slice(kvs, func(i, j int) bool { return kvLess(&kvs[i], &kvs[j]) }) + + out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory) + expected := expectedRoot(kvs) + assert.Equal(t, expected, out.Root) +} + +func TestParallelUpdateFallbackSmallBatch(t *testing.T) { + // Less than 64 ops → single-threaded fallback. + kvs := randomKVs(10, 42) + out := ParallelUpdate(core.Terminator, kvs, 8, memoryPageSetFactory) + expected := expectedRoot(kvs) + assert.Equal(t, expected, out.Root) +} + +func TestParallelUpdateDeterministic(t *testing.T) { + kvs := randomKVs(200, 99) + + r1 := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory).Root + r2 := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory).Root + assert.Equal(t, r1, r2, "same inputs should produce same root") +} + +func TestParallelUpdateMatchesSingleThreaded(t *testing.T) { + tests := []struct { + name string + numKVs int + workers int + }{ + {"1kv_2w", 1, 2}, + {"10kv_2w", 10, 2}, + {"100kv_2w", 100, 2}, + {"100kv_4w", 100, 4}, + {"100kv_8w", 100, 8}, + {"500kv_4w", 500, 4}, + {"1000kv_8w", 1000, 8}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + kvs := randomKVs(tc.numKVs, 12345) + + single := singleThreadedUpdate( + core.Terminator, kvs, NewMemoryPageSet(true), + ) + parallel := ParallelUpdate( + core.Terminator, kvs, tc.workers, memoryPageSetFactory, + ) + + assert.Equal(t, single.Root, parallel.Root, + "parallel root should match single-threaded root") + }) + } +} + +// --- helpers --- + +func randomKVs(n int, seed int64) []core.KeyValue { + rng := rand.New(rand.NewSource(seed)) + kvs := make([]core.KeyValue, n) + seen := make(map[core.KeyPath]bool, n) + + for i := range n { + for { + var kp core.KeyPath + rng.Read(kp[:]) + if seen[kp] { + continue + } + seen[kp] = true + var vh core.ValueHash + rng.Read(vh[:]) + kvs[i] = core.KeyValue{Key: kp, Value: vh} + break + } + } + + sort.Slice(kvs, func(i, j int) bool { return kvLess(&kvs[i], &kvs[j]) }) + return kvs +} + +func kvLess(a, b *core.KeyValue) bool { + for i := range a.Key { + if a.Key[i] < b.Key[i] { + return true + } + if a.Key[i] > b.Key[i] { + return false + } + } + return false +}