From bd4b17907fb1f689dbbbe122144a4dc86c2edc43 Mon Sep 17 00:00:00 2001 From: Guillaume Ballet <3272758+gballet@users.noreply.github.com> Date: Mon, 1 Sep 2025 15:06:51 +0200 Subject: [PATCH] trie/bintrie: add eip7864 binary trees and run its tests (#32365) Implement the binary tree as specified in [eip-7864](https://eips.ethereum.org/EIPS/eip-7864). This will gradually replace verkle trees in the codebase. This is only running the tests and will not be executed in production, but will help me rebase some of my work, so that it doesn't bitrot as much. --------- Signed-off-by: Guillaume Ballet Co-authored-by: Parithosh Jayanthi Co-authored-by: rjl493456442 --- core/types/hashes.go | 3 + trie/bintrie/binary_node.go | 133 +++++++++ trie/bintrie/binary_node_test.go | 252 ++++++++++++++++ trie/bintrie/empty.go | 72 +++++ trie/bintrie/empty_test.go | 222 ++++++++++++++ trie/bintrie/hashed_node.go | 66 +++++ trie/bintrie/hashed_node_test.go | 128 ++++++++ trie/bintrie/internal_node.go | 189 ++++++++++++ trie/bintrie/internal_node_test.go | 458 +++++++++++++++++++++++++++++ trie/bintrie/iterator.go | 261 ++++++++++++++++ trie/bintrie/iterator_test.go | 83 ++++++ trie/bintrie/key_encoding.go | 79 +++++ trie/bintrie/stem_node.go | 213 ++++++++++++++ trie/bintrie/stem_node_test.go | 373 +++++++++++++++++++++++ trie/bintrie/trie.go | 353 ++++++++++++++++++++++ trie/bintrie/trie_test.go | 197 +++++++++++++ trie/committer.go | 8 +- trie/iterator.go | 4 +- trie/proof.go | 4 +- trie/tracer.go | 39 ++- trie/trie.go | 26 +- trie/trie_reader.go | 20 +- trie/verkle.go | 16 +- 23 files changed, 3140 insertions(+), 59 deletions(-) create mode 100644 trie/bintrie/binary_node.go create mode 100644 trie/bintrie/binary_node_test.go create mode 100644 trie/bintrie/empty.go create mode 100644 trie/bintrie/empty_test.go create mode 100644 trie/bintrie/hashed_node.go create mode 100644 trie/bintrie/hashed_node_test.go create mode 100644 trie/bintrie/internal_node.go create mode 100644 trie/bintrie/internal_node_test.go create mode 100644 trie/bintrie/iterator.go create mode 100644 trie/bintrie/iterator_test.go create mode 100644 trie/bintrie/key_encoding.go create mode 100644 trie/bintrie/stem_node.go create mode 100644 trie/bintrie/stem_node_test.go create mode 100644 trie/bintrie/trie.go create mode 100644 trie/bintrie/trie_test.go diff --git a/core/types/hashes.go b/core/types/hashes.go index 05cfaeed74..22f1f946dc 100644 --- a/core/types/hashes.go +++ b/core/types/hashes.go @@ -45,4 +45,7 @@ var ( // EmptyVerkleHash is the known hash of an empty verkle trie. EmptyVerkleHash = common.Hash{} + + // EmptyBinaryHash is the known hash of an empty binary trie. + EmptyBinaryHash = common.Hash{} ) diff --git a/trie/bintrie/binary_node.go b/trie/bintrie/binary_node.go new file mode 100644 index 0000000000..1c003a6c8f --- /dev/null +++ b/trie/bintrie/binary_node.go @@ -0,0 +1,133 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "errors" + + "github.com/ethereum/go-ethereum/common" +) + +type ( + NodeFlushFn func([]byte, BinaryNode) + NodeResolverFn func([]byte, common.Hash) ([]byte, error) +) + +// zero is the zero value for a 32-byte array. +var zero [32]byte + +const ( + NodeWidth = 256 // Number of child per leaf node + StemSize = 31 // Number of bytes to travel before reaching a group of leaves +) + +const ( + nodeTypeStem = iota + 1 // Stem node, contains a stem and a bitmap of values + nodeTypeInternal +) + +// BinaryNode is an interface for a binary trie node. +type BinaryNode interface { + Get([]byte, NodeResolverFn) ([]byte, error) + Insert([]byte, []byte, NodeResolverFn, int) (BinaryNode, error) + Copy() BinaryNode + Hash() common.Hash + GetValuesAtStem([]byte, NodeResolverFn) ([][]byte, error) + InsertValuesAtStem([]byte, [][]byte, NodeResolverFn, int) (BinaryNode, error) + CollectNodes([]byte, NodeFlushFn) error + + toDot(parent, path string) string + GetHeight() int +} + +// SerializeNode serializes a binary trie node into a byte slice. +func SerializeNode(node BinaryNode) []byte { + switch n := (node).(type) { + case *InternalNode: + var serialized [65]byte + serialized[0] = nodeTypeInternal + copy(serialized[1:33], n.left.Hash().Bytes()) + copy(serialized[33:65], n.right.Hash().Bytes()) + return serialized[:] + case *StemNode: + var serialized [32 + 32 + 256*32]byte + serialized[0] = nodeTypeStem + copy(serialized[1:32], node.(*StemNode).Stem) + bitmap := serialized[32:64] + offset := 64 + for i, v := range node.(*StemNode).Values { + if v != nil { + bitmap[i/8] |= 1 << (7 - (i % 8)) + copy(serialized[offset:offset+32], v) + offset += 32 + } + } + return serialized[:] + default: + panic("invalid node type") + } +} + +var invalidSerializedLength = errors.New("invalid serialized node length") + +// DeserializeNode deserializes a binary trie node from a byte slice. +func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) { + if len(serialized) == 0 { + return Empty{}, nil + } + + switch serialized[0] { + case nodeTypeInternal: + if len(serialized) != 65 { + return nil, invalidSerializedLength + } + return &InternalNode{ + depth: depth, + left: HashedNode(common.BytesToHash(serialized[1:33])), + right: HashedNode(common.BytesToHash(serialized[33:65])), + }, nil + case nodeTypeStem: + if len(serialized) < 64 { + return nil, invalidSerializedLength + } + var values [256][]byte + bitmap := serialized[32:64] + offset := 64 + + for i := range 256 { + if bitmap[i/8]>>(7-(i%8))&1 == 1 { + if len(serialized) < offset+32 { + return nil, invalidSerializedLength + } + values[i] = serialized[offset : offset+32] + offset += 32 + } + } + return &StemNode{ + Stem: serialized[1:32], + Values: values[:], + depth: depth, + }, nil + default: + return nil, errors.New("invalid node type") + } +} + +// ToDot converts the binary trie to a DOT language representation. Useful for debugging. +func ToDot(root BinaryNode) string { + return root.toDot("", "") +} diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go new file mode 100644 index 0000000000..b21daaab69 --- /dev/null +++ b/trie/bintrie/binary_node_test.go @@ -0,0 +1,252 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode +func TestSerializeDeserializeInternalNode(t *testing.T) { + // Create an internal node with two hashed children + leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321") + + node := &InternalNode{ + depth: 5, + left: HashedNode(leftHash), + right: HashedNode(rightHash), + } + + // Serialize the node + serialized := SerializeNode(node) + + // Check the serialized format + if serialized[0] != nodeTypeInternal { + t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0]) + } + + if len(serialized) != 65 { + t.Errorf("Expected serialized length to be 65, got %d", len(serialized)) + } + + // Deserialize the node + deserialized, err := DeserializeNode(serialized, 5) + if err != nil { + t.Fatalf("Failed to deserialize node: %v", err) + } + + // Check that it's an internal node + internalNode, ok := deserialized.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", deserialized) + } + + // Check the depth + if internalNode.depth != 5 { + t.Errorf("Expected depth 5, got %d", internalNode.depth) + } + + // Check the left and right hashes + if internalNode.left.Hash() != leftHash { + t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash()) + } + + if internalNode.right.Hash() != rightHash { + t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash()) + } +} + +// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode +func TestSerializeDeserializeStemNode(t *testing.T) { + // Create a stem node with some values + stem := make([]byte, 31) + for i := range stem { + stem[i] = byte(i) + } + + var values [256][]byte + // Add some values at different indices + values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes() + values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes() + values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 10, + } + + // Serialize the node + serialized := SerializeNode(node) + + // Check the serialized format + if serialized[0] != nodeTypeStem { + t.Errorf("Expected type byte to be %d, got %d", nodeTypeStem, serialized[0]) + } + + // Check the stem is correctly serialized + if !bytes.Equal(serialized[1:32], stem) { + t.Errorf("Stem mismatch in serialized data") + } + + // Deserialize the node + deserialized, err := DeserializeNode(serialized, 10) + if err != nil { + t.Fatalf("Failed to deserialize node: %v", err) + } + + // Check that it's a stem node + stemNode, ok := deserialized.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", deserialized) + } + + // Check the stem + if !bytes.Equal(stemNode.Stem, stem) { + t.Errorf("Stem mismatch after deserialization") + } + + // Check the values + if !bytes.Equal(stemNode.Values[0], values[0]) { + t.Errorf("Value at index 0 mismatch") + } + if !bytes.Equal(stemNode.Values[10], values[10]) { + t.Errorf("Value at index 10 mismatch") + } + if !bytes.Equal(stemNode.Values[255], values[255]) { + t.Errorf("Value at index 255 mismatch") + } + + // Check that other values are nil + for i := range NodeWidth { + if i == 0 || i == 10 || i == 255 { + continue + } + if stemNode.Values[i] != nil { + t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i]) + } + } +} + +// TestDeserializeEmptyNode tests deserialization of empty node +func TestDeserializeEmptyNode(t *testing.T) { + // Empty byte slice should deserialize to Empty node + deserialized, err := DeserializeNode([]byte{}, 0) + if err != nil { + t.Fatalf("Failed to deserialize empty node: %v", err) + } + + _, ok := deserialized.(Empty) + if !ok { + t.Fatalf("Expected Empty node, got %T", deserialized) + } +} + +// TestDeserializeInvalidType tests deserialization with invalid type byte +func TestDeserializeInvalidType(t *testing.T) { + // Create invalid serialized data with unknown type byte + invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid + + _, err := DeserializeNode(invalidData, 0) + if err == nil { + t.Fatal("Expected error for invalid type byte, got nil") + } +} + +// TestDeserializeInvalidLength tests deserialization with invalid data length +func TestDeserializeInvalidLength(t *testing.T) { + // InternalNode with type byte 1 but wrong length + invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node + + _, err := DeserializeNode(invalidData, 0) + if err == nil { + t.Fatal("Expected error for invalid data length, got nil") + } + + if err.Error() != "invalid serialized node length" { + t.Errorf("Expected 'invalid serialized node length' error, got: %v", err) + } +} + +// TestKeyToPath tests the keyToPath function +func TestKeyToPath(t *testing.T) { + tests := []struct { + name string + depth int + key []byte + expected []byte + wantErr bool + }{ + { + name: "depth 0", + depth: 0, + key: []byte{0x80}, // 10000000 in binary + expected: []byte{1}, + wantErr: false, + }, + { + name: "depth 7", + depth: 7, + key: []byte{0xFF}, // 11111111 in binary + expected: []byte{1, 1, 1, 1, 1, 1, 1, 1}, + wantErr: false, + }, + { + name: "depth crossing byte boundary", + depth: 10, + key: []byte{0xFF, 0x00}, // 11111111 00000000 in binary + expected: []byte{1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}, + wantErr: false, + }, + { + name: "max valid depth", + depth: 31 * 8, + key: make([]byte, 32), + expected: make([]byte, 31*8+1), + wantErr: false, + }, + { + name: "depth too large", + depth: 31*8 + 1, + key: make([]byte, 32), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path, err := keyToPath(tt.depth, tt.key) + if tt.wantErr { + if err == nil { + t.Errorf("Expected error for depth %d, got nil", tt.depth) + } + return + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + if !bytes.Equal(path, tt.expected) { + t.Errorf("Path mismatch: expected %v, got %v", tt.expected, path) + } + }) + } +} diff --git a/trie/bintrie/empty.go b/trie/bintrie/empty.go new file mode 100644 index 0000000000..7cfe373b35 --- /dev/null +++ b/trie/bintrie/empty.go @@ -0,0 +1,72 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "slices" + + "github.com/ethereum/go-ethereum/common" +) + +type Empty struct{} + +func (e Empty) Get(_ []byte, _ NodeResolverFn) ([]byte, error) { + return nil, nil +} + +func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) { + var values [256][]byte + values[key[31]] = value + return &StemNode{ + Stem: slices.Clone(key[:31]), + Values: values[:], + depth: depth, + }, nil +} + +func (e Empty) Copy() BinaryNode { + return Empty{} +} + +func (e Empty) Hash() common.Hash { + return common.Hash{} +} + +func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) { + var values [256][]byte + return values[:], nil +} + +func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) { + return &StemNode{ + Stem: slices.Clone(key[:31]), + Values: values, + depth: depth, + }, nil +} + +func (e Empty) CollectNodes(_ []byte, _ NodeFlushFn) error { + return nil +} + +func (e Empty) toDot(parent string, path string) string { + return "" +} + +func (e Empty) GetHeight() int { + return 0 +} diff --git a/trie/bintrie/empty_test.go b/trie/bintrie/empty_test.go new file mode 100644 index 0000000000..574ae1830b --- /dev/null +++ b/trie/bintrie/empty_test.go @@ -0,0 +1,222 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestEmptyGet tests the Get method +func TestEmptyGet(t *testing.T) { + node := Empty{} + + key := make([]byte, 32) + value, err := node.Get(key, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if value != nil { + t.Errorf("Expected nil value from empty node, got %x", value) + } +} + +// TestEmptyInsert tests the Insert method +func TestEmptyInsert(t *testing.T) { + node := Empty{} + + key := make([]byte, 32) + key[0] = 0x12 + key[31] = 0x34 + value := common.HexToHash("0xabcd").Bytes() + + newNode, err := node.Insert(key, value, nil, 0) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + + // Should create a StemNode + stemNode, ok := newNode.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", newNode) + } + + // Check the stem (first 31 bytes of key) + if !bytes.Equal(stemNode.Stem, key[:31]) { + t.Errorf("Stem mismatch: expected %x, got %x", key[:31], stemNode.Stem) + } + + // Check the value at the correct index (last byte of key) + if !bytes.Equal(stemNode.Values[key[31]], value) { + t.Errorf("Value mismatch at index %d: expected %x, got %x", key[31], value, stemNode.Values[key[31]]) + } + + // Check that other values are nil + for i := 0; i < 256; i++ { + if i != int(key[31]) && stemNode.Values[i] != nil { + t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i]) + } + } +} + +// TestEmptyCopy tests the Copy method +func TestEmptyCopy(t *testing.T) { + node := Empty{} + + copied := node.Copy() + copiedEmpty, ok := copied.(Empty) + if !ok { + t.Fatalf("Expected Empty, got %T", copied) + } + + // Both should be empty + if node != copiedEmpty { + // Empty is a zero-value struct, so copies should be equal + t.Errorf("Empty nodes should be equal") + } +} + +// TestEmptyHash tests the Hash method +func TestEmptyHash(t *testing.T) { + node := Empty{} + + hash := node.Hash() + + // Empty node should have zero hash + if hash != (common.Hash{}) { + t.Errorf("Expected zero hash for empty node, got %x", hash) + } +} + +// TestEmptyGetValuesAtStem tests the GetValuesAtStem method +func TestEmptyGetValuesAtStem(t *testing.T) { + node := Empty{} + + stem := make([]byte, 31) + values, err := node.GetValuesAtStem(stem, nil) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Should return an array of 256 nil values + if len(values) != 256 { + t.Errorf("Expected 256 values, got %d", len(values)) + } + + for i, v := range values { + if v != nil { + t.Errorf("Expected nil value at index %d, got %x", i, v) + } + } +} + +// TestEmptyInsertValuesAtStem tests the InsertValuesAtStem method +func TestEmptyInsertValuesAtStem(t *testing.T) { + node := Empty{} + + stem := make([]byte, 31) + stem[0] = 0x42 + + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + values[10] = common.HexToHash("0x0202").Bytes() + values[255] = common.HexToHash("0x0303").Bytes() + + newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 5) + if err != nil { + t.Fatalf("Failed to insert values: %v", err) + } + + // Should create a StemNode + stemNode, ok := newNode.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", newNode) + } + + // Check the stem + if !bytes.Equal(stemNode.Stem, stem) { + t.Errorf("Stem mismatch: expected %x, got %x", stem, stemNode.Stem) + } + + // Check the depth + if stemNode.depth != 5 { + t.Errorf("Depth mismatch: expected 5, got %d", stemNode.depth) + } + + // Check the values + if !bytes.Equal(stemNode.Values[0], values[0]) { + t.Error("Value at index 0 mismatch") + } + if !bytes.Equal(stemNode.Values[10], values[10]) { + t.Error("Value at index 10 mismatch") + } + if !bytes.Equal(stemNode.Values[255], values[255]) { + t.Error("Value at index 255 mismatch") + } + + // Check that values is the same slice (not a copy) + if &stemNode.Values[0] != &values[0] { + t.Error("Expected values to be the same slice reference") + } +} + +// TestEmptyCollectNodes tests the CollectNodes method +func TestEmptyCollectNodes(t *testing.T) { + node := Empty{} + + var collected []BinaryNode + flushFn := func(path []byte, n BinaryNode) { + collected = append(collected, n) + } + + err := node.CollectNodes([]byte{0, 1, 0}, flushFn) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Should not collect anything for empty node + if len(collected) != 0 { + t.Errorf("Expected no collected nodes for empty, got %d", len(collected)) + } +} + +// TestEmptyToDot tests the toDot method +func TestEmptyToDot(t *testing.T) { + node := Empty{} + + dot := node.toDot("parent", "010") + + // Should return empty string for empty node + if dot != "" { + t.Errorf("Expected empty string for empty node toDot, got %s", dot) + } +} + +// TestEmptyGetHeight tests the GetHeight method +func TestEmptyGetHeight(t *testing.T) { + node := Empty{} + + height := node.GetHeight() + + // Empty node should have height 0 + if height != 0 { + t.Errorf("Expected height 0 for empty node, got %d", height) + } +} diff --git a/trie/bintrie/hashed_node.go b/trie/bintrie/hashed_node.go new file mode 100644 index 0000000000..8f9fd66a59 --- /dev/null +++ b/trie/bintrie/hashed_node.go @@ -0,0 +1,66 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +type HashedNode common.Hash + +func (h HashedNode) Get(_ []byte, _ NodeResolverFn) ([]byte, error) { + panic("not implemented") // TODO: Implement +} + +func (h HashedNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { + return nil, errors.New("insert not implemented for hashed node") +} + +func (h HashedNode) Copy() BinaryNode { + nh := common.Hash(h) + return HashedNode(nh) +} + +func (h HashedNode) Hash() common.Hash { + return common.Hash(h) +} + +func (h HashedNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) { + return nil, errors.New("attempted to get values from an unresolved node") +} + +func (h HashedNode) InsertValuesAtStem(key []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { + return nil, errors.New("insertValuesAtStem not implemented for hashed node") +} + +func (h HashedNode) toDot(parent string, path string) string { + me := fmt.Sprintf("hash%s", path) + ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, h) + ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) + return ret +} + +func (h HashedNode) CollectNodes([]byte, NodeFlushFn) error { + return errors.New("collectNodes not implemented for hashed node") +} + +func (h HashedNode) GetHeight() int { + panic("tried to get the height of a hashed node, this is a bug") +} diff --git a/trie/bintrie/hashed_node_test.go b/trie/bintrie/hashed_node_test.go new file mode 100644 index 0000000000..0c19ae0c57 --- /dev/null +++ b/trie/bintrie/hashed_node_test.go @@ -0,0 +1,128 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestHashedNodeHash tests the Hash method +func TestHashedNodeHash(t *testing.T) { + hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + node := HashedNode(hash) + + // Hash should return the stored hash + if node.Hash() != hash { + t.Errorf("Hash mismatch: expected %x, got %x", hash, node.Hash()) + } +} + +// TestHashedNodeCopy tests the Copy method +func TestHashedNodeCopy(t *testing.T) { + hash := common.HexToHash("0xabcdef") + node := HashedNode(hash) + + copied := node.Copy() + copiedHash, ok := copied.(HashedNode) + if !ok { + t.Fatalf("Expected HashedNode, got %T", copied) + } + + // Hash should be the same + if common.Hash(copiedHash) != hash { + t.Errorf("Hash mismatch after copy: expected %x, got %x", hash, copiedHash) + } + + // But should be a different object + if &node == &copiedHash { + t.Error("Copy returned same object reference") + } +} + +// TestHashedNodeInsert tests that Insert returns an error +func TestHashedNodeInsert(t *testing.T) { + node := HashedNode(common.HexToHash("0x1234")) + + key := make([]byte, 32) + value := make([]byte, 32) + + _, err := node.Insert(key, value, nil, 0) + if err == nil { + t.Fatal("Expected error for Insert on HashedNode") + } + + if err.Error() != "insert not implemented for hashed node" { + t.Errorf("Unexpected error message: %v", err) + } +} + +// TestHashedNodeGetValuesAtStem tests that GetValuesAtStem returns an error +func TestHashedNodeGetValuesAtStem(t *testing.T) { + node := HashedNode(common.HexToHash("0x1234")) + + stem := make([]byte, 31) + _, err := node.GetValuesAtStem(stem, nil) + if err == nil { + t.Fatal("Expected error for GetValuesAtStem on HashedNode") + } + + if err.Error() != "attempted to get values from an unresolved node" { + t.Errorf("Unexpected error message: %v", err) + } +} + +// TestHashedNodeInsertValuesAtStem tests that InsertValuesAtStem returns an error +func TestHashedNodeInsertValuesAtStem(t *testing.T) { + node := HashedNode(common.HexToHash("0x1234")) + + stem := make([]byte, 31) + values := make([][]byte, 256) + + _, err := node.InsertValuesAtStem(stem, values, nil, 0) + if err == nil { + t.Fatal("Expected error for InsertValuesAtStem on HashedNode") + } + + if err.Error() != "insertValuesAtStem not implemented for hashed node" { + t.Errorf("Unexpected error message: %v", err) + } +} + +// TestHashedNodeToDot tests the toDot method for visualization +func TestHashedNodeToDot(t *testing.T) { + hash := common.HexToHash("0x1234") + node := HashedNode(hash) + + dot := node.toDot("parent", "010") + + // Should contain the hash value and parent connection + expectedHash := "hash010" + if !contains(dot, expectedHash) { + t.Errorf("Expected dot output to contain %s", expectedHash) + } + + if !contains(dot, "parent -> hash010") { + t.Error("Expected dot output to contain parent connection") + } +} + +// Helper function +func contains(s, substr string) bool { + return len(s) >= len(substr) && s != "" && len(substr) > 0 +} diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go new file mode 100644 index 0000000000..f3ddd1aab0 --- /dev/null +++ b/trie/bintrie/internal_node.go @@ -0,0 +1,189 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "crypto/sha256" + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +func keyToPath(depth int, key []byte) ([]byte, error) { + if depth > 31*8 { + return nil, errors.New("node too deep") + } + path := make([]byte, 0, depth+1) + for i := range depth + 1 { + bit := key[i/8] >> (7 - (i % 8)) & 1 + path = append(path, bit) + } + return path, nil +} + +// InternalNode is a binary trie internal node. +type InternalNode struct { + left, right BinaryNode + depth int +} + +// GetValuesAtStem retrieves the group of values located at the given stem key. +func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) { + if bt.depth > 31*8 { + return nil, errors.New("node too deep") + } + + bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 + var child *BinaryNode + if bit == 0 { + child = &bt.left + } else { + child = &bt.right + } + + if hn, ok := (*child).(HashedNode); ok { + path, err := keyToPath(bt.depth, stem) + if err != nil { + return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) + } + data, err := resolver(path, common.Hash(hn)) + if err != nil { + return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) + } + node, err := DeserializeNode(data, bt.depth+1) + if err != nil { + return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) + } + *child = node + } + return (*child).GetValuesAtStem(stem, resolver) +} + +// Get retrieves the value for the given key. +func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error) { + values, err := bt.GetValuesAtStem(key[:31], resolver) + if err != nil { + return nil, fmt.Errorf("get error: %w", err) + } + return values[key[31]], nil +} + +// Insert inserts a new key-value pair into the trie. +func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { + var values [256][]byte + values[key[31]] = value + return bt.InsertValuesAtStem(key[:31], values[:], resolver, depth) +} + +// Copy creates a deep copy of the node. +func (bt *InternalNode) Copy() BinaryNode { + return &InternalNode{ + left: bt.left.Copy(), + right: bt.right.Copy(), + depth: bt.depth, + } +} + +// Hash returns the hash of the node. +func (bt *InternalNode) Hash() common.Hash { + h := sha256.New() + if bt.left != nil { + h.Write(bt.left.Hash().Bytes()) + } else { + h.Write(zero[:]) + } + if bt.right != nil { + h.Write(bt.right.Hash().Bytes()) + } else { + h.Write(zero[:]) + } + return common.BytesToHash(h.Sum(nil)) +} + +// InsertValuesAtStem inserts a full value group at the given stem in the internal node. +// Already-existing values will be overwritten. +func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { + var ( + child *BinaryNode + err error + ) + bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 + if bit == 0 { + child = &bt.left + } else { + child = &bt.right + } + *child, err = (*child).InsertValuesAtStem(stem, values, resolver, depth+1) + return bt, err +} + +// CollectNodes collects all child nodes at a given path, and flushes it +// into the provided node collector. +func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error { + if bt.left != nil { + var p [256]byte + copy(p[:], path) + childpath := p[:len(path)] + childpath = append(childpath, 0) + if err := bt.left.CollectNodes(childpath, flushfn); err != nil { + return err + } + } + if bt.right != nil { + var p [256]byte + copy(p[:], path) + childpath := p[:len(path)] + childpath = append(childpath, 1) + if err := bt.right.CollectNodes(childpath, flushfn); err != nil { + return err + } + } + flushfn(path, bt) + return nil +} + +// GetHeight returns the height of the node. +func (bt *InternalNode) GetHeight() int { + var ( + leftHeight int + rightHeight int + ) + if bt.left != nil { + leftHeight = bt.left.GetHeight() + } + if bt.right != nil { + rightHeight = bt.right.GetHeight() + } + return 1 + max(leftHeight, rightHeight) +} + +func (bt *InternalNode) toDot(parent, path string) string { + me := fmt.Sprintf("internal%s", path) + ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, bt.Hash()) + if len(parent) > 0 { + ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) + } + + if bt.left != nil { + ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0))) + } + if bt.right != nil { + ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1))) + } + return ret +} diff --git a/trie/bintrie/internal_node_test.go b/trie/bintrie/internal_node_test.go new file mode 100644 index 0000000000..158d8b7147 --- /dev/null +++ b/trie/bintrie/internal_node_test.go @@ -0,0 +1,458 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "errors" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestInternalNodeGet tests the Get method +func TestInternalNodeGet(t *testing.T) { + // Create a simple tree structure + leftStem := make([]byte, 31) + rightStem := make([]byte, 31) + rightStem[0] = 0x80 // First bit is 1 + + var leftValues, rightValues [256][]byte + leftValues[0] = common.HexToHash("0x0101").Bytes() + rightValues[0] = common.HexToHash("0x0202").Bytes() + + node := &InternalNode{ + depth: 0, + left: &StemNode{ + Stem: leftStem, + Values: leftValues[:], + depth: 1, + }, + right: &StemNode{ + Stem: rightStem, + Values: rightValues[:], + depth: 1, + }, + } + + // Get value from left subtree + leftKey := make([]byte, 32) + leftKey[31] = 0 + value, err := node.Get(leftKey, nil) + if err != nil { + t.Fatalf("Failed to get left value: %v", err) + } + if !bytes.Equal(value, leftValues[0]) { + t.Errorf("Left value mismatch: expected %x, got %x", leftValues[0], value) + } + + // Get value from right subtree + rightKey := make([]byte, 32) + rightKey[0] = 0x80 + rightKey[31] = 0 + value, err = node.Get(rightKey, nil) + if err != nil { + t.Fatalf("Failed to get right value: %v", err) + } + if !bytes.Equal(value, rightValues[0]) { + t.Errorf("Right value mismatch: expected %x, got %x", rightValues[0], value) + } +} + +// TestInternalNodeGetWithResolver tests Get with HashedNode resolution +func TestInternalNodeGetWithResolver(t *testing.T) { + // Create an internal node with a hashed child + hashedChild := HashedNode(common.HexToHash("0x1234")) + + node := &InternalNode{ + depth: 0, + left: hashedChild, + right: Empty{}, + } + + // Mock resolver that returns a stem node + resolver := func(path []byte, hash common.Hash) ([]byte, error) { + if hash == common.Hash(hashedChild) { + stem := make([]byte, 31) + var values [256][]byte + values[5] = common.HexToHash("0xabcd").Bytes() + stemNode := &StemNode{ + Stem: stem, + Values: values[:], + depth: 1, + } + return SerializeNode(stemNode), nil + } + return nil, errors.New("node not found") + } + + // Get value through the hashed node + key := make([]byte, 32) + key[31] = 5 + value, err := node.Get(key, resolver) + if err != nil { + t.Fatalf("Failed to get value: %v", err) + } + + expectedValue := common.HexToHash("0xabcd").Bytes() + if !bytes.Equal(value, expectedValue) { + t.Errorf("Value mismatch: expected %x, got %x", expectedValue, value) + } +} + +// TestInternalNodeInsert tests the Insert method +func TestInternalNodeInsert(t *testing.T) { + // Start with an internal node with empty children + node := &InternalNode{ + depth: 0, + left: Empty{}, + right: Empty{}, + } + + // Insert a value into the left subtree + leftKey := make([]byte, 32) + leftKey[31] = 10 + leftValue := common.HexToHash("0x0101").Bytes() + + newNode, err := node.Insert(leftKey, leftValue, nil, 0) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + + internalNode, ok := newNode.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", newNode) + } + + // Check that left child is now a StemNode + leftStem, ok := internalNode.left.(*StemNode) + if !ok { + t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) + } + + // Check the inserted value + if !bytes.Equal(leftStem.Values[10], leftValue) { + t.Errorf("Value mismatch: expected %x, got %x", leftValue, leftStem.Values[10]) + } + + // Right child should still be Empty + _, ok = internalNode.right.(Empty) + if !ok { + t.Errorf("Expected right child to remain Empty, got %T", internalNode.right) + } +} + +// TestInternalNodeCopy tests the Copy method +func TestInternalNodeCopy(t *testing.T) { + // Create an internal node with stem children + leftStem := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 1, + } + leftStem.Values[0] = common.HexToHash("0x0101").Bytes() + + rightStem := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 1, + } + rightStem.Stem[0] = 0x80 + rightStem.Values[0] = common.HexToHash("0x0202").Bytes() + + node := &InternalNode{ + depth: 0, + left: leftStem, + right: rightStem, + } + + // Create a copy + copied := node.Copy() + copiedInternal, ok := copied.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", copied) + } + + // Check depth + if copiedInternal.depth != node.depth { + t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedInternal.depth) + } + + // Check that children are copied + copiedLeft, ok := copiedInternal.left.(*StemNode) + if !ok { + t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.left) + } + + copiedRight, ok := copiedInternal.right.(*StemNode) + if !ok { + t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.right) + } + + // Verify deep copy (children should be different objects) + if copiedLeft == leftStem { + t.Error("Left child not properly copied") + } + if copiedRight == rightStem { + t.Error("Right child not properly copied") + } + + // But values should be equal + if !bytes.Equal(copiedLeft.Values[0], leftStem.Values[0]) { + t.Error("Left child value mismatch after copy") + } + if !bytes.Equal(copiedRight.Values[0], rightStem.Values[0]) { + t.Error("Right child value mismatch after copy") + } +} + +// TestInternalNodeHash tests the Hash method +func TestInternalNodeHash(t *testing.T) { + // Create an internal node + node := &InternalNode{ + depth: 0, + left: HashedNode(common.HexToHash("0x1111")), + right: HashedNode(common.HexToHash("0x2222")), + } + + hash1 := node.Hash() + + // Hash should be deterministic + hash2 := node.Hash() + if hash1 != hash2 { + t.Errorf("Hash not deterministic: %x != %x", hash1, hash2) + } + + // Changing a child should change the hash + node.left = HashedNode(common.HexToHash("0x3333")) + hash3 := node.Hash() + if hash1 == hash3 { + t.Error("Hash didn't change after modifying left child") + } + + // Test with nil children (should use zero hash) + nodeWithNil := &InternalNode{ + depth: 0, + left: nil, + right: HashedNode(common.HexToHash("0x4444")), + } + hashWithNil := nodeWithNil.Hash() + if hashWithNil == (common.Hash{}) { + t.Error("Hash shouldn't be zero even with nil child") + } +} + +// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method +func TestInternalNodeGetValuesAtStem(t *testing.T) { + // Create a tree with values at different stems + leftStem := make([]byte, 31) + rightStem := make([]byte, 31) + rightStem[0] = 0x80 + + var leftValues, rightValues [256][]byte + leftValues[0] = common.HexToHash("0x0101").Bytes() + leftValues[10] = common.HexToHash("0x0102").Bytes() + rightValues[0] = common.HexToHash("0x0201").Bytes() + rightValues[20] = common.HexToHash("0x0202").Bytes() + + node := &InternalNode{ + depth: 0, + left: &StemNode{ + Stem: leftStem, + Values: leftValues[:], + depth: 1, + }, + right: &StemNode{ + Stem: rightStem, + Values: rightValues[:], + depth: 1, + }, + } + + // Get values from left stem + values, err := node.GetValuesAtStem(leftStem, nil) + if err != nil { + t.Fatalf("Failed to get left values: %v", err) + } + if !bytes.Equal(values[0], leftValues[0]) { + t.Error("Left value at index 0 mismatch") + } + if !bytes.Equal(values[10], leftValues[10]) { + t.Error("Left value at index 10 mismatch") + } + + // Get values from right stem + values, err = node.GetValuesAtStem(rightStem, nil) + if err != nil { + t.Fatalf("Failed to get right values: %v", err) + } + if !bytes.Equal(values[0], rightValues[0]) { + t.Error("Right value at index 0 mismatch") + } + if !bytes.Equal(values[20], rightValues[20]) { + t.Error("Right value at index 20 mismatch") + } +} + +// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method +func TestInternalNodeInsertValuesAtStem(t *testing.T) { + // Start with an internal node with empty children + node := &InternalNode{ + depth: 0, + left: Empty{}, + right: Empty{}, + } + + // Insert values at a stem in the left subtree + stem := make([]byte, 31) + var values [256][]byte + values[5] = common.HexToHash("0x0505").Bytes() + values[10] = common.HexToHash("0x1010").Bytes() + + newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 0) + if err != nil { + t.Fatalf("Failed to insert values: %v", err) + } + + internalNode, ok := newNode.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", newNode) + } + + // Check that left child is now a StemNode with the values + leftStem, ok := internalNode.left.(*StemNode) + if !ok { + t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) + } + + if !bytes.Equal(leftStem.Values[5], values[5]) { + t.Error("Value at index 5 mismatch") + } + if !bytes.Equal(leftStem.Values[10], values[10]) { + t.Error("Value at index 10 mismatch") + } +} + +// TestInternalNodeCollectNodes tests CollectNodes method +func TestInternalNodeCollectNodes(t *testing.T) { + // Create an internal node with two stem children + leftStem := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 1, + } + + rightStem := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 1, + } + rightStem.Stem[0] = 0x80 + + node := &InternalNode{ + depth: 0, + left: leftStem, + right: rightStem, + } + + var collectedPaths [][]byte + var collectedNodes []BinaryNode + + flushFn := func(path []byte, n BinaryNode) { + pathCopy := make([]byte, len(path)) + copy(pathCopy, path) + collectedPaths = append(collectedPaths, pathCopy) + collectedNodes = append(collectedNodes, n) + } + + err := node.CollectNodes([]byte{1}, flushFn) + if err != nil { + t.Fatalf("Failed to collect nodes: %v", err) + } + + // Should have collected 3 nodes: left stem, right stem, and the internal node itself + if len(collectedNodes) != 3 { + t.Errorf("Expected 3 collected nodes, got %d", len(collectedNodes)) + } + + // Check paths + expectedPaths := [][]byte{ + {1, 0}, // left child + {1, 1}, // right child + {1}, // internal node itself + } + + for i, expectedPath := range expectedPaths { + if !bytes.Equal(collectedPaths[i], expectedPath) { + t.Errorf("Path %d mismatch: expected %v, got %v", i, expectedPath, collectedPaths[i]) + } + } +} + +// TestInternalNodeGetHeight tests GetHeight method +func TestInternalNodeGetHeight(t *testing.T) { + // Create a tree with different heights + // Left subtree: depth 2 (internal -> stem) + // Right subtree: depth 1 (stem) + leftInternal := &InternalNode{ + depth: 1, + left: &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 2, + }, + right: Empty{}, + } + + rightStem := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 1, + } + + node := &InternalNode{ + depth: 0, + left: leftInternal, + right: rightStem, + } + + height := node.GetHeight() + // Height should be max(left height, right height) + 1 + // Left height: 2, Right height: 1, so total: 3 + if height != 3 { + t.Errorf("Expected height 3, got %d", height) + } +} + +// TestInternalNodeDepthTooLarge tests handling of excessive depth +func TestInternalNodeDepthTooLarge(t *testing.T) { + // Create an internal node at max depth + node := &InternalNode{ + depth: 31*8 + 1, + left: Empty{}, + right: Empty{}, + } + + stem := make([]byte, 31) + _, err := node.GetValuesAtStem(stem, nil) + if err == nil { + t.Fatal("Expected error for excessive depth") + } + if err.Error() != "node too deep" { + t.Errorf("Expected 'node too deep' error, got: %v", err) + } +} diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go new file mode 100644 index 0000000000..a6bab2bcfa --- /dev/null +++ b/trie/bintrie/iterator.go @@ -0,0 +1,261 @@ +// Copyright 2025 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "errors" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/trie" +) + +var errIteratorEnd = errors.New("end of iteration") + +type binaryNodeIteratorState struct { + Node BinaryNode + Index int +} + +type binaryNodeIterator struct { + trie *BinaryTrie + current BinaryNode + lastErr error + + stack []binaryNodeIteratorState +} + +func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) { + if t.Hash() == zero { + return &binaryNodeIterator{trie: t, lastErr: errIteratorEnd}, nil + } + it := &binaryNodeIterator{trie: t, current: t.root} + // it.err = it.seek(start) + return it, nil +} + +// Next moves the iterator to the next node. If the parameter is false, any child +// nodes will be skipped. +func (it *binaryNodeIterator) Next(descend bool) bool { + if it.lastErr == errIteratorEnd { + it.lastErr = errIteratorEnd + return false + } + + if len(it.stack) == 0 { + it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.root}) + it.current = it.trie.root + + return true + } + + switch node := it.current.(type) { + case *InternalNode: + // index: 0 = nothing visited, 1=left visited, 2=right visited + context := &it.stack[len(it.stack)-1] + + // recurse into both children + if context.Index == 0 { + if _, isempty := node.left.(Empty); node.left != nil && !isempty { + it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left}) + it.current = node.left + return it.Next(descend) + } + + context.Index++ + } + + if context.Index == 1 { + if _, isempty := node.right.(Empty); node.right != nil && !isempty { + it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right}) + it.current = node.right + return it.Next(descend) + } + + context.Index++ + } + + // Reached the end of this node, go back to the parent, if + // this isn't root. + if len(it.stack) == 1 { + it.lastErr = errIteratorEnd + return false + } + it.stack = it.stack[:len(it.stack)-1] + it.current = it.stack[len(it.stack)-1].Node + it.stack[len(it.stack)-1].Index++ + return it.Next(descend) + case *StemNode: + // Look for the next non-empty value + for i := it.stack[len(it.stack)-1].Index; i < 256; i++ { + if node.Values[i] != nil { + it.stack[len(it.stack)-1].Index = i + 1 + return true + } + } + + // go back to parent to get the next leaf + it.stack = it.stack[:len(it.stack)-1] + it.current = it.stack[len(it.stack)-1].Node + it.stack[len(it.stack)-1].Index++ + return it.Next(descend) + case HashedNode: + // resolve the node + data, err := it.trie.nodeResolver(it.Path(), common.Hash(node)) + if err != nil { + panic(err) + } + it.current, err = DeserializeNode(data, len(it.stack)-1) + if err != nil { + panic(err) + } + + // update the stack and parent with the resolved node + it.stack[len(it.stack)-1].Node = it.current + parent := &it.stack[len(it.stack)-2] + if parent.Index == 0 { + parent.Node.(*InternalNode).left = it.current + } else { + parent.Node.(*InternalNode).right = it.current + } + return it.Next(descend) + case Empty: + // do nothing + return false + default: + panic("invalid node type") + } +} + +// Error returns the error status of the iterator. +func (it *binaryNodeIterator) Error() error { + if it.lastErr == errIteratorEnd { + return nil + } + return it.lastErr +} + +// Hash returns the hash of the current node. +func (it *binaryNodeIterator) Hash() common.Hash { + return it.current.Hash() +} + +// Parent returns the hash of the parent of the current node. The hash may be the one +// grandparent if the immediate parent is an internal node with no hash. +func (it *binaryNodeIterator) Parent() common.Hash { + return it.stack[len(it.stack)-1].Node.Hash() +} + +// Path returns the hex-encoded path to the current node. +// Callers must not retain references to the return value after calling Next. +// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10. +func (it *binaryNodeIterator) Path() []byte { + if it.Leaf() { + return it.LeafKey() + } + var path []byte + for i, state := range it.stack { + // skip the last byte + if i >= len(it.stack)-1 { + break + } + path = append(path, byte(state.Index)) + } + return path +} + +// NodeBlob returns the serialized bytes of the current node. +func (it *binaryNodeIterator) NodeBlob() []byte { + return SerializeNode(it.current) +} + +// Leaf returns true iff the current node is a leaf node. +func (it *binaryNodeIterator) Leaf() bool { + _, ok := it.current.(*StemNode) + return ok +} + +// LeafKey returns the key of the leaf. The method panics if the iterator is not +// positioned at a leaf. Callers must not retain references to the value after +// calling Next. +func (it *binaryNodeIterator) LeafKey() []byte { + leaf, ok := it.current.(*StemNode) + if !ok { + panic("Leaf() called on an binary node iterator not at a leaf location") + } + return leaf.Key(it.stack[len(it.stack)-1].Index - 1) +} + +// LeafBlob returns the content of the leaf. The method panics if the iterator +// is not positioned at a leaf. Callers must not retain references to the value +// after calling Next. +func (it *binaryNodeIterator) LeafBlob() []byte { + leaf, ok := it.current.(*StemNode) + if !ok { + panic("LeafBlob() called on an binary node iterator not at a leaf location") + } + return leaf.Values[it.stack[len(it.stack)-1].Index-1] +} + +// LeafProof returns the Merkle proof of the leaf. The method panics if the +// iterator is not positioned at a leaf. Callers must not retain references +// to the value after calling Next. +func (it *binaryNodeIterator) LeafProof() [][]byte { + sn, ok := it.current.(*StemNode) + if !ok { + panic("LeafProof() called on an binary node iterator not at a leaf location") + } + + proof := make([][]byte, 0, len(it.stack)+NodeWidth) + + // Build proof by walking up the stack and collecting sibling hashes + for i := range it.stack[:len(it.stack)-2] { + state := it.stack[i] + internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode + + // Add the sibling hash to the proof + if state.Index == 0 { + // We came from left, so include right sibling + proof = append(proof, internalNode.right.Hash().Bytes()) + } else { + // We came from right, so include left sibling + proof = append(proof, internalNode.left.Hash().Bytes()) + } + } + + // Add the stem and siblings + proof = append(proof, sn.Stem) + for _, v := range sn.Values { + proof = append(proof, v) + } + + return proof +} + +// AddResolver sets an intermediate database to use for looking up trie nodes +// before reaching into the real persistent layer. +// +// This is not required for normal operation, rather is an optimization for +// cases where trie nodes can be recovered from some external mechanism without +// reading from disk. In those cases, this resolver allows short circuiting +// accesses and returning them from memory. +// +// Before adding a similar mechanism to any other place in Geth, consider +// making trie.Database an interface and wrapping at that level. It's a huge +// refactor, but it could be worth it if another occurrence arises. +func (it *binaryNodeIterator) AddResolver(trie.NodeResolver) { + // Not implemented, but should not panic +} diff --git a/trie/bintrie/iterator_test.go b/trie/bintrie/iterator_test.go new file mode 100644 index 0000000000..8773e9e0c5 --- /dev/null +++ b/trie/bintrie/iterator_test.go @@ -0,0 +1,83 @@ +// Copyright 2025 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/triedb" + "github.com/ethereum/go-ethereum/triedb/hashdb" + "github.com/ethereum/go-ethereum/triedb/pathdb" + "github.com/holiman/uint256" +) + +func newTestDatabase(diskdb ethdb.Database, scheme string) *triedb.Database { + config := &triedb.Config{Preimages: true} + if scheme == rawdb.HashScheme { + config.HashDB = &hashdb.Config{CleanCacheSize: 0} + } else { + config.PathDB = &pathdb.Config{TrieCleanSize: 0, StateCleanSize: 0} + } + return triedb.NewDatabase(diskdb, config) +} + +func TestBinaryIterator(t *testing.T) { + trie, err := NewBinaryTrie(types.EmptyVerkleHash, newTestDatabase(rawdb.NewMemoryDatabase(), rawdb.PathScheme)) + if err != nil { + t.Fatal(err) + } + account0 := &types.StateAccount{ + Nonce: 1, + Balance: uint256.NewInt(2), + Root: types.EmptyRootHash, + CodeHash: nil, + } + // NOTE: the code size isn't written to the trie via TryUpdateAccount + // so it will be missing from the test nodes. + trie.UpdateAccount(common.Address{}, account0, 0) + account1 := &types.StateAccount{ + Nonce: 1337, + Balance: uint256.NewInt(2000), + Root: types.EmptyRootHash, + CodeHash: nil, + } + // This address is meant to hash to a value that has the same first byte as 0xbf + var clash = common.HexToAddress("69fd8034cdb20934dedffa7dccb4fb3b8062a8be") + trie.UpdateAccount(clash, account1, 0) + + // Manually go over every node to check that we get all + // the correct nodes. + it, err := trie.NodeIterator(nil) + if err != nil { + t.Fatal(err) + } + var leafcount int + for it.Next(true) { + t.Logf("Node: %x", it.Path()) + if it.Leaf() { + leafcount++ + t.Logf("\tLeaf: %x", it.LeafKey()) + } + } + if leafcount != 2 { + t.Fatalf("invalid leaf count: %d != 6", leafcount) + } +} diff --git a/trie/bintrie/key_encoding.go b/trie/bintrie/key_encoding.go new file mode 100644 index 0000000000..13c2057371 --- /dev/null +++ b/trie/bintrie/key_encoding.go @@ -0,0 +1,79 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "crypto/sha256" + + "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" +) + +const ( + BasicDataLeafKey = 0 + CodeHashLeafKey = 1 + BasicDataCodeSizeOffset = 5 + BasicDataNonceOffset = 8 + BasicDataBalanceOffset = 16 +) + +var ( + zeroHash = common.Hash{} + codeOffset = uint256.NewInt(128) +) + +func GetBinaryTreeKey(addr common.Address, key []byte) []byte { + hasher := sha256.New() + hasher.Write(zeroHash[:12]) + hasher.Write(addr[:]) + hasher.Write(key[:31]) + k := hasher.Sum(nil) + k[31] = key[31] + return k +} + +func GetBinaryTreeKeyCodeHash(addr common.Address) []byte { + var k [32]byte + k[31] = CodeHashLeafKey + return GetBinaryTreeKey(addr, k[:]) +} + +func GetBinaryTreeKeyStorageSlot(address common.Address, key []byte) []byte { + var k [32]byte + + // Case when the key belongs to the account header + if bytes.Equal(key[:31], zeroHash[:31]) && key[31] < 64 { + k[31] = 64 + key[31] + return GetBinaryTreeKey(address, k[:]) + } + + // Set the main storage offset + // note that the first 64 bytes of the main offset storage + // are unreachable, which is consistent with the spec and + // what verkle does. + k[0] = 1 // 1 << 248 + copy(k[1:], key[:31]) + k[31] = key[31] + + return GetBinaryTreeKey(address, k[:]) +} + +func GetBinaryTreeKeyCodeChunk(address common.Address, chunknr *uint256.Int) []byte { + chunkOffset := new(uint256.Int).Add(codeOffset, chunknr).Bytes() + return GetBinaryTreeKey(address, chunkOffset) +} diff --git a/trie/bintrie/stem_node.go b/trie/bintrie/stem_node.go new file mode 100644 index 0000000000..50c06c9761 --- /dev/null +++ b/trie/bintrie/stem_node.go @@ -0,0 +1,213 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "slices" + + "github.com/ethereum/go-ethereum/common" +) + +// StemNode represents a group of `NodeWith` values sharing the same stem. +type StemNode struct { + Stem []byte // Stem path to get to 256 values + Values [][]byte // All values, indexed by the last byte of the key. + depth int // Depth of the node +} + +// Get retrieves the value for the given key. +func (bt *StemNode) Get(key []byte, _ NodeResolverFn) ([]byte, error) { + panic("this should not be called directly") +} + +// Insert inserts a new key-value pair into the node. +func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) { + if !bytes.Equal(bt.Stem, key[:31]) { + bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 + + n := &InternalNode{depth: bt.depth} + bt.depth++ + var child, other *BinaryNode + if bitStem == 0 { + n.left = bt + child = &n.left + other = &n.right + } else { + n.right = bt + child = &n.right + other = &n.left + } + + bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1 + if bitKey == bitStem { + var err error + *child, err = (*child).Insert(key, value, nil, depth+1) + if err != nil { + return n, fmt.Errorf("insert error: %w", err) + } + *other = Empty{} + } else { + var values [256][]byte + values[key[31]] = value + *other = &StemNode{ + Stem: slices.Clone(key[:31]), + Values: values[:], + depth: depth + 1, + } + } + return n, nil + } + if len(value) != 32 { + return bt, errors.New("invalid insertion: value length") + } + bt.Values[key[31]] = value + return bt, nil +} + +// Copy creates a deep copy of the node. +func (bt *StemNode) Copy() BinaryNode { + var values [256][]byte + for i, v := range bt.Values { + values[i] = slices.Clone(v) + } + return &StemNode{ + Stem: slices.Clone(bt.Stem), + Values: values[:], + depth: bt.depth, + } +} + +// GetHeight returns the height of the node. +func (bt *StemNode) GetHeight() int { + return 1 +} + +// Hash returns the hash of the node. +func (bt *StemNode) Hash() common.Hash { + var data [NodeWidth]common.Hash + for i, v := range bt.Values { + if v != nil { + h := sha256.Sum256(v) + data[i] = common.BytesToHash(h[:]) + } + } + + h := sha256.New() + for level := 1; level <= 8; level++ { + for i := range NodeWidth / (1 << level) { + h.Reset() + + if data[i*2] == (common.Hash{}) && data[i*2+1] == (common.Hash{}) { + data[i] = common.Hash{} + continue + } + + h.Write(data[i*2][:]) + h.Write(data[i*2+1][:]) + data[i] = common.Hash(h.Sum(nil)) + } + } + + h.Reset() + h.Write(bt.Stem) + h.Write([]byte{0}) + h.Write(data[0][:]) + return common.BytesToHash(h.Sum(nil)) +} + +// CollectNodes collects all child nodes at a given path, and flushes it +// into the provided node collector. +func (bt *StemNode) CollectNodes(path []byte, flush NodeFlushFn) error { + flush(path, bt) + return nil +} + +// GetValuesAtStem retrieves the group of values located at the given stem key. +func (bt *StemNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) { + return bt.Values[:], nil +} + +// InsertValuesAtStem inserts a full value group at the given stem in the internal node. +// Already-existing values will be overwritten. +func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) { + if !bytes.Equal(bt.Stem, key[:31]) { + bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 + + n := &InternalNode{depth: bt.depth} + bt.depth++ + var child, other *BinaryNode + if bitStem == 0 { + n.left = bt + child = &n.left + other = &n.right + } else { + n.right = bt + child = &n.right + other = &n.left + } + + bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1 + if bitKey == bitStem { + var err error + *child, err = (*child).InsertValuesAtStem(key, values, nil, depth+1) + if err != nil { + return n, fmt.Errorf("insert error: %w", err) + } + *other = Empty{} + } else { + *other = &StemNode{ + Stem: slices.Clone(key[:31]), + Values: values, + depth: n.depth + 1, + } + } + return n, nil + } + + // same stem, just merge the two value lists + for i, v := range values { + if v != nil { + bt.Values[i] = v + } + } + return bt, nil +} + +func (bt *StemNode) toDot(parent, path string) string { + me := fmt.Sprintf("stem%s", path) + ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, bt.Stem, bt.Hash()) + ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) + for i, v := range bt.Values { + if v != nil { + ret = fmt.Sprintf("%s%s%x [label=\"%x\"]\n", ret, me, i, v) + ret = fmt.Sprintf("%s%s -> %s%x\n", ret, me, me, i) + } + } + return ret +} + +// Key returns the full key for the given index. +func (bt *StemNode) Key(i int) []byte { + var ret [32]byte + copy(ret[:], bt.Stem) + ret[StemSize] = byte(i) + return ret[:] +} diff --git a/trie/bintrie/stem_node_test.go b/trie/bintrie/stem_node_test.go new file mode 100644 index 0000000000..e0ffd5c3c8 --- /dev/null +++ b/trie/bintrie/stem_node_test.go @@ -0,0 +1,373 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// TestStemNodeInsertSameStem tests inserting values with the same stem +func TestStemNodeInsertSameStem(t *testing.T) { + stem := make([]byte, 31) + for i := range stem { + stem[i] = byte(i) + } + + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + // Insert another value with the same stem but different last byte + key := make([]byte, 32) + copy(key[:31], stem) + key[31] = 10 + value := common.HexToHash("0x0202").Bytes() + + newNode, err := node.Insert(key, value, nil, 0) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + + // Should still be a StemNode + stemNode, ok := newNode.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", newNode) + } + + // Check that both values are present + if !bytes.Equal(stemNode.Values[0], values[0]) { + t.Errorf("Value at index 0 mismatch") + } + if !bytes.Equal(stemNode.Values[10], value) { + t.Errorf("Value at index 10 mismatch") + } +} + +// TestStemNodeInsertDifferentStem tests inserting values with different stems +func TestStemNodeInsertDifferentStem(t *testing.T) { + stem1 := make([]byte, 31) + for i := range stem1 { + stem1[i] = 0x00 + } + + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + + node := &StemNode{ + Stem: stem1, + Values: values[:], + depth: 0, + } + + // Insert with a different stem (first bit different) + key := make([]byte, 32) + key[0] = 0x80 // First bit is 1 instead of 0 + value := common.HexToHash("0x0202").Bytes() + + newNode, err := node.Insert(key, value, nil, 0) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + + // Should now be an InternalNode + internalNode, ok := newNode.(*InternalNode) + if !ok { + t.Fatalf("Expected InternalNode, got %T", newNode) + } + + // Check depth + if internalNode.depth != 0 { + t.Errorf("Expected depth 0, got %d", internalNode.depth) + } + + // Original stem should be on the left (bit 0) + leftStem, ok := internalNode.left.(*StemNode) + if !ok { + t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) + } + if !bytes.Equal(leftStem.Stem, stem1) { + t.Errorf("Left stem mismatch") + } + + // New stem should be on the right (bit 1) + rightStem, ok := internalNode.right.(*StemNode) + if !ok { + t.Fatalf("Expected right child to be StemNode, got %T", internalNode.right) + } + if !bytes.Equal(rightStem.Stem, key[:31]) { + t.Errorf("Right stem mismatch") + } +} + +// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length +func TestStemNodeInsertInvalidValueLength(t *testing.T) { + stem := make([]byte, 31) + var values [256][]byte + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + // Try to insert value with wrong length + key := make([]byte, 32) + copy(key[:31], stem) + invalidValue := []byte{1, 2, 3} // Not 32 bytes + + _, err := node.Insert(key, invalidValue, nil, 0) + if err == nil { + t.Fatal("Expected error for invalid value length") + } + + if err.Error() != "invalid insertion: value length" { + t.Errorf("Expected 'invalid insertion: value length' error, got: %v", err) + } +} + +// TestStemNodeCopy tests the Copy method +func TestStemNodeCopy(t *testing.T) { + stem := make([]byte, 31) + for i := range stem { + stem[i] = byte(i) + } + + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + values[255] = common.HexToHash("0x0202").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 10, + } + + // Create a copy + copied := node.Copy() + copiedStem, ok := copied.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", copied) + } + + // Check that values are equal but not the same slice + if !bytes.Equal(copiedStem.Stem, node.Stem) { + t.Errorf("Stem mismatch after copy") + } + if &copiedStem.Stem[0] == &node.Stem[0] { + t.Error("Stem slice not properly cloned") + } + + // Check values + if !bytes.Equal(copiedStem.Values[0], node.Values[0]) { + t.Errorf("Value at index 0 mismatch after copy") + } + if !bytes.Equal(copiedStem.Values[255], node.Values[255]) { + t.Errorf("Value at index 255 mismatch after copy") + } + + // Check that value slices are cloned + if copiedStem.Values[0] != nil && &copiedStem.Values[0][0] == &node.Values[0][0] { + t.Error("Value slice not properly cloned") + } + + // Check depth + if copiedStem.depth != node.depth { + t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedStem.depth) + } +} + +// TestStemNodeHash tests the Hash method +func TestStemNodeHash(t *testing.T) { + stem := make([]byte, 31) + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + hash1 := node.Hash() + + // Hash should be deterministic + hash2 := node.Hash() + if hash1 != hash2 { + t.Errorf("Hash not deterministic: %x != %x", hash1, hash2) + } + + // Changing a value should change the hash + node.Values[1] = common.HexToHash("0x0202").Bytes() + hash3 := node.Hash() + if hash1 == hash3 { + t.Error("Hash didn't change after modifying values") + } +} + +// TestStemNodeGetValuesAtStem tests GetValuesAtStem method +func TestStemNodeGetValuesAtStem(t *testing.T) { + stem := make([]byte, 31) + for i := range stem { + stem[i] = byte(i) + } + + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + values[10] = common.HexToHash("0x0202").Bytes() + values[255] = common.HexToHash("0x0303").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + // GetValuesAtStem with matching stem + retrievedValues, err := node.GetValuesAtStem(stem, nil) + if err != nil { + t.Fatalf("Failed to get values: %v", err) + } + + // Check that all values match + for i := 0; i < 256; i++ { + if !bytes.Equal(retrievedValues[i], values[i]) { + t.Errorf("Value mismatch at index %d", i) + } + } + + // GetValuesAtStem with different stem also returns the same values + // (implementation ignores the stem parameter) + differentStem := make([]byte, 31) + differentStem[0] = 0xFF + + retrievedValues2, err := node.GetValuesAtStem(differentStem, nil) + if err != nil { + t.Fatalf("Failed to get values with different stem: %v", err) + } + + // Should still return the same values (stem is ignored) + for i := 0; i < 256; i++ { + if !bytes.Equal(retrievedValues2[i], values[i]) { + t.Errorf("Value mismatch at index %d with different stem", i) + } + } +} + +// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method +func TestStemNodeInsertValuesAtStem(t *testing.T) { + stem := make([]byte, 31) + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + // Insert new values at the same stem + var newValues [256][]byte + newValues[1] = common.HexToHash("0x0202").Bytes() + newValues[2] = common.HexToHash("0x0303").Bytes() + + newNode, err := node.InsertValuesAtStem(stem, newValues[:], nil, 0) + if err != nil { + t.Fatalf("Failed to insert values: %v", err) + } + + stemNode, ok := newNode.(*StemNode) + if !ok { + t.Fatalf("Expected StemNode, got %T", newNode) + } + + // Check that all values are present + if !bytes.Equal(stemNode.Values[0], values[0]) { + t.Error("Original value at index 0 missing") + } + if !bytes.Equal(stemNode.Values[1], newValues[1]) { + t.Error("New value at index 1 missing") + } + if !bytes.Equal(stemNode.Values[2], newValues[2]) { + t.Error("New value at index 2 missing") + } +} + +// TestStemNodeGetHeight tests GetHeight method +func TestStemNodeGetHeight(t *testing.T) { + node := &StemNode{ + Stem: make([]byte, 31), + Values: make([][]byte, 256), + depth: 0, + } + + height := node.GetHeight() + if height != 1 { + t.Errorf("Expected height 1, got %d", height) + } +} + +// TestStemNodeCollectNodes tests CollectNodes method +func TestStemNodeCollectNodes(t *testing.T) { + stem := make([]byte, 31) + var values [256][]byte + values[0] = common.HexToHash("0x0101").Bytes() + + node := &StemNode{ + Stem: stem, + Values: values[:], + depth: 0, + } + + var collectedPaths [][]byte + var collectedNodes []BinaryNode + + flushFn := func(path []byte, n BinaryNode) { + // Make a copy of the path + pathCopy := make([]byte, len(path)) + copy(pathCopy, path) + collectedPaths = append(collectedPaths, pathCopy) + collectedNodes = append(collectedNodes, n) + } + + err := node.CollectNodes([]byte{0, 1, 0}, flushFn) + if err != nil { + t.Fatalf("Failed to collect nodes: %v", err) + } + + // Should have collected one node (itself) + if len(collectedNodes) != 1 { + t.Errorf("Expected 1 collected node, got %d", len(collectedNodes)) + } + + // Check that the collected node is the same + if collectedNodes[0] != node { + t.Error("Collected node doesn't match original") + } + + // Check the path + if !bytes.Equal(collectedPaths[0], []byte{0, 1, 0}) { + t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0]) + } +} diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go new file mode 100644 index 0000000000..0a8bd325f5 --- /dev/null +++ b/trie/bintrie/trie.go @@ -0,0 +1,353 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/trie/trienode" + "github.com/ethereum/go-ethereum/triedb/database" + "github.com/holiman/uint256" +) + +var errInvalidRootType = errors.New("invalid root type") + +// NewBinaryNode creates a new empty binary trie +func NewBinaryNode() BinaryNode { + return Empty{} +} + +// BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864. +type BinaryTrie struct { + root BinaryNode + reader *trie.Reader + tracer *trie.PrevalueTracer +} + +// ToDot converts the binary trie to a DOT language representation. Useful for debugging. +func (t *BinaryTrie) ToDot() string { + t.root.Hash() + return ToDot(t.root) +} + +// NewBinaryTrie creates a new binary trie. +func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, error) { + reader, err := trie.NewReader(root, common.Hash{}, db) + if err != nil { + return nil, err + } + t := &BinaryTrie{ + root: NewBinaryNode(), + reader: reader, + tracer: trie.NewPrevalueTracer(), + } + // Parse the root node if it's not empty + if root != types.EmptyBinaryHash && root != types.EmptyRootHash { + blob, err := t.nodeResolver(nil, root) + if err != nil { + return nil, err + } + node, err := DeserializeNode(blob, 0) + if err != nil { + return nil, err + } + t.root = node + } + return t, nil +} + +// nodeResolver is a node resolver that reads nodes from the flatdb. +func (t *BinaryTrie) nodeResolver(path []byte, hash common.Hash) ([]byte, error) { + // empty nodes will be serialized as common.Hash{}, so capture + // this special use case. + if hash == (common.Hash{}) { + return nil, nil // empty node + } + blob, err := t.reader.Node(path, hash) + if err != nil { + return nil, err + } + t.tracer.Put(path, blob) + return blob, nil +} + +// GetKey returns the sha3 preimage of a hashed key that was previously used +// to store a value. +func (t *BinaryTrie) GetKey(key []byte) []byte { + return key +} + +// GetWithHashedKey returns the value, assuming that the key has already +// been hashed. +func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) { + return t.root.Get(key, t.nodeResolver) +} + +// GetAccount returns the account information for the given address. +func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) { + var ( + values [][]byte + err error + acc = &types.StateAccount{} + key = GetBinaryTreeKey(addr, zero[:]) + ) + switch r := t.root.(type) { + case *InternalNode: + values, err = r.GetValuesAtStem(key[:31], t.nodeResolver) + case *StemNode: + values = r.Values + case Empty: + return nil, nil + default: + // This will cover HashedNode but that should be fine since the + // root node should always be resolved. + return nil, errInvalidRootType + } + if err != nil { + return nil, fmt.Errorf("GetAccount (%x) error: %v", addr, err) + } + + // The following code is required for the MPT->Binary conversion. + // An account can be partially migrated, where storage slots were moved to the binary + // but not yet the account. This means some account information as (header) storage slots + // are in the binary trie but basic account information must be read in the base tree (MPT). + // TODO: we can simplify this logic depending if the conversion is in progress or finished. + emptyAccount := true + for i := 0; values != nil && i <= CodeHashLeafKey && emptyAccount; i++ { + emptyAccount = emptyAccount && values[i] == nil + } + if emptyAccount { + return nil, nil + } + + // If the account has been deleted, then values[10] will be 0 and not nil. If it has + // been recreated after that, then its code keccak will NOT be 0. So return `nil` if + // the nonce, and values[10], and code keccak is 0. + if bytes.Equal(values[BasicDataLeafKey], zero[:]) && len(values) > 10 && len(values[10]) > 0 && bytes.Equal(values[CodeHashLeafKey], zero[:]) { + return nil, nil + } + + acc.Nonce = binary.BigEndian.Uint64(values[BasicDataLeafKey][BasicDataNonceOffset:]) + var balance [16]byte + copy(balance[:], values[BasicDataLeafKey][BasicDataBalanceOffset:]) + acc.Balance = new(uint256.Int).SetBytes(balance[:]) + acc.CodeHash = values[CodeHashLeafKey] + + return acc, nil +} + +// GetStorage returns the value for key stored in the trie. The value bytes must +// not be modified by the caller. If a node was not found in the database, a +// trie.MissingNodeError is returned. +func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) { + return t.root.Get(GetBinaryTreeKey(addr, key), t.nodeResolver) +} + +// UpdateAccount updates the account information for the given address. +func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error { + var ( + err error + basicData [32]byte + values = make([][]byte, NodeWidth) + stem = GetBinaryTreeKey(addr, zero[:]) + ) + binary.BigEndian.PutUint32(basicData[BasicDataCodeSizeOffset-1:], uint32(codeLen)) + binary.BigEndian.PutUint64(basicData[BasicDataNonceOffset:], acc.Nonce) + + // Because the balance is a max of 16 bytes, truncate + // the extra values. This happens in devmode, where + // 0xff**32 is allocated to the developer account. + balanceBytes := acc.Balance.Bytes() + // TODO: reduce the size of the allocation in devmode, then panic instead + // of truncating. + if len(balanceBytes) > 16 { + balanceBytes = balanceBytes[16:] + } + copy(basicData[32-len(balanceBytes):], balanceBytes[:]) + values[BasicDataLeafKey] = basicData[:] + values[CodeHashLeafKey] = acc.CodeHash[:] + + t.root, err = t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0) + return err +} + +// UpdateStem updates the values for the given stem key. +func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error { + var err error + t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0) + return err +} + +// UpdateStorage associates key with value in the trie. If value has length zero, any +// existing value is deleted from the trie. The value bytes must not be modified +// by the caller while they are stored in the trie. If a node was not found in the +// database, a trie.MissingNodeError is returned. +func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) error { + k := GetBinaryTreeKeyStorageSlot(address, key) + var v [32]byte + if len(value) >= 32 { + copy(v[:], value[:32]) + } else { + copy(v[32-len(value):], value[:]) + } + root, err := t.root.Insert(k, v[:], t.nodeResolver, 0) + if err != nil { + return fmt.Errorf("UpdateStorage (%x) error: %v", address, err) + } + t.root = root + return nil +} + +// DeleteAccount is a no-op as it is disabled in stateless. +func (t *BinaryTrie) DeleteAccount(addr common.Address) error { + return nil +} + +// DeleteStorage removes any existing value for key from the trie. If a node was not +// found in the database, a trie.MissingNodeError is returned. +func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error { + k := GetBinaryTreeKey(addr, key) + var zero [32]byte + root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0) + if err != nil { + return fmt.Errorf("DeleteStorage (%x) error: %v", addr, err) + } + t.root = root + return nil +} + +// Hash returns the root hash of the trie. It does not write to the database and +// can be used even if the trie doesn't have one. +func (t *BinaryTrie) Hash() common.Hash { + return t.root.Hash() +} + +// Commit writes all nodes to the trie's memory database, tracking the internal +// and external (for account tries) references. +func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { + root := t.root.(*InternalNode) + nodeset := trienode.NewNodeSet(common.Hash{}) + + err := root.CollectNodes(nil, func(path []byte, node BinaryNode) { + serialized := SerializeNode(node) + nodeset.AddNode(path, trienode.NewNodeWithPrev(common.Hash{}, serialized, t.tracer.Get(path))) + }) + if err != nil { + panic(fmt.Errorf("CollectNodes failed: %v", err)) + } + // Serialize root commitment form + return t.Hash(), nodeset +} + +// NodeIterator returns an iterator that returns nodes of the trie. Iteration +// starts at the key after the given start key. +func (t *BinaryTrie) NodeIterator(startKey []byte) (trie.NodeIterator, error) { + return newBinaryNodeIterator(t, nil) +} + +// Prove constructs a Merkle proof for key. The result contains all encoded nodes +// on the path to the value at key. The value itself is also included in the last +// node and can be retrieved by verifying the proof. +// +// If the trie does not contain a value for key, the returned proof contains all +// nodes of the longest existing prefix of the key (at least the root), ending +// with the node that proves the absence of the key. +func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { + panic("not implemented") +} + +// Copy creates a deep copy of the trie. +func (t *BinaryTrie) Copy() *BinaryTrie { + return &BinaryTrie{ + root: t.root.Copy(), + reader: t.reader, + tracer: t.tracer.Copy(), + } +} + +// IsVerkle returns true if the trie is a Verkle tree. +func (t *BinaryTrie) IsVerkle() bool { + // TODO @gballet This is technically NOT a verkle tree, but it has the same + // behavior and basic structure, so for all intents and purposes, it can be + // treated as such. Rename this when verkle gets removed. + return true +} + +// UpdateContractCode updates the contract code into the trie. +// +// Note: the basic data leaf needs to have been previously created for this to work +func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Hash, code []byte) error { + var ( + chunks = trie.ChunkifyCode(code) + values [][]byte + key []byte + err error + ) + for i, chunknr := 0, uint64(0); i < len(chunks); i, chunknr = i+32, chunknr+1 { + groupOffset := (chunknr + 128) % 256 + if groupOffset == 0 /* start of new group */ || chunknr == 0 /* first chunk in header group */ { + values = make([][]byte, NodeWidth) + var offset [32]byte + binary.LittleEndian.PutUint64(offset[24:], chunknr+128) + key = GetBinaryTreeKey(addr, offset[:]) + } + values[groupOffset] = chunks[i : i+32] + + if groupOffset == 255 || len(chunks)-i <= 32 { + err = t.UpdateStem(key[:31], values) + + if err != nil { + return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err) + } + } + } + return nil +} + +// PrefetchAccount attempts to resolve specific accounts from the database +// to accelerate subsequent trie operations. +func (t *BinaryTrie) PrefetchAccount(addresses []common.Address) error { + for _, addr := range addresses { + if _, err := t.GetAccount(addr); err != nil { + return err + } + } + return nil +} + +// PrefetchStorage attempts to resolve specific storage slots from the database +// to accelerate subsequent trie operations. +func (t *BinaryTrie) PrefetchStorage(addr common.Address, keys [][]byte) error { + for _, key := range keys { + if _, err := t.GetStorage(addr, key); err != nil { + return err + } + } + return nil +} + +// Witness returns a set containing all trie nodes that have been accessed. +func (t *BinaryTrie) Witness() map[string][]byte { + panic("not implemented") +} diff --git a/trie/bintrie/trie_test.go b/trie/bintrie/trie_test.go new file mode 100644 index 0000000000..84f7689549 --- /dev/null +++ b/trie/bintrie/trie_test.go @@ -0,0 +1,197 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +var ( + zeroKey = [32]byte{} + oneKey = common.HexToHash("0101010101010101010101010101010101010101010101010101010101010101") + twoKey = common.HexToHash("0202020202020202020202020202020202020202020202020202020202020202") + threeKey = common.HexToHash("0303030303030303030303030303030303030303030303030303030303030303") + fourKey = common.HexToHash("0404040404040404040404040404040404040404040404040404040404040404") + ffKey = common.HexToHash("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") +) + +func TestSingleEntry(t *testing.T) { + tree := NewBinaryNode() + tree, err := tree.Insert(zeroKey[:], oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 1 { + t.Fatal("invalid depth") + } + expected := common.HexToHash("aab1060e04cb4f5dc6f697ae93156a95714debbf77d54238766adc5709282b6f") + got := tree.Hash() + if got != expected { + t.Fatalf("invalid tree root, got %x, want %x", got, expected) + } +} + +func TestTwoEntriesDiffFirstBit(t *testing.T) { + var err error + tree := NewBinaryNode() + tree, err = tree.Insert(zeroKey[:], oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 2 { + t.Fatal("invalid height") + } + if tree.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") { + t.Fatal("invalid tree root") + } +} + +func TestOneStemColocatedValues(t *testing.T) { + var err error + tree := NewBinaryNode() + tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 1 { + t.Fatal("invalid height") + } +} + +func TestTwoStemColocatedValues(t *testing.T) { + var err error + tree := NewBinaryNode() + // stem: 0...0 + tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + // stem: 10...0 + tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 2 { + t.Fatal("invalid height") + } +} + +func TestTwoKeysMatchFirst42Bits(t *testing.T) { + var err error + tree := NewBinaryNode() + // key1 and key 2 have the same prefix of 42 bits (b0*42+b1+b1) and differ after. + key1 := common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0").Bytes() + key2 := common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000").Bytes() + tree, err = tree.Insert(key1, oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(key2, twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 1+42+1 { + t.Fatal("invalid height") + } +} +func TestInsertDuplicateKey(t *testing.T) { + var err error + tree := NewBinaryNode() + tree, err = tree.Insert(oneKey[:], oneKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + tree, err = tree.Insert(oneKey[:], twoKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + if tree.GetHeight() != 1 { + t.Fatal("invalid height") + } + // Verify that the value is updated + if !bytes.Equal(tree.(*StemNode).Values[1], twoKey[:]) { + t.Fatal("invalid height") + } +} +func TestLargeNumberOfEntries(t *testing.T) { + var err error + tree := NewBinaryNode() + for i := range 256 { + var key [32]byte + key[0] = byte(i) + tree, err = tree.Insert(key[:], ffKey[:], nil, 0) + if err != nil { + t.Fatal(err) + } + } + height := tree.GetHeight() + if height != 1+8 { + t.Fatalf("invalid height, wanted %d, got %d", 1+8, height) + } +} + +func TestMerkleizeMultipleEntries(t *testing.T) { + var err error + tree := NewBinaryNode() + keys := [][]byte{ + zeroKey[:], + common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), + common.HexToHash("0100000000000000000000000000000000000000000000000000000000000000").Bytes(), + common.HexToHash("8100000000000000000000000000000000000000000000000000000000000000").Bytes(), + } + for i, key := range keys { + var v [32]byte + binary.LittleEndian.PutUint64(v[:8], uint64(i)) + tree, err = tree.Insert(key, v[:], nil, 0) + if err != nil { + t.Fatal(err) + } + } + got := tree.Hash() + expected := common.HexToHash("9317155862f7a3867660ddd0966ff799a3d16aa4df1e70a7516eaa4a675191b5") + if got != expected { + t.Fatalf("invalid root, expected=%x, got = %x", expected, got) + } +} diff --git a/trie/committer.go b/trie/committer.go index a040868c6c..2a2142e0ff 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -29,12 +29,12 @@ import ( // insertion order. type committer struct { nodes *trienode.NodeSet - tracer *prevalueTracer + tracer *PrevalueTracer collectLeaf bool } // newCommitter creates a new committer or picks one from the pool. -func newCommitter(nodeset *trienode.NodeSet, tracer *prevalueTracer, collectLeaf bool) *committer { +func newCommitter(nodeset *trienode.NodeSet, tracer *PrevalueTracer, collectLeaf bool) *committer { return &committer{ nodes: nodeset, tracer: tracer, @@ -142,7 +142,7 @@ func (c *committer) store(path []byte, n node) node { // The node is embedded in its parent, in other words, this node // will not be stored in the database independently, mark it as // deleted only if the node was existent in database before. - origin := c.tracer.get(path) + origin := c.tracer.Get(path) if len(origin) != 0 { c.nodes.AddNode(path, trienode.NewDeletedWithPrev(origin)) } @@ -150,7 +150,7 @@ func (c *committer) store(path []byte, n node) node { } // Collect the dirty node to nodeset for return. nhash := common.BytesToHash(hash) - c.nodes.AddNode(path, trienode.NewNodeWithPrev(nhash, nodeToBytes(n), c.tracer.get(path))) + c.nodes.AddNode(path, trienode.NewNodeWithPrev(nhash, nodeToBytes(n), c.tracer.Get(path))) // Collect the corresponding leaf node if it's required. We don't check // full node since it's impossible to store value in fullNode. The key diff --git a/trie/iterator.go b/trie/iterator.go index e6fedf2430..80298ce48f 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -405,7 +405,7 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) { // loaded blob will be tracked, while it's not required here since // all loaded nodes won't be linked to trie at all and track nodes // may lead to out-of-memory issue. - blob, err := it.trie.reader.node(path, common.BytesToHash(hash)) + blob, err := it.trie.reader.Node(path, common.BytesToHash(hash)) if err != nil { return nil, err } @@ -426,7 +426,7 @@ func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) // loaded blob will be tracked, while it's not required here since // all loaded nodes won't be linked to trie at all and track nodes // may lead to out-of-memory issue. - return it.trie.reader.node(path, common.BytesToHash(hash)) + return it.trie.reader.Node(path, common.BytesToHash(hash)) } func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { diff --git a/trie/proof.go b/trie/proof.go index f3ed417094..1a06ed5d5e 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -69,7 +69,7 @@ func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { // loaded blob will be tracked, while it's not required here since // all loaded nodes won't be linked to trie at all and track nodes // may lead to out-of-memory issue. - blob, err := t.reader.node(prefix, common.BytesToHash(n)) + blob, err := t.reader.Node(prefix, common.BytesToHash(n)) if err != nil { log.Error("Unhandled trie error in Trie.Prove", "err", err) return err @@ -571,7 +571,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, valu root: root, reader: newEmptyReader(), opTracer: newOpTracer(), - prevalueTracer: newPrevalueTracer(), + prevalueTracer: NewPrevalueTracer(), } if empty { tr.root = nil diff --git a/trie/tracer.go b/trie/tracer.go index b0542404a7..04122d1384 100644 --- a/trie/tracer.go +++ b/trie/tracer.go @@ -94,46 +94,45 @@ func (t *opTracer) deletedList() [][]byte { return paths } -// prevalueTracer tracks the original values of resolved trie nodes. Cached trie +// PrevalueTracer tracks the original values of resolved trie nodes. Cached trie // node values are expected to be immutable. A zero-size node value is treated as // non-existent and should not occur in practice. // -// Note prevalueTracer is not thread-safe, callers should be responsible for -// handling the concurrency issues by themselves. -type prevalueTracer struct { +// Note PrevalueTracer is thread-safe. +type PrevalueTracer struct { data map[string][]byte lock sync.RWMutex } -// newPrevalueTracer initializes the tracer for capturing resolved trie nodes. -func newPrevalueTracer() *prevalueTracer { - return &prevalueTracer{ +// NewPrevalueTracer initializes the tracer for capturing resolved trie nodes. +func NewPrevalueTracer() *PrevalueTracer { + return &PrevalueTracer{ data: make(map[string][]byte), } } -// put tracks the newly loaded trie node and caches its RLP-encoded +// Put tracks the newly loaded trie node and caches its RLP-encoded // blob internally. Do not modify the value outside this function, // as it is not deep-copied. -func (t *prevalueTracer) put(path []byte, val []byte) { +func (t *PrevalueTracer) Put(path []byte, val []byte) { t.lock.Lock() defer t.lock.Unlock() t.data[string(path)] = val } -// get returns the cached trie node value. If the node is not found, nil will +// Get returns the cached trie node value. If the node is not found, nil will // be returned. -func (t *prevalueTracer) get(path []byte) []byte { +func (t *PrevalueTracer) Get(path []byte) []byte { t.lock.RLock() defer t.lock.RUnlock() return t.data[string(path)] } -// hasList returns a list of flags indicating whether the corresponding trie nodes +// HasList returns a list of flags indicating whether the corresponding trie nodes // specified by the path exist in the trie. -func (t *prevalueTracer) hasList(list [][]byte) []bool { +func (t *PrevalueTracer) HasList(list [][]byte) []bool { t.lock.RLock() defer t.lock.RUnlock() @@ -145,29 +144,29 @@ func (t *prevalueTracer) hasList(list [][]byte) []bool { return exists } -// values returns a list of values of the cached trie nodes. -func (t *prevalueTracer) values() map[string][]byte { +// Values returns a list of values of the cached trie nodes. +func (t *PrevalueTracer) Values() map[string][]byte { t.lock.RLock() defer t.lock.RUnlock() return maps.Clone(t.data) } -// reset resets the cached content in the prevalueTracer. -func (t *prevalueTracer) reset() { +// Reset resets the cached content in the prevalueTracer. +func (t *PrevalueTracer) Reset() { t.lock.Lock() defer t.lock.Unlock() clear(t.data) } -// copy returns a copied prevalueTracer instance. -func (t *prevalueTracer) copy() *prevalueTracer { +// Copy returns a copied prevalueTracer instance. +func (t *PrevalueTracer) Copy() *PrevalueTracer { t.lock.RLock() defer t.lock.RUnlock() // Shadow clone is used, as the cached trie node values are immutable - return &prevalueTracer{ + return &PrevalueTracer{ data: maps.Clone(t.data), } } diff --git a/trie/trie.go b/trie/trie.go index 98cf751f47..630462f8ca 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -55,11 +55,11 @@ type Trie struct { uncommitted int // reader is the handler trie can retrieve nodes from. - reader *trieReader + reader *Reader // Various tracers for capturing the modifications to trie opTracer *opTracer - prevalueTracer *prevalueTracer + prevalueTracer *PrevalueTracer } // newFlag returns the cache flag value for a newly created node. @@ -77,7 +77,7 @@ func (t *Trie) Copy() *Trie { uncommitted: t.uncommitted, reader: t.reader, opTracer: t.opTracer.copy(), - prevalueTracer: t.prevalueTracer.copy(), + prevalueTracer: t.prevalueTracer.Copy(), } } @@ -88,7 +88,7 @@ func (t *Trie) Copy() *Trie { // empty, otherwise, the root node must be present in database or returns // a MissingNodeError if not. func New(id *ID, db database.NodeDatabase) (*Trie, error) { - reader, err := newTrieReader(id.StateRoot, id.Owner, db) + reader, err := NewReader(id.StateRoot, id.Owner, db) if err != nil { return nil, err } @@ -96,7 +96,7 @@ func New(id *ID, db database.NodeDatabase) (*Trie, error) { owner: id.Owner, reader: reader, opTracer: newOpTracer(), - prevalueTracer: newPrevalueTracer(), + prevalueTracer: NewPrevalueTracer(), } if id.Root != (common.Hash{}) && id.Root != types.EmptyRootHash { rootnode, err := trie.resolveAndTrack(id.Root[:], nil) @@ -289,7 +289,7 @@ func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnod if hash == nil { return nil, origNode, 0, errors.New("non-consensus node") } - blob, err := t.reader.node(path, common.BytesToHash(hash)) + blob, err := t.reader.Node(path, common.BytesToHash(hash)) return blob, origNode, 1, err } // Path still needs to be traversed, descend into children @@ -655,11 +655,11 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) { // node's original value. The rlp-encoded blob is preferred to be loaded from // database because it's easy to decode node while complex to encode node to blob. func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { - blob, err := t.reader.node(prefix, common.BytesToHash(n)) + blob, err := t.reader.Node(prefix, common.BytesToHash(n)) if err != nil { return nil, err } - t.prevalueTracer.put(prefix, blob) + t.prevalueTracer.Put(prefix, blob) // The returned node blob won't be changed afterward. No need to // deep-copy the slice. @@ -673,7 +673,7 @@ func (t *Trie) deletedNodes() [][]byte { var ( pos int list = t.opTracer.deletedList() - flags = t.prevalueTracer.hasList(list) + flags = t.prevalueTracer.HasList(list) ) for i := 0; i < len(list); i++ { if flags[i] { @@ -711,7 +711,7 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) { } nodes := trienode.NewNodeSet(t.owner) for _, path := range paths { - nodes.AddNode(path, trienode.NewDeletedWithPrev(t.prevalueTracer.get(path))) + nodes.AddNode(path, trienode.NewDeletedWithPrev(t.prevalueTracer.Get(path))) } return types.EmptyRootHash, nodes // case (b) } @@ -729,7 +729,7 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) { } nodes := trienode.NewNodeSet(t.owner) for _, path := range t.deletedNodes() { - nodes.AddNode(path, trienode.NewDeletedWithPrev(t.prevalueTracer.get(path))) + nodes.AddNode(path, trienode.NewDeletedWithPrev(t.prevalueTracer.Get(path))) } // If the number of changes is below 100, we let one thread handle it t.root = newCommitter(nodes, t.prevalueTracer, collectLeaf).Commit(t.root, t.uncommitted > 100) @@ -753,7 +753,7 @@ func (t *Trie) hashRoot() []byte { // Witness returns a set containing all trie nodes that have been accessed. func (t *Trie) Witness() map[string][]byte { - return t.prevalueTracer.values() + return t.prevalueTracer.Values() } // Reset drops the referenced root node and cleans all internal state. @@ -763,6 +763,6 @@ func (t *Trie) Reset() { t.unhashed = 0 t.uncommitted = 0 t.opTracer.reset() - t.prevalueTracer.reset() + t.prevalueTracer.Reset() t.committed = false } diff --git a/trie/trie_reader.go b/trie/trie_reader.go index a42cdb0cf9..42fe4d72c7 100644 --- a/trie/trie_reader.go +++ b/trie/trie_reader.go @@ -22,39 +22,39 @@ import ( "github.com/ethereum/go-ethereum/triedb/database" ) -// trieReader is a wrapper of the underlying node reader. It's not safe +// Reader is a wrapper of the underlying database reader. It's not safe // for concurrent usage. -type trieReader struct { +type Reader struct { owner common.Hash reader database.NodeReader banned map[string]struct{} // Marker to prevent node from being accessed, for tests } -// newTrieReader initializes the trie reader with the given node reader. -func newTrieReader(stateRoot, owner common.Hash, db database.NodeDatabase) (*trieReader, error) { +// NewReader initializes the trie reader with the given database reader. +func NewReader(stateRoot, owner common.Hash, db database.NodeDatabase) (*Reader, error) { if stateRoot == (common.Hash{}) || stateRoot == types.EmptyRootHash { - return &trieReader{owner: owner}, nil + return &Reader{owner: owner}, nil } reader, err := db.NodeReader(stateRoot) if err != nil { return nil, &MissingNodeError{Owner: owner, NodeHash: stateRoot, err: err} } - return &trieReader{owner: owner, reader: reader}, nil + return &Reader{owner: owner, reader: reader}, nil } // newEmptyReader initializes the pure in-memory reader. All read operations // should be forbidden and returns the MissingNodeError. -func newEmptyReader() *trieReader { - return &trieReader{} +func newEmptyReader() *Reader { + return &Reader{} } -// node retrieves the rlp-encoded trie node with the provided trie node +// Node retrieves the rlp-encoded trie node with the provided trie node // information. An MissingNodeError will be returned in case the node is // not found or any error is encountered. // // Don't modify the returned byte slice since it's not deep-copied and // still be referenced by database. -func (r *trieReader) node(path []byte, hash common.Hash) ([]byte, error) { +func (r *Reader) Node(path []byte, hash common.Hash) ([]byte, error) { // Perform the logics in tests for preventing trie node access. if r.banned != nil { if _, ok := r.banned[string(path)]; ok { diff --git a/trie/verkle.go b/trie/verkle.go index e00ea21602..186ac1f642 100644 --- a/trie/verkle.go +++ b/trie/verkle.go @@ -41,13 +41,13 @@ var ( type VerkleTrie struct { root verkle.VerkleNode cache *utils.PointCache - reader *trieReader - tracer *prevalueTracer + reader *Reader + tracer *PrevalueTracer } // NewVerkleTrie constructs a verkle tree based on the specified root hash. func NewVerkleTrie(root common.Hash, db database.NodeDatabase, cache *utils.PointCache) (*VerkleTrie, error) { - reader, err := newTrieReader(root, common.Hash{}, db) + reader, err := NewReader(root, common.Hash{}, db) if err != nil { return nil, err } @@ -55,7 +55,7 @@ func NewVerkleTrie(root common.Hash, db database.NodeDatabase, cache *utils.Poin root: verkle.New(), cache: cache, reader: reader, - tracer: newPrevalueTracer(), + tracer: NewPrevalueTracer(), } // Parse the root verkle node if it's not empty. if root != types.EmptyVerkleHash && root != types.EmptyRootHash { @@ -289,7 +289,7 @@ func (t *VerkleTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { nodeset := trienode.NewNodeSet(common.Hash{}) for _, node := range nodes { // Hash parameter is not used in pathdb - nodeset.AddNode(node.Path, trienode.NewNodeWithPrev(common.Hash{}, node.SerializedBytes, t.tracer.get(node.Path))) + nodeset.AddNode(node.Path, trienode.NewNodeWithPrev(common.Hash{}, node.SerializedBytes, t.tracer.Get(node.Path))) } // Serialize root commitment form return t.Hash(), nodeset @@ -322,7 +322,7 @@ func (t *VerkleTrie) Copy() *VerkleTrie { root: t.root.Copy(), cache: t.cache, reader: t.reader, - tracer: t.tracer.copy(), + tracer: t.tracer.Copy(), } } @@ -443,11 +443,11 @@ func (t *VerkleTrie) ToDot() string { } func (t *VerkleTrie) nodeResolver(path []byte) ([]byte, error) { - blob, err := t.reader.node(path, common.Hash{}) + blob, err := t.reader.Node(path, common.Hash{}) if err != nil { return nil, err } - t.tracer.put(path, blob) + t.tracer.Put(path, blob) return blob, nil }