diff --git a/trie/bintrie/binary_node.go b/trie/bintrie/binary_node.go
index 8905a82285..e7f57d45a2 100644
--- a/trie/bintrie/binary_node.go
+++ b/trie/bintrie/binary_node.go
@@ -16,140 +16,32 @@
package bintrie
-import (
- "errors"
-
- "github.com/ethereum/go-ethereum/common"
-)
-
-type (
- NodeFlushFn func([]byte, BinaryNode)
- NodeResolverFn func([]byte, common.Hash) ([]byte, error)
-)
+import "github.com/ethereum/go-ethereum/common"
// zero is the zero value for a 32-byte array.
var zero [32]byte
const (
- StemNodeWidth = 256 // Number of child per leaf node
- StemSize = 31 // Number of bytes to travel before reaching a group of leaves
- NodeTypeBytes = 1 // Size of node type prefix in serialization
- HashSize = 32 // Size of a hash in bytes
- BitmapSize = 32 // Size of the bitmap in a stem node
+ StemNodeWidth = 256 // Number of children per leaf node
+ StemSize = 31 // Number of bytes to travel before reaching a group of leaves
+ NodeTypeBytes = 1 // Size of node type prefix in serialization
+ HashSize = 32 // Size of a hash in bytes
+ StemBitmapSize = 32 // Size of the bitmap in a stem node (256 values = 32 bytes)
)
const (
- nodeTypeStem = iota + 1 // Stem node, contains a stem and a bitmap of values
+ nodeTypeStem = iota + 1
nodeTypeInternal
)
-// BinaryNode is an interface for a binary trie node.
-type BinaryNode interface {
- Get([]byte, NodeResolverFn) ([]byte, error)
- Insert([]byte, []byte, NodeResolverFn, int) (BinaryNode, error)
- Copy() BinaryNode
- Hash() common.Hash
- GetValuesAtStem([]byte, NodeResolverFn) ([][]byte, error)
- InsertValuesAtStem([]byte, [][]byte, NodeResolverFn, int) (BinaryNode, error)
- CollectNodes([]byte, NodeFlushFn) error
-
- toDot(parent, path string) string
- GetHeight() int
-}
-
-// SerializeNode serializes a binary trie node into a byte slice.
-func SerializeNode(node BinaryNode) []byte {
- switch n := (node).(type) {
- case *InternalNode:
- // InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash
- var serialized [NodeTypeBytes + HashSize + HashSize]byte
- serialized[0] = nodeTypeInternal
- copy(serialized[1:33], n.left.Hash().Bytes())
- copy(serialized[33:65], n.right.Hash().Bytes())
- return serialized[:]
- case *StemNode:
- // StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values
- var serialized [NodeTypeBytes + StemSize + BitmapSize + StemNodeWidth*HashSize]byte
- serialized[0] = nodeTypeStem
- copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], n.Stem)
- bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize]
- offset := NodeTypeBytes + StemSize + BitmapSize
- for i, v := range n.Values {
- if v != nil {
- bitmap[i/8] |= 1 << (7 - (i % 8))
- copy(serialized[offset:offset+HashSize], v)
- offset += HashSize
- }
- }
- // Only return the actual data, not the entire array
- return serialized[:offset]
- default:
- panic("invalid node type")
+// DeserializeAndHash deserializes a node from bytes and returns its hash.
+// This is a convenience function for external callers that need to compute
+// the hash of a serialized node without maintaining a nodeStore.
+func DeserializeAndHash(blob []byte, depth int) (common.Hash, error) {
+ s := newNodeStore()
+ ref, err := s.deserializeNode(blob, depth)
+ if err != nil {
+ return common.Hash{}, err
}
-}
-
-var invalidSerializedLength = errors.New("invalid serialized node length")
-
-// DeserializeNode deserializes a binary trie node from a byte slice. The
-// hash will be recomputed from the deserialized data.
-func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
- return deserializeNode(serialized, depth, common.Hash{}, true, true)
-}
-
-// DeserializeNodeWithHash deserializes a binary trie node from a byte slice, using the provided hash.
-func DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
- return deserializeNode(serialized, depth, hn, false, false)
-}
-
-func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (BinaryNode, error) {
- if len(serialized) == 0 {
- return Empty{}, nil
- }
-
- switch serialized[0] {
- case nodeTypeInternal:
- if len(serialized) != 65 {
- return nil, invalidSerializedLength
- }
- return &InternalNode{
- depth: depth,
- left: HashedNode(common.BytesToHash(serialized[1:33])),
- right: HashedNode(common.BytesToHash(serialized[33:65])),
- hash: hn,
- mustRecompute: mustRecompute,
- dirty: dirty,
- }, nil
- case nodeTypeStem:
- if len(serialized) < 64 {
- return nil, invalidSerializedLength
- }
- var values [StemNodeWidth][]byte
- bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize]
- offset := NodeTypeBytes + StemSize + BitmapSize
-
- for i := range StemNodeWidth {
- if bitmap[i/8]>>(7-(i%8))&1 == 1 {
- if len(serialized) < offset+HashSize {
- return nil, invalidSerializedLength
- }
- values[i] = serialized[offset : offset+HashSize]
- offset += HashSize
- }
- }
- return &StemNode{
- Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
- Values: values[:],
- depth: depth,
- hash: hn,
- mustRecompute: mustRecompute,
- dirty: dirty,
- }, nil
- default:
- return nil, errors.New("invalid node type")
- }
-}
-
-// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
-func ToDot(root BinaryNode) string {
- return root.toDot("", "")
+ return s.computeHash(ref), nil
}
diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go
index 242743ba53..12ac199903 100644
--- a/trie/bintrie/binary_node_test.go
+++ b/trie/bintrie/binary_node_test.go
@@ -23,79 +23,100 @@ import (
"github.com/ethereum/go-ethereum/common"
)
-// TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode
+// TestSerializeDeserializeInternalNode tests flat 65-byte serialization and
+// deserialization of InternalNode through nodeStore.
func TestSerializeDeserializeInternalNode(t *testing.T) {
- // Create an internal node with two hashed children
leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321")
- node := &InternalNode{
- depth: 5,
- left: HashedNode(leftHash),
- right: HashedNode(rightHash),
- }
+ s := newNodeStore()
+ leftRef := s.newHashedRef(leftHash)
+ rightRef := s.newHashedRef(rightHash)
- // Serialize the node
- serialized := SerializeNode(node)
+ rootRef := s.newInternalRef(0)
+ rootNode := s.getInternal(rootRef.Index())
+ rootNode.left = leftRef
+ rootNode.right = rightRef
+ s.root = rootRef
- // Check the serialized format
+ // Serialize the node — flat 65-byte format
+ serialized := s.serializeNode(rootRef)
+
+ // Check the serialized format: [type(1)][leftHash(32)][rightHash(32)]
if serialized[0] != nodeTypeInternal {
t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0])
}
- if len(serialized) != 65 {
- t.Errorf("Expected serialized length to be 65, got %d", len(serialized))
+ expectedLen := NodeTypeBytes + 2*HashSize // 1 + 64 = 65
+ if len(serialized) != expectedLen {
+ t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized))
}
- // Deserialize the node
- deserialized, err := DeserializeNode(serialized, 5)
+ // Check that left and right hashes are embedded directly
+ if !bytes.Equal(serialized[NodeTypeBytes:NodeTypeBytes+HashSize], leftHash[:]) {
+ t.Error("Left hash not found at expected position")
+ }
+ if !bytes.Equal(serialized[NodeTypeBytes+HashSize:], rightHash[:]) {
+ t.Error("Right hash not found at expected position")
+ }
+
+ // Deserialize into a new store
+ ds := newNodeStore()
+ deserialized, err := ds.deserializeNode(serialized, 0)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
- // Check that it's an internal node
- internalNode, ok := deserialized.(*InternalNode)
- if !ok {
- t.Fatalf("Expected InternalNode, got %T", deserialized)
+ // Root should be an InternalNode
+ if deserialized.Kind() != kindInternal {
+ t.Fatalf("Expected kindInternal, got kind %d", deserialized.Kind())
}
- // Check the depth
- if internalNode.depth != 5 {
- t.Errorf("Expected depth 5, got %d", internalNode.depth)
+ internalNode := ds.getInternal(deserialized.Index())
+ if internalNode.depth != 0 {
+ t.Errorf("Expected depth 0, got %d", internalNode.depth)
}
- // Check the left and right hashes
- if internalNode.left.Hash() != leftHash {
- t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash())
+ // Left child should be a HashedNode with the correct hash
+ if internalNode.left.Kind() != kindHashed {
+ t.Fatalf("Expected left child to be kindHashed, got %d", internalNode.left.Kind())
+ }
+ if ds.computeHash(internalNode.left) != leftHash {
+ t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, ds.computeHash(internalNode.left))
}
- if internalNode.right.Hash() != rightHash {
- t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash())
+ // Right child should be a HashedNode with the correct hash
+ if internalNode.right.Kind() != kindHashed {
+ t.Fatalf("Expected right child to be kindHashed, got %d", internalNode.right.Kind())
+ }
+ if ds.computeHash(internalNode.right) != rightHash {
+ t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, ds.computeHash(internalNode.right))
}
}
-// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode
+// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode through nodeStore.
func TestSerializeDeserializeStemNode(t *testing.T) {
- // Create a stem node with some values
stem := make([]byte, StemSize)
for i := range stem {
stem[i] = byte(i)
}
var values [StemNodeWidth][]byte
- // Add some values at different indices
values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes()
values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes()
values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 10,
+ s := newNodeStore()
+ ref := s.newStemRef(stem, 10)
+ sn := s.getStem(ref.Index())
+ for i, v := range values {
+ if v != nil {
+ sn.setValue(byte(i), v)
+ }
}
// Serialize the node
- serialized := SerializeNode(node)
+ serialized := s.serializeNode(ref)
// Check the serialized format
if serialized[0] != nodeTypeStem {
@@ -107,31 +128,32 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
t.Errorf("Stem mismatch in serialized data")
}
- // Deserialize the node
- deserialized, err := DeserializeNode(serialized, 10)
+ // Deserialize into a new store
+ ds := newNodeStore()
+ deserializedRef, err := ds.deserializeNode(serialized, 10)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
- // Check that it's a stem node
- stemNode, ok := deserialized.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", deserialized)
+ if deserializedRef.Kind() != kindStem {
+ t.Fatalf("Expected kindStem, got kind %d", deserializedRef.Kind())
}
+ stemNode := ds.getStem(deserializedRef.Index())
+
// Check the stem
- if !bytes.Equal(stemNode.Stem, stem) {
+ if !bytes.Equal(stemNode.Stem[:], stem) {
t.Errorf("Stem mismatch after deserialization")
}
// Check the values
- if !bytes.Equal(stemNode.Values[0], values[0]) {
+ if !bytes.Equal(stemNode.getValue(0), values[0]) {
t.Errorf("Value at index 0 mismatch")
}
- if !bytes.Equal(stemNode.Values[10], values[10]) {
+ if !bytes.Equal(stemNode.getValue(10), values[10]) {
t.Errorf("Value at index 10 mismatch")
}
- if !bytes.Equal(stemNode.Values[255], values[255]) {
+ if !bytes.Equal(stemNode.getValue(255), values[255]) {
t.Errorf("Value at index 255 mismatch")
}
@@ -140,43 +162,43 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
if i == 0 || i == 10 || i == 255 {
continue
}
- if stemNode.Values[i] != nil {
- t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i])
+ if stemNode.hasValue(byte(i)) {
+ t.Errorf("Expected no value at index %d, got %x", i, stemNode.getValue(byte(i)))
}
}
}
-// TestDeserializeEmptyNode tests deserialization of empty node
+// TestDeserializeEmptyNode tests deserialization of empty node.
func TestDeserializeEmptyNode(t *testing.T) {
- // Empty byte slice should deserialize to Empty node
- deserialized, err := DeserializeNode([]byte{}, 0)
+ s := newNodeStore()
+ deserialized, err := s.deserializeNode([]byte{}, 0)
if err != nil {
t.Fatalf("Failed to deserialize empty node: %v", err)
}
- _, ok := deserialized.(Empty)
- if !ok {
- t.Fatalf("Expected Empty node, got %T", deserialized)
+ if !deserialized.IsEmpty() {
+ t.Fatalf("Expected emptyRef, got kind %d", deserialized.Kind())
}
}
-// TestDeserializeInvalidType tests deserialization with invalid type byte
+// TestDeserializeInvalidType tests deserialization with invalid type byte.
func TestDeserializeInvalidType(t *testing.T) {
- // Create invalid serialized data with unknown type byte
+ s := newNodeStore()
invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid
- _, err := DeserializeNode(invalidData, 0)
+ _, err := s.deserializeNode(invalidData, 0)
if err == nil {
t.Fatal("Expected error for invalid type byte, got nil")
}
}
-// TestDeserializeInvalidLength tests deserialization with invalid data length
+// TestDeserializeInvalidLength tests deserialization with invalid data length.
func TestDeserializeInvalidLength(t *testing.T) {
- // InternalNode with type byte 1 but wrong length
- invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node
+ s := newNodeStore()
+ // InternalNode with valid type byte but wrong length (needs exactly 65 bytes)
+ invalidData := []byte{nodeTypeInternal, 0, 0, 0}
- _, err := DeserializeNode(invalidData, 0)
+ _, err := s.deserializeNode(invalidData, 0)
if err == nil {
t.Fatal("Expected error for invalid data length, got nil")
}
@@ -186,7 +208,7 @@ func TestDeserializeInvalidLength(t *testing.T) {
}
}
-// TestKeyToPath tests the keyToPath function
+// TestKeyToPath tests the keyToPath function.
func TestKeyToPath(t *testing.T) {
tests := []struct {
name string
@@ -218,14 +240,14 @@ func TestKeyToPath(t *testing.T) {
},
{
name: "max valid depth",
- depth: StemSize * 8,
+ depth: StemSize*8 - 1,
key: make([]byte, HashSize),
- expected: make([]byte, StemSize*8+1),
+ expected: make([]byte, StemSize*8),
wantErr: false,
},
{
name: "depth too large",
- depth: StemSize*8 + 1,
+ depth: StemSize * 8,
key: make([]byte, HashSize),
wantErr: true,
},
diff --git a/trie/bintrie/empty.go b/trie/bintrie/empty.go
deleted file mode 100644
index c47e284dac..0000000000
--- a/trie/bintrie/empty.go
+++ /dev/null
@@ -1,76 +0,0 @@
-// Copyright 2025 go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package bintrie
-
-import (
- "slices"
-
- "github.com/ethereum/go-ethereum/common"
-)
-
-type Empty struct{}
-
-func (e Empty) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
- return nil, nil
-}
-
-func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
- var values [256][]byte
- values[key[31]] = value
- return &StemNode{
- Stem: slices.Clone(key[:31]),
- Values: values[:],
- depth: depth,
- mustRecompute: true,
- dirty: true,
- }, nil
-}
-
-func (e Empty) Copy() BinaryNode {
- return Empty{}
-}
-
-func (e Empty) Hash() common.Hash {
- return common.Hash{}
-}
-
-func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
- var values [256][]byte
- return values[:], nil
-}
-
-func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
- return &StemNode{
- Stem: slices.Clone(key[:31]),
- Values: values,
- depth: depth,
- mustRecompute: true,
- dirty: true,
- }, nil
-}
-
-func (e Empty) CollectNodes(_ []byte, _ NodeFlushFn) error {
- return nil
-}
-
-func (e Empty) toDot(parent string, path string) string {
- return ""
-}
-
-func (e Empty) GetHeight() int {
- return 0
-}
diff --git a/trie/bintrie/empty_test.go b/trie/bintrie/empty_test.go
deleted file mode 100644
index 4da1ed15a0..0000000000
--- a/trie/bintrie/empty_test.go
+++ /dev/null
@@ -1,264 +0,0 @@
-// Copyright 2025 go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package bintrie
-
-import (
- "bytes"
- "testing"
-
- "github.com/ethereum/go-ethereum/common"
-)
-
-// TestEmptyGet tests the Get method
-func TestEmptyGet(t *testing.T) {
- node := Empty{}
-
- key := make([]byte, 32)
- value, err := node.Get(key, nil)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
-
- if value != nil {
- t.Errorf("Expected nil value from empty node, got %x", value)
- }
-}
-
-// TestEmptyInsert tests the Insert method
-func TestEmptyInsert(t *testing.T) {
- node := Empty{}
-
- key := make([]byte, 32)
- key[0] = 0x12
- key[31] = 0x34
- value := common.HexToHash("0xabcd").Bytes()
-
- newNode, err := node.Insert(key, value, nil, 0)
- if err != nil {
- t.Fatalf("Failed to insert: %v", err)
- }
-
- // Should create a StemNode
- stemNode, ok := newNode.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", newNode)
- }
-
- // Check the stem (first 31 bytes of key)
- if !bytes.Equal(stemNode.Stem, key[:31]) {
- t.Errorf("Stem mismatch: expected %x, got %x", key[:31], stemNode.Stem)
- }
-
- // Check the value at the correct index (last byte of key)
- if !bytes.Equal(stemNode.Values[key[31]], value) {
- t.Errorf("Value mismatch at index %d: expected %x, got %x", key[31], value, stemNode.Values[key[31]])
- }
-
- // Check that other values are nil
- for i := 0; i < 256; i++ {
- if i != int(key[31]) && stemNode.Values[i] != nil {
- t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i])
- }
- }
-}
-
-// TestEmptyCopy tests the Copy method
-func TestEmptyCopy(t *testing.T) {
- node := Empty{}
-
- copied := node.Copy()
- copiedEmpty, ok := copied.(Empty)
- if !ok {
- t.Fatalf("Expected Empty, got %T", copied)
- }
-
- // Both should be empty
- if node != copiedEmpty {
- // Empty is a zero-value struct, so copies should be equal
- t.Errorf("Empty nodes should be equal")
- }
-}
-
-// TestEmptyHash tests the Hash method
-func TestEmptyHash(t *testing.T) {
- node := Empty{}
-
- hash := node.Hash()
-
- // Empty node should have zero hash
- if hash != (common.Hash{}) {
- t.Errorf("Expected zero hash for empty node, got %x", hash)
- }
-}
-
-// TestEmptyGetValuesAtStem tests the GetValuesAtStem method
-func TestEmptyGetValuesAtStem(t *testing.T) {
- node := Empty{}
-
- stem := make([]byte, 31)
- values, err := node.GetValuesAtStem(stem, nil)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
-
- // Should return an array of 256 nil values
- if len(values) != 256 {
- t.Errorf("Expected 256 values, got %d", len(values))
- }
-
- for i, v := range values {
- if v != nil {
- t.Errorf("Expected nil value at index %d, got %x", i, v)
- }
- }
-}
-
-// TestEmptyInsertValuesAtStem tests the InsertValuesAtStem method
-func TestEmptyInsertValuesAtStem(t *testing.T) {
- node := Empty{}
-
- stem := make([]byte, 31)
- stem[0] = 0x42
-
- var values [256][]byte
- values[0] = common.HexToHash("0x0101").Bytes()
- values[10] = common.HexToHash("0x0202").Bytes()
- values[255] = common.HexToHash("0x0303").Bytes()
-
- newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 5)
- if err != nil {
- t.Fatalf("Failed to insert values: %v", err)
- }
-
- // Should create a StemNode
- stemNode, ok := newNode.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", newNode)
- }
-
- // Check the stem
- if !bytes.Equal(stemNode.Stem, stem) {
- t.Errorf("Stem mismatch: expected %x, got %x", stem, stemNode.Stem)
- }
-
- // Check the depth
- if stemNode.depth != 5 {
- t.Errorf("Depth mismatch: expected 5, got %d", stemNode.depth)
- }
-
- // Check the values
- if !bytes.Equal(stemNode.Values[0], values[0]) {
- t.Error("Value at index 0 mismatch")
- }
- if !bytes.Equal(stemNode.Values[10], values[10]) {
- t.Error("Value at index 10 mismatch")
- }
- if !bytes.Equal(stemNode.Values[255], values[255]) {
- t.Error("Value at index 255 mismatch")
- }
-
- // Check that values is the same slice (not a copy)
- if &stemNode.Values[0] != &values[0] {
- t.Error("Expected values to be the same slice reference")
- }
-}
-
-// TestEmptyCollectNodes tests the CollectNodes method
-func TestEmptyCollectNodes(t *testing.T) {
- node := Empty{}
-
- var collected []BinaryNode
- flushFn := func(path []byte, n BinaryNode) {
- collected = append(collected, n)
- }
-
- err := node.CollectNodes([]byte{0, 1, 0}, flushFn)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
-
- // Should not collect anything for empty node
- if len(collected) != 0 {
- t.Errorf("Expected no collected nodes for empty, got %d", len(collected))
- }
-}
-
-// TestEmptyToDot tests the toDot method
-func TestEmptyToDot(t *testing.T) {
- node := Empty{}
-
- dot := node.toDot("parent", "010")
-
- // Should return empty string for empty node
- if dot != "" {
- t.Errorf("Expected empty string for empty node toDot, got %s", dot)
- }
-}
-
-// TestEmptyGetHeight tests the GetHeight method
-func TestEmptyGetHeight(t *testing.T) {
- node := Empty{}
-
- height := node.GetHeight()
-
- // Empty node should have height 0
- if height != 0 {
- t.Errorf("Expected height 0 for empty node, got %d", height)
- }
-}
-
-// TestEmptyInsertMarksDirty verifies that a StemNode produced by Empty.Insert
-// is marked dirty. Without this, CollectNodes would skip the freshly created
-// stem and its blob would never reach disk, producing "missing trie node"
-// errors on subsequent reads.
-func TestEmptyInsertMarksDirty(t *testing.T) {
- key := make([]byte, 32)
- key[0] = 0xaa
- val := make([]byte, 32)
- val[0] = 0xbb
- n, err := Empty{}.Insert(key, val, nil, 0)
- if err != nil {
- t.Fatalf("Insert: %v", err)
- }
- sn, ok := n.(*StemNode)
- if !ok {
- t.Fatalf("expected *StemNode, got %T", n)
- }
- if !sn.dirty {
- t.Fatalf("stem produced by Empty.Insert must have dirty=true")
- }
-}
-
-// TestEmptyInsertValuesAtStemMarksDirty is the analogous guard for the
-// bulk-insert entry point. Fresh stems created here must be dirty.
-func TestEmptyInsertValuesAtStemMarksDirty(t *testing.T) {
- key := make([]byte, 32)
- key[0] = 0xcc
- values := make([][]byte, 256)
- values[0] = make([]byte, 32)
- n, err := Empty{}.InsertValuesAtStem(key, values, nil, 3)
- if err != nil {
- t.Fatalf("InsertValuesAtStem: %v", err)
- }
- sn, ok := n.(*StemNode)
- if !ok {
- t.Fatalf("expected *StemNode, got %T", n)
- }
- if !sn.dirty {
- t.Fatalf("stem produced by Empty.InsertValuesAtStem must have dirty=true")
- }
-}
diff --git a/trie/bintrie/hashed_node.go b/trie/bintrie/hashed_node.go
index e44c6d1e8a..b176df079b 100644
--- a/trie/bintrie/hashed_node.go
+++ b/trie/bintrie/hashed_node.go
@@ -16,75 +16,10 @@
package bintrie
-import (
- "errors"
- "fmt"
-
- "github.com/ethereum/go-ethereum/common"
-)
+import "github.com/ethereum/go-ethereum/common"
+// HashedNode is an unresolved node — only its hash is known.
type HashedNode common.Hash
-func (h HashedNode) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
- panic("not implemented") // TODO: Implement
-}
-
-func (h HashedNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
- return nil, errors.New("insert not implemented for hashed node")
-}
-
-func (h HashedNode) Copy() BinaryNode {
- nh := common.Hash(h)
- return HashedNode(nh)
-}
-
-func (h HashedNode) Hash() common.Hash {
- return common.Hash(h)
-}
-
-func (h HashedNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
- return nil, errors.New("attempted to get values from an unresolved node")
-}
-
-func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
- // Step 1: Generate the path for this node's position in the tree
- path, err := keyToPath(depth, stem)
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem path generation error: %w", err)
- }
-
- if resolver == nil {
- return nil, errors.New("InsertValuesAtStem resolve error: resolver is nil")
- }
-
- // Step 2: Resolve the hashed node to get the actual node data
- data, err := resolver(path, common.Hash(h))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
- }
-
- // Step 3: Deserialize the resolved data into a concrete node
- node, err := DeserializeNodeWithHash(data, depth, common.Hash(h))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
- }
-
- // Step 4: Call InsertValuesAtStem on the resolved concrete node
- return node.InsertValuesAtStem(stem, values, resolver, depth)
-}
-
-func (h HashedNode) toDot(parent string, path string) string {
- me := fmt.Sprintf("hash%s", path)
- ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, h)
- ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
- return ret
-}
-
-func (h HashedNode) CollectNodes([]byte, NodeFlushFn) error {
- // HashedNodes are already persisted in the database and don't need to be collected.
- return nil
-}
-
-func (h HashedNode) GetHeight() int {
- panic("tried to get the height of a hashed node, this is a bug")
-}
+// Hash returns the node's hash.
+func (h HashedNode) Hash() common.Hash { return common.Hash(h) }
diff --git a/trie/bintrie/hashed_node_test.go b/trie/bintrie/hashed_node_test.go
index f9e6984888..ae77b7c570 100644
--- a/trie/bintrie/hashed_node_test.go
+++ b/trie/bintrie/hashed_node_test.go
@@ -18,180 +18,137 @@ package bintrie
import (
"bytes"
+ "errors"
"testing"
"github.com/ethereum/go-ethereum/common"
)
-// TestHashedNodeHash tests the Hash method
+// TestHashedNodeHash tests the Hash method via nodeStore.
func TestHashedNodeHash(t *testing.T) {
hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
- node := HashedNode(hash)
+ s := newNodeStore()
+ ref := s.newHashedRef(hash)
- // Hash should return the stored hash
- if node.Hash() != hash {
- t.Errorf("Hash mismatch: expected %x, got %x", hash, node.Hash())
+ if s.computeHash(ref) != hash {
+ t.Errorf("Hash mismatch: expected %x, got %x", hash, s.computeHash(ref))
}
}
-// TestHashedNodeCopy tests the Copy method
+// TestHashedNodeCopy tests the Copy method via nodeStore.
func TestHashedNodeCopy(t *testing.T) {
hash := common.HexToHash("0xabcdef")
- node := HashedNode(hash)
+ s := newNodeStore()
+ ref := s.newHashedRef(hash)
+ s.root = ref
- copied := node.Copy()
- copiedHash, ok := copied.(HashedNode)
- if !ok {
- t.Fatalf("Expected HashedNode, got %T", copied)
- }
+ ns := s.Copy()
+ copiedHash := ns.computeHash(ns.root)
- // Hash should be the same
- if common.Hash(copiedHash) != hash {
+ if copiedHash != hash {
t.Errorf("Hash mismatch after copy: expected %x, got %x", hash, copiedHash)
}
-
- // But should be a different object
- if &node == &copiedHash {
- t.Error("Copy returned same object reference")
- }
}
-// TestHashedNodeInsert tests that Insert returns an error
-func TestHashedNodeInsert(t *testing.T) {
- node := HashedNode(common.HexToHash("0x1234"))
-
- key := make([]byte, HashSize)
- value := make([]byte, HashSize)
-
- _, err := node.Insert(key, value, nil, 0)
- if err == nil {
- t.Fatal("Expected error for Insert on HashedNode")
- }
-
- if err.Error() != "insert not implemented for hashed node" {
- t.Errorf("Unexpected error message: %v", err)
- }
-}
-
-// TestHashedNodeGetValuesAtStem tests that GetValuesAtStem returns an error
-func TestHashedNodeGetValuesAtStem(t *testing.T) {
- node := HashedNode(common.HexToHash("0x1234"))
-
- stem := make([]byte, StemSize)
- _, err := node.GetValuesAtStem(stem, nil)
- if err == nil {
- t.Fatal("Expected error for GetValuesAtStem on HashedNode")
- }
-
- if err.Error() != "attempted to get values from an unresolved node" {
- t.Errorf("Unexpected error message: %v", err)
- }
-}
-
-// TestHashedNodeInsertValuesAtStem tests that InsertValuesAtStem returns an error
+// TestHashedNodeInsertValuesAtStem tests InsertValuesAtStem resolution via nodeStore.
func TestHashedNodeInsertValuesAtStem(t *testing.T) {
- node := HashedNode(common.HexToHash("0x1234"))
+ // Test 1: nil resolver should return an error
+ s := newNodeStore()
+ hashedRef := s.newHashedRef(common.HexToHash("0x1234"))
+ s.root = hashedRef
stem := make([]byte, StemSize)
values := make([][]byte, StemNodeWidth)
- // Test 1: nil resolver should return an error
- _, err := node.InsertValuesAtStem(stem, values, nil, 0)
+ err := s.InsertValuesAtStem(stem, values, nil)
if err == nil {
- t.Fatal("Expected error for InsertValuesAtStem on HashedNode with nil resolver")
- }
-
- if err.Error() != "InsertValuesAtStem resolve error: resolver is nil" {
- t.Errorf("Unexpected error message: %v", err)
+ t.Fatal("Expected error for InsertValuesAtStem with nil resolver")
}
// Test 2: mock resolver returning invalid data should return deserialization error
mockResolver := func(path []byte, hash common.Hash) ([]byte, error) {
- // Return invalid/nonsense data that cannot be deserialized
return []byte{0xff, 0xff, 0xff}, nil
}
- _, err = node.InsertValuesAtStem(stem, values, mockResolver, 0)
- if err == nil {
- t.Fatal("Expected error for InsertValuesAtStem on HashedNode with invalid resolver data")
- }
+ s2 := newNodeStore()
+ hashedRef2 := s2.newHashedRef(common.HexToHash("0x1234"))
+ s2.root = hashedRef2
- expectedPrefix := "InsertValuesAtStem node deserialization error:"
- if len(err.Error()) < len(expectedPrefix) || err.Error()[:len(expectedPrefix)] != expectedPrefix {
- t.Errorf("Expected deserialization error, got: %v", err)
+ err = s2.InsertValuesAtStem(stem, values, mockResolver)
+ if err == nil {
+ t.Fatal("Expected error for InsertValuesAtStem with invalid resolver data")
}
// Test 3: mock resolver returning valid serialized node should succeed
stem = make([]byte, StemSize)
stem[0] = 0xaa
- var originalValues [StemNodeWidth][]byte
+ originalValues := make([][]byte, StemNodeWidth)
originalValues[0] = common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111").Bytes()
originalValues[1] = common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222").Bytes()
- originalNode := &StemNode{
- Stem: stem,
- Values: originalValues[:],
- depth: 0,
+ // Build the serialized node
+ rs := newNodeStore()
+ ref := rs.newStemRef(stem, 0)
+ sn := rs.getStem(ref.Index())
+ for i, v := range originalValues {
+ if v != nil {
+ sn.setValue(byte(i), v)
+ }
}
+ serialized := rs.serializeNode(ref)
- // Serialize the node
- serialized := SerializeNode(originalNode)
-
- // Create a mock resolver that returns the serialized node
validResolver := func(path []byte, hash common.Hash) ([]byte, error) {
return serialized, nil
}
- var newValues [StemNodeWidth][]byte
+ s3 := newNodeStore()
+ hashedRef3 := s3.newHashedRef(common.HexToHash("0x1234"))
+ s3.root = hashedRef3
+
+ newValues := make([][]byte, StemNodeWidth)
newValues[2] = common.HexToHash("0x3333333333333333333333333333333333333333333333333333333333333333").Bytes()
- resolvedNode, err := node.InsertValuesAtStem(stem, newValues[:], validResolver, 0)
+ err = s3.InsertValuesAtStem(stem, newValues, validResolver)
if err != nil {
t.Fatalf("Expected successful resolution and insertion, got error: %v", err)
}
- resultStem, ok := resolvedNode.(*StemNode)
- if !ok {
- t.Fatalf("Expected resolved node to be *StemNode, got %T", resolvedNode)
+ // Verify original values are preserved
+ retrieved, err := s3.GetValuesAtStem(stem, nil)
+ if err != nil {
+ t.Fatal(err)
}
-
- if !bytes.Equal(resultStem.Stem, stem) {
- t.Errorf("Stem mismatch: expected %x, got %x", stem, resultStem.Stem)
+ if !bytes.Equal(retrieved[0], originalValues[0]) {
+ t.Errorf("Original value at index 0 not preserved")
}
-
- // Verify the original values are preserved
- if !bytes.Equal(resultStem.Values[0], originalValues[0]) {
- t.Errorf("Original value at index 0 not preserved: expected %x, got %x", originalValues[0], resultStem.Values[0])
+ if !bytes.Equal(retrieved[1], originalValues[1]) {
+ t.Errorf("Original value at index 1 not preserved")
}
- if !bytes.Equal(resultStem.Values[1], originalValues[1]) {
- t.Errorf("Original value at index 1 not preserved: expected %x, got %x", originalValues[1], resultStem.Values[1])
- }
-
- // Verify the new value was inserted
- if !bytes.Equal(resultStem.Values[2], newValues[2]) {
- t.Errorf("New value at index 2 not inserted correctly: expected %x, got %x", newValues[2], resultStem.Values[2])
+ if !bytes.Equal(retrieved[2], newValues[2]) {
+ t.Errorf("New value at index 2 not inserted correctly")
}
}
-// TestHashedNodeToDot tests the toDot method for visualization
-func TestHashedNodeToDot(t *testing.T) {
- hash := common.HexToHash("0x1234")
- node := HashedNode(hash)
+// TestHashedNodeGetError tests that getting through an unresolved HashedNode root returns error.
+func TestHashedNodeGetError(t *testing.T) {
+ s := newNodeStore()
+ // Create root as hashed, then try to resolve through InternalNode parent
+ rootRef := s.newInternalRef(0)
+ rootNode := s.getInternal(rootRef.Index())
+ hashedLeft := s.newHashedRef(common.HexToHash("0x1234"))
+ rootNode.left = hashedLeft
+ rootNode.right = emptyRef
+ s.root = rootRef
- dot := node.toDot("parent", "010")
+ key := make([]byte, 32) // goes left
+ key[31] = 5
- // Should contain the hash value and parent connection
- expectedHash := "hash010"
- if !contains(dot, expectedHash) {
- t.Errorf("Expected dot output to contain %s", expectedHash)
+ resolver := func(path []byte, hash common.Hash) ([]byte, error) {
+ return nil, errors.New("node not found")
}
- if !contains(dot, "parent -> hash010") {
- t.Error("Expected dot output to contain parent connection")
+ _, err := s.Get(key, resolver)
+ if err == nil {
+ t.Fatal("Expected error when resolver fails")
}
}
-
-// Helper function
-func contains(s, substr string) bool {
- return len(s) >= len(substr) && s != "" && len(substr) > 0
-}
diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go
index 811f65bcd8..b83cb92d87 100644
--- a/trie/bintrie/internal_node.go
+++ b/trie/bintrie/internal_node.go
@@ -17,35 +17,13 @@
package bintrie
import (
- "crypto/sha256"
"errors"
- "fmt"
- "math/bits"
- "runtime"
- "sync"
"github.com/ethereum/go-ethereum/common"
)
-// parallelDepth returns the tree depth below which Hash() spawns goroutines.
-func parallelDepth() int {
- return min(bits.Len(uint(runtime.NumCPU())), 8)
-}
-
-// isDirty reports whether a BinaryNode child needs rehashing.
-func isDirty(n BinaryNode) bool {
- switch v := n.(type) {
- case *InternalNode:
- return v.mustRecompute
- case *StemNode:
- return v.mustRecompute
- default:
- return false
- }
-}
-
func keyToPath(depth int, key []byte) ([]byte, error) {
- if depth > 31*8 {
+ if depth >= 31*8 {
return nil, errors.New("node too deep")
}
path := make([]byte, 0, depth+1)
@@ -56,252 +34,12 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
return path, nil
}
-// InternalNode is a binary trie internal node.
+// Invariant: dirty=false implies mustRecompute=false. Every mutation that
+// invalidates the cached hash MUST also mark the blob for re-flush.
type InternalNode struct {
- left, right BinaryNode
- depth int
-
- mustRecompute bool // true if the hash needs to be recomputed
- dirty bool // true if the node's on-disk blob is stale (needs flush)
- hash common.Hash // cached hash when mustRecompute == false
-}
-
-// GetValuesAtStem retrieves the group of values located at the given stem key.
-func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) {
- if bt.depth > 31*8 {
- return nil, errors.New("node too deep")
- }
-
- bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
- if bit == 0 {
- if hn, ok := bt.left.(HashedNode); ok {
- path, err := keyToPath(bt.depth, stem)
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
- }
- data, err := resolver(path, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
- }
- node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
- }
- bt.left = node
- }
- return bt.left.GetValuesAtStem(stem, resolver)
- }
-
- if hn, ok := bt.right.(HashedNode); ok {
- path, err := keyToPath(bt.depth, stem)
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
- }
- data, err := resolver(path, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
- }
- node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
- }
- bt.right = node
- }
- return bt.right.GetValuesAtStem(stem, resolver)
-}
-
-// Get retrieves the value for the given key.
-func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error) {
- values, err := bt.GetValuesAtStem(key[:31], resolver)
- if err != nil {
- return nil, fmt.Errorf("get error: %w", err)
- }
- if values == nil {
- return nil, nil
- }
- return values[key[31]], nil
-}
-
-// Insert inserts a new key-value pair into the trie.
-func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
- var values [256][]byte
- values[key[31]] = value
- return bt.InsertValuesAtStem(key[:31], values[:], resolver, depth)
-}
-
-// Copy creates a deep copy of the node.
-func (bt *InternalNode) Copy() BinaryNode {
- return &InternalNode{
- left: bt.left.Copy(),
- right: bt.right.Copy(),
- depth: bt.depth,
- mustRecompute: bt.mustRecompute,
- dirty: bt.dirty,
- hash: bt.hash,
- }
-}
-
-// Hash returns the hash of the node.
-func (bt *InternalNode) Hash() common.Hash {
- if !bt.mustRecompute {
- return bt.hash
- }
-
- // At shallow depths, parallelize when both children need rehashing:
- // hash left subtree in a goroutine, right subtree inline, then combine.
- // Skip goroutine overhead when only one child is dirty (common case
- // for narrow state updates that touch a single path through the trie).
- if bt.depth < parallelDepth() && isDirty(bt.left) && isDirty(bt.right) {
- var input [64]byte
- var lh common.Hash
- var wg sync.WaitGroup
- wg.Add(1)
- go func() {
- defer wg.Done()
- lh = bt.left.Hash()
- }()
- rh := bt.right.Hash()
- copy(input[32:], rh[:])
- wg.Wait()
- copy(input[:32], lh[:])
- bt.hash = sha256.Sum256(input[:])
- bt.mustRecompute = false
- return bt.hash
- }
-
- // Deeper nodes: sequential using pooled hasher (goroutine overhead > hash cost)
- h := newSha256()
- defer returnSha256(h)
- if bt.left != nil {
- h.Write(bt.left.Hash().Bytes())
- } else {
- h.Write(zero[:])
- }
- if bt.right != nil {
- h.Write(bt.right.Hash().Bytes())
- } else {
- h.Write(zero[:])
- }
- bt.hash = common.BytesToHash(h.Sum(nil))
- bt.mustRecompute = false
- return bt.hash
-}
-
-// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
-// Already-existing values will be overwritten.
-func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
- var err error
- bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
- if bit == 0 {
- if bt.left == nil {
- bt.left = Empty{}
- }
-
- if hn, ok := bt.left.(HashedNode); ok {
- path, err := keyToPath(bt.depth, stem)
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
- }
- data, err := resolver(path, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
- }
- node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
- }
- bt.left = node
- }
-
- bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
- bt.mustRecompute = true
- bt.dirty = true
- return bt, err
- }
-
- if bt.right == nil {
- bt.right = Empty{}
- }
-
- if hn, ok := bt.right.(HashedNode); ok {
- path, err := keyToPath(bt.depth, stem)
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
- }
- data, err := resolver(path, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
- }
- node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
- if err != nil {
- return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
- }
- bt.right = node
- }
-
- bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
- bt.mustRecompute = true
- bt.dirty = true
- return bt, err
-}
-
-// CollectNodes collects all child nodes at a given path, and flushes it
-// into the provided node collector. Clean subtrees (dirty == false) are
-// skipped.
-func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
- if !bt.dirty {
- return nil
- }
- if bt.left != nil {
- var p [256]byte
- copy(p[:], path)
- childpath := p[:len(path)]
- childpath = append(childpath, 0)
- if err := bt.left.CollectNodes(childpath, flushfn); err != nil {
- return err
- }
- }
- if bt.right != nil {
- var p [256]byte
- copy(p[:], path)
- childpath := p[:len(path)]
- childpath = append(childpath, 1)
- if err := bt.right.CollectNodes(childpath, flushfn); err != nil {
- return err
- }
- }
- flushfn(path, bt)
- bt.dirty = false
- return nil
-}
-
-// GetHeight returns the height of the node.
-func (bt *InternalNode) GetHeight() int {
- var (
- leftHeight int
- rightHeight int
- )
- if bt.left != nil {
- leftHeight = bt.left.GetHeight()
- }
- if bt.right != nil {
- rightHeight = bt.right.GetHeight()
- }
- return 1 + max(leftHeight, rightHeight)
-}
-
-func (bt *InternalNode) toDot(parent, path string) string {
- me := fmt.Sprintf("internal%s", path)
- ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, bt.Hash())
- if len(parent) > 0 {
- ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
- }
-
- if bt.left != nil {
- ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0)))
- }
- if bt.right != nil {
- ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1)))
- }
- return ret
+ left, right nodeRef
+ depth uint8
+ mustRecompute bool // hash is stale (cleared by Hash)
+ dirty bool // on-disk blob is stale (cleared by CollectNodes)
+ hash common.Hash
}
diff --git a/trie/bintrie/internal_node_test.go b/trie/bintrie/internal_node_test.go
index ddcec8085d..8d5a75de8c 100644
--- a/trie/bintrie/internal_node_test.go
+++ b/trie/bintrie/internal_node_test.go
@@ -24,35 +24,33 @@ import (
"github.com/ethereum/go-ethereum/common"
)
-// TestInternalNodeGet tests the Get method
+// TestInternalNodeGet tests the Get method via nodeStore.
func TestInternalNodeGet(t *testing.T) {
- // Create a simple tree structure
+ s := newNodeStore()
+
leftStem := make([]byte, 31)
rightStem := make([]byte, 31)
- rightStem[0] = 0x80 // First bit is 1
+ rightStem[0] = 0x80
- var leftValues, rightValues [256][]byte
+ leftValues := make([][]byte, 256)
leftValues[0] = common.HexToHash("0x0101").Bytes()
+ rightValues := make([][]byte, 256)
rightValues[0] = common.HexToHash("0x0202").Bytes()
- node := &InternalNode{
- depth: 0,
- left: &StemNode{
- Stem: leftStem,
- Values: leftValues[:],
- depth: 1,
- },
- right: &StemNode{
- Stem: rightStem,
- Values: rightValues[:],
- depth: 1,
- },
+ // Build tree: root -> left stem, right stem
+ // Insert left stem values
+ s.root = emptyRef
+ if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
+ t.Fatal(err)
}
// Get value from left subtree
leftKey := make([]byte, 32)
leftKey[31] = 0
- value, err := node.Get(leftKey, nil)
+ value, err := s.Get(leftKey, nil)
if err != nil {
t.Fatalf("Failed to get left value: %v", err)
}
@@ -64,7 +62,7 @@ func TestInternalNodeGet(t *testing.T) {
rightKey := make([]byte, 32)
rightKey[0] = 0x80
rightKey[31] = 0
- value, err = node.Get(rightKey, nil)
+ value, err = s.Get(rightKey, nil)
if err != nil {
t.Fatalf("Failed to get right value: %v", err)
}
@@ -73,29 +71,26 @@ func TestInternalNodeGet(t *testing.T) {
}
}
-// TestInternalNodeGetWithResolver tests Get with HashedNode resolution
+// TestInternalNodeGetWithResolver tests Get with HashedNode resolution via nodeStore.
func TestInternalNodeGetWithResolver(t *testing.T) {
- // Create an internal node with a hashed child
- hashedChild := HashedNode(common.HexToHash("0x1234"))
-
- node := &InternalNode{
- depth: 0,
- left: hashedChild,
- right: Empty{},
- }
+ // Create a store with an internal node containing a hashed child
+ s := newNodeStore()
+ hashedChild := s.newHashedRef(common.HexToHash("0x1234"))
+ rootRef := s.newInternalRef(0)
+ rootNode := s.getInternal(rootRef.Index())
+ rootNode.left = hashedChild
+ rootNode.right = emptyRef
+ s.root = rootRef
// Mock resolver that returns a stem node
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
- if hash == common.Hash(hashedChild) {
+ if hash == common.HexToHash("0x1234") {
+ rs := newNodeStore()
stem := make([]byte, 31)
- var values [256][]byte
- values[5] = common.HexToHash("0xabcd").Bytes()
- stemNode := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 1,
- }
- return SerializeNode(stemNode), nil
+ ref := rs.newStemRef(stem, 1)
+ sn := rs.getStem(ref.Index())
+ sn.setValue(5, common.HexToHash("0xabcd").Bytes())
+ return rs.serializeNode(ref), nil
}
return nil, errors.New("node not found")
}
@@ -103,7 +98,7 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
// Get value through the hashed node
key := make([]byte, 32)
key[31] = 5
- value, err := node.Get(key, resolver)
+ value, err := s.Get(key, resolver)
if err != nil {
t.Fatalf("Failed to get value: %v", err)
}
@@ -114,179 +109,113 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
}
}
-// TestInternalNodeInsert tests the Insert method
+// TestInternalNodeInsert tests the Insert method via nodeStore.
func TestInternalNodeInsert(t *testing.T) {
- // Start with an internal node with empty children
- node := &InternalNode{
- depth: 0,
- left: Empty{},
- right: Empty{},
- }
+ s := newNodeStore()
- // Insert a value into the left subtree
leftKey := make([]byte, 32)
leftKey[31] = 10
leftValue := common.HexToHash("0x0101").Bytes()
- newNode, err := node.Insert(leftKey, leftValue, nil, 0)
- if err != nil {
+ if err := s.Insert(leftKey, leftValue, nil); err != nil {
t.Fatalf("Failed to insert: %v", err)
}
- internalNode, ok := newNode.(*InternalNode)
- if !ok {
- t.Fatalf("Expected InternalNode, got %T", newNode)
+ // Verify the value was stored
+ value, err := s.Get(leftKey, nil)
+ if err != nil {
+ t.Fatalf("Failed to get: %v", err)
}
-
- // Check that left child is now a StemNode
- leftStem, ok := internalNode.left.(*StemNode)
- if !ok {
- t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
- }
-
- // Check the inserted value
- if !bytes.Equal(leftStem.Values[10], leftValue) {
- t.Errorf("Value mismatch: expected %x, got %x", leftValue, leftStem.Values[10])
- }
-
- // Right child should still be Empty
- _, ok = internalNode.right.(Empty)
- if !ok {
- t.Errorf("Expected right child to remain Empty, got %T", internalNode.right)
+ if !bytes.Equal(value, leftValue) {
+ t.Errorf("Value mismatch: expected %x, got %x", leftValue, value)
}
}
-// TestInternalNodeCopy tests the Copy method
+// TestInternalNodeCopy tests the Copy method via nodeStore.
func TestInternalNodeCopy(t *testing.T) {
- // Create an internal node with stem children
- leftStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- }
- leftStem.Values[0] = common.HexToHash("0x0101").Bytes()
+ s := newNodeStore()
- rightStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- }
- rightStem.Stem[0] = 0x80
- rightStem.Values[0] = common.HexToHash("0x0202").Bytes()
+ leftKey := make([]byte, 32)
+ leftKey[31] = 0
+ leftValue := common.HexToHash("0x0101").Bytes()
- node := &InternalNode{
- depth: 0,
- left: leftStem,
- right: rightStem,
+ rightKey := make([]byte, 32)
+ rightKey[0] = 0x80
+ rightKey[31] = 0
+ rightValue := common.HexToHash("0x0202").Bytes()
+
+ if err := s.Insert(leftKey, leftValue, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.Insert(rightKey, rightValue, nil); err != nil {
+ t.Fatal(err)
}
- // Create a copy
- copied := node.Copy()
- copiedInternal, ok := copied.(*InternalNode)
- if !ok {
- t.Fatalf("Expected InternalNode, got %T", copied)
- }
+ ns := s.Copy()
- // Check depth
- if copiedInternal.depth != node.depth {
- t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedInternal.depth)
- }
-
- // Check that children are copied
- copiedLeft, ok := copiedInternal.left.(*StemNode)
- if !ok {
- t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.left)
- }
-
- copiedRight, ok := copiedInternal.right.(*StemNode)
- if !ok {
- t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.right)
- }
-
- // Verify deep copy (children should be different objects)
- if copiedLeft == leftStem {
- t.Error("Left child not properly copied")
- }
- if copiedRight == rightStem {
- t.Error("Right child not properly copied")
- }
-
- // But values should be equal
- if !bytes.Equal(copiedLeft.Values[0], leftStem.Values[0]) {
+ // Values should be equal
+ v1, _ := ns.Get(leftKey, nil)
+ if !bytes.Equal(v1, leftValue) {
t.Error("Left child value mismatch after copy")
}
- if !bytes.Equal(copiedRight.Values[0], rightStem.Values[0]) {
+ v2, _ := ns.Get(rightKey, nil)
+ if !bytes.Equal(v2, rightValue) {
t.Error("Right child value mismatch after copy")
}
}
-// TestInternalNodeHash tests the Hash method
+// TestInternalNodeHash tests the Hash method via nodeStore.
func TestInternalNodeHash(t *testing.T) {
- // Create an internal node
- node := &InternalNode{
- depth: 0,
- left: HashedNode(common.HexToHash("0x1111")),
- right: HashedNode(common.HexToHash("0x2222")),
- }
+ s := newNodeStore()
+ leftRef := s.newHashedRef(common.HexToHash("0x1111"))
+ rightRef := s.newHashedRef(common.HexToHash("0x2222"))
+ rootRef := s.newInternalRef(0)
+ rootNode := s.getInternal(rootRef.Index())
+ rootNode.left = leftRef
+ rootNode.right = rightRef
+ s.root = rootRef
- hash1 := node.Hash()
+ hash1 := s.computeHash(rootRef)
// Hash should be deterministic
- hash2 := node.Hash()
+ hash2 := s.computeHash(rootRef)
if hash1 != hash2 {
t.Errorf("Hash not deterministic: %x != %x", hash1, hash2)
}
// Changing a child should change the hash
- node.left = HashedNode(common.HexToHash("0x3333"))
- node.mustRecompute = true
- hash3 := node.Hash()
+ rootNode.left = s.newHashedRef(common.HexToHash("0x3333"))
+ rootNode.mustRecompute = true
+ hash3 := s.computeHash(rootRef)
if hash1 == hash3 {
t.Error("Hash didn't change after modifying left child")
}
-
- // Test with nil children (should use zero hash)
- nodeWithNil := &InternalNode{
- depth: 0,
- left: nil,
- right: HashedNode(common.HexToHash("0x4444")),
- mustRecompute: true,
- }
- hashWithNil := nodeWithNil.Hash()
- if hashWithNil == (common.Hash{}) {
- t.Error("Hash shouldn't be zero even with nil child")
- }
}
-// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method
+// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method via nodeStore.
func TestInternalNodeGetValuesAtStem(t *testing.T) {
- // Create a tree with values at different stems
+ s := newNodeStore()
+
leftStem := make([]byte, 31)
rightStem := make([]byte, 31)
rightStem[0] = 0x80
- var leftValues, rightValues [256][]byte
+ leftValues := make([][]byte, 256)
leftValues[0] = common.HexToHash("0x0101").Bytes()
leftValues[10] = common.HexToHash("0x0102").Bytes()
+ rightValues := make([][]byte, 256)
rightValues[0] = common.HexToHash("0x0201").Bytes()
rightValues[20] = common.HexToHash("0x0202").Bytes()
- node := &InternalNode{
- depth: 0,
- left: &StemNode{
- Stem: leftStem,
- Values: leftValues[:],
- depth: 1,
- },
- right: &StemNode{
- Stem: rightStem,
- Values: rightValues[:],
- depth: 1,
- },
+ if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
+ t.Fatal(err)
}
// Get values from left stem
- values, err := node.GetValuesAtStem(leftStem, nil)
+ values, err := s.GetValuesAtStem(leftStem, nil)
if err != nil {
t.Fatalf("Failed to get left values: %v", err)
}
@@ -298,7 +227,7 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
}
// Get values from right stem
- values, err = node.GetValuesAtStem(rightStem, nil)
+ values, err = s.GetValuesAtStem(rightStem, nil)
if err != nil {
t.Fatalf("Failed to get right values: %v", err)
}
@@ -310,201 +239,103 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
}
}
-// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method
+// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method via nodeStore.
func TestInternalNodeInsertValuesAtStem(t *testing.T) {
- // Start with an internal node with empty children
- node := &InternalNode{
- depth: 0,
- left: Empty{},
- right: Empty{},
- }
+ s := newNodeStore()
- // Insert values at a stem in the left subtree
stem := make([]byte, 31)
- var values [256][]byte
+ values := make([][]byte, 256)
values[5] = common.HexToHash("0x0505").Bytes()
values[10] = common.HexToHash("0x1010").Bytes()
- newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 0)
- if err != nil {
+ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
t.Fatalf("Failed to insert values: %v", err)
}
- internalNode, ok := newNode.(*InternalNode)
- if !ok {
- t.Fatalf("Expected InternalNode, got %T", newNode)
+ // Check that the values are stored
+ retrieved, err := s.GetValuesAtStem(stem, nil)
+ if err != nil {
+ t.Fatalf("Failed to get values: %v", err)
}
-
- // Check that left child is now a StemNode with the values
- leftStem, ok := internalNode.left.(*StemNode)
- if !ok {
- t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
- }
-
- if !bytes.Equal(leftStem.Values[5], values[5]) {
+ if !bytes.Equal(retrieved[5], values[5]) {
t.Error("Value at index 5 mismatch")
}
- if !bytes.Equal(leftStem.Values[10], values[10]) {
+ if !bytes.Equal(retrieved[10], values[10]) {
t.Error("Value at index 10 mismatch")
}
}
-// TestInternalNodeCollectNodes tests CollectNodes method
+// TestInternalNodeCollectNodes tests CollectNodes method via nodeStore.
func TestInternalNodeCollectNodes(t *testing.T) {
- // Create an internal node with two stem children. All three are
- // marked dirty to mirror production semantics — see CollectNodes.
- leftStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- dirty: true,
- }
+ s := newNodeStore()
- rightStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- dirty: true,
- }
- rightStem.Stem[0] = 0x80
+ leftStem := make([]byte, 31)
+ rightStem := make([]byte, 31)
+ rightStem[0] = 0x80
- node := &InternalNode{
- depth: 0,
- left: leftStem,
- right: rightStem,
- dirty: true,
+ leftValues := make([][]byte, 256)
+ rightValues := make([][]byte, 256)
+
+ if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
+ t.Fatal(err)
}
var collectedPaths [][]byte
- var collectedNodes []BinaryNode
-
- flushFn := func(path []byte, n BinaryNode) {
+ flushFn := func(path []byte, hash common.Hash, serialized []byte) {
pathCopy := make([]byte, len(path))
copy(pathCopy, path)
collectedPaths = append(collectedPaths, pathCopy)
- collectedNodes = append(collectedNodes, n)
}
- err := node.CollectNodes([]byte{1}, flushFn)
+ err := s.collectNodes(s.root, []byte{1}, flushFn)
if err != nil {
t.Fatalf("Failed to collect nodes: %v", err)
}
// Should have collected 3 nodes: left stem, right stem, and the internal node itself
- if len(collectedNodes) != 3 {
- t.Errorf("Expected 3 collected nodes, got %d", len(collectedNodes))
- }
-
- // Check paths
- expectedPaths := [][]byte{
- {1, 0}, // left child
- {1, 1}, // right child
- {1}, // internal node itself
- }
-
- for i, expectedPath := range expectedPaths {
- if !bytes.Equal(collectedPaths[i], expectedPath) {
- t.Errorf("Path %d mismatch: expected %v, got %v", i, expectedPath, collectedPaths[i])
- }
+ if len(collectedPaths) != 3 {
+ t.Errorf("Expected 3 collected nodes, got %d", len(collectedPaths))
}
}
-// TestInternalNodeCollectNodesSkipsClean verifies clean subtrees are not
-// flushed. A dirty internal node over clean children only flushes itself;
-// a fully clean tree flushes nothing.
-func TestInternalNodeCollectNodesSkipsClean(t *testing.T) {
- leftStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- }
- rightStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- }
- rightStem.Stem[0] = 0x80
-
- dirtyParent := &InternalNode{
- depth: 0,
- left: leftStem,
- right: rightStem,
- dirty: true,
- }
-
- var collected []BinaryNode
- flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) }
-
- if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil {
- t.Fatalf("CollectNodes: %v", err)
- }
- if len(collected) != 1 || collected[0] != dirtyParent {
- t.Fatalf("expected only the dirty parent to be flushed, got %d nodes", len(collected))
- }
- if dirtyParent.dirty {
- t.Errorf("parent dirty flag should be cleared after flush")
- }
-
- // Second call on the same tree should be a no-op: everything is clean.
- collected = nil
- if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil {
- t.Fatalf("CollectNodes (second call): %v", err)
- }
- if len(collected) != 0 {
- t.Errorf("expected no nodes to be flushed on clean tree, got %d", len(collected))
- }
-}
-
-// TestInternalNodeGetHeight tests GetHeight method
+// TestInternalNodeGetHeight tests GetHeight method via nodeStore.
func TestInternalNodeGetHeight(t *testing.T) {
- // Create a tree with different heights
- // Left subtree: depth 2 (internal -> stem)
- // Right subtree: depth 1 (stem)
- leftInternal := &InternalNode{
- depth: 1,
- left: &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 2,
- },
- right: Empty{},
+ s := newNodeStore()
+
+ // Insert values that create a deeper tree
+ stem1 := make([]byte, 31) // left
+ stem2 := make([]byte, 31)
+ stem2[0] = 0x40 // 01... -> goes left at depth 0, right at depth 1
+
+ values1 := make([][]byte, 256)
+ values1[0] = common.HexToHash("0x01").Bytes()
+ values2 := make([][]byte, 256)
+ values2[0] = common.HexToHash("0x02").Bytes()
+
+ if err := s.InsertValuesAtStem(stem1, values1, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.InsertValuesAtStem(stem2, values2, nil); err != nil {
+ t.Fatal(err)
}
- rightStem := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 1,
- }
-
- node := &InternalNode{
- depth: 0,
- left: leftInternal,
- right: rightStem,
- }
-
- height := node.GetHeight()
- // Height should be max(left height, right height) + 1
- // Left height: 2, Right height: 1, so total: 3
- if height != 3 {
- t.Errorf("Expected height 3, got %d", height)
+ height := s.getHeight(s.root)
+ if height < 2 {
+ t.Errorf("Expected height >= 2, got %d", height)
}
}
-// TestInternalNodeDepthTooLarge tests handling of excessive depth
+// TestInternalNodeDepthTooLarge tests handling of excessive depth via nodeStore.
func TestInternalNodeDepthTooLarge(t *testing.T) {
- // Create an internal node at max depth
- node := &InternalNode{
- depth: 31*8 + 1,
- left: Empty{},
- right: Empty{},
- }
-
- stem := make([]byte, 31)
- _, err := node.GetValuesAtStem(stem, nil)
- if err == nil {
- t.Fatal("Expected error for excessive depth")
- }
- if err.Error() != "node too deep" {
- t.Errorf("Expected 'node too deep' error, got: %v", err)
- }
+ s := newNodeStore()
+ // Creating an internal node beyond max depth should panic
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("Expected panic for excessive depth")
+ }
+ }()
+ s.newInternalRef(31*8 + 1)
}
diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go
index 048d37f766..31645430c3 100644
--- a/trie/bintrie/iterator.go
+++ b/trie/bintrie/iterator.go
@@ -26,13 +26,14 @@ import (
var errIteratorEnd = errors.New("end of iteration")
type binaryNodeIteratorState struct {
- Node BinaryNode
+ Node nodeRef
Index int
}
type binaryNodeIterator struct {
trie *BinaryTrie
- current BinaryNode
+ store *nodeStore
+ current nodeRef
lastErr error
stack []binaryNodeIteratorState
@@ -40,56 +41,63 @@ type binaryNodeIterator struct {
func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) {
if t.Hash() == zero {
- return &binaryNodeIterator{trie: t, lastErr: errIteratorEnd}, nil
+ return &binaryNodeIterator{trie: t, store: t.store, lastErr: errIteratorEnd}, nil
}
- it := &binaryNodeIterator{trie: t, current: t.root}
- // it.err = it.seek(start)
+ it := &binaryNodeIterator{trie: t, store: t.store, current: t.store.root}
return it, nil
}
-// Next moves the iterator to the next node. If the parameter is false, any child
-// nodes will be skipped.
+// Next moves the iterator to the next node. If descend is false, children of
+// the current node are skipped.
func (it *binaryNodeIterator) Next(descend bool) bool {
if it.lastErr == errIteratorEnd {
- it.lastErr = errIteratorEnd
return false
}
if len(it.stack) == 0 {
- it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.root})
- it.current = it.trie.root
-
+ it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.store.root})
+ it.current = it.trie.store.root
return true
}
- switch node := it.current.(type) {
- case *InternalNode:
- // index: 0 = nothing visited, 1=left visited, 2=right visited
+ switch it.current.Kind() {
+ case kindInternal:
+ // index: 0 = nothing visited, 1 = left visited, 2 = right visited.
+ node := it.store.getInternal(it.current.Index())
context := &it.stack[len(it.stack)-1]
- // recurse into both children
+ if !descend {
+ // Skip children: pop this node and advance parent.
+ if len(it.stack) == 1 {
+ it.lastErr = errIteratorEnd
+ return false
+ }
+ it.stack = it.stack[:len(it.stack)-1]
+ it.current = it.stack[len(it.stack)-1].Node
+ it.stack[len(it.stack)-1].Index++
+ return it.Next(true)
+ }
+
+ // Recurse into both children.
if context.Index == 0 {
- if _, isempty := node.left.(Empty); node.left != nil && !isempty {
+ if !node.left.IsEmpty() {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left})
it.current = node.left
return it.Next(descend)
}
-
context.Index++
}
if context.Index == 1 {
- if _, isempty := node.right.(Empty); node.right != nil && !isempty {
+ if !node.right.IsEmpty() {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right})
it.current = node.right
return it.Next(descend)
}
-
context.Index++
}
- // Reached the end of this node, go back to the parent, if
- // this isn't root.
+ // Reached the end of this node; go back to the parent unless we're at the root.
if len(it.stack) == 1 {
it.lastErr = errIteratorEnd
return false
@@ -98,17 +106,18 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++
return it.Next(descend)
- case *StemNode:
- // Look for the next non-empty value
+
+ case kindStem:
+ // Look for the next non-empty value in this stem.
+ sn := it.store.getStem(it.current.Index())
for i := it.stack[len(it.stack)-1].Index; i < 256; i++ {
- if node.Values[i] != nil {
+ if sn.hasValue(byte(i)) {
it.stack[len(it.stack)-1].Index = i + 1
return true
}
}
- // go back to parent to get the next leaf
- // Check if we're at the root before popping
+ // No more values in this stem; go back to parent to get the next leaf.
if len(it.stack) == 1 {
it.lastErr = errIteratorEnd
return false
@@ -117,51 +126,47 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++
return it.Next(descend)
- case HashedNode:
- // resolve the node
- resolverPath := it.Path()
- data, err := it.trie.nodeResolver(resolverPath, common.Hash(node))
- if err != nil {
- panic(err)
- }
- if data == nil {
- // Empty/nil node — treat as Empty, backtrack
- it.current = Empty{}
- it.stack[len(it.stack)-1].Node = it.current
- return it.Next(descend)
- }
- it.current, err = DeserializeNodeWithHash(data, len(it.stack)-1, common.Hash(node))
- if err != nil {
- panic(err)
- }
- // update the stack and parent with the resolved node
- it.stack[len(it.stack)-1].Node = it.current
- if len(it.stack) >= 2 {
- parent := &it.stack[len(it.stack)-2]
- if parent.Index == 0 {
- parent.Node.(*InternalNode).left = it.current
- } else {
- parent.Node.(*InternalNode).right = it.current
- }
- }
- return it.Next(descend)
- case Empty:
- // Empty node - go back to parent and continue
- if len(it.stack) <= 1 {
- it.lastErr = errIteratorEnd
+ case kindHashed:
+ // Resolve the hashed node from disk, then rewire the parent to point at the
+ // resolved node in place.
+ if len(it.stack) < 2 {
+ it.lastErr = errors.New("cannot resolve hashed root during iteration")
return false
}
- it.stack = it.stack[:len(it.stack)-1]
- it.current = it.stack[len(it.stack)-1].Node
- it.stack[len(it.stack)-1].Index++
+ hn := it.store.getHashed(it.current.Index())
+ data, err := it.trie.nodeResolver(it.Path(), hn.Hash())
+ if err != nil {
+ it.lastErr = err
+ return false
+ }
+ resolved, err := it.store.deserializeNodeWithHash(data, len(it.stack)-1, hn.Hash())
+ if err != nil {
+ it.lastErr = err
+ return false
+ }
+
+ oldHashedIdx := it.current.Index()
+ it.current = resolved
+ it.stack[len(it.stack)-1].Node = resolved
+ parent := &it.stack[len(it.stack)-2]
+ parentNode := it.store.getInternal(parent.Node.Index())
+ if parent.Index == 0 {
+ parentNode.left = resolved
+ } else {
+ parentNode.right = resolved
+ }
+ it.store.freeHashedNode(oldHashedIdx)
return it.Next(descend)
+
+ case kindEmpty:
+ return false
+
default:
panic("invalid node type")
}
}
-// Error returns the error status of the iterator.
func (it *binaryNodeIterator) Error() error {
if it.lastErr == errIteratorEnd {
return nil
@@ -169,27 +174,28 @@ func (it *binaryNodeIterator) Error() error {
return it.lastErr
}
-// Hash returns the hash of the current node.
func (it *binaryNodeIterator) Hash() common.Hash {
- return it.current.Hash()
+ return it.store.computeHash(it.current)
}
-// Parent returns the hash of the parent of the current node. The hash may be the one
-// grandparent if the immediate parent is an internal node with no hash.
+// Parent returns the hash of the current node's parent. When the immediate
+// parent is an internal node whose hash has not been materialised, the
+// returned hash may be the one of a grandparent instead.
func (it *binaryNodeIterator) Parent() common.Hash {
- return it.stack[len(it.stack)-1].Node.Hash()
+ if len(it.stack) < 2 {
+ return common.Hash{}
+ }
+ return it.store.computeHash(it.stack[len(it.stack)-2].Node)
}
-// Path returns the hex-encoded path to the current node.
-// Callers must not retain references to the return value after calling Next.
-// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10.
+// Path returns the bit-path to the current node.
+// Callers must not retain references to the returned slice after calling Next.
func (it *binaryNodeIterator) Path() []byte {
if it.Leaf() {
return it.LeafKey()
}
var path []byte
for i, state := range it.stack {
- // skip the last byte
if i >= len(it.stack)-1 {
break
}
@@ -198,107 +204,94 @@ func (it *binaryNodeIterator) Path() []byte {
return path
}
-// NodeBlob returns the serialized bytes of the current node.
func (it *binaryNodeIterator) NodeBlob() []byte {
- return SerializeNode(it.current)
+ return it.store.serializeNode(it.current)
}
-// Leaf returns true iff the current node is a leaf node.
-// In a Binary Trie, a StemNode contains up to 256 leaf values.
-// The iterator is only considered to be "at a leaf" when it's positioned
-// at a specific non-nil value within the StemNode, not just at the StemNode itself.
+// Leaf reports whether the iterator is currently positioned at a leaf value.
+// A StemNode holds up to 256 values; the iterator is only "at a leaf" when
+// positioned at a specific non-nil value inside the stem, not merely at the
+// StemNode itself. The stack Index points to the NEXT position after the
+// current value, so Index == 0 means we haven't yielded anything yet.
func (it *binaryNodeIterator) Leaf() bool {
- sn, ok := it.current.(*StemNode)
- if !ok {
+ if it.current.Kind() != kindStem {
return false
}
- // Check if we have a valid stack position
if len(it.stack) == 0 {
return false
}
- // The Index in the stack state points to the NEXT position after the current value.
- // So if Index is 0, we haven't started iterating through the values yet.
- // If Index is 5, we're currently at value[4] (the 5th value, 0-indexed).
idx := it.stack[len(it.stack)-1].Index
if idx == 0 || idx > 256 {
return false
}
- // Check if there's actually a value at the current position
+ sn := it.store.getStem(it.current.Index())
currentValueIndex := idx - 1
- return sn.Values[currentValueIndex] != nil
+ return sn.hasValue(byte(currentValueIndex))
}
-// LeafKey returns the key of the leaf. The method panics if the iterator is not
-// positioned at a leaf. Callers must not retain references to the value after
-// calling Next.
+// LeafKey returns the key of the leaf. Panics if the iterator is not
+// positioned at a leaf. Callers must not retain references to the returned
+// slice after calling Next.
func (it *binaryNodeIterator) LeafKey() []byte {
- leaf, ok := it.current.(*StemNode)
- if !ok {
+ if it.current.Kind() != kindStem {
panic("Leaf() called on an binary node iterator not at a leaf location")
}
- return leaf.Key(it.stack[len(it.stack)-1].Index - 1)
+ sn := it.store.getStem(it.current.Index())
+ return sn.Key(it.stack[len(it.stack)-1].Index - 1)
}
-// LeafBlob returns the content of the leaf. The method panics if the iterator
-// is not positioned at a leaf. Callers must not retain references to the value
-// after calling Next.
+// LeafBlob returns the leaf value. Panics if the iterator is not positioned
+// at a leaf. Callers must not retain references to the returned slice after
+// calling Next.
func (it *binaryNodeIterator) LeafBlob() []byte {
- leaf, ok := it.current.(*StemNode)
- if !ok {
+ if it.current.Kind() != kindStem {
panic("LeafBlob() called on an binary node iterator not at a leaf location")
}
- return leaf.Values[it.stack[len(it.stack)-1].Index-1]
+ sn := it.store.getStem(it.current.Index())
+ return sn.getValue(byte(it.stack[len(it.stack)-1].Index - 1))
}
-// LeafProof returns the Merkle proof of the leaf. The method panics if the
-// iterator is not positioned at a leaf. Callers must not retain references
-// to the value after calling Next.
+// LeafProof returns the Merkle proof of the leaf. Panics if the iterator is
+// not positioned at a leaf. Callers must not retain references to the
+// returned slices after calling Next.
func (it *binaryNodeIterator) LeafProof() [][]byte {
- sn, ok := it.current.(*StemNode)
- if !ok {
+ if it.current.Kind() != kindStem {
panic("LeafProof() called on an binary node iterator not at a leaf location")
}
+ sn := it.store.getStem(it.current.Index())
proof := make([][]byte, 0, len(it.stack)+StemNodeWidth)
- // Build proof by walking up the stack and collecting sibling hashes
+ if len(it.stack) < 2 {
+ proof = append(proof, sn.Stem[:])
+ proof = append(proof, sn.allValues()...)
+ return proof
+ }
+
for i := range it.stack[:len(it.stack)-2] {
state := it.stack[i]
- internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode
+ internalNode := it.store.getInternal(state.Node.Index())
- // Add the sibling hash to the proof
if state.Index == 0 {
- // We came from left, so include right sibling
- proof = append(proof, internalNode.right.Hash().Bytes())
+ rh := it.store.computeHash(internalNode.right)
+ proof = append(proof, rh.Bytes())
} else {
- // We came from right, so include left sibling
- proof = append(proof, internalNode.left.Hash().Bytes())
+ lh := it.store.computeHash(internalNode.left)
+ proof = append(proof, lh.Bytes())
}
}
// Add the stem and siblings
- proof = append(proof, sn.Stem)
- for _, v := range sn.Values {
- proof = append(proof, v)
- }
+ proof = append(proof, sn.Stem[:])
+ proof = append(proof, sn.allValues()...)
return proof
}
-// AddResolver sets an intermediate database to use for looking up trie nodes
-// before reaching into the real persistent layer.
-//
-// This is not required for normal operation, rather is an optimization for
-// cases where trie nodes can be recovered from some external mechanism without
-// reading from disk. In those cases, this resolver allows short circuiting
-// accesses and returning them from memory.
-//
-// Before adding a similar mechanism to any other place in Geth, consider
-// making trie.Database an interface and wrapping at that level. It's a huge
-// refactor, but it could be worth it if another occurrence arises.
+// AddResolver is a no-op (satisfies the NodeIterator interface).
func (it *binaryNodeIterator) AddResolver(trie.NodeResolver) {
// Not implemented, but should not panic
}
diff --git a/trie/bintrie/iterator_test.go b/trie/bintrie/iterator_test.go
index 3e717c07ba..746f6e8c0f 100644
--- a/trie/bintrie/iterator_test.go
+++ b/trie/bintrie/iterator_test.go
@@ -27,14 +27,13 @@ import (
// makeTrie creates a BinaryTrie populated with the given key-value pairs.
func makeTrie(t *testing.T, entries [][2]common.Hash) *BinaryTrie {
t.Helper()
+ store := newNodeStore()
tr := &BinaryTrie{
- root: NewBinaryNode(),
+ store: store,
tracer: trie.NewPrevalueTracer(),
}
for _, kv := range entries {
- var err error
- tr.root, err = tr.root.Insert(kv[0][:], kv[1][:], nil, 0)
- if err != nil {
+ if err := store.Insert(kv[0][:], kv[1][:], nil); err != nil {
t.Fatal(err)
}
}
@@ -64,7 +63,7 @@ func countLeaves(t *testing.T, tr *BinaryTrie) int {
// no nodes and reports no error.
func TestIteratorEmptyTrie(t *testing.T) {
tr := &BinaryTrie{
- root: Empty{},
+ store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
it, err := newBinaryNodeIterator(tr, nil)
@@ -145,8 +144,8 @@ func TestIteratorEmptyNodeBacktrack(t *testing.T) {
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
})
- if _, ok := tr.root.(*InternalNode); !ok {
- t.Fatalf("expected InternalNode root, got %T", tr.root)
+ if tr.store.root.Kind() != kindInternal {
+ t.Fatalf("expected InternalNode root, got kind %d", tr.store.root.Kind())
}
if leaves := countLeaves(t, tr); leaves != 2 {
t.Fatalf("expected 2 leaves, got %d (Empty backtrack bug?)", leaves)
@@ -162,18 +161,31 @@ func TestIteratorHashedNodeNilData(t *testing.T) {
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
})
- root, ok := tr.root.(*InternalNode)
- if !ok {
- t.Fatalf("expected InternalNode root, got %T", tr.root)
+ root := tr.store.root
+ if root.Kind() != kindInternal {
+ t.Fatalf("expected InternalNode root, got kind %d", root.Kind())
}
+ rootNode := tr.store.getInternal(root.Index())
// Replace right child with a zero-hash HashedNode. nodeResolver
// short-circuits on common.Hash{} and returns (nil, nil), which
// triggers the nil-data guard in the iterator.
- root.right = HashedNode(common.Hash{})
+ rootNode.right = tr.store.newHashedRef(common.Hash{})
// Should not panic; the zero-hash right child should be treated as Empty.
- if leaves := countLeaves(t, tr); leaves != 1 {
+ // Since the hashed node can't be resolved (nil data -> empty deserialization),
+ // only the left leaf should be counted.
+ it, err := newBinaryNodeIterator(tr, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ leaves := 0
+ for it.Next(true) {
+ if it.Leaf() {
+ leaves++
+ }
+ }
+ if leaves != 1 {
t.Fatalf("expected 1 leaf (zero-hash right node skipped), got %d", leaves)
}
}
diff --git a/trie/bintrie/node_ref.go b/trie/bintrie/node_ref.go
new file mode 100644
index 0000000000..1c9c0f6284
--- /dev/null
+++ b/trie/bintrie/node_ref.go
@@ -0,0 +1,56 @@
+// Copyright 2026 go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package bintrie
+
+// nodeKind identifies the type of a trie node stored in a nodeRef.
+type nodeKind uint8
+
+const (
+ kindEmpty nodeKind = iota
+ kindInternal
+ kindStem // up to 256 values per stem
+ kindHashed
+)
+
+// nodeRef is a compact, GC-invisible reference to a node in a nodeStore.
+// It packs a 2-bit type tag (bits 31-30) and a 30-bit index (bits 29-0)
+// into a single uint32. Because nodeRef contains no Go pointers, slices
+// of structs containing nodeRef fields are allocated in noscan spans —
+// the garbage collector never examines them.
+type nodeRef uint32
+
+const (
+ kindShift uint32 = 30
+ indexMask uint32 = (1 << kindShift) - 1
+
+ // emptyRef represents an empty node.
+ emptyRef nodeRef = 0
+)
+
+func makeRef(kind nodeKind, idx uint32) nodeRef {
+ if idx > indexMask {
+ panic("nodeRef index overflow")
+ }
+ return nodeRef(uint32(kind)<> kindShift) }
+
+// Index within the typed pool.
+func (r nodeRef) Index() uint32 { return uint32(r) & indexMask }
+
+func (r nodeRef) IsEmpty() bool { return r.Kind() == kindEmpty }
diff --git a/trie/bintrie/node_store.go b/trie/bintrie/node_store.go
new file mode 100644
index 0000000000..8a35f06ee1
--- /dev/null
+++ b/trie/bintrie/node_store.go
@@ -0,0 +1,184 @@
+// Copyright 2026 go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package bintrie
+
+import "github.com/ethereum/go-ethereum/common"
+
+// storeChunkSize is the number of nodes per chunk in each typed pool.
+const storeChunkSize = 4096
+
+// nodeStore is a GC-friendly arena for binary trie nodes. Nodes are packed
+// into typed chunked pools so pointer-free types (InternalNode, HashedNode)
+// land in noscan spans the GC skips entirely.
+type nodeStore struct {
+ internalChunks []*[storeChunkSize]InternalNode
+ internalCount uint32
+
+ stemChunks []*[storeChunkSize]StemNode
+ stemCount uint32
+
+ hashedChunks []*[storeChunkSize]HashedNode
+ hashedCount uint32
+
+ root nodeRef
+
+ // Free list for recycling hashed-node slots after resolve. Internal and
+ // stem nodes are never freed under current semantics (no delete path,
+ // stem-split keeps the old stem at a deeper position), so they don't
+ // have free lists.
+ freeHashed []uint32
+}
+
+func newNodeStore() *nodeStore {
+ return &nodeStore{root: emptyRef}
+}
+
+func (s *nodeStore) allocInternal() uint32 {
+ idx := s.internalCount
+ chunkIdx := idx / storeChunkSize
+ if uint32(len(s.internalChunks)) <= chunkIdx {
+ s.internalChunks = append(s.internalChunks, new([storeChunkSize]InternalNode))
+ }
+ s.internalCount++
+ if s.internalCount > indexMask {
+ panic("internal node pool overflow")
+ }
+ return idx
+}
+
+func (s *nodeStore) getInternal(idx uint32) *InternalNode {
+ return &s.internalChunks[idx/storeChunkSize][idx%storeChunkSize]
+}
+
+func (s *nodeStore) newInternalRef(depth int) nodeRef {
+ if depth > 248 {
+ panic("node depth exceeds maximum binary trie depth")
+ }
+ idx := s.allocInternal()
+ n := s.getInternal(idx)
+ n.depth = uint8(depth)
+ n.mustRecompute = true
+ n.dirty = true
+ return makeRef(kindInternal, idx)
+}
+
+func (s *nodeStore) allocStem() uint32 {
+ idx := s.stemCount
+ chunkIdx := idx / storeChunkSize
+ if uint32(len(s.stemChunks)) <= chunkIdx {
+ s.stemChunks = append(s.stemChunks, new([storeChunkSize]StemNode))
+ }
+ s.stemCount++
+ if s.stemCount > indexMask {
+ panic("stem node pool overflow")
+ }
+ return idx
+}
+
+func (s *nodeStore) getStem(idx uint32) *StemNode {
+ return &s.stemChunks[idx/storeChunkSize][idx%storeChunkSize]
+}
+
+func (s *nodeStore) newStemRef(stem []byte, depth int) nodeRef {
+ if depth > 248 {
+ panic("node depth exceeds maximum binary trie depth")
+ }
+ idx := s.allocStem()
+ sn := s.getStem(idx)
+ copy(sn.Stem[:], stem[:StemSize])
+ sn.depth = uint8(depth)
+ sn.mustRecompute = true
+ sn.dirty = true
+ return makeRef(kindStem, idx)
+}
+
+func (s *nodeStore) allocHashed() uint32 {
+ if n := len(s.freeHashed); n > 0 {
+ idx := s.freeHashed[n-1]
+ s.freeHashed = s.freeHashed[:n-1]
+ *s.getHashed(idx) = HashedNode{}
+ return idx
+ }
+ idx := s.hashedCount
+ chunkIdx := idx / storeChunkSize
+ if uint32(len(s.hashedChunks)) <= chunkIdx {
+ s.hashedChunks = append(s.hashedChunks, new([storeChunkSize]HashedNode))
+ }
+ s.hashedCount++
+ if s.hashedCount > indexMask {
+ panic("hashed node pool overflow")
+ }
+ return idx
+}
+
+func (s *nodeStore) getHashed(idx uint32) *HashedNode {
+ return &s.hashedChunks[idx/storeChunkSize][idx%storeChunkSize]
+}
+
+func (s *nodeStore) freeHashedNode(idx uint32) {
+ s.freeHashed = append(s.freeHashed, idx)
+}
+
+func (s *nodeStore) newHashedRef(hash common.Hash) nodeRef {
+ idx := s.allocHashed()
+ *s.getHashed(idx) = HashedNode(hash)
+ return makeRef(kindHashed, idx)
+}
+
+func (s *nodeStore) Copy() *nodeStore {
+ ns := &nodeStore{
+ root: s.root,
+ internalCount: s.internalCount,
+ stemCount: s.stemCount,
+ hashedCount: s.hashedCount,
+ }
+ ns.internalChunks = make([]*[storeChunkSize]InternalNode, len(s.internalChunks))
+ for i, chunk := range s.internalChunks {
+ cp := *chunk
+ ns.internalChunks[i] = &cp
+ }
+ ns.stemChunks = make([]*[storeChunkSize]StemNode, len(s.stemChunks))
+ for i, chunk := range s.stemChunks {
+ cp := *chunk
+ ns.stemChunks[i] = &cp
+ }
+ // Deep-copy each stem's value slots — they may alias serialized buffers,
+ // so we can't rely on the chunk-wise struct copy above.
+ for i := uint32(0); i < s.stemCount; i++ {
+ src := s.getStem(i)
+ dst := ns.getStem(i)
+ for j, v := range src.values {
+ if v == nil {
+ continue
+ }
+ cp := make([]byte, len(v))
+ copy(cp, v)
+ dst.values[j] = cp
+ }
+ }
+ ns.hashedChunks = make([]*[storeChunkSize]HashedNode, len(s.hashedChunks))
+ for i, chunk := range s.hashedChunks {
+ cp := *chunk
+ ns.hashedChunks[i] = &cp
+ }
+ if len(s.freeHashed) > 0 {
+ ns.freeHashed = make([]uint32, len(s.freeHashed))
+ copy(ns.freeHashed, s.freeHashed)
+ }
+
+ return ns
+}
diff --git a/trie/bintrie/stem_node.go b/trie/bintrie/stem_node.go
index 0ceae6b062..93c55acefa 100644
--- a/trie/bintrie/stem_node.go
+++ b/trie/bintrie/stem_node.go
@@ -17,236 +17,93 @@
package bintrie
import (
- "bytes"
- "errors"
- "fmt"
- "slices"
+ "crypto/sha256"
"github.com/ethereum/go-ethereum/common"
)
-// StemNode represents a group of `NodeWith` values sharing the same stem.
+// StemNode holds up to 256 values sharing a 31-byte stem.
+//
+// Invariant: dirty=false implies mustRecompute=false. Every mutation that
+// invalidates the cached hash MUST also mark the blob for re-flush.
type StemNode struct {
- Stem []byte // Stem path to get to StemNodeWidth values
- Values [][]byte // All values, indexed by the last byte of the key.
- depth int // Depth of the node
+ Stem [StemSize]byte
+ values [StemNodeWidth][]byte // nil == slot absent
- mustRecompute bool // true if the hash needs to be recomputed
- dirty bool // true if the node's on-disk blob is stale (needs flush)
+ depth uint8
+
+ mustRecompute bool // hash is stale (cleared by Hash)
+ dirty bool // on-disk blob is stale (cleared by CollectNodes)
hash common.Hash // cached hash when mustRecompute == false
}
-// Get retrieves the value for the given key.
-func (bt *StemNode) Get(key []byte, _ NodeResolverFn) ([]byte, error) {
- if !bytes.Equal(bt.Stem, key[:StemSize]) {
- return nil, nil
- }
- return bt.Values[key[StemSize]], nil
+func (sn *StemNode) getValue(suffix byte) []byte {
+ return sn.values[suffix]
}
-// Insert inserts a new key-value pair into the node.
-func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
- if !bytes.Equal(bt.Stem, key[:StemSize]) {
- bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
-
- n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true}
- bt.depth++
- // bt is re-parented under n and sits at a new path — rewrite its blob.
- bt.mustRecompute = true
- bt.dirty = true
- var child, other *BinaryNode
- if bitStem == 0 {
- n.left = bt
- child = &n.left
- other = &n.right
- } else {
- n.right = bt
- child = &n.right
- other = &n.left
- }
-
- bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
- if bitKey == bitStem {
- var err error
- *child, err = (*child).Insert(key, value, nil, depth+1)
- if err != nil {
- return n, fmt.Errorf("insert error: %w", err)
- }
- *other = Empty{}
- } else {
- var values [StemNodeWidth][]byte
- values[key[StemSize]] = value
- *other = &StemNode{
- Stem: slices.Clone(key[:StemSize]),
- Values: values[:],
- depth: depth + 1,
- mustRecompute: true,
- dirty: true,
- }
- }
- return n, nil
- }
- if len(value) != HashSize {
- return bt, errors.New("invalid insertion: value length")
- }
- bt.Values[key[StemSize]] = value
- bt.mustRecompute = true
- bt.dirty = true
- return bt, nil
+func (sn *StemNode) hasValue(suffix byte) bool {
+ return sn.values[suffix] != nil
}
-// Copy creates a deep copy of the node.
-func (bt *StemNode) Copy() BinaryNode {
- var values [StemNodeWidth][]byte
- for i, v := range bt.Values {
- values[i] = slices.Clone(v)
- }
- return &StemNode{
- Stem: slices.Clone(bt.Stem),
- Values: values[:],
- depth: bt.depth,
- hash: bt.hash,
- mustRecompute: bt.mustRecompute,
- dirty: bt.dirty,
- }
+// allValues returns the underlying slot array as a slice. nil entries mean
+// absent. Callers must treat it as read-only.
+func (sn *StemNode) allValues() [][]byte {
+ return sn.values[:]
}
-// GetHeight returns the height of the node.
-func (bt *StemNode) GetHeight() int {
- return 1
+// setValue mutates a value slot and marks the stem for re-hash and
+// re-flush. This is the only API for post-load value mutation; direct
+// values[...] writes are reserved for the on-disk load path in
+// decodeNode, which must leave mustRecompute/dirty at their loaded
+// state.
+func (sn *StemNode) setValue(suffix byte, value []byte) {
+ sn.values[suffix] = value
+ sn.mustRecompute = true
+ sn.dirty = true
}
-// Hash returns the hash of the node.
-func (bt *StemNode) Hash() common.Hash {
- if !bt.mustRecompute {
- return bt.hash
+func (sn *StemNode) Hash() common.Hash {
+ if !sn.mustRecompute {
+ return sn.hash
}
+ // Use sha256.Sum256 (returns [32]byte by value) instead of a pooled
+ // hash.Hash: feeding data[i][:0] into the interface method Sum forces
+ // data to heap (escape analysis is conservative through interfaces).
+ // Sum256 takes []byte and returns by value, so data stays on stack.
var data [StemNodeWidth]common.Hash
- h := newSha256()
- defer returnSha256(h)
- for i, v := range bt.Values {
+
+ for i, v := range sn.values {
if v != nil {
- h.Reset()
- h.Write(v)
- h.Sum(data[i][:0])
+ data[i] = sha256.Sum256(v)
}
}
- h.Reset()
+ var pair [2 * HashSize]byte
for level := 1; level <= 8; level++ {
for i := range StemNodeWidth / (1 << level) {
- h.Reset()
-
if data[i*2] == (common.Hash{}) && data[i*2+1] == (common.Hash{}) {
data[i] = common.Hash{}
continue
}
-
- h.Write(data[i*2][:])
- h.Write(data[i*2+1][:])
- data[i] = common.Hash(h.Sum(nil))
+ copy(pair[:HashSize], data[i*2][:])
+ copy(pair[HashSize:], data[i*2+1][:])
+ data[i] = sha256.Sum256(pair[:])
}
}
- h.Reset()
- h.Write(bt.Stem)
- h.Write([]byte{0})
- h.Write(data[0][:])
- bt.hash = common.BytesToHash(h.Sum(nil))
- bt.mustRecompute = false
- return bt.hash
+ var final [StemSize + 1 + HashSize]byte
+ copy(final[:StemSize], sn.Stem[:])
+ final[StemSize] = 0
+ copy(final[StemSize+1:], data[0][:])
+ sn.hash = sha256.Sum256(final[:])
+ sn.mustRecompute = false
+ return sn.hash
}
-// CollectNodes flushes the stem via the collector when dirty; clean stems
-// are skipped.
-func (bt *StemNode) CollectNodes(path []byte, flush NodeFlushFn) error {
- if !bt.dirty {
- return nil
- }
- flush(path, bt)
- bt.dirty = false
- return nil
-}
-
-// GetValuesAtStem retrieves the group of values located at the given stem key.
-func (bt *StemNode) GetValuesAtStem(stem []byte, _ NodeResolverFn) ([][]byte, error) {
- if !bytes.Equal(bt.Stem, stem) {
- return nil, nil
- }
- return bt.Values[:], nil
-}
-
-// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
-// Already-existing values will be overwritten.
-func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
- if !bytes.Equal(bt.Stem, key[:StemSize]) {
- bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
-
- n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true}
- bt.depth++
- // bt is re-parented under n and sits at a new path — rewrite its blob.
- bt.mustRecompute = true
- bt.dirty = true
- var child, other *BinaryNode
- if bitStem == 0 {
- n.left = bt
- child = &n.left
- other = &n.right
- } else {
- n.right = bt
- child = &n.right
- other = &n.left
- }
-
- bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
- if bitKey == bitStem {
- var err error
- *child, err = (*child).InsertValuesAtStem(key, values, nil, depth+1)
- if err != nil {
- return n, fmt.Errorf("insert error: %w", err)
- }
- *other = Empty{}
- } else {
- *other = &StemNode{
- Stem: slices.Clone(key[:StemSize]),
- Values: values,
- depth: n.depth + 1,
- mustRecompute: true,
- dirty: true,
- }
- }
- return n, nil
- }
-
- // same stem, just merge the two value lists
- for i, v := range values {
- if v != nil {
- bt.Values[i] = v
- bt.mustRecompute = true
- bt.dirty = true
- }
- }
- return bt, nil
-}
-
-func (bt *StemNode) toDot(parent, path string) string {
- me := fmt.Sprintf("stem%s", path)
- ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, bt.Stem, bt.Hash())
- ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
- for i, v := range bt.Values {
- if v != nil {
- ret = fmt.Sprintf("%s%s%x [label=\"%x\"]\n", ret, me, i, v)
- ret = fmt.Sprintf("%s%s -> %s%x\n", ret, me, me, i)
- }
- }
- return ret
-}
-
-// Key returns the full key for the given index.
-func (bt *StemNode) Key(i int) []byte {
+func (sn *StemNode) Key(i int) []byte {
var ret [HashSize]byte
- copy(ret[:], bt.Stem)
+ copy(ret[:], sn.Stem[:])
ret[StemSize] = byte(i)
return ret[:]
}
diff --git a/trie/bintrie/stem_node_test.go b/trie/bintrie/stem_node_test.go
index 2743e7ce9b..5faf903fba 100644
--- a/trie/bintrie/stem_node_test.go
+++ b/trie/bintrie/stem_node_test.go
@@ -23,165 +23,99 @@ import (
"github.com/ethereum/go-ethereum/common"
)
-// TestStemNodeGet tests the Get method for matching stem, non-matching stem,
-// and nil-value suffix scenarios.
-func TestStemNodeGet(t *testing.T) {
- stem := make([]byte, StemSize)
- stem[0] = 0xAB
- var values [StemNodeWidth][]byte
- values[5] = common.HexToHash("0xdeadbeef").Bytes()
-
- node := &StemNode{Stem: stem, Values: values[:], depth: 0}
-
- // Matching stem, populated suffix → returns value.
- key := make([]byte, HashSize)
- copy(key[:StemSize], stem)
- key[StemSize] = 5
- got, err := node.Get(key, nil)
- if err != nil {
- t.Fatalf("Get error: %v", err)
- }
- if !bytes.Equal(got, values[5]) {
- t.Fatalf("Get = %x, want %x", got, values[5])
- }
-
- // Matching stem, empty suffix → returns nil (slot not set).
- key[StemSize] = 99
- got, err = node.Get(key, nil)
- if err != nil {
- t.Fatalf("Get error: %v", err)
- }
- if got != nil {
- t.Fatalf("Get(empty suffix) = %x, want nil", got)
- }
-
- // Non-matching stem → returns nil, nil.
- otherKey := make([]byte, HashSize)
- otherKey[0] = 0xFF
- got, err = node.Get(otherKey, nil)
- if err != nil {
- t.Fatalf("Get error: %v", err)
- }
- if got != nil {
- t.Fatalf("Get(wrong stem) = %x, want nil", got)
- }
-}
-
-// TestStemNodeInsertSameStem tests inserting values with the same stem
+// TestStemNodeInsertSameStem tests inserting values with the same stem via nodeStore.
func TestStemNodeInsertSameStem(t *testing.T) {
+ s := newNodeStore()
+
stem := make([]byte, 31)
for i := range stem {
stem[i] = byte(i)
}
- var values [256][]byte
- values[0] = common.HexToHash("0x0101").Bytes()
-
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
+ // Insert first value
+ key1 := make([]byte, 32)
+ copy(key1[:31], stem)
+ key1[31] = 0
+ value1 := common.HexToHash("0x0101").Bytes()
+ if err := s.Insert(key1, value1, nil); err != nil {
+ t.Fatal(err)
}
// Insert another value with the same stem but different last byte
- key := make([]byte, 32)
- copy(key[:31], stem)
- key[31] = 10
- value := common.HexToHash("0x0202").Bytes()
-
- newNode, err := node.Insert(key, value, nil, 0)
- if err != nil {
- t.Fatalf("Failed to insert: %v", err)
+ key2 := make([]byte, 32)
+ copy(key2[:31], stem)
+ key2[31] = 10
+ value2 := common.HexToHash("0x0202").Bytes()
+ if err := s.Insert(key2, value2, nil); err != nil {
+ t.Fatal(err)
}
- // Should still be a StemNode
- stemNode, ok := newNode.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", newNode)
+ // Root should still be a StemNode
+ if s.root.Kind() != kindStem {
+ t.Fatalf("Expected kindStem root, got kind %d", s.root.Kind())
}
// Check that both values are present
- if !bytes.Equal(stemNode.Values[0], values[0]) {
+ v1, _ := s.Get(key1, nil)
+ if !bytes.Equal(v1, value1) {
t.Errorf("Value at index 0 mismatch")
}
- if !bytes.Equal(stemNode.Values[10], value) {
+ v2, _ := s.Get(key2, nil)
+ if !bytes.Equal(v2, value2) {
t.Errorf("Value at index 10 mismatch")
}
}
-// TestStemNodeInsertDifferentStem tests inserting values with different stems
+// TestStemNodeInsertDifferentStem tests inserting values with different stems via nodeStore.
func TestStemNodeInsertDifferentStem(t *testing.T) {
- stem1 := make([]byte, 31)
- for i := range stem1 {
- stem1[i] = 0x00
- }
+ s := newNodeStore()
- var values [256][]byte
- values[0] = common.HexToHash("0x0101").Bytes()
-
- node := &StemNode{
- Stem: stem1,
- Values: values[:],
- depth: 0,
+ // Insert first value with stem of all zeros
+ key1 := make([]byte, 32)
+ key1[31] = 0
+ value1 := common.HexToHash("0x0101").Bytes()
+ if err := s.Insert(key1, value1, nil); err != nil {
+ t.Fatal(err)
}
// Insert with a different stem (first bit different)
- key := make([]byte, 32)
- key[0] = 0x80 // First bit is 1 instead of 0
- value := common.HexToHash("0x0202").Bytes()
-
- newNode, err := node.Insert(key, value, nil, 0)
- if err != nil {
- t.Fatalf("Failed to insert: %v", err)
+ key2 := make([]byte, 32)
+ key2[0] = 0x80 // First bit is 1 instead of 0
+ value2 := common.HexToHash("0x0202").Bytes()
+ if err := s.Insert(key2, value2, nil); err != nil {
+ t.Fatal(err)
}
// Should now be an InternalNode
- internalNode, ok := newNode.(*InternalNode)
- if !ok {
- t.Fatalf("Expected InternalNode, got %T", newNode)
+ if s.root.Kind() != kindInternal {
+ t.Fatalf("Expected kindInternal root, got kind %d", s.root.Kind())
}
// Check depth
- if internalNode.depth != 0 {
- t.Errorf("Expected depth 0, got %d", internalNode.depth)
+ rootNode := s.getInternal(s.root.Index())
+ if rootNode.depth != 0 {
+ t.Errorf("Expected depth 0, got %d", rootNode.depth)
}
- // Original stem should be on the left (bit 0)
- leftStem, ok := internalNode.left.(*StemNode)
- if !ok {
- t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
+ // Verify both values are retrievable
+ v1, _ := s.Get(key1, nil)
+ if !bytes.Equal(v1, value1) {
+ t.Error("Value 1 mismatch")
}
- if !bytes.Equal(leftStem.Stem, stem1) {
- t.Errorf("Left stem mismatch")
- }
-
- // New stem should be on the right (bit 1)
- rightStem, ok := internalNode.right.(*StemNode)
- if !ok {
- t.Fatalf("Expected right child to be StemNode, got %T", internalNode.right)
- }
- if !bytes.Equal(rightStem.Stem, key[:31]) {
- t.Errorf("Right stem mismatch")
+ v2, _ := s.Get(key2, nil)
+ if !bytes.Equal(v2, value2) {
+ t.Error("Value 2 mismatch")
}
}
-// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length
+// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length via nodeStore.
func TestStemNodeInsertInvalidValueLength(t *testing.T) {
- stem := make([]byte, 31)
- var values [256][]byte
+ s := newNodeStore()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
- }
-
- // Try to insert value with wrong length
key := make([]byte, 32)
- copy(key[:31], stem)
invalidValue := []byte{1, 2, 3} // Not 32 bytes
- _, err := node.Insert(key, invalidValue, nil, 0)
+ err := s.Insert(key, invalidValue, nil)
if err == nil {
t.Fatal("Expected error for invalid value length")
}
@@ -191,221 +125,209 @@ func TestStemNodeInsertInvalidValueLength(t *testing.T) {
}
}
-// TestStemNodeCopy tests the Copy method
+// TestStemNodeCopy tests the Copy method via nodeStore.
func TestStemNodeCopy(t *testing.T) {
- stem := make([]byte, 31)
- for i := range stem {
- stem[i] = byte(i)
+ s := newNodeStore()
+
+ key1 := make([]byte, 32)
+ for i := range 31 {
+ key1[i] = byte(i)
+ }
+ key1[31] = 0
+ value1 := common.HexToHash("0x0101").Bytes()
+
+ key2 := make([]byte, 32)
+ copy(key2[:31], key1[:31])
+ key2[31] = 255
+ value2 := common.HexToHash("0x0202").Bytes()
+
+ if err := s.Insert(key1, value1, nil); err != nil {
+ t.Fatal(err)
+ }
+ if err := s.Insert(key2, value2, nil); err != nil {
+ t.Fatal(err)
}
- var values [256][]byte
- values[0] = common.HexToHash("0x0101").Bytes()
- values[255] = common.HexToHash("0x0202").Bytes()
+ ns := s.Copy()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 10,
- }
-
- // Create a copy
- copied := node.Copy()
- copiedStem, ok := copied.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", copied)
- }
-
- // Check that values are equal but not the same slice
- if !bytes.Equal(copiedStem.Stem, node.Stem) {
- t.Errorf("Stem mismatch after copy")
- }
- if &copiedStem.Stem[0] == &node.Stem[0] {
- t.Error("Stem slice not properly cloned")
- }
-
- // Check values
- if !bytes.Equal(copiedStem.Values[0], node.Values[0]) {
+ // Check that values are equal
+ v1, _ := ns.Get(key1, nil)
+ if !bytes.Equal(v1, value1) {
t.Errorf("Value at index 0 mismatch after copy")
}
- if !bytes.Equal(copiedStem.Values[255], node.Values[255]) {
+ v2, _ := ns.Get(key2, nil)
+ if !bytes.Equal(v2, value2) {
t.Errorf("Value at index 255 mismatch after copy")
}
-
- // Check that value slices are cloned
- if copiedStem.Values[0] != nil && &copiedStem.Values[0][0] == &node.Values[0][0] {
- t.Error("Value slice not properly cloned")
- }
-
- // Check depth
- if copiedStem.depth != node.depth {
- t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedStem.depth)
- }
}
-// TestStemNodeHash tests the Hash method
+// TestStemNodeHash tests the Hash method.
func TestStemNodeHash(t *testing.T) {
- stem := make([]byte, 31)
- var values [256][]byte
- values[0] = common.HexToHash("0x0101").Bytes()
+ s := newNodeStore()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
+ key := make([]byte, 32)
+ key[31] = 0
+ value := common.HexToHash("0x0101").Bytes()
+ if err := s.Insert(key, value, nil); err != nil {
+ t.Fatal(err)
}
- hash1 := node.Hash()
+ hash1 := s.computeHash(s.root)
// Hash should be deterministic
- hash2 := node.Hash()
+ hash2 := s.computeHash(s.root)
if hash1 != hash2 {
t.Errorf("Hash not deterministic: %x != %x", hash1, hash2)
}
// Changing a value should change the hash
- node.Values[1] = common.HexToHash("0x0202").Bytes()
- node.mustRecompute = true
- hash3 := node.Hash()
+ key2 := make([]byte, 32)
+ key2[31] = 1
+ value2 := common.HexToHash("0x0202").Bytes()
+ if err := s.Insert(key2, value2, nil); err != nil {
+ t.Fatal(err)
+ }
+ hash3 := s.computeHash(s.root)
if hash1 == hash3 {
t.Error("Hash didn't change after modifying values")
}
}
-// TestStemNodeGetValuesAtStem tests GetValuesAtStem method
+// TestStemNodeGetValuesAtStem tests GetValuesAtStem method via nodeStore.
func TestStemNodeGetValuesAtStem(t *testing.T) {
+ s := newNodeStore()
+
stem := make([]byte, 31)
for i := range stem {
stem[i] = byte(i)
}
- var values [256][]byte
+ values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes()
values[10] = common.HexToHash("0x0202").Bytes()
values[255] = common.HexToHash("0x0303").Bytes()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
+ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
+ t.Fatal(err)
}
// GetValuesAtStem with matching stem
- retrievedValues, err := node.GetValuesAtStem(stem, nil)
+ retrievedValues, err := s.GetValuesAtStem(stem, nil)
if err != nil {
t.Fatalf("Failed to get values: %v", err)
}
- // Check that all values match
- for i := range 256 {
- if !bytes.Equal(retrievedValues[i], values[i]) {
- t.Errorf("Value mismatch at index %d", i)
- }
+ if !bytes.Equal(retrievedValues[0], values[0]) {
+ t.Error("Value at index 0 mismatch")
+ }
+ if !bytes.Equal(retrievedValues[10], values[10]) {
+ t.Error("Value at index 10 mismatch")
+ }
+ if !bytes.Equal(retrievedValues[255], values[255]) {
+ t.Error("Value at index 255 mismatch")
}
- // GetValuesAtStem with different stem should return nil
+ // GetValuesAtStem with different stem should return nil values
differentStem := make([]byte, 31)
differentStem[0] = 0xFF
- shouldBeNil, err := node.GetValuesAtStem(differentStem, nil)
+ shouldBeEmpty, err := s.GetValuesAtStem(differentStem, nil)
if err != nil {
t.Fatalf("Failed to get values with different stem: %v", err)
}
- if shouldBeNil != nil {
- t.Error("Expected nil for different stem, got non-nil")
+ allNil := true
+ for _, v := range shouldBeEmpty {
+ if v != nil {
+ allNil = false
+ break
+ }
+ }
+ if !allNil {
+ t.Error("Expected all nil values for different stem")
}
}
-// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method
+// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method via nodeStore.
func TestStemNodeInsertValuesAtStem(t *testing.T) {
+ s := newNodeStore()
+
stem := make([]byte, 31)
- var values [256][]byte
+ values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
+ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
+ t.Fatal(err)
}
// Insert new values at the same stem
- var newValues [256][]byte
+ newValues := make([][]byte, 256)
newValues[1] = common.HexToHash("0x0202").Bytes()
newValues[2] = common.HexToHash("0x0303").Bytes()
- newNode, err := node.InsertValuesAtStem(stem, newValues[:], nil, 0)
- if err != nil {
- t.Fatalf("Failed to insert values: %v", err)
- }
-
- stemNode, ok := newNode.(*StemNode)
- if !ok {
- t.Fatalf("Expected StemNode, got %T", newNode)
+ if err := s.InsertValuesAtStem(stem, newValues, nil); err != nil {
+ t.Fatal(err)
}
// Check that all values are present
- if !bytes.Equal(stemNode.Values[0], values[0]) {
+ retrieved, err := s.GetValuesAtStem(stem, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(retrieved[0], values[0]) {
t.Error("Original value at index 0 missing")
}
- if !bytes.Equal(stemNode.Values[1], newValues[1]) {
+ if !bytes.Equal(retrieved[1], newValues[1]) {
t.Error("New value at index 1 missing")
}
- if !bytes.Equal(stemNode.Values[2], newValues[2]) {
+ if !bytes.Equal(retrieved[2], newValues[2]) {
t.Error("New value at index 2 missing")
}
}
-// TestStemNodeGetHeight tests GetHeight method
+// TestStemNodeGetHeight tests GetHeight method via nodeStore.
func TestStemNodeGetHeight(t *testing.T) {
- node := &StemNode{
- Stem: make([]byte, 31),
- Values: make([][]byte, 256),
- depth: 0,
+ s := newNodeStore()
+
+ key := make([]byte, 32)
+ value := common.HexToHash("0x01").Bytes()
+ if err := s.Insert(key, value, nil); err != nil {
+ t.Fatal(err)
}
- height := node.GetHeight()
+ height := s.getHeight(s.root)
if height != 1 {
t.Errorf("Expected height 1, got %d", height)
}
}
-// TestStemNodeCollectNodes tests CollectNodes method
+// TestStemNodeCollectNodes tests CollectNodes method via nodeStore.
func TestStemNodeCollectNodes(t *testing.T) {
+ s := newNodeStore()
+
stem := make([]byte, 31)
- var values [256][]byte
+ values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes()
- node := &StemNode{
- Stem: stem,
- Values: values[:],
- depth: 0,
- dirty: true,
+ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
+ t.Fatal(err)
}
var collectedPaths [][]byte
- var collectedNodes []BinaryNode
-
- flushFn := func(path []byte, n BinaryNode) {
- // Make a copy of the path
+ flushFn := func(path []byte, hash common.Hash, serialized []byte) {
pathCopy := make([]byte, len(path))
copy(pathCopy, path)
collectedPaths = append(collectedPaths, pathCopy)
- collectedNodes = append(collectedNodes, n)
}
- err := node.CollectNodes([]byte{0, 1, 0}, flushFn)
+ err := s.collectNodes(s.root, []byte{0, 1, 0}, flushFn)
if err != nil {
t.Fatalf("Failed to collect nodes: %v", err)
}
// Should have collected one node (itself)
- if len(collectedNodes) != 1 {
- t.Errorf("Expected 1 collected node, got %d", len(collectedNodes))
- }
-
- // Check that the collected node is the same
- if collectedNodes[0] != node {
- t.Error("Collected node doesn't match original")
+ if len(collectedPaths) != 1 {
+ t.Errorf("Expected 1 collected node, got %d", len(collectedPaths))
}
// Check the path
@@ -413,44 +335,3 @@ func TestStemNodeCollectNodes(t *testing.T) {
t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0])
}
}
-
-// TestStemNodeCollectNodesSkipsClean verifies that a clean stem is not
-// flushed, and that flushing a dirty stem clears its dirty flag so that
-// a subsequent CollectNodes on the same node is a no-op.
-func TestStemNodeCollectNodesSkipsClean(t *testing.T) {
- stem := make([]byte, 31)
- node := &StemNode{
- Stem: stem,
- Values: make([][]byte, 256),
- depth: 0,
- }
-
- var collected []BinaryNode
- flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) }
-
- if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
- t.Fatalf("CollectNodes on clean stem: %v", err)
- }
- if len(collected) != 0 {
- t.Fatalf("expected clean stem not to be flushed, got %d", len(collected))
- }
-
- node.dirty = true
- if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
- t.Fatalf("CollectNodes on dirty stem: %v", err)
- }
- if len(collected) != 1 {
- t.Fatalf("expected dirty stem to be flushed once, got %d", len(collected))
- }
- if node.dirty {
- t.Errorf("stem dirty flag should be cleared after flush")
- }
-
- collected = nil
- if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
- t.Fatalf("CollectNodes after flush: %v", err)
- }
- if len(collected) != 0 {
- t.Errorf("expected no flush on clean stem, got %d", len(collected))
- }
-}
diff --git a/trie/bintrie/store_commit.go b/trie/bintrie/store_commit.go
new file mode 100644
index 0000000000..b142ecb34c
--- /dev/null
+++ b/trie/bintrie/store_commit.go
@@ -0,0 +1,310 @@
+// Copyright 2026 go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package bintrie
+
+import (
+ "crypto/sha256"
+ "errors"
+ "fmt"
+ "math/bits"
+ "runtime"
+ "sync"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+type nodeFlushFn func(path []byte, hash common.Hash, serialized []byte)
+
+func (s *nodeStore) Hash() common.Hash {
+ return s.computeHash(s.root)
+}
+
+func (s *nodeStore) computeHash(ref nodeRef) common.Hash {
+ switch ref.Kind() {
+ case kindInternal:
+ return s.hashInternal(ref.Index())
+ case kindStem:
+ return s.getStem(ref.Index()).Hash()
+ case kindHashed:
+ return s.getHashed(ref.Index()).Hash()
+ case kindEmpty:
+ return common.Hash{}
+ default:
+ return common.Hash{}
+ }
+}
+
+// parallelHashDepth is the tree depth below which hashInternal spawns
+// goroutines for shallow-depth parallelism. Computed once at init because
+// NumCPU() never changes after startup.
+var parallelHashDepth = min(bits.Len(uint(runtime.NumCPU())), 8)
+
+// hashInternal hashes an InternalNode and caches the result.
+//
+// At shallow depths (< parallelHashDepth) the left subtree is hashed in a
+// goroutine while the right subtree is hashed inline, then the two digests
+// are combined. Below that threshold the goroutine spawn cost outweighs the
+// hashing work, so deeper nodes hash both children sequentially.
+func (s *nodeStore) hashInternal(idx uint32) common.Hash {
+ node := s.getInternal(idx)
+ if !node.mustRecompute {
+ return node.hash
+ }
+
+ if int(node.depth) < parallelHashDepth {
+ var input [64]byte
+ var lh common.Hash
+ var wg sync.WaitGroup
+ if !node.left.IsEmpty() {
+ wg.Add(1)
+ go func() {
+ // defer wg.Done() so a panic in computeHash still releases
+ // the waiter; without this, a recover() higher in the call
+ // stack would leave the parent stuck in wg.Wait forever.
+ defer wg.Done()
+ lh = s.computeHash(node.left)
+ }()
+ }
+ if !node.right.IsEmpty() {
+ rh := s.computeHash(node.right)
+ copy(input[32:], rh[:])
+ }
+ wg.Wait()
+ copy(input[:32], lh[:])
+ node.hash = sha256.Sum256(input[:])
+ node.mustRecompute = false
+ return node.hash
+ }
+
+ // Deep sequential branch — mirrors the shallow branch's shape to keep
+ // input on the stack. Writing lh/rh through hash.Hash (interface)
+ // forces escape; copy into a local [64]byte and hash it in one shot.
+ var input [64]byte
+ if !node.left.IsEmpty() {
+ lh := s.computeHash(node.left)
+ copy(input[:HashSize], lh[:])
+ }
+ if !node.right.IsEmpty() {
+ rh := s.computeHash(node.right)
+ copy(input[HashSize:], rh[:])
+ }
+ node.hash = sha256.Sum256(input[:])
+ node.mustRecompute = false
+ return node.hash
+}
+
+// SerializeNode serializes a node into the flat on-disk format.
+func (s *nodeStore) serializeNode(ref nodeRef) []byte {
+ switch ref.Kind() {
+ case kindInternal:
+ node := s.getInternal(ref.Index())
+ var serialized [NodeTypeBytes + HashSize + HashSize]byte
+ serialized[0] = nodeTypeInternal
+ lh := s.computeHash(node.left)
+ rh := s.computeHash(node.right)
+ copy(serialized[NodeTypeBytes:NodeTypeBytes+HashSize], lh[:])
+ copy(serialized[NodeTypeBytes+HashSize:], rh[:])
+ return serialized[:]
+
+ case kindStem:
+ sn := s.getStem(ref.Index())
+ // Count present slots to size the blob.
+ var count int
+ for _, v := range sn.values {
+ if v != nil {
+ count++
+ }
+ }
+ serializedLen := NodeTypeBytes + StemSize + StemBitmapSize + count*HashSize
+ serialized := make([]byte, serializedLen)
+ serialized[0] = nodeTypeStem
+ copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], sn.Stem[:])
+ bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+StemBitmapSize]
+ offset := NodeTypeBytes + StemSize + StemBitmapSize
+ for i, v := range sn.values {
+ if v != nil {
+ bitmap[i/8] |= 1 << (7 - (i % 8))
+ copy(serialized[offset:offset+HashSize], v)
+ offset += HashSize
+ }
+ }
+ return serialized
+
+ default:
+ panic(fmt.Sprintf("SerializeNode: unexpected node kind %d", ref.Kind()))
+ }
+}
+
+var errInvalidSerializedLength = errors.New("invalid serialized node length")
+
+// DeserializeNode deserializes a node from bytes, recomputing its hash. The
+// returned node is marked dirty (provenance unknown, safe re-flush default).
+func (s *nodeStore) deserializeNode(serialized []byte, depth int) (nodeRef, error) {
+ return s.decodeNode(serialized, depth, common.Hash{}, true, true)
+}
+
+// DeserializeNodeWithHash deserializes a node whose hash is already known and
+// whose blob is already on disk (mustRecompute=false, dirty=false).
+func (s *nodeStore) deserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (nodeRef, error) {
+ return s.decodeNode(serialized, depth, hn, false, false)
+}
+
+func (s *nodeStore) decodeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (nodeRef, error) {
+ if len(serialized) == 0 {
+ return emptyRef, nil
+ }
+
+ switch serialized[0] {
+ case nodeTypeInternal:
+ if len(serialized) != NodeTypeBytes+2*HashSize {
+ return emptyRef, errInvalidSerializedLength
+ }
+ var leftHash, rightHash common.Hash
+ copy(leftHash[:], serialized[NodeTypeBytes:NodeTypeBytes+HashSize])
+ copy(rightHash[:], serialized[NodeTypeBytes+HashSize:])
+
+ var leftRef, rightRef nodeRef
+ if leftHash != (common.Hash{}) {
+ leftRef = s.newHashedRef(leftHash)
+ }
+ if rightHash != (common.Hash{}) {
+ rightRef = s.newHashedRef(rightHash)
+ }
+
+ ref := s.newInternalRef(depth)
+ node := s.getInternal(ref.Index())
+ node.left = leftRef
+ node.right = rightRef
+ if !mustRecompute {
+ node.hash = hn
+ node.mustRecompute = false
+ }
+ node.dirty = dirty
+ return ref, nil
+
+ case nodeTypeStem:
+ if len(serialized) < NodeTypeBytes+StemSize+StemBitmapSize {
+ return emptyRef, errInvalidSerializedLength
+ }
+ stemIdx := s.allocStem()
+ sn := s.getStem(stemIdx)
+ copy(sn.Stem[:], serialized[NodeTypeBytes:NodeTypeBytes+StemSize])
+ bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+StemBitmapSize]
+ offset := NodeTypeBytes + StemSize + StemBitmapSize
+ for i := range StemNodeWidth {
+ if bitmap[i/8]>>(7-(i%8))&1 != 1 {
+ continue
+ }
+ if len(serialized) < offset+HashSize {
+ return emptyRef, errInvalidSerializedLength
+ }
+ // Zero-copy: each slot aliases the serialized input buffer.
+ sn.values[i] = serialized[offset : offset+HashSize]
+ offset += HashSize
+ }
+ sn.depth = uint8(depth)
+ sn.hash = hn
+ sn.mustRecompute = mustRecompute
+ sn.dirty = dirty
+ return makeRef(kindStem, stemIdx), nil
+
+ default:
+ return emptyRef, errors.New("invalid node type")
+ }
+}
+
+// CollectNodes flushes every node that needs flushing via flushfn in post-order.
+// Invariant: any ancestor of a node that needs flushing is itself marked, so a
+// clean root means the whole subtree is clean.
+func (s *nodeStore) collectNodes(ref nodeRef, path []byte, flushfn nodeFlushFn) error {
+ switch ref.Kind() {
+ case kindEmpty:
+ return nil
+ case kindInternal:
+ node := s.getInternal(ref.Index())
+ if !node.dirty {
+ return nil
+ }
+ // Reuse path buffer across children: flushfn consumers
+ // (NodeSet.AddNode, tracer.Get) clone via string(path), so in-place
+ // mutation is safe.
+ path = append(path, 0)
+ if err := s.collectNodes(node.left, path, flushfn); err != nil {
+ return err
+ }
+ path[len(path)-1] = 1
+ if err := s.collectNodes(node.right, path, flushfn); err != nil {
+ return err
+ }
+ path = path[:len(path)-1]
+ flushfn(path, s.computeHash(ref), s.serializeNode(ref))
+ node.dirty = false
+ return nil
+ case kindStem:
+ sn := s.getStem(ref.Index())
+ if !sn.dirty {
+ return nil
+ }
+ flushfn(path, s.computeHash(ref), s.serializeNode(ref))
+ sn.dirty = false
+ return nil
+ case kindHashed:
+ return nil // Already committed
+ default:
+ return fmt.Errorf("CollectNodes: unexpected kind %d", ref.Kind())
+ }
+}
+
+func (s *nodeStore) toDot(ref nodeRef, parent, path string) string {
+ switch ref.Kind() {
+ case kindInternal:
+ node := s.getInternal(ref.Index())
+ me := fmt.Sprintf("internal%s", path)
+ ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, s.computeHash(ref))
+ if len(parent) > 0 {
+ ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
+ }
+ if !node.left.IsEmpty() {
+ ret += s.toDot(node.left, me, fmt.Sprintf("%s%02x", path, 0))
+ }
+ if !node.right.IsEmpty() {
+ ret += s.toDot(node.right, me, fmt.Sprintf("%s%02x", path, 1))
+ }
+ return ret
+ case kindStem:
+ sn := s.getStem(ref.Index())
+ me := fmt.Sprintf("stem%s", path)
+ ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, sn.Stem, sn.Hash())
+ ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
+ for i, v := range sn.values {
+ if v == nil {
+ continue
+ }
+ ret += fmt.Sprintf("%s%x [label=\"%x\"]\n", me, i, v)
+ ret += fmt.Sprintf("%s -> %s%x\n", me, me, i)
+ }
+ return ret
+ case kindHashed:
+ hn := s.getHashed(ref.Index())
+ me := fmt.Sprintf("hash%s", path)
+ ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, hn.Hash())
+ ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
+ return ret
+ default:
+ return ""
+ }
+}
diff --git a/trie/bintrie/store_ops.go b/trie/bintrie/store_ops.go
new file mode 100644
index 0000000000..9a73c8bd64
--- /dev/null
+++ b/trie/bintrie/store_ops.go
@@ -0,0 +1,345 @@
+// Copyright 2026 go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package bintrie
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// nodeResolverFn resolves a hashed node from the database.
+type nodeResolverFn func([]byte, common.Hash) ([]byte, error)
+
+// GetValue returns the value at (stem, suffix) or nil if absent. Thin
+// wrapper over GetValuesAtStem — the underlying StemNode returns its
+// 256-slot array as a slice header (no allocation), so the per-call cost
+// is the tree walk plus one index.
+func (s *nodeStore) GetValue(stem []byte, suffix byte, resolver nodeResolverFn) ([]byte, error) {
+ values, err := s.GetValuesAtStem(stem, resolver)
+ if err != nil || values == nil {
+ return nil, err
+ }
+ return values[suffix], nil
+}
+
+// GetValuesAtStem returns the 256 value slots at stem, or nil if the stem
+// is not in the trie. The returned slice is a view over the in-place
+// StemNode values array (no allocation) and must be treated read-only.
+func (s *nodeStore) GetValuesAtStem(stem []byte, resolver nodeResolverFn) ([][]byte, error) {
+ cur := s.root
+ var parentIdx uint32
+ var parentIsLeft bool
+
+ for {
+ switch cur.Kind() {
+ case kindInternal:
+ node := s.getInternal(cur.Index())
+ if node.depth >= 31*8 {
+ return nil, errors.New("node too deep")
+ }
+ bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
+ parentIdx = cur.Index()
+ if bit == 0 {
+ parentIsLeft = true
+ cur = node.left
+ } else {
+ parentIsLeft = false
+ cur = node.right
+ }
+
+ case kindStem:
+ sn := s.getStem(cur.Index())
+ if sn.Stem != [StemSize]byte(stem[:StemSize]) {
+ return nil, nil
+ }
+ return sn.allValues(), nil
+
+ case kindHashed:
+ // HashedNode at root is impossible: NewBinaryTrie resolves the
+ // root eagerly before any query. Any HashedNode we encounter here
+ // is necessarily a child of a previously-visited internal node.
+ if resolver == nil {
+ return nil, errors.New("getValuesAtStem: cannot resolve hashed node without resolver")
+ }
+ hn := s.getHashed(cur.Index())
+ parentNode := s.getInternal(parentIdx)
+ path, err := keyToPath(int(parentNode.depth), stem)
+ if err != nil {
+ return nil, fmt.Errorf("getValuesAtStem path error: %w", err)
+ }
+ data, err := resolver(path, hn.Hash())
+ if err != nil {
+ return nil, fmt.Errorf("getValuesAtStem resolve error: %w", err)
+ }
+ resolved, err := s.deserializeNodeWithHash(data, int(parentNode.depth)+1, hn.Hash())
+ if err != nil {
+ return nil, fmt.Errorf("getValuesAtStem deserialization error: %w", err)
+ }
+ s.freeHashedNode(cur.Index())
+ if parentIsLeft {
+ parentNode.left = resolved
+ } else {
+ parentNode.right = resolved
+ }
+ cur = resolved
+
+ case kindEmpty:
+ var values [StemNodeWidth][]byte
+ return values[:], nil
+
+ default:
+ return nil, fmt.Errorf("getValuesAtStem: unexpected node kind %d", cur.Kind())
+ }
+ }
+}
+
+// InsertSingle writes a single value slot at (stem, suffix). Thin wrapper
+// over InsertValuesAtStem — builds a stack-allocated 256-slot array with
+// only the target slot set and delegates. Matches the original design
+// gballet referenced (comment 3101751325): one primary insert path; the
+// single-slot variant dispatches through it so the split / resolve logic
+// lives in one place.
+func (s *nodeStore) InsertSingle(stem []byte, suffix byte, value []byte, resolver nodeResolverFn) error {
+ if len(value) != HashSize {
+ return errors.New("invalid insertion: value length")
+ }
+ var values [StemNodeWidth][]byte
+ values[suffix] = value
+ return s.InsertValuesAtStem(stem, values[:], resolver)
+}
+
+// InsertValuesAtStem writes the supplied value slots at stem. values may be
+// sparse (nil entries are ignored). The recursive implementation dispatches
+// through the same body, so a single code path handles internal descent,
+// HashedNode resolution, stem merge, and stem split.
+func (s *nodeStore) InsertValuesAtStem(stem []byte, values [][]byte, resolver nodeResolverFn) error {
+ var err error
+ s.root, err = s.insertValuesAtStem(s.root, stem, values, resolver, 0)
+ return err
+}
+
+func (s *nodeStore) insertValuesAtStem(ref nodeRef, stem []byte, values [][]byte, resolver nodeResolverFn, depth int) (nodeRef, error) {
+ switch ref.Kind() {
+ case kindInternal:
+ node := s.getInternal(ref.Index())
+ bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
+ if bit == 0 {
+ if node.left.Kind() == kindHashed {
+ if resolver == nil {
+ return ref, errors.New("insertValuesAtStem: cannot resolve hashed node without resolver")
+ }
+ hn := s.getHashed(node.left.Index())
+ path, err := keyToPath(int(node.depth), stem)
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
+ }
+ data, err := resolver(path, hn.Hash())
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
+ }
+ resolved, err := s.deserializeNodeWithHash(data, int(node.depth)+1, hn.Hash())
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
+ }
+ s.freeHashedNode(node.left.Index())
+ node.left = resolved
+ }
+ newChild, err := s.insertValuesAtStem(node.left, stem, values, resolver, depth+1)
+ if err != nil {
+ return ref, err
+ }
+ node.left = newChild
+ } else {
+ if node.right.Kind() == kindHashed {
+ if resolver == nil {
+ return ref, errors.New("insertValuesAtStem: cannot resolve hashed node without resolver")
+ }
+ hn := s.getHashed(node.right.Index())
+ path, err := keyToPath(int(node.depth), stem)
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
+ }
+ data, err := resolver(path, hn.Hash())
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
+ }
+ resolved, err := s.deserializeNodeWithHash(data, int(node.depth)+1, hn.Hash())
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
+ }
+ s.freeHashedNode(node.right.Index())
+ node.right = resolved
+ }
+ newChild, err := s.insertValuesAtStem(node.right, stem, values, resolver, depth+1)
+ if err != nil {
+ return ref, err
+ }
+ node.right = newChild
+ }
+ node.mustRecompute = true
+ node.dirty = true
+ return ref, nil
+
+ case kindStem:
+ sn := s.getStem(ref.Index())
+ if sn.Stem == [StemSize]byte(stem[:StemSize]) {
+ // Same stem — merge values (setValue marks dirty+mustRecompute)
+ for i, v := range values {
+ if v != nil {
+ sn.setValue(byte(i), v)
+ }
+ }
+ return ref, nil
+ }
+ // Different stem — split
+ return s.splitStemValuesInsert(ref, stem, values, resolver, depth)
+
+ case kindHashed:
+ hn := s.getHashed(ref.Index())
+ path, err := keyToPath(depth, stem)
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
+ }
+ if resolver == nil {
+ return ref, errors.New("InsertValuesAtStem: resolver is nil")
+ }
+ data, err := resolver(path, hn.Hash())
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
+ }
+ resolved, err := s.deserializeNodeWithHash(data, depth, hn.Hash())
+ if err != nil {
+ return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
+ }
+ s.freeHashedNode(ref.Index())
+ return s.insertValuesAtStem(resolved, stem, values, resolver, depth)
+
+ case kindEmpty:
+ // Create new StemNode. Flag flips before the value loop so an
+ // all-nil values input still marks the newly-created stem dirty.
+ stemIdx := s.allocStem()
+ sn := s.getStem(stemIdx)
+ copy(sn.Stem[:], stem[:StemSize])
+ sn.depth = uint8(depth)
+ sn.mustRecompute = true
+ sn.dirty = true
+ for i, v := range values {
+ if v != nil {
+ sn.setValue(byte(i), v)
+ }
+ }
+ return makeRef(kindStem, stemIdx), nil
+
+ default:
+ return ref, fmt.Errorf("insertValuesAtStem: unexpected kind %d", ref.Kind())
+ }
+}
+
+// splitStemValuesInsert splits a StemNode when the new stem diverges.
+func (s *nodeStore) splitStemValuesInsert(existingRef nodeRef, newStem []byte, values [][]byte, resolver nodeResolverFn, depth int) (nodeRef, error) {
+ existing := s.getStem(existingRef.Index())
+
+ if int(existing.depth) >= StemSize*8 {
+ panic("splitStemValuesInsert: identical stems")
+ }
+
+ bitStem := existing.Stem[existing.depth/8] >> (7 - (existing.depth % 8)) & 1
+ nRef := s.newInternalRef(int(existing.depth))
+ nNode := s.getInternal(nRef.Index())
+ existing.depth++
+
+ bitKey := newStem[nNode.depth/8] >> (7 - (nNode.depth % 8)) & 1
+ if bitKey == bitStem {
+ // Same direction — need deeper split
+ var child nodeRef
+ if bitStem == 0 {
+ nNode.left = existingRef
+ child = nNode.left
+ } else {
+ nNode.right = existingRef
+ child = nNode.right
+ }
+ newChild, err := s.insertValuesAtStem(child, newStem, values, resolver, depth+1)
+ if err != nil {
+ // Roll back the depth increment so a retry sees the same
+ // existing state and extracts bitStem at the correct offset.
+ // nRef itself leaks (no internal free-list), but the slot is
+ // unreachable from the tree and harmless.
+ existing.depth--
+ return nRef, err
+ }
+ if bitStem == 0 {
+ nNode.left = newChild
+ nNode.right = emptyRef
+ } else {
+ nNode.right = newChild
+ nNode.left = emptyRef
+ }
+ } else {
+ // Divergence — create new StemNode for the new values
+ newStemIdx := s.allocStem()
+ newSn := s.getStem(newStemIdx)
+ copy(newSn.Stem[:], newStem[:StemSize])
+ newSn.depth = nNode.depth + 1
+ newSn.mustRecompute = true
+ newSn.dirty = true
+ for i, v := range values {
+ if v != nil {
+ newSn.setValue(byte(i), v)
+ }
+ }
+ newStemRef := makeRef(kindStem, newStemIdx)
+
+ if bitStem == 0 {
+ nNode.left = existingRef
+ nNode.right = newStemRef
+ } else {
+ nNode.left = newStemRef
+ nNode.right = existingRef
+ }
+ }
+ return nRef, nil
+}
+
+func (s *nodeStore) Insert(key []byte, value []byte, resolver nodeResolverFn) error {
+ return s.InsertSingle(key[:StemSize], key[StemSize], value, resolver)
+}
+
+func (s *nodeStore) Get(key []byte, resolver nodeResolverFn) ([]byte, error) {
+ return s.GetValue(key[:StemSize], key[StemSize], resolver)
+}
+
+func (s *nodeStore) getHeight(ref nodeRef) int {
+ switch ref.Kind() {
+ case kindInternal:
+ node := s.getInternal(ref.Index())
+ lh := s.getHeight(node.left)
+ rh := s.getHeight(node.right)
+ if lh > rh {
+ return 1 + lh
+ }
+ return 1 + rh
+ case kindStem:
+ return 1
+ case kindEmpty:
+ return 0
+ default:
+ return 0
+ }
+}
diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go
index 23d014eb33..8c69e0aa00 100644
--- a/trie/bintrie/trie.go
+++ b/trie/bintrie/trie.go
@@ -19,7 +19,6 @@ package bintrie
import (
"bytes"
"encoding/binary"
- "errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
@@ -31,8 +30,6 @@ import (
"github.com/holiman/uint256"
)
-var errInvalidRootType = errors.New("invalid root type")
-
// ChunkedCode represents a sequence of HashSize-byte chunks of code (StemSize bytes of which
// are actual code, and NodeTypeBytes byte is the pushdata offset).
type ChunkedCode []byte
@@ -108,22 +105,17 @@ func ChunkifyCode(code []byte) ChunkedCode {
return chunks
}
-// NewBinaryNode creates a new empty binary trie
-func NewBinaryNode() BinaryNode {
- return Empty{}
-}
-
// BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864.
type BinaryTrie struct {
- root BinaryNode
+ store *nodeStore
reader *trie.Reader
tracer *trie.PrevalueTracer
}
// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
func (t *BinaryTrie) ToDot() string {
- t.root.Hash()
- return ToDot(t.root)
+ t.store.computeHash(t.store.root)
+ return t.store.toDot(t.store.root, "", "")
}
// NewBinaryTrie creates a new binary trie.
@@ -133,7 +125,7 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
return nil, err
}
t := &BinaryTrie{
- root: NewBinaryNode(),
+ store: newNodeStore(),
reader: reader,
tracer: trie.NewPrevalueTracer(),
}
@@ -143,11 +135,11 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
if err != nil {
return nil, err
}
- node, err := DeserializeNodeWithHash(blob, 0, root)
+ ref, err := t.store.deserializeNodeWithHash(blob, 0, root)
if err != nil {
return nil, err
}
- t.root = node
+ t.store.root = ref
}
return t, nil
}
@@ -176,29 +168,18 @@ func (t *BinaryTrie) GetKey(key []byte) []byte {
// GetWithHashedKey returns the value, assuming that the key has already
// been hashed.
func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) {
- return t.root.Get(key, t.nodeResolver)
+ return t.store.Get(key, t.nodeResolver)
}
// GetAccount returns the account information for the given address.
func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) {
var (
- values [][]byte
- err error
- acc = &types.StateAccount{}
- key = GetBinaryTreeKey(addr, zero[:])
+ err error
+ acc = &types.StateAccount{}
+ key = GetBinaryTreeKey(addr, zero[:])
)
- switch r := t.root.(type) {
- case *InternalNode:
- values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver)
- case *StemNode:
- values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver)
- case Empty:
- return nil, nil
- default:
- // This will cover HashedNode but that should be fine since the
- // root node should always be resolved.
- return nil, errInvalidRootType
- }
+
+ values, err := t.store.GetValuesAtStem(key[:StemSize], t.nodeResolver)
if err != nil {
return nil, fmt.Errorf("GetAccount (%x) error: %v", addr, err)
}
@@ -219,7 +200,7 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
// If the account has been deleted, BasicData and CodeHash will both be
// 32-byte zero blobs (not nil). If the account is recreated afterwards,
// UpdateAccount overwrites BasicData and CodeHash with non-zero values,
- // so this branch won't activate..
+ // so this branch won't activate.
if bytes.Equal(values[BasicDataLeafKey], zero[:]) &&
bytes.Equal(values[CodeHashLeafKey], zero[:]) {
return nil, nil
@@ -238,13 +219,12 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
// not be modified by the caller. If a node was not found in the database, a
// trie.MissingNodeError is returned.
func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) {
- return t.root.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
+ return t.store.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
}
// UpdateAccount updates the account information for the given address.
func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error {
var (
- err error
basicData [HashSize]byte
values = make([][]byte, StemNodeWidth)
stem = GetBinaryTreeKey(addr, zero[:])
@@ -265,15 +245,12 @@ func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount,
values[BasicDataLeafKey] = basicData[:]
values[CodeHashLeafKey] = acc.CodeHash[:]
- t.root, err = t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0)
- return err
+ return t.store.InsertValuesAtStem(stem, values, t.nodeResolver)
}
// UpdateStem updates the values for the given stem key.
func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error {
- var err error
- t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0)
- return err
+ return t.store.InsertValuesAtStem(key, values, t.nodeResolver)
}
// UpdateStorage associates key with value in the trie. If value has length zero, any
@@ -288,11 +265,10 @@ func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) er
} else {
copy(v[HashSize-len(value):], value[:])
}
- root, err := t.root.Insert(k, v[:], t.nodeResolver, 0)
+ err := t.store.Insert(k, v[:], t.nodeResolver)
if err != nil {
return fmt.Errorf("UpdateStorage (%x) error: %v", address, err)
}
- t.root = root
return nil
}
@@ -307,12 +283,7 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
values[BasicDataLeafKey] = zero[:]
values[CodeHashLeafKey] = zero[:]
- root, err := t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0)
- if err != nil {
- return fmt.Errorf("DeleteAccount (%x) error: %v", addr, err)
- }
- t.root = root
- return nil
+ return t.store.InsertValuesAtStem(stem, values, t.nodeResolver)
}
// DeleteStorage removes any existing value for key from the trie. If a node was not
@@ -320,18 +291,17 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error {
k := GetBinaryTreeKeyStorageSlot(addr, key)
var zero [HashSize]byte
- root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0)
+ err := t.store.Insert(k, zero[:], t.nodeResolver)
if err != nil {
return fmt.Errorf("DeleteStorage (%x) error: %v", addr, err)
}
- t.root = root
return nil
}
// Hash returns the root hash of the trie. It does not write to the database and
// can be used even if the trie doesn't have one.
func (t *BinaryTrie) Hash() common.Hash {
- return t.root.Hash()
+ return t.store.computeHash(t.store.root)
}
// Commit writes all nodes to the trie's memory database, tracking the internal
@@ -339,15 +309,15 @@ func (t *BinaryTrie) Hash() common.Hash {
func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
nodeset := trienode.NewNodeSet(common.Hash{})
- // The root can be any type of BinaryNode (InternalNode, StemNode, etc.)
- err := t.root.CollectNodes(nil, func(path []byte, node BinaryNode) {
- serialized := SerializeNode(node)
- nodeset.AddNode(path, trienode.NewNodeWithPrev(node.Hash(), serialized, t.tracer.Get(path)))
+ // Pre-size the path buffer: collectNodes reuses it in-place via
+ // append/truncate; 32 covers typical binary-trie depth without regrowth.
+ pathBuf := make([]byte, 0, 32)
+ err := t.store.collectNodes(t.store.root, pathBuf, func(path []byte, hash common.Hash, serialized []byte) {
+ nodeset.AddNode(path, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(path)))
})
if err != nil {
panic(fmt.Errorf("CollectNodes failed: %v", err))
}
- // Serialize root commitment form
return t.Hash(), nodeset
}
@@ -371,7 +341,7 @@ func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
// Copy creates a deep copy of the trie.
func (t *BinaryTrie) Copy() *BinaryTrie {
return &BinaryTrie{
- root: t.root.Copy(),
+ store: t.store.Copy(),
reader: t.reader,
tracer: t.tracer.Copy(),
}
@@ -407,7 +377,6 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has
if groupOffset == StemNodeWidth-1 || len(chunks)-i <= HashSize {
err = t.UpdateStem(key[:StemSize], values)
-
if err != nil {
return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err)
}
diff --git a/trie/bintrie/trie_test.go b/trie/bintrie/trie_test.go
index c436501464..73aacb76c4 100644
--- a/trie/bintrie/trie_test.go
+++ b/trie/bintrie/trie_test.go
@@ -37,147 +37,130 @@ var (
)
func TestSingleEntry(t *testing.T) {
- tree := NewBinaryNode()
- tree, err := tree.Insert(zeroKey[:], oneKey[:], nil, 0)
- if err != nil {
+ s := newNodeStore()
+ if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 1 {
+ if s.getHeight(s.root) != 1 {
t.Fatal("invalid depth")
}
expected := common.HexToHash("aab1060e04cb4f5dc6f697ae93156a95714debbf77d54238766adc5709282b6f")
- got := tree.Hash()
+ got := s.Hash()
if got != expected {
t.Fatalf("invalid tree root, got %x, want %x", got, expected)
}
}
func TestTwoEntriesDiffFirstBit(t *testing.T) {
- var err error
- tree := NewBinaryNode()
- tree, err = tree.Insert(zeroKey[:], oneKey[:], nil, 0)
- if err != nil {
+ s := newNodeStore()
+ if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 2 {
+ if s.getHeight(s.root) != 2 {
t.Fatal("invalid height")
}
- if tree.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") {
+ if s.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") {
t.Fatal("invalid tree root")
}
}
func TestOneStemColocatedValues(t *testing.T) {
- var err error
- tree := NewBinaryNode()
- tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
- if err != nil {
+ s := newNodeStore()
+ if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 1 {
+ if s.getHeight(s.root) != 1 {
t.Fatal("invalid height")
}
}
func TestTwoStemColocatedValues(t *testing.T) {
- var err error
- tree := NewBinaryNode()
+ s := newNodeStore()
// stem: 0...0
- tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
// stem: 10...0
- tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 2 {
+ if s.getHeight(s.root) != 2 {
t.Fatal("invalid height")
}
}
func TestTwoKeysMatchFirst42Bits(t *testing.T) {
- var err error
- tree := NewBinaryNode()
+ s := newNodeStore()
// key1 and key 2 have the same prefix of 42 bits (b0*42+b1+b1) and differ after.
key1 := common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0").Bytes()
key2 := common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000").Bytes()
- tree, err = tree.Insert(key1, oneKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(key1, oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(key2, twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(key2, twoKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 1+42+1 {
+ if s.getHeight(s.root) != 1+42+1 {
t.Fatal("invalid height")
}
}
+
func TestInsertDuplicateKey(t *testing.T) {
- var err error
- tree := NewBinaryNode()
- tree, err = tree.Insert(oneKey[:], oneKey[:], nil, 0)
- if err != nil {
+ s := newNodeStore()
+ if err := s.Insert(oneKey[:], oneKey[:], nil); err != nil {
t.Fatal(err)
}
- tree, err = tree.Insert(oneKey[:], twoKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(oneKey[:], twoKey[:], nil); err != nil {
t.Fatal(err)
}
- if tree.GetHeight() != 1 {
+ if s.getHeight(s.root) != 1 {
t.Fatal("invalid height")
}
// Verify that the value is updated
- if !bytes.Equal(tree.(*StemNode).Values[1], twoKey[:]) {
- t.Fatal("invalid height")
+ v, err := s.Get(oneKey[:], nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(v, twoKey[:]) {
+ t.Fatal("value not updated")
}
}
+
func TestLargeNumberOfEntries(t *testing.T) {
- var err error
- tree := NewBinaryNode()
+ s := newNodeStore()
for i := range StemNodeWidth {
var key [HashSize]byte
key[0] = byte(i)
- tree, err = tree.Insert(key[:], ffKey[:], nil, 0)
- if err != nil {
+ if err := s.Insert(key[:], ffKey[:], nil); err != nil {
t.Fatal(err)
}
}
- height := tree.GetHeight()
+ height := s.getHeight(s.root)
if height != 1+8 {
t.Fatalf("invalid height, wanted %d, got %d", 1+8, height)
}
}
func TestMerkleizeMultipleEntries(t *testing.T) {
- var err error
- tree := NewBinaryNode()
+ s := newNodeStore()
keys := [][]byte{
zeroKey[:],
common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(),
@@ -187,12 +170,11 @@ func TestMerkleizeMultipleEntries(t *testing.T) {
for i, key := range keys {
var v [HashSize]byte
binary.LittleEndian.PutUint64(v[:8], uint64(i))
- tree, err = tree.Insert(key, v[:], nil, 0)
- if err != nil {
+ if err := s.Insert(key, v[:], nil); err != nil {
t.Fatal(err)
}
}
- got := tree.Hash()
+ got := s.Hash()
expected := common.HexToHash("9317155862f7a3867660ddd0966ff799a3d16aa4df1e70a7516eaa4a675191b5")
if got != expected {
t.Fatalf("invalid root, expected=%x, got = %x", expected, got)
@@ -206,7 +188,7 @@ func TestMerkleizeMultipleEntries(t *testing.T) {
func TestStorageRoundTrip(t *testing.T) {
tracer := trie.NewPrevalueTracer()
tr := &BinaryTrie{
- root: NewBinaryNode(),
+ store: newNodeStore(),
tracer: tracer,
}
addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678")
@@ -274,7 +256,7 @@ func TestStorageRoundTrip(t *testing.T) {
func newEmptyTestTrie(t *testing.T) *BinaryTrie {
t.Helper()
return &BinaryTrie{
- root: NewBinaryNode(),
+ store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
}
@@ -599,7 +581,7 @@ func TestBinaryTrieWitness(t *testing.T) {
tracer := trie.NewPrevalueTracer()
tr := &BinaryTrie{
- root: NewBinaryNode(),
+ store: newNodeStore(),
tracer: tracer,
}
if w := tr.Witness(); len(w) != 0 {
@@ -626,7 +608,7 @@ func TestBinaryTrieWitness(t *testing.T) {
func testAccount(t *testing.T, addr common.Address, nonce uint64, balance uint64) *BinaryTrie {
t.Helper()
tr := &BinaryTrie{
- root: NewBinaryNode(),
+ store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
acc := &types.StateAccount{
@@ -649,8 +631,8 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) {
tr := testAccount(t, addr, 42, 100)
// Verify root is a StemNode (single stem inserted).
- if _, ok := tr.root.(*StemNode); !ok {
- t.Fatalf("expected StemNode root, got %T", tr.root)
+ if tr.store.root.Kind() != kindStem {
+ t.Fatalf("expected StemNode root, got kind %d", tr.store.root.Kind())
}
// Query a completely different address — must return nil.
@@ -680,7 +662,7 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) {
// address returns nil when the trie root is an InternalNode (multi-account trie).
func TestGetAccountNonMembershipInternalRoot(t *testing.T) {
tr := &BinaryTrie{
- root: NewBinaryNode(),
+ store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
@@ -700,8 +682,8 @@ func TestGetAccountNonMembershipInternalRoot(t *testing.T) {
}
// Verify root is an InternalNode.
- if _, ok := tr.root.(*InternalNode); !ok {
- t.Fatalf("expected InternalNode root, got %T", tr.root)
+ if tr.store.root.Kind() != kindInternal {
+ t.Fatalf("expected InternalNode root, got kind %d", tr.store.root.Kind())
}
// Query a non-existent address — must return nil.
@@ -723,8 +705,8 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) {
tr := testAccount(t, addr, 1, 100)
// Verify root is a StemNode.
- if _, ok := tr.root.(*StemNode); !ok {
- t.Fatalf("expected StemNode root, got %T", tr.root)
+ if tr.store.root.Kind() != kindStem {
+ t.Fatalf("expected StemNode root, got kind %d", tr.store.root.Kind())
}
// Query storage for a different address — must return nil, not panic.
@@ -743,7 +725,7 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) {
// non-existent address returns nil when the root is an InternalNode.
func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
tr := &BinaryTrie{
- root: NewBinaryNode(),
+ store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
@@ -765,8 +747,8 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
t.Fatalf("UpdateStorage error: %v", err)
}
- if _, ok := tr.root.(*InternalNode); !ok {
- t.Fatalf("expected InternalNode root, got %T", tr.root)
+ if tr.store.root.Kind() != kindInternal {
+ t.Fatalf("expected InternalNode root, got kind %d", tr.store.root.Kind())
}
// Query storage for a non-existent address — must return nil.
@@ -780,36 +762,29 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
}
}
-// commitKeyN derives a distinct 32-byte key from a seed integer. Used by
-// TestBinaryTrieCommitIncremental and BenchmarkCollectNodes_SparseWrite to
-// populate a trie with many disjoint stems.
-func commitKeyN(i int) [HashSize]byte {
- var k [HashSize]byte
- binary.BigEndian.PutUint64(k[:8], uint64(i)*0x9e3779b97f4a7c15)
- binary.BigEndian.PutUint64(k[8:16], uint64(i)*0xc2b2ae3d27d4eb4f)
- binary.BigEndian.PutUint64(k[16:24], uint64(i)*0x165667b19e3779f9)
- binary.BigEndian.PutUint64(k[24:32], uint64(i)*0x85ebca77c2b2ae63)
- return k
-}
-
-// TestBinaryTrieCommitIncremental verifies that a second Commit with only a
-// single modified leaf flushes only the path from that leaf to the root,
-// not the entire tree.
-func TestBinaryTrieCommitIncremental(t *testing.T) {
+// TestCommitSkipCleanSubtrees verifies that CollectNodes short-circuits on
+// clean subtrees. First Commit flushes every resolved node; a follow-up
+// Commit with no modifications flushes nothing; a single-leaf modification
+// flushes only the root-to-leaf path.
+func TestCommitSkipCleanSubtrees(t *testing.T) {
tr := &BinaryTrie{
- root: NewBinaryNode(),
+ store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
-
- const n = 512
- keys := make([][HashSize]byte, n)
+ const n = 200
+ key := func(i int) [HashSize]byte {
+ var k [HashSize]byte
+ binary.BigEndian.PutUint64(k[:8], uint64(i+1)*0x9e3779b97f4a7c15)
+ binary.BigEndian.PutUint64(k[8:16], uint64(i+1)*0xc2b2ae3d27d4eb4f)
+ binary.BigEndian.PutUint64(k[16:24], uint64(i+1)*0x165667b19e3779f9)
+ binary.BigEndian.PutUint64(k[24:32], uint64(i+1)*0x85ebca77c2b2ae63)
+ return k
+ }
for i := range n {
- keys[i] = commitKeyN(i + 1)
+ k := key(i)
var v [HashSize]byte
binary.BigEndian.PutUint64(v[24:], uint64(i+1))
- var err error
- tr.root, err = tr.root.Insert(keys[i][:], v[:], nil, 0)
- if err != nil {
+ if err := tr.store.Insert(k[:], v[:], nil); err != nil {
t.Fatalf("Insert %d: %v", i, err)
}
}
@@ -818,69 +793,55 @@ func TestBinaryTrieCommitIncremental(t *testing.T) {
if len(ns1.Nodes) == 0 {
t.Fatal("first Commit produced empty NodeSet")
}
- if len(ns1.Nodes) < n {
- t.Fatalf("first Commit: expected at least %d nodes, got %d", n, len(ns1.Nodes))
- }
- // Second Commit on the same trie with no modifications: NodeSet must
- // be empty because every subtree is clean.
_, nsNoop := tr.Commit(false)
if len(nsNoop.Nodes) != 0 {
- t.Fatalf("no-op Commit: expected empty NodeSet, got %d nodes", len(nsNoop.Nodes))
+ t.Fatalf("no-op Commit: expected empty NodeSet, got %d", len(nsNoop.Nodes))
}
- // Modify a single leaf's value. Only the path from that leaf to the
- // root should appear in the next Commit's NodeSet.
+ // Modify a single leaf — only the root-to-leaf path should flush.
+ k := key(n / 2)
var newVal [HashSize]byte
newVal[0] = 0xff
- var err error
- tr.root, err = tr.root.Insert(keys[n/2][:], newVal[:], nil, 0)
- if err != nil {
+ if err := tr.store.Insert(k[:], newVal[:], nil); err != nil {
t.Fatalf("Insert (modify): %v", err)
}
_, ns2 := tr.Commit(false)
-
- // Path length for a binary trie of n=512 stems is bounded by the
- // internal depth at which the modified stem sits. Allow generous
- // slack: up to 64 nodes is fine, anywhere near n (512) is a regression.
if len(ns2.Nodes) == 0 {
t.Fatal("modified Commit produced empty NodeSet")
}
- if len(ns2.Nodes) > 64 {
- t.Fatalf("modified Commit: expected small NodeSet, got %d nodes (first Commit had %d)", len(ns2.Nodes), len(ns1.Nodes))
+ if len(ns2.Nodes) > 32 {
+ t.Fatalf("modified Commit: expected ≤32 nodes (path+stem), got %d", len(ns2.Nodes))
}
if len(ns2.Nodes) >= len(ns1.Nodes) {
t.Fatalf("expected second NodeSet (%d) to be smaller than first (%d)", len(ns2.Nodes), len(ns1.Nodes))
}
}
-// BenchmarkCollectNodes_SparseWrite measures Commit cost when only one leaf
-// changes between blocks — the common case for state updates. After warm-up
+// BenchmarkCollectNodesSparseWrite measures Commit cost when one leaf
+// changes per block — the common case for state updates. After warm-up
// (populate + initial Commit), each iteration modifies a single leaf and
-// re-Commits. Under the skip-clean optimization, each iteration flushes
-// only the root-to-leaf path; pre-fix behavior would re-flush the entire
-// tree every iteration.
-func BenchmarkCollectNodes_SparseWrite(b *testing.B) {
+// re-Commits. Matches the shape of the same-named benchmark on master so
+// the two trees can be benchstat'd directly.
+func BenchmarkCollectNodesSparseWrite(b *testing.B) {
const n = 10_000
-
tr := &BinaryTrie{
- root: NewBinaryNode(),
+ store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
keys := make([][HashSize]byte, n)
for i := range n {
- keys[i] = commitKeyN(i + 1)
+ binary.BigEndian.PutUint64(keys[i][:8], uint64(i+1)*0x9e3779b97f4a7c15)
+ binary.BigEndian.PutUint64(keys[i][8:16], uint64(i+1)*0xc2b2ae3d27d4eb4f)
+ binary.BigEndian.PutUint64(keys[i][16:24], uint64(i+1)*0x165667b19e3779f9)
+ binary.BigEndian.PutUint64(keys[i][24:32], uint64(i+1)*0x85ebca77c2b2ae63)
var v [HashSize]byte
binary.BigEndian.PutUint64(v[24:], uint64(i+1))
- var err error
- tr.root, err = tr.root.Insert(keys[i][:], v[:], nil, 0)
- if err != nil {
- b.Fatalf("Insert %d: %v", i, err)
+ if err := tr.store.Insert(keys[i][:], v[:], nil); err != nil {
+ b.Fatalf("warmup Insert %d: %v", i, err)
}
}
- // Flush the initial tree so subsequent Commits reflect the
- // single-modification workload we want to measure.
- _, _ = tr.Commit(false)
+ _, _ = tr.Commit(false) // warmup flush
var newVal [HashSize]byte
b.ReportAllocs()
@@ -888,10 +849,8 @@ func BenchmarkCollectNodes_SparseWrite(b *testing.B) {
for i := 0; i < b.N; i++ {
idx := i % n
binary.BigEndian.PutUint64(newVal[24:], uint64(i+1))
- var err error
- tr.root, err = tr.root.Insert(keys[idx][:], newVal[:], nil, 0)
- if err != nil {
- b.Fatalf("Insert at iter %d: %v", i, err)
+ if err := tr.store.Insert(keys[idx][:], newVal[:], nil); err != nil {
+ b.Fatalf("iter %d Insert: %v", i, err)
}
_, ns := tr.Commit(false)
if len(ns.Nodes) == 0 {
diff --git a/triedb/pathdb/database.go b/triedb/pathdb/database.go
index 04c76cfd53..98975d7fa5 100644
--- a/triedb/pathdb/database.go
+++ b/triedb/pathdb/database.go
@@ -102,11 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) {
if len(blob) == 0 {
return types.EmptyBinaryHash, nil
}
- n, err := bintrie.DeserializeNode(blob, 0)
- if err != nil {
- return common.Hash{}, err
- }
- return n.Hash(), nil
+ return bintrie.DeserializeAndHash(blob, 0)
}
// Database is a multiple-layered structure for maintaining in-memory states