trie/bintrie: group 2^N internal nodes into a single serialization unit

This commit is contained in:
Guillaume Ballet 2026-01-21 18:07:20 +01:00
parent 251b863107
commit 2a404c4cc2
No known key found for this signature in database
4 changed files with 610 additions and 39 deletions

View file

@ -36,6 +36,14 @@ const (
NodeTypeBytes = 1 // Size of node type prefix in serialization
HashSize = 32 // Size of a hash in bytes
BitmapSize = 32 // Size of the bitmap in a stem node
// GroupDepth is the number of levels in a grouped subtree serialization.
// Groups are byte-aligned (depth % 8 == 0). This may become configurable later.
// Serialization format for InternalNode groups:
// [1 byte type] [1 byte group depth (1-8)] [32 byte bitmap] [N × 32 byte hashes]
// The bitmap has 2^groupDepth bits, indicating which bottom-layer children are present.
// Only present children's hashes are stored, in order.
GroupDepth = 8
)
const (
@ -57,16 +65,72 @@ type BinaryNode interface {
GetHeight() int
}
// serializeSubtree recursively collects child hashes from a subtree of InternalNodes.
// It traverses up to `remainingDepth` levels, storing hashes of bottom-layer children.
// position tracks the current index (0 to 2^groupDepth - 1) for bitmap placement.
// hashes collects the hashes of present children, bitmap tracks which positions are present.
func serializeSubtree(node BinaryNode, remainingDepth int, position int, bitmap []byte, hashes *[]common.Hash) {
if remainingDepth == 0 {
// Bottom layer: store hash if not empty
switch node.(type) {
case Empty:
// Leave bitmap bit unset, don't add hash
return
default:
// StemNode, HashedNode, or InternalNode at boundary: store hash
bitmap[position/8] |= 1 << (7 - (position % 8))
*hashes = append(*hashes, node.Hash())
}
return
}
switch n := node.(type) {
case *InternalNode:
// Recurse into left (bit 0) and right (bit 1) children
leftPos := position * 2
rightPos := position*2 + 1
serializeSubtree(n.left, remainingDepth-1, leftPos, bitmap, hashes)
serializeSubtree(n.right, remainingDepth-1, rightPos, bitmap, hashes)
case Empty:
// Empty subtree: all positions in this subtree are empty (bits already 0)
return
default:
// StemNode or HashedNode before reaching bottom: store hash at current position
// This creates a variable-depth group where this branch terminates early.
// We need to mark this single position and all its would-be descendants as "this hash".
// For simplicity, we store the hash at the first leaf position of this subtree.
firstLeafPos := position << remainingDepth
bitmap[firstLeafPos/8] |= 1 << (7 - (firstLeafPos % 8))
*hashes = append(*hashes, node.Hash())
}
}
// SerializeNode serializes a binary trie node into a byte slice.
func SerializeNode(node BinaryNode) []byte {
switch n := (node).(type) {
case *InternalNode:
// InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash
var serialized [NodeTypeBytes + HashSize + HashSize]byte
// InternalNode group: 1 byte type + 1 byte group depth + 32 byte bitmap + N×32 byte hashes
groupDepth := GroupDepth
var bitmap [BitmapSize]byte
var hashes []common.Hash
serializeSubtree(n, groupDepth, 0, bitmap[:], &hashes)
// Build serialized output
serializedLen := NodeTypeBytes + 1 + BitmapSize + len(hashes)*HashSize
serialized := make([]byte, serializedLen)
serialized[0] = nodeTypeInternal
copy(serialized[1:33], n.left.Hash().Bytes())
copy(serialized[33:65], n.right.Hash().Bytes())
return serialized[:]
serialized[1] = byte(groupDepth)
copy(serialized[2:2+BitmapSize], bitmap[:])
offset := NodeTypeBytes + 1 + BitmapSize
for _, h := range hashes {
copy(serialized[offset:offset+HashSize], h.Bytes())
offset += HashSize
}
return serialized
case *StemNode:
// StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values
var serialized [NodeTypeBytes + StemSize + BitmapSize + StemNodeWidth*HashSize]byte
@ -90,6 +154,51 @@ func SerializeNode(node BinaryNode) []byte {
var invalidSerializedLength = errors.New("invalid serialized node length")
// deserializeSubtree reconstructs an InternalNode subtree from grouped serialization.
// remainingDepth is how many more levels to build, position is current index in the bitmap,
// nodeDepth is the actual trie depth for the node being created.
// hashIdx tracks the current position in the hash data (incremented as hashes are consumed).
func deserializeSubtree(remainingDepth int, position int, nodeDepth int, bitmap []byte, hashData []byte, hashIdx *int) (BinaryNode, error) {
if remainingDepth == 0 {
// Bottom layer: check bitmap and return HashedNode or Empty
if bitmap[position/8]>>(7-(position%8))&1 == 1 {
if len(hashData) < (*hashIdx+1)*HashSize {
return nil, invalidSerializedLength
}
hash := common.BytesToHash(hashData[*hashIdx*HashSize : (*hashIdx+1)*HashSize])
*hashIdx++
return HashedNode(hash), nil
}
return Empty{}, nil
}
// Check if this entire subtree is empty by examining all relevant bitmap bits
leftPos := position * 2
rightPos := position*2 + 1
left, err := deserializeSubtree(remainingDepth-1, leftPos, nodeDepth+1, bitmap, hashData, hashIdx)
if err != nil {
return nil, err
}
right, err := deserializeSubtree(remainingDepth-1, rightPos, nodeDepth+1, bitmap, hashData, hashIdx)
if err != nil {
return nil, err
}
// If both children are empty, return Empty
_, leftEmpty := left.(Empty)
_, rightEmpty := right.(Empty)
if leftEmpty && rightEmpty {
return Empty{}, nil
}
return &InternalNode{
depth: nodeDepth,
left: left,
right: right,
}, nil
}
// DeserializeNode deserializes a binary trie node from a byte slice.
func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
if len(serialized) == 0 {
@ -98,14 +207,20 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
switch serialized[0] {
case nodeTypeInternal:
if len(serialized) != 65 {
// Grouped format: 1 byte type + 1 byte group depth + 32 byte bitmap + N×32 byte hashes
if len(serialized) < NodeTypeBytes+1+BitmapSize {
return nil, invalidSerializedLength
}
return &InternalNode{
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
}, nil
groupDepth := int(serialized[1])
if groupDepth < 1 || groupDepth > GroupDepth {
return nil, errors.New("invalid group depth")
}
bitmap := serialized[2 : 2+BitmapSize]
hashData := serialized[2+BitmapSize:]
// Count present children from bitmap
hashIdx := 0
return deserializeSubtree(groupDepth, 0, depth, bitmap, hashData, &hashIdx)
case nodeTypeStem:
if len(serialized) < 64 {
return nil, invalidSerializedLength

View file

@ -24,13 +24,15 @@ import (
)
// TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode
// with the grouped subtree format. A single InternalNode with HashedNode children serializes
// as a depth-8 group where the children appear at their first leaf positions.
func TestSerializeDeserializeInternalNode(t *testing.T) {
// Create an internal node with two hashed children
leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321")
node := &InternalNode{
depth: 5,
depth: 0, // Use depth 0 (byte-aligned) for this test
left: HashedNode(leftHash),
right: HashedNode(rightHash),
}
@ -38,42 +40,81 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
// Serialize the node
serialized := SerializeNode(node)
// Check the serialized format
// Check the serialized format: type byte + group depth byte + 32 byte bitmap + N*32 byte hashes
if serialized[0] != nodeTypeInternal {
t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0])
}
if len(serialized) != 65 {
t.Errorf("Expected serialized length to be 65, got %d", len(serialized))
if serialized[1] != GroupDepth {
t.Errorf("Expected group depth to be %d, got %d", GroupDepth, serialized[1])
}
// Expected length: 1 (type) + 1 (group depth) + 32 (bitmap) + 2*32 (two hashes) = 98 bytes
expectedLen := NodeTypeBytes + 1 + BitmapSize + 2*HashSize
if len(serialized) != expectedLen {
t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized))
}
// The left child (HashedNode) terminates at remainingDepth=7, so it's placed at position 0<<7 = 0
// The right child (HashedNode) terminates at remainingDepth=7, so it's placed at position 1<<7 = 128
bitmap := serialized[2 : 2+BitmapSize]
if bitmap[0]&0x80 == 0 { // bit 0 (MSB of byte 0)
t.Error("Expected bit 0 to be set in bitmap (left child)")
}
if bitmap[16]&0x80 == 0 { // bit 128 (MSB of byte 16)
t.Error("Expected bit 128 to be set in bitmap (right child)")
}
// Deserialize the node
deserialized, err := DeserializeNode(serialized, 5)
deserialized, err := DeserializeNode(serialized, 0)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
// Check that it's an internal node
// With grouped format, deserialization creates a tree of InternalNodes down to the hashes.
// The root should be an InternalNode, and we should be able to navigate down 8 levels
// to find the HashedNode children.
internalNode, ok := deserialized.(*InternalNode)
if !ok {
t.Fatalf("Expected InternalNode, got %T", deserialized)
}
// Check the depth
if internalNode.depth != 5 {
t.Errorf("Expected depth 5, got %d", internalNode.depth)
if internalNode.depth != 0 {
t.Errorf("Expected depth 0, got %d", internalNode.depth)
}
// Check the left and right hashes
if internalNode.left.Hash() != leftHash {
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash())
// Navigate to position 0 (8 left turns) to find the left hash
node0 := navigateToLeaf(internalNode, 0, 8)
if node0.Hash() != leftHash {
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, node0.Hash())
}
if internalNode.right.Hash() != rightHash {
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash())
// Navigate to position 128 (right, then 7 lefts) to find the right hash
node128 := navigateToLeaf(internalNode, 128, 8)
if node128.Hash() != rightHash {
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, node128.Hash())
}
}
// navigateToLeaf navigates to a specific position in the tree (used by grouped serialization tests)
func navigateToLeaf(node BinaryNode, position, depth int) BinaryNode {
for d := 0; d < depth; d++ {
in, ok := node.(*InternalNode)
if !ok {
return node
}
// Check bit at position (depth-1-d) to determine left or right
bit := (position >> (depth - 1 - d)) & 1
if bit == 0 {
node = in.left
} else {
node = in.right
}
}
return node
}
// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode
func TestSerializeDeserializeStemNode(t *testing.T) {
// Create a stem node with some values

View file

@ -0,0 +1,364 @@
package bintrie
import (
"fmt"
"testing"
"github.com/ethereum/go-ethereum/common"
)
// TestGroupedSerializationDebug helps understand the grouped serialization format
func TestGroupedSerializationDebug(t *testing.T) {
leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321")
node := &InternalNode{
depth: 0,
left: HashedNode(leftHash),
right: HashedNode(rightHash),
}
serialized := SerializeNode(node)
t.Logf("Serialized length: %d", len(serialized))
t.Logf("Type: %d, GroupDepth: %d", serialized[0], serialized[1])
bitmap := serialized[2 : 2+BitmapSize]
t.Logf("Bitmap: %x", bitmap)
// Count and show set bits
for i := 0; i < 256; i++ {
if bitmap[i/8]>>(7-(i%8))&1 == 1 {
t.Logf("Bit %d is set", i)
}
}
// Deserialize
deserialized, err := DeserializeNode(serialized, 0)
if err != nil {
t.Fatalf("Error: %v", err)
}
t.Logf("Deserialized type: %T", deserialized)
// Walk the tree and print structure
printTree(t, deserialized, 0, "root")
}
func printTree(t *testing.T, node BinaryNode, depth int, path string) {
indent := ""
for i := 0; i < depth; i++ {
indent += " "
}
switch n := node.(type) {
case *InternalNode:
t.Logf("%s%s: InternalNode (depth=%d)", indent, path, n.depth)
printTree(t, n.left, depth+1, path+"/L")
printTree(t, n.right, depth+1, path+"/R")
case HashedNode:
t.Logf("%s%s: HashedNode(%x)", indent, path, common.Hash(n))
case Empty:
t.Logf("%s%s: Empty", indent, path)
default:
t.Logf("%s%s: %T", indent, path, node)
}
}
// TestFullDepth8Tree tests a full 8-level tree (all 256 bottom positions filled)
func TestFullDepth8Tree(t *testing.T) {
// Build a full 8-level tree
root := buildFullTree(0, 8)
serialized := SerializeNode(root)
t.Logf("Full tree serialized length: %d", len(serialized))
t.Logf("Expected: 1 + 1 + 32 + 256*32 = %d", 1+1+32+256*32)
// Count set bits in bitmap
bitmap := serialized[2 : 2+BitmapSize]
count := 0
for i := 0; i < 256; i++ {
if bitmap[i/8]>>(7-(i%8))&1 == 1 {
count++
}
}
t.Logf("Set bits in bitmap: %d", count)
// Deserialize and verify structure
deserialized, err := DeserializeNode(serialized, 0)
if err != nil {
t.Fatalf("Error: %v", err)
}
// Verify it's an InternalNode with depth 0
in, ok := deserialized.(*InternalNode)
if !ok {
t.Fatalf("Expected InternalNode, got %T", deserialized)
}
if in.depth != 0 {
t.Errorf("Expected depth 0, got %d", in.depth)
}
// Count leaves at depth 8
leafCount := countLeavesAtDepth(deserialized, 8, 0)
t.Logf("Leaves at depth 8: %d", leafCount)
if leafCount != 256 {
t.Errorf("Expected 256 leaves, got %d", leafCount)
}
}
func buildFullTree(depth, maxDepth int) BinaryNode {
if depth == maxDepth {
// Create a unique hash for this position
var h common.Hash
h[0] = byte(depth)
h[1] = byte(depth >> 8)
return HashedNode(h)
}
return &InternalNode{
depth: depth,
left: buildFullTree(depth+1, maxDepth),
right: buildFullTree(depth+1, maxDepth),
}
}
func countLeavesAtDepth(node BinaryNode, targetDepth, currentDepth int) int {
if currentDepth == targetDepth {
if _, ok := node.(Empty); ok {
return 0
}
return 1
}
in, ok := node.(*InternalNode)
if !ok {
return 0 // Terminated early
}
return countLeavesAtDepth(in.left, targetDepth, currentDepth+1) +
countLeavesAtDepth(in.right, targetDepth, currentDepth+1)
}
// TestRoundTripPreservesHashes tests that round-trip preserves the original hashes
func TestRoundTripPreservesHashes(t *testing.T) {
// Build a tree with known hashes at specific positions
hashes := make([]common.Hash, 256)
for i := range hashes {
hashes[i] = common.BytesToHash([]byte(fmt.Sprintf("hash-%d", i)))
}
root := buildTreeWithHashes(0, 8, 0, hashes)
serialized := SerializeNode(root)
deserialized, err := DeserializeNode(serialized, 0)
if err != nil {
t.Fatalf("Error: %v", err)
}
// Verify each hash at depth 8
for i := 0; i < 256; i++ {
node := navigateToLeaf(deserialized, i, 8)
if node == nil {
t.Errorf("Position %d: node is nil", i)
continue
}
if node.Hash() != hashes[i] {
t.Errorf("Position %d: hash mismatch, expected %x, got %x", i, hashes[i], node.Hash())
}
}
}
func buildTreeWithHashes(depth, maxDepth, position int, hashes []common.Hash) BinaryNode {
if depth == maxDepth {
return HashedNode(hashes[position])
}
return &InternalNode{
depth: depth,
left: buildTreeWithHashes(depth+1, maxDepth, position*2, hashes),
right: buildTreeWithHashes(depth+1, maxDepth, position*2+1, hashes),
}
}
// TestCollectNodesGrouping verifies that CollectNodes only flushes at group boundaries
// and that the serialized/deserialized tree matches the original.
func TestCollectNodesGrouping(t *testing.T) {
// Build a tree that spans multiple groups (16 levels = 2 groups)
// This creates a tree where:
// - Group 1: depths 0-7 (root group)
// - Group 2: depths 8-15 (leaf groups, up to 256 of them)
// Use unique hashes at leaves so we get unique serialized blobs
root := buildDeepTreeUnique(0, 16, 0)
// Compute the root hash before collection
originalRootHash := root.Hash()
// Collect and serialize all nodes, storing by hash
serializedNodes := make(map[common.Hash][]byte)
var collectedNodes []struct {
path []byte
node BinaryNode
}
err := root.CollectNodes(nil, func(path []byte, node BinaryNode) {
pathCopy := make([]byte, len(path))
copy(pathCopy, path)
collectedNodes = append(collectedNodes, struct {
path []byte
node BinaryNode
}{pathCopy, node})
// Serialize and store by hash
serialized := SerializeNode(node)
serializedNodes[node.Hash()] = serialized
})
if err != nil {
t.Fatalf("CollectNodes failed: %v", err)
}
// Count nodes by depth
depthCounts := make(map[int]int)
for _, cn := range collectedNodes {
switch n := cn.node.(type) {
case *InternalNode:
depthCounts[n.depth]++
case *StemNode:
t.Logf("Collected StemNode at path len %d", len(cn.path))
}
}
// With a 16-level tree:
// - 1 node at depth 0 (the root group)
// - 256 nodes at depth 8 (the second-level groups)
// Total: 257 InternalNode groups
if depthCounts[0] != 1 {
t.Errorf("Expected 1 node at depth 0, got %d", depthCounts[0])
}
if depthCounts[8] != 256 {
t.Errorf("Expected 256 nodes at depth 8, got %d", depthCounts[8])
}
t.Logf("Total collected nodes: %d", len(collectedNodes))
t.Logf("Total serialized blobs: %d", len(serializedNodes))
t.Logf("Depth counts: %v", depthCounts)
// Now deserialize starting from the root hash
// Create a resolver that looks up serialized data by hash
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
if data, ok := serializedNodes[hash]; ok {
return data, nil
}
return nil, fmt.Errorf("node not found: %x", hash)
}
// Deserialize the root
rootData, ok := serializedNodes[originalRootHash]
if !ok {
t.Fatalf("Root hash not found in serialized nodes: %x", originalRootHash)
}
deserializedRoot, err := DeserializeNode(rootData, 0)
if err != nil {
t.Fatalf("Failed to deserialize root: %v", err)
}
// Verify the deserialized root hash matches
if deserializedRoot.Hash() != originalRootHash {
t.Errorf("Deserialized root hash mismatch: expected %x, got %x", originalRootHash, deserializedRoot.Hash())
}
// Traverse both trees and compare structure at all 16 levels
// We need to resolve HashedNodes in the deserialized tree to compare deeper
err = compareTreesWithResolver(t, root, deserializedRoot, resolver, 0, 16, "root")
if err != nil {
t.Errorf("Tree comparison failed: %v", err)
}
t.Log("Tree comparison passed - deserialized tree matches original")
}
// compareTreesWithResolver compares two trees, resolving HashedNodes as needed
func compareTreesWithResolver(t *testing.T, original, deserialized BinaryNode, resolver NodeResolverFn, depth, maxDepth int, path string) error {
if depth >= maxDepth {
// At leaf level, just compare hashes
if original.Hash() != deserialized.Hash() {
return fmt.Errorf("hash mismatch at %s: original=%x, deserialized=%x", path, original.Hash(), deserialized.Hash())
}
return nil
}
// Get the actual nodes (resolve HashedNodes if needed)
origNode := original
deserNode := deserialized
// Resolve deserialized HashedNode if needed
if h, ok := deserNode.(HashedNode); ok {
data, err := resolver(nil, common.Hash(h))
if err != nil {
return fmt.Errorf("failed to resolve deserialized node at %s: %v", path, err)
}
deserNode, err = DeserializeNode(data, depth)
if err != nil {
return fmt.Errorf("failed to deserialize node at %s: %v", path, err)
}
}
// Both should be InternalNodes at this point
origInternal, origOk := origNode.(*InternalNode)
deserInternal, deserOk := deserNode.(*InternalNode)
if !origOk || !deserOk {
// Check if both are the same type
if fmt.Sprintf("%T", origNode) != fmt.Sprintf("%T", deserNode) {
return fmt.Errorf("type mismatch at %s: original=%T, deserialized=%T", path, origNode, deserNode)
}
// Both are non-InternalNode, compare hashes
if origNode.Hash() != deserNode.Hash() {
return fmt.Errorf("hash mismatch at %s: original=%x, deserialized=%x", path, origNode.Hash(), deserNode.Hash())
}
return nil
}
// Compare depths
if origInternal.depth != deserInternal.depth {
return fmt.Errorf("depth mismatch at %s: original=%d, deserialized=%d", path, origInternal.depth, deserInternal.depth)
}
// Recursively compare children
if err := compareTreesWithResolver(t, origInternal.left, deserInternal.left, resolver, depth+1, maxDepth, path+"/L"); err != nil {
return err
}
if err := compareTreesWithResolver(t, origInternal.right, deserInternal.right, resolver, depth+1, maxDepth, path+"/R"); err != nil {
return err
}
return nil
}
func buildDeepTree(depth, maxDepth int) BinaryNode {
if depth == maxDepth {
// Create a unique hash for this leaf position
var h common.Hash
h[0] = byte(depth)
h[1] = byte(depth >> 8)
return HashedNode(h)
}
return &InternalNode{
depth: depth,
left: buildDeepTree(depth+1, maxDepth),
right: buildDeepTree(depth+1, maxDepth),
}
}
// buildDeepTreeUnique builds a tree where each leaf has a unique hash based on its position
func buildDeepTreeUnique(depth, maxDepth, position int) BinaryNode {
if depth == maxDepth {
// Create a unique hash based on position in the tree
var h common.Hash
h[0] = byte(position)
h[1] = byte(position >> 8)
h[2] = byte(position >> 16)
h[3] = byte(position >> 24)
return HashedNode(h)
}
return &InternalNode{
depth: depth,
left: buildDeepTreeUnique(depth+1, maxDepth, position*2),
right: buildDeepTreeUnique(depth+1, maxDepth, position*2+1),
}
}

View file

@ -184,31 +184,82 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
return bt, err
}
// CollectNodes collects all child nodes at a given path, and flushes it
// into the provided node collector.
// CollectNodes collects all child nodes at group boundaries (every GroupDepth levels),
// and flushes them into the provided node collector. Each flush serializes an 8-level
// subtree group. Nodes within a group are not flushed individually.
func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
if bt.left != nil {
var p [256]byte
copy(p[:], path)
childpath := p[:len(path)]
childpath = append(childpath, 0)
if err := bt.left.CollectNodes(childpath, flushfn); err != nil {
// Only flush at group boundaries (depth % GroupDepth == 0)
if bt.depth%GroupDepth == 0 {
// We're at a group boundary - first collect any nodes in deeper groups,
// then flush this group
if err := bt.collectChildGroups(path, flushfn, GroupDepth-1); err != nil {
return err
}
flushfn(path, bt)
return nil
}
// Not at a group boundary - this shouldn't happen if we're called correctly from root
// but handle it by continuing to traverse
return bt.collectChildGroups(path, flushfn, GroupDepth-(bt.depth%GroupDepth)-1)
}
// collectChildGroups traverses within a group to find and collect nodes in the next group.
// remainingLevels is how many more levels below the current node until we reach the group boundary.
// When remainingLevels=0, the current node's children are at the next group boundary.
func (bt *InternalNode) collectChildGroups(path []byte, flushfn NodeFlushFn, remainingLevels int) error {
if remainingLevels == 0 {
// Current node is at depth (groupBoundary - 1), its children are at the next group boundary
if bt.left != nil {
if err := bt.left.CollectNodes(appendBit(path, 0), flushfn); err != nil {
return err
}
}
if bt.right != nil {
if err := bt.right.CollectNodes(appendBit(path, 1), flushfn); err != nil {
return err
}
}
return nil
}
// Continue traversing within the group
if bt.left != nil {
switch n := bt.left.(type) {
case *InternalNode:
if err := n.collectChildGroups(appendBit(path, 0), flushfn, remainingLevels-1); err != nil {
return err
}
default:
// StemNode, HashedNode, or Empty - they handle their own collection
if err := bt.left.CollectNodes(appendBit(path, 0), flushfn); err != nil {
return err
}
}
}
if bt.right != nil {
var p [256]byte
copy(p[:], path)
childpath := p[:len(path)]
childpath = append(childpath, 1)
if err := bt.right.CollectNodes(childpath, flushfn); err != nil {
return err
switch n := bt.right.(type) {
case *InternalNode:
if err := n.collectChildGroups(appendBit(path, 1), flushfn, remainingLevels-1); err != nil {
return err
}
default:
// StemNode, HashedNode, or Empty - they handle their own collection
if err := bt.right.CollectNodes(appendBit(path, 1), flushfn); err != nil {
return err
}
}
}
flushfn(path, bt)
return nil
}
// appendBit appends a bit to a path, returning a new slice
func appendBit(path []byte, bit byte) []byte {
var p [256]byte
copy(p[:], path)
result := p[:len(path)]
return append(result, bit)
}
// GetHeight returns the height of the node.
func (bt *InternalNode) GetHeight() int {
var (