diff --git a/trie/bintrie/binary_node.go b/trie/bintrie/binary_node.go index 690489b2aa..e1eee7a7bd 100644 --- a/trie/bintrie/binary_node.go +++ b/trie/bintrie/binary_node.go @@ -36,6 +36,14 @@ const ( NodeTypeBytes = 1 // Size of node type prefix in serialization HashSize = 32 // Size of a hash in bytes BitmapSize = 32 // Size of the bitmap in a stem node + + // GroupDepth is the number of levels in a grouped subtree serialization. + // Groups are byte-aligned (depth % 8 == 0). This may become configurable later. + // Serialization format for InternalNode groups: + // [1 byte type] [1 byte group depth (1-8)] [32 byte bitmap] [N × 32 byte hashes] + // The bitmap has 2^groupDepth bits, indicating which bottom-layer children are present. + // Only present children's hashes are stored, in order. + GroupDepth = 8 ) const ( @@ -57,16 +65,72 @@ type BinaryNode interface { GetHeight() int } +// serializeSubtree recursively collects child hashes from a subtree of InternalNodes. +// It traverses up to `remainingDepth` levels, storing hashes of bottom-layer children. +// position tracks the current index (0 to 2^groupDepth - 1) for bitmap placement. +// hashes collects the hashes of present children, bitmap tracks which positions are present. +func serializeSubtree(node BinaryNode, remainingDepth int, position int, bitmap []byte, hashes *[]common.Hash) { + if remainingDepth == 0 { + // Bottom layer: store hash if not empty + switch node.(type) { + case Empty: + // Leave bitmap bit unset, don't add hash + return + default: + // StemNode, HashedNode, or InternalNode at boundary: store hash + bitmap[position/8] |= 1 << (7 - (position % 8)) + *hashes = append(*hashes, node.Hash()) + } + return + } + + switch n := node.(type) { + case *InternalNode: + // Recurse into left (bit 0) and right (bit 1) children + leftPos := position * 2 + rightPos := position*2 + 1 + serializeSubtree(n.left, remainingDepth-1, leftPos, bitmap, hashes) + serializeSubtree(n.right, remainingDepth-1, rightPos, bitmap, hashes) + case Empty: + // Empty subtree: all positions in this subtree are empty (bits already 0) + return + default: + // StemNode or HashedNode before reaching bottom: store hash at current position + // This creates a variable-depth group where this branch terminates early. + // We need to mark this single position and all its would-be descendants as "this hash". + // For simplicity, we store the hash at the first leaf position of this subtree. + firstLeafPos := position << remainingDepth + bitmap[firstLeafPos/8] |= 1 << (7 - (firstLeafPos % 8)) + *hashes = append(*hashes, node.Hash()) + } +} + // SerializeNode serializes a binary trie node into a byte slice. func SerializeNode(node BinaryNode) []byte { switch n := (node).(type) { case *InternalNode: - // InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash - var serialized [NodeTypeBytes + HashSize + HashSize]byte + // InternalNode group: 1 byte type + 1 byte group depth + 32 byte bitmap + N×32 byte hashes + groupDepth := GroupDepth + + var bitmap [BitmapSize]byte + var hashes []common.Hash + + serializeSubtree(n, groupDepth, 0, bitmap[:], &hashes) + + // Build serialized output + serializedLen := NodeTypeBytes + 1 + BitmapSize + len(hashes)*HashSize + serialized := make([]byte, serializedLen) serialized[0] = nodeTypeInternal - copy(serialized[1:33], n.left.Hash().Bytes()) - copy(serialized[33:65], n.right.Hash().Bytes()) - return serialized[:] + serialized[1] = byte(groupDepth) + copy(serialized[2:2+BitmapSize], bitmap[:]) + + offset := NodeTypeBytes + 1 + BitmapSize + for _, h := range hashes { + copy(serialized[offset:offset+HashSize], h.Bytes()) + offset += HashSize + } + + return serialized case *StemNode: // StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values var serialized [NodeTypeBytes + StemSize + BitmapSize + StemNodeWidth*HashSize]byte @@ -90,6 +154,51 @@ func SerializeNode(node BinaryNode) []byte { var invalidSerializedLength = errors.New("invalid serialized node length") +// deserializeSubtree reconstructs an InternalNode subtree from grouped serialization. +// remainingDepth is how many more levels to build, position is current index in the bitmap, +// nodeDepth is the actual trie depth for the node being created. +// hashIdx tracks the current position in the hash data (incremented as hashes are consumed). +func deserializeSubtree(remainingDepth int, position int, nodeDepth int, bitmap []byte, hashData []byte, hashIdx *int) (BinaryNode, error) { + if remainingDepth == 0 { + // Bottom layer: check bitmap and return HashedNode or Empty + if bitmap[position/8]>>(7-(position%8))&1 == 1 { + if len(hashData) < (*hashIdx+1)*HashSize { + return nil, invalidSerializedLength + } + hash := common.BytesToHash(hashData[*hashIdx*HashSize : (*hashIdx+1)*HashSize]) + *hashIdx++ + return HashedNode(hash), nil + } + return Empty{}, nil + } + + // Check if this entire subtree is empty by examining all relevant bitmap bits + leftPos := position * 2 + rightPos := position*2 + 1 + + left, err := deserializeSubtree(remainingDepth-1, leftPos, nodeDepth+1, bitmap, hashData, hashIdx) + if err != nil { + return nil, err + } + right, err := deserializeSubtree(remainingDepth-1, rightPos, nodeDepth+1, bitmap, hashData, hashIdx) + if err != nil { + return nil, err + } + + // If both children are empty, return Empty + _, leftEmpty := left.(Empty) + _, rightEmpty := right.(Empty) + if leftEmpty && rightEmpty { + return Empty{}, nil + } + + return &InternalNode{ + depth: nodeDepth, + left: left, + right: right, + }, nil +} + // DeserializeNode deserializes a binary trie node from a byte slice. func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) { if len(serialized) == 0 { @@ -98,14 +207,20 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) { switch serialized[0] { case nodeTypeInternal: - if len(serialized) != 65 { + // Grouped format: 1 byte type + 1 byte group depth + 32 byte bitmap + N×32 byte hashes + if len(serialized) < NodeTypeBytes+1+BitmapSize { return nil, invalidSerializedLength } - return &InternalNode{ - depth: depth, - left: HashedNode(common.BytesToHash(serialized[1:33])), - right: HashedNode(common.BytesToHash(serialized[33:65])), - }, nil + groupDepth := int(serialized[1]) + if groupDepth < 1 || groupDepth > GroupDepth { + return nil, errors.New("invalid group depth") + } + bitmap := serialized[2 : 2+BitmapSize] + hashData := serialized[2+BitmapSize:] + + // Count present children from bitmap + hashIdx := 0 + return deserializeSubtree(groupDepth, 0, depth, bitmap, hashData, &hashIdx) case nodeTypeStem: if len(serialized) < 64 { return nil, invalidSerializedLength diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go index 242743ba53..a04f08c9b3 100644 --- a/trie/bintrie/binary_node_test.go +++ b/trie/bintrie/binary_node_test.go @@ -24,13 +24,15 @@ import ( ) // TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode +// with the grouped subtree format. A single InternalNode with HashedNode children serializes +// as a depth-8 group where the children appear at their first leaf positions. func TestSerializeDeserializeInternalNode(t *testing.T) { // Create an internal node with two hashed children leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321") node := &InternalNode{ - depth: 5, + depth: 0, // Use depth 0 (byte-aligned) for this test left: HashedNode(leftHash), right: HashedNode(rightHash), } @@ -38,42 +40,81 @@ func TestSerializeDeserializeInternalNode(t *testing.T) { // Serialize the node serialized := SerializeNode(node) - // Check the serialized format + // Check the serialized format: type byte + group depth byte + 32 byte bitmap + N*32 byte hashes if serialized[0] != nodeTypeInternal { t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0]) } - if len(serialized) != 65 { - t.Errorf("Expected serialized length to be 65, got %d", len(serialized)) + if serialized[1] != GroupDepth { + t.Errorf("Expected group depth to be %d, got %d", GroupDepth, serialized[1]) + } + + // Expected length: 1 (type) + 1 (group depth) + 32 (bitmap) + 2*32 (two hashes) = 98 bytes + expectedLen := NodeTypeBytes + 1 + BitmapSize + 2*HashSize + if len(serialized) != expectedLen { + t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized)) + } + + // The left child (HashedNode) terminates at remainingDepth=7, so it's placed at position 0<<7 = 0 + // The right child (HashedNode) terminates at remainingDepth=7, so it's placed at position 1<<7 = 128 + bitmap := serialized[2 : 2+BitmapSize] + if bitmap[0]&0x80 == 0 { // bit 0 (MSB of byte 0) + t.Error("Expected bit 0 to be set in bitmap (left child)") + } + if bitmap[16]&0x80 == 0 { // bit 128 (MSB of byte 16) + t.Error("Expected bit 128 to be set in bitmap (right child)") } // Deserialize the node - deserialized, err := DeserializeNode(serialized, 5) + deserialized, err := DeserializeNode(serialized, 0) if err != nil { t.Fatalf("Failed to deserialize node: %v", err) } - // Check that it's an internal node + // With grouped format, deserialization creates a tree of InternalNodes down to the hashes. + // The root should be an InternalNode, and we should be able to navigate down 8 levels + // to find the HashedNode children. internalNode, ok := deserialized.(*InternalNode) if !ok { t.Fatalf("Expected InternalNode, got %T", deserialized) } // Check the depth - if internalNode.depth != 5 { - t.Errorf("Expected depth 5, got %d", internalNode.depth) + if internalNode.depth != 0 { + t.Errorf("Expected depth 0, got %d", internalNode.depth) } - // Check the left and right hashes - if internalNode.left.Hash() != leftHash { - t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash()) + // Navigate to position 0 (8 left turns) to find the left hash + node0 := navigateToLeaf(internalNode, 0, 8) + if node0.Hash() != leftHash { + t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, node0.Hash()) } - if internalNode.right.Hash() != rightHash { - t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash()) + // Navigate to position 128 (right, then 7 lefts) to find the right hash + node128 := navigateToLeaf(internalNode, 128, 8) + if node128.Hash() != rightHash { + t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, node128.Hash()) } } +// navigateToLeaf navigates to a specific position in the tree (used by grouped serialization tests) +func navigateToLeaf(node BinaryNode, position, depth int) BinaryNode { + for d := 0; d < depth; d++ { + in, ok := node.(*InternalNode) + if !ok { + return node + } + // Check bit at position (depth-1-d) to determine left or right + bit := (position >> (depth - 1 - d)) & 1 + if bit == 0 { + node = in.left + } else { + node = in.right + } + } + return node +} + // TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode func TestSerializeDeserializeStemNode(t *testing.T) { // Create a stem node with some values diff --git a/trie/bintrie/group_debug_test.go b/trie/bintrie/group_debug_test.go new file mode 100644 index 0000000000..cc54a6f68c --- /dev/null +++ b/trie/bintrie/group_debug_test.go @@ -0,0 +1,364 @@ +package bintrie + +import ( + "fmt" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestGroupedSerializationDebug helps understand the grouped serialization format +func TestGroupedSerializationDebug(t *testing.T) { + leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321") + + node := &InternalNode{ + depth: 0, + left: HashedNode(leftHash), + right: HashedNode(rightHash), + } + + serialized := SerializeNode(node) + t.Logf("Serialized length: %d", len(serialized)) + t.Logf("Type: %d, GroupDepth: %d", serialized[0], serialized[1]) + + bitmap := serialized[2 : 2+BitmapSize] + t.Logf("Bitmap: %x", bitmap) + + // Count and show set bits + for i := 0; i < 256; i++ { + if bitmap[i/8]>>(7-(i%8))&1 == 1 { + t.Logf("Bit %d is set", i) + } + } + + // Deserialize + deserialized, err := DeserializeNode(serialized, 0) + if err != nil { + t.Fatalf("Error: %v", err) + } + + t.Logf("Deserialized type: %T", deserialized) + + // Walk the tree and print structure + printTree(t, deserialized, 0, "root") +} + +func printTree(t *testing.T, node BinaryNode, depth int, path string) { + indent := "" + for i := 0; i < depth; i++ { + indent += " " + } + + switch n := node.(type) { + case *InternalNode: + t.Logf("%s%s: InternalNode (depth=%d)", indent, path, n.depth) + printTree(t, n.left, depth+1, path+"/L") + printTree(t, n.right, depth+1, path+"/R") + case HashedNode: + t.Logf("%s%s: HashedNode(%x)", indent, path, common.Hash(n)) + case Empty: + t.Logf("%s%s: Empty", indent, path) + default: + t.Logf("%s%s: %T", indent, path, node) + } +} + +// TestFullDepth8Tree tests a full 8-level tree (all 256 bottom positions filled) +func TestFullDepth8Tree(t *testing.T) { + // Build a full 8-level tree + root := buildFullTree(0, 8) + + serialized := SerializeNode(root) + t.Logf("Full tree serialized length: %d", len(serialized)) + t.Logf("Expected: 1 + 1 + 32 + 256*32 = %d", 1+1+32+256*32) + + // Count set bits in bitmap + bitmap := serialized[2 : 2+BitmapSize] + count := 0 + for i := 0; i < 256; i++ { + if bitmap[i/8]>>(7-(i%8))&1 == 1 { + count++ + } + } + t.Logf("Set bits in bitmap: %d", count) + + // Deserialize and verify structure + deserialized, err := DeserializeNode(serialized, 0) + if err != nil { + t.Fatalf("Error: %v", err) + } + + // Verify it's an InternalNode with depth 0 + in, ok := deserialized.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", deserialized) + } + if in.depth != 0 { + t.Errorf("Expected depth 0, got %d", in.depth) + } + + // Count leaves at depth 8 + leafCount := countLeavesAtDepth(deserialized, 8, 0) + t.Logf("Leaves at depth 8: %d", leafCount) + if leafCount != 256 { + t.Errorf("Expected 256 leaves, got %d", leafCount) + } +} + +func buildFullTree(depth, maxDepth int) BinaryNode { + if depth == maxDepth { + // Create a unique hash for this position + var h common.Hash + h[0] = byte(depth) + h[1] = byte(depth >> 8) + return HashedNode(h) + } + return &InternalNode{ + depth: depth, + left: buildFullTree(depth+1, maxDepth), + right: buildFullTree(depth+1, maxDepth), + } +} + +func countLeavesAtDepth(node BinaryNode, targetDepth, currentDepth int) int { + if currentDepth == targetDepth { + if _, ok := node.(Empty); ok { + return 0 + } + return 1 + } + in, ok := node.(*InternalNode) + if !ok { + return 0 // Terminated early + } + return countLeavesAtDepth(in.left, targetDepth, currentDepth+1) + + countLeavesAtDepth(in.right, targetDepth, currentDepth+1) +} + +// TestRoundTripPreservesHashes tests that round-trip preserves the original hashes +func TestRoundTripPreservesHashes(t *testing.T) { + // Build a tree with known hashes at specific positions + hashes := make([]common.Hash, 256) + for i := range hashes { + hashes[i] = common.BytesToHash([]byte(fmt.Sprintf("hash-%d", i))) + } + + root := buildTreeWithHashes(0, 8, 0, hashes) + + serialized := SerializeNode(root) + deserialized, err := DeserializeNode(serialized, 0) + if err != nil { + t.Fatalf("Error: %v", err) + } + + // Verify each hash at depth 8 + for i := 0; i < 256; i++ { + node := navigateToLeaf(deserialized, i, 8) + if node == nil { + t.Errorf("Position %d: node is nil", i) + continue + } + if node.Hash() != hashes[i] { + t.Errorf("Position %d: hash mismatch, expected %x, got %x", i, hashes[i], node.Hash()) + } + } +} + +func buildTreeWithHashes(depth, maxDepth, position int, hashes []common.Hash) BinaryNode { + if depth == maxDepth { + return HashedNode(hashes[position]) + } + return &InternalNode{ + depth: depth, + left: buildTreeWithHashes(depth+1, maxDepth, position*2, hashes), + right: buildTreeWithHashes(depth+1, maxDepth, position*2+1, hashes), + } +} + +// TestCollectNodesGrouping verifies that CollectNodes only flushes at group boundaries +// and that the serialized/deserialized tree matches the original. +func TestCollectNodesGrouping(t *testing.T) { + // Build a tree that spans multiple groups (16 levels = 2 groups) + // This creates a tree where: + // - Group 1: depths 0-7 (root group) + // - Group 2: depths 8-15 (leaf groups, up to 256 of them) + // Use unique hashes at leaves so we get unique serialized blobs + root := buildDeepTreeUnique(0, 16, 0) + + // Compute the root hash before collection + originalRootHash := root.Hash() + + // Collect and serialize all nodes, storing by hash + serializedNodes := make(map[common.Hash][]byte) + var collectedNodes []struct { + path []byte + node BinaryNode + } + + err := root.CollectNodes(nil, func(path []byte, node BinaryNode) { + pathCopy := make([]byte, len(path)) + copy(pathCopy, path) + collectedNodes = append(collectedNodes, struct { + path []byte + node BinaryNode + }{pathCopy, node}) + + // Serialize and store by hash + serialized := SerializeNode(node) + serializedNodes[node.Hash()] = serialized + }) + if err != nil { + t.Fatalf("CollectNodes failed: %v", err) + } + + // Count nodes by depth + depthCounts := make(map[int]int) + for _, cn := range collectedNodes { + switch n := cn.node.(type) { + case *InternalNode: + depthCounts[n.depth]++ + case *StemNode: + t.Logf("Collected StemNode at path len %d", len(cn.path)) + } + } + + // With a 16-level tree: + // - 1 node at depth 0 (the root group) + // - 256 nodes at depth 8 (the second-level groups) + // Total: 257 InternalNode groups + if depthCounts[0] != 1 { + t.Errorf("Expected 1 node at depth 0, got %d", depthCounts[0]) + } + if depthCounts[8] != 256 { + t.Errorf("Expected 256 nodes at depth 8, got %d", depthCounts[8]) + } + + t.Logf("Total collected nodes: %d", len(collectedNodes)) + t.Logf("Total serialized blobs: %d", len(serializedNodes)) + t.Logf("Depth counts: %v", depthCounts) + + // Now deserialize starting from the root hash + // Create a resolver that looks up serialized data by hash + resolver := func(path []byte, hash common.Hash) ([]byte, error) { + if data, ok := serializedNodes[hash]; ok { + return data, nil + } + return nil, fmt.Errorf("node not found: %x", hash) + } + + // Deserialize the root + rootData, ok := serializedNodes[originalRootHash] + if !ok { + t.Fatalf("Root hash not found in serialized nodes: %x", originalRootHash) + } + deserializedRoot, err := DeserializeNode(rootData, 0) + if err != nil { + t.Fatalf("Failed to deserialize root: %v", err) + } + + // Verify the deserialized root hash matches + if deserializedRoot.Hash() != originalRootHash { + t.Errorf("Deserialized root hash mismatch: expected %x, got %x", originalRootHash, deserializedRoot.Hash()) + } + + // Traverse both trees and compare structure at all 16 levels + // We need to resolve HashedNodes in the deserialized tree to compare deeper + err = compareTreesWithResolver(t, root, deserializedRoot, resolver, 0, 16, "root") + if err != nil { + t.Errorf("Tree comparison failed: %v", err) + } + + t.Log("Tree comparison passed - deserialized tree matches original") +} + +// compareTreesWithResolver compares two trees, resolving HashedNodes as needed +func compareTreesWithResolver(t *testing.T, original, deserialized BinaryNode, resolver NodeResolverFn, depth, maxDepth int, path string) error { + if depth >= maxDepth { + // At leaf level, just compare hashes + if original.Hash() != deserialized.Hash() { + return fmt.Errorf("hash mismatch at %s: original=%x, deserialized=%x", path, original.Hash(), deserialized.Hash()) + } + return nil + } + + // Get the actual nodes (resolve HashedNodes if needed) + origNode := original + deserNode := deserialized + + // Resolve deserialized HashedNode if needed + if h, ok := deserNode.(HashedNode); ok { + data, err := resolver(nil, common.Hash(h)) + if err != nil { + return fmt.Errorf("failed to resolve deserialized node at %s: %v", path, err) + } + deserNode, err = DeserializeNode(data, depth) + if err != nil { + return fmt.Errorf("failed to deserialize node at %s: %v", path, err) + } + } + + // Both should be InternalNodes at this point + origInternal, origOk := origNode.(*InternalNode) + deserInternal, deserOk := deserNode.(*InternalNode) + + if !origOk || !deserOk { + // Check if both are the same type + if fmt.Sprintf("%T", origNode) != fmt.Sprintf("%T", deserNode) { + return fmt.Errorf("type mismatch at %s: original=%T, deserialized=%T", path, origNode, deserNode) + } + // Both are non-InternalNode, compare hashes + if origNode.Hash() != deserNode.Hash() { + return fmt.Errorf("hash mismatch at %s: original=%x, deserialized=%x", path, origNode.Hash(), deserNode.Hash()) + } + return nil + } + + // Compare depths + if origInternal.depth != deserInternal.depth { + return fmt.Errorf("depth mismatch at %s: original=%d, deserialized=%d", path, origInternal.depth, deserInternal.depth) + } + + // Recursively compare children + if err := compareTreesWithResolver(t, origInternal.left, deserInternal.left, resolver, depth+1, maxDepth, path+"/L"); err != nil { + return err + } + if err := compareTreesWithResolver(t, origInternal.right, deserInternal.right, resolver, depth+1, maxDepth, path+"/R"); err != nil { + return err + } + + return nil +} + +func buildDeepTree(depth, maxDepth int) BinaryNode { + if depth == maxDepth { + // Create a unique hash for this leaf position + var h common.Hash + h[0] = byte(depth) + h[1] = byte(depth >> 8) + return HashedNode(h) + } + return &InternalNode{ + depth: depth, + left: buildDeepTree(depth+1, maxDepth), + right: buildDeepTree(depth+1, maxDepth), + } +} + +// buildDeepTreeUnique builds a tree where each leaf has a unique hash based on its position +func buildDeepTreeUnique(depth, maxDepth, position int) BinaryNode { + if depth == maxDepth { + // Create a unique hash based on position in the tree + var h common.Hash + h[0] = byte(position) + h[1] = byte(position >> 8) + h[2] = byte(position >> 16) + h[3] = byte(position >> 24) + return HashedNode(h) + } + return &InternalNode{ + depth: depth, + left: buildDeepTreeUnique(depth+1, maxDepth, position*2), + right: buildDeepTreeUnique(depth+1, maxDepth, position*2+1), + } +} diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go index 0a7bece521..6351731fe0 100644 --- a/trie/bintrie/internal_node.go +++ b/trie/bintrie/internal_node.go @@ -184,31 +184,82 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve return bt, err } -// CollectNodes collects all child nodes at a given path, and flushes it -// into the provided node collector. +// CollectNodes collects all child nodes at group boundaries (every GroupDepth levels), +// and flushes them into the provided node collector. Each flush serializes an 8-level +// subtree group. Nodes within a group are not flushed individually. func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error { - if bt.left != nil { - var p [256]byte - copy(p[:], path) - childpath := p[:len(path)] - childpath = append(childpath, 0) - if err := bt.left.CollectNodes(childpath, flushfn); err != nil { + // Only flush at group boundaries (depth % GroupDepth == 0) + if bt.depth%GroupDepth == 0 { + // We're at a group boundary - first collect any nodes in deeper groups, + // then flush this group + if err := bt.collectChildGroups(path, flushfn, GroupDepth-1); err != nil { return err } + flushfn(path, bt) + return nil + } + // Not at a group boundary - this shouldn't happen if we're called correctly from root + // but handle it by continuing to traverse + return bt.collectChildGroups(path, flushfn, GroupDepth-(bt.depth%GroupDepth)-1) +} + +// collectChildGroups traverses within a group to find and collect nodes in the next group. +// remainingLevels is how many more levels below the current node until we reach the group boundary. +// When remainingLevels=0, the current node's children are at the next group boundary. +func (bt *InternalNode) collectChildGroups(path []byte, flushfn NodeFlushFn, remainingLevels int) error { + if remainingLevels == 0 { + // Current node is at depth (groupBoundary - 1), its children are at the next group boundary + if bt.left != nil { + if err := bt.left.CollectNodes(appendBit(path, 0), flushfn); err != nil { + return err + } + } + if bt.right != nil { + if err := bt.right.CollectNodes(appendBit(path, 1), flushfn); err != nil { + return err + } + } + return nil + } + + // Continue traversing within the group + if bt.left != nil { + switch n := bt.left.(type) { + case *InternalNode: + if err := n.collectChildGroups(appendBit(path, 0), flushfn, remainingLevels-1); err != nil { + return err + } + default: + // StemNode, HashedNode, or Empty - they handle their own collection + if err := bt.left.CollectNodes(appendBit(path, 0), flushfn); err != nil { + return err + } + } } if bt.right != nil { - var p [256]byte - copy(p[:], path) - childpath := p[:len(path)] - childpath = append(childpath, 1) - if err := bt.right.CollectNodes(childpath, flushfn); err != nil { - return err + switch n := bt.right.(type) { + case *InternalNode: + if err := n.collectChildGroups(appendBit(path, 1), flushfn, remainingLevels-1); err != nil { + return err + } + default: + // StemNode, HashedNode, or Empty - they handle their own collection + if err := bt.right.CollectNodes(appendBit(path, 1), flushfn); err != nil { + return err + } } } - flushfn(path, bt) return nil } +// appendBit appends a bit to a path, returning a new slice +func appendBit(path []byte, bit byte) []byte { + var p [256]byte + copy(p[:], path) + result := p[:len(path)] + return append(result, bit) +} + // GetHeight returns the height of the node. func (bt *InternalNode) GetHeight() int { var (