// Copyright 2025 go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . package bintrie import ( "bytes" "testing" "github.com/ethereum/go-ethereum/common" ) // TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode // with the grouped subtree format. A single InternalNode with HashedNode children serializes // as a depth-8 group where the children appear at their first leaf positions. func TestSerializeDeserializeInternalNode(t *testing.T) { // Create an internal node with two hashed children leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321") node := &InternalNode{ depth: 0, // Use depth 0 (byte-aligned) for this test left: HashedNode(leftHash), right: HashedNode(rightHash), } // Serialize the node with default group depth of 8 serialized := SerializeNode(node, MaxGroupDepth) // Check the serialized format: type byte + group depth byte + 32 byte bitmap + N*32 byte hashes if serialized[0] != nodeTypeInternal { t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0]) } if serialized[1] != MaxGroupDepth { t.Errorf("Expected group depth to be %d, got %d", MaxGroupDepth, serialized[1]) } // Expected length: 1 (type) + 1 (group depth) + 32 (bitmap) + 2*32 (two hashes) = 98 bytes bitmapSize := BitmapSizeForDepth(MaxGroupDepth) expectedLen := NodeTypeBytes + 1 + bitmapSize + 2*HashSize if len(serialized) != expectedLen { t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized)) } // The left child (HashedNode) terminates at remainingDepth=7, so it's placed at position 0<<7 = 0 // The right child (HashedNode) terminates at remainingDepth=7, so it's placed at position 1<<7 = 128 bitmap := serialized[2 : 2+bitmapSize] if bitmap[0]&0x80 == 0 { // bit 0 (MSB of byte 0) t.Error("Expected bit 0 to be set in bitmap (left child)") } if bitmap[16]&0x80 == 0 { // bit 128 (MSB of byte 16) t.Error("Expected bit 128 to be set in bitmap (right child)") } // Deserialize the node deserialized, err := DeserializeNode(serialized, 0) if err != nil { t.Fatalf("Failed to deserialize node: %v", err) } // With grouped format, deserialization creates a tree of InternalNodes down to the hashes. // The root should be an InternalNode, and we should be able to navigate down 8 levels // to find the HashedNode children. internalNode, ok := deserialized.(*InternalNode) if !ok { t.Fatalf("Expected InternalNode, got %T", deserialized) } // Check the depth if internalNode.depth != 0 { t.Errorf("Expected depth 0, got %d", internalNode.depth) } // Navigate to position 0 (8 left turns) to find the left hash node0 := navigateToLeaf(internalNode, 0, 8) if node0.Hash() != leftHash { t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, node0.Hash()) } // Navigate to position 128 (right, then 7 lefts) to find the right hash node128 := navigateToLeaf(internalNode, 128, 8) if node128.Hash() != rightHash { t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, node128.Hash()) } } // navigateToLeaf navigates to a specific position in the tree (used by grouped serialization tests) func navigateToLeaf(node BinaryNode, position, depth int) BinaryNode { for d := 0; d < depth; d++ { in, ok := node.(*InternalNode) if !ok { return node } // Check bit at position (depth-1-d) to determine left or right bit := (position >> (depth - 1 - d)) & 1 if bit == 0 { node = in.left } else { node = in.right } } return node } // TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode func TestSerializeDeserializeStemNode(t *testing.T) { // Create a stem node with some values stem := make([]byte, StemSize) for i := range stem { stem[i] = byte(i) } var values [StemNodeWidth][]byte // Add some values at different indices values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes() values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes() values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes() node := &StemNode{ Stem: stem, Values: values[:], depth: 10, } // Serialize the node (groupDepth doesn't affect StemNode serialization) serialized := SerializeNode(node, MaxGroupDepth) // Check the serialized format if serialized[0] != nodeTypeStem { t.Errorf("Expected type byte to be %d, got %d", nodeTypeStem, serialized[0]) } // Check the stem is correctly serialized if !bytes.Equal(serialized[1:1+StemSize], stem) { t.Errorf("Stem mismatch in serialized data") } // Deserialize the node deserialized, err := DeserializeNode(serialized, 10) if err != nil { t.Fatalf("Failed to deserialize node: %v", err) } // Check that it's a stem node stemNode, ok := deserialized.(*StemNode) if !ok { t.Fatalf("Expected StemNode, got %T", deserialized) } // Check the stem if !bytes.Equal(stemNode.Stem, stem) { t.Errorf("Stem mismatch after deserialization") } // Check the values if !bytes.Equal(stemNode.Values[0], values[0]) { t.Errorf("Value at index 0 mismatch") } if !bytes.Equal(stemNode.Values[10], values[10]) { t.Errorf("Value at index 10 mismatch") } if !bytes.Equal(stemNode.Values[255], values[255]) { t.Errorf("Value at index 255 mismatch") } // Check that other values are nil for i := range StemNodeWidth { if i == 0 || i == 10 || i == 255 { continue } if stemNode.Values[i] != nil { t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i]) } } } // TestDeserializeEmptyNode tests deserialization of empty node func TestDeserializeEmptyNode(t *testing.T) { // Empty byte slice should deserialize to Empty node deserialized, err := DeserializeNode([]byte{}, 0) if err != nil { t.Fatalf("Failed to deserialize empty node: %v", err) } _, ok := deserialized.(Empty) if !ok { t.Fatalf("Expected Empty node, got %T", deserialized) } } // TestDeserializeInvalidType tests deserialization with invalid type byte func TestDeserializeInvalidType(t *testing.T) { // Create invalid serialized data with unknown type byte invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid _, err := DeserializeNode(invalidData, 0) if err == nil { t.Fatal("Expected error for invalid type byte, got nil") } } // TestDeserializeInvalidLength tests deserialization with invalid data length func TestDeserializeInvalidLength(t *testing.T) { // InternalNode with valid type byte and group depth but too short for bitmap invalidData := []byte{nodeTypeInternal, 8, 0, 0} // Too short for bitmap (needs 32 bytes) _, err := DeserializeNode(invalidData, 0) if err == nil { t.Fatal("Expected error for invalid data length, got nil") } if err.Error() != "invalid serialized node length" { t.Errorf("Expected 'invalid serialized node length' error, got: %v", err) } } // TestKeyToPath tests the keyToPath function func TestKeyToPath(t *testing.T) { tests := []struct { name string depth int key []byte expected []byte wantErr bool }{ { name: "depth 0", depth: 0, key: []byte{0x80}, // 10000000 in binary expected: []byte{1}, wantErr: false, }, { name: "depth 7", depth: 7, key: []byte{0xFF}, // 11111111 in binary expected: []byte{1, 1, 1, 1, 1, 1, 1, 1}, wantErr: false, }, { name: "depth crossing byte boundary", depth: 10, key: []byte{0xFF, 0x00}, // 11111111 00000000 in binary expected: []byte{1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}, wantErr: false, }, { name: "max valid depth", depth: StemSize * 8, key: make([]byte, HashSize), expected: make([]byte, StemSize*8+1), wantErr: false, }, { name: "depth too large", depth: StemSize*8 + 1, key: make([]byte, HashSize), wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { path, err := keyToPath(tt.depth, tt.key) if tt.wantErr { if err == nil { t.Errorf("Expected error for depth %d, got nil", tt.depth) } return } if err != nil { t.Errorf("Unexpected error: %v", err) return } if !bytes.Equal(path, tt.expected) { t.Errorf("Path mismatch: expected %v, got %v", tt.expected, path) } }) } }