go-ethereum/trie/bintrie/binary_node_test.go
2026-01-23 11:29:29 +01:00

294 lines
9.1 KiB
Go

// Copyright 2025 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"bytes"
"testing"
"github.com/ethereum/go-ethereum/common"
)
// TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode
// with the grouped subtree format. A single InternalNode with HashedNode children serializes
// as a depth-8 group where the children appear at their first leaf positions.
func TestSerializeDeserializeInternalNode(t *testing.T) {
// Create an internal node with two hashed children
leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321")
node := &InternalNode{
depth: 0, // Use depth 0 (byte-aligned) for this test
left: HashedNode(leftHash),
right: HashedNode(rightHash),
}
// Serialize the node with default group depth of 8
serialized := SerializeNode(node, MaxGroupDepth)
// Check the serialized format: type byte + group depth byte + 32 byte bitmap + N*32 byte hashes
if serialized[0] != nodeTypeInternal {
t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0])
}
if serialized[1] != MaxGroupDepth {
t.Errorf("Expected group depth to be %d, got %d", MaxGroupDepth, serialized[1])
}
// Expected length: 1 (type) + 1 (group depth) + 32 (bitmap) + 2*32 (two hashes) = 98 bytes
bitmapSize := BitmapSizeForDepth(MaxGroupDepth)
expectedLen := NodeTypeBytes + 1 + bitmapSize + 2*HashSize
if len(serialized) != expectedLen {
t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized))
}
// The left child (HashedNode) terminates at remainingDepth=7, so it's placed at position 0<<7 = 0
// The right child (HashedNode) terminates at remainingDepth=7, so it's placed at position 1<<7 = 128
bitmap := serialized[2 : 2+bitmapSize]
if bitmap[0]&0x80 == 0 { // bit 0 (MSB of byte 0)
t.Error("Expected bit 0 to be set in bitmap (left child)")
}
if bitmap[16]&0x80 == 0 { // bit 128 (MSB of byte 16)
t.Error("Expected bit 128 to be set in bitmap (right child)")
}
// Deserialize the node
deserialized, err := DeserializeNode(serialized, 0)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
// With grouped format, deserialization creates a tree of InternalNodes down to the hashes.
// The root should be an InternalNode, and we should be able to navigate down 8 levels
// to find the HashedNode children.
internalNode, ok := deserialized.(*InternalNode)
if !ok {
t.Fatalf("Expected InternalNode, got %T", deserialized)
}
// Check the depth
if internalNode.depth != 0 {
t.Errorf("Expected depth 0, got %d", internalNode.depth)
}
// Navigate to position 0 (8 left turns) to find the left hash
node0 := navigateToLeaf(internalNode, 0, 8)
if node0.Hash() != leftHash {
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, node0.Hash())
}
// Navigate to position 128 (right, then 7 lefts) to find the right hash
node128 := navigateToLeaf(internalNode, 128, 8)
if node128.Hash() != rightHash {
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, node128.Hash())
}
}
// navigateToLeaf navigates to a specific position in the tree (used by grouped serialization tests)
func navigateToLeaf(node BinaryNode, position, depth int) BinaryNode {
for d := 0; d < depth; d++ {
in, ok := node.(*InternalNode)
if !ok {
return node
}
// Check bit at position (depth-1-d) to determine left or right
bit := (position >> (depth - 1 - d)) & 1
if bit == 0 {
node = in.left
} else {
node = in.right
}
}
return node
}
// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode
func TestSerializeDeserializeStemNode(t *testing.T) {
// Create a stem node with some values
stem := make([]byte, StemSize)
for i := range stem {
stem[i] = byte(i)
}
var values [StemNodeWidth][]byte
// Add some values at different indices
values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes()
values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes()
values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes()
node := &StemNode{
Stem: stem,
Values: values[:],
depth: 10,
}
// Serialize the node (groupDepth doesn't affect StemNode serialization)
serialized := SerializeNode(node, MaxGroupDepth)
// Check the serialized format
if serialized[0] != nodeTypeStem {
t.Errorf("Expected type byte to be %d, got %d", nodeTypeStem, serialized[0])
}
// Check the stem is correctly serialized
if !bytes.Equal(serialized[1:1+StemSize], stem) {
t.Errorf("Stem mismatch in serialized data")
}
// Deserialize the node
deserialized, err := DeserializeNode(serialized, 10)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
// Check that it's a stem node
stemNode, ok := deserialized.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", deserialized)
}
// Check the stem
if !bytes.Equal(stemNode.Stem, stem) {
t.Errorf("Stem mismatch after deserialization")
}
// Check the values
if !bytes.Equal(stemNode.Values[0], values[0]) {
t.Errorf("Value at index 0 mismatch")
}
if !bytes.Equal(stemNode.Values[10], values[10]) {
t.Errorf("Value at index 10 mismatch")
}
if !bytes.Equal(stemNode.Values[255], values[255]) {
t.Errorf("Value at index 255 mismatch")
}
// Check that other values are nil
for i := range StemNodeWidth {
if i == 0 || i == 10 || i == 255 {
continue
}
if stemNode.Values[i] != nil {
t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i])
}
}
}
// TestDeserializeEmptyNode tests deserialization of empty node
func TestDeserializeEmptyNode(t *testing.T) {
// Empty byte slice should deserialize to Empty node
deserialized, err := DeserializeNode([]byte{}, 0)
if err != nil {
t.Fatalf("Failed to deserialize empty node: %v", err)
}
_, ok := deserialized.(Empty)
if !ok {
t.Fatalf("Expected Empty node, got %T", deserialized)
}
}
// TestDeserializeInvalidType tests deserialization with invalid type byte
func TestDeserializeInvalidType(t *testing.T) {
// Create invalid serialized data with unknown type byte
invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid
_, err := DeserializeNode(invalidData, 0)
if err == nil {
t.Fatal("Expected error for invalid type byte, got nil")
}
}
// TestDeserializeInvalidLength tests deserialization with invalid data length
func TestDeserializeInvalidLength(t *testing.T) {
// InternalNode with valid type byte and group depth but too short for bitmap
invalidData := []byte{nodeTypeInternal, 8, 0, 0} // Too short for bitmap (needs 32 bytes)
_, err := DeserializeNode(invalidData, 0)
if err == nil {
t.Fatal("Expected error for invalid data length, got nil")
}
if err.Error() != "invalid serialized node length" {
t.Errorf("Expected 'invalid serialized node length' error, got: %v", err)
}
}
// TestKeyToPath tests the keyToPath function
func TestKeyToPath(t *testing.T) {
tests := []struct {
name string
depth int
key []byte
expected []byte
wantErr bool
}{
{
name: "depth 0",
depth: 0,
key: []byte{0x80}, // 10000000 in binary
expected: []byte{1},
wantErr: false,
},
{
name: "depth 7",
depth: 7,
key: []byte{0xFF}, // 11111111 in binary
expected: []byte{1, 1, 1, 1, 1, 1, 1, 1},
wantErr: false,
},
{
name: "depth crossing byte boundary",
depth: 10,
key: []byte{0xFF, 0x00}, // 11111111 00000000 in binary
expected: []byte{1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0},
wantErr: false,
},
{
name: "max valid depth",
depth: StemSize * 8,
key: make([]byte, HashSize),
expected: make([]byte, StemSize*8+1),
wantErr: false,
},
{
name: "depth too large",
depth: StemSize*8 + 1,
key: make([]byte, HashSize),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path, err := keyToPath(tt.depth, tt.key)
if tt.wantErr {
if err == nil {
t.Errorf("Expected error for depth %d, got nil", tt.depth)
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
if !bytes.Equal(path, tt.expected) {
t.Errorf("Path mismatch: expected %v, got %v", tt.expected, path)
}
})
}
}