From 88fd10529fd44f4977ee3f56ec3ac1e0a373fa2c Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 12 Feb 2026 17:10:58 +0800 Subject: [PATCH] nomt/merkle: add Phase 2 merkle engine (PageWalker, PageSet, ElidedChildren) Implement the in-memory batch update engine for the NOMT binary merkle trie: - elided.go: ElidedChildren 64-bit bitfield for tracking elided child pages - pageset.go: PageSet interface + MemoryPageSet in-memory implementation - pagewalker.go: PageWalker left-to-right walker with partial compaction - AdvanceAndReplace: replace terminal nodes with sub-tries - AdvanceAndPlaceNode: place pre-computed child page roots - Conclude: finalize walk and return new root + updated pages - compactUp/compactStep: hash upward with leaf/terminator compaction - core/triepos.go: add SharedDepth method needed by PageWalker Co-Authored-By: Claude Opus 4.6 --- nomt/core/triepos.go | 14 + nomt/merkle/elided.go | 54 ++++ nomt/merkle/elided_test.go | 54 ++++ nomt/merkle/pageset.go | 113 +++++++ nomt/merkle/pageset_test.go | 72 +++++ nomt/merkle/pagewalker.go | 541 +++++++++++++++++++++++++++++++++ nomt/merkle/pagewalker_test.go | 360 ++++++++++++++++++++++ 7 files changed, 1208 insertions(+) create mode 100644 nomt/merkle/elided.go create mode 100644 nomt/merkle/elided_test.go create mode 100644 nomt/merkle/pageset.go create mode 100644 nomt/merkle/pageset_test.go create mode 100644 nomt/merkle/pagewalker.go create mode 100644 nomt/merkle/pagewalker_test.go diff --git a/nomt/core/triepos.go b/nomt/core/triepos.go index 4e2eddf082..4a7d555e9f 100644 --- a/nomt/core/triepos.go +++ b/nomt/core/triepos.go @@ -184,6 +184,20 @@ func (p *TriePosition) SiblingIndex() int { return siblingIndex(p.nodeIndex) } +// SharedDepth returns the number of leading path bits shared between two +// TriePositions, considering only bits up to the shorter depth. +func (p *TriePosition) SharedDepth(other *TriePosition) int { + maxBits := min(int(p.depth), int(other.depth)) + for i := range maxBits { + pBit := (p.path[i/8] >> (7 - i%8)) & 1 + oBit := (other.path[i/8] >> (7 - i%8)) & 1 + if pBit != oBit { + return i + } + } + return maxBits +} + // --- internal helpers --- // computeNodeIndex converts a page-local bit path to a level-order node index. diff --git a/nomt/merkle/elided.go b/nomt/merkle/elided.go new file mode 100644 index 0000000000..09383c5f73 --- /dev/null +++ b/nomt/merkle/elided.go @@ -0,0 +1,54 @@ +// Package merkle implements the in-memory batch update engine for the NOMT +// binary merkle trie. It processes sorted key-value changes and produces +// updated pages plus a new root hash. +package merkle + +import "encoding/binary" + +// ElidedChildren is a 64-bit bitfield tracking which of a page's 64 child +// pages are elided (not stored on disk and reconstructed on-the-fly). +type ElidedChildren struct { + elided uint64 +} + +// NewElidedChildren returns an empty ElidedChildren with no children elided. +func NewElidedChildren() ElidedChildren { + return ElidedChildren{} +} + +// ElidedChildrenFromBytes decodes an ElidedChildren from its 8-byte +// little-endian representation. +func ElidedChildrenFromBytes(raw [8]byte) ElidedChildren { + return ElidedChildren{elided: binary.LittleEndian.Uint64(raw[:])} +} + +// ElidedChildrenFromUint64 wraps a raw uint64 bitfield. +func ElidedChildrenFromUint64(v uint64) ElidedChildren { + return ElidedChildren{elided: v} +} + +// ToBytes encodes the ElidedChildren as 8 bytes (little-endian). +func (e *ElidedChildren) ToBytes() [8]byte { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], e.elided) + return buf +} + +// SetElide marks or clears the elided flag for the given child index (0-63). +func (e *ElidedChildren) SetElide(childIndex uint8, elide bool) { + if elide { + e.elided |= 1 << childIndex + } else { + e.elided &^= 1 << childIndex + } +} + +// IsElided reports whether the child at the given index is elided. +func (e *ElidedChildren) IsElided(childIndex uint8) bool { + return (e.elided>>childIndex)&1 == 1 +} + +// Raw returns the underlying uint64 bitfield. +func (e *ElidedChildren) Raw() uint64 { + return e.elided +} diff --git a/nomt/merkle/elided_test.go b/nomt/merkle/elided_test.go new file mode 100644 index 0000000000..65dc493aa9 --- /dev/null +++ b/nomt/merkle/elided_test.go @@ -0,0 +1,54 @@ +package merkle + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestElidedChildrenNew(t *testing.T) { + ec := NewElidedChildren() + for i := range 64 { + assert.False(t, ec.IsElided(uint8(i)), "child %d", i) + } + assert.Equal(t, uint64(0), ec.Raw()) +} + +func TestElidedChildrenSetAndCheck(t *testing.T) { + ec := NewElidedChildren() + + ec.SetElide(0, true) + assert.True(t, ec.IsElided(0)) + assert.False(t, ec.IsElided(1)) + + ec.SetElide(63, true) + assert.True(t, ec.IsElided(63)) + + ec.SetElide(0, false) + assert.False(t, ec.IsElided(0)) + assert.True(t, ec.IsElided(63)) +} + +func TestElidedChildrenRoundTrip(t *testing.T) { + ec := NewElidedChildren() + ec.SetElide(5, true) + ec.SetElide(33, true) + ec.SetElide(62, true) + + encoded := ec.ToBytes() + decoded := ElidedChildrenFromBytes(encoded) + + assert.True(t, decoded.IsElided(5)) + assert.True(t, decoded.IsElided(33)) + assert.True(t, decoded.IsElided(62)) + assert.False(t, decoded.IsElided(0)) + assert.Equal(t, ec.Raw(), decoded.Raw()) +} + +func TestElidedChildrenFromUint64(t *testing.T) { + ec := ElidedChildrenFromUint64(0xFF) + for i := range 8 { + assert.True(t, ec.IsElided(uint8(i)), "child %d", i) + } + assert.False(t, ec.IsElided(8)) +} diff --git a/nomt/merkle/pageset.go b/nomt/merkle/pageset.go new file mode 100644 index 0000000000..6ab192fbe4 --- /dev/null +++ b/nomt/merkle/pageset.go @@ -0,0 +1,113 @@ +package merkle + +import ( + "github.com/ethereum/go-ethereum/nomt/core" +) + +// PageOriginKind discriminates the origin of a page in the PageSet. +type PageOriginKind int + +const ( + // PageOriginPersisted indicates the page was loaded from on-disk storage. + PageOriginPersisted PageOriginKind = iota + // PageOriginFresh indicates the page was freshly created (zeroed). + PageOriginFresh +) + +// PageOrigin tracks where a page came from, used by the PageWalker to decide +// how to handle page elision and diff tracking. +type PageOrigin struct { + Kind PageOriginKind +} + +// PageSet is the interface through which the PageWalker reads and creates +// pages during trie updates. +type PageSet interface { + // Get retrieves a page by its ID. Returns the page, its origin, and + // whether it was found. + Get(pageID core.PageID) (*core.RawPage, PageOrigin, bool) + + // Contains reports whether a page exists in the set. + Contains(pageID core.PageID) bool + + // Fresh creates a new zeroed page for the given ID. + Fresh(pageID core.PageID) *core.RawPage + + // Insert adds or replaces a page in the set. + Insert(pageID core.PageID, page *core.RawPage, origin PageOrigin) +} + +// MemoryPageSet is an in-memory PageSet implementation backed by a map. +type MemoryPageSet struct { + pages map[string]memoryPageEntry +} + +type memoryPageEntry struct { + page *core.RawPage + origin PageOrigin +} + +// NewMemoryPageSet creates a MemoryPageSet, optionally pre-populated with a +// root page. +func NewMemoryPageSet(withRoot bool) *MemoryPageSet { + ps := &MemoryPageSet{ + pages: make(map[string]memoryPageEntry, 16), + } + if withRoot { + root := core.RootPageID() + page := new(core.RawPage) + ps.pages[pageIDKey(root)] = memoryPageEntry{ + page: page, + origin: PageOrigin{Kind: PageOriginFresh}, + } + } + return ps +} + +// Get retrieves a page from the in-memory set. +func (m *MemoryPageSet) Get(pageID core.PageID) (*core.RawPage, PageOrigin, bool) { + entry, ok := m.pages[pageIDKey(pageID)] + if !ok { + return nil, PageOrigin{}, false + } + // Return a copy so the walker can mutate freely. + pageCopy := new(core.RawPage) + *pageCopy = *entry.page + return pageCopy, entry.origin, true +} + +// Contains reports whether the page exists. +func (m *MemoryPageSet) Contains(pageID core.PageID) bool { + _, ok := m.pages[pageIDKey(pageID)] + return ok +} + +// Fresh creates a new zeroed page. +func (m *MemoryPageSet) Fresh(pageID core.PageID) *core.RawPage { + return new(core.RawPage) +} + +// Insert stores a page in the set. +func (m *MemoryPageSet) Insert( + pageID core.PageID, page *core.RawPage, origin PageOrigin, +) { + m.pages[pageIDKey(pageID)] = memoryPageEntry{page: page, origin: origin} +} + +// Apply applies a list of UpdatedPages into the page set, making them +// available for subsequent reads. +func (m *MemoryPageSet) Apply(updates []UpdatedPage) { + for _, up := range updates { + pageCopy := new(core.RawPage) + *pageCopy = *up.Page + m.pages[pageIDKey(up.PageID)] = memoryPageEntry{ + page: pageCopy, + origin: PageOrigin{Kind: PageOriginPersisted}, + } + } +} + +func pageIDKey(id core.PageID) string { + encoded := id.Encode() + return string(encoded[:]) +} diff --git a/nomt/merkle/pageset_test.go b/nomt/merkle/pageset_test.go new file mode 100644 index 0000000000..cebb6480be --- /dev/null +++ b/nomt/merkle/pageset_test.go @@ -0,0 +1,72 @@ +package merkle + +import ( + "testing" + + "github.com/ethereum/go-ethereum/nomt/core" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMemoryPageSetRootInit(t *testing.T) { + ps := NewMemoryPageSet(true) + root := core.RootPageID() + assert.True(t, ps.Contains(root)) + + page, origin, ok := ps.Get(root) + require.True(t, ok) + assert.Equal(t, PageOriginFresh, origin.Kind) + assert.NotNil(t, page) +} + +func TestMemoryPageSetInsertGet(t *testing.T) { + ps := NewMemoryPageSet(false) + root := core.RootPageID() + assert.False(t, ps.Contains(root)) + + page := new(core.RawPage) + page.SetNodeAt(0, core.Node{0x42}) + ps.Insert(root, page, PageOrigin{Kind: PageOriginPersisted}) + + got, origin, ok := ps.Get(root) + require.True(t, ok) + assert.Equal(t, PageOriginPersisted, origin.Kind) + assert.Equal(t, core.Node{0x42}, got.NodeAt(0)) +} + +func TestMemoryPageSetGetReturnsCopy(t *testing.T) { + ps := NewMemoryPageSet(false) + root := core.RootPageID() + page := new(core.RawPage) + page.SetNodeAt(0, core.Node{0x01}) + ps.Insert(root, page, PageOrigin{Kind: PageOriginFresh}) + + got, _, _ := ps.Get(root) + got.SetNodeAt(0, core.Node{0xFF}) // mutate the copy + + original, _, _ := ps.Get(root) + assert.Equal(t, core.Node{0x01}, original.NodeAt(0), + "mutation should not affect the stored page") +} + +func TestMemoryPageSetFresh(t *testing.T) { + ps := NewMemoryPageSet(false) + root := core.RootPageID() + page := ps.Fresh(root) + assert.NotNil(t, page) + assert.Equal(t, core.Terminator, page.NodeAt(0)) +} + +func TestMemoryPageSetChildPage(t *testing.T) { + ps := NewMemoryPageSet(false) + root := core.RootPageID() + child, err := root.ChildPageID(5) + require.NoError(t, err) + + page := new(core.RawPage) + page.SetNodeAt(0, core.Node{0xAB}) + ps.Insert(child, page, PageOrigin{Kind: PageOriginPersisted}) + + assert.True(t, ps.Contains(child)) + assert.False(t, ps.Contains(root)) +} diff --git a/nomt/merkle/pagewalker.go b/nomt/merkle/pagewalker.go new file mode 100644 index 0000000000..7f10548744 --- /dev/null +++ b/nomt/merkle/pagewalker.go @@ -0,0 +1,541 @@ +package merkle + +import ( + "github.com/ethereum/go-ethereum/nomt/core" +) + +// UpdatedPage is a page that was modified during a trie update. +type UpdatedPage struct { + PageID core.PageID + Page *core.RawPage + Diff core.PageDiff +} + +// Output is the result of concluding a PageWalker. +type Output struct { + // Root is the new root node hash (set when no parent page was supplied). + Root core.Node + // Pages is the list of all pages modified during the update. + Pages []UpdatedPage + // ChildPageRoots holds (position, node) pairs for nodes that should be + // placed in the parent page's bottom layer (set when a parent page was + // supplied). + ChildPageRoots []childPageRoot +} + +type childPageRoot struct { + Position core.TriePosition + Node core.Node +} + +// stackPage is a page currently held in the walker's ascending page stack. +type stackPage struct { + pageID core.PageID + page *core.RawPage + diff core.PageDiff + elided ElidedChildren + origin PageOrigin +} + +// PageWalker performs left-to-right walking and updating of the page tree. +// +// Usage: create a PageWalker, make repeated calls to AdvanceAndReplace (and +// optionally Advance / AdvanceAndPlaceNode), then call Conclude to get the +// new root and all updated pages. +type PageWalker struct { + lastPosition *core.TriePosition // nil before first advance + position core.TriePosition + parentPage *core.PageID + root core.Node + + stack []stackPage + siblingStack []siblingEntry + prevNode *core.Node + + outputPages []UpdatedPage + childPageRoots []childPageRoot +} + +type siblingEntry struct { + node core.Node + depth int +} + +// NewPageWalker creates a new PageWalker starting from the given root. +// If parentPage is non-nil, the walker is constrained to pages below the +// parent, and Conclude returns ChildPageRoots instead of Root. +func NewPageWalker(root core.Node, parentPage *core.PageID) *PageWalker { + return &PageWalker{ + position: core.NewTriePosition(), + root: root, + parentPage: parentPage, + stack: make([]stackPage, 0, 8), + siblingStack: make([]siblingEntry, 0, 16), + outputPages: make([]UpdatedPage, 0, 16), + } +} + +// AdvanceAndReplace advances to the given position and replaces the terminal +// node there with a sub-trie built from the provided key-value pairs. +// +// The pairs must be sorted and must all share the prefix corresponding to +// the position. An empty slice deletes the existing terminal node. +// +// Panics if the position is not greater than the previous position. +func (w *PageWalker) AdvanceAndReplace( + pageSet PageSet, + newPos core.TriePosition, + ops []core.KeyValue, +) { + if w.lastPosition != nil { + w.assertForward(&newPos) + w.compactUp(&newPos) + } + pos := newPos + w.lastPosition = &pos + w.buildStack(pageSet, newPos) + w.replaceTerminal(pageSet, ops) +} + +// AdvanceAndPlaceNode advances to the given position and sets the given node. +// +// This is used to place child-page root nodes computed by parallel workers. +func (w *PageWalker) AdvanceAndPlaceNode( + pageSet PageSet, + newPos core.TriePosition, + node core.Node, +) { + if w.lastPosition != nil { + w.assertForward(&newPos) + w.compactUp(&newPos) + } + pos := newPos + w.lastPosition = &pos + w.buildStack(pageSet, newPos) + w.placeNode(node) +} + +// Advance moves the walker to a new position without modifying the trie. +func (w *PageWalker) Advance(newPos core.TriePosition) { + if w.lastPosition != nil { + w.assertForward(&newPos) + w.compactUp(&newPos) + } + pos := newPos + w.lastPosition = &pos +} + +// Conclude finalizes the walk and returns the output: a new root node and +// all updated pages. +func (w *PageWalker) Conclude() Output { + w.compactUpToRoot() + + out := Output{ + Root: w.root, + Pages: w.outputPages, + ChildPageRoots: w.childPageRoots, + } + return out +} + +// --- core operations --- + +func (w *PageWalker) placeNode(node core.Node) { + if w.position.IsRoot() { + prev := w.root + w.prevNode = &prev + w.root = node + } else { + prev := w.node() + w.prevNode = &prev + w.setNode(node) + } +} + +func (w *PageWalker) replaceTerminal(pageSet PageSet, ops []core.KeyValue) { + var existingNode core.Node + if w.position.IsRoot() { + existingNode = w.root + } else { + existingNode = w.node() + } + w.prevNode = &existingNode + + startDepth := w.position.Depth() + + core.BuildTrie(int(w.position.Depth()), ops, func(wn core.WriteNode) { + node := wn.Node + + // For internal nodes, clear garbage in the sibling slot if the + // sibling is a terminator. + if wn.Kind == core.WriteNodeInternal && wn.InternalData != nil { + lastBit := w.position.PeekLastBit() + var zeroSibling bool + if lastBit { + zeroSibling = core.IsTerminator(&wn.InternalData.Left) + } else { + zeroSibling = core.IsTerminator(&wn.InternalData.Right) + } + if zeroSibling { + w.setSibling(core.Terminator) + } + } + + // Navigate: up then down. + if wn.GoUp && len(wn.DownBits) > 0 { + // Optimization: if the first down bit goes to the sibling, + // use Sibling() instead of going up + down. + if wn.DownBits[0] != w.position.PeekLastBit() { + w.position.Sibling() + w.downBits(pageSet, wn.DownBits[1:], true) + } else { + w.up() + w.downBits(pageSet, wn.DownBits, true) + } + } else if wn.GoUp { + w.up() + } else if len(wn.DownBits) > 0 { + // First bit is only fresh if we are at the start position and + // start is at end of page (or root). After that, definitely fresh. + fresh := w.position.DepthInPage() == core.PageDepth || + w.position.IsRoot() + if w.position.Depth() <= startDepth { + w.downBits(pageSet, wn.DownBits[:1], fresh) + w.downBits(pageSet, wn.DownBits[1:], true) + } else { + w.downBits(pageSet, wn.DownBits, true) + } + } + + if w.position.IsRoot() { + w.root = node + } else { + w.setNode(node) + } + }) +} + +// --- stack navigation --- + +// up moves one level toward the root. If crossing a page boundary, pops the +// stack page and emits it as output. +func (w *PageWalker) up() { + if w.position.DepthInPage() == 1 { + w.popStackPage() + } + w.position.Up(1) +} + +// downBits descends along the given bit path, pushing new pages onto the +// stack as needed. +func (w *PageWalker) downBits(pageSet PageSet, bits []bool, fresh bool) { + for _, bit := range bits { + if w.position.IsRoot() { + rootID := core.RootPageID() + var page *core.RawPage + var origin PageOrigin + if fresh { + page = pageSet.Fresh(rootID) + origin = PageOrigin{Kind: PageOriginFresh} + } else { + var ok bool + page, origin, ok = pageSet.Get(rootID) + if !ok { + panic("pagewalker: root page not in page set") + } + } + w.pushPage(rootID, page, origin) + } else if w.position.DepthInPage() == core.PageDepth { + // Crossing into a child page. + parentSP := w.stack[len(w.stack)-1] + childIdx := w.position.ChildPageIndex() + childID, err := parentSP.pageID.ChildPageID(childIdx) + if err != nil { + panic("pagewalker: child page ID overflow") + } + var page *core.RawPage + var origin PageOrigin + if fresh { + page = pageSet.Fresh(childID) + origin = PageOrigin{Kind: PageOriginFresh} + } else { + var ok bool + page, origin, ok = pageSet.Get(childID) + if !ok { + panic("pagewalker: child page not in page set") + } + } + w.pushPage(childID, page, origin) + } + w.position.Down(bit) + } +} + +func (w *PageWalker) pushPage( + pageID core.PageID, page *core.RawPage, origin PageOrigin, +) { + w.stack = append(w.stack, stackPage{ + pageID: pageID, + page: page, + diff: core.PageDiff{}, + elided: ElidedChildrenFromUint64(page.ElidedChildren()), + origin: origin, + }) +} + +func (w *PageWalker) popStackPage() { + if len(w.stack) == 0 { + return + } + sp := w.stack[len(w.stack)-1] + w.stack = w.stack[:len(w.stack)-1] + + // Store elided children back into the page before emitting. + if !sp.pageID.IsRoot() { + sp.page.SetElidedChildren(sp.elided.Raw()) + } + + w.outputPages = append(w.outputPages, UpdatedPage{ + PageID: sp.pageID, + Page: sp.page, + Diff: sp.diff, + }) +} + +// buildStack pushes pages onto the stack from the current position down to +// the target position's page. +func (w *PageWalker) buildStack(pageSet PageSet, position core.TriePosition) { + w.position = position + newPageID := position.PageID() + if newPageID == nil { + // Target is at the root — pop all remaining stack pages. + for len(w.stack) > 0 { + w.popStackPage() + } + return + } + + // Determine which ancestor to push down from. + var target *core.PageID + if len(w.stack) > 0 { + t := w.stack[len(w.stack)-1].pageID + target = &t + } else if w.parentPage != nil { + target = w.parentPage + } + + // Collect pages from newPageID up to target. + var toPush []core.PageID + cur := *newPageID + for { + if target != nil && cur.Equal(*target) { + break + } + toPush = append(toPush, cur) + if cur.IsRoot() { + break + } + cur = cur.ParentPageID() + } + + // Push in ascending order (root-ward first). + for i := len(toPush) - 1; i >= 0; i-- { + pid := toPush[i] + page, origin, ok := pageSet.Get(pid) + if !ok { + panic("pagewalker: page not in page set during build_stack") + } + w.pushPage(pid, page, origin) + } +} + +// --- compaction --- + +// compactUp hashes upward from the current position toward the shared +// ancestor with the target position. This is the core "partial compaction" +// that avoids redundant hashing. +func (w *PageWalker) compactUp(targetPos *core.TriePosition) { + if len(w.stack) == 0 { + return + } + + currentDepth := int(w.position.Depth()) + sharedDepth := w.position.SharedDepth(targetPos) + + // Prune sibling stack entries beyond shared depth. + keepLen := 0 + for _, s := range w.siblingStack { + if s.depth <= sharedDepth { + keepLen++ + } else { + break + } + } + w.siblingStack = w.siblingStack[:keepLen] + + compactLayers := currentDepth - (sharedDepth + 1) + if compactLayers == 0 { + if w.prevNode != nil { + w.siblingStack = append(w.siblingStack, siblingEntry{ + node: *w.prevNode, + depth: currentDepth, + }) + w.prevNode = nil + } + } else { + w.prevNode = nil + } + + for i := range compactLayers { + nextNode := w.compactStep() + w.up() + + if len(w.stack) == 0 { + if w.parentPage == nil { + w.root = nextNode + } else { + w.childPageRoots = append(w.childPageRoots, childPageRoot{ + Position: w.position, + Node: nextNode, + }) + } + break + } else { + // Save the final relevant sibling. + if i == compactLayers-1 { + w.siblingStack = append(w.siblingStack, siblingEntry{ + node: w.node(), + depth: int(w.position.Depth()), + }) + } + w.setNode(nextNode) + } + } +} + +// compactUpToRoot is called by Conclude to hash all remaining layers to root. +func (w *PageWalker) compactUpToRoot() { + if len(w.stack) == 0 { + return + } + + w.siblingStack = w.siblingStack[:0] + compactLayers := int(w.position.Depth()) + + for range compactLayers { + nextNode := w.compactStep() + w.up() + + if len(w.stack) == 0 { + if w.parentPage == nil { + w.root = nextNode + } else { + w.childPageRoots = append(w.childPageRoots, childPageRoot{ + Position: w.position, + Node: nextNode, + }) + } + break + } else { + w.setNode(nextNode) + } + } +} + +// compactStep performs one layer of compaction: reads the current node and +// its sibling, then either compacts terminators/leaves upward or hashes +// an internal node. +func (w *PageWalker) compactStep() core.Node { + node := w.node() + sibling := w.siblingNode() + bit := w.position.PeekLastBit() + + nodeKind := core.NodeKindOf(&node) + sibKind := core.NodeKindOf(&sibling) + + switch { + case nodeKind == core.NodeTerminator && sibKind == core.NodeTerminator: + return core.Terminator + + case nodeKind == core.NodeLeaf && sibKind == core.NodeTerminator: + // Compact: clear this node, move leaf up. + w.setNode(core.Terminator) + return node + + case nodeKind == core.NodeTerminator && sibKind == core.NodeLeaf: + // Compact: clear sibling, move leaf up. + w.position.Sibling() + w.setNode(core.Terminator) + return sibling + + default: + // Internal: hash the two children together. + var id core.InternalData + if bit { + id = core.InternalData{Left: sibling, Right: node} + } else { + id = core.InternalData{Left: node, Right: sibling} + } + return core.HashInternal(&id) + } +} + +// --- page node access --- + +// node reads the node at the current position from the top stack page. +func (w *PageWalker) node() core.Node { + sp := &w.stack[len(w.stack)-1] + return sp.page.NodeAt(w.position.NodeIndex()) +} + +// siblingNode reads the sibling of the current position. +func (w *PageWalker) siblingNode() core.Node { + sp := &w.stack[len(w.stack)-1] + return sp.page.NodeAt(w.position.SiblingIndex()) +} + +// setNode writes a node at the current position and records it in the diff. +func (w *PageWalker) setNode(node core.Node) { + idx := w.position.NodeIndex() + sibNode := w.siblingNode() + + sp := &w.stack[len(w.stack)-1] + sp.page.SetNodeAt(idx, node) + + // If both the node and its sibling are terminators at the first layer, + // mark the page as cleared instead of changed. + if w.position.IsFirstLayerInPage() && + core.IsTerminator(&node) && + core.IsTerminator(&sibNode) { + sp.diff.SetCleared() + } else { + sp.diff.SetChanged(idx) + } +} + +// setSibling writes a node at the sibling position and records the change. +func (w *PageWalker) setSibling(node core.Node) { + sibIdx := w.position.SiblingIndex() + sp := &w.stack[len(w.stack)-1] + sp.page.SetNodeAt(sibIdx, node) + sp.diff.SetChanged(sibIdx) +} + +// --- assertions --- + +func (w *PageWalker) assertForward(newPos *core.TriePosition) { + if w.lastPosition == nil { + return + } + newPath := newPos.Path() + lastPath := w.lastPosition.Path() + for i := range newPath { + if newPath[i] > lastPath[i] { + return + } + if newPath[i] < lastPath[i] { + panic("pagewalker: positions must advance left-to-right") + } + } + panic("pagewalker: positions must advance left-to-right (equal)") +} diff --git a/nomt/merkle/pagewalker_test.go b/nomt/merkle/pagewalker_test.go new file mode 100644 index 0000000000..3ba0b53efd --- /dev/null +++ b/nomt/merkle/pagewalker_test.go @@ -0,0 +1,360 @@ +package merkle + +import ( + "testing" + + "github.com/ethereum/go-ethereum/nomt/core" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper: build a TriePosition by descending along the given bits. +func triePos(bits ...bool) core.TriePosition { + p := core.NewTriePosition() + for _, b := range bits { + p.Down(b) + } + return p +} + +// Helper: create a KeyPath with the given bits set at the MSB positions. +func keyPath(bits ...bool) core.KeyPath { + var kp core.KeyPath + for i, b := range bits { + if b { + kp[i/8] |= 1 << (7 - i%8) + } + } + return kp +} + +// Helper: create a ValueHash filled with a single byte. +func val(v byte) core.ValueHash { + var vh core.ValueHash + for i := range vh { + vh[i] = v + } + return vh +} + +// Helper: compute the expected root from a set of key-value pairs using +// BuildTrie directly (the "oracle"). +func expectedRoot(kvs []core.KeyValue) core.Node { + return core.BuildTrie(0, kvs, func(_ core.WriteNode) {}) +} + +func TestPageWalkerEmptyTrie(t *testing.T) { + ps := NewMemoryPageSet(true) + walker := NewPageWalker(core.Terminator, nil) + out := walker.Conclude() + assert.Equal(t, core.Terminator, out.Root) + assert.Empty(t, out.Pages) + _ = ps +} + +func TestPageWalkerSingleInsert(t *testing.T) { + ps := NewMemoryPageSet(true) + walker := NewPageWalker(core.Terminator, nil) + + kp := keyPath(false, false) + v := val(1) + pos := triePos(false, false) + + walker.AdvanceAndReplace(ps, pos, []core.KeyValue{{Key: kp, Value: v}}) + out := walker.Conclude() + + expected := expectedRoot([]core.KeyValue{{Key: kp, Value: v}}) + assert.Equal(t, expected, out.Root) + assert.True(t, core.IsLeaf(&out.Root)) +} + +func TestPageWalkerTwoInsertsSameAdvance(t *testing.T) { + // Two keys that share a common prefix at position [0,0], then diverge. + ps := NewMemoryPageSet(true) + walker := NewPageWalker(core.Terminator, nil) + + kp1 := keyPath(false, false, true, false) // 0010... + kp2 := keyPath(false, false, true, true) // 0011... + v1, v2 := val(1), val(2) + + pos := triePos(false, false) + walker.AdvanceAndReplace(ps, pos, []core.KeyValue{ + {Key: kp1, Value: v1}, + {Key: kp2, Value: v2}, + }) + out := walker.Conclude() + + expected := expectedRoot([]core.KeyValue{ + {Key: kp1, Value: v1}, + {Key: kp2, Value: v2}, + }) + assert.Equal(t, expected, out.Root) + assert.True(t, core.IsInternal(&out.Root)) +} + +func TestPageWalkerTwoAdvances(t *testing.T) { + // First advance inserts at [0], second at [1]. + ps := NewMemoryPageSet(true) + walker := NewPageWalker(core.Terminator, nil) + + kp0 := keyPath(false, false) // 00... + kp1 := keyPath(true, false) // 10... + v0, v1 := val(1), val(2) + + walker.AdvanceAndReplace(ps, triePos(false), []core.KeyValue{ + {Key: kp0, Value: v0}, + }) + walker.AdvanceAndReplace(ps, triePos(true), []core.KeyValue{ + {Key: kp1, Value: v1}, + }) + out := walker.Conclude() + + expected := expectedRoot([]core.KeyValue{ + {Key: kp0, Value: v0}, + {Key: kp1, Value: v1}, + }) + assert.Equal(t, expected, out.Root) +} + +func TestPageWalkerMultipleAdvances(t *testing.T) { + // Match the Rust multi_value test pattern: + // 0b00010000 = 0x10, 0b00100000 = 0x20, 0b01000000 = 0x40, + // 0b10100000 = 0xA0, 0b10110000 = 0xB0 + ps := NewMemoryPageSet(true) + walker := NewPageWalker(core.Terminator, nil) + + kvA := core.KeyValue{Key: makeKVKey(0x10), Value: makeKVVal(0x10)} + kvB := core.KeyValue{Key: makeKVKey(0x20), Value: makeKVVal(0x20)} + kvC := core.KeyValue{Key: makeKVKey(0x40), Value: makeKVVal(0x40)} + kvD := core.KeyValue{Key: makeKVKey(0xA0), Value: makeKVVal(0xA0)} + kvE := core.KeyValue{Key: makeKVKey(0xB0), Value: makeKVVal(0xB0)} + + allOps := []core.KeyValue{kvA, kvB, kvC, kvD, kvE} + expected := expectedRoot(allOps) + + // Group by terminal: A and B share prefix 00, C is at 01, D and E + // share prefix 101. + // Terminal at [0,0]: A, B + walker.AdvanceAndReplace(ps, triePos(false, false), + []core.KeyValue{kvA, kvB}) + // Terminal at [0,1]: C + walker.AdvanceAndReplace(ps, triePos(false, true), + []core.KeyValue{kvC}) + // Terminal at [1]: D, E + walker.AdvanceAndReplace(ps, triePos(true), + []core.KeyValue{kvD, kvE}) + + out := walker.Conclude() + assert.Equal(t, expected, out.Root) +} + +func TestPageWalkerDeleteToTerminator(t *testing.T) { + // Insert a leaf, then delete it in a second walker pass. + ps := NewMemoryPageSet(true) + + kp := keyPath(false) + v := val(1) + + // First: insert a leaf. + walker1 := NewPageWalker(core.Terminator, nil) + walker1.AdvanceAndReplace(ps, triePos(false), []core.KeyValue{ + {Key: kp, Value: v}, + }) + out1 := walker1.Conclude() + require.True(t, core.IsLeaf(&out1.Root)) + ps.Apply(out1.Pages) + + // Second: delete it (empty ops = terminator replacement). + walker2 := NewPageWalker(out1.Root, nil) + walker2.AdvanceAndReplace(ps, triePos(false), nil) + out2 := walker2.Conclude() + assert.Equal(t, core.Terminator, out2.Root) +} + +func TestPageWalkerCompactionLeafUp(t *testing.T) { + // When one sibling becomes a terminator and the other is a leaf, + // the leaf should be compacted upward. + ps := NewMemoryPageSet(true) + + kp0 := keyPath(false) + kp1 := keyPath(true) + v := val(1) + + // Insert two leaves. + walker1 := NewPageWalker(core.Terminator, nil) + walker1.AdvanceAndReplace(ps, triePos(false), []core.KeyValue{ + {Key: kp0, Value: v}, + }) + walker1.AdvanceAndReplace(ps, triePos(true), []core.KeyValue{ + {Key: kp1, Value: v}, + }) + out1 := walker1.Conclude() + ps.Apply(out1.Pages) + + // Delete the right leaf — left leaf should compact up to root. + walker2 := NewPageWalker(out1.Root, nil) + walker2.AdvanceAndReplace(ps, triePos(true), nil) + out2 := walker2.Conclude() + + expectedLeaf := core.HashLeaf(&core.LeafData{KeyPath: kp0, ValueHash: v}) + assert.Equal(t, expectedLeaf, out2.Root) +} + +func TestPageWalkerOutputPages(t *testing.T) { + ps := NewMemoryPageSet(true) + walker := NewPageWalker(core.Terminator, nil) + + kp := keyPath(false) + v := val(1) + + walker.AdvanceAndReplace(ps, triePos(false), []core.KeyValue{ + {Key: kp, Value: v}, + }) + out := walker.Conclude() + + // Should output at least the root page. + require.NotEmpty(t, out.Pages) + assert.True(t, out.Pages[0].PageID.IsRoot()) +} + +func TestPageWalkerAdvanceBackwardsPanics(t *testing.T) { + ps := NewMemoryPageSet(true) + walker := NewPageWalker(core.Terminator, nil) + + walker.AdvanceAndReplace(ps, triePos(true), []core.KeyValue{ + {Key: keyPath(true), Value: val(1)}, + }) + + assert.Panics(t, func() { + walker.AdvanceAndReplace(ps, triePos(false), []core.KeyValue{ + {Key: keyPath(false), Value: val(2)}, + }) + }) +} + +func TestPageWalkerDeterministic(t *testing.T) { + // Same inputs should produce the same root. + kvs := []core.KeyValue{ + {Key: makeKVKey(0x10), Value: makeKVVal(0x10)}, + {Key: makeKVKey(0x50), Value: makeKVVal(0x50)}, + {Key: makeKVKey(0xA0), Value: makeKVVal(0xA0)}, + } + + run := func() core.Node { + ps := NewMemoryPageSet(true) + w := NewPageWalker(core.Terminator, nil) + // All at root terminal since trie is empty. + w.AdvanceAndReplace(ps, triePos(false), kvs[:2]) + w.AdvanceAndReplace(ps, triePos(true), kvs[2:]) + return w.Conclude().Root + } + + r1 := run() + r2 := run() + assert.Equal(t, r1, r2) +} + +func TestPageWalkerIncrementalUpdates(t *testing.T) { + // Build a trie, apply updates, verify the resulting root matches + // building from scratch. + ps := NewMemoryPageSet(true) + + kp0 := makeKVKey(0x10) + kp1 := makeKVKey(0x80) + v1, v2 := makeKVVal(0x01), makeKVVal(0x02) + + // Pass 1: insert two keys. + w1 := NewPageWalker(core.Terminator, nil) + w1.AdvanceAndReplace(ps, triePos(false), []core.KeyValue{ + {Key: kp0, Value: v1}, + }) + w1.AdvanceAndReplace(ps, triePos(true), []core.KeyValue{ + {Key: kp1, Value: v1}, + }) + out1 := w1.Conclude() + ps.Apply(out1.Pages) + + // Pass 2: update the second key's value. + w2 := NewPageWalker(out1.Root, nil) + w2.AdvanceAndReplace(ps, triePos(true), []core.KeyValue{ + {Key: kp1, Value: v2}, + }) + out2 := w2.Conclude() + + // The expected root should match building the whole trie from scratch + // with the updated value. + expected := expectedRoot([]core.KeyValue{ + {Key: kp0, Value: v1}, + {Key: kp1, Value: v2}, + }) + assert.Equal(t, expected, out2.Root) +} + +func TestPageWalkerAdvanceWithoutModify(t *testing.T) { + // Test the Advance (read-only) method. + ps := NewMemoryPageSet(true) + walker := NewPageWalker(core.Terminator, nil) + + kp0 := keyPath(false, false, true, false) // 0010... + kp1 := keyPath(false, false, true, true) // 0011... + kp2 := keyPath(true) // 1... + v := val(1) + + walker.AdvanceAndReplace(ps, triePos(false, false), []core.KeyValue{ + {Key: kp0, Value: v}, + {Key: kp1, Value: v}, + }) + + // Advance past [0,1] without modifying — the walker should still + // compact correctly. + walker.Advance(triePos(false, true)) + + walker.AdvanceAndReplace(ps, triePos(true), []core.KeyValue{ + {Key: kp2, Value: v}, + }) + + out := walker.Conclude() + + expected := expectedRoot([]core.KeyValue{ + {Key: kp0, Value: v}, + {Key: kp1, Value: v}, + {Key: kp2, Value: v}, + }) + assert.Equal(t, expected, out.Root) +} + +func TestPageWalkerPageDiffs(t *testing.T) { + // Verify that output pages have non-empty diffs. + ps := NewMemoryPageSet(true) + walker := NewPageWalker(core.Terminator, nil) + + kp := keyPath(false, false) + v := val(1) + walker.AdvanceAndReplace(ps, triePos(false, false), []core.KeyValue{ + {Key: kp, Value: v}, + }) + out := walker.Conclude() + + require.NotEmpty(t, out.Pages) + // The root page should have at least one changed node. + assert.True(t, out.Pages[0].Diff.Count() > 0, + "page diff should track changed nodes") +} + +// --- helpers --- + +func makeKVKey(b byte) core.KeyPath { + var kp core.KeyPath + for i := range kp { + kp[i] = b + } + return kp +} + +func makeKVVal(b byte) core.ValueHash { + var vh core.ValueHash + for i := range vh { + vh[i] = b + } + return vh +}