From 7aebfb3c71491ebf8f0b8b48ac0e2a24e1aa374f Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 12 Feb 2026 17:05:00 +0800 Subject: [PATCH] nomt/core: add Phase 1 core primitives for NOMT binary merkle trie Implement foundational types and algorithms for the NOMT storage engine: - node.go: Node/KeyPath/ValueHash types with MSB-based kind discrimination - hasher.go: Keccak256 hashing with leaf/internal MSB labeling - page.go: 4096-byte RawPage layout (126 nodes + elided children + pageID) - pageid.go: PageID encode/decode with shift-then-add encoding - triepos.go: TriePosition navigation (Down/Up/Sibling/PageID/NodeIndex) - pagediff.go: 128-bit PageDiff bitfield for tracking changed nodes - update.go: BuildTrie 3-pointer left-frontier algorithm, LeafOpsSpliced Co-Authored-By: Claude Opus 4.6 --- nomt/core/hasher.go | 46 +++++++ nomt/core/node.go | 71 ++++++++++ nomt/core/node_test.go | 150 ++++++++++++++++++++ nomt/core/page.go | 71 ++++++++++ nomt/core/page_test.go | 92 +++++++++++++ nomt/core/pagediff.go | 127 +++++++++++++++++ nomt/core/pagediff_test.go | 123 +++++++++++++++++ nomt/core/pageid.go | 276 +++++++++++++++++++++++++++++++++++++ nomt/core/pageid_test.go | 243 ++++++++++++++++++++++++++++++++ nomt/core/triepos.go | 253 ++++++++++++++++++++++++++++++++++ nomt/core/triepos_test.go | 170 +++++++++++++++++++++++ nomt/core/update.go | 275 ++++++++++++++++++++++++++++++++++++ nomt/core/update_test.go | 238 ++++++++++++++++++++++++++++++++ 13 files changed, 2135 insertions(+) create mode 100644 nomt/core/hasher.go create mode 100644 nomt/core/node.go create mode 100644 nomt/core/node_test.go create mode 100644 nomt/core/page.go create mode 100644 nomt/core/page_test.go create mode 100644 nomt/core/pagediff.go create mode 100644 nomt/core/pagediff_test.go create mode 100644 nomt/core/pageid.go create mode 100644 nomt/core/pageid_test.go create mode 100644 nomt/core/triepos.go create mode 100644 nomt/core/triepos_test.go create mode 100644 nomt/core/update.go create mode 100644 nomt/core/update_test.go diff --git a/nomt/core/hasher.go b/nomt/core/hasher.go new file mode 100644 index 0000000000..ce9850d333 --- /dev/null +++ b/nomt/core/hasher.go @@ -0,0 +1,46 @@ +package core + +import "golang.org/x/crypto/sha3" + +// HashLeaf computes the hash of a leaf node: keccak256(keyPath || valueHash) +// with the MSB of byte 0 set to 1. +func HashLeaf(data *LeafData) Node { + h := sha3.NewLegacyKeccak256() + h.Write(data.KeyPath[:]) + h.Write(data.ValueHash[:]) + var out Node + h.Sum(out[:0]) + setMSB(&out) + return out +} + +// HashInternal computes the hash of an internal node: keccak256(left || right) +// with the MSB of byte 0 cleared to 0. +func HashInternal(data *InternalData) Node { + h := sha3.NewLegacyKeccak256() + h.Write(data.Left[:]) + h.Write(data.Right[:]) + var out Node + h.Sum(out[:0]) + clearMSB(&out) + return out +} + +// HashValue computes keccak256 of an arbitrary-length value. +func HashValue(value []byte) ValueHash { + h := sha3.NewLegacyKeccak256() + h.Write(value) + var out ValueHash + h.Sum(out[:0]) + return out +} + +// setMSB sets the most significant bit (bit 7 of byte 0) to 1. +func setMSB(n *Node) { + n[0] |= 0x80 +} + +// clearMSB clears the most significant bit (bit 7 of byte 0) to 0. +func clearMSB(n *Node) { + n[0] &= 0x7F +} diff --git a/nomt/core/node.go b/nomt/core/node.go new file mode 100644 index 0000000000..aa4615996a --- /dev/null +++ b/nomt/core/node.go @@ -0,0 +1,71 @@ +// Package core defines the fundamental data structures for a NOMT binary +// merkle trie. All types are pure computation with no I/O dependencies. +package core + +// Node is a 256-bit hash representing a node in the binary merkle trie. +// The MSB of byte 0 discriminates leaves (MSB=1) from internal nodes (MSB=0). +// The all-zeros value is reserved as the Terminator. +type Node = [32]byte + +// KeyPath is the 256-bit lookup path for a key in the trie. +type KeyPath = [32]byte + +// ValueHash is the 256-bit hash of a value stored at a leaf. +type ValueHash = [32]byte + +// Terminator is the special node value denoting an empty sub-trie. +// When this appears at a location, no key with a matching path prefix has a value. +var Terminator Node + +// NodeKind discriminates the three kinds of trie nodes. +type NodeKind int + +const ( + // NodeTerminator indicates an empty sub-trie (all-zero node). + NodeTerminator NodeKind = iota + // NodeLeaf indicates a leaf node (MSB of byte 0 is 1). + NodeLeaf + // NodeInternal indicates an internal (branch) node (MSB of byte 0 is 0, non-zero). + NodeInternal +) + +// NodeKindOf returns the kind of the given node using MSB discrimination. +// +// If the MSB of byte 0 is set, it is a leaf. If the node is all zeros, +// it is a terminator. Otherwise it is an internal node. +func NodeKindOf(n *Node) NodeKind { + if n[0]>>7 == 1 { + return NodeLeaf + } + if *n == Terminator { + return NodeTerminator + } + return NodeInternal +} + +// IsTerminator reports whether the node is the all-zero terminator. +func IsTerminator(n *Node) bool { + return *n == Terminator +} + +// IsLeaf reports whether the node's MSB indicates a leaf. +func IsLeaf(n *Node) bool { + return n[0]>>7 == 1 +} + +// IsInternal reports whether the node is a non-terminator internal node. +func IsInternal(n *Node) bool { + return n[0]>>7 == 0 && *n != Terminator +} + +// InternalData holds the preimage of an internal (branch) node. +type InternalData struct { + Left Node + Right Node +} + +// LeafData holds the preimage of a leaf node. +type LeafData struct { + KeyPath KeyPath + ValueHash ValueHash +} diff --git a/nomt/core/node_test.go b/nomt/core/node_test.go new file mode 100644 index 0000000000..824969ad11 --- /dev/null +++ b/nomt/core/node_test.go @@ -0,0 +1,150 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTerminatorIsZero(t *testing.T) { + var zero [32]byte + assert.Equal(t, zero, Terminator) +} + +func TestNodeKindOf(t *testing.T) { + tests := []struct { + name string + node Node + want NodeKind + }{ + { + name: "terminator", + node: Terminator, + want: NodeTerminator, + }, + { + name: "leaf with MSB set", + node: Node{0x80, 0x01, 0x02}, + want: NodeLeaf, + }, + { + name: "leaf with all bits set in first byte", + node: Node{0xFF, 0x01}, + want: NodeLeaf, + }, + { + name: "internal node", + node: Node{0x01, 0x02, 0x03}, + want: NodeInternal, + }, + { + name: "internal with MSB clear", + node: Node{0x7F, 0xFF, 0xFF}, + want: NodeInternal, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NodeKindOf(&tt.node) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestIsTerminator(t *testing.T) { + assert.True(t, IsTerminator(&Terminator)) + + nonZero := Node{0x01} + assert.False(t, IsTerminator(&nonZero)) +} + +func TestIsLeaf(t *testing.T) { + leaf := Node{0x80} + assert.True(t, IsLeaf(&leaf)) + + internal := Node{0x7F, 0xFF} + assert.False(t, IsLeaf(&internal)) +} + +func TestIsInternal(t *testing.T) { + internal := Node{0x01} + assert.True(t, IsInternal(&internal)) + + leaf := Node{0x80} + assert.False(t, IsInternal(&leaf)) + + assert.False(t, IsInternal(&Terminator)) +} + +func TestHashLeafSetsMSB(t *testing.T) { + data := &LeafData{ + KeyPath: KeyPath{0x01, 0x02, 0x03}, + ValueHash: ValueHash{0x04, 0x05, 0x06}, + } + result := HashLeaf(data) + require.True(t, IsLeaf(&result), "HashLeaf must produce a leaf node") + require.False(t, IsTerminator(&result)) +} + +func TestHashInternalClearsMSB(t *testing.T) { + data := &InternalData{ + Left: Node{0xFF, 0x01}, + Right: Node{0x80, 0x02}, + } + result := HashInternal(data) + require.True(t, IsInternal(&result), + "HashInternal must produce an internal node") + require.False(t, IsLeaf(&result)) +} + +func TestHashLeafDeterministic(t *testing.T) { + data := &LeafData{ + KeyPath: KeyPath{0xAB, 0xCD}, + ValueHash: ValueHash{0xEF, 0x01}, + } + h1 := HashLeaf(data) + h2 := HashLeaf(data) + assert.Equal(t, h1, h2, "same inputs must produce same hash") +} + +func TestHashInternalDeterministic(t *testing.T) { + data := &InternalData{ + Left: Node{0x11, 0x22}, + Right: Node{0x33, 0x44}, + } + h1 := HashInternal(data) + h2 := HashInternal(data) + assert.Equal(t, h1, h2, "same inputs must produce same hash") +} + +func TestHashLeafDiffersFromInternal(t *testing.T) { + // Using the same 64-byte preimage for both should produce different + // hashes due to MSB tagging (even if the raw keccak is the same, + // the MSB bit will differ). + var key KeyPath + var val ValueHash + for i := range key { + key[i] = byte(i) + } + for i := range val { + val[i] = byte(i + 32) + } + + leaf := HashLeaf(&LeafData{KeyPath: key, ValueHash: val}) + internal := HashInternal(&InternalData{Left: Node(key), Right: Node(val)}) + + // They share the same keccak input, but MSB tagging makes them differ. + assert.NotEqual(t, leaf, internal, + "leaf and internal hashes must differ due to MSB tagging") +} + +func TestHashValue(t *testing.T) { + v1 := HashValue([]byte("hello")) + v2 := HashValue([]byte("hello")) + v3 := HashValue([]byte("world")) + + assert.Equal(t, v1, v2, "same value must produce same hash") + assert.NotEqual(t, v1, v3, "different values must differ") +} diff --git a/nomt/core/page.go b/nomt/core/page.go new file mode 100644 index 0000000000..83beda5dd3 --- /dev/null +++ b/nomt/core/page.go @@ -0,0 +1,71 @@ +package core + +import "encoding/binary" + +// Page layout constants. +const ( + // PageDepth is the depth of the rootless sub-binary tree stored in a page. + PageDepth = 6 + + // NodesPerPage is the total number of nodes in one page: (2^(depth+1)) - 2 = 126. + NodesPerPage = (1 << (PageDepth + 1)) - 2 + + // NumChildren is the number of child pages each page can have: 2^depth = 64. + NumChildren = 1 << PageDepth + + // PageSize is the size of a raw page in bytes, aligned to SSD page size. + PageSize = 4096 + + // elidedChildrenOffset stores the 8-byte elided children bitfield. + // Layout: [nodes 4032] [padding 24] [elided 8] [pageID 32] = 4096 + elidedChildrenOffset = PageSize - 32 - 8 // 4056 + + // pageIDOffset stores the 32-byte encoded PageID. + pageIDOffset = PageSize - 32 // 4064 +) + +// RawPage is a 4096-byte page storing a rootless sub-tree of depth 6. +// +// Layout: +// +// [0..4032) 126 nodes × 32 bytes each, in level-order +// [4032..4056) 24 bytes padding +// [4056..4064) ElidedChildren bitfield (8 bytes, little-endian uint64) +// [4064..4096) PageID encoded (32 bytes) +type RawPage [PageSize]byte + +// NodeAt reads the 32-byte node at the given index (0-based level-order). +func (p *RawPage) NodeAt(index int) Node { + var n Node + off := index * 32 + copy(n[:], p[off:off+32]) + return n +} + +// SetNodeAt writes a 32-byte node at the given index. +func (p *RawPage) SetNodeAt(index int, n Node) { + off := index * 32 + copy(p[off:off+32], n[:]) +} + +// ElidedChildren reads the 8-byte elided children bitfield. +func (p *RawPage) ElidedChildren() uint64 { + return binary.LittleEndian.Uint64(p[elidedChildrenOffset:]) +} + +// SetElidedChildren writes the 8-byte elided children bitfield. +func (p *RawPage) SetElidedChildren(ec uint64) { + binary.LittleEndian.PutUint64(p[elidedChildrenOffset:], ec) +} + +// PageIDBytes reads the 32-byte encoded PageID from the page. +func (p *RawPage) PageIDBytes() [32]byte { + var id [32]byte + copy(id[:], p[pageIDOffset:pageIDOffset+32]) + return id +} + +// SetPageIDBytes writes the 32-byte encoded PageID into the page. +func (p *RawPage) SetPageIDBytes(id [32]byte) { + copy(p[pageIDOffset:pageIDOffset+32], id[:]) +} diff --git a/nomt/core/page_test.go b/nomt/core/page_test.go new file mode 100644 index 0000000000..c7fe2dde24 --- /dev/null +++ b/nomt/core/page_test.go @@ -0,0 +1,92 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPageConstants(t *testing.T) { + assert.Equal(t, 6, PageDepth) + assert.Equal(t, 126, NodesPerPage) + assert.Equal(t, 64, NumChildren) + assert.Equal(t, 4096, PageSize) +} + +func TestPageNodeRoundTrip(t *testing.T) { + var page RawPage + node := Node{0xAB, 0xCD, 0xEF} + + tests := []struct { + name string + index int + }{ + {"first node", 0}, + {"middle node", 63}, + {"last node", NodesPerPage - 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + page.SetNodeAt(tt.index, node) + got := page.NodeAt(tt.index) + assert.Equal(t, node, got) + }) + } +} + +func TestPageElidedChildrenRoundTrip(t *testing.T) { + var page RawPage + + tests := []struct { + name string + ec uint64 + }{ + {"zero", 0}, + {"all set", ^uint64(0)}, + {"first bit", 1}, + {"last bit", 1 << 63}, + {"alternating", 0xAAAAAAAAAAAAAAAA}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + page.SetElidedChildren(tt.ec) + got := page.ElidedChildren() + assert.Equal(t, tt.ec, got) + }) + } +} + +func TestPageIDRoundTrip(t *testing.T) { + var page RawPage + id := [32]byte{0x01, 0x02, 0x03} + + page.SetPageIDBytes(id) + got := page.PageIDBytes() + assert.Equal(t, id, got) +} + +func TestPageRegionsDontOverlap(t *testing.T) { + var page RawPage + + // Write to last node + lastNode := Node{0xFF, 0xFF, 0xFF, 0xFF} + page.SetNodeAt(NodesPerPage-1, lastNode) + + // Write elided children + page.SetElidedChildren(0xDEADBEEFCAFEBABE) + + // Write PageID + var id [32]byte + for i := range id { + id[i] = byte(i) + } + page.SetPageIDBytes(id) + + // Verify none of them corrupted each other + require.Equal(t, lastNode, page.NodeAt(NodesPerPage-1)) + require.Equal(t, uint64(0xDEADBEEFCAFEBABE), page.ElidedChildren()) + require.Equal(t, id, page.PageIDBytes()) +} diff --git a/nomt/core/pagediff.go b/nomt/core/pagediff.go new file mode 100644 index 0000000000..238703514d --- /dev/null +++ b/nomt/core/pagediff.go @@ -0,0 +1,127 @@ +package core + +import ( + "encoding/binary" + "math/bits" +) + +// PageDiff tracks which nodes in a page have changed, using a 128-bit +// bitfield stored as two uint64 words. +// +// Bit 63 of the second word is the "cleared" flag, indicating the page +// was cleared entirely. Bit 62 of the second word is reserved. +type PageDiff struct { + words [2]uint64 +} + +const clearBit = uint64(1) << 63 + +// SetChanged marks the node at the given index (0-125) as changed. +// Also clears the "cleared" flag if set. +func (d *PageDiff) SetChanged(index int) { + // Always clear the "cleared" flag when setting a changed node. + d.words[1] &= ^clearBit + if index < 64 { + d.words[0] |= 1 << index + } else { + d.words[1] |= 1 << (index - 64) + } +} + +// IsChanged reports whether the node at the given index is marked changed. +func (d *PageDiff) IsChanged(index int) bool { + if index < 64 { + return d.words[0]&(1<> 56) + buf[25] = byte(word >> 48) + buf[26] = byte(word >> 40) + buf[27] = byte(word >> 32) + buf[28] = byte(word >> 24) + buf[29] = byte(word >> 16) + buf[30] = byte(word >> 8) + buf[31] = byte(word) + return buf + } + + // Slow path: use big.Int for deep pages. + // Same shift-then-add order as fast path. + val := new(big.Int) + for _, limb := range id.path { + val.Lsh(val, 6) + val.Add(val, big.NewInt(int64(limb)+1)) + } + + var buf [32]byte + b := val.Bytes() + copy(buf[32-len(b):], b) + return buf +} + +// DecodePageID decodes a PageID from its 256-bit representation. +func DecodePageID(bytes [32]byte) (PageID, error) { + val := new(big.Int).SetBytes(bytes[:]) + + if val.Cmp(highestEncoded42) > 0 { + return PageID{}, ErrInvalidPageIDBytes + } + + if val.Sign() == 0 { + return RootPageID(), nil + } + + bitLen := val.BitLen() + sextets := (bitLen + 5) / 6 + + path := make([]uint8, 0, sextets) + for i := 0; i < sextets-1; i++ { + val.Sub(val, bigOne) + x := new(big.Int).And(val, big63) + path = append(path, uint8(x.Uint64())) + val.Rsh(val, 6) + } + // Last sextet: only push if non-zero after subtracting 1. + if val.Sign() != 0 { + val.Sub(val, bigOne) + path = append(path, uint8(val.Uint64())) + } + + // Reverse to get most-significant first. + for i, j := 0, len(path)-1; i < j; i, j = i+1, j-1 { + path[i], path[j] = path[j], path[i] + } + + return PageID{path: path}, nil +} + +// ChildPageID returns the child PageID at the given child index (0-63). +func (id PageID) ChildPageID(childIndex uint8) (PageID, error) { + if childIndex > MaxChildIndex { + return PageID{}, ErrPageIDOverflow + } + if len(id.path) >= MaxPageDepth { + return PageID{}, ErrPageIDOverflow + } + p := make([]uint8, len(id.path)+1) + copy(p, id.path) + p[len(id.path)] = childIndex + return PageID{path: p}, nil +} + +// ParentPageID returns the parent PageID. If this is the root, returns root. +func (id PageID) ParentPageID() PageID { + if len(id.path) == 0 { + return RootPageID() + } + p := make([]uint8, len(id.path)-1) + copy(p, id.path[:len(id.path)-1]) + return PageID{path: p} +} + +// IsDescendantOf reports whether this page is a descendant of other. +func (id PageID) IsDescendantOf(other PageID) bool { + if len(id.path) < len(other.path) { + return false + } + for i := range other.path { + if id.path[i] != other.path[i] { + return false + } + } + return true +} + +// Equal reports whether two PageIDs are the same. +func (id PageID) Equal(other PageID) bool { + if len(id.path) != len(other.path) { + return false + } + for i := range id.path { + if id.path[i] != other.path[i] { + return false + } + } + return true +} + +// MinKeyPath returns the minimum key path that could land in this page. +func (id PageID) MinKeyPath() KeyPath { + var path KeyPath + for i, childIndex := range id.path { + setBitsInKeyPath(&path, i*6, childIndex) + } + // Remaining bits are already zero. + return path +} + +// MaxKeyPath returns the maximum key path that could land in this page. +func (id PageID) MaxKeyPath() KeyPath { + var path KeyPath + // Fill all with 1s first. + for i := range path { + path[i] = 0xFF + } + // Set the prefix bits from the page path. + for i, childIndex := range id.path { + setBitsInKeyPath(&path, i*6, childIndex) + } + return path +} + +// setBitsInKeyPath writes a 6-bit child index into the key path at the given +// bit offset. +func setBitsInKeyPath(path *KeyPath, bitOffset int, childIndex uint8) { + for b := 0; b < 6; b++ { + bit := (childIndex >> (5 - b)) & 1 + byteIdx := (bitOffset + b) / 8 + bitIdx := 7 - ((bitOffset + b) % 8) + if bit == 1 { + path[byteIdx] |= 1 << bitIdx + } else { + path[byteIdx] &^= 1 << bitIdx + } + } +} + +// PageIDsForKeyPath returns the sequence of PageIDs from root down to the +// deepest page containing the given key path. +func PageIDsForKeyPath(keyPath KeyPath) []PageID { + ids := make([]PageID, 0, MaxPageDepth+1) + current := RootPageID() + ids = append(ids, current) + + for depth := 0; depth < MaxPageDepth; depth++ { + bitStart := depth * 6 + childIndex := extractChildIndex(keyPath, bitStart) + child, err := current.ChildPageID(childIndex) + if err != nil { + break + } + ids = append(ids, child) + current = child + } + return ids +} + +// extractChildIndex extracts a 6-bit child index from the key path at the +// given bit offset. +func extractChildIndex(keyPath KeyPath, bitOffset int) uint8 { + var idx uint8 + for b := 0; b < 6; b++ { + byteIdx := (bitOffset + b) / 8 + bitIdx := 7 - ((bitOffset + b) % 8) + bit := (keyPath[byteIdx] >> bitIdx) & 1 + idx = (idx << 1) | bit + } + return idx +} diff --git a/nomt/core/pageid_test.go b/nomt/core/pageid_test.go new file mode 100644 index 0000000000..c22d86f82a --- /dev/null +++ b/nomt/core/pageid_test.go @@ -0,0 +1,243 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRootPageIDEncodeDecode(t *testing.T) { + root := RootPageID() + encoded := root.Encode() + assert.Equal(t, [32]byte{}, encoded, "root encodes to all zeros") + + decoded, err := DecodePageID(encoded) + require.NoError(t, err) + assert.True(t, decoded.IsRoot()) + assert.Equal(t, 0, decoded.Depth()) +} + +func TestPageIDEncodeDecodeRoundTrip(t *testing.T) { + tests := []struct { + name string + path []uint8 + }{ + {"root", nil}, + {"child 0", []uint8{0}}, + {"child 6", []uint8{6}}, + {"child 63", []uint8{63}}, + {"depth 2", []uint8{6, 4}}, + {"depth 3", []uint8{6, 4, 63}}, + {"depth 9 (u64 boundary)", []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9}}, + {"depth 10 (big.Int)", []uint8{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}}, + {"all zeros depth 5", []uint8{0, 0, 0, 0, 0}}, + {"all 63s depth 5", []uint8{63, 63, 63, 63, 63}}, + {"mixed deep", []uint8{0, 63, 0, 63, 0, 63, 0, 63, 0, 63, 0}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id := NewPageID(tt.path) + encoded := id.Encode() + decoded, err := DecodePageID(encoded) + require.NoError(t, err) + assert.True(t, id.Equal(decoded), + "path=%v decoded=%v encoded=%x", + tt.path, decoded.Path(), encoded) + }) + } +} + +func TestPageIDKnownValues(t *testing.T) { + // Shift-then-add encoding: + // encode([6]) = 0*64 + (6+1) = 7 + id1 := NewPageID([]uint8{6}) + enc1 := id1.Encode() + assert.Equal(t, byte(7), enc1[31]) + for i := 0; i < 31; i++ { + assert.Equal(t, byte(0), enc1[i]) + } + + // encode([6, 4]) = 7*64 + (4+1) = 448 + 5 = 453 = 0x01C5 + id2 := NewPageID([]uint8{6, 4}) + enc2 := id2.Encode() + assert.Equal(t, byte(0x01), enc2[30]) + assert.Equal(t, byte(0xC5), enc2[31]) + + // encode([6, 4, 63]) = 453*64 + (63+1) = 28992 + 64 = 29056 = 0x7180 + id3 := NewPageID([]uint8{6, 4, 63}) + enc3 := id3.Encode() + assert.Equal(t, byte(0x71), enc3[30]) + assert.Equal(t, byte(0x80), enc3[31]) +} + +func TestChildAndParentPageID(t *testing.T) { + root := RootPageID() + + page1, err := root.ChildPageID(6) + require.NoError(t, err) + assert.Equal(t, []uint8{6}, page1.Path()) + assert.True(t, page1.ParentPageID().Equal(root)) + + page2, err := page1.ChildPageID(4) + require.NoError(t, err) + assert.Equal(t, []uint8{6, 4}, page2.Path()) + assert.True(t, page2.ParentPageID().Equal(page1)) + + page3, err := page2.ChildPageID(63) + require.NoError(t, err) + assert.Equal(t, []uint8{6, 4, 63}, page3.Path()) + assert.True(t, page3.ParentPageID().Equal(page2)) + + // Verify encode/decode matches child construction. + decoded1, err := DecodePageID(page1.Encode()) + require.NoError(t, err) + assert.True(t, page1.Equal(decoded1)) + + decoded2, err := DecodePageID(page2.Encode()) + require.NoError(t, err) + assert.True(t, page2.Equal(decoded2)) +} + +func TestPageIDOverflow(t *testing.T) { + current := RootPageID() + for range MaxPageDepth { + var err error + current, err = current.ChildPageID(0) + require.NoError(t, err) + } + assert.Equal(t, MaxPageDepth, current.Depth()) + + _, err := current.ChildPageID(0) + assert.ErrorIs(t, err, ErrPageIDOverflow) +} + +func TestPageIDOverflowMaxChild(t *testing.T) { + current := RootPageID() + for range MaxPageDepth { + var err error + current, err = current.ChildPageID(63) + require.NoError(t, err) + } + _, err := current.ChildPageID(0) + assert.ErrorIs(t, err, ErrPageIDOverflow) +} + +func TestInvalidPageIDBytes(t *testing.T) { + var bytes [32]byte + bytes[0] = 128 // bit 255 set + _, err := DecodePageID(bytes) + assert.ErrorIs(t, err, ErrInvalidPageIDBytes) +} + +func TestPageIDSiblingOrdering(t *testing.T) { + root := RootPageID() + var lastEnc [32]byte + for i := range uint8(NumChildren) { + child, err := root.ChildPageID(i) + require.NoError(t, err) + enc := child.Encode() + + assert.NotEqual(t, [32]byte{}, enc) + + if i > 0 { + assert.True(t, compareBE(enc, lastEnc) > 0, + "child %d should sort after child %d", i, i-1) + } + lastEnc = enc + } +} + +func TestPageIDIsDescendantOf(t *testing.T) { + root := RootPageID() + child, _ := root.ChildPageID(5) + grandchild, _ := child.ChildPageID(10) + + assert.True(t, child.IsDescendantOf(root)) + assert.True(t, grandchild.IsDescendantOf(root)) + assert.True(t, grandchild.IsDescendantOf(child)) + assert.False(t, root.IsDescendantOf(child)) + assert.True(t, root.IsDescendantOf(root)) +} + +func TestRootMinMaxKeyPath(t *testing.T) { + root := RootPageID() + assert.Equal(t, [32]byte{}, root.MinKeyPath()) + + var allOnes [32]byte + for i := range allOnes { + allOnes[i] = 0xFF + } + assert.Equal(t, allOnes, root.MaxKeyPath()) +} + +func TestPageMinMaxKeyPath(t *testing.T) { + root := RootPageID() + + // Child 0: first 6 bits are 000000, so min key is all zeros. + minPage, _ := root.ChildPageID(0) + assert.Equal(t, [32]byte{}, minPage.MinKeyPath()) + + // Child 63: first 6 bits are 111111 → 0xFC in first byte. + maxPage, _ := root.ChildPageID(63) + minKey := maxPage.MinKeyPath() + assert.Equal(t, byte(0xFC), minKey[0]) + for i := 1; i < 32; i++ { + assert.Equal(t, byte(0), minKey[i]) + } + + // Child 0: max key has 000000 prefix then all ones → 0x03 then 0xFF. + maxKey := minPage.MaxKeyPath() + assert.Equal(t, byte(0x03), maxKey[0]) + for i := 1; i < 32; i++ { + assert.Equal(t, byte(0xFF), maxKey[i]) + } +} + +func TestPageIDsForKeyPath(t *testing.T) { + // Key path: first 6 bits = 000001 (=1), next 6 bits = 000010 (=2) + var keyPath KeyPath + keyPath[0] = 0b00000100 // bits: 000001|00... + keyPath[1] = 0b00100000 // bits: ...0010|0000... + + ids := PageIDsForKeyPath(keyPath) + require.True(t, len(ids) >= 3) + + assert.True(t, ids[0].IsRoot()) + assert.Equal(t, []uint8{1}, ids[1].Path()) + assert.Equal(t, []uint8{1, 2}, ids[2].Path()) +} + +func TestMaxDepthEncodeDecodeRoundTrip(t *testing.T) { + // Build a max-depth page with all zeros. + path := make([]uint8, MaxPageDepth) + id := NewPageID(path) + enc := id.Encode() + dec, err := DecodePageID(enc) + require.NoError(t, err) + assert.True(t, id.Equal(dec)) + + // Build a max-depth page with all 63s. + for i := range path { + path[i] = 63 + } + id = NewPageID(path) + enc = id.Encode() + dec, err = DecodePageID(enc) + require.NoError(t, err) + assert.True(t, id.Equal(dec)) +} + +// compareBE compares two [32]byte big-endian values. +func compareBE(a, b [32]byte) int { + for i := range 32 { + if a[i] < b[i] { + return -1 + } + if a[i] > b[i] { + return 1 + } + } + return 0 +} diff --git a/nomt/core/triepos.go b/nomt/core/triepos.go new file mode 100644 index 0000000000..4e2eddf082 --- /dev/null +++ b/nomt/core/triepos.go @@ -0,0 +1,253 @@ +package core + +// TriePosition tracks a position within the paged binary trie, combining a +// key path prefix with a node index within the current page. +type TriePosition struct { + path [32]byte + depth uint16 + nodeIndex int +} + +// NewTriePosition creates a TriePosition at the root. +func NewTriePosition() TriePosition { + return TriePosition{} +} + +// TriePositionFromPathAndDepth creates a TriePosition at the given depth +// within the path. Panics if depth is 0. +func TriePositionFromPathAndDepth(path KeyPath, depth uint16) TriePosition { + if depth == 0 { + panic("triepos: depth must be non-zero") + } + if depth > 256 { + panic("triepos: depth out of range") + } + pagePath := lastPagePath(path[:], depth) + return TriePosition{ + path: path, + depth: depth, + nodeIndex: computeNodeIndex(pagePath), + } +} + +// IsRoot reports whether the position is at the root. +func (p *TriePosition) IsRoot() bool { + return p.depth == 0 +} + +// Depth returns the current depth in the trie. +func (p *TriePosition) Depth() uint16 { + return p.depth +} + +// Path returns the raw 32-byte path. +func (p *TriePosition) Path() [32]byte { + return p.path +} + +// Bit returns the bit at position n in the path (0 = MSB of byte 0). +func (p *TriePosition) Bit(n int) bool { + return (p.path[n/8]>>(7-n%8))&1 == 1 +} + +// NodeIndex returns the index of the current node within its page. +func (p *TriePosition) NodeIndex() int { + return p.nodeIndex +} + +// Down moves the position down by 1 bit (left if bit=false, right if bit=true). +func (p *TriePosition) Down(bit bool) { + if p.depth == 256 { + panic("triepos: can't descend past 256 bits") + } + if int(p.depth)%PageDepth == 0 { + // Entering a new page: node index resets. + if bit { + p.nodeIndex = 1 + } else { + p.nodeIndex = 0 + } + } else { + left := p.nodeIndex*2 + 2 + if bit { + p.nodeIndex = left + 1 + } else { + p.nodeIndex = left + } + } + setBit(&p.path, int(p.depth), bit) + p.depth++ +} + +// Up moves the position up by d bits. +func (p *TriePosition) Up(d uint16) { + if d > p.depth { + panic("triepos: can't move up past root") + } + newDepth := p.depth - d + if newDepth == 0 { + *p = NewTriePosition() + return + } + + prevPageDepth := (int(p.depth) + PageDepth - 1) / PageDepth + newPageDepth := (int(newDepth) + PageDepth - 1) / PageDepth + p.depth = newDepth + + if prevPageDepth == newPageDepth { + // Same page — walk up parent indices. + for range d { + p.nodeIndex = (p.nodeIndex - 2) / 2 + } + } else { + // Crossed a page boundary — recompute. + pagePath := lastPagePath(p.path[:], p.depth) + p.nodeIndex = computeNodeIndex(pagePath) + } +} + +// Sibling moves to the sibling node. Panics at root. +func (p *TriePosition) Sibling() { + if p.depth == 0 { + panic("triepos: can't sibling at root") + } + i := int(p.depth) - 1 + flipBit(&p.path, i) + p.nodeIndex = siblingIndex(p.nodeIndex) +} + +// PeekLastBit returns the last bit of the path. Panics at root. +func (p *TriePosition) PeekLastBit() bool { + if p.depth == 0 { + panic("triepos: can't peek at root") + } + return p.Bit(int(p.depth) - 1) +} + +// PageID returns the PageID for the page this position lands in. +// Returns nil if at the root. +func (p *TriePosition) PageID() *PageID { + if p.IsRoot() { + return nil + } + + pageID := RootPageID() + d := int(p.depth) + // Number of complete 6-bit chunks before the current partial chunk. + fullChunks := (d - 1) / PageDepth + for i := range fullChunks { + childIndex := extractChildIndex(p.path, i*PageDepth) + pageID, _ = pageID.ChildPageID(childIndex) + } + return &pageID +} + +// DepthInPage returns the number of bits traversed in the current page (1-6), +// or 0 if at the root. +func (p *TriePosition) DepthInPage() int { + if p.depth == 0 { + return 0 + } + d := int(p.depth) + return d - ((d-1)/PageDepth)*PageDepth +} + +// IsFirstLayerInPage reports whether this position is at the top of its page +// (node index 0 or 1). +func (p *TriePosition) IsFirstLayerInPage() bool { + return p.nodeIndex&^1 == 0 +} + +// ChildNodeIndices returns the left and right child node indices within the +// page. Panics if not at depth 1-5 within the page. +func (p *TriePosition) ChildNodeIndices() (left, right int) { + dip := p.DepthInPage() + if dip == 0 || dip > PageDepth-1 { + panic("triepos: child indices out of bounds") + } + left = p.nodeIndex*2 + 2 + right = left + 1 + return +} + +// ChildPageIndex returns the ChildPageIndex for the current node. +// Panics if not at the last layer of the page (indices 62-125). +func (p *TriePosition) ChildPageIndex() uint8 { + if p.nodeIndex < 62 { + panic("triepos: not at last layer") + } + return uint8(p.nodeIndex - 62) +} + +// SiblingIndex returns the index of the sibling node. +func (p *TriePosition) SiblingIndex() int { + return siblingIndex(p.nodeIndex) +} + +// --- internal helpers --- + +// computeNodeIndex converts a page-local bit path to a level-order node index. +// Formula: (2^depth - 2) + bits_as_uint, where depth is 1-6. +func computeNodeIndex(pagePath pageBitPath) int { + depth := pagePath.len + if depth == 0 { + return 0 + } + if depth > PageDepth { + depth = PageDepth + } + return (1 << depth) - 2 + pagePath.asUint(depth) +} + +// pageBitPath represents a sub-slice of bits within a path for node indexing. +type pageBitPath struct { + path []byte // the full 32-byte key path + offset int // bit offset where this page's path starts + len int // number of bits (1-6) +} + +// asUint interprets the first `n` bits as an unsigned integer (MSB first). +func (p pageBitPath) asUint(n int) int { + var val int + for i := range n { + byteIdx := (p.offset + i) / 8 + bitIdx := 7 - (p.offset+i)%8 + bit := int((p.path[byteIdx] >> bitIdx) & 1) + val = (val << 1) | bit + } + return val +} + +// lastPagePath extracts the relevant bit path for the current page. +func lastPagePath(path []byte, depth uint16) pageBitPath { + d := int(depth) + prevPageEnd := ((d - 1) / PageDepth) * PageDepth + return pageBitPath{ + path: path, + offset: prevPageEnd, + len: d - prevPageEnd, + } +} + +func setBit(path *[32]byte, idx int, val bool) { + byteIdx := idx / 8 + bitIdx := uint(7 - idx%8) + if val { + path[byteIdx] |= 1 << bitIdx + } else { + path[byteIdx] &^= 1 << bitIdx + } +} + +func flipBit(path *[32]byte, idx int) { + byteIdx := idx / 8 + bitIdx := uint(7 - idx%8) + path[byteIdx] ^= 1 << bitIdx +} + +func siblingIndex(nodeIndex int) int { + if nodeIndex%2 == 0 { + return nodeIndex + 1 + } + return nodeIndex - 1 +} diff --git a/nomt/core/triepos_test.go b/nomt/core/triepos_test.go new file mode 100644 index 0000000000..37800f0fe5 --- /dev/null +++ b/nomt/core/triepos_test.go @@ -0,0 +1,170 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTriePositionNew(t *testing.T) { + p := NewTriePosition() + assert.True(t, p.IsRoot()) + assert.Equal(t, uint16(0), p.Depth()) + assert.Equal(t, 0, p.NodeIndex()) +} + +func TestTriePositionDown(t *testing.T) { + p := NewTriePosition() + + // Go left: node index 0 (first in new page). + p.Down(false) + assert.Equal(t, uint16(1), p.Depth()) + assert.Equal(t, 0, p.NodeIndex()) + assert.False(t, p.PeekLastBit()) + + // Go right: node index 0*2+2+1 = 3 (right child of node 0). + p.Down(true) + assert.Equal(t, uint16(2), p.Depth()) + assert.Equal(t, 3, p.NodeIndex()) + assert.True(t, p.PeekLastBit()) +} + +func TestTriePositionDownPageBoundary(t *testing.T) { + p := NewTriePosition() + // Descend 6 levels (fills one page). + for range 6 { + p.Down(false) + } + assert.Equal(t, uint16(6), p.Depth()) + // At depth 6: DepthInPage = 6 - ((6-1)/6)*6 = 6-0 = 6. + // This is the last layer (bottom) of the first page. + assert.Equal(t, 6, p.DepthInPage()) + + // Going one more enters a new page. + p.Down(false) + assert.Equal(t, uint16(7), p.Depth()) + assert.Equal(t, 1, p.DepthInPage()) + assert.Equal(t, 0, p.NodeIndex()) // first node in new page +} + +func TestTriePositionNodeIndex(t *testing.T) { + // Manual verification of node_index formula. + tests := []struct { + name string + bits []bool // bits to descend + expected int // expected node index + }{ + {"left", []bool{false}, 0}, + {"right", []bool{true}, 1}, + {"left-left", []bool{false, false}, 2}, + {"left-right", []bool{false, true}, 3}, + {"right-left", []bool{true, false}, 4}, + {"right-right", []bool{true, true}, 5}, + {"3 deep all left", []bool{false, false, false}, 6}, + {"3 deep LLR", []bool{false, false, true}, 7}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewTriePosition() + for _, bit := range tt.bits { + p.Down(bit) + } + assert.Equal(t, tt.expected, p.NodeIndex()) + }) + } +} + +func TestTriePositionUp(t *testing.T) { + p := NewTriePosition() + p.Down(false) // depth 1, index 0 + p.Down(true) // depth 2, index 3 + p.Down(false) // depth 3, index 8 + + p.Up(1) // back to depth 2, index 3 + assert.Equal(t, uint16(2), p.Depth()) + assert.Equal(t, 3, p.NodeIndex()) + + p.Up(2) // back to root + assert.True(t, p.IsRoot()) +} + +func TestTriePositionSibling(t *testing.T) { + p := NewTriePosition() + p.Down(false) // left child, index 0 + assert.Equal(t, 0, p.NodeIndex()) + assert.False(t, p.PeekLastBit()) + + p.Sibling() // now right child, index 1 + assert.Equal(t, 1, p.NodeIndex()) + assert.True(t, p.PeekLastBit()) + + p.Sibling() // back to left + assert.Equal(t, 0, p.NodeIndex()) +} + +func TestTriePositionDepthInPage(t *testing.T) { + tests := []struct { + depth uint16 + expected int + }{ + {0, 0}, + {1, 1}, + {6, 6}, + {7, 1}, // New page starts at depth 7. + {12, 6}, + {13, 1}, + } + + for _, tt := range tests { + p := NewTriePosition() + for range tt.depth { + p.Down(false) + } + assert.Equal(t, tt.expected, p.DepthInPage(), + "depth=%d", tt.depth) + } +} + +func TestTriePositionPageID(t *testing.T) { + p := NewTriePosition() + assert.Nil(t, p.PageID(), "root has no page ID") + + // Descend 1: still in root page. + p.Down(false) + pageID := p.PageID() + require.NotNil(t, pageID) + assert.True(t, pageID.IsRoot()) + + // Descend 6 more (total depth 7): now in child page. + for range 6 { + p.Down(false) + } + pageID = p.PageID() + require.NotNil(t, pageID) + assert.Equal(t, 1, pageID.Depth()) + assert.Equal(t, uint8(0), pageID.ChildIndexAt(0)) +} + +func TestTriePositionChildPageIndex(t *testing.T) { + p := NewTriePosition() + // Descend to depth 6 (bottom of root page) → all left. + for range 6 { + p.Down(false) + } + // At depth 6, node_index should be 62 (first of bottom layer). + assert.Equal(t, 62, p.NodeIndex()) + assert.Equal(t, uint8(0), p.ChildPageIndex()) +} + +func TestTriePositionMax255Depth(t *testing.T) { + p := NewTriePosition() + for range 255 { + p.Down(true) + } + assert.Equal(t, uint16(255), p.Depth()) + // One more descent should work (to 256). + p.Down(false) + assert.Equal(t, uint16(256), p.Depth()) +} diff --git a/nomt/core/update.go b/nomt/core/update.go new file mode 100644 index 0000000000..28911ab70c --- /dev/null +++ b/nomt/core/update.go @@ -0,0 +1,275 @@ +package core + +import "sort" + +// LeafOp represents a leaf operation: set or delete. +// A nil ValueHash pointer means delete. +type LeafOp struct { + Key KeyPath + Value *ValueHash +} + +// KeyValue is a resolved (key, value) pair for trie building. +type KeyValue struct { + Key KeyPath + Value ValueHash +} + +// WriteNodeKind enumerates the types of write commands from BuildTrie. +type WriteNodeKind int + +const ( + WriteNodeLeaf WriteNodeKind = iota + WriteNodeInternal + WriteNodeTerminator +) + +// WriteNode represents a node to be written during trie building. +type WriteNode struct { + Kind WriteNodeKind + Node Node + LeafData *LeafData // set for leaf writes + InternalData *InternalData // set for internal writes + + // Navigation: move up 1 before writing (true for internal nodes and + // non-first leaves). + GoUp bool + // Navigation: bits to descend after going up (only for leaf writes). + DownBits []bool +} + +// SharedBits counts the number of shared prefix bits between two key paths, +// starting after `skip` bits. +func SharedBits(a, b *KeyPath, skip int) int { + count := 0 + for i := skip; i < 256; i++ { + aBit := (a[i/8] >> (7 - i%8)) & 1 + bBit := (b[i/8] >> (7 - i%8)) & 1 + if aBit != bBit { + break + } + count++ + } + return count +} + +// LeafOpsSpliced creates a combined operation list from an existing leaf and +// new operations. If the existing leaf's key is not in ops, it is spliced in. +// Deletions (nil value) are filtered out. +func LeafOpsSpliced(existingLeaf *LeafData, ops []LeafOp) []KeyValue { + // Find splice position: where the existing leaf would be inserted. + spliceIndex := -1 + if existingLeaf != nil { + idx := sort.Search(len(ops), func(i int) bool { + return keyPathCmp(&ops[i].Key, &existingLeaf.KeyPath) >= 0 + }) + if idx >= len(ops) || ops[idx].Key != existingLeaf.KeyPath { + spliceIndex = idx + } + } + + result := make([]KeyValue, 0, len(ops)+1) + + if spliceIndex < 0 { + // No splicing needed — just filter out deletes. + for _, op := range ops { + if op.Value != nil { + result = append(result, KeyValue{op.Key, *op.Value}) + } + } + return result + } + + // Before splice point. + for _, op := range ops[:spliceIndex] { + if op.Value != nil { + result = append(result, KeyValue{op.Key, *op.Value}) + } + } + + // The existing leaf. + result = append(result, KeyValue{ + existingLeaf.KeyPath, + existingLeaf.ValueHash, + }) + + // After splice point. + for _, op := range ops[spliceIndex:] { + if op.Value != nil { + result = append(result, KeyValue{op.Key, *op.Value}) + } + } + + return result +} + +// BuildTrie builds a compact sub-trie from sorted (key, value) pairs. +// +// skip: the number of prefix bits already consumed (all ops share this prefix). +// ops: sorted (KeyPath, ValueHash) pairs. +// visit: callback invoked for each computed node, bottom-up. +// +// Returns the root node of the built sub-trie. +// +// The algorithm uses a 3-pointer sliding window (a, b, c) over the sorted +// ops to determine each leaf's depth based on shared bits with its neighbors. +// Internal nodes are computed by hashing up a left-frontier stack. +func BuildTrie(skip int, ops []KeyValue, visit func(WriteNode)) Node { + if len(ops) == 0 { + visit(WriteNode{Kind: WriteNodeTerminator, Node: Terminator}) + return Terminator + } + + if len(ops) == 1 { + ld := LeafData{ + KeyPath: ops[0].Key, + ValueHash: ops[0].Value, + } + h := HashLeaf(&ld) + visit(WriteNode{ + Kind: WriteNodeLeaf, + Node: h, + LeafData: &ld, + GoUp: false, + }) + return h + } + + // 3-pointer left-frontier algorithm. + type pendingSibling struct { + node Node + layer int + } + pendingSiblings := make([]pendingSibling, 0, 16) + + commonAfterPrefix := func(k1, k2 *KeyPath) int { + return SharedBits(k1, k2, skip) + } + + // Sliding window: a, b, c. + var aKey *KeyPath + var aVal *ValueHash + + for bIdx := 0; bIdx < len(ops); bIdx++ { + thisKey := &ops[bIdx].Key + thisVal := &ops[bIdx].Value + + var n1 *int + if aKey != nil { + v := commonAfterPrefix(aKey, thisKey) + n1 = &v + } + + var n2 *int + if bIdx+1 < len(ops) { + v := commonAfterPrefix(&ops[bIdx+1].Key, thisKey) + n2 = &v + } + + ld := LeafData{KeyPath: *thisKey, ValueHash: *thisVal} + leaf := HashLeaf(&ld) + + var leafDepth, hashUpLayers int + switch { + case n1 == nil && n2 == nil: + leafDepth = 0 + hashUpLayers = 0 + case n1 == nil && n2 != nil: + leafDepth = *n2 + 1 + hashUpLayers = 0 + case n1 != nil && n2 == nil: + leafDepth = *n1 + 1 + hashUpLayers = *n1 + 1 + default: + leafDepth = max(*n1, *n2) + 1 + hashUpLayers = 0 + if *n1 > *n2 { + hashUpLayers = *n1 - *n2 + } + } + + layer := leafDepth + lastNode := leaf + + // Compute down bits for the visitor. + downStart := skip + if n1 != nil { + downStart = skip + *n1 + } + leafEndBit := skip + leafDepth + + var downBits []bool + if leafEndBit > downStart { + downBits = make([]bool, leafEndBit-downStart) + for i := downStart; i < leafEndBit; i++ { + downBits[i-downStart] = bitAt(thisKey, i) + } + } + + visit(WriteNode{ + Kind: WriteNodeLeaf, + Node: leaf, + LeafData: &ld, + GoUp: n1 != nil, + DownBits: downBits, + }) + + // Hash upward. + for h := 0; h < hashUpLayers; h++ { + layer-- + bitIdx := skip + layer // the bit at this layer + bit := bitAt(thisKey, bitIdx) + + // Pop sibling from pending if it matches. + var sibling Node + if len(pendingSiblings) > 0 && + pendingSiblings[len(pendingSiblings)-1].layer == layer+1 { + sibling = pendingSiblings[len(pendingSiblings)-1].node + pendingSiblings = pendingSiblings[:len(pendingSiblings)-1] + } + + var id InternalData + if bit { + id = InternalData{Left: sibling, Right: lastNode} + } else { + id = InternalData{Left: lastNode, Right: sibling} + } + + lastNode = HashInternal(&id) + visit(WriteNode{ + Kind: WriteNodeInternal, + Node: lastNode, + InternalData: &id, + GoUp: true, + }) + } + + pendingSiblings = append(pendingSiblings, + pendingSibling{node: lastNode, layer: layer}) + + aKey = thisKey + aVal = thisVal + } + _ = aVal // used in the loop to track state + + if len(pendingSiblings) > 0 { + return pendingSiblings[len(pendingSiblings)-1].node + } + return Terminator +} + +func bitAt(key *KeyPath, idx int) bool { + return (key[idx/8]>>(7-idx%8))&1 == 1 +} + +func keyPathCmp(a, b *KeyPath) int { + for i := range a { + if a[i] < b[i] { + return -1 + } + if a[i] > b[i] { + return 1 + } + } + return 0 +} diff --git a/nomt/core/update_test.go b/nomt/core/update_test.go new file mode 100644 index 0000000000..a6725b5fc1 --- /dev/null +++ b/nomt/core/update_test.go @@ -0,0 +1,238 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSharedBits(t *testing.T) { + tests := []struct { + name string + a, b KeyPath + skip int + expected int + }{ + {"identical", KeyPath{0xFF}, KeyPath{0xFF}, 0, 256}, + {"differ at bit 0", KeyPath{0x80}, KeyPath{0x00}, 0, 0}, + {"share 4 bits", KeyPath{0xF0}, KeyPath{0xF8}, 0, 4}, + {"with skip", KeyPath{0xF0}, KeyPath{0xF8}, 4, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SharedBits(&tt.a, &tt.b, tt.skip) + assert.Equal(t, tt.expected, got) + }) + } +} + +// makeKV creates a (key, value) pair where both key and value are filled with b. +func makeKV(b byte) KeyValue { + var key KeyPath + var val ValueHash + for i := range key { + key[i] = b + } + for i := range val { + val[i] = b + } + return KeyValue{Key: key, Value: val} +} + +func TestBuildTrieEmpty(t *testing.T) { + var visited []WriteNode + root := BuildTrie(0, nil, func(wn WriteNode) { + visited = append(visited, wn) + }) + + require.Len(t, visited, 1) + assert.Equal(t, WriteNodeTerminator, visited[0].Kind) + assert.Equal(t, Terminator, root) +} + +func TestBuildTrieSingleLeaf(t *testing.T) { + kv := makeKV(0xFF) + var visited []WriteNode + root := BuildTrie(0, []KeyValue{kv}, func(wn WriteNode) { + visited = append(visited, wn) + }) + + require.Len(t, visited, 1) + assert.Equal(t, WriteNodeLeaf, visited[0].Kind) + assert.False(t, visited[0].GoUp) + assert.True(t, IsLeaf(&root)) + + expected := HashLeaf(&LeafData{ + KeyPath: kv.Key, + ValueHash: kv.Value, + }) + assert.Equal(t, expected, root) +} + +func TestBuildTrieTwoLeaves(t *testing.T) { + // Keys: 0x00... and 0xFF... differ at bit 0. + kv0 := makeKV(0x00) + kvF := makeKV(0xFF) + + var visited []WriteNode + root := BuildTrie(0, []KeyValue{kv0, kvF}, func(wn WriteNode) { + visited = append(visited, wn) + }) + + // Should visit: leaf_0, leaf_F, internal(leaf_0, leaf_F). + require.Len(t, visited, 3) + assert.Equal(t, WriteNodeLeaf, visited[0].Kind) + assert.Equal(t, WriteNodeLeaf, visited[1].Kind) + assert.Equal(t, WriteNodeInternal, visited[2].Kind) + + leaf0 := HashLeaf(&LeafData{KeyPath: kv0.Key, ValueHash: kv0.Value}) + leafF := HashLeaf(&LeafData{KeyPath: kvF.Key, ValueHash: kvF.Value}) + expected := HashInternal(&InternalData{Left: leaf0, Right: leafF}) + + assert.Equal(t, expected, root) + assert.True(t, IsInternal(&root)) +} + +func TestBuildTrieThreeLeaves(t *testing.T) { + // Three keys sharing common prefixes. + // 0b00010001... = 0x11 + // 0b00010010... = 0x12 + // 0b00010100... = 0x14 + kv1 := makeKV(0x11) + kv2 := makeKV(0x12) + kv3 := makeKV(0x14) + + var visited []WriteNode + root := BuildTrie(0, []KeyValue{kv1, kv2, kv3}, func(wn WriteNode) { + visited = append(visited, wn) + }) + + assert.True(t, IsInternal(&root) || IsLeaf(&root), + "root should be non-terminator") + + // Verify determinism. + var visited2 []WriteNode + root2 := BuildTrie(0, []KeyValue{kv1, kv2, kv3}, func(wn WriteNode) { + visited2 = append(visited2, wn) + }) + assert.Equal(t, root, root2) + assert.Equal(t, len(visited), len(visited2)) +} + +func TestBuildTrieWithSkip(t *testing.T) { + // Keys all share prefix 0001 (4 bits): 0x11, 0x12, 0x14. + kv1 := makeKV(0x11) + kv2 := makeKV(0x12) + kv3 := makeKV(0x14) + + var visited []WriteNode + root := BuildTrie(4, []KeyValue{kv1, kv2, kv3}, func(wn WriteNode) { + visited = append(visited, wn) + }) + + // Should produce a non-trivial sub-trie. + assert.False(t, IsTerminator(&root)) + assert.True(t, len(visited) >= 3) +} + +func TestBuildTrieMultiValue(t *testing.T) { + // Matches the Rust multi_value test pattern. + // 0b00010000 = 0x10 + // 0b00100000 = 0x20 + // 0b01000000 = 0x40 + // 0b10100000 = 0xA0 + // 0b10110000 = 0xB0 + kvA := makeKV(0x10) + kvB := makeKV(0x20) + kvC := makeKV(0x40) + kvD := makeKV(0xA0) + kvE := makeKV(0xB0) + + var nodes []Node + root := BuildTrie(0, []KeyValue{kvA, kvB, kvC, kvD, kvE}, + func(wn WriteNode) { + nodes = append(nodes, wn.Node) + }) + + // Manually verify the trie structure. + leafA := HashLeaf(&LeafData{KeyPath: kvA.Key, ValueHash: kvA.Value}) + leafB := HashLeaf(&LeafData{KeyPath: kvB.Key, ValueHash: kvB.Value}) + leafC := HashLeaf(&LeafData{KeyPath: kvC.Key, ValueHash: kvC.Value}) + leafD := HashLeaf(&LeafData{KeyPath: kvD.Key, ValueHash: kvD.Value}) + leafE := HashLeaf(&LeafData{KeyPath: kvE.Key, ValueHash: kvE.Value}) + + branchAB := HashInternal(&InternalData{Left: leafA, Right: leafB}) + branchABC := HashInternal(&InternalData{Left: branchAB, Right: leafC}) + + branchDE1 := HashInternal(&InternalData{Left: leafD, Right: leafE}) + branchDE2 := HashInternal(&InternalData{Left: Terminator, Right: branchDE1}) + branchDE3 := HashInternal(&InternalData{Left: branchDE2, Right: Terminator}) + + expected := HashInternal(&InternalData{Left: branchABC, Right: branchDE3}) + + assert.Equal(t, expected, root) +} + +func TestLeafOpsSplicedNoExisting(t *testing.T) { + val := ValueHash{0x01} + ops := []LeafOp{ + {Key: KeyPath{0x10}, Value: &val}, + {Key: KeyPath{0x20}, Value: &val}, + } + + result := LeafOpsSpliced(nil, ops) + assert.Len(t, result, 2) +} + +func TestLeafOpsSplicedWithExistingLeaf(t *testing.T) { + val := ValueHash{0x01} + ops := []LeafOp{ + {Key: KeyPath{0x10}, Value: &val}, + {Key: KeyPath{0x30}, Value: &val}, + } + + existing := &LeafData{ + KeyPath: KeyPath{0x20}, + ValueHash: ValueHash{0x02}, + } + + result := LeafOpsSpliced(existing, ops) + assert.Len(t, result, 3) + assert.Equal(t, KeyPath{0x10}, result[0].Key) + assert.Equal(t, KeyPath{0x20}, result[1].Key) + assert.Equal(t, KeyPath{0x30}, result[2].Key) +} + +func TestLeafOpsSplicedDeleteFiltered(t *testing.T) { + val := ValueHash{0x01} + ops := []LeafOp{ + {Key: KeyPath{0x10}, Value: &val}, + {Key: KeyPath{0x20}, Value: nil}, // delete + {Key: KeyPath{0x30}, Value: &val}, + } + + result := LeafOpsSpliced(nil, ops) + assert.Len(t, result, 2) + assert.Equal(t, KeyPath{0x10}, result[0].Key) + assert.Equal(t, KeyPath{0x30}, result[1].Key) +} + +func TestLeafOpsSplicedExistingKeyInOps(t *testing.T) { + val := ValueHash{0x01} + newVal := ValueHash{0x99} + ops := []LeafOp{ + {Key: KeyPath{0x20}, Value: &newVal}, + } + + existing := &LeafData{ + KeyPath: KeyPath{0x20}, + ValueHash: val, + } + + // The existing leaf should NOT be spliced because its key is in ops. + result := LeafOpsSpliced(existing, ops) + assert.Len(t, result, 1) + assert.Equal(t, newVal, result[0].Value) +}