diff --git a/trie/bintrie/binary_node.go b/trie/bintrie/binary_node.go
index 8905a82285..fbfd4a6ca7 100644
--- a/trie/bintrie/binary_node.go
+++ b/trie/bintrie/binary_node.go
@@ -16,140 +16,43 @@
package bintrie
-import (
- "errors"
-
- "github.com/ethereum/go-ethereum/common"
-)
-
-type (
- NodeFlushFn func([]byte, BinaryNode)
- NodeResolverFn func([]byte, common.Hash) ([]byte, error)
-)
+import "github.com/ethereum/go-ethereum/common"
// zero is the zero value for a 32-byte array.
var zero [32]byte
const (
- StemNodeWidth = 256 // Number of child per leaf node
- StemSize = 31 // Number of bytes to travel before reaching a group of leaves
- NodeTypeBytes = 1 // Size of node type prefix in serialization
- HashSize = 32 // Size of a hash in bytes
- BitmapSize = 32 // Size of the bitmap in a stem node
+ StemNodeWidth = 256 // Number of children per leaf node
+ StemSize = 31 // Number of bytes to travel before reaching a group of leaves
+ NodeTypeBytes = 1 // Size of node type prefix in serialization
+ HashSize = 32 // Size of a hash in bytes
+ StemBitmapSize = 32 // Size of the bitmap in a stem node (256 values = 32 bytes)
+
+ // MaxGroupDepth is the maximum allowed group depth for InternalNode serialization.
+ MaxGroupDepth = 8
)
+// BitmapSizeForDepth returns the bitmap size in bytes for a given group depth.
+func BitmapSizeForDepth(groupDepth int) int {
+ if groupDepth <= 3 {
+ return 1
+ }
+ return 1 << (groupDepth - 3)
+}
+
const (
- nodeTypeStem = iota + 1 // Stem node, contains a stem and a bitmap of values
+ nodeTypeStem = iota + 1
nodeTypeInternal
)
-// BinaryNode is an interface for a binary trie node.
-type BinaryNode interface {
- Get([]byte, NodeResolverFn) ([]byte, error)
- Insert([]byte, []byte, NodeResolverFn, int) (BinaryNode, error)
- Copy() BinaryNode
- Hash() common.Hash
- GetValuesAtStem([]byte, NodeResolverFn) ([][]byte, error)
- InsertValuesAtStem([]byte, [][]byte, NodeResolverFn, int) (BinaryNode, error)
- CollectNodes([]byte, NodeFlushFn) error
-
- toDot(parent, path string) string
- GetHeight() int
-}
-
-// SerializeNode serializes a binary trie node into a byte slice.
-func SerializeNode(node BinaryNode) []byte {
- switch n := (node).(type) {
- case *InternalNode:
- // InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash
- var serialized [NodeTypeBytes + HashSize + HashSize]byte
- serialized[0] = nodeTypeInternal
- copy(serialized[1:33], n.left.Hash().Bytes())
- copy(serialized[33:65], n.right.Hash().Bytes())
- return serialized[:]
- case *StemNode:
- // StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values
- var serialized [NodeTypeBytes + StemSize + BitmapSize + StemNodeWidth*HashSize]byte
- serialized[0] = nodeTypeStem
- copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], n.Stem)
- bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize]
- offset := NodeTypeBytes + StemSize + BitmapSize
- for i, v := range n.Values {
- if v != nil {
- bitmap[i/8] |= 1 << (7 - (i % 8))
- copy(serialized[offset:offset+HashSize], v)
- offset += HashSize
- }
- }
- // Only return the actual data, not the entire array
- return serialized[:offset]
- default:
- panic("invalid node type")
+// DeserializeAndHash deserializes a node from bytes and returns its hash.
+// This is a convenience function for external callers that need to compute
+// the hash of a serialized node without maintaining a NodeStore.
+func DeserializeAndHash(blob []byte, depth int) (common.Hash, error) {
+ s := NewNodeStore()
+ ref, err := s.DeserializeNode(blob, depth)
+ if err != nil {
+ return common.Hash{}, err
}
-}
-
-var invalidSerializedLength = errors.New("invalid serialized node length")
-
-// DeserializeNode deserializes a binary trie node from a byte slice. The
-// hash will be recomputed from the deserialized data.
-func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
- return deserializeNode(serialized, depth, common.Hash{}, true, true)
-}
-
-// DeserializeNodeWithHash deserializes a binary trie node from a byte slice, using the provided hash.
-func DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
- return deserializeNode(serialized, depth, hn, false, false)
-}
-
-func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (BinaryNode, error) {
- if len(serialized) == 0 {
- return Empty{}, nil
- }
-
- switch serialized[0] {
- case nodeTypeInternal:
- if len(serialized) != 65 {
- return nil, invalidSerializedLength
- }
- return &InternalNode{
- depth: depth,
- left: HashedNode(common.BytesToHash(serialized[1:33])),
- right: HashedNode(common.BytesToHash(serialized[33:65])),
- hash: hn,
- mustRecompute: mustRecompute,
- dirty: dirty,
- }, nil
- case nodeTypeStem:
- if len(serialized) < 64 {
- return nil, invalidSerializedLength
- }
- var values [StemNodeWidth][]byte
- bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize]
- offset := NodeTypeBytes + StemSize + BitmapSize
-
- for i := range StemNodeWidth {
- if bitmap[i/8]>>(7-(i%8))&1 == 1 {
- if len(serialized) < offset+HashSize {
- return nil, invalidSerializedLength
- }
- values[i] = serialized[offset : offset+HashSize]
- offset += HashSize
- }
- }
- return &StemNode{
- Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
- Values: values[:],
- depth: depth,
- hash: hn,
- mustRecompute: mustRecompute,
- dirty: dirty,
- }, nil
- default:
- return nil, errors.New("invalid node type")
- }
-}
-
-// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
-func ToDot(root BinaryNode) string {
- return root.toDot("", "")
+ return s.ComputeHash(ref), nil
}
diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go
index 242743ba53..f8a00bc9f1 100644
--- a/trie/bintrie/binary_node_test.go
+++ b/trie/bintrie/binary_node_test.go
@@ -24,78 +24,118 @@ import (
)
// TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode
+// with the grouped subtree format through NodeStore.
func TestSerializeDeserializeInternalNode(t *testing.T) {
- // Create an internal node with two hashed children
leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321")
- node := &InternalNode{
- depth: 5,
- left: HashedNode(leftHash),
- right: HashedNode(rightHash),
- }
+ s := NewNodeStore()
+ leftRef := s.newHashedRef(leftHash)
+ rightRef := s.newHashedRef(rightHash)
- // Serialize the node
- serialized := SerializeNode(node)
+ rootRef := s.newInternalRef(0)
+ rootNode := s.getInternal(rootRef.Index())
+ rootNode.left = leftRef
+ rootNode.right = rightRef
+ s.SetRoot(rootRef)
+
+ // Serialize the node with default group depth of 8
+ serialized := s.SerializeNode(rootRef, MaxGroupDepth)
// Check the serialized format
if serialized[0] != nodeTypeInternal {
t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0])
}
-
- if len(serialized) != 65 {
- t.Errorf("Expected serialized length to be 65, got %d", len(serialized))
+ if serialized[1] != MaxGroupDepth {
+ t.Errorf("Expected group depth to be %d, got %d", MaxGroupDepth, serialized[1])
}
- // Deserialize the node
- deserialized, err := DeserializeNode(serialized, 5)
+ bitmapSize := BitmapSizeForDepth(MaxGroupDepth)
+ expectedLen := 1 + 1 + bitmapSize + 2*HashSize
+ if len(serialized) != expectedLen {
+ t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized))
+ }
+
+ // Check bitmap bits
+ bitmap := serialized[2 : 2+bitmapSize]
+ if bitmap[0]&0x80 == 0 {
+ t.Error("Expected bit 0 to be set in bitmap (left child)")
+ }
+ if bitmap[16]&0x80 == 0 {
+ t.Error("Expected bit 128 to be set in bitmap (right child)")
+ }
+
+ // Deserialize into a new store
+ ds := NewNodeStore()
+ deserialized, err := ds.DeserializeNode(serialized, 0)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
- // Check that it's an internal node
- internalNode, ok := deserialized.(*InternalNode)
- if !ok {
- t.Fatalf("Expected InternalNode, got %T", deserialized)
+ // Root should be an InternalNode
+ if deserialized.Kind() != KindInternal {
+ t.Fatalf("Expected KindInternal, got kind %d", deserialized.Kind())
}
- // Check the depth
- if internalNode.depth != 5 {
- t.Errorf("Expected depth 5, got %d", internalNode.depth)
+ internalNode := ds.getInternal(deserialized.Index())
+ if internalNode.depth != 0 {
+ t.Errorf("Expected depth 0, got %d", internalNode.depth)
}
- // Check the left and right hashes
- if internalNode.left.Hash() != leftHash {
- t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash())
+ // Navigate to position 0 (8 left turns) to find the left hash
+ node0 := navigateToLeafRef(ds, deserialized, 0, 8)
+ if ds.ComputeHash(node0) != leftHash {
+ t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, ds.ComputeHash(node0))
}
- if internalNode.right.Hash() != rightHash {
- t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash())
+ // Navigate to position 128 (right, then 7 lefts) to find the right hash
+ node128 := navigateToLeafRef(ds, deserialized, 128, 8)
+ if ds.ComputeHash(node128) != rightHash {
+ t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, ds.ComputeHash(node128))
}
}
-// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode
+// navigateToLeafRef navigates to a specific position in the tree using NodeStore.
+func navigateToLeafRef(s *NodeStore, ref NodeRef, position, depth int) NodeRef {
+ cur := ref
+ for d := 0; d < depth; d++ {
+ if cur.Kind() != KindInternal {
+ return cur
+ }
+ in := s.getInternal(cur.Index())
+ bit := (position >> (depth - 1 - d)) & 1
+ if bit == 0 {
+ cur = in.left
+ } else {
+ cur = in.right
+ }
+ }
+ return cur
+}
+
+// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode through NodeStore.
func TestSerializeDeserializeStemNode(t *testing.T) {
- // Create a stem node with some values
stem := make([]byte, StemSize)
for i := range stem {
stem[i] = byte(i)
}
var values [StemNodeWidth][]byte
- // Add some values at different indices
values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes()
values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes()
values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 10,
+ s := NewNodeStore()
+ ref := s.newStemRef(stem, 10)
+ sn := s.getStem(ref.Index())
+ for i, v := range values {
+ if v != nil {
+ sn.setValue(byte(i), v)
+ }
}
// Serialize the node
- serialized := SerializeNode(node)
+ serialized := s.SerializeNode(ref, MaxGroupDepth)
// Check the serialized format
if serialized[0] != nodeTypeStem {
@@ -107,31 +147,32 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
t.Errorf("Stem mismatch in serialized data")
}
- // Deserialize the node
- deserialized, err := DeserializeNode(serialized, 10)
+ // Deserialize into a new store
+ ds := NewNodeStore()
+ deserializedRef, err := ds.DeserializeNode(serialized, 10)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
- // Check that it's a stem node
- stemNode, ok := deserialized.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", deserialized)
+ if deserializedRef.Kind() != KindStem {
+ t.Fatalf("Expected KindStem, got kind %d", deserializedRef.Kind())
}
+ stemNode := ds.getStem(deserializedRef.Index())
+
// Check the stem
- if !bytes.Equal(stemNode.Stem, stem) {
+ if !bytes.Equal(stemNode.Stem[:], stem) {
t.Errorf("Stem mismatch after deserialization")
}
// Check the values
- if !bytes.Equal(stemNode.Values[0], values[0]) {
+ if !bytes.Equal(stemNode.getValue(0), values[0]) {
t.Errorf("Value at index 0 mismatch")
}
- if !bytes.Equal(stemNode.Values[10], values[10]) {
+ if !bytes.Equal(stemNode.getValue(10), values[10]) {
t.Errorf("Value at index 10 mismatch")
}
- if !bytes.Equal(stemNode.Values[255], values[255]) {
+ if !bytes.Equal(stemNode.getValue(255), values[255]) {
t.Errorf("Value at index 255 mismatch")
}
@@ -140,43 +181,43 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
if i == 0 || i == 10 || i == 255 {
continue
}
- if stemNode.Values[i] != nil {
- t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i])
+ if stemNode.hasValue(byte(i)) {
+ t.Errorf("Expected no value at index %d, got %x", i, stemNode.getValue(byte(i)))
}
}
}
-// TestDeserializeEmptyNode tests deserialization of empty node
+// TestDeserializeEmptyNode tests deserialization of empty node.
func TestDeserializeEmptyNode(t *testing.T) {
- // Empty byte slice should deserialize to Empty node
- deserialized, err := DeserializeNode([]byte{}, 0)
+ s := NewNodeStore()
+ deserialized, err := s.DeserializeNode([]byte{}, 0)
if err != nil {
t.Fatalf("Failed to deserialize empty node: %v", err)
}
- _, ok := deserialized.(Empty)
- if !ok {
- t.Fatalf("Expected Empty node, got %T", deserialized)
+ if !deserialized.IsEmpty() {
+ t.Fatalf("Expected EmptyRef, got kind %d", deserialized.Kind())
}
}
-// TestDeserializeInvalidType tests deserialization with invalid type byte
+// TestDeserializeInvalidType tests deserialization with invalid type byte.
func TestDeserializeInvalidType(t *testing.T) {
- // Create invalid serialized data with unknown type byte
+ s := NewNodeStore()
invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid
- _, err := DeserializeNode(invalidData, 0)
+ _, err := s.DeserializeNode(invalidData, 0)
if err == nil {
t.Fatal("Expected error for invalid type byte, got nil")
}
}
-// TestDeserializeInvalidLength tests deserialization with invalid data length
+// TestDeserializeInvalidLength tests deserialization with invalid data length.
func TestDeserializeInvalidLength(t *testing.T) {
- // InternalNode with type byte 1 but wrong length
- invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node
+ s := NewNodeStore()
+ // InternalNode with valid type byte and group depth but too short for bitmap
+ invalidData := []byte{nodeTypeInternal, 8, 0, 0} // Too short for bitmap (needs 32 bytes)
- _, err := DeserializeNode(invalidData, 0)
+ _, err := s.DeserializeNode(invalidData, 0)
if err == nil {
t.Fatal("Expected error for invalid data length, got nil")
}
@@ -186,7 +227,7 @@ func TestDeserializeInvalidLength(t *testing.T) {
}
}
-// TestKeyToPath tests the keyToPath function
+// TestKeyToPath tests the keyToPath function.
func TestKeyToPath(t *testing.T) {
tests := []struct {
name string
@@ -218,14 +259,14 @@ func TestKeyToPath(t *testing.T) {
},
{
name: "max valid depth",
- depth: StemSize * 8,
+ depth: StemSize*8 - 1,
key: make([]byte, HashSize),
- expected: make([]byte, StemSize*8+1),
+ expected: make([]byte, StemSize*8),
wantErr: false,
},
{
name: "depth too large",
- depth: StemSize*8 + 1,
+ depth: StemSize * 8,
key: make([]byte, HashSize),
wantErr: true,
},
diff --git a/trie/bintrie/empty.go b/trie/bintrie/empty.go
deleted file mode 100644
index c47e284dac..0000000000
--- a/trie/bintrie/empty.go
+++ /dev/null
@@ -1,76 +0,0 @@
-// Copyright 2025 go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package bintrie
-
-import (
- "slices"
-
- "github.com/ethereum/go-ethereum/common"
-)
-
-type Empty struct{}
-
-func (e Empty) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
- return nil, nil
-}
-
-func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
- var values [256][]byte
- values[key[31]] = value
- return &StemNode{
- Stem: slices.Clone(key[:31]),
- Values: values[:],
- depth: depth,
- mustRecompute: true,
- dirty: true,
- }, nil
-}
-
-func (e Empty) Copy() BinaryNode {
- return Empty{}
-}
-
-func (e Empty) Hash() common.Hash {
- return common.Hash{}
-}
-
-func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
- var values [256][]byte
- return values[:], nil
-}
-
-func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
- return &StemNode{
- Stem: slices.Clone(key[:31]),
- Values: values,
- depth: depth,
- mustRecompute: true,
- dirty: true,
- }, nil
-}
-
-func (e Empty) CollectNodes(_ []byte, _ NodeFlushFn) error {
- return nil
-}
-
-func (e Empty) toDot(parent string, path string) string {
- return ""
-}
-
-func (e Empty) GetHeight() int {
- return 0
-}
diff --git a/trie/bintrie/empty_test.go b/trie/bintrie/empty_test.go
deleted file mode 100644
index 4da1ed15a0..0000000000
--- a/trie/bintrie/empty_test.go
+++ /dev/null
@@ -1,264 +0,0 @@
-// Copyright 2025 go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package bintrie
-
-import (
- "bytes"
- "testing"
-
- "github.com/ethereum/go-ethereum/common"
-)
-
-// TestEmptyGet tests the Get method
-func TestEmptyGet(t *testing.T) {
- node := Empty{}
-
- key := make([]byte, 32)
- value, err := node.Get(key, nil)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
-
- if value != nil {
- t.Errorf("Expected nil value from empty node, got %x", value)
- }
-}
-
-// TestEmptyInsert tests the Insert method
-func TestEmptyInsert(t *testing.T) {
- node := Empty{}
-
- key := make([]byte, 32)
- key[0] = 0x12
- key[31] = 0x34
- value := common.HexToHash("0xabcd").Bytes()
-
- newNode, err := node.Insert(key, value, nil, 0)
- if err != nil {
- t.Fatalf("Failed to insert: %v", err)
- }
-
- // Should create a StemNode
- stemNode, ok := newNode.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", newNode)
- }
-
- // Check the stem (first 31 bytes of key)
- if !bytes.Equal(stemNode.Stem, key[:31]) {
- t.Errorf("Stem mismatch: expected %x, got %x", key[:31], stemNode.Stem)
- }
-
- // Check the value at the correct index (last byte of key)
- if !bytes.Equal(stemNode.Values[key[31]], value) {
- t.Errorf("Value mismatch at index %d: expected %x, got %x", key[31], value, stemNode.Values[key[31]])
- }
-
- // Check that other values are nil
- for i := 0; i < 256; i++ {
- if i != int(key[31]) && stemNode.Values[i] != nil {
- t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i])
- }
- }
-}
-
-// TestEmptyCopy tests the Copy method
-func TestEmptyCopy(t *testing.T) {
- node := Empty{}
-
- copied := node.Copy()
- copiedEmpty, ok := copied.(Empty)
- if !ok {
- t.Fatalf("Expected Empty, got %T", copied)
- }
-
- // Both should be empty
- if node != copiedEmpty {
- // Empty is a zero-value struct, so copies should be equal
- t.Errorf("Empty nodes should be equal")
- }
-}
-
-// TestEmptyHash tests the Hash method
-func TestEmptyHash(t *testing.T) {
- node := Empty{}
-
- hash := node.Hash()
-
- // Empty node should have zero hash
- if hash != (common.Hash{}) {
- t.Errorf("Expected zero hash for empty node, got %x", hash)
- }
-}
-
-// TestEmptyGetValuesAtStem tests the GetValuesAtStem method
-func TestEmptyGetValuesAtStem(t *testing.T) {
- node := Empty{}
-
- stem := make([]byte, 31)
- values, err := node.GetValuesAtStem(stem, nil)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
-
- // Should return an array of 256 nil values
- if len(values) != 256 {
- t.Errorf("Expected 256 values, got %d", len(values))
- }
-
- for i, v := range values {
- if v != nil {
- t.Errorf("Expected nil value at index %d, got %x", i, v)
- }
- }
-}
-
-// TestEmptyInsertValuesAtStem tests the InsertValuesAtStem method
-func TestEmptyInsertValuesAtStem(t *testing.T) {
- node := Empty{}
-
- stem := make([]byte, 31)
- stem[0] = 0x42
-
- var values [256][]byte
- values[0] = common.HexToHash("0x0101").Bytes()
- values[10] = common.HexToHash("0x0202").Bytes()
- values[255] = common.HexToHash("0x0303").Bytes()
-
- newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 5)
- if err != nil {
- t.Fatalf("Failed to insert values: %v", err)
- }
-
- // Should create a StemNode
- stemNode, ok := newNode.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", newNode)
- }
-
- // Check the stem
- if !bytes.Equal(stemNode.Stem, stem) {
- t.Errorf("Stem mismatch: expected %x, got %x", stem, stemNode.Stem)
- }
-
- // Check the depth
- if stemNode.depth != 5 {
- t.Errorf("Depth mismatch: expected 5, got %d", stemNode.depth)
- }
-
- // Check the values
- if !bytes.Equal(stemNode.Values[0], values[0]) {
- t.Error("Value at index 0 mismatch")
- }
- if !bytes.Equal(stemNode.Values[10], values[10]) {
- t.Error("Value at index 10 mismatch")
- }
- if !bytes.Equal(stemNode.Values[255], values[255]) {
- t.Error("Value at index 255 mismatch")
- }
-
- // Check that values is the same slice (not a copy)
- if &stemNode.Values[0] != &values[0] {
- t.Error("Expected values to be the same slice reference")
- }
-}
-
-// TestEmptyCollectNodes tests the CollectNodes method
-func TestEmptyCollectNodes(t *testing.T) {
- node := Empty{}
-
- var collected []BinaryNode
- flushFn := func(path []byte, n BinaryNode) {
- collected = append(collected, n)
- }
-
- err := node.CollectNodes([]byte{0, 1, 0}, flushFn)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
-
- // Should not collect anything for empty node
- if len(collected) != 0 {
- t.Errorf("Expected no collected nodes for empty, got %d", len(collected))
- }
-}
-
-// TestEmptyToDot tests the toDot method
-func TestEmptyToDot(t *testing.T) {
- node := Empty{}
-
- dot := node.toDot("parent", "010")
-
- // Should return empty string for empty node
- if dot != "" {
- t.Errorf("Expected empty string for empty node toDot, got %s", dot)
- }
-}
-
-// TestEmptyGetHeight tests the GetHeight method
-func TestEmptyGetHeight(t *testing.T) {
- node := Empty{}
-
- height := node.GetHeight()
-
- // Empty node should have height 0
- if height != 0 {
- t.Errorf("Expected height 0 for empty node, got %d", height)
- }
-}
-
-// TestEmptyInsertMarksDirty verifies that a StemNode produced by Empty.Insert
-// is marked dirty. Without this, CollectNodes would skip the freshly created
-// stem and its blob would never reach disk, producing "missing trie node"
-// errors on subsequent reads.
-func TestEmptyInsertMarksDirty(t *testing.T) {
- key := make([]byte, 32)
- key[0] = 0xaa
- val := make([]byte, 32)
- val[0] = 0xbb
- n, err := Empty{}.Insert(key, val, nil, 0)
- if err != nil {
- t.Fatalf("Insert: %v", err)
- }
- sn, ok := n.(*StemNode)
- if !ok {
- t.Fatalf("expected *StemNode, got %T", n)
- }
- if !sn.dirty {
- t.Fatalf("stem produced by Empty.Insert must have dirty=true")
- }
-}
-
-// TestEmptyInsertValuesAtStemMarksDirty is the analogous guard for the
-// bulk-insert entry point. Fresh stems created here must be dirty.
-func TestEmptyInsertValuesAtStemMarksDirty(t *testing.T) {
- key := make([]byte, 32)
- key[0] = 0xcc
- values := make([][]byte, 256)
- values[0] = make([]byte, 32)
- n, err := Empty{}.InsertValuesAtStem(key, values, nil, 3)
- if err != nil {
- t.Fatalf("InsertValuesAtStem: %v", err)
- }
- sn, ok := n.(*StemNode)
- if !ok {
- t.Fatalf("expected *StemNode, got %T", n)
- }
- if !sn.dirty {
- t.Fatalf("stem produced by Empty.InsertValuesAtStem must have dirty=true")
- }
-}
diff --git a/trie/bintrie/hashed_node.go b/trie/bintrie/hashed_node.go
index e44c6d1e8a..9d349beaf8 100644
--- a/trie/bintrie/hashed_node.go
+++ b/trie/bintrie/hashed_node.go
@@ -16,75 +16,9 @@
package bintrie
-import (
- "errors"
- "fmt"
+import "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/common"
-)
-
-type HashedNode common.Hash
-
-func (h HashedNode) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
- panic("not implemented") // TODO: Implement
-}
-
-func (h HashedNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
- return nil, errors.New("insert not implemented for hashed node")
-}
-
-func (h HashedNode) Copy() BinaryNode {
- nh := common.Hash(h)
- return HashedNode(nh)
-}
-
-func (h HashedNode) Hash() common.Hash {
- return common.Hash(h)
-}
-
-func (h HashedNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
- return nil, errors.New("attempted to get values from an unresolved node")
-}
-
-func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
- // Step 1: Generate the path for this node's position in the tree
- path, err := keyToPath(depth, stem)
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem path generation error: %w", err)
- }
-
- if resolver == nil {
- return nil, errors.New("InsertValuesAtStem resolve error: resolver is nil")
- }
-
- // Step 2: Resolve the hashed node to get the actual node data
- data, err := resolver(path, common.Hash(h))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
- }
-
- // Step 3: Deserialize the resolved data into a concrete node
- node, err := DeserializeNodeWithHash(data, depth, common.Hash(h))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
- }
-
- // Step 4: Call InsertValuesAtStem on the resolved concrete node
- return node.InsertValuesAtStem(stem, values, resolver, depth)
-}
-
-func (h HashedNode) toDot(parent string, path string) string {
- me := fmt.Sprintf("hash%s", path)
- ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, h)
- ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
- return ret
-}
-
-func (h HashedNode) CollectNodes([]byte, NodeFlushFn) error {
- // HashedNodes are already persisted in the database and don't need to be collected.
- return nil
-}
-
-func (h HashedNode) GetHeight() int {
- panic("tried to get the height of a hashed node, this is a bug")
+// HashedNode represents an unresolved node that only stores its hash.
+type HashedNode struct {
+ hash common.Hash
}
diff --git a/trie/bintrie/hashed_node_test.go b/trie/bintrie/hashed_node_test.go
index f9e6984888..b723f7f256 100644
--- a/trie/bintrie/hashed_node_test.go
+++ b/trie/bintrie/hashed_node_test.go
@@ -18,180 +18,137 @@ package bintrie
import (
"bytes"
+ "errors"
"testing"
"github.com/ethereum/go-ethereum/common"
)
-// TestHashedNodeHash tests the Hash method
+// TestHashedNodeHash tests the Hash method via NodeStore.
func TestHashedNodeHash(t *testing.T) {
hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
- node := HashedNode(hash)
+ s := NewNodeStore()
+ ref := s.newHashedRef(hash)
- // Hash should return the stored hash
- if node.Hash() != hash {
- t.Errorf("Hash mismatch: expected %x, got %x", hash, node.Hash())
+ if s.ComputeHash(ref) != hash {
+ t.Errorf("Hash mismatch: expected %x, got %x", hash, s.ComputeHash(ref))
}
}
-// TestHashedNodeCopy tests the Copy method
+// TestHashedNodeCopy tests the Copy method via NodeStore.
func TestHashedNodeCopy(t *testing.T) {
hash := common.HexToHash("0xabcdef")
- node := HashedNode(hash)
+ s := NewNodeStore()
+ ref := s.newHashedRef(hash)
+ s.SetRoot(ref)
- copied := node.Copy()
- copiedHash, ok := copied.(HashedNode)
- if !ok {
- t.Fatalf("Expected HashedNode, got %T", copied)
- }
+ ns := s.Copy()
+ copiedHash := ns.ComputeHash(ns.Root())
- // Hash should be the same
- if common.Hash(copiedHash) != hash {
+ if copiedHash != hash {
t.Errorf("Hash mismatch after copy: expected %x, got %x", hash, copiedHash)
}
-
- // But should be a different object
- if &node == &copiedHash {
- t.Error("Copy returned same object reference")
- }
}
-// TestHashedNodeInsert tests that Insert returns an error
-func TestHashedNodeInsert(t *testing.T) {
- node := HashedNode(common.HexToHash("0x1234"))
-
- key := make([]byte, HashSize)
- value := make([]byte, HashSize)
-
- _, err := node.Insert(key, value, nil, 0)
- if err == nil {
- t.Fatal("Expected error for Insert on HashedNode")
- }
-
- if err.Error() != "insert not implemented for hashed node" {
- t.Errorf("Unexpected error message: %v", err)
- }
-}
-
-// TestHashedNodeGetValuesAtStem tests that GetValuesAtStem returns an error
-func TestHashedNodeGetValuesAtStem(t *testing.T) {
- node := HashedNode(common.HexToHash("0x1234"))
-
- stem := make([]byte, StemSize)
- _, err := node.GetValuesAtStem(stem, nil)
- if err == nil {
- t.Fatal("Expected error for GetValuesAtStem on HashedNode")
- }
-
- if err.Error() != "attempted to get values from an unresolved node" {
- t.Errorf("Unexpected error message: %v", err)
- }
-}
-
-// TestHashedNodeInsertValuesAtStem tests that InsertValuesAtStem returns an error
+// TestHashedNodeInsertValuesAtStem tests InsertValuesAtStem resolution via NodeStore.
func TestHashedNodeInsertValuesAtStem(t *testing.T) {
- node := HashedNode(common.HexToHash("0x1234"))
+ // Test 1: nil resolver should return an error
+ s := NewNodeStore()
+ hashedRef := s.newHashedRef(common.HexToHash("0x1234"))
+ s.SetRoot(hashedRef)
stem := make([]byte, StemSize)
values := make([][]byte, StemNodeWidth)
- // Test 1: nil resolver should return an error
- _, err := node.InsertValuesAtStem(stem, values, nil, 0)
+ err := s.InsertValuesAtStem(stem, values, nil)
if err == nil {
- t.Fatal("Expected error for InsertValuesAtStem on HashedNode with nil resolver")
- }
-
- if err.Error() != "InsertValuesAtStem resolve error: resolver is nil" {
- t.Errorf("Unexpected error message: %v", err)
+ t.Fatal("Expected error for InsertValuesAtStem with nil resolver")
}
// Test 2: mock resolver returning invalid data should return deserialization error
mockResolver := func(path []byte, hash common.Hash) ([]byte, error) {
- // Return invalid/nonsense data that cannot be deserialized
return []byte{0xff, 0xff, 0xff}, nil
}
- _, err = node.InsertValuesAtStem(stem, values, mockResolver, 0)
- if err == nil {
- t.Fatal("Expected error for InsertValuesAtStem on HashedNode with invalid resolver data")
- }
+ s2 := NewNodeStore()
+ hashedRef2 := s2.newHashedRef(common.HexToHash("0x1234"))
+ s2.SetRoot(hashedRef2)
- expectedPrefix := "InsertValuesAtStem node deserialization error:"
- if len(err.Error()) < len(expectedPrefix) || err.Error()[:len(expectedPrefix)] != expectedPrefix {
- t.Errorf("Expected deserialization error, got: %v", err)
+ err = s2.InsertValuesAtStem(stem, values, mockResolver)
+ if err == nil {
+ t.Fatal("Expected error for InsertValuesAtStem with invalid resolver data")
}
// Test 3: mock resolver returning valid serialized node should succeed
stem = make([]byte, StemSize)
stem[0] = 0xaa
- var originalValues [StemNodeWidth][]byte
+ originalValues := make([][]byte, StemNodeWidth)
originalValues[0] = common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111").Bytes()
originalValues[1] = common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222").Bytes()
- originalNode := &StemNode{
- Stem: stem,
- Values: originalValues[:],
- depth: 0,
+ // Build the serialized node
+ rs := NewNodeStore()
+ ref := rs.newStemRef(stem, 0)
+ sn := rs.getStem(ref.Index())
+ for i, v := range originalValues {
+ if v != nil {
+ sn.setValue(byte(i), v)
+ }
}
+ serialized := rs.SerializeNode(ref, MaxGroupDepth)
- // Serialize the node
- serialized := SerializeNode(originalNode)
-
- // Create a mock resolver that returns the serialized node
validResolver := func(path []byte, hash common.Hash) ([]byte, error) {
return serialized, nil
}
- var newValues [StemNodeWidth][]byte
+ s3 := NewNodeStore()
+ hashedRef3 := s3.newHashedRef(common.HexToHash("0x1234"))
+ s3.SetRoot(hashedRef3)
+
+ newValues := make([][]byte, StemNodeWidth)
newValues[2] = common.HexToHash("0x3333333333333333333333333333333333333333333333333333333333333333").Bytes()
- resolvedNode, err := node.InsertValuesAtStem(stem, newValues[:], validResolver, 0)
+ err = s3.InsertValuesAtStem(stem, newValues, validResolver)
if err != nil {
t.Fatalf("Expected successful resolution and insertion, got error: %v", err)
}
- resultStem, ok := resolvedNode.(*StemNode)
- if !ok {
- t.Fatalf("Expected resolved node to be *StemNode, got %T", resolvedNode)
+ // Verify original values are preserved
+ retrieved, err := s3.GetValuesAtStem(stem, nil)
+ if err != nil {
+ t.Fatal(err)
}
-
- if !bytes.Equal(resultStem.Stem, stem) {
- t.Errorf("Stem mismatch: expected %x, got %x", stem, resultStem.Stem)
+ if !bytes.Equal(retrieved[0], originalValues[0]) {
+ t.Errorf("Original value at index 0 not preserved")
}
-
- // Verify the original values are preserved
- if !bytes.Equal(resultStem.Values[0], originalValues[0]) {
- t.Errorf("Original value at index 0 not preserved: expected %x, got %x", originalValues[0], resultStem.Values[0])
+ if !bytes.Equal(retrieved[1], originalValues[1]) {
+ t.Errorf("Original value at index 1 not preserved")
}
- if !bytes.Equal(resultStem.Values[1], originalValues[1]) {
- t.Errorf("Original value at index 1 not preserved: expected %x, got %x", originalValues[1], resultStem.Values[1])
- }
-
- // Verify the new value was inserted
- if !bytes.Equal(resultStem.Values[2], newValues[2]) {
- t.Errorf("New value at index 2 not inserted correctly: expected %x, got %x", newValues[2], resultStem.Values[2])
+ if !bytes.Equal(retrieved[2], newValues[2]) {
+ t.Errorf("New value at index 2 not inserted correctly")
}
}
-// TestHashedNodeToDot tests the toDot method for visualization
-func TestHashedNodeToDot(t *testing.T) {
- hash := common.HexToHash("0x1234")
- node := HashedNode(hash)
+// TestHashedNodeGetError tests that getting through an unresolved HashedNode root returns error.
+func TestHashedNodeGetError(t *testing.T) {
+ s := NewNodeStore()
+ // Create root as hashed, then try to resolve through InternalNode parent
+ rootRef := s.newInternalRef(0)
+ rootNode := s.getInternal(rootRef.Index())
+ hashedLeft := s.newHashedRef(common.HexToHash("0x1234"))
+ rootNode.left = hashedLeft
+ rootNode.right = EmptyRef
+ s.SetRoot(rootRef)
- dot := node.toDot("parent", "010")
+ key := make([]byte, 32) // goes left
+ key[31] = 5
- // Should contain the hash value and parent connection
- expectedHash := "hash010"
- if !contains(dot, expectedHash) {
- t.Errorf("Expected dot output to contain %s", expectedHash)
+ resolver := func(path []byte, hash common.Hash) ([]byte, error) {
+ return nil, errors.New("node not found")
}
- if !contains(dot, "parent -> hash010") {
- t.Error("Expected dot output to contain parent connection")
+ _, err := s.Get(key, resolver)
+ if err == nil {
+ t.Fatal("Expected error when resolver fails")
}
}
-
-// Helper function
-func contains(s, substr string) bool {
- return len(s) >= len(substr) && s != "" && len(substr) > 0
-}
diff --git a/trie/bintrie/hasher.go b/trie/bintrie/hasher.go
index b81c145723..d56cba96c5 100644
--- a/trie/bintrie/hasher.go
+++ b/trie/bintrie/hasher.go
@@ -37,3 +37,11 @@ func newSha256() hash.Hash {
func returnSha256(h hash.Hash) {
sha256Pool.Put(h)
}
+
+// sha256Sum256 computes a sha256 digest and returns it as a common.Hash.
+func sha256Sum256(data []byte) [32]byte {
+ return sha256.Sum256(data)
+}
+
+// parallelHashDepth controls below which depth hashing is parallelised.
+const parallelHashDepth = 4
diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go
index 811f65bcd8..51beecefd5 100644
--- a/trie/bintrie/internal_node.go
+++ b/trie/bintrie/internal_node.go
@@ -17,35 +17,13 @@
package bintrie
import (
- "crypto/sha256"
"errors"
- "fmt"
- "math/bits"
- "runtime"
- "sync"
"github.com/ethereum/go-ethereum/common"
)
-// parallelDepth returns the tree depth below which Hash() spawns goroutines.
-func parallelDepth() int {
- return min(bits.Len(uint(runtime.NumCPU())), 8)
-}
-
-// isDirty reports whether a BinaryNode child needs rehashing.
-func isDirty(n BinaryNode) bool {
- switch v := n.(type) {
- case *InternalNode:
- return v.mustRecompute
- case *StemNode:
- return v.mustRecompute
- default:
- return false
- }
-}
-
func keyToPath(depth int, key []byte) ([]byte, error) {
- if depth > 31*8 {
+ if depth >= 31*8 {
return nil, errors.New("node too deep")
}
path := make([]byte, 0, depth+1)
@@ -56,252 +34,20 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
return path, nil
}
+// makeKeyPath is a simplified version of keyToPath that doesn't return an error.
+func makeKeyPath(depth int, key []byte) []byte {
+ path := make([]byte, 0, depth+1)
+ for i := range depth + 1 {
+ bit := key[i/8] >> (7 - (i % 8)) & 1
+ path = append(path, bit)
+ }
+ return path
+}
+
// InternalNode is a binary trie internal node.
type InternalNode struct {
- left, right BinaryNode
- depth int
-
+ left, right NodeRef
+ depth uint8
mustRecompute bool // true if the hash needs to be recomputed
- dirty bool // true if the node's on-disk blob is stale (needs flush)
hash common.Hash // cached hash when mustRecompute == false
}
-
-// GetValuesAtStem retrieves the group of values located at the given stem key.
-func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) {
- if bt.depth > 31*8 {
- return nil, errors.New("node too deep")
- }
-
- bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
- if bit == 0 {
- if hn, ok := bt.left.(HashedNode); ok {
- path, err := keyToPath(bt.depth, stem)
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
- }
- data, err := resolver(path, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
- }
- node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
- }
- bt.left = node
- }
- return bt.left.GetValuesAtStem(stem, resolver)
- }
-
- if hn, ok := bt.right.(HashedNode); ok {
- path, err := keyToPath(bt.depth, stem)
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
- }
- data, err := resolver(path, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
- }
- node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
- }
- bt.right = node
- }
- return bt.right.GetValuesAtStem(stem, resolver)
-}
-
-// Get retrieves the value for the given key.
-func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error) {
- values, err := bt.GetValuesAtStem(key[:31], resolver)
- if err != nil {
- return nil, fmt.Errorf("get error: %w", err)
- }
- if values == nil {
- return nil, nil
- }
- return values[key[31]], nil
-}
-
-// Insert inserts a new key-value pair into the trie.
-func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
- var values [256][]byte
- values[key[31]] = value
- return bt.InsertValuesAtStem(key[:31], values[:], resolver, depth)
-}
-
-// Copy creates a deep copy of the node.
-func (bt *InternalNode) Copy() BinaryNode {
- return &InternalNode{
- left: bt.left.Copy(),
- right: bt.right.Copy(),
- depth: bt.depth,
- mustRecompute: bt.mustRecompute,
- dirty: bt.dirty,
- hash: bt.hash,
- }
-}
-
-// Hash returns the hash of the node.
-func (bt *InternalNode) Hash() common.Hash {
- if !bt.mustRecompute {
- return bt.hash
- }
-
- // At shallow depths, parallelize when both children need rehashing:
- // hash left subtree in a goroutine, right subtree inline, then combine.
- // Skip goroutine overhead when only one child is dirty (common case
- // for narrow state updates that touch a single path through the trie).
- if bt.depth < parallelDepth() && isDirty(bt.left) && isDirty(bt.right) {
- var input [64]byte
- var lh common.Hash
- var wg sync.WaitGroup
- wg.Add(1)
- go func() {
- defer wg.Done()
- lh = bt.left.Hash()
- }()
- rh := bt.right.Hash()
- copy(input[32:], rh[:])
- wg.Wait()
- copy(input[:32], lh[:])
- bt.hash = sha256.Sum256(input[:])
- bt.mustRecompute = false
- return bt.hash
- }
-
- // Deeper nodes: sequential using pooled hasher (goroutine overhead > hash cost)
- h := newSha256()
- defer returnSha256(h)
- if bt.left != nil {
- h.Write(bt.left.Hash().Bytes())
- } else {
- h.Write(zero[:])
- }
- if bt.right != nil {
- h.Write(bt.right.Hash().Bytes())
- } else {
- h.Write(zero[:])
- }
- bt.hash = common.BytesToHash(h.Sum(nil))
- bt.mustRecompute = false
- return bt.hash
-}
-
-// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
-// Already-existing values will be overwritten.
-func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
- var err error
- bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
- if bit == 0 {
- if bt.left == nil {
- bt.left = Empty{}
- }
-
- if hn, ok := bt.left.(HashedNode); ok {
- path, err := keyToPath(bt.depth, stem)
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
- }
- data, err := resolver(path, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
- }
- node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
- }
- bt.left = node
- }
-
- bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
- bt.mustRecompute = true
- bt.dirty = true
- return bt, err
- }
-
- if bt.right == nil {
- bt.right = Empty{}
- }
-
- if hn, ok := bt.right.(HashedNode); ok {
- path, err := keyToPath(bt.depth, stem)
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
- }
- data, err := resolver(path, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
- }
- node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
- }
- bt.right = node
- }
-
- bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
- bt.mustRecompute = true
- bt.dirty = true
- return bt, err
-}
-
-// CollectNodes collects all child nodes at a given path, and flushes it
-// into the provided node collector. Clean subtrees (dirty == false) are
-// skipped.
-func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
- if !bt.dirty {
- return nil
- }
- if bt.left != nil {
- var p [256]byte
- copy(p[:], path)
- childpath := p[:len(path)]
- childpath = append(childpath, 0)
- if err := bt.left.CollectNodes(childpath, flushfn); err != nil {
- return err
- }
- }
- if bt.right != nil {
- var p [256]byte
- copy(p[:], path)
- childpath := p[:len(path)]
- childpath = append(childpath, 1)
- if err := bt.right.CollectNodes(childpath, flushfn); err != nil {
- return err
- }
- }
- flushfn(path, bt)
- bt.dirty = false
- return nil
-}
-
-// GetHeight returns the height of the node.
-func (bt *InternalNode) GetHeight() int {
- var (
- leftHeight int
- rightHeight int
- )
- if bt.left != nil {
- leftHeight = bt.left.GetHeight()
- }
- if bt.right != nil {
- rightHeight = bt.right.GetHeight()
- }
- return 1 + max(leftHeight, rightHeight)
-}
-
-func (bt *InternalNode) toDot(parent, path string) string {
- me := fmt.Sprintf("internal%s", path)
- ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, bt.Hash())
- if len(parent) > 0 {
- ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
- }
-
- if bt.left != nil {
- ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0)))
- }
- if bt.right != nil {
- ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1)))
- }
- return ret
-}
diff --git a/trie/bintrie/internal_node_test.go b/trie/bintrie/internal_node_test.go
index ddcec8085d..7ae94b1b8c 100644
--- a/trie/bintrie/internal_node_test.go
+++ b/trie/bintrie/internal_node_test.go
@@ -24,35 +24,33 @@ import (
"github.com/ethereum/go-ethereum/common"
)
-// TestInternalNodeGet tests the Get method
+// TestInternalNodeGet tests the Get method via NodeStore.
func TestInternalNodeGet(t *testing.T) {
- // Create a simple tree structure
+ s := NewNodeStore()
+
leftStem := make([]byte, 31)
rightStem := make([]byte, 31)
- rightStem[0] = 0x80 // First bit is 1
+ rightStem[0] = 0x80
- var leftValues, rightValues [256][]byte
+ leftValues := make([][]byte, 256)
leftValues[0] = common.HexToHash("0x0101").Bytes()
+ rightValues := make([][]byte, 256)
rightValues[0] = common.HexToHash("0x0202").Bytes()
- node := &InternalNode{
- depth: 0,
- left: &StemNode{
- Stem: leftStem,
- Values: leftValues[:],
- depth: 1,
- },
- right: &StemNode{
- Stem: rightStem,
- Values: rightValues[:],
- depth: 1,
- },
+ // Build tree: root -> left stem, right stem
+ // Insert left stem values
+ s.SetRoot(EmptyRef)
+ if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
+ t.Fatal(err)
}
// Get value from left subtree
leftKey := make([]byte, 32)
leftKey[31] = 0
- value, err := node.Get(leftKey, nil)
+ value, err := s.Get(leftKey, nil)
if err != nil {
t.Fatalf("Failed to get left value: %v", err)
}
@@ -64,7 +62,7 @@ func TestInternalNodeGet(t *testing.T) {
rightKey := make([]byte, 32)
rightKey[0] = 0x80
rightKey[31] = 0
- value, err = node.Get(rightKey, nil)
+ value, err = s.Get(rightKey, nil)
if err != nil {
t.Fatalf("Failed to get right value: %v", err)
}
@@ -73,29 +71,26 @@ func TestInternalNodeGet(t *testing.T) {
}
}
-// TestInternalNodeGetWithResolver tests Get with HashedNode resolution
+// TestInternalNodeGetWithResolver tests Get with HashedNode resolution via NodeStore.
func TestInternalNodeGetWithResolver(t *testing.T) {
- // Create an internal node with a hashed child
- hashedChild := HashedNode(common.HexToHash("0x1234"))
-
- node := &InternalNode{
- depth: 0,
- left: hashedChild,
- right: Empty{},
- }
+ // Create a store with an internal node containing a hashed child
+ s := NewNodeStore()
+ hashedChild := s.newHashedRef(common.HexToHash("0x1234"))
+ rootRef := s.newInternalRef(0)
+ rootNode := s.getInternal(rootRef.Index())
+ rootNode.left = hashedChild
+ rootNode.right = EmptyRef
+ s.SetRoot(rootRef)
// Mock resolver that returns a stem node
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
- if hash == common.Hash(hashedChild) {
+ if hash == common.HexToHash("0x1234") {
+ rs := NewNodeStore()
stem := make([]byte, 31)
- var values [256][]byte
- values[5] = common.HexToHash("0xabcd").Bytes()
- stemNode := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 1,
- }
- return SerializeNode(stemNode), nil
+ ref := rs.newStemRef(stem, 1)
+ sn := rs.getStem(ref.Index())
+ sn.setValue(5, common.HexToHash("0xabcd").Bytes())
+ return rs.SerializeNode(ref, MaxGroupDepth), nil
}
return nil, errors.New("node not found")
}
@@ -103,7 +98,7 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
// Get value through the hashed node
key := make([]byte, 32)
key[31] = 5
- value, err := node.Get(key, resolver)
+ value, err := s.Get(key, resolver)
if err != nil {
t.Fatalf("Failed to get value: %v", err)
}
@@ -114,179 +109,113 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
}
}
-// TestInternalNodeInsert tests the Insert method
+// TestInternalNodeInsert tests the Insert method via NodeStore.
func TestInternalNodeInsert(t *testing.T) {
- // Start with an internal node with empty children
- node := &InternalNode{
- depth: 0,
- left: Empty{},
- right: Empty{},
- }
+ s := NewNodeStore()
- // Insert a value into the left subtree
leftKey := make([]byte, 32)
leftKey[31] = 10
leftValue := common.HexToHash("0x0101").Bytes()
- newNode, err := node.Insert(leftKey, leftValue, nil, 0)
- if err != nil {
+ if err := s.Insert(leftKey, leftValue, nil); err != nil {
t.Fatalf("Failed to insert: %v", err)
}
- internalNode, ok := newNode.(*InternalNode)
- if !ok {
- t.Fatalf("Expected InternalNode, got %T", newNode)
+ // Verify the value was stored
+ value, err := s.Get(leftKey, nil)
+ if err != nil {
+ t.Fatalf("Failed to get: %v", err)
}
-
- // Check that left child is now a StemNode
- leftStem, ok := internalNode.left.(*StemNode)
- if !ok {
- t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
- }
-
- // Check the inserted value
- if !bytes.Equal(leftStem.Values[10], leftValue) {
- t.Errorf("Value mismatch: expected %x, got %x", leftValue, leftStem.Values[10])
- }
-
- // Right child should still be Empty
- _, ok = internalNode.right.(Empty)
- if !ok {
- t.Errorf("Expected right child to remain Empty, got %T", internalNode.right)
+ if !bytes.Equal(value, leftValue) {
+ t.Errorf("Value mismatch: expected %x, got %x", leftValue, value)
}
}
-// TestInternalNodeCopy tests the Copy method
+// TestInternalNodeCopy tests the Copy method via NodeStore.
func TestInternalNodeCopy(t *testing.T) {
- // Create an internal node with stem children
- leftStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- }
- leftStem.Values[0] = common.HexToHash("0x0101").Bytes()
+ s := NewNodeStore()
- rightStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- }
- rightStem.Stem[0] = 0x80
- rightStem.Values[0] = common.HexToHash("0x0202").Bytes()
+ leftKey := make([]byte, 32)
+ leftKey[31] = 0
+ leftValue := common.HexToHash("0x0101").Bytes()
- node := &InternalNode{
- depth: 0,
- left: leftStem,
- right: rightStem,
+ rightKey := make([]byte, 32)
+ rightKey[0] = 0x80
+ rightKey[31] = 0
+ rightValue := common.HexToHash("0x0202").Bytes()
+
+ if err := s.Insert(leftKey, leftValue, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.Insert(rightKey, rightValue, nil); err != nil {
+ t.Fatal(err)
}
- // Create a copy
- copied := node.Copy()
- copiedInternal, ok := copied.(*InternalNode)
- if !ok {
- t.Fatalf("Expected InternalNode, got %T", copied)
- }
+ ns := s.Copy()
- // Check depth
- if copiedInternal.depth != node.depth {
- t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedInternal.depth)
- }
-
- // Check that children are copied
- copiedLeft, ok := copiedInternal.left.(*StemNode)
- if !ok {
- t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.left)
- }
-
- copiedRight, ok := copiedInternal.right.(*StemNode)
- if !ok {
- t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.right)
- }
-
- // Verify deep copy (children should be different objects)
- if copiedLeft == leftStem {
- t.Error("Left child not properly copied")
- }
- if copiedRight == rightStem {
- t.Error("Right child not properly copied")
- }
-
- // But values should be equal
- if !bytes.Equal(copiedLeft.Values[0], leftStem.Values[0]) {
+ // Values should be equal
+ v1, _ := ns.Get(leftKey, nil)
+ if !bytes.Equal(v1, leftValue) {
t.Error("Left child value mismatch after copy")
}
- if !bytes.Equal(copiedRight.Values[0], rightStem.Values[0]) {
+ v2, _ := ns.Get(rightKey, nil)
+ if !bytes.Equal(v2, rightValue) {
t.Error("Right child value mismatch after copy")
}
}
-// TestInternalNodeHash tests the Hash method
+// TestInternalNodeHash tests the Hash method via NodeStore.
func TestInternalNodeHash(t *testing.T) {
- // Create an internal node
- node := &InternalNode{
- depth: 0,
- left: HashedNode(common.HexToHash("0x1111")),
- right: HashedNode(common.HexToHash("0x2222")),
- }
+ s := NewNodeStore()
+ leftRef := s.newHashedRef(common.HexToHash("0x1111"))
+ rightRef := s.newHashedRef(common.HexToHash("0x2222"))
+ rootRef := s.newInternalRef(0)
+ rootNode := s.getInternal(rootRef.Index())
+ rootNode.left = leftRef
+ rootNode.right = rightRef
+ s.SetRoot(rootRef)
- hash1 := node.Hash()
+ hash1 := s.ComputeHash(rootRef)
// Hash should be deterministic
- hash2 := node.Hash()
+ hash2 := s.ComputeHash(rootRef)
if hash1 != hash2 {
t.Errorf("Hash not deterministic: %x != %x", hash1, hash2)
}
// Changing a child should change the hash
- node.left = HashedNode(common.HexToHash("0x3333"))
- node.mustRecompute = true
- hash3 := node.Hash()
+ rootNode.left = s.newHashedRef(common.HexToHash("0x3333"))
+ rootNode.mustRecompute = true
+ hash3 := s.ComputeHash(rootRef)
if hash1 == hash3 {
t.Error("Hash didn't change after modifying left child")
}
-
- // Test with nil children (should use zero hash)
- nodeWithNil := &InternalNode{
- depth: 0,
- left: nil,
- right: HashedNode(common.HexToHash("0x4444")),
- mustRecompute: true,
- }
- hashWithNil := nodeWithNil.Hash()
- if hashWithNil == (common.Hash{}) {
- t.Error("Hash shouldn't be zero even with nil child")
- }
}
-// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method
+// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method via NodeStore.
func TestInternalNodeGetValuesAtStem(t *testing.T) {
- // Create a tree with values at different stems
+ s := NewNodeStore()
+
leftStem := make([]byte, 31)
rightStem := make([]byte, 31)
rightStem[0] = 0x80
- var leftValues, rightValues [256][]byte
+ leftValues := make([][]byte, 256)
leftValues[0] = common.HexToHash("0x0101").Bytes()
leftValues[10] = common.HexToHash("0x0102").Bytes()
+ rightValues := make([][]byte, 256)
rightValues[0] = common.HexToHash("0x0201").Bytes()
rightValues[20] = common.HexToHash("0x0202").Bytes()
- node := &InternalNode{
- depth: 0,
- left: &StemNode{
- Stem: leftStem,
- Values: leftValues[:],
- depth: 1,
- },
- right: &StemNode{
- Stem: rightStem,
- Values: rightValues[:],
- depth: 1,
- },
+ if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
+ t.Fatal(err)
}
// Get values from left stem
- values, err := node.GetValuesAtStem(leftStem, nil)
+ values, err := s.GetValuesAtStem(leftStem, nil)
if err != nil {
t.Fatalf("Failed to get left values: %v", err)
}
@@ -298,7 +227,7 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
}
// Get values from right stem
- values, err = node.GetValuesAtStem(rightStem, nil)
+ values, err = s.GetValuesAtStem(rightStem, nil)
if err != nil {
t.Fatalf("Failed to get right values: %v", err)
}
@@ -310,201 +239,103 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
}
}
-// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method
+// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method via NodeStore.
func TestInternalNodeInsertValuesAtStem(t *testing.T) {
- // Start with an internal node with empty children
- node := &InternalNode{
- depth: 0,
- left: Empty{},
- right: Empty{},
- }
+ s := NewNodeStore()
- // Insert values at a stem in the left subtree
stem := make([]byte, 31)
- var values [256][]byte
+ values := make([][]byte, 256)
values[5] = common.HexToHash("0x0505").Bytes()
values[10] = common.HexToHash("0x1010").Bytes()
- newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 0)
- if err != nil {
+ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
t.Fatalf("Failed to insert values: %v", err)
}
- internalNode, ok := newNode.(*InternalNode)
- if !ok {
- t.Fatalf("Expected InternalNode, got %T", newNode)
+ // Check that the values are stored
+ retrieved, err := s.GetValuesAtStem(stem, nil)
+ if err != nil {
+ t.Fatalf("Failed to get values: %v", err)
}
-
- // Check that left child is now a StemNode with the values
- leftStem, ok := internalNode.left.(*StemNode)
- if !ok {
- t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
- }
-
- if !bytes.Equal(leftStem.Values[5], values[5]) {
+ if !bytes.Equal(retrieved[5], values[5]) {
t.Error("Value at index 5 mismatch")
}
- if !bytes.Equal(leftStem.Values[10], values[10]) {
+ if !bytes.Equal(retrieved[10], values[10]) {
t.Error("Value at index 10 mismatch")
}
}
-// TestInternalNodeCollectNodes tests CollectNodes method
+// TestInternalNodeCollectNodes tests CollectNodes method via NodeStore.
func TestInternalNodeCollectNodes(t *testing.T) {
- // Create an internal node with two stem children. All three are
- // marked dirty to mirror production semantics — see CollectNodes.
- leftStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- dirty: true,
- }
+ s := NewNodeStore()
- rightStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- dirty: true,
- }
- rightStem.Stem[0] = 0x80
+ leftStem := make([]byte, 31)
+ rightStem := make([]byte, 31)
+ rightStem[0] = 0x80
- node := &InternalNode{
- depth: 0,
- left: leftStem,
- right: rightStem,
- dirty: true,
+ leftValues := make([][]byte, 256)
+ rightValues := make([][]byte, 256)
+
+ if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
+ t.Fatal(err)
}
var collectedPaths [][]byte
- var collectedNodes []BinaryNode
-
- flushFn := func(path []byte, n BinaryNode) {
+ flushFn := func(path []byte, hash common.Hash, serialized []byte) {
pathCopy := make([]byte, len(path))
copy(pathCopy, path)
collectedPaths = append(collectedPaths, pathCopy)
- collectedNodes = append(collectedNodes, n)
}
- err := node.CollectNodes([]byte{1}, flushFn)
+ err := s.CollectNodes(s.Root(), []byte{1}, flushFn, MaxGroupDepth)
if err != nil {
t.Fatalf("Failed to collect nodes: %v", err)
}
// Should have collected 3 nodes: left stem, right stem, and the internal node itself
- if len(collectedNodes) != 3 {
- t.Errorf("Expected 3 collected nodes, got %d", len(collectedNodes))
- }
-
- // Check paths
- expectedPaths := [][]byte{
- {1, 0}, // left child
- {1, 1}, // right child
- {1}, // internal node itself
- }
-
- for i, expectedPath := range expectedPaths {
- if !bytes.Equal(collectedPaths[i], expectedPath) {
- t.Errorf("Path %d mismatch: expected %v, got %v", i, expectedPath, collectedPaths[i])
- }
+ if len(collectedPaths) != 3 {
+ t.Errorf("Expected 3 collected nodes, got %d", len(collectedPaths))
}
}
-// TestInternalNodeCollectNodesSkipsClean verifies clean subtrees are not
-// flushed. A dirty internal node over clean children only flushes itself;
-// a fully clean tree flushes nothing.
-func TestInternalNodeCollectNodesSkipsClean(t *testing.T) {
- leftStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- }
- rightStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- }
- rightStem.Stem[0] = 0x80
-
- dirtyParent := &InternalNode{
- depth: 0,
- left: leftStem,
- right: rightStem,
- dirty: true,
- }
-
- var collected []BinaryNode
- flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) }
-
- if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil {
- t.Fatalf("CollectNodes: %v", err)
- }
- if len(collected) != 1 || collected[0] != dirtyParent {
- t.Fatalf("expected only the dirty parent to be flushed, got %d nodes", len(collected))
- }
- if dirtyParent.dirty {
- t.Errorf("parent dirty flag should be cleared after flush")
- }
-
- // Second call on the same tree should be a no-op: everything is clean.
- collected = nil
- if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil {
- t.Fatalf("CollectNodes (second call): %v", err)
- }
- if len(collected) != 0 {
- t.Errorf("expected no nodes to be flushed on clean tree, got %d", len(collected))
- }
-}
-
-// TestInternalNodeGetHeight tests GetHeight method
+// TestInternalNodeGetHeight tests GetHeight method via NodeStore.
func TestInternalNodeGetHeight(t *testing.T) {
- // Create a tree with different heights
- // Left subtree: depth 2 (internal -> stem)
- // Right subtree: depth 1 (stem)
- leftInternal := &InternalNode{
- depth: 1,
- left: &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 2,
- },
- right: Empty{},
+ s := NewNodeStore()
+
+ // Insert values that create a deeper tree
+ stem1 := make([]byte, 31) // left
+ stem2 := make([]byte, 31)
+ stem2[0] = 0x40 // 01... -> goes left at depth 0, right at depth 1
+
+ values1 := make([][]byte, 256)
+ values1[0] = common.HexToHash("0x01").Bytes()
+ values2 := make([][]byte, 256)
+ values2[0] = common.HexToHash("0x02").Bytes()
+
+ if err := s.InsertValuesAtStem(stem1, values1, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.InsertValuesAtStem(stem2, values2, nil); err != nil {
+ t.Fatal(err)
}
- rightStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- }
-
- node := &InternalNode{
- depth: 0,
- left: leftInternal,
- right: rightStem,
- }
-
- height := node.GetHeight()
- // Height should be max(left height, right height) + 1
- // Left height: 2, Right height: 1, so total: 3
- if height != 3 {
- t.Errorf("Expected height 3, got %d", height)
+ height := s.GetHeight(s.Root())
+ if height < 2 {
+ t.Errorf("Expected height >= 2, got %d", height)
}
}
-// TestInternalNodeDepthTooLarge tests handling of excessive depth
+// TestInternalNodeDepthTooLarge tests handling of excessive depth via NodeStore.
func TestInternalNodeDepthTooLarge(t *testing.T) {
- // Create an internal node at max depth
- node := &InternalNode{
- depth: 31*8 + 1,
- left: Empty{},
- right: Empty{},
- }
-
- stem := make([]byte, 31)
- _, err := node.GetValuesAtStem(stem, nil)
- if err == nil {
- t.Fatal("Expected error for excessive depth")
- }
- if err.Error() != "node too deep" {
- t.Errorf("Expected 'node too deep' error, got: %v", err)
- }
+ s := NewNodeStore()
+ // Creating an internal node beyond max depth should panic
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("Expected panic for excessive depth")
+ }
+ }()
+ s.newInternalRef(31*8 + 1)
}
diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go
index 048d37f766..4fe927376f 100644
--- a/trie/bintrie/iterator.go
+++ b/trie/bintrie/iterator.go
@@ -26,13 +26,14 @@ import (
var errIteratorEnd = errors.New("end of iteration")
type binaryNodeIteratorState struct {
- Node BinaryNode
+ Node NodeRef
Index int
}
type binaryNodeIterator struct {
trie *BinaryTrie
- current BinaryNode
+ store *NodeStore
+ current NodeRef
lastErr error
stack []binaryNodeIteratorState
@@ -40,56 +41,59 @@ type binaryNodeIterator struct {
func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) {
if t.Hash() == zero {
- return &binaryNodeIterator{trie: t, lastErr: errIteratorEnd}, nil
+ return &binaryNodeIterator{trie: t, store: t.store, lastErr: errIteratorEnd}, nil
}
- it := &binaryNodeIterator{trie: t, current: t.root}
- // it.err = it.seek(start)
+ it := &binaryNodeIterator{trie: t, store: t.store, current: t.store.Root()}
return it, nil
}
-// Next moves the iterator to the next node. If the parameter is false, any child
-// nodes will be skipped.
+// Next moves the iterator to the next node.
func (it *binaryNodeIterator) Next(descend bool) bool {
if it.lastErr == errIteratorEnd {
- it.lastErr = errIteratorEnd
return false
}
if len(it.stack) == 0 {
- it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.root})
- it.current = it.trie.root
-
+ it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.store.Root()})
+ it.current = it.trie.store.Root()
return true
}
- switch node := it.current.(type) {
- case *InternalNode:
- // index: 0 = nothing visited, 1=left visited, 2=right visited
+ switch it.current.Kind() {
+ case KindInternal:
+ node := it.store.getInternal(it.current.Index())
context := &it.stack[len(it.stack)-1]
- // recurse into both children
+ if !descend {
+ // Skip children: pop this node and advance parent
+ if len(it.stack) == 1 {
+ it.lastErr = errIteratorEnd
+ return false
+ }
+ it.stack = it.stack[:len(it.stack)-1]
+ it.current = it.stack[len(it.stack)-1].Node
+ it.stack[len(it.stack)-1].Index++
+ return it.Next(true)
+ }
+
if context.Index == 0 {
- if _, isempty := node.left.(Empty); node.left != nil && !isempty {
+ if !node.left.IsEmpty() {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left})
it.current = node.left
return it.Next(descend)
}
-
context.Index++
}
if context.Index == 1 {
- if _, isempty := node.right.(Empty); node.right != nil && !isempty {
+ if !node.right.IsEmpty() {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right})
it.current = node.right
return it.Next(descend)
}
-
context.Index++
}
- // Reached the end of this node, go back to the parent, if
- // this isn't root.
if len(it.stack) == 1 {
it.lastErr = errIteratorEnd
return false
@@ -98,17 +102,16 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++
return it.Next(descend)
- case *StemNode:
- // Look for the next non-empty value
+
+ case KindStem:
+ sn := it.store.getStem(it.current.Index())
for i := it.stack[len(it.stack)-1].Index; i < 256; i++ {
- if node.Values[i] != nil {
+ if sn.hasValue(byte(i)) {
it.stack[len(it.stack)-1].Index = i + 1
return true
}
}
- // go back to parent to get the next leaf
- // Check if we're at the root before popping
if len(it.stack) == 1 {
it.lastErr = errIteratorEnd
return false
@@ -117,45 +120,39 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++
return it.Next(descend)
- case HashedNode:
- // resolve the node
- resolverPath := it.Path()
- data, err := it.trie.nodeResolver(resolverPath, common.Hash(node))
- if err != nil {
- panic(err)
- }
- if data == nil {
- // Empty/nil node — treat as Empty, backtrack
- it.current = Empty{}
- it.stack[len(it.stack)-1].Node = it.current
- return it.Next(descend)
- }
- it.current, err = DeserializeNodeWithHash(data, len(it.stack)-1, common.Hash(node))
- if err != nil {
- panic(err)
- }
- // update the stack and parent with the resolved node
- it.stack[len(it.stack)-1].Node = it.current
- if len(it.stack) >= 2 {
- parent := &it.stack[len(it.stack)-2]
- if parent.Index == 0 {
- parent.Node.(*InternalNode).left = it.current
- } else {
- parent.Node.(*InternalNode).right = it.current
- }
- }
- return it.Next(descend)
- case Empty:
- // Empty node - go back to parent and continue
- if len(it.stack) <= 1 {
- it.lastErr = errIteratorEnd
+ case KindHashed:
+ if len(it.stack) < 2 {
+ it.lastErr = errors.New("cannot resolve hashed root during iteration")
return false
}
- it.stack = it.stack[:len(it.stack)-1]
- it.current = it.stack[len(it.stack)-1].Node
- it.stack[len(it.stack)-1].Index++
+ hn := it.store.getHashed(it.current.Index())
+ data, err := it.trie.nodeResolver(it.Path(), hn.hash)
+ if err != nil {
+ it.lastErr = err
+ return false
+ }
+ resolved, err := it.store.DeserializeNodeWithHash(data, len(it.stack)-1, hn.hash)
+ if err != nil {
+ it.lastErr = err
+ return false
+ }
+
+ // Update the stack and parent with the resolved node
+ it.current = resolved
+ it.stack[len(it.stack)-1].Node = resolved
+ parent := &it.stack[len(it.stack)-2]
+ parentNode := it.store.getInternal(parent.Node.Index())
+ if parent.Index == 0 {
+ parentNode.left = resolved
+ } else {
+ parentNode.right = resolved
+ }
return it.Next(descend)
+
+ case KindEmpty:
+ return false
+
default:
panic("invalid node type")
}
@@ -171,25 +168,21 @@ func (it *binaryNodeIterator) Error() error {
// Hash returns the hash of the current node.
func (it *binaryNodeIterator) Hash() common.Hash {
- return it.current.Hash()
+ return it.store.ComputeHash(it.current)
}
-// Parent returns the hash of the parent of the current node. The hash may be the one
-// grandparent if the immediate parent is an internal node with no hash.
+// Parent returns the hash of the parent of the current node.
func (it *binaryNodeIterator) Parent() common.Hash {
- return it.stack[len(it.stack)-1].Node.Hash()
+ return it.store.ComputeHash(it.stack[len(it.stack)-1].Node)
}
// Path returns the hex-encoded path to the current node.
-// Callers must not retain references to the return value after calling Next.
-// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10.
func (it *binaryNodeIterator) Path() []byte {
if it.Leaf() {
return it.LeafKey()
}
var path []byte
for i, state := range it.stack {
- // skip the last byte
if i >= len(it.stack)-1 {
break
}
@@ -200,105 +193,78 @@ func (it *binaryNodeIterator) Path() []byte {
// NodeBlob returns the serialized bytes of the current node.
func (it *binaryNodeIterator) NodeBlob() []byte {
- return SerializeNode(it.current)
+ return it.store.SerializeNode(it.current, MaxGroupDepth)
}
// Leaf returns true iff the current node is a leaf node.
-// In a Binary Trie, a StemNode contains up to 256 leaf values.
-// The iterator is only considered to be "at a leaf" when it's positioned
-// at a specific non-nil value within the StemNode, not just at the StemNode itself.
func (it *binaryNodeIterator) Leaf() bool {
- sn, ok := it.current.(*StemNode)
- if !ok {
+ if it.current.Kind() != KindStem {
return false
}
- // Check if we have a valid stack position
if len(it.stack) == 0 {
return false
}
- // The Index in the stack state points to the NEXT position after the current value.
- // So if Index is 0, we haven't started iterating through the values yet.
- // If Index is 5, we're currently at value[4] (the 5th value, 0-indexed).
idx := it.stack[len(it.stack)-1].Index
if idx == 0 || idx > 256 {
return false
}
- // Check if there's actually a value at the current position
+ sn := it.store.getStem(it.current.Index())
currentValueIndex := idx - 1
- return sn.Values[currentValueIndex] != nil
+ return sn.hasValue(byte(currentValueIndex))
}
-// LeafKey returns the key of the leaf. The method panics if the iterator is not
-// positioned at a leaf. Callers must not retain references to the value after
-// calling Next.
+// LeafKey returns the key of the leaf.
func (it *binaryNodeIterator) LeafKey() []byte {
- leaf, ok := it.current.(*StemNode)
- if !ok {
+ if it.current.Kind() != KindStem {
panic("Leaf() called on an binary node iterator not at a leaf location")
}
- return leaf.Key(it.stack[len(it.stack)-1].Index - 1)
+ sn := it.store.getStem(it.current.Index())
+ return sn.Key(it.stack[len(it.stack)-1].Index - 1)
}
-// LeafBlob returns the content of the leaf. The method panics if the iterator
-// is not positioned at a leaf. Callers must not retain references to the value
-// after calling Next.
+// LeafBlob returns the content of the leaf.
func (it *binaryNodeIterator) LeafBlob() []byte {
- leaf, ok := it.current.(*StemNode)
- if !ok {
+ if it.current.Kind() != KindStem {
panic("LeafBlob() called on an binary node iterator not at a leaf location")
}
- return leaf.Values[it.stack[len(it.stack)-1].Index-1]
+ sn := it.store.getStem(it.current.Index())
+ return sn.getValue(byte(it.stack[len(it.stack)-1].Index - 1))
}
-// LeafProof returns the Merkle proof of the leaf. The method panics if the
-// iterator is not positioned at a leaf. Callers must not retain references
-// to the value after calling Next.
+// LeafProof returns the Merkle proof of the leaf.
func (it *binaryNodeIterator) LeafProof() [][]byte {
- sn, ok := it.current.(*StemNode)
- if !ok {
+ if it.current.Kind() != KindStem {
panic("LeafProof() called on an binary node iterator not at a leaf location")
}
+ sn := it.store.getStem(it.current.Index())
proof := make([][]byte, 0, len(it.stack)+StemNodeWidth)
- // Build proof by walking up the stack and collecting sibling hashes
for i := range it.stack[:len(it.stack)-2] {
state := it.stack[i]
- internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode
+ internalNode := it.store.getInternal(state.Node.Index())
- // Add the sibling hash to the proof
if state.Index == 0 {
- // We came from left, so include right sibling
- proof = append(proof, internalNode.right.Hash().Bytes())
+ rh := it.store.ComputeHash(internalNode.right)
+ proof = append(proof, rh.Bytes())
} else {
- // We came from right, so include left sibling
- proof = append(proof, internalNode.left.Hash().Bytes())
+ lh := it.store.ComputeHash(internalNode.left)
+ proof = append(proof, lh.Bytes())
}
}
// Add the stem and siblings
- proof = append(proof, sn.Stem)
- for _, v := range sn.Values {
- proof = append(proof, v)
- }
+ proof = append(proof, sn.Stem[:])
+ proof = append(proof, sn.allValues()...)
return proof
}
// AddResolver sets an intermediate database to use for looking up trie nodes
// before reaching into the real persistent layer.
-//
-// This is not required for normal operation, rather is an optimization for
-// cases where trie nodes can be recovered from some external mechanism without
-// reading from disk. In those cases, this resolver allows short circuiting
-// accesses and returning them from memory.
-//
-// Before adding a similar mechanism to any other place in Geth, consider
-// making trie.Database an interface and wrapping at that level. It's a huge
-// refactor, but it could be worth it if another occurrence arises.
func (it *binaryNodeIterator) AddResolver(trie.NodeResolver) {
// Not implemented, but should not panic
}
diff --git a/trie/bintrie/iterator_test.go b/trie/bintrie/iterator_test.go
index 3e717c07ba..954e3b8400 100644
--- a/trie/bintrie/iterator_test.go
+++ b/trie/bintrie/iterator_test.go
@@ -27,14 +27,14 @@ import (
// makeTrie creates a BinaryTrie populated with the given key-value pairs.
func makeTrie(t *testing.T, entries [][2]common.Hash) *BinaryTrie {
t.Helper()
+ store := NewNodeStore()
tr := &BinaryTrie{
- root: NewBinaryNode(),
- tracer: trie.NewPrevalueTracer(),
+ store: store,
+ tracer: trie.NewPrevalueTracer(),
+ groupDepth: MaxGroupDepth,
}
for _, kv := range entries {
- var err error
- tr.root, err = tr.root.Insert(kv[0][:], kv[1][:], nil, 0)
- if err != nil {
+ if err := store.Insert(kv[0][:], kv[1][:], nil); err != nil {
t.Fatal(err)
}
}
@@ -64,8 +64,9 @@ func countLeaves(t *testing.T, tr *BinaryTrie) int {
// no nodes and reports no error.
func TestIteratorEmptyTrie(t *testing.T) {
tr := &BinaryTrie{
- root: Empty{},
- tracer: trie.NewPrevalueTracer(),
+ store: NewNodeStore(),
+ tracer: trie.NewPrevalueTracer(),
+ groupDepth: MaxGroupDepth,
}
it, err := newBinaryNodeIterator(tr, nil)
if err != nil {
@@ -145,8 +146,8 @@ func TestIteratorEmptyNodeBacktrack(t *testing.T) {
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
})
- if _, ok := tr.root.(*InternalNode); !ok {
- t.Fatalf("expected InternalNode root, got %T", tr.root)
+ if tr.store.Root().Kind() != KindInternal {
+ t.Fatalf("expected InternalNode root, got kind %d", tr.store.Root().Kind())
}
if leaves := countLeaves(t, tr); leaves != 2 {
t.Fatalf("expected 2 leaves, got %d (Empty backtrack bug?)", leaves)
@@ -162,18 +163,31 @@ func TestIteratorHashedNodeNilData(t *testing.T) {
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
})
- root, ok := tr.root.(*InternalNode)
- if !ok {
- t.Fatalf("expected InternalNode root, got %T", tr.root)
+ root := tr.store.Root()
+ if root.Kind() != KindInternal {
+ t.Fatalf("expected InternalNode root, got kind %d", root.Kind())
}
+ rootNode := tr.store.getInternal(root.Index())
// Replace right child with a zero-hash HashedNode. nodeResolver
// short-circuits on common.Hash{} and returns (nil, nil), which
// triggers the nil-data guard in the iterator.
- root.right = HashedNode(common.Hash{})
+ rootNode.right = tr.store.newHashedRef(common.Hash{})
// Should not panic; the zero-hash right child should be treated as Empty.
- if leaves := countLeaves(t, tr); leaves != 1 {
+ // Since the hashed node can't be resolved (nil data -> empty deserialization),
+ // only the left leaf should be counted.
+ it, err := newBinaryNodeIterator(tr, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ leaves := 0
+ for it.Next(true) {
+ if it.Leaf() {
+ leaves++
+ }
+ }
+ if leaves != 1 {
t.Fatalf("expected 1 leaf (zero-hash right node skipped), got %d", leaves)
}
}
diff --git a/trie/bintrie/node_ref.go b/trie/bintrie/node_ref.go
new file mode 100644
index 0000000000..da6d24a711
--- /dev/null
+++ b/trie/bintrie/node_ref.go
@@ -0,0 +1,57 @@
+// Copyright 2025 go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package bintrie
+
+// NodeKind identifies the type of a trie node stored in a NodeRef.
+type NodeKind uint8
+
+const (
+ KindEmpty NodeKind = iota // no node
+ KindInternal // internal binary branching node
+ KindStem // leaf group containing up to 256 values
+ KindHashed // unresolved node (hash only)
+ KindInvalid // sentinel for validation
+)
+
+// NodeRef is a compact, GC-invisible reference to a node in a NodeStore.
+// It packs a 2-bit type tag (bits 31-30) and a 30-bit index (bits 29-0)
+// into a single uint32. Because NodeRef contains no Go pointers, slices
+// of structs containing NodeRef fields are allocated in noscan spans —
+// the garbage collector never examines them.
+type NodeRef uint32
+
+const (
+ kindShift uint32 = 30
+ indexMask uint32 = (1 << kindShift) - 1
+
+ // EmptyRef is the zero-value NodeRef, representing an empty node.
+ EmptyRef NodeRef = 0
+)
+
+// MakeRef creates a NodeRef from a kind and pool index.
+func MakeRef(kind NodeKind, idx uint32) NodeRef {
+ return NodeRef(uint32(kind)<> kindShift) }
+
+// Index returns the pool index within the node's typed pool.
+func (r NodeRef) Index() uint32 { return uint32(r) & indexMask }
+
+// IsEmpty returns true if this ref represents an empty node.
+func (r NodeRef) IsEmpty() bool { return r == EmptyRef }
diff --git a/trie/bintrie/node_store.go b/trie/bintrie/node_store.go
new file mode 100644
index 0000000000..79670a4331
--- /dev/null
+++ b/trie/bintrie/node_store.go
@@ -0,0 +1,239 @@
+// Copyright 2025 go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package bintrie
+
+import "github.com/ethereum/go-ethereum/common"
+
+// storeChunkSize is the number of nodes per chunk in each typed pool.
+// Using fixed-size array chunks ensures that pointers to nodes within
+// existing chunks remain valid when new chunks are added (no reallocation
+// of the backing data, only the outer pointer slice grows).
+const storeChunkSize = 4096
+
+// NodeStore is a GC-friendly arena for binary trie nodes.
+//
+// Instead of allocating each node as a separate heap object with interface
+// pointers (which the GC must scan), NodeStore packs nodes into typed chunked
+// pools. InternalNode and HashedNode contain ZERO Go pointers, so their pool
+// backing arrays are allocated in noscan spans — the GC skips them entirely.
+// StemNode has one pointer (valueData []byte) per node.
+//
+// For a trie with 25K InternalNodes, this reduces GC-scanned pointer-words
+// from ~125K (with interface-based nodes) to ~25K (just StemNode valueData),
+// an ~80% reduction. At mainnet scale (millions of nodes), this prevents
+// multi-second GC pauses.
+type NodeStore struct {
+ // InternalNode pool — NOSCAN: InternalNode contains zero Go pointers.
+ // Children are NodeRef (uint32), hash is [32]byte.
+ internalChunks []*[storeChunkSize]InternalNode
+ internalCount uint32
+
+ // StemNode pool — each StemNode has one pointer (valueData []byte).
+ // Still much better than the old design where each InternalNode had
+ // two BinaryNode interface pointers (4 pointer-words each).
+ stemChunks []*[storeChunkSize]StemNode
+ stemCount uint32
+
+ // HashedNode pool — NOSCAN: HashedNode is just [32]byte.
+ hashedChunks []*[storeChunkSize]HashedNode
+ hashedCount uint32
+
+ root NodeRef
+
+ // Free lists for recycling deleted node slots.
+ freeInternals []uint32
+ freeStems []uint32
+ freeHashed []uint32
+}
+
+// NewNodeStore creates a new empty NodeStore.
+func NewNodeStore() *NodeStore {
+ return &NodeStore{root: EmptyRef}
+}
+
+// Root returns the root NodeRef.
+func (s *NodeStore) Root() NodeRef { return s.root }
+
+// SetRoot sets the root NodeRef.
+func (s *NodeStore) SetRoot(ref NodeRef) { s.root = ref }
+
+// --- InternalNode allocation ---
+
+func (s *NodeStore) allocInternal() uint32 {
+ if n := len(s.freeInternals); n > 0 {
+ idx := s.freeInternals[n-1]
+ s.freeInternals = s.freeInternals[:n-1]
+ *s.getInternal(idx) = InternalNode{}
+ return idx
+ }
+ idx := s.internalCount
+ chunkIdx := idx / storeChunkSize
+ if uint32(len(s.internalChunks)) <= chunkIdx {
+ s.internalChunks = append(s.internalChunks, new([storeChunkSize]InternalNode))
+ }
+ s.internalCount++
+ if s.internalCount > indexMask {
+ panic("internal node pool overflow")
+ }
+ return idx
+}
+
+func (s *NodeStore) getInternal(idx uint32) *InternalNode {
+ return &s.internalChunks[idx/storeChunkSize][idx%storeChunkSize]
+}
+
+// newInternalRef allocates an InternalNode and returns its NodeRef.
+func (s *NodeStore) newInternalRef(depth int) NodeRef {
+ if depth > 248 {
+ panic("node depth exceeds maximum binary trie depth")
+ }
+ idx := s.allocInternal()
+ n := s.getInternal(idx)
+ n.depth = uint8(depth)
+ n.mustRecompute = true
+ return MakeRef(KindInternal, idx)
+}
+
+// --- StemNode allocation ---
+
+func (s *NodeStore) allocStem() uint32 {
+ if n := len(s.freeStems); n > 0 {
+ idx := s.freeStems[n-1]
+ s.freeStems = s.freeStems[:n-1]
+ *s.getStem(idx) = StemNode{}
+ return idx
+ }
+ idx := s.stemCount
+ chunkIdx := idx / storeChunkSize
+ if uint32(len(s.stemChunks)) <= chunkIdx {
+ s.stemChunks = append(s.stemChunks, new([storeChunkSize]StemNode))
+ }
+ s.stemCount++
+ if s.stemCount > indexMask {
+ panic("internal node pool overflow")
+ }
+ return idx
+}
+
+func (s *NodeStore) getStem(idx uint32) *StemNode {
+ return &s.stemChunks[idx/storeChunkSize][idx%storeChunkSize]
+}
+
+// newStemRef allocates a StemNode with the given stem/depth and returns its NodeRef.
+func (s *NodeStore) newStemRef(stem []byte, depth int) NodeRef {
+ if depth > 248 {
+ panic("node depth exceeds maximum binary trie depth")
+ }
+ idx := s.allocStem()
+ sn := s.getStem(idx)
+ copy(sn.Stem[:], stem[:StemSize])
+ sn.depth = uint8(depth)
+ sn.mustRecompute = true
+ return MakeRef(KindStem, idx)
+}
+
+// --- HashedNode allocation ---
+
+func (s *NodeStore) allocHashed() uint32 {
+ if n := len(s.freeHashed); n > 0 {
+ idx := s.freeHashed[n-1]
+ s.freeHashed = s.freeHashed[:n-1]
+ *s.getHashed(idx) = HashedNode{}
+ return idx
+ }
+ idx := s.hashedCount
+ chunkIdx := idx / storeChunkSize
+ if uint32(len(s.hashedChunks)) <= chunkIdx {
+ s.hashedChunks = append(s.hashedChunks, new([storeChunkSize]HashedNode))
+ }
+ s.hashedCount++
+ if s.hashedCount > indexMask {
+ panic("internal node pool overflow")
+ }
+ return idx
+}
+
+func (s *NodeStore) getHashed(idx uint32) *HashedNode {
+ return &s.hashedChunks[idx/storeChunkSize][idx%storeChunkSize]
+}
+
+func (s *NodeStore) freeHashedNode(idx uint32) {
+ s.freeHashed = append(s.freeHashed, idx)
+}
+
+// newHashedRef allocates a HashedNode and returns its NodeRef.
+func (s *NodeStore) newHashedRef(hash common.Hash) NodeRef {
+ idx := s.allocHashed()
+ *s.getHashed(idx) = HashedNode{hash: hash}
+ return MakeRef(KindHashed, idx)
+}
+
+// Copy creates a deep copy of the NodeStore and all its nodes.
+func (s *NodeStore) Copy() *NodeStore {
+ ns := &NodeStore{
+ root: s.root,
+ internalCount: s.internalCount,
+ stemCount: s.stemCount,
+ hashedCount: s.hashedCount,
+ }
+
+ // Deep copy internal chunks
+ ns.internalChunks = make([]*[storeChunkSize]InternalNode, len(s.internalChunks))
+ for i, chunk := range s.internalChunks {
+ cp := *chunk
+ ns.internalChunks[i] = &cp
+ }
+
+ // Deep copy stem chunks (need to deep copy valueData)
+ ns.stemChunks = make([]*[storeChunkSize]StemNode, len(s.stemChunks))
+ for i, chunk := range s.stemChunks {
+ cp := *chunk
+ ns.stemChunks[i] = &cp
+ }
+ // Deep copy pointer fields for each active stem
+ for i := uint32(0); i < s.stemCount; i++ {
+ src := s.getStem(i)
+ dst := ns.getStem(i)
+ if len(src.valueData) > 0 {
+ dst.valueData = make([]byte, len(src.valueData))
+ copy(dst.valueData, src.valueData)
+ }
+ }
+
+ // Deep copy hashed chunks
+ ns.hashedChunks = make([]*[storeChunkSize]HashedNode, len(s.hashedChunks))
+ for i, chunk := range s.hashedChunks {
+ cp := *chunk
+ ns.hashedChunks[i] = &cp
+ }
+
+ // Copy free lists
+ if len(s.freeInternals) > 0 {
+ ns.freeInternals = make([]uint32, len(s.freeInternals))
+ copy(ns.freeInternals, s.freeInternals)
+ }
+ if len(s.freeStems) > 0 {
+ ns.freeStems = make([]uint32, len(s.freeStems))
+ copy(ns.freeStems, s.freeStems)
+ }
+ if len(s.freeHashed) > 0 {
+ ns.freeHashed = make([]uint32, len(s.freeHashed))
+ copy(ns.freeHashed, s.freeHashed)
+ }
+
+ return ns
+}
diff --git a/trie/bintrie/stem_node.go b/trie/bintrie/stem_node.go
index 0ceae6b062..9f1ac6d2e0 100644
--- a/trie/bintrie/stem_node.go
+++ b/trie/bintrie/stem_node.go
@@ -17,119 +17,135 @@
package bintrie
import (
- "bytes"
- "errors"
- "fmt"
- "slices"
+ "math/bits"
"github.com/ethereum/go-ethereum/common"
)
-// StemNode represents a group of `NodeWith` values sharing the same stem.
+// StemNode represents a group of `StemNodeWidth` values sharing the same stem.
+// It uses a packed representation: bitmap indicates which of the 256 positions
+// have values, and valueData stores the values contiguously in bitmap order.
type StemNode struct {
- Stem []byte // Stem path to get to StemNodeWidth values
- Values [][]byte // All values, indexed by the last byte of the key.
- depth int // Depth of the node
+ Stem [StemSize]byte // Stem path to get to StemNodeWidth values
+ bitmap [StemBitmapSize]byte // bitmap indicating which positions have values
+ valueData []byte // packed value data (count * HashSize bytes)
+ count uint16 // number of values present
+ depth uint8 // Depth of the node
+ shared bool // true if valueData is shared with serialized input
mustRecompute bool // true if the hash needs to be recomputed
- dirty bool // true if the node's on-disk blob is stale (needs flush)
hash common.Hash // cached hash when mustRecompute == false
}
-// Get retrieves the value for the given key.
-func (bt *StemNode) Get(key []byte, _ NodeResolverFn) ([]byte, error) {
- if !bytes.Equal(bt.Stem, key[:StemSize]) {
- return nil, nil
+// posInData returns the index within valueData for the given suffix.
+// Returns -1 if the suffix is not present.
+func (sn *StemNode) posInData(suffix byte) int {
+ idx := int(suffix)
+ if sn.bitmap[idx/8]>>(7-(idx%8))&1 == 0 {
+ return -1
}
- return bt.Values[key[StemSize]], nil
+ // Count the bits set before this position to determine the offset
+ pos := 0
+ byteIdx := idx / 8
+ for i := 0; i < byteIdx; i++ {
+ pos += bits.OnesCount8(sn.bitmap[i])
+ }
+ // Count bits in the partial byte
+ mask := byte(0xFF) << (8 - (idx % 8))
+ pos += bits.OnesCount8(sn.bitmap[byteIdx] & mask)
+ return pos
}
-// Insert inserts a new key-value pair into the node.
-func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
- if !bytes.Equal(bt.Stem, key[:StemSize]) {
- bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
+// getValue returns the value at the given suffix, or nil if not present.
+func (sn *StemNode) getValue(suffix byte) []byte {
+ pos := sn.posInData(suffix)
+ if pos < 0 {
+ return nil
+ }
+ start := pos * HashSize
+ return sn.valueData[start : start+HashSize]
+}
- n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true}
- bt.depth++
- // bt is re-parented under n and sits at a new path — rewrite its blob.
- bt.mustRecompute = true
- bt.dirty = true
- var child, other *BinaryNode
- if bitStem == 0 {
- n.left = bt
- child = &n.left
- other = &n.right
- } else {
- n.right = bt
- child = &n.right
- other = &n.left
+// hasValue returns true if the given suffix has a value.
+func (sn *StemNode) hasValue(suffix byte) bool {
+ idx := int(suffix)
+ return sn.bitmap[idx/8]>>(7-(idx%8))&1 == 1
+}
+
+// allValues returns all 256 values (nil for absent positions).
+func (sn *StemNode) allValues() [][]byte {
+ values := make([][]byte, StemNodeWidth)
+ dataIdx := 0
+ for i := range StemNodeWidth {
+ if sn.bitmap[i/8]>>(7-(i%8))&1 == 1 {
+ values[i] = sn.valueData[dataIdx*HashSize : (dataIdx+1)*HashSize]
+ dataIdx++
}
-
- bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
- if bitKey == bitStem {
- var err error
- *child, err = (*child).Insert(key, value, nil, depth+1)
- if err != nil {
- return n, fmt.Errorf("insert error: %w", err)
- }
- *other = Empty{}
- } else {
- var values [StemNodeWidth][]byte
- values[key[StemSize]] = value
- *other = &StemNode{
- Stem: slices.Clone(key[:StemSize]),
- Values: values[:],
- depth: depth + 1,
- mustRecompute: true,
- dirty: true,
- }
- }
- return n, nil
}
- if len(value) != HashSize {
- return bt, errors.New("invalid insertion: value length")
- }
- bt.Values[key[StemSize]] = value
- bt.mustRecompute = true
- bt.dirty = true
- return bt, nil
+ return values
}
-// Copy creates a deep copy of the node.
-func (bt *StemNode) Copy() BinaryNode {
- var values [StemNodeWidth][]byte
- for i, v := range bt.Values {
- values[i] = slices.Clone(v)
- }
- return &StemNode{
- Stem: slices.Clone(bt.Stem),
- Values: values[:],
- depth: bt.depth,
- hash: bt.hash,
- mustRecompute: bt.mustRecompute,
- dirty: bt.dirty,
+// ensureWritable makes the valueData writable (copies if shared with serialized input).
+func (sn *StemNode) ensureWritable() {
+ if sn.shared || cap(sn.valueData)-len(sn.valueData) < HashSize {
+ newData := make([]byte, len(sn.valueData), len(sn.valueData)+HashSize*4)
+ copy(newData, sn.valueData)
+ sn.valueData = newData
+ sn.shared = false
}
}
-// GetHeight returns the height of the node.
-func (bt *StemNode) GetHeight() int {
- return 1
+// setValue sets or inserts a value at the given suffix.
+func (sn *StemNode) setValue(suffix byte, value []byte) {
+ idx := int(suffix)
+ pos := sn.posInData(suffix)
+ if pos >= 0 {
+ // Overwrite existing value
+ copy(sn.valueData[pos*HashSize:], value[:HashSize])
+ return
+ }
+ // New value: insert into bitmap and valueData at the correct position.
+ sn.bitmap[idx/8] |= 1 << (7 - (idx % 8))
+ sn.count++
+
+ // Find the correct position in valueData (count bits before this position).
+ insertPos := 0
+ byteIdx := idx / 8
+ for i := 0; i < byteIdx; i++ {
+ insertPos += bits.OnesCount8(sn.bitmap[i])
+ }
+ mask := byte(0xFF) << (8 - (idx % 8))
+ insertPos += bits.OnesCount8(sn.bitmap[byteIdx] & mask)
+
+ // Insert value at the correct position in valueData.
+ insertOffset := insertPos * HashSize
+ // Grow the slice
+ sn.valueData = append(sn.valueData, make([]byte, HashSize)...)
+ // Shift data after insertion point
+ copy(sn.valueData[insertOffset+HashSize:], sn.valueData[insertOffset:len(sn.valueData)-HashSize])
+ // Copy the new value
+ copy(sn.valueData[insertOffset:], value[:HashSize])
}
// Hash returns the hash of the node.
-func (bt *StemNode) Hash() common.Hash {
- if !bt.mustRecompute {
- return bt.hash
+func (sn *StemNode) Hash() common.Hash {
+ if !sn.mustRecompute {
+ return sn.hash
}
var data [StemNodeWidth]common.Hash
h := newSha256()
defer returnSha256(h)
- for i, v := range bt.Values {
- if v != nil {
+
+ // Hash each present value
+ dataIdx := 0
+ for i := range StemNodeWidth {
+ if sn.bitmap[i/8]>>(7-(i%8))&1 == 1 {
+ v := sn.valueData[dataIdx*HashSize : (dataIdx+1)*HashSize]
h.Reset()
h.Write(v)
h.Sum(data[i][:0])
+ dataIdx++
}
}
h.Reset()
@@ -150,103 +166,18 @@ func (bt *StemNode) Hash() common.Hash {
}
h.Reset()
- h.Write(bt.Stem)
+ h.Write(sn.Stem[:])
h.Write([]byte{0})
h.Write(data[0][:])
- bt.hash = common.BytesToHash(h.Sum(nil))
- bt.mustRecompute = false
- return bt.hash
-}
-
-// CollectNodes flushes the stem via the collector when dirty; clean stems
-// are skipped.
-func (bt *StemNode) CollectNodes(path []byte, flush NodeFlushFn) error {
- if !bt.dirty {
- return nil
- }
- flush(path, bt)
- bt.dirty = false
- return nil
-}
-
-// GetValuesAtStem retrieves the group of values located at the given stem key.
-func (bt *StemNode) GetValuesAtStem(stem []byte, _ NodeResolverFn) ([][]byte, error) {
- if !bytes.Equal(bt.Stem, stem) {
- return nil, nil
- }
- return bt.Values[:], nil
-}
-
-// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
-// Already-existing values will be overwritten.
-func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
- if !bytes.Equal(bt.Stem, key[:StemSize]) {
- bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
-
- n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true}
- bt.depth++
- // bt is re-parented under n and sits at a new path — rewrite its blob.
- bt.mustRecompute = true
- bt.dirty = true
- var child, other *BinaryNode
- if bitStem == 0 {
- n.left = bt
- child = &n.left
- other = &n.right
- } else {
- n.right = bt
- child = &n.right
- other = &n.left
- }
-
- bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
- if bitKey == bitStem {
- var err error
- *child, err = (*child).InsertValuesAtStem(key, values, nil, depth+1)
- if err != nil {
- return n, fmt.Errorf("insert error: %w", err)
- }
- *other = Empty{}
- } else {
- *other = &StemNode{
- Stem: slices.Clone(key[:StemSize]),
- Values: values,
- depth: n.depth + 1,
- mustRecompute: true,
- dirty: true,
- }
- }
- return n, nil
- }
-
- // same stem, just merge the two value lists
- for i, v := range values {
- if v != nil {
- bt.Values[i] = v
- bt.mustRecompute = true
- bt.dirty = true
- }
- }
- return bt, nil
-}
-
-func (bt *StemNode) toDot(parent, path string) string {
- me := fmt.Sprintf("stem%s", path)
- ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, bt.Stem, bt.Hash())
- ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
- for i, v := range bt.Values {
- if v != nil {
- ret = fmt.Sprintf("%s%s%x [label=\"%x\"]\n", ret, me, i, v)
- ret = fmt.Sprintf("%s%s -> %s%x\n", ret, me, me, i)
- }
- }
- return ret
+ sn.hash = common.BytesToHash(h.Sum(nil))
+ sn.mustRecompute = false
+ return sn.hash
}
// Key returns the full key for the given index.
-func (bt *StemNode) Key(i int) []byte {
+func (sn *StemNode) Key(i int) []byte {
var ret [HashSize]byte
- copy(ret[:], bt.Stem)
+ copy(ret[:], sn.Stem[:])
ret[StemSize] = byte(i)
return ret[:]
}
diff --git a/trie/bintrie/stem_node_test.go b/trie/bintrie/stem_node_test.go
index 2743e7ce9b..9ad77ff438 100644
--- a/trie/bintrie/stem_node_test.go
+++ b/trie/bintrie/stem_node_test.go
@@ -23,165 +23,99 @@ import (
"github.com/ethereum/go-ethereum/common"
)
-// TestStemNodeGet tests the Get method for matching stem, non-matching stem,
-// and nil-value suffix scenarios.
-func TestStemNodeGet(t *testing.T) {
- stem := make([]byte, StemSize)
- stem[0] = 0xAB
- var values [StemNodeWidth][]byte
- values[5] = common.HexToHash("0xdeadbeef").Bytes()
-
- node := &StemNode{Stem: stem, Values: values[:], depth: 0}
-
- // Matching stem, populated suffix → returns value.
- key := make([]byte, HashSize)
- copy(key[:StemSize], stem)
- key[StemSize] = 5
- got, err := node.Get(key, nil)
- if err != nil {
- t.Fatalf("Get error: %v", err)
- }
- if !bytes.Equal(got, values[5]) {
- t.Fatalf("Get = %x, want %x", got, values[5])
- }
-
- // Matching stem, empty suffix → returns nil (slot not set).
- key[StemSize] = 99
- got, err = node.Get(key, nil)
- if err != nil {
- t.Fatalf("Get error: %v", err)
- }
- if got != nil {
- t.Fatalf("Get(empty suffix) = %x, want nil", got)
- }
-
- // Non-matching stem → returns nil, nil.
- otherKey := make([]byte, HashSize)
- otherKey[0] = 0xFF
- got, err = node.Get(otherKey, nil)
- if err != nil {
- t.Fatalf("Get error: %v", err)
- }
- if got != nil {
- t.Fatalf("Get(wrong stem) = %x, want nil", got)
- }
-}
-
-// TestStemNodeInsertSameStem tests inserting values with the same stem
+// TestStemNodeInsertSameStem tests inserting values with the same stem via NodeStore.
func TestStemNodeInsertSameStem(t *testing.T) {
+ s := NewNodeStore()
+
stem := make([]byte, 31)
for i := range stem {
stem[i] = byte(i)
}
- var values [256][]byte
- values[0] = common.HexToHash("0x0101").Bytes()
-
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
+ // Insert first value
+ key1 := make([]byte, 32)
+ copy(key1[:31], stem)
+ key1[31] = 0
+ value1 := common.HexToHash("0x0101").Bytes()
+ if err := s.Insert(key1, value1, nil); err != nil {
+ t.Fatal(err)
}
// Insert another value with the same stem but different last byte
- key := make([]byte, 32)
- copy(key[:31], stem)
- key[31] = 10
- value := common.HexToHash("0x0202").Bytes()
-
- newNode, err := node.Insert(key, value, nil, 0)
- if err != nil {
- t.Fatalf("Failed to insert: %v", err)
+ key2 := make([]byte, 32)
+ copy(key2[:31], stem)
+ key2[31] = 10
+ value2 := common.HexToHash("0x0202").Bytes()
+ if err := s.Insert(key2, value2, nil); err != nil {
+ t.Fatal(err)
}
- // Should still be a StemNode
- stemNode, ok := newNode.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", newNode)
+ // Root should still be a StemNode
+ if s.Root().Kind() != KindStem {
+ t.Fatalf("Expected KindStem root, got kind %d", s.Root().Kind())
}
// Check that both values are present
- if !bytes.Equal(stemNode.Values[0], values[0]) {
+ v1, _ := s.Get(key1, nil)
+ if !bytes.Equal(v1, value1) {
t.Errorf("Value at index 0 mismatch")
}
- if !bytes.Equal(stemNode.Values[10], value) {
+ v2, _ := s.Get(key2, nil)
+ if !bytes.Equal(v2, value2) {
t.Errorf("Value at index 10 mismatch")
}
}
-// TestStemNodeInsertDifferentStem tests inserting values with different stems
+// TestStemNodeInsertDifferentStem tests inserting values with different stems via NodeStore.
func TestStemNodeInsertDifferentStem(t *testing.T) {
- stem1 := make([]byte, 31)
- for i := range stem1 {
- stem1[i] = 0x00
- }
+ s := NewNodeStore()
- var values [256][]byte
- values[0] = common.HexToHash("0x0101").Bytes()
-
- node := &StemNode{
- Stem: stem1,
- Values: values[:],
- depth: 0,
+ // Insert first value with stem of all zeros
+ key1 := make([]byte, 32)
+ key1[31] = 0
+ value1 := common.HexToHash("0x0101").Bytes()
+ if err := s.Insert(key1, value1, nil); err != nil {
+ t.Fatal(err)
}
// Insert with a different stem (first bit different)
- key := make([]byte, 32)
- key[0] = 0x80 // First bit is 1 instead of 0
- value := common.HexToHash("0x0202").Bytes()
-
- newNode, err := node.Insert(key, value, nil, 0)
- if err != nil {
- t.Fatalf("Failed to insert: %v", err)
+ key2 := make([]byte, 32)
+ key2[0] = 0x80 // First bit is 1 instead of 0
+ value2 := common.HexToHash("0x0202").Bytes()
+ if err := s.Insert(key2, value2, nil); err != nil {
+ t.Fatal(err)
}
// Should now be an InternalNode
- internalNode, ok := newNode.(*InternalNode)
- if !ok {
- t.Fatalf("Expected InternalNode, got %T", newNode)
+ if s.Root().Kind() != KindInternal {
+ t.Fatalf("Expected KindInternal root, got kind %d", s.Root().Kind())
}
// Check depth
- if internalNode.depth != 0 {
- t.Errorf("Expected depth 0, got %d", internalNode.depth)
+ rootNode := s.getInternal(s.Root().Index())
+ if rootNode.depth != 0 {
+ t.Errorf("Expected depth 0, got %d", rootNode.depth)
}
- // Original stem should be on the left (bit 0)
- leftStem, ok := internalNode.left.(*StemNode)
- if !ok {
- t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
+ // Verify both values are retrievable
+ v1, _ := s.Get(key1, nil)
+ if !bytes.Equal(v1, value1) {
+ t.Error("Value 1 mismatch")
}
- if !bytes.Equal(leftStem.Stem, stem1) {
- t.Errorf("Left stem mismatch")
- }
-
- // New stem should be on the right (bit 1)
- rightStem, ok := internalNode.right.(*StemNode)
- if !ok {
- t.Fatalf("Expected right child to be StemNode, got %T", internalNode.right)
- }
- if !bytes.Equal(rightStem.Stem, key[:31]) {
- t.Errorf("Right stem mismatch")
+ v2, _ := s.Get(key2, nil)
+ if !bytes.Equal(v2, value2) {
+ t.Error("Value 2 mismatch")
}
}
-// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length
+// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length via NodeStore.
func TestStemNodeInsertInvalidValueLength(t *testing.T) {
- stem := make([]byte, 31)
- var values [256][]byte
+ s := NewNodeStore()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
- }
-
- // Try to insert value with wrong length
key := make([]byte, 32)
- copy(key[:31], stem)
invalidValue := []byte{1, 2, 3} // Not 32 bytes
- _, err := node.Insert(key, invalidValue, nil, 0)
+ err := s.Insert(key, invalidValue, nil)
if err == nil {
t.Fatal("Expected error for invalid value length")
}
@@ -191,221 +125,209 @@ func TestStemNodeInsertInvalidValueLength(t *testing.T) {
}
}
-// TestStemNodeCopy tests the Copy method
+// TestStemNodeCopy tests the Copy method via NodeStore.
func TestStemNodeCopy(t *testing.T) {
- stem := make([]byte, 31)
- for i := range stem {
- stem[i] = byte(i)
+ s := NewNodeStore()
+
+ key1 := make([]byte, 32)
+ for i := range 31 {
+ key1[i] = byte(i)
+ }
+ key1[31] = 0
+ value1 := common.HexToHash("0x0101").Bytes()
+
+ key2 := make([]byte, 32)
+ copy(key2[:31], key1[:31])
+ key2[31] = 255
+ value2 := common.HexToHash("0x0202").Bytes()
+
+ if err := s.Insert(key1, value1, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.Insert(key2, value2, nil); err != nil {
+ t.Fatal(err)
}
- var values [256][]byte
- values[0] = common.HexToHash("0x0101").Bytes()
- values[255] = common.HexToHash("0x0202").Bytes()
+ ns := s.Copy()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 10,
- }
-
- // Create a copy
- copied := node.Copy()
- copiedStem, ok := copied.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", copied)
- }
-
- // Check that values are equal but not the same slice
- if !bytes.Equal(copiedStem.Stem, node.Stem) {
- t.Errorf("Stem mismatch after copy")
- }
- if &copiedStem.Stem[0] == &node.Stem[0] {
- t.Error("Stem slice not properly cloned")
- }
-
- // Check values
- if !bytes.Equal(copiedStem.Values[0], node.Values[0]) {
+ // Check that values are equal
+ v1, _ := ns.Get(key1, nil)
+ if !bytes.Equal(v1, value1) {
t.Errorf("Value at index 0 mismatch after copy")
}
- if !bytes.Equal(copiedStem.Values[255], node.Values[255]) {
+ v2, _ := ns.Get(key2, nil)
+ if !bytes.Equal(v2, value2) {
t.Errorf("Value at index 255 mismatch after copy")
}
-
- // Check that value slices are cloned
- if copiedStem.Values[0] != nil && &copiedStem.Values[0][0] == &node.Values[0][0] {
- t.Error("Value slice not properly cloned")
- }
-
- // Check depth
- if copiedStem.depth != node.depth {
- t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedStem.depth)
- }
}
-// TestStemNodeHash tests the Hash method
+// TestStemNodeHash tests the Hash method.
func TestStemNodeHash(t *testing.T) {
- stem := make([]byte, 31)
- var values [256][]byte
- values[0] = common.HexToHash("0x0101").Bytes()
+ s := NewNodeStore()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
+ key := make([]byte, 32)
+ key[31] = 0
+ value := common.HexToHash("0x0101").Bytes()
+ if err := s.Insert(key, value, nil); err != nil {
+ t.Fatal(err)
}
- hash1 := node.Hash()
+ hash1 := s.ComputeHash(s.Root())
// Hash should be deterministic
- hash2 := node.Hash()
+ hash2 := s.ComputeHash(s.Root())
if hash1 != hash2 {
t.Errorf("Hash not deterministic: %x != %x", hash1, hash2)
}
// Changing a value should change the hash
- node.Values[1] = common.HexToHash("0x0202").Bytes()
- node.mustRecompute = true
- hash3 := node.Hash()
+ key2 := make([]byte, 32)
+ key2[31] = 1
+ value2 := common.HexToHash("0x0202").Bytes()
+ if err := s.Insert(key2, value2, nil); err != nil {
+ t.Fatal(err)
+ }
+ hash3 := s.ComputeHash(s.Root())
if hash1 == hash3 {
t.Error("Hash didn't change after modifying values")
}
}
-// TestStemNodeGetValuesAtStem tests GetValuesAtStem method
+// TestStemNodeGetValuesAtStem tests GetValuesAtStem method via NodeStore.
func TestStemNodeGetValuesAtStem(t *testing.T) {
+ s := NewNodeStore()
+
stem := make([]byte, 31)
for i := range stem {
stem[i] = byte(i)
}
- var values [256][]byte
+ values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes()
values[10] = common.HexToHash("0x0202").Bytes()
values[255] = common.HexToHash("0x0303").Bytes()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
+ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
+ t.Fatal(err)
}
// GetValuesAtStem with matching stem
- retrievedValues, err := node.GetValuesAtStem(stem, nil)
+ retrievedValues, err := s.GetValuesAtStem(stem, nil)
if err != nil {
t.Fatalf("Failed to get values: %v", err)
}
- // Check that all values match
- for i := range 256 {
- if !bytes.Equal(retrievedValues[i], values[i]) {
- t.Errorf("Value mismatch at index %d", i)
- }
+ if !bytes.Equal(retrievedValues[0], values[0]) {
+ t.Error("Value at index 0 mismatch")
+ }
+ if !bytes.Equal(retrievedValues[10], values[10]) {
+ t.Error("Value at index 10 mismatch")
+ }
+ if !bytes.Equal(retrievedValues[255], values[255]) {
+ t.Error("Value at index 255 mismatch")
}
- // GetValuesAtStem with different stem should return nil
+ // GetValuesAtStem with different stem should return nil values
differentStem := make([]byte, 31)
differentStem[0] = 0xFF
- shouldBeNil, err := node.GetValuesAtStem(differentStem, nil)
+ shouldBeEmpty, err := s.GetValuesAtStem(differentStem, nil)
if err != nil {
t.Fatalf("Failed to get values with different stem: %v", err)
}
- if shouldBeNil != nil {
- t.Error("Expected nil for different stem, got non-nil")
+ allNil := true
+ for _, v := range shouldBeEmpty {
+ if v != nil {
+ allNil = false
+ break
+ }
+ }
+ if !allNil {
+ t.Error("Expected all nil values for different stem")
}
}
-// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method
+// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method via NodeStore.
func TestStemNodeInsertValuesAtStem(t *testing.T) {
+ s := NewNodeStore()
+
stem := make([]byte, 31)
- var values [256][]byte
+ values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
+ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
+ t.Fatal(err)
}
// Insert new values at the same stem
- var newValues [256][]byte
+ newValues := make([][]byte, 256)
newValues[1] = common.HexToHash("0x0202").Bytes()
newValues[2] = common.HexToHash("0x0303").Bytes()
- newNode, err := node.InsertValuesAtStem(stem, newValues[:], nil, 0)
- if err != nil {
- t.Fatalf("Failed to insert values: %v", err)
- }
-
- stemNode, ok := newNode.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", newNode)
+ if err := s.InsertValuesAtStem(stem, newValues, nil); err != nil {
+ t.Fatal(err)
}
// Check that all values are present
- if !bytes.Equal(stemNode.Values[0], values[0]) {
+ retrieved, err := s.GetValuesAtStem(stem, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(retrieved[0], values[0]) {
t.Error("Original value at index 0 missing")
}
- if !bytes.Equal(stemNode.Values[1], newValues[1]) {
+ if !bytes.Equal(retrieved[1], newValues[1]) {
t.Error("New value at index 1 missing")
}
- if !bytes.Equal(stemNode.Values[2], newValues[2]) {
+ if !bytes.Equal(retrieved[2], newValues[2]) {
t.Error("New value at index 2 missing")
}
}
-// TestStemNodeGetHeight tests GetHeight method
+// TestStemNodeGetHeight tests GetHeight method via NodeStore.
func TestStemNodeGetHeight(t *testing.T) {
- node := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 0,
+ s := NewNodeStore()
+
+ key := make([]byte, 32)
+ value := common.HexToHash("0x01").Bytes()
+ if err := s.Insert(key, value, nil); err != nil {
+ t.Fatal(err)
}
- height := node.GetHeight()
+ height := s.GetHeight(s.Root())
if height != 1 {
t.Errorf("Expected height 1, got %d", height)
}
}
-// TestStemNodeCollectNodes tests CollectNodes method
+// TestStemNodeCollectNodes tests CollectNodes method via NodeStore.
func TestStemNodeCollectNodes(t *testing.T) {
+ s := NewNodeStore()
+
stem := make([]byte, 31)
- var values [256][]byte
+ values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
- dirty: true,
+ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
+ t.Fatal(err)
}
var collectedPaths [][]byte
- var collectedNodes []BinaryNode
-
- flushFn := func(path []byte, n BinaryNode) {
- // Make a copy of the path
+ flushFn := func(path []byte, hash common.Hash, serialized []byte) {
pathCopy := make([]byte, len(path))
copy(pathCopy, path)
collectedPaths = append(collectedPaths, pathCopy)
- collectedNodes = append(collectedNodes, n)
}
- err := node.CollectNodes([]byte{0, 1, 0}, flushFn)
+ err := s.CollectNodes(s.Root(), []byte{0, 1, 0}, flushFn, MaxGroupDepth)
if err != nil {
t.Fatalf("Failed to collect nodes: %v", err)
}
// Should have collected one node (itself)
- if len(collectedNodes) != 1 {
- t.Errorf("Expected 1 collected node, got %d", len(collectedNodes))
- }
-
- // Check that the collected node is the same
- if collectedNodes[0] != node {
- t.Error("Collected node doesn't match original")
+ if len(collectedPaths) != 1 {
+ t.Errorf("Expected 1 collected node, got %d", len(collectedPaths))
}
// Check the path
@@ -413,44 +335,3 @@ func TestStemNodeCollectNodes(t *testing.T) {
t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0])
}
}
-
-// TestStemNodeCollectNodesSkipsClean verifies that a clean stem is not
-// flushed, and that flushing a dirty stem clears its dirty flag so that
-// a subsequent CollectNodes on the same node is a no-op.
-func TestStemNodeCollectNodesSkipsClean(t *testing.T) {
- stem := make([]byte, 31)
- node := &StemNode{
- Stem: stem,
- Values: make([][]byte, 256),
- depth: 0,
- }
-
- var collected []BinaryNode
- flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) }
-
- if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
- t.Fatalf("CollectNodes on clean stem: %v", err)
- }
- if len(collected) != 0 {
- t.Fatalf("expected clean stem not to be flushed, got %d", len(collected))
- }
-
- node.dirty = true
- if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
- t.Fatalf("CollectNodes on dirty stem: %v", err)
- }
- if len(collected) != 1 {
- t.Fatalf("expected dirty stem to be flushed once, got %d", len(collected))
- }
- if node.dirty {
- t.Errorf("stem dirty flag should be cleared after flush")
- }
-
- collected = nil
- if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
- t.Fatalf("CollectNodes after flush: %v", err)
- }
- if len(collected) != 0 {
- t.Errorf("expected no flush on clean stem, got %d", len(collected))
- }
-}
diff --git a/trie/bintrie/store_commit.go b/trie/bintrie/store_commit.go
new file mode 100644
index 0000000000..3f71a71e38
--- /dev/null
+++ b/trie/bintrie/store_commit.go
@@ -0,0 +1,477 @@
+// Copyright 2025 go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package bintrie
+
+import (
+ "errors"
+ "fmt"
+ "math/bits"
+ "sync"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// NodeFlushFn is called during commit to flush serialized nodes.
+type NodeFlushFn func(path []byte, hash common.Hash, serialized []byte)
+
+// Hash computes and returns the root hash.
+func (s *NodeStore) Hash() common.Hash {
+ return s.ComputeHash(s.root)
+}
+
+// ComputeHash computes the hash of the node referenced by ref.
+func (s *NodeStore) ComputeHash(ref NodeRef) common.Hash {
+ switch ref.Kind() {
+ case KindInternal:
+ return s.hashInternal(ref.Index())
+ case KindStem:
+ return s.getStem(ref.Index()).Hash()
+ case KindHashed:
+ return s.getHashed(ref.Index()).hash
+ case KindEmpty:
+ return common.Hash{}
+ default:
+ return common.Hash{}
+ }
+}
+
+// hashInternal computes the hash of an InternalNode. At shallow depths
+// (< parallelHashDepth), the left subtree is hashed in a goroutine while
+// the right subtree is hashed inline. This is safe because left and right
+// subtrees are disjoint in a well-formed tree — no node appears in both.
+// ComputeHash must not be called concurrently with mutations to the NodeStore.
+func (s *NodeStore) hashInternal(idx uint32) common.Hash {
+ node := s.getInternal(idx)
+ if !node.mustRecompute {
+ return node.hash
+ }
+
+ if node.depth < parallelHashDepth {
+ var input [64]byte
+ var lh common.Hash
+ var wg sync.WaitGroup
+ if !node.left.IsEmpty() {
+ wg.Add(1)
+ go func() {
+ lh = s.ComputeHash(node.left)
+ wg.Done()
+ }()
+ }
+ if !node.right.IsEmpty() {
+ rh := s.ComputeHash(node.right)
+ copy(input[32:], rh[:])
+ }
+ wg.Wait()
+ copy(input[:32], lh[:])
+ node.hash = sha256Sum256(input[:])
+ node.mustRecompute = false
+ return node.hash
+ }
+
+ var input [64]byte
+ if !node.left.IsEmpty() {
+ lh := s.ComputeHash(node.left)
+ copy(input[:32], lh[:])
+ }
+ if !node.right.IsEmpty() {
+ rh := s.ComputeHash(node.right)
+ copy(input[32:], rh[:])
+ }
+ node.hash = sha256Sum256(input[:])
+ node.mustRecompute = false
+ return node.hash
+}
+
+// --- Serialization ---
+
+// countSubtreeChildren counts non-empty children at the bottom layer of a subtree.
+func (s *NodeStore) countSubtreeChildren(ref NodeRef, remainingDepth int) int {
+ if remainingDepth == 0 {
+ if ref.IsEmpty() {
+ return 0
+ }
+ return 1
+ }
+ if ref.Kind() == KindInternal {
+ node := s.getInternal(ref.Index())
+ return s.countSubtreeChildren(node.left, remainingDepth-1) + s.countSubtreeChildren(node.right, remainingDepth-1)
+ }
+ if ref.IsEmpty() {
+ return 0
+ }
+ return 1
+}
+
+// serializeSubtreeDirect writes child hashes directly into the serialized buffer.
+func (s *NodeStore) serializeSubtreeDirect(ref NodeRef, remainingDepth int, position int, absoluteDepth int, bitmap []byte, out []byte, hashOffset *int) {
+ if remainingDepth == 0 {
+ if ref.IsEmpty() {
+ return
+ }
+ bitmap[position/8] |= 1 << (7 - (position % 8))
+ h := s.ComputeHash(ref)
+ copy(out[*hashOffset:*hashOffset+HashSize], h[:])
+ *hashOffset += HashSize
+ return
+ }
+
+ if ref.Kind() == KindInternal {
+ node := s.getInternal(ref.Index())
+ leftPos := position * 2
+ rightPos := position*2 + 1
+ s.serializeSubtreeDirect(node.left, remainingDepth-1, leftPos, absoluteDepth+1, bitmap, out, hashOffset)
+ s.serializeSubtreeDirect(node.right, remainingDepth-1, rightPos, absoluteDepth+1, bitmap, out, hashOffset)
+ return
+ }
+
+ if ref.IsEmpty() {
+ return
+ }
+
+ // Leaf (StemNode or HashedNode) at a non-zero remaining depth: compute its position.
+ leafPos := position
+ if ref.Kind() == KindStem {
+ sn := s.getStem(ref.Index())
+ for d := 0; d < remainingDepth; d++ {
+ bit := sn.Stem[(absoluteDepth+d)/8] >> (7 - ((absoluteDepth + d) % 8)) & 1
+ leafPos = leafPos*2 + int(bit)
+ }
+ } else {
+ leafPos = position << remainingDepth
+ }
+ bitmap[leafPos/8] |= 1 << (7 - (leafPos % 8))
+ h := s.ComputeHash(ref)
+ copy(out[*hashOffset:*hashOffset+HashSize], h[:])
+ *hashOffset += HashSize
+}
+
+// SerializeNode serializes a node referenced by ref.
+func (s *NodeStore) SerializeNode(ref NodeRef, groupDepth int) []byte {
+ if groupDepth < 1 || groupDepth > MaxGroupDepth {
+ panic("groupDepth must be between 1 and 8")
+ }
+
+ switch ref.Kind() {
+ case KindInternal:
+ node := s.getInternal(ref.Index())
+ bitmapSize := BitmapSizeForDepth(groupDepth)
+ hashCount := s.countSubtreeChildren(ref, groupDepth)
+ serializedLen := NodeTypeBytes + 1 + bitmapSize + hashCount*HashSize
+ serialized := make([]byte, serializedLen)
+ serialized[0] = nodeTypeInternal
+ serialized[1] = byte(groupDepth)
+ bitmap := serialized[2 : 2+bitmapSize]
+ hashOffset := 2 + bitmapSize
+ s.serializeSubtreeDirect(ref, groupDepth, 0, int(node.depth), bitmap, serialized, &hashOffset)
+ return serialized
+
+ case KindStem:
+ sn := s.getStem(ref.Index())
+ serializedLen := NodeTypeBytes + StemSize + StemBitmapSize + len(sn.valueData)
+ serialized := make([]byte, serializedLen)
+ serialized[0] = nodeTypeStem
+ copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], sn.Stem[:])
+ copy(serialized[NodeTypeBytes+StemSize:NodeTypeBytes+StemSize+StemBitmapSize], sn.bitmap[:])
+ copy(serialized[NodeTypeBytes+StemSize+StemBitmapSize:], sn.valueData)
+ return serialized
+
+ default:
+ panic(fmt.Sprintf("SerializeNode: unexpected node kind %d", ref.Kind()))
+ }
+}
+
+// --- Deserialization ---
+
+var errInvalidSerializedLength = errors.New("invalid serialized node length")
+
+// DeserializeNode deserializes a node from bytes, recomputing its hash.
+func (s *NodeStore) DeserializeNode(serialized []byte, depth int) (NodeRef, error) {
+ return s.deserializeNode(serialized, depth, common.Hash{}, true)
+}
+
+// DeserializeNodeWithHash deserializes a node, using the provided hash.
+func (s *NodeStore) DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (NodeRef, error) {
+ return s.deserializeNode(serialized, depth, hn, false)
+}
+
+func (s *NodeStore) deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute bool) (NodeRef, error) {
+ if len(serialized) == 0 {
+ return EmptyRef, nil
+ }
+
+ switch serialized[0] {
+ case nodeTypeInternal:
+ if len(serialized) < NodeTypeBytes+1 {
+ return EmptyRef, errInvalidSerializedLength
+ }
+ groupDepth := int(serialized[1])
+ if groupDepth < 1 || groupDepth > MaxGroupDepth {
+ return EmptyRef, errors.New("invalid group depth")
+ }
+ bitmapSize := BitmapSizeForDepth(groupDepth)
+ if len(serialized) < NodeTypeBytes+1+bitmapSize {
+ return EmptyRef, errInvalidSerializedLength
+ }
+ bitmap := serialized[2 : 2+bitmapSize]
+ hashData := serialized[2+bitmapSize:]
+
+ hashIdx := 0
+ ref, err := s.deserializeSubtree(groupDepth, 0, depth, bitmap, hashData, &hashIdx, mustRecompute)
+ if err != nil {
+ return EmptyRef, err
+ }
+ if ref.Kind() == KindInternal && !mustRecompute {
+ s.getInternal(ref.Index()).hash = hn
+ }
+ return ref, nil
+
+ case nodeTypeStem:
+ if len(serialized) < 64 {
+ return EmptyRef, errInvalidSerializedLength
+ }
+ stemIdx := s.allocStem()
+ sn := s.getStem(stemIdx)
+ copy(sn.Stem[:], serialized[NodeTypeBytes:NodeTypeBytes+StemSize])
+ copy(sn.bitmap[:], serialized[NodeTypeBytes+StemSize:NodeTypeBytes+StemSize+StemBitmapSize])
+
+ var count uint16
+ for i := range StemBitmapSize {
+ count += uint16(bits.OnesCount8(sn.bitmap[i]))
+ }
+ sn.count = count
+ dataStart := NodeTypeBytes + StemSize + StemBitmapSize
+ dataEnd := dataStart + int(count)*HashSize
+ if len(serialized) < dataEnd {
+ return EmptyRef, errInvalidSerializedLength
+ }
+ // Zero-copy sub-slice of serialized data
+ sn.valueData = serialized[dataStart:dataEnd]
+ sn.shared = true
+ sn.depth = uint8(depth)
+ sn.hash = hn
+ sn.mustRecompute = mustRecompute
+ return MakeRef(KindStem, stemIdx), nil
+
+ default:
+ return EmptyRef, errors.New("invalid node type")
+ }
+}
+
+func (s *NodeStore) deserializeSubtree(remainingDepth int, position int, nodeDepth int, bitmap []byte, hashData []byte, hashIdx *int, mustRecompute bool) (NodeRef, error) {
+ if remainingDepth == 0 {
+ if bitmap[position/8]>>(7-(position%8))&1 == 1 {
+ if len(hashData) < (*hashIdx+1)*HashSize {
+ return EmptyRef, errInvalidSerializedLength
+ }
+ hash := common.BytesToHash(hashData[*hashIdx*HashSize : (*hashIdx+1)*HashSize])
+ *hashIdx++
+ return s.newHashedRef(hash), nil
+ }
+ return EmptyRef, nil
+ }
+
+ leftPos := position * 2
+ rightPos := position*2 + 1
+
+ left, err := s.deserializeSubtree(remainingDepth-1, leftPos, nodeDepth+1, bitmap, hashData, hashIdx, true)
+ if err != nil {
+ return EmptyRef, err
+ }
+ right, err := s.deserializeSubtree(remainingDepth-1, rightPos, nodeDepth+1, bitmap, hashData, hashIdx, true)
+ if err != nil {
+ return EmptyRef, err
+ }
+
+ if left.IsEmpty() && right.IsEmpty() {
+ return EmptyRef, nil
+ }
+
+ if nodeDepth > 248 {
+ panic("node depth exceeds maximum binary trie depth")
+ }
+ idx := s.allocInternal()
+ n := s.getInternal(idx)
+ n.depth = uint8(nodeDepth)
+ n.left = left
+ n.right = right
+ n.mustRecompute = mustRecompute
+ return MakeRef(KindInternal, idx), nil
+}
+
+// --- CollectNodes (Commit) ---
+
+// CollectNodes traverses the trie, flushing nodes at group boundaries.
+func (s *NodeStore) CollectNodes(ref NodeRef, path []byte, flushfn NodeFlushFn, groupDepth int) error {
+ if groupDepth < 1 || groupDepth > MaxGroupDepth {
+ return errors.New("groupDepth must be between 1 and 8")
+ }
+ buf := make([]byte, len(path), len(path)+MaxGroupDepth+1)
+ copy(buf, path)
+ return s.collectNodesBuf(ref, buf, flushfn, groupDepth)
+}
+
+func (s *NodeStore) collectNodesBuf(ref NodeRef, buf []byte, flushfn NodeFlushFn, groupDepth int) error {
+ switch ref.Kind() {
+ case KindInternal:
+ node := s.getInternal(ref.Index())
+ if int(node.depth)%groupDepth == 0 {
+ if err := s.collectChildGroupsBuf(ref, buf, flushfn, groupDepth, groupDepth-1); err != nil {
+ return err
+ }
+ serialized := s.SerializeNode(ref, groupDepth)
+ flushfn(buf, s.ComputeHash(ref), serialized)
+ return nil
+ }
+ return s.collectChildGroupsBuf(ref, buf, flushfn, groupDepth, groupDepth-(int(node.depth)%groupDepth)-1)
+
+ case KindStem:
+ serialized := s.SerializeNode(ref, groupDepth)
+ flushfn(buf, s.ComputeHash(ref), serialized)
+ return nil
+
+ case KindHashed, KindEmpty:
+ return nil
+
+ default:
+ return fmt.Errorf("collectNodesBuf: unexpected kind %d", ref.Kind())
+ }
+}
+
+func (s *NodeStore) collectChildGroupsBuf(ref NodeRef, buf []byte, flushfn NodeFlushFn, groupDepth int, remainingLevels int) error {
+ if ref.Kind() != KindInternal {
+ return nil
+ }
+ node := s.getInternal(ref.Index())
+ saved := len(buf)
+ childDepth := int(node.depth) + 1
+
+ if remainingLevels == 0 {
+ if !node.left.IsEmpty() {
+ buf = append(buf, 0)
+ if err := s.collectNodesBuf(node.left, buf, flushfn, groupDepth); err != nil {
+ return err
+ }
+ buf = buf[:saved]
+ }
+ if !node.right.IsEmpty() {
+ buf = append(buf, 1)
+ if err := s.collectNodesBuf(node.right, buf, flushfn, groupDepth); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+
+ // Left child
+ if !node.left.IsEmpty() {
+ if node.left.Kind() == KindInternal {
+ buf = append(buf, 0)
+ if err := s.collectChildGroupsBuf(node.left, buf, flushfn, groupDepth, remainingLevels-1); err != nil {
+ return err
+ }
+ buf = buf[:saved]
+ } else {
+ buf = append(buf, 0)
+ buf = s.extendPathBuf(buf, node.left, remainingLevels, childDepth)
+ if err := s.collectNodesBuf(node.left, buf, flushfn, groupDepth); err != nil {
+ return err
+ }
+ buf = buf[:saved]
+ }
+ }
+
+ // Right child
+ if !node.right.IsEmpty() {
+ if node.right.Kind() == KindInternal {
+ buf = append(buf, 1)
+ if err := s.collectChildGroupsBuf(node.right, buf, flushfn, groupDepth, remainingLevels-1); err != nil {
+ return err
+ }
+ } else {
+ buf = append(buf, 1)
+ buf = s.extendPathBuf(buf, node.right, remainingLevels, childDepth)
+ if err := s.collectNodesBuf(node.right, buf, flushfn, groupDepth); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// extendPathBuf extends the path buffer to the group's leaf boundary.
+func (s *NodeStore) extendPathBuf(buf []byte, ref NodeRef, remainingLevels int, absoluteDepth int) []byte {
+ if remainingLevels <= 0 {
+ return buf
+ }
+ if ref.Kind() == KindStem {
+ sn := s.getStem(ref.Index())
+ for d := 0; d < remainingLevels; d++ {
+ bit := sn.Stem[(absoluteDepth+d)/8] >> (7 - ((absoluteDepth + d) % 8)) & 1
+ buf = append(buf, bit)
+ }
+ } else {
+ for d := 0; d < remainingLevels; d++ {
+ buf = append(buf, 0)
+ }
+ }
+ return buf
+}
+
+// ToDot generates a DOT representation for debugging.
+func (s *NodeStore) ToDot(ref NodeRef, parent, path string) string {
+ switch ref.Kind() {
+ case KindInternal:
+ node := s.getInternal(ref.Index())
+ me := fmt.Sprintf("internal%s", path)
+ ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, s.ComputeHash(ref))
+ if len(parent) > 0 {
+ ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
+ }
+ if !node.left.IsEmpty() {
+ ret += s.ToDot(node.left, me, fmt.Sprintf("%s%02x", path, 0))
+ }
+ if !node.right.IsEmpty() {
+ ret += s.ToDot(node.right, me, fmt.Sprintf("%s%02x", path, 1))
+ }
+ return ret
+ case KindStem:
+ sn := s.getStem(ref.Index())
+ me := fmt.Sprintf("stem%s", path)
+ ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, sn.Stem, sn.Hash())
+ ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
+ idx := 0
+ for i := range StemNodeWidth {
+ if sn.bitmap[i/8]>>(7-i%8)&1 != 1 {
+ continue
+ }
+ v := sn.valueData[idx*HashSize : (idx+1)*HashSize]
+ idx++
+ ret += fmt.Sprintf("%s%x [label=\"%x\"]\n", me, i, v)
+ ret += fmt.Sprintf("%s -> %s%x\n", me, me, i)
+ }
+ return ret
+ case KindHashed:
+ hn := s.getHashed(ref.Index())
+ me := fmt.Sprintf("hash%s", path)
+ ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, hn.hash)
+ ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
+ return ret
+ default:
+ return ""
+ }
+}
diff --git a/trie/bintrie/store_ops.go b/trie/bintrie/store_ops.go
new file mode 100644
index 0000000000..72b2c2389b
--- /dev/null
+++ b/trie/bintrie/store_ops.go
@@ -0,0 +1,556 @@
+// Copyright 2025 go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package bintrie
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// NodeResolverFn resolves a hashed node from the database.
+type NodeResolverFn func([]byte, common.Hash) ([]byte, error)
+
+// GetSingle retrieves a single value at stem[suffix] from the trie root.
+func (s *NodeStore) GetSingle(stem []byte, suffix byte, resolver NodeResolverFn) ([]byte, error) {
+ return s.getSingle(s.root, stem, suffix, resolver)
+}
+
+// getSingle retrieves a single value using iterative traversal.
+func (s *NodeStore) getSingle(ref NodeRef, stem []byte, suffix byte, resolver NodeResolverFn) ([]byte, error) {
+ cur := ref
+ // Track parent for HashedNode resolution (update parent's child ref).
+ var parentIdx uint32
+ var parentIsLeft bool
+ hasParent := false
+
+ for {
+ switch cur.Kind() {
+ case KindInternal:
+ node := s.getInternal(cur.Index())
+ if node.depth >= 31*8 {
+ return nil, errors.New("node too deep")
+ }
+ bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
+ parentIdx = cur.Index()
+ hasParent = true
+ if bit == 0 {
+ parentIsLeft = true
+ cur = node.left
+ } else {
+ parentIsLeft = false
+ cur = node.right
+ }
+
+ case KindStem:
+ sn := s.getStem(cur.Index())
+ if sn.Stem != [StemSize]byte(stem[:StemSize]) {
+ return nil, nil
+ }
+ return sn.getValue(suffix), nil
+
+ case KindHashed:
+ if !hasParent {
+ return nil, errors.New("getSingle: hashed node at root")
+ }
+ hn := s.getHashed(cur.Index())
+ parentNode := s.getInternal(parentIdx)
+ path := makeKeyPath(int(parentNode.depth), stem)
+ data, err := resolver(path, hn.hash)
+ if err != nil {
+ return nil, fmt.Errorf("getSingle resolve error: %w", err)
+ }
+ resolved, err := s.DeserializeNodeWithHash(data, int(parentNode.depth)+1, hn.hash)
+ if err != nil {
+ return nil, fmt.Errorf("getSingle deserialization error: %w", err)
+ }
+ // Update parent's child ref
+ s.freeHashedNode(cur.Index())
+ if parentIsLeft {
+ parentNode.left = resolved
+ } else {
+ parentNode.right = resolved
+ }
+ cur = resolved
+
+ case KindEmpty:
+ return nil, nil
+
+ default:
+ return nil, fmt.Errorf("getSingle: unexpected node kind %d", cur.Kind())
+ }
+ }
+}
+
+// GetValuesAtStem retrieves all 256 values at a stem.
+func (s *NodeStore) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) {
+ return s.getValuesAtStem(s.root, stem, resolver)
+}
+
+// getValuesAtStem uses iterative traversal to find the StemNode.
+func (s *NodeStore) getValuesAtStem(ref NodeRef, stem []byte, resolver NodeResolverFn) ([][]byte, error) {
+ cur := ref
+ var parentIdx uint32
+ var parentIsLeft bool
+ hasParent := false
+
+ for {
+ switch cur.Kind() {
+ case KindInternal:
+ node := s.getInternal(cur.Index())
+ if node.depth >= 31*8 {
+ return nil, errors.New("node too deep")
+ }
+ bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
+ parentIdx = cur.Index()
+ hasParent = true
+ if bit == 0 {
+ parentIsLeft = true
+ cur = node.left
+ } else {
+ parentIsLeft = false
+ cur = node.right
+ }
+
+ case KindStem:
+ sn := s.getStem(cur.Index())
+ if sn.Stem != [StemSize]byte(stem[:StemSize]) {
+ return nil, nil
+ }
+ return sn.allValues(), nil
+
+ case KindHashed:
+ if !hasParent {
+ return nil, errors.New("getValuesAtStem: hashed node at root")
+ }
+ hn := s.getHashed(cur.Index())
+ parentNode := s.getInternal(parentIdx)
+ path := makeKeyPath(int(parentNode.depth), stem)
+ data, err := resolver(path, hn.hash)
+ if err != nil {
+ return nil, fmt.Errorf("getValuesAtStem resolve error: %w", err)
+ }
+ resolved, err := s.DeserializeNodeWithHash(data, int(parentNode.depth)+1, hn.hash)
+ if err != nil {
+ return nil, fmt.Errorf("getValuesAtStem deserialization error: %w", err)
+ }
+ s.freeHashedNode(cur.Index())
+ if parentIsLeft {
+ parentNode.left = resolved
+ } else {
+ parentNode.right = resolved
+ }
+ cur = resolved
+
+ case KindEmpty:
+ var values [StemNodeWidth][]byte
+ return values[:], nil
+
+ default:
+ return nil, fmt.Errorf("getValuesAtStem: unexpected node kind %d", cur.Kind())
+ }
+ }
+}
+
+// InsertSingle inserts a single value at stem[suffix] into the trie.
+func (s *NodeStore) InsertSingle(stem []byte, suffix byte, value []byte, resolver NodeResolverFn) error {
+ if len(value) != HashSize {
+ return errors.New("invalid insertion: value length")
+ }
+
+ // Handle root-is-empty case
+ if s.root.IsEmpty() {
+ ref := s.newStemRef(stem, 0)
+ sn := s.getStem(ref.Index())
+ sn.setValue(suffix, value)
+ s.root = ref
+ return nil
+ }
+
+ // Handle root-is-stem case
+ if s.root.Kind() == KindStem {
+ sn := s.getStem(s.root.Index())
+ if sn.Stem == [StemSize]byte(stem[:StemSize]) {
+ sn.ensureWritable()
+ sn.setValue(suffix, value)
+ sn.mustRecompute = true
+ return nil
+ }
+ // Different stem — promote root to internal node via split
+ newRoot := s.splitStemInsert(s.root, stem, suffix, value, int(sn.depth))
+ s.root = newRoot
+ return nil
+ }
+
+ // Root is an InternalNode — iterative descent
+ return s.insertSingleInternal(stem, suffix, value, resolver)
+}
+
+// insertSingleInternal handles insertion when root is an InternalNode.
+func (s *NodeStore) insertSingleInternal(stem []byte, suffix byte, value []byte, resolver NodeResolverFn) error {
+ type pathEntry struct {
+ internalIdx uint32
+ isLeft bool
+ }
+ var pathStack [256]pathEntry // stack-allocated, max depth 248
+ pathLen := 0
+
+ cur := s.root
+
+ for {
+ switch cur.Kind() {
+ case KindInternal:
+ node := s.getInternal(cur.Index())
+ node.mustRecompute = true
+ bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
+ pathStack[pathLen] = pathEntry{internalIdx: cur.Index(), isLeft: bit == 0}
+ pathLen++
+ if bit == 0 {
+ cur = node.left
+ } else {
+ cur = node.right
+ }
+
+ case KindStem:
+ sn := s.getStem(cur.Index())
+ if sn.Stem == [StemSize]byte(stem[:StemSize]) {
+ sn.ensureWritable()
+ sn.setValue(suffix, value)
+ sn.mustRecompute = true
+ return nil
+ }
+ // Different stem — split
+ parentDepth := int(s.getInternal(pathStack[pathLen-1].internalIdx).depth) + 1
+ newRef := s.splitStemInsert(cur, stem, suffix, value, parentDepth)
+ p := pathStack[pathLen-1]
+ parent := s.getInternal(p.internalIdx)
+ if p.isLeft {
+ parent.left = newRef
+ } else {
+ parent.right = newRef
+ }
+ return nil
+
+ case KindHashed:
+ if pathLen == 0 {
+ return errors.New("insertSingle: hashed node at root")
+ }
+ p := pathStack[pathLen-1]
+ parentNode := s.getInternal(p.internalIdx)
+ hn := s.getHashed(cur.Index())
+ path := makeKeyPath(int(parentNode.depth), stem)
+ data, err := resolver(path, hn.hash)
+ if err != nil {
+ return fmt.Errorf("insertSingle resolve error: %w", err)
+ }
+ resolved, err := s.DeserializeNodeWithHash(data, int(parentNode.depth)+1, hn.hash)
+ if err != nil {
+ return fmt.Errorf("insertSingle deserialization error: %w", err)
+ }
+ s.freeHashedNode(cur.Index())
+ if p.isLeft {
+ parentNode.left = resolved
+ } else {
+ parentNode.right = resolved
+ }
+ cur = resolved
+
+ case KindEmpty:
+ parentDepth := int(s.getInternal(pathStack[pathLen-1].internalIdx).depth) + 1
+ ref := s.newStemRef(stem, parentDepth)
+ sn := s.getStem(ref.Index())
+ sn.setValue(suffix, value)
+ p := pathStack[pathLen-1]
+ parent := s.getInternal(p.internalIdx)
+ if p.isLeft {
+ parent.left = ref
+ } else {
+ parent.right = ref
+ }
+ return nil
+
+ default:
+ return fmt.Errorf("insertSingle: unexpected node kind %d", cur.Kind())
+ }
+ }
+}
+
+// splitStemInsert handles the case where we need to split a StemNode
+// into a chain of InternalNodes because the new key has a different stem.
+func (s *NodeStore) splitStemInsert(existingRef NodeRef, newStem []byte, suffix byte, value []byte, depth int) NodeRef {
+ existing := s.getStem(existingRef.Index())
+ existingDepth := depth
+
+ var firstRef NodeRef
+ var lastInternalIdx uint32
+ var lastIsLeft bool
+ first := true
+
+ for {
+ bitExisting := existing.Stem[existingDepth/8] >> (7 - (existingDepth % 8)) & 1
+ bitNew := newStem[existingDepth/8] >> (7 - (existingDepth % 8)) & 1
+
+ newInternalIdx := s.allocInternal()
+ newInternal := s.getInternal(newInternalIdx)
+ newInternal.depth = uint8(existingDepth)
+ newInternal.mustRecompute = true
+ newRef := MakeRef(KindInternal, newInternalIdx)
+
+ if first {
+ firstRef = newRef
+ first = false
+ } else {
+ parent := s.getInternal(lastInternalIdx)
+ if lastIsLeft {
+ parent.left = newRef
+ } else {
+ parent.right = newRef
+ }
+ }
+
+ if bitExisting != bitNew {
+ // Divergence point
+ existing.depth = uint8(existingDepth + 1)
+
+ newStemIdx := s.allocStem()
+ newSn := s.getStem(newStemIdx)
+ copy(newSn.Stem[:], newStem[:StemSize])
+ newSn.depth = uint8(existingDepth + 1)
+ newSn.mustRecompute = true
+ newSn.setValue(suffix, value)
+ newStemRef := MakeRef(KindStem, newStemIdx)
+
+ if bitExisting == 0 {
+ newInternal.left = existingRef
+ newInternal.right = newStemRef
+ } else {
+ newInternal.left = newStemRef
+ newInternal.right = existingRef
+ }
+ return firstRef
+ }
+
+ // Same bit — continue splitting
+ lastInternalIdx = newInternalIdx
+ lastIsLeft = (bitExisting == 0)
+ existingDepth++
+ }
+}
+
+// InsertValuesAtStem inserts a full group of values at the given stem.
+func (s *NodeStore) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn) error {
+ newRoot, err := s.insertValuesAtStem(s.root, stem, values, resolver, 0)
+ if err != nil {
+ return err
+ }
+ s.root = newRoot
+ return nil
+}
+
+// insertValuesAtStem recursively inserts values at a stem.
+func (s *NodeStore) insertValuesAtStem(ref NodeRef, stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (NodeRef, error) {
+ switch ref.Kind() {
+ case KindInternal:
+ node := s.getInternal(ref.Index())
+ bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
+ if bit == 0 {
+ if node.left.Kind() == KindHashed {
+ hn := s.getHashed(node.left.Index())
+ path := makeKeyPath(int(node.depth), stem)
+ data, err := resolver(path, hn.hash)
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
+ }
+ resolved, err := s.DeserializeNodeWithHash(data, int(node.depth)+1, hn.hash)
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
+ }
+ s.freeHashedNode(node.left.Index())
+ node.left = resolved
+ }
+ newChild, err := s.insertValuesAtStem(node.left, stem, values, resolver, depth+1)
+ if err != nil {
+ return ref, err
+ }
+ node.left = newChild
+ } else {
+ if node.right.Kind() == KindHashed {
+ hn := s.getHashed(node.right.Index())
+ path := makeKeyPath(int(node.depth), stem)
+ data, err := resolver(path, hn.hash)
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
+ }
+ resolved, err := s.DeserializeNodeWithHash(data, int(node.depth)+1, hn.hash)
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
+ }
+ s.freeHashedNode(node.right.Index())
+ node.right = resolved
+ }
+ newChild, err := s.insertValuesAtStem(node.right, stem, values, resolver, depth+1)
+ if err != nil {
+ return ref, err
+ }
+ node.right = newChild
+ }
+ node.mustRecompute = true
+ return ref, nil
+
+ case KindStem:
+ sn := s.getStem(ref.Index())
+ if sn.Stem == [StemSize]byte(stem[:StemSize]) {
+ // Same stem — merge values
+ sn.ensureWritable()
+ for i, v := range values {
+ if v != nil {
+ sn.setValue(byte(i), v)
+ sn.mustRecompute = true
+ }
+ }
+ return ref, nil
+ }
+ // Different stem — split
+ return s.splitStemValuesInsert(ref, stem, values, resolver, depth)
+
+ case KindHashed:
+ hn := s.getHashed(ref.Index())
+ path, err := keyToPath(depth, stem)
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
+ }
+ if resolver == nil {
+ return ref, errors.New("InsertValuesAtStem: resolver is nil")
+ }
+ data, err := resolver(path, hn.hash)
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
+ }
+ resolved, err := s.DeserializeNodeWithHash(data, depth, hn.hash)
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
+ }
+ s.freeHashedNode(ref.Index())
+ return s.insertValuesAtStem(resolved, stem, values, resolver, depth)
+
+ case KindEmpty:
+ // Create new StemNode
+ stemIdx := s.allocStem()
+ sn := s.getStem(stemIdx)
+ copy(sn.Stem[:], stem[:StemSize])
+ sn.depth = uint8(depth)
+ sn.mustRecompute = true
+ for i, v := range values {
+ if v != nil {
+ sn.count++
+ sn.bitmap[i/8] |= 1 << (7 - (i % 8))
+ sn.valueData = append(sn.valueData, v[:HashSize]...)
+ }
+ }
+ return MakeRef(KindStem, stemIdx), nil
+
+ default:
+ return ref, fmt.Errorf("insertValuesAtStem: unexpected kind %d", ref.Kind())
+ }
+}
+
+// splitStemValuesInsert handles splitting a StemNode when inserting values with a different stem.
+func (s *NodeStore) splitStemValuesInsert(existingRef NodeRef, newStem []byte, values [][]byte, resolver NodeResolverFn, depth int) (NodeRef, error) {
+ existing := s.getStem(existingRef.Index())
+
+ bitStem := existing.Stem[existing.depth/8] >> (7 - (existing.depth % 8)) & 1
+ nRef := s.newInternalRef(int(existing.depth))
+ nNode := s.getInternal(nRef.Index())
+ existing.depth++
+
+ bitKey := newStem[nNode.depth/8] >> (7 - (nNode.depth % 8)) & 1
+ if bitKey == bitStem {
+ // Same direction — need deeper split
+ var child NodeRef
+ if bitStem == 0 {
+ nNode.left = existingRef
+ child = nNode.left
+ } else {
+ nNode.right = existingRef
+ child = nNode.right
+ }
+ newChild, err := s.insertValuesAtStem(child, newStem, values, resolver, depth+1)
+ if err != nil {
+ return nRef, err
+ }
+ if bitStem == 0 {
+ nNode.left = newChild
+ nNode.right = EmptyRef
+ } else {
+ nNode.right = newChild
+ nNode.left = EmptyRef
+ }
+ } else {
+ // Divergence — create new StemNode for the new values
+ newStemIdx := s.allocStem()
+ newSn := s.getStem(newStemIdx)
+ copy(newSn.Stem[:], newStem[:StemSize])
+ newSn.depth = nNode.depth + 1
+ newSn.mustRecompute = true
+ for i, v := range values {
+ if v != nil {
+ newSn.setValue(byte(i), v)
+ }
+ }
+ newStemRef := MakeRef(KindStem, newStemIdx)
+
+ if bitStem == 0 {
+ nNode.left = existingRef
+ nNode.right = newStemRef
+ } else {
+ nNode.left = newStemRef
+ nNode.right = existingRef
+ }
+ }
+ return nRef, nil
+}
+
+// Insert inserts a key-value pair using the full 32-byte key.
+func (s *NodeStore) Insert(key []byte, value []byte, resolver NodeResolverFn) error {
+ return s.InsertSingle(key[:StemSize], key[StemSize], value, resolver)
+}
+
+// Get retrieves the value for the given 32-byte key.
+func (s *NodeStore) Get(key []byte, resolver NodeResolverFn) ([]byte, error) {
+ return s.GetSingle(key[:StemSize], key[StemSize], resolver)
+}
+
+// GetHeight returns the height of the trie rooted at ref.
+func (s *NodeStore) GetHeight(ref NodeRef) int {
+ switch ref.Kind() {
+ case KindInternal:
+ node := s.getInternal(ref.Index())
+ lh := s.GetHeight(node.left)
+ rh := s.GetHeight(node.right)
+ if lh > rh {
+ return 1 + lh
+ }
+ return 1 + rh
+ case KindStem:
+ return 1
+ case KindEmpty:
+ return 0
+ default:
+ return 0
+ }
+}
diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go
index 23d014eb33..39fa4dac30 100644
--- a/trie/bintrie/trie.go
+++ b/trie/bintrie/trie.go
@@ -19,7 +19,6 @@ package bintrie
import (
"bytes"
"encoding/binary"
- "errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
@@ -31,8 +30,6 @@ import (
"github.com/holiman/uint256"
)
-var errInvalidRootType = errors.New("invalid root type")
-
// ChunkedCode represents a sequence of HashSize-byte chunks of code (StemSize bytes of which
// are actual code, and NodeTypeBytes byte is the pushdata offset).
type ChunkedCode []byte
@@ -108,22 +105,18 @@ func ChunkifyCode(code []byte) ChunkedCode {
return chunks
}
-// NewBinaryNode creates a new empty binary trie
-func NewBinaryNode() BinaryNode {
- return Empty{}
-}
-
// BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864.
type BinaryTrie struct {
- root BinaryNode
- reader *trie.Reader
- tracer *trie.PrevalueTracer
+ store *NodeStore
+ reader *trie.Reader
+ tracer *trie.PrevalueTracer
+ groupDepth int
}
// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
func (t *BinaryTrie) ToDot() string {
- t.root.Hash()
- return ToDot(t.root)
+ t.store.ComputeHash(t.store.Root())
+ return t.store.ToDot(t.store.Root(), "", "")
}
// NewBinaryTrie creates a new binary trie.
@@ -133,9 +126,10 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
return nil, err
}
t := &BinaryTrie{
- root: NewBinaryNode(),
- reader: reader,
- tracer: trie.NewPrevalueTracer(),
+ store: NewNodeStore(),
+ reader: reader,
+ tracer: trie.NewPrevalueTracer(),
+ groupDepth: MaxGroupDepth,
}
// Parse the root node if it's not empty
if root != types.EmptyBinaryHash && root != types.EmptyRootHash {
@@ -143,11 +137,11 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
if err != nil {
return nil, err
}
- node, err := DeserializeNodeWithHash(blob, 0, root)
+ ref, err := t.store.DeserializeNodeWithHash(blob, 0, root)
if err != nil {
return nil, err
}
- t.root = node
+ t.store.SetRoot(ref)
}
return t, nil
}
@@ -176,29 +170,18 @@ func (t *BinaryTrie) GetKey(key []byte) []byte {
// GetWithHashedKey returns the value, assuming that the key has already
// been hashed.
func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) {
- return t.root.Get(key, t.nodeResolver)
+ return t.store.Get(key, t.nodeResolver)
}
// GetAccount returns the account information for the given address.
func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) {
var (
- values [][]byte
err error
acc = &types.StateAccount{}
key = GetBinaryTreeKey(addr, zero[:])
)
- switch r := t.root.(type) {
- case *InternalNode:
- values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver)
- case *StemNode:
- values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver)
- case Empty:
- return nil, nil
- default:
- // This will cover HashedNode but that should be fine since the
- // root node should always be resolved.
- return nil, errInvalidRootType
- }
+
+ values, err := t.store.GetValuesAtStem(key[:StemSize], t.nodeResolver)
if err != nil {
return nil, fmt.Errorf("GetAccount (%x) error: %v", addr, err)
}
@@ -219,7 +202,7 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
// If the account has been deleted, BasicData and CodeHash will both be
// 32-byte zero blobs (not nil). If the account is recreated afterwards,
// UpdateAccount overwrites BasicData and CodeHash with non-zero values,
- // so this branch won't activate..
+ // so this branch won't activate.
if bytes.Equal(values[BasicDataLeafKey], zero[:]) &&
bytes.Equal(values[CodeHashLeafKey], zero[:]) {
return nil, nil
@@ -238,13 +221,12 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
// not be modified by the caller. If a node was not found in the database, a
// trie.MissingNodeError is returned.
func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) {
- return t.root.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
+ return t.store.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
}
// UpdateAccount updates the account information for the given address.
func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error {
var (
- err error
basicData [HashSize]byte
values = make([][]byte, StemNodeWidth)
stem = GetBinaryTreeKey(addr, zero[:])
@@ -265,15 +247,12 @@ func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount,
values[BasicDataLeafKey] = basicData[:]
values[CodeHashLeafKey] = acc.CodeHash[:]
- t.root, err = t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0)
- return err
+ return t.store.InsertValuesAtStem(stem, values, t.nodeResolver)
}
// UpdateStem updates the values for the given stem key.
func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error {
- var err error
- t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0)
- return err
+ return t.store.InsertValuesAtStem(key, values, t.nodeResolver)
}
// UpdateStorage associates key with value in the trie. If value has length zero, any
@@ -288,11 +267,10 @@ func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) er
} else {
copy(v[HashSize-len(value):], value[:])
}
- root, err := t.root.Insert(k, v[:], t.nodeResolver, 0)
+ err := t.store.Insert(k, v[:], t.nodeResolver)
if err != nil {
return fmt.Errorf("UpdateStorage (%x) error: %v", address, err)
}
- t.root = root
return nil
}
@@ -307,12 +285,7 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
values[BasicDataLeafKey] = zero[:]
values[CodeHashLeafKey] = zero[:]
- root, err := t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0)
- if err != nil {
- return fmt.Errorf("DeleteAccount (%x) error: %v", addr, err)
- }
- t.root = root
- return nil
+ return t.store.InsertValuesAtStem(stem, values, t.nodeResolver)
}
// DeleteStorage removes any existing value for key from the trie. If a node was not
@@ -320,18 +293,17 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error {
k := GetBinaryTreeKeyStorageSlot(addr, key)
var zero [HashSize]byte
- root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0)
+ err := t.store.Insert(k, zero[:], t.nodeResolver)
if err != nil {
return fmt.Errorf("DeleteStorage (%x) error: %v", addr, err)
}
- t.root = root
return nil
}
// Hash returns the root hash of the trie. It does not write to the database and
// can be used even if the trie doesn't have one.
func (t *BinaryTrie) Hash() common.Hash {
- return t.root.Hash()
+ return t.store.ComputeHash(t.store.Root())
}
// Commit writes all nodes to the trie's memory database, tracking the internal
@@ -339,15 +311,17 @@ func (t *BinaryTrie) Hash() common.Hash {
func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
nodeset := trienode.NewNodeSet(common.Hash{})
- // The root can be any type of BinaryNode (InternalNode, StemNode, etc.)
- err := t.root.CollectNodes(nil, func(path []byte, node BinaryNode) {
- serialized := SerializeNode(node)
- nodeset.AddNode(path, trienode.NewNodeWithPrev(node.Hash(), serialized, t.tracer.Get(path)))
- })
+ groupDepth := t.groupDepth
+ if groupDepth == 0 {
+ groupDepth = MaxGroupDepth
+ }
+
+ err := t.store.CollectNodes(t.store.Root(), nil, func(path []byte, hash common.Hash, serialized []byte) {
+ nodeset.AddNode(path, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(path)))
+ }, groupDepth)
if err != nil {
panic(fmt.Errorf("CollectNodes failed: %v", err))
}
- // Serialize root commitment form
return t.Hash(), nodeset
}
@@ -371,9 +345,10 @@ func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
// Copy creates a deep copy of the trie.
func (t *BinaryTrie) Copy() *BinaryTrie {
return &BinaryTrie{
- root: t.root.Copy(),
- reader: t.reader,
- tracer: t.tracer.Copy(),
+ store: t.store.Copy(),
+ reader: t.reader,
+ tracer: t.tracer.Copy(),
+ groupDepth: t.groupDepth,
}
}
@@ -407,7 +382,6 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has
if groupOffset == StemNodeWidth-1 || len(chunks)-i <= HashSize {
err = t.UpdateStem(key[:StemSize], values)
-
if err != nil {
return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err)
}
diff --git a/trie/bintrie/trie_test.go b/trie/bintrie/trie_test.go
index c436501464..b06bb4841f 100644
--- a/trie/bintrie/trie_test.go
+++ b/trie/bintrie/trie_test.go
@@ -37,147 +37,130 @@ var (
)
func TestSingleEntry(t *testing.T) {
- tree := NewBinaryNode()
- tree, err := tree.Insert(zeroKey[:], oneKey[:], nil, 0)
- if err != nil {
+ s := NewNodeStore()
+ if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 1 {
+ if s.GetHeight(s.Root()) != 1 {
t.Fatal("invalid depth")
}
expected := common.HexToHash("aab1060e04cb4f5dc6f697ae93156a95714debbf77d54238766adc5709282b6f")
- got := tree.Hash()
+ got := s.Hash()
if got != expected {
t.Fatalf("invalid tree root, got %x, want %x", got, expected)
}
}
func TestTwoEntriesDiffFirstBit(t *testing.T) {
- var err error
- tree := NewBinaryNode()
- tree, err = tree.Insert(zeroKey[:], oneKey[:], nil, 0)
- if err != nil {
+ s := NewNodeStore()
+ if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 2 {
+ if s.GetHeight(s.Root()) != 2 {
t.Fatal("invalid height")
}
- if tree.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") {
+ if s.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") {
t.Fatal("invalid tree root")
}
}
func TestOneStemColocatedValues(t *testing.T) {
- var err error
- tree := NewBinaryNode()
- tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
- if err != nil {
+ s := NewNodeStore()
+ if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 1 {
+ if s.GetHeight(s.Root()) != 1 {
t.Fatal("invalid height")
}
}
func TestTwoStemColocatedValues(t *testing.T) {
- var err error
- tree := NewBinaryNode()
+ s := NewNodeStore()
// stem: 0...0
- tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
// stem: 10...0
- tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 2 {
+ if s.GetHeight(s.Root()) != 2 {
t.Fatal("invalid height")
}
}
func TestTwoKeysMatchFirst42Bits(t *testing.T) {
- var err error
- tree := NewBinaryNode()
+ s := NewNodeStore()
// key1 and key 2 have the same prefix of 42 bits (b0*42+b1+b1) and differ after.
key1 := common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0").Bytes()
key2 := common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000").Bytes()
- tree, err = tree.Insert(key1, oneKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(key1, oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(key2, twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(key2, twoKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 1+42+1 {
+ if s.GetHeight(s.Root()) != 1+42+1 {
t.Fatal("invalid height")
}
}
+
func TestInsertDuplicateKey(t *testing.T) {
- var err error
- tree := NewBinaryNode()
- tree, err = tree.Insert(oneKey[:], oneKey[:], nil, 0)
- if err != nil {
+ s := NewNodeStore()
+ if err := s.Insert(oneKey[:], oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(oneKey[:], twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(oneKey[:], twoKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 1 {
+ if s.GetHeight(s.Root()) != 1 {
t.Fatal("invalid height")
}
// Verify that the value is updated
- if !bytes.Equal(tree.(*StemNode).Values[1], twoKey[:]) {
- t.Fatal("invalid height")
+ v, err := s.Get(oneKey[:], nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(v, twoKey[:]) {
+ t.Fatal("value not updated")
}
}
+
func TestLargeNumberOfEntries(t *testing.T) {
- var err error
- tree := NewBinaryNode()
+ s := NewNodeStore()
for i := range StemNodeWidth {
var key [HashSize]byte
key[0] = byte(i)
- tree, err = tree.Insert(key[:], ffKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(key[:], ffKey[:], nil); err != nil {
t.Fatal(err)
}
}
- height := tree.GetHeight()
+ height := s.GetHeight(s.Root())
if height != 1+8 {
t.Fatalf("invalid height, wanted %d, got %d", 1+8, height)
}
}
func TestMerkleizeMultipleEntries(t *testing.T) {
- var err error
- tree := NewBinaryNode()
+ s := NewNodeStore()
keys := [][]byte{
zeroKey[:],
common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(),
@@ -187,12 +170,11 @@ func TestMerkleizeMultipleEntries(t *testing.T) {
for i, key := range keys {
var v [HashSize]byte
binary.LittleEndian.PutUint64(v[:8], uint64(i))
- tree, err = tree.Insert(key, v[:], nil, 0)
- if err != nil {
+ if err := s.Insert(key, v[:], nil); err != nil {
t.Fatal(err)
}
}
- got := tree.Hash()
+ got := s.Hash()
expected := common.HexToHash("9317155862f7a3867660ddd0966ff799a3d16aa4df1e70a7516eaa4a675191b5")
if got != expected {
t.Fatalf("invalid root, expected=%x, got = %x", expected, got)
@@ -206,8 +188,9 @@ func TestMerkleizeMultipleEntries(t *testing.T) {
func TestStorageRoundTrip(t *testing.T) {
tracer := trie.NewPrevalueTracer()
tr := &BinaryTrie{
- root: NewBinaryNode(),
- tracer: tracer,
+ store: NewNodeStore(),
+ tracer: tracer,
+ groupDepth: MaxGroupDepth,
}
addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678")
@@ -274,8 +257,9 @@ func TestStorageRoundTrip(t *testing.T) {
func newEmptyTestTrie(t *testing.T) *BinaryTrie {
t.Helper()
return &BinaryTrie{
- root: NewBinaryNode(),
- tracer: trie.NewPrevalueTracer(),
+ store: NewNodeStore(),
+ tracer: trie.NewPrevalueTracer(),
+ groupDepth: MaxGroupDepth,
}
}
@@ -599,8 +583,9 @@ func TestBinaryTrieWitness(t *testing.T) {
tracer := trie.NewPrevalueTracer()
tr := &BinaryTrie{
- root: NewBinaryNode(),
- tracer: tracer,
+ store: NewNodeStore(),
+ tracer: tracer,
+ groupDepth: MaxGroupDepth,
}
if w := tr.Witness(); len(w) != 0 {
t.Fatal("expected empty witness for fresh trie")
@@ -626,8 +611,9 @@ func TestBinaryTrieWitness(t *testing.T) {
func testAccount(t *testing.T, addr common.Address, nonce uint64, balance uint64) *BinaryTrie {
t.Helper()
tr := &BinaryTrie{
- root: NewBinaryNode(),
- tracer: trie.NewPrevalueTracer(),
+ store: NewNodeStore(),
+ tracer: trie.NewPrevalueTracer(),
+ groupDepth: MaxGroupDepth,
}
acc := &types.StateAccount{
Nonce: nonce,
@@ -649,8 +635,8 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) {
tr := testAccount(t, addr, 42, 100)
// Verify root is a StemNode (single stem inserted).
- if _, ok := tr.root.(*StemNode); !ok {
- t.Fatalf("expected StemNode root, got %T", tr.root)
+ if tr.store.Root().Kind() != KindStem {
+ t.Fatalf("expected StemNode root, got kind %d", tr.store.Root().Kind())
}
// Query a completely different address — must return nil.
@@ -680,8 +666,9 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) {
// address returns nil when the trie root is an InternalNode (multi-account trie).
func TestGetAccountNonMembershipInternalRoot(t *testing.T) {
tr := &BinaryTrie{
- root: NewBinaryNode(),
- tracer: trie.NewPrevalueTracer(),
+ store: NewNodeStore(),
+ tracer: trie.NewPrevalueTracer(),
+ groupDepth: MaxGroupDepth,
}
// Insert two accounts whose binary tree keys have different first bits
@@ -700,8 +687,8 @@ func TestGetAccountNonMembershipInternalRoot(t *testing.T) {
}
// Verify root is an InternalNode.
- if _, ok := tr.root.(*InternalNode); !ok {
- t.Fatalf("expected InternalNode root, got %T", tr.root)
+ if tr.store.Root().Kind() != KindInternal {
+ t.Fatalf("expected InternalNode root, got kind %d", tr.store.Root().Kind())
}
// Query a non-existent address — must return nil.
@@ -723,8 +710,8 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) {
tr := testAccount(t, addr, 1, 100)
// Verify root is a StemNode.
- if _, ok := tr.root.(*StemNode); !ok {
- t.Fatalf("expected StemNode root, got %T", tr.root)
+ if tr.store.Root().Kind() != KindStem {
+ t.Fatalf("expected StemNode root, got kind %d", tr.store.Root().Kind())
}
// Query storage for a different address — must return nil, not panic.
@@ -743,8 +730,9 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) {
// non-existent address returns nil when the root is an InternalNode.
func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
tr := &BinaryTrie{
- root: NewBinaryNode(),
- tracer: trie.NewPrevalueTracer(),
+ store: NewNodeStore(),
+ tracer: trie.NewPrevalueTracer(),
+ groupDepth: MaxGroupDepth,
}
addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678")
@@ -765,8 +753,8 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
t.Fatalf("UpdateStorage error: %v", err)
}
- if _, ok := tr.root.(*InternalNode); !ok {
- t.Fatalf("expected InternalNode root, got %T", tr.root)
+ if tr.store.Root().Kind() != KindInternal {
+ t.Fatalf("expected InternalNode root, got kind %d", tr.store.Root().Kind())
}
// Query storage for a non-existent address — must return nil.
diff --git a/triedb/pathdb/database.go b/triedb/pathdb/database.go
index 04c76cfd53..98975d7fa5 100644
--- a/triedb/pathdb/database.go
+++ b/triedb/pathdb/database.go
@@ -102,11 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) {
if len(blob) == 0 {
return types.EmptyBinaryHash, nil
}
- n, err := bintrie.DeserializeNode(blob, 0)
- if err != nil {
- return common.Hash{}, err
- }
- return n.Hash(), nil
+ return bintrie.DeserializeAndHash(blob, 0)
}
// Database is a multiple-layered structure for maintaining in-memory states