mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-06-19 21:31:37 +00:00
nomt/merkle: add Phase 7 parallel workers for trie updates
Parallelize the PageWalker trie update across multiple goroutines by partitioning sorted operations by the root page's 64 child subtrees (first 6 bits of each key path). Each worker runs an independent PageWalker constrained to child pages below the root (using parentPage mechanism), producing ChildPageRoots. After all workers complete, a root walker places the child roots using AdvanceAndPlaceNode and concludes with the final trie root. Workers operate on disjoint page subtrees so no synchronization is needed during computation — only sync.WaitGroup for goroutine join. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
53fd00926f
commit
cb3e13d93d
3 changed files with 527 additions and 46 deletions
|
|
@ -10,6 +10,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
|
|
@ -27,6 +28,10 @@ const (
|
|||
type Config struct {
|
||||
// HTCapacity is the number of hash table buckets. Must be a power of 2.
|
||||
HTCapacity uint64
|
||||
|
||||
// NumWorkers is the number of parallel goroutines for trie updates.
|
||||
// Defaults to runtime.NumCPU() if zero.
|
||||
NumWorkers int
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration.
|
||||
|
|
@ -38,11 +43,12 @@ func DefaultConfig() Config {
|
|||
|
||||
// DB is the NOMT trie database.
|
||||
type DB struct {
|
||||
dataDir string
|
||||
bb *bitbox.DB
|
||||
root core.Node
|
||||
syncSeqn uint32
|
||||
mu sync.RWMutex
|
||||
dataDir string
|
||||
bb *bitbox.DB
|
||||
root core.Node
|
||||
syncSeqn uint32
|
||||
numWorkers int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Open opens or creates a NOMT trie database at the given directory.
|
||||
|
|
@ -75,10 +81,16 @@ func Open(dataDir string, config Config) (*DB, error) {
|
|||
}
|
||||
}
|
||||
|
||||
numWorkers := config.NumWorkers
|
||||
if numWorkers <= 0 {
|
||||
numWorkers = runtime.NumCPU()
|
||||
}
|
||||
|
||||
db := &DB{
|
||||
dataDir: dataDir,
|
||||
bb: bb,
|
||||
root: core.Terminator,
|
||||
dataDir: dataDir,
|
||||
bb: bb,
|
||||
root: core.Terminator,
|
||||
numWorkers: numWorkers,
|
||||
}
|
||||
|
||||
// Run WAL recovery.
|
||||
|
|
@ -136,52 +148,21 @@ func (db *DB) Update(ops []core.LeafOp) (core.Node, error) {
|
|||
return ops[i].Key != ops[j].Key && keyLess(&ops[i].Key, &ops[j].Key)
|
||||
})
|
||||
|
||||
// Build a BitboxPageSet that loads pages from disk.
|
||||
pageSet := newBitboxPageSet(db.bb)
|
||||
|
||||
// For a simple implementation, treat the entire trie as a single
|
||||
// terminal at the root and replace it.
|
||||
// Convert to KeyValue (filter out deletes).
|
||||
kvs := make([]core.KeyValue, 0, len(ops))
|
||||
for _, op := range ops {
|
||||
if op.Value != nil {
|
||||
kvs = append(kvs, core.KeyValue{Key: op.Key, Value: *op.Value})
|
||||
}
|
||||
}
|
||||
|
||||
walker := merkle.NewPageWalker(db.root, nil)
|
||||
|
||||
if len(kvs) > 0 || len(ops) > 0 {
|
||||
// Simple approach: single advance at root with all operations.
|
||||
// For an incremental update on a non-empty trie, we'd need to
|
||||
// seek to terminals first. This simplified version rebuilds.
|
||||
pos := core.NewTriePosition()
|
||||
pos.Down(false) // advance to position [0] for left subtree
|
||||
|
||||
// Split ops into left (0-prefix) and right (1-prefix).
|
||||
leftKVs := make([]core.KeyValue, 0, len(kvs))
|
||||
rightKVs := make([]core.KeyValue, 0, len(kvs))
|
||||
for _, kv := range kvs {
|
||||
if kv.Key[0]&0x80 == 0 {
|
||||
leftKVs = append(leftKVs, kv)
|
||||
} else {
|
||||
rightKVs = append(rightKVs, kv)
|
||||
}
|
||||
}
|
||||
|
||||
leftPos := core.NewTriePosition()
|
||||
leftPos.Down(false)
|
||||
if len(leftKVs) > 0 {
|
||||
walker.AdvanceAndReplace(pageSet, leftPos, leftKVs)
|
||||
}
|
||||
|
||||
rightPos := core.NewTriePosition()
|
||||
rightPos.Down(true)
|
||||
if len(rightKVs) > 0 {
|
||||
walker.AdvanceAndReplace(pageSet, rightPos, rightKVs)
|
||||
}
|
||||
if len(kvs) == 0 {
|
||||
return db.root, nil
|
||||
}
|
||||
|
||||
out := walker.Conclude()
|
||||
pageSetFactory := func() merkle.PageSet {
|
||||
return newBitboxPageSet(db.bb)
|
||||
}
|
||||
out := merkle.ParallelUpdate(db.root, kvs, db.numWorkers, pageSetFactory)
|
||||
|
||||
// Persist updated pages.
|
||||
walPath := filepath.Join(db.dataDir, walFileName)
|
||||
|
|
|
|||
242
nomt/merkle/worker.go
Normal file
242
nomt/merkle/worker.go
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
package merkle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/ethereum/go-ethereum/nomt/core"
|
||||
)
|
||||
|
||||
// childBucket groups key-value operations for a single root page child index.
|
||||
type childBucket struct {
|
||||
childIndex uint8
|
||||
kvs []core.KeyValue
|
||||
}
|
||||
|
||||
// workerTask describes the work assigned to a single worker goroutine.
|
||||
type workerTask struct {
|
||||
children []childBucket
|
||||
}
|
||||
|
||||
// workerResult holds the output produced by a single worker.
|
||||
type workerResult struct {
|
||||
childPageRoots []childPageRoot
|
||||
pages []UpdatedPage
|
||||
err error
|
||||
}
|
||||
|
||||
// ParallelUpdate applies sorted key-value operations to the trie using
|
||||
// multiple worker goroutines. Each worker processes a disjoint set of root
|
||||
// page child subtrees (partitioned by the first 6 bits of each key path).
|
||||
//
|
||||
// If numWorkers <= 1 or the batch is small, falls back to single-threaded.
|
||||
// The pageSetFactory is called once per worker to create independent PageSets.
|
||||
func ParallelUpdate(
|
||||
root core.Node,
|
||||
kvs []core.KeyValue,
|
||||
numWorkers int,
|
||||
pageSetFactory func() PageSet,
|
||||
) Output {
|
||||
if numWorkers <= 0 {
|
||||
numWorkers = runtime.NumCPU()
|
||||
}
|
||||
if len(kvs) == 0 {
|
||||
return Output{Root: root}
|
||||
}
|
||||
if numWorkers <= 1 || len(kvs) < 64 {
|
||||
return singleThreadedUpdate(root, kvs, pageSetFactory())
|
||||
}
|
||||
|
||||
// Step 1: Partition by child index (first 6 bits).
|
||||
buckets := partitionByChildIndex(kvs)
|
||||
|
||||
// Step 2: Assign to workers.
|
||||
tasks := assignToWorkers(buckets, numWorkers)
|
||||
if len(tasks) == 0 {
|
||||
return Output{Root: root}
|
||||
}
|
||||
if len(tasks) == 1 {
|
||||
return singleThreadedUpdate(root, kvs, pageSetFactory())
|
||||
}
|
||||
|
||||
// Step 3: Launch workers.
|
||||
results := make([]workerResult, len(tasks))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(tasks))
|
||||
|
||||
for i, task := range tasks {
|
||||
go func(idx int, t workerTask) {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
results[idx] = workerResult{
|
||||
err: fmt.Errorf("worker %d panicked: %v", idx, r),
|
||||
}
|
||||
}
|
||||
}()
|
||||
ps := pageSetFactory()
|
||||
results[idx] = runWorker(root, t, ps)
|
||||
}(i, task)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Step 4: Check for errors.
|
||||
for _, r := range results {
|
||||
if r.err != nil {
|
||||
panic(r.err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Collect all child page roots and updated pages.
|
||||
// ChildPageRoots are already left-to-right ordered because tasks are
|
||||
// assigned in ascending child index order and each worker processes
|
||||
// children left-to-right.
|
||||
var allChildRoots []childPageRoot
|
||||
var allPages []UpdatedPage
|
||||
for _, r := range results {
|
||||
allChildRoots = append(allChildRoots, r.childPageRoots...)
|
||||
allPages = append(allPages, r.pages...)
|
||||
}
|
||||
|
||||
// Step 6: Root walker places all child roots.
|
||||
rootPS := pageSetFactory()
|
||||
rootWalker := NewPageWalker(root, nil)
|
||||
for _, cpr := range allChildRoots {
|
||||
rootWalker.AdvanceAndPlaceNode(rootPS, cpr.Position, cpr.Node)
|
||||
}
|
||||
rootOut := rootWalker.Conclude()
|
||||
|
||||
// Step 7: Merge all pages.
|
||||
allPages = append(allPages, rootOut.Pages...)
|
||||
rootOut.Pages = allPages
|
||||
return rootOut
|
||||
}
|
||||
|
||||
// singleThreadedUpdate runs the trie update with a single PageWalker.
|
||||
// This is the fallback for small batches or single-worker configurations.
|
||||
func singleThreadedUpdate(
|
||||
root core.Node,
|
||||
kvs []core.KeyValue,
|
||||
pageSet PageSet,
|
||||
) Output {
|
||||
walker := NewPageWalker(root, nil)
|
||||
|
||||
var leftKVs, rightKVs []core.KeyValue
|
||||
for i := range kvs {
|
||||
if kvs[i].Key[0]&0x80 == 0 {
|
||||
leftKVs = append(leftKVs, kvs[i])
|
||||
} else {
|
||||
rightKVs = append(rightKVs, kvs[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(leftKVs) > 0 {
|
||||
leftPos := core.NewTriePosition()
|
||||
leftPos.Down(false)
|
||||
walker.AdvanceAndReplace(pageSet, leftPos, leftKVs)
|
||||
}
|
||||
if len(rightKVs) > 0 {
|
||||
rightPos := core.NewTriePosition()
|
||||
rightPos.Down(true)
|
||||
walker.AdvanceAndReplace(pageSet, rightPos, rightKVs)
|
||||
}
|
||||
|
||||
return walker.Conclude()
|
||||
}
|
||||
|
||||
// partitionByChildIndex buckets sorted KVs by the first 6 bits of each key
|
||||
// path (the root page's child index: 0-63).
|
||||
func partitionByChildIndex(kvs []core.KeyValue) [64][]core.KeyValue {
|
||||
var buckets [64][]core.KeyValue
|
||||
for i := range kvs {
|
||||
childIdx := kvs[i].Key[0] >> 2
|
||||
buckets[childIdx] = append(buckets[childIdx], kvs[i])
|
||||
}
|
||||
return buckets
|
||||
}
|
||||
|
||||
// assignToWorkers distributes non-empty child buckets across numWorkers
|
||||
// contiguous ranges.
|
||||
func assignToWorkers(
|
||||
buckets [64][]core.KeyValue,
|
||||
numWorkers int,
|
||||
) []workerTask {
|
||||
var nonEmpty []childBucket
|
||||
for i, kvs := range buckets {
|
||||
if len(kvs) > 0 {
|
||||
nonEmpty = append(nonEmpty, childBucket{
|
||||
childIndex: uint8(i),
|
||||
kvs: kvs,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(nonEmpty) == 0 {
|
||||
return nil
|
||||
}
|
||||
if numWorkers > len(nonEmpty) {
|
||||
numWorkers = len(nonEmpty)
|
||||
}
|
||||
|
||||
tasks := make([]workerTask, numWorkers)
|
||||
perWorker := len(nonEmpty) / numWorkers
|
||||
remainder := len(nonEmpty) % numWorkers
|
||||
|
||||
idx := 0
|
||||
for w := range numWorkers {
|
||||
count := perWorker
|
||||
if w < remainder {
|
||||
count++
|
||||
}
|
||||
tasks[w] = workerTask{children: nonEmpty[idx : idx+count]}
|
||||
idx += count
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
// runWorker processes a worker's assigned child subtrees using a PageWalker
|
||||
// constrained to pages below the root page.
|
||||
func runWorker(
|
||||
root core.Node,
|
||||
task workerTask,
|
||||
pageSet PageSet,
|
||||
) workerResult {
|
||||
rootPageID := core.RootPageID()
|
||||
walker := NewPageWalker(root, &rootPageID)
|
||||
|
||||
for _, child := range task.children {
|
||||
var leftKVs, rightKVs []core.KeyValue
|
||||
for i := range child.kvs {
|
||||
if (child.kvs[i].Key[0]>>1)&1 == 0 {
|
||||
leftKVs = append(leftKVs, child.kvs[i])
|
||||
} else {
|
||||
rightKVs = append(rightKVs, child.kvs[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(leftKVs) > 0 {
|
||||
walker.AdvanceAndReplace(pageSet, childPosition(child.childIndex, false), leftKVs)
|
||||
}
|
||||
if len(rightKVs) > 0 {
|
||||
walker.AdvanceAndReplace(pageSet, childPosition(child.childIndex, true), rightKVs)
|
||||
}
|
||||
}
|
||||
|
||||
out := walker.Conclude()
|
||||
return workerResult{
|
||||
childPageRoots: out.ChildPageRoots,
|
||||
pages: out.Pages,
|
||||
}
|
||||
}
|
||||
|
||||
// childPosition creates a TriePosition at depth 7: 6 bits encoding the root
|
||||
// page's child index (MSB first) plus one additional bit for left/right within
|
||||
// the child page.
|
||||
func childPosition(childIndex uint8, rightBit bool) core.TriePosition {
|
||||
pos := core.NewTriePosition()
|
||||
for b := 5; b >= 0; b-- {
|
||||
pos.Down((childIndex>>b)&1 == 1)
|
||||
}
|
||||
pos.Down(rightBit)
|
||||
return pos
|
||||
}
|
||||
258
nomt/merkle/worker_test.go
Normal file
258
nomt/merkle/worker_test.go
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
package merkle
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/nomt/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Unit tests for helpers ---
|
||||
|
||||
func TestPartitionByChildIndex(t *testing.T) {
|
||||
// Key 0x00... → child 0, key 0xFC... → child 63.
|
||||
kvs := []core.KeyValue{
|
||||
{Key: makeKVKey(0x00), Value: makeKVVal(1)},
|
||||
{Key: makeKVKey(0x04), Value: makeKVVal(2)}, // 0x04 >> 2 = 1
|
||||
{Key: makeKVKey(0xFC), Value: makeKVVal(3)}, // 0xFC >> 2 = 63
|
||||
}
|
||||
buckets := partitionByChildIndex(kvs)
|
||||
|
||||
assert.Len(t, buckets[0], 1)
|
||||
assert.Len(t, buckets[1], 1)
|
||||
assert.Len(t, buckets[63], 1)
|
||||
|
||||
// All other buckets should be empty.
|
||||
nonEmpty := 0
|
||||
for _, b := range buckets {
|
||||
if len(b) > 0 {
|
||||
nonEmpty++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 3, nonEmpty)
|
||||
}
|
||||
|
||||
func TestChildPosition(t *testing.T) {
|
||||
// Child 0, left: 7 bits all false → depth 7.
|
||||
pos := childPosition(0, false)
|
||||
assert.Equal(t, uint16(7), pos.Depth())
|
||||
for i := range 7 {
|
||||
assert.False(t, pos.Bit(i), "bit %d should be 0", i)
|
||||
}
|
||||
|
||||
// Child 0, right: 6 false + 1 true → depth 7.
|
||||
pos = childPosition(0, true)
|
||||
assert.Equal(t, uint16(7), pos.Depth())
|
||||
for i := range 6 {
|
||||
assert.False(t, pos.Bit(i), "bit %d should be 0", i)
|
||||
}
|
||||
assert.True(t, pos.Bit(6))
|
||||
|
||||
// Child 63 (0b111111), left: 6 true + 1 false → depth 7.
|
||||
pos = childPosition(63, false)
|
||||
assert.Equal(t, uint16(7), pos.Depth())
|
||||
for i := range 6 {
|
||||
assert.True(t, pos.Bit(i), "bit %d should be 1", i)
|
||||
}
|
||||
assert.False(t, pos.Bit(6))
|
||||
|
||||
// Child 63 (0b111111), right: 7 true → depth 7.
|
||||
pos = childPosition(63, true)
|
||||
assert.Equal(t, uint16(7), pos.Depth())
|
||||
for i := range 7 {
|
||||
assert.True(t, pos.Bit(i), "bit %d should be 1", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssignToWorkers(t *testing.T) {
|
||||
// 3 non-empty buckets, 2 workers.
|
||||
var buckets [64][]core.KeyValue
|
||||
buckets[0] = []core.KeyValue{{Key: makeKVKey(0x00), Value: makeKVVal(1)}}
|
||||
buckets[10] = []core.KeyValue{{Key: makeKVKey(0x28), Value: makeKVVal(2)}} // 0x28>>2=10
|
||||
buckets[63] = []core.KeyValue{{Key: makeKVKey(0xFC), Value: makeKVVal(3)}}
|
||||
|
||||
tasks := assignToWorkers(buckets, 2)
|
||||
require.Len(t, tasks, 2)
|
||||
// 3 items / 2 workers: first gets 2, second gets 1.
|
||||
assert.Len(t, tasks[0].children, 2)
|
||||
assert.Len(t, tasks[1].children, 1)
|
||||
assert.Equal(t, uint8(0), tasks[0].children[0].childIndex)
|
||||
assert.Equal(t, uint8(10), tasks[0].children[1].childIndex)
|
||||
assert.Equal(t, uint8(63), tasks[1].children[0].childIndex)
|
||||
}
|
||||
|
||||
func TestAssignToWorkersMoreWorkersThanChildren(t *testing.T) {
|
||||
var buckets [64][]core.KeyValue
|
||||
buckets[5] = []core.KeyValue{{Key: makeKVKey(0x14), Value: makeKVVal(1)}} // 0x14>>2=5
|
||||
buckets[6] = []core.KeyValue{{Key: makeKVKey(0x18), Value: makeKVVal(2)}} // 0x18>>2=6
|
||||
|
||||
tasks := assignToWorkers(buckets, 8)
|
||||
// Only 2 non-empty, so cap to 2 workers.
|
||||
require.Len(t, tasks, 2)
|
||||
assert.Len(t, tasks[0].children, 1)
|
||||
assert.Len(t, tasks[1].children, 1)
|
||||
}
|
||||
|
||||
// --- Integration tests ---
|
||||
|
||||
// permissivePageSet wraps MemoryPageSet to return fresh pages for missing
|
||||
// entries (matching bitboxPageSet behavior). This is needed because the
|
||||
// parallel workers descend into child pages that may not exist yet.
|
||||
type permissivePageSet struct {
|
||||
*MemoryPageSet
|
||||
}
|
||||
|
||||
func (ps *permissivePageSet) Get(pageID core.PageID) (*core.RawPage, PageOrigin, bool) {
|
||||
page, origin, ok := ps.MemoryPageSet.Get(pageID)
|
||||
if !ok {
|
||||
fresh := new(core.RawPage)
|
||||
return fresh, PageOrigin{Kind: PageOriginFresh}, true
|
||||
}
|
||||
return page, origin, true
|
||||
}
|
||||
|
||||
func memoryPageSetFactory() PageSet {
|
||||
return &permissivePageSet{NewMemoryPageSet(true)}
|
||||
}
|
||||
|
||||
func TestParallelUpdateEmpty(t *testing.T) {
|
||||
out := ParallelUpdate(core.Terminator, nil, 4, memoryPageSetFactory)
|
||||
assert.Equal(t, core.Terminator, out.Root)
|
||||
}
|
||||
|
||||
func TestParallelUpdateSingleKey(t *testing.T) {
|
||||
kv := core.KeyValue{Key: makeKVKey(0x50), Value: makeKVVal(1)}
|
||||
kvs := []core.KeyValue{kv}
|
||||
|
||||
out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory)
|
||||
expected := expectedRoot(kvs)
|
||||
assert.Equal(t, expected, out.Root)
|
||||
}
|
||||
|
||||
func TestParallelUpdateTwoKeysDifferentChildren(t *testing.T) {
|
||||
// 0x00 → child 0, 0x80 → child 32.
|
||||
kvs := []core.KeyValue{
|
||||
{Key: makeKVKey(0x00), Value: makeKVVal(1)},
|
||||
{Key: makeKVKey(0x80), Value: makeKVVal(2)},
|
||||
}
|
||||
|
||||
out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory)
|
||||
expected := expectedRoot(kvs)
|
||||
assert.Equal(t, expected, out.Root)
|
||||
}
|
||||
|
||||
func TestParallelUpdateSparseChildren(t *testing.T) {
|
||||
// Only children 0 and 63 have ops.
|
||||
kvs := []core.KeyValue{
|
||||
{Key: makeKVKey(0x00), Value: makeKVVal(1)},
|
||||
{Key: makeKVKey(0xFC), Value: makeKVVal(2)},
|
||||
}
|
||||
|
||||
out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory)
|
||||
expected := expectedRoot(kvs)
|
||||
assert.Equal(t, expected, out.Root)
|
||||
}
|
||||
|
||||
func TestParallelUpdateSingleChild(t *testing.T) {
|
||||
// All keys land in child 0 (first 6 bits = 000000).
|
||||
kvs := []core.KeyValue{
|
||||
{Key: makeKVKey(0x00), Value: makeKVVal(1)},
|
||||
{Key: makeKVKey(0x01), Value: makeKVVal(2)},
|
||||
{Key: makeKVKey(0x02), Value: makeKVVal(3)},
|
||||
{Key: makeKVKey(0x03), Value: makeKVVal(4)},
|
||||
}
|
||||
sort.Slice(kvs, func(i, j int) bool { return kvLess(&kvs[i], &kvs[j]) })
|
||||
|
||||
out := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory)
|
||||
expected := expectedRoot(kvs)
|
||||
assert.Equal(t, expected, out.Root)
|
||||
}
|
||||
|
||||
func TestParallelUpdateFallbackSmallBatch(t *testing.T) {
|
||||
// Less than 64 ops → single-threaded fallback.
|
||||
kvs := randomKVs(10, 42)
|
||||
out := ParallelUpdate(core.Terminator, kvs, 8, memoryPageSetFactory)
|
||||
expected := expectedRoot(kvs)
|
||||
assert.Equal(t, expected, out.Root)
|
||||
}
|
||||
|
||||
func TestParallelUpdateDeterministic(t *testing.T) {
|
||||
kvs := randomKVs(200, 99)
|
||||
|
||||
r1 := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory).Root
|
||||
r2 := ParallelUpdate(core.Terminator, kvs, 4, memoryPageSetFactory).Root
|
||||
assert.Equal(t, r1, r2, "same inputs should produce same root")
|
||||
}
|
||||
|
||||
func TestParallelUpdateMatchesSingleThreaded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numKVs int
|
||||
workers int
|
||||
}{
|
||||
{"1kv_2w", 1, 2},
|
||||
{"10kv_2w", 10, 2},
|
||||
{"100kv_2w", 100, 2},
|
||||
{"100kv_4w", 100, 4},
|
||||
{"100kv_8w", 100, 8},
|
||||
{"500kv_4w", 500, 4},
|
||||
{"1000kv_8w", 1000, 8},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
kvs := randomKVs(tc.numKVs, 12345)
|
||||
|
||||
single := singleThreadedUpdate(
|
||||
core.Terminator, kvs, NewMemoryPageSet(true),
|
||||
)
|
||||
parallel := ParallelUpdate(
|
||||
core.Terminator, kvs, tc.workers, memoryPageSetFactory,
|
||||
)
|
||||
|
||||
assert.Equal(t, single.Root, parallel.Root,
|
||||
"parallel root should match single-threaded root")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func randomKVs(n int, seed int64) []core.KeyValue {
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
kvs := make([]core.KeyValue, n)
|
||||
seen := make(map[core.KeyPath]bool, n)
|
||||
|
||||
for i := range n {
|
||||
for {
|
||||
var kp core.KeyPath
|
||||
rng.Read(kp[:])
|
||||
if seen[kp] {
|
||||
continue
|
||||
}
|
||||
seen[kp] = true
|
||||
var vh core.ValueHash
|
||||
rng.Read(vh[:])
|
||||
kvs[i] = core.KeyValue{Key: kp, Value: vh}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(kvs, func(i, j int) bool { return kvLess(&kvs[i], &kvs[j]) })
|
||||
return kvs
|
||||
}
|
||||
|
||||
func kvLess(a, b *core.KeyValue) bool {
|
||||
for i := range a.Key {
|
||||
if a.Key[i] < b.Key[i] {
|
||||
return true
|
||||
}
|
||||
if a.Key[i] > b.Key[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Loading…
Reference in a new issue