trie/bintrie: replace BinaryNode interface with GC-free NodeRef arena (#34055)

## Summary

Replace the `BinaryNode` interface with `NodeRef uint32` indices into
typed arena pools, eliminating GC-scanned pointers from binary trie
nodes.

Inspired by [fjl's
observation](https://github.com/ethereum/go-ethereum/pull/34034#issuecomment-4075176446):
> *"if the binary trie produces such a large graph, it should probably
be changed so that the trie node type does not contain pointers. The
runtime does not scan objects that do not contain pointers, so it can
really help with the performance to build it this way."*

### The problem

CPU profiling of the binary trie (EIP-7864) showed **44% of CPU time in
garbage collection**. Each `InternalNode` held two `BinaryNode`
interface values (2 pointer-words each), and the GC scanned every one.
With ~25K `InternalNode`s in memory during block processing, this
created enormous GC pressure.

### The solution

`NodeRef` is a compact `uint32` (2-bit kind tag + 30-bit pool index).
`NodeStore` manages chunked typed pools per node kind:
- **InternalNode pool**: ZERO Go pointers (children are `NodeRef`, hash
is `[32]byte`) → noscan spans
- **HashedNode pool**: ZERO Go pointers → noscan spans
- **StemNode pool**: retains `Values [][]byte` (matching existing
format)

The serialization format is unchanged — flat InternalNode
`[type][leftHash][rightHash]` = 65 bytes.

## Benchmark: Apple M4 Pro (`--benchtime=10s --count=3`, on top of
#34021)

| Metric | Baseline | Arena | Delta |
|--------|----------|-------|-------|
| Approve (Mgas/s) | 374 | 382 | **+2.1%** |
| BalanceOf (Mgas/s) | 885 | 901 | **+1.8%** |
| Approve allocs/op | 775K | **607K** | **-21.7%** |
| BalanceOf allocs/op | 265K | **228K** | **-14.0%** |

## Benchmark: AMD EPYC 48-core (50GB state, execution-specs ERC-20, on
top of #34021 + #34032)

| Benchmark | Baseline | Arena | Delta |
|-----------|----------|-------|-------|
| erc20_approve (write) | 22.4 Mgas/s | **27.0 Mgas/s** | **+20.5%** |
| mixed_sload_sstore | 62.9 Mgas/s | **97.3 Mgas/s** | **+54.7%** |
| erc20_balanceof (read) | 180.8 Mgas/s | 167.6 Mgas/s | -7.3% (cold
cache variance) |

The arena benefit scales with heap size — the EPYC (larger heap, more GC
pressure) shows much larger gains than the M4 Pro (efficient unified
memory). The mixed workload baseline was unstable (62.9 vs 16.3 Mgas/s
between runs due to GC-induced throughput collapse); the arena
eliminates this entirely (95-97 Mgas/s, stable).

## Dependencies

Benchmarked with #34021 (H01 N+1 fix) + #34032 (R14 parallel hashing).
No code dependency — applies independently to master.

All test suites pass (`trie/bintrie` with `-race`, `core/state`,
`triedb/pathdb`, `cmd/geth`).

---------

Co-authored-by: Guillaume Ballet <3272758+gballet@users.noreply.github.com>
This commit is contained in:
CPerezz 2026-04-20 14:08:30 +02:00 committed by GitHub
parent 29e0a6f404
commit b6d415c88d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 1689 additions and 2092 deletions

View file

@ -16,140 +16,32 @@
package bintrie
import (
"errors"
"github.com/ethereum/go-ethereum/common"
)
type (
NodeFlushFn func([]byte, BinaryNode)
NodeResolverFn func([]byte, common.Hash) ([]byte, error)
)
import "github.com/ethereum/go-ethereum/common"
// zero is the zero value for a 32-byte array.
var zero [32]byte
const (
StemNodeWidth = 256 // Number of child per leaf node
StemSize = 31 // Number of bytes to travel before reaching a group of leaves
NodeTypeBytes = 1 // Size of node type prefix in serialization
HashSize = 32 // Size of a hash in bytes
BitmapSize = 32 // Size of the bitmap in a stem node
StemNodeWidth = 256 // Number of children per leaf node
StemSize = 31 // Number of bytes to travel before reaching a group of leaves
NodeTypeBytes = 1 // Size of node type prefix in serialization
HashSize = 32 // Size of a hash in bytes
StemBitmapSize = 32 // Size of the bitmap in a stem node (256 values = 32 bytes)
)
const (
nodeTypeStem = iota + 1 // Stem node, contains a stem and a bitmap of values
nodeTypeStem = iota + 1
nodeTypeInternal
)
// BinaryNode is an interface for a binary trie node.
type BinaryNode interface {
Get([]byte, NodeResolverFn) ([]byte, error)
Insert([]byte, []byte, NodeResolverFn, int) (BinaryNode, error)
Copy() BinaryNode
Hash() common.Hash
GetValuesAtStem([]byte, NodeResolverFn) ([][]byte, error)
InsertValuesAtStem([]byte, [][]byte, NodeResolverFn, int) (BinaryNode, error)
CollectNodes([]byte, NodeFlushFn) error
toDot(parent, path string) string
GetHeight() int
}
// SerializeNode serializes a binary trie node into a byte slice.
func SerializeNode(node BinaryNode) []byte {
switch n := (node).(type) {
case *InternalNode:
// InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash
var serialized [NodeTypeBytes + HashSize + HashSize]byte
serialized[0] = nodeTypeInternal
copy(serialized[1:33], n.left.Hash().Bytes())
copy(serialized[33:65], n.right.Hash().Bytes())
return serialized[:]
case *StemNode:
// StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values
var serialized [NodeTypeBytes + StemSize + BitmapSize + StemNodeWidth*HashSize]byte
serialized[0] = nodeTypeStem
copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], n.Stem)
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize]
offset := NodeTypeBytes + StemSize + BitmapSize
for i, v := range n.Values {
if v != nil {
bitmap[i/8] |= 1 << (7 - (i % 8))
copy(serialized[offset:offset+HashSize], v)
offset += HashSize
}
}
// Only return the actual data, not the entire array
return serialized[:offset]
default:
panic("invalid node type")
// DeserializeAndHash deserializes a node from bytes and returns its hash.
// This is a convenience function for external callers that need to compute
// the hash of a serialized node without maintaining a nodeStore.
func DeserializeAndHash(blob []byte, depth int) (common.Hash, error) {
s := newNodeStore()
ref, err := s.deserializeNode(blob, depth)
if err != nil {
return common.Hash{}, err
}
}
var invalidSerializedLength = errors.New("invalid serialized node length")
// DeserializeNode deserializes a binary trie node from a byte slice. The
// hash will be recomputed from the deserialized data.
func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
return deserializeNode(serialized, depth, common.Hash{}, true, true)
}
// DeserializeNodeWithHash deserializes a binary trie node from a byte slice, using the provided hash.
func DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
return deserializeNode(serialized, depth, hn, false, false)
}
func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (BinaryNode, error) {
if len(serialized) == 0 {
return Empty{}, nil
}
switch serialized[0] {
case nodeTypeInternal:
if len(serialized) != 65 {
return nil, invalidSerializedLength
}
return &InternalNode{
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
hash: hn,
mustRecompute: mustRecompute,
dirty: dirty,
}, nil
case nodeTypeStem:
if len(serialized) < 64 {
return nil, invalidSerializedLength
}
var values [StemNodeWidth][]byte
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize]
offset := NodeTypeBytes + StemSize + BitmapSize
for i := range StemNodeWidth {
if bitmap[i/8]>>(7-(i%8))&1 == 1 {
if len(serialized) < offset+HashSize {
return nil, invalidSerializedLength
}
values[i] = serialized[offset : offset+HashSize]
offset += HashSize
}
}
return &StemNode{
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
Values: values[:],
depth: depth,
hash: hn,
mustRecompute: mustRecompute,
dirty: dirty,
}, nil
default:
return nil, errors.New("invalid node type")
}
}
// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
func ToDot(root BinaryNode) string {
return root.toDot("", "")
return s.computeHash(ref), nil
}

View file

@ -23,79 +23,100 @@ import (
"github.com/ethereum/go-ethereum/common"
)
// TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode
// TestSerializeDeserializeInternalNode tests flat 65-byte serialization and
// deserialization of InternalNode through nodeStore.
func TestSerializeDeserializeInternalNode(t *testing.T) {
// Create an internal node with two hashed children
leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321")
node := &InternalNode{
depth: 5,
left: HashedNode(leftHash),
right: HashedNode(rightHash),
}
s := newNodeStore()
leftRef := s.newHashedRef(leftHash)
rightRef := s.newHashedRef(rightHash)
// Serialize the node
serialized := SerializeNode(node)
rootRef := s.newInternalRef(0)
rootNode := s.getInternal(rootRef.Index())
rootNode.left = leftRef
rootNode.right = rightRef
s.root = rootRef
// Check the serialized format
// Serialize the node — flat 65-byte format
serialized := s.serializeNode(rootRef)
// Check the serialized format: [type(1)][leftHash(32)][rightHash(32)]
if serialized[0] != nodeTypeInternal {
t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0])
}
if len(serialized) != 65 {
t.Errorf("Expected serialized length to be 65, got %d", len(serialized))
expectedLen := NodeTypeBytes + 2*HashSize // 1 + 64 = 65
if len(serialized) != expectedLen {
t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized))
}
// Deserialize the node
deserialized, err := DeserializeNode(serialized, 5)
// Check that left and right hashes are embedded directly
if !bytes.Equal(serialized[NodeTypeBytes:NodeTypeBytes+HashSize], leftHash[:]) {
t.Error("Left hash not found at expected position")
}
if !bytes.Equal(serialized[NodeTypeBytes+HashSize:], rightHash[:]) {
t.Error("Right hash not found at expected position")
}
// Deserialize into a new store
ds := newNodeStore()
deserialized, err := ds.deserializeNode(serialized, 0)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
// Check that it's an internal node
internalNode, ok := deserialized.(*InternalNode)
if !ok {
t.Fatalf("Expected InternalNode, got %T", deserialized)
// Root should be an InternalNode
if deserialized.Kind() != kindInternal {
t.Fatalf("Expected kindInternal, got kind %d", deserialized.Kind())
}
// Check the depth
if internalNode.depth != 5 {
t.Errorf("Expected depth 5, got %d", internalNode.depth)
internalNode := ds.getInternal(deserialized.Index())
if internalNode.depth != 0 {
t.Errorf("Expected depth 0, got %d", internalNode.depth)
}
// Check the left and right hashes
if internalNode.left.Hash() != leftHash {
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash())
// Left child should be a HashedNode with the correct hash
if internalNode.left.Kind() != kindHashed {
t.Fatalf("Expected left child to be kindHashed, got %d", internalNode.left.Kind())
}
if ds.computeHash(internalNode.left) != leftHash {
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, ds.computeHash(internalNode.left))
}
if internalNode.right.Hash() != rightHash {
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash())
// Right child should be a HashedNode with the correct hash
if internalNode.right.Kind() != kindHashed {
t.Fatalf("Expected right child to be kindHashed, got %d", internalNode.right.Kind())
}
if ds.computeHash(internalNode.right) != rightHash {
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, ds.computeHash(internalNode.right))
}
}
// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode
// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode through nodeStore.
func TestSerializeDeserializeStemNode(t *testing.T) {
// Create a stem node with some values
stem := make([]byte, StemSize)
for i := range stem {
stem[i] = byte(i)
}
var values [StemNodeWidth][]byte
// Add some values at different indices
values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes()
values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes()
values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes()
node := &StemNode{
Stem: stem,
Values: values[:],
depth: 10,
s := newNodeStore()
ref := s.newStemRef(stem, 10)
sn := s.getStem(ref.Index())
for i, v := range values {
if v != nil {
sn.setValue(byte(i), v)
}
}
// Serialize the node
serialized := SerializeNode(node)
serialized := s.serializeNode(ref)
// Check the serialized format
if serialized[0] != nodeTypeStem {
@ -107,31 +128,32 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
t.Errorf("Stem mismatch in serialized data")
}
// Deserialize the node
deserialized, err := DeserializeNode(serialized, 10)
// Deserialize into a new store
ds := newNodeStore()
deserializedRef, err := ds.deserializeNode(serialized, 10)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
// Check that it's a stem node
stemNode, ok := deserialized.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", deserialized)
if deserializedRef.Kind() != kindStem {
t.Fatalf("Expected kindStem, got kind %d", deserializedRef.Kind())
}
stemNode := ds.getStem(deserializedRef.Index())
// Check the stem
if !bytes.Equal(stemNode.Stem, stem) {
if !bytes.Equal(stemNode.Stem[:], stem) {
t.Errorf("Stem mismatch after deserialization")
}
// Check the values
if !bytes.Equal(stemNode.Values[0], values[0]) {
if !bytes.Equal(stemNode.getValue(0), values[0]) {
t.Errorf("Value at index 0 mismatch")
}
if !bytes.Equal(stemNode.Values[10], values[10]) {
if !bytes.Equal(stemNode.getValue(10), values[10]) {
t.Errorf("Value at index 10 mismatch")
}
if !bytes.Equal(stemNode.Values[255], values[255]) {
if !bytes.Equal(stemNode.getValue(255), values[255]) {
t.Errorf("Value at index 255 mismatch")
}
@ -140,43 +162,43 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
if i == 0 || i == 10 || i == 255 {
continue
}
if stemNode.Values[i] != nil {
t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i])
if stemNode.hasValue(byte(i)) {
t.Errorf("Expected no value at index %d, got %x", i, stemNode.getValue(byte(i)))
}
}
}
// TestDeserializeEmptyNode tests deserialization of empty node
// TestDeserializeEmptyNode tests deserialization of empty node.
func TestDeserializeEmptyNode(t *testing.T) {
// Empty byte slice should deserialize to Empty node
deserialized, err := DeserializeNode([]byte{}, 0)
s := newNodeStore()
deserialized, err := s.deserializeNode([]byte{}, 0)
if err != nil {
t.Fatalf("Failed to deserialize empty node: %v", err)
}
_, ok := deserialized.(Empty)
if !ok {
t.Fatalf("Expected Empty node, got %T", deserialized)
if !deserialized.IsEmpty() {
t.Fatalf("Expected emptyRef, got kind %d", deserialized.Kind())
}
}
// TestDeserializeInvalidType tests deserialization with invalid type byte
// TestDeserializeInvalidType tests deserialization with invalid type byte.
func TestDeserializeInvalidType(t *testing.T) {
// Create invalid serialized data with unknown type byte
s := newNodeStore()
invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid
_, err := DeserializeNode(invalidData, 0)
_, err := s.deserializeNode(invalidData, 0)
if err == nil {
t.Fatal("Expected error for invalid type byte, got nil")
}
}
// TestDeserializeInvalidLength tests deserialization with invalid data length
// TestDeserializeInvalidLength tests deserialization with invalid data length.
func TestDeserializeInvalidLength(t *testing.T) {
// InternalNode with type byte 1 but wrong length
invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node
s := newNodeStore()
// InternalNode with valid type byte but wrong length (needs exactly 65 bytes)
invalidData := []byte{nodeTypeInternal, 0, 0, 0}
_, err := DeserializeNode(invalidData, 0)
_, err := s.deserializeNode(invalidData, 0)
if err == nil {
t.Fatal("Expected error for invalid data length, got nil")
}
@ -186,7 +208,7 @@ func TestDeserializeInvalidLength(t *testing.T) {
}
}
// TestKeyToPath tests the keyToPath function
// TestKeyToPath tests the keyToPath function.
func TestKeyToPath(t *testing.T) {
tests := []struct {
name string
@ -218,14 +240,14 @@ func TestKeyToPath(t *testing.T) {
},
{
name: "max valid depth",
depth: StemSize * 8,
depth: StemSize*8 - 1,
key: make([]byte, HashSize),
expected: make([]byte, StemSize*8+1),
expected: make([]byte, StemSize*8),
wantErr: false,
},
{
name: "depth too large",
depth: StemSize*8 + 1,
depth: StemSize * 8,
key: make([]byte, HashSize),
wantErr: true,
},

View file

@ -1,76 +0,0 @@
// Copyright 2025 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"slices"
"github.com/ethereum/go-ethereum/common"
)
type Empty struct{}
func (e Empty) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
return nil, nil
}
func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
var values [256][]byte
values[key[31]] = value
return &StemNode{
Stem: slices.Clone(key[:31]),
Values: values[:],
depth: depth,
mustRecompute: true,
dirty: true,
}, nil
}
func (e Empty) Copy() BinaryNode {
return Empty{}
}
func (e Empty) Hash() common.Hash {
return common.Hash{}
}
func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
var values [256][]byte
return values[:], nil
}
func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
return &StemNode{
Stem: slices.Clone(key[:31]),
Values: values,
depth: depth,
mustRecompute: true,
dirty: true,
}, nil
}
func (e Empty) CollectNodes(_ []byte, _ NodeFlushFn) error {
return nil
}
func (e Empty) toDot(parent string, path string) string {
return ""
}
func (e Empty) GetHeight() int {
return 0
}

View file

@ -1,264 +0,0 @@
// Copyright 2025 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"bytes"
"testing"
"github.com/ethereum/go-ethereum/common"
)
// TestEmptyGet tests the Get method
func TestEmptyGet(t *testing.T) {
node := Empty{}
key := make([]byte, 32)
value, err := node.Get(key, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if value != nil {
t.Errorf("Expected nil value from empty node, got %x", value)
}
}
// TestEmptyInsert tests the Insert method
func TestEmptyInsert(t *testing.T) {
node := Empty{}
key := make([]byte, 32)
key[0] = 0x12
key[31] = 0x34
value := common.HexToHash("0xabcd").Bytes()
newNode, err := node.Insert(key, value, nil, 0)
if err != nil {
t.Fatalf("Failed to insert: %v", err)
}
// Should create a StemNode
stemNode, ok := newNode.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", newNode)
}
// Check the stem (first 31 bytes of key)
if !bytes.Equal(stemNode.Stem, key[:31]) {
t.Errorf("Stem mismatch: expected %x, got %x", key[:31], stemNode.Stem)
}
// Check the value at the correct index (last byte of key)
if !bytes.Equal(stemNode.Values[key[31]], value) {
t.Errorf("Value mismatch at index %d: expected %x, got %x", key[31], value, stemNode.Values[key[31]])
}
// Check that other values are nil
for i := 0; i < 256; i++ {
if i != int(key[31]) && stemNode.Values[i] != nil {
t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i])
}
}
}
// TestEmptyCopy tests the Copy method
func TestEmptyCopy(t *testing.T) {
node := Empty{}
copied := node.Copy()
copiedEmpty, ok := copied.(Empty)
if !ok {
t.Fatalf("Expected Empty, got %T", copied)
}
// Both should be empty
if node != copiedEmpty {
// Empty is a zero-value struct, so copies should be equal
t.Errorf("Empty nodes should be equal")
}
}
// TestEmptyHash tests the Hash method
func TestEmptyHash(t *testing.T) {
node := Empty{}
hash := node.Hash()
// Empty node should have zero hash
if hash != (common.Hash{}) {
t.Errorf("Expected zero hash for empty node, got %x", hash)
}
}
// TestEmptyGetValuesAtStem tests the GetValuesAtStem method
func TestEmptyGetValuesAtStem(t *testing.T) {
node := Empty{}
stem := make([]byte, 31)
values, err := node.GetValuesAtStem(stem, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Should return an array of 256 nil values
if len(values) != 256 {
t.Errorf("Expected 256 values, got %d", len(values))
}
for i, v := range values {
if v != nil {
t.Errorf("Expected nil value at index %d, got %x", i, v)
}
}
}
// TestEmptyInsertValuesAtStem tests the InsertValuesAtStem method
func TestEmptyInsertValuesAtStem(t *testing.T) {
node := Empty{}
stem := make([]byte, 31)
stem[0] = 0x42
var values [256][]byte
values[0] = common.HexToHash("0x0101").Bytes()
values[10] = common.HexToHash("0x0202").Bytes()
values[255] = common.HexToHash("0x0303").Bytes()
newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 5)
if err != nil {
t.Fatalf("Failed to insert values: %v", err)
}
// Should create a StemNode
stemNode, ok := newNode.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", newNode)
}
// Check the stem
if !bytes.Equal(stemNode.Stem, stem) {
t.Errorf("Stem mismatch: expected %x, got %x", stem, stemNode.Stem)
}
// Check the depth
if stemNode.depth != 5 {
t.Errorf("Depth mismatch: expected 5, got %d", stemNode.depth)
}
// Check the values
if !bytes.Equal(stemNode.Values[0], values[0]) {
t.Error("Value at index 0 mismatch")
}
if !bytes.Equal(stemNode.Values[10], values[10]) {
t.Error("Value at index 10 mismatch")
}
if !bytes.Equal(stemNode.Values[255], values[255]) {
t.Error("Value at index 255 mismatch")
}
// Check that values is the same slice (not a copy)
if &stemNode.Values[0] != &values[0] {
t.Error("Expected values to be the same slice reference")
}
}
// TestEmptyCollectNodes tests the CollectNodes method
func TestEmptyCollectNodes(t *testing.T) {
node := Empty{}
var collected []BinaryNode
flushFn := func(path []byte, n BinaryNode) {
collected = append(collected, n)
}
err := node.CollectNodes([]byte{0, 1, 0}, flushFn)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Should not collect anything for empty node
if len(collected) != 0 {
t.Errorf("Expected no collected nodes for empty, got %d", len(collected))
}
}
// TestEmptyToDot tests the toDot method
func TestEmptyToDot(t *testing.T) {
node := Empty{}
dot := node.toDot("parent", "010")
// Should return empty string for empty node
if dot != "" {
t.Errorf("Expected empty string for empty node toDot, got %s", dot)
}
}
// TestEmptyGetHeight tests the GetHeight method
func TestEmptyGetHeight(t *testing.T) {
node := Empty{}
height := node.GetHeight()
// Empty node should have height 0
if height != 0 {
t.Errorf("Expected height 0 for empty node, got %d", height)
}
}
// TestEmptyInsertMarksDirty verifies that a StemNode produced by Empty.Insert
// is marked dirty. Without this, CollectNodes would skip the freshly created
// stem and its blob would never reach disk, producing "missing trie node"
// errors on subsequent reads.
func TestEmptyInsertMarksDirty(t *testing.T) {
key := make([]byte, 32)
key[0] = 0xaa
val := make([]byte, 32)
val[0] = 0xbb
n, err := Empty{}.Insert(key, val, nil, 0)
if err != nil {
t.Fatalf("Insert: %v", err)
}
sn, ok := n.(*StemNode)
if !ok {
t.Fatalf("expected *StemNode, got %T", n)
}
if !sn.dirty {
t.Fatalf("stem produced by Empty.Insert must have dirty=true")
}
}
// TestEmptyInsertValuesAtStemMarksDirty is the analogous guard for the
// bulk-insert entry point. Fresh stems created here must be dirty.
func TestEmptyInsertValuesAtStemMarksDirty(t *testing.T) {
key := make([]byte, 32)
key[0] = 0xcc
values := make([][]byte, 256)
values[0] = make([]byte, 32)
n, err := Empty{}.InsertValuesAtStem(key, values, nil, 3)
if err != nil {
t.Fatalf("InsertValuesAtStem: %v", err)
}
sn, ok := n.(*StemNode)
if !ok {
t.Fatalf("expected *StemNode, got %T", n)
}
if !sn.dirty {
t.Fatalf("stem produced by Empty.InsertValuesAtStem must have dirty=true")
}
}

View file

@ -16,75 +16,10 @@
package bintrie
import (
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
)
import "github.com/ethereum/go-ethereum/common"
// HashedNode is an unresolved node — only its hash is known.
type HashedNode common.Hash
func (h HashedNode) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
panic("not implemented") // TODO: Implement
}
func (h HashedNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
return nil, errors.New("insert not implemented for hashed node")
}
func (h HashedNode) Copy() BinaryNode {
nh := common.Hash(h)
return HashedNode(nh)
}
func (h HashedNode) Hash() common.Hash {
return common.Hash(h)
}
func (h HashedNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
return nil, errors.New("attempted to get values from an unresolved node")
}
func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
// Step 1: Generate the path for this node's position in the tree
path, err := keyToPath(depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem path generation error: %w", err)
}
if resolver == nil {
return nil, errors.New("InsertValuesAtStem resolve error: resolver is nil")
}
// Step 2: Resolve the hashed node to get the actual node data
data, err := resolver(path, common.Hash(h))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
// Step 3: Deserialize the resolved data into a concrete node
node, err := DeserializeNodeWithHash(data, depth, common.Hash(h))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
// Step 4: Call InsertValuesAtStem on the resolved concrete node
return node.InsertValuesAtStem(stem, values, resolver, depth)
}
func (h HashedNode) toDot(parent string, path string) string {
me := fmt.Sprintf("hash%s", path)
ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, h)
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
return ret
}
func (h HashedNode) CollectNodes([]byte, NodeFlushFn) error {
// HashedNodes are already persisted in the database and don't need to be collected.
return nil
}
func (h HashedNode) GetHeight() int {
panic("tried to get the height of a hashed node, this is a bug")
}
// Hash returns the node's hash.
func (h HashedNode) Hash() common.Hash { return common.Hash(h) }

View file

@ -18,180 +18,137 @@ package bintrie
import (
"bytes"
"errors"
"testing"
"github.com/ethereum/go-ethereum/common"
)
// TestHashedNodeHash tests the Hash method
// TestHashedNodeHash tests the Hash method via nodeStore.
func TestHashedNodeHash(t *testing.T) {
hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
node := HashedNode(hash)
s := newNodeStore()
ref := s.newHashedRef(hash)
// Hash should return the stored hash
if node.Hash() != hash {
t.Errorf("Hash mismatch: expected %x, got %x", hash, node.Hash())
if s.computeHash(ref) != hash {
t.Errorf("Hash mismatch: expected %x, got %x", hash, s.computeHash(ref))
}
}
// TestHashedNodeCopy tests the Copy method
// TestHashedNodeCopy tests the Copy method via nodeStore.
func TestHashedNodeCopy(t *testing.T) {
hash := common.HexToHash("0xabcdef")
node := HashedNode(hash)
s := newNodeStore()
ref := s.newHashedRef(hash)
s.root = ref
copied := node.Copy()
copiedHash, ok := copied.(HashedNode)
if !ok {
t.Fatalf("Expected HashedNode, got %T", copied)
}
ns := s.Copy()
copiedHash := ns.computeHash(ns.root)
// Hash should be the same
if common.Hash(copiedHash) != hash {
if copiedHash != hash {
t.Errorf("Hash mismatch after copy: expected %x, got %x", hash, copiedHash)
}
// But should be a different object
if &node == &copiedHash {
t.Error("Copy returned same object reference")
}
}
// TestHashedNodeInsert tests that Insert returns an error
func TestHashedNodeInsert(t *testing.T) {
node := HashedNode(common.HexToHash("0x1234"))
key := make([]byte, HashSize)
value := make([]byte, HashSize)
_, err := node.Insert(key, value, nil, 0)
if err == nil {
t.Fatal("Expected error for Insert on HashedNode")
}
if err.Error() != "insert not implemented for hashed node" {
t.Errorf("Unexpected error message: %v", err)
}
}
// TestHashedNodeGetValuesAtStem tests that GetValuesAtStem returns an error
func TestHashedNodeGetValuesAtStem(t *testing.T) {
node := HashedNode(common.HexToHash("0x1234"))
stem := make([]byte, StemSize)
_, err := node.GetValuesAtStem(stem, nil)
if err == nil {
t.Fatal("Expected error for GetValuesAtStem on HashedNode")
}
if err.Error() != "attempted to get values from an unresolved node" {
t.Errorf("Unexpected error message: %v", err)
}
}
// TestHashedNodeInsertValuesAtStem tests that InsertValuesAtStem returns an error
// TestHashedNodeInsertValuesAtStem tests InsertValuesAtStem resolution via nodeStore.
func TestHashedNodeInsertValuesAtStem(t *testing.T) {
node := HashedNode(common.HexToHash("0x1234"))
// Test 1: nil resolver should return an error
s := newNodeStore()
hashedRef := s.newHashedRef(common.HexToHash("0x1234"))
s.root = hashedRef
stem := make([]byte, StemSize)
values := make([][]byte, StemNodeWidth)
// Test 1: nil resolver should return an error
_, err := node.InsertValuesAtStem(stem, values, nil, 0)
err := s.InsertValuesAtStem(stem, values, nil)
if err == nil {
t.Fatal("Expected error for InsertValuesAtStem on HashedNode with nil resolver")
}
if err.Error() != "InsertValuesAtStem resolve error: resolver is nil" {
t.Errorf("Unexpected error message: %v", err)
t.Fatal("Expected error for InsertValuesAtStem with nil resolver")
}
// Test 2: mock resolver returning invalid data should return deserialization error
mockResolver := func(path []byte, hash common.Hash) ([]byte, error) {
// Return invalid/nonsense data that cannot be deserialized
return []byte{0xff, 0xff, 0xff}, nil
}
_, err = node.InsertValuesAtStem(stem, values, mockResolver, 0)
if err == nil {
t.Fatal("Expected error for InsertValuesAtStem on HashedNode with invalid resolver data")
}
s2 := newNodeStore()
hashedRef2 := s2.newHashedRef(common.HexToHash("0x1234"))
s2.root = hashedRef2
expectedPrefix := "InsertValuesAtStem node deserialization error:"
if len(err.Error()) < len(expectedPrefix) || err.Error()[:len(expectedPrefix)] != expectedPrefix {
t.Errorf("Expected deserialization error, got: %v", err)
err = s2.InsertValuesAtStem(stem, values, mockResolver)
if err == nil {
t.Fatal("Expected error for InsertValuesAtStem with invalid resolver data")
}
// Test 3: mock resolver returning valid serialized node should succeed
stem = make([]byte, StemSize)
stem[0] = 0xaa
var originalValues [StemNodeWidth][]byte
originalValues := make([][]byte, StemNodeWidth)
originalValues[0] = common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111").Bytes()
originalValues[1] = common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222").Bytes()
originalNode := &StemNode{
Stem: stem,
Values: originalValues[:],
depth: 0,
// Build the serialized node
rs := newNodeStore()
ref := rs.newStemRef(stem, 0)
sn := rs.getStem(ref.Index())
for i, v := range originalValues {
if v != nil {
sn.setValue(byte(i), v)
}
}
serialized := rs.serializeNode(ref)
// Serialize the node
serialized := SerializeNode(originalNode)
// Create a mock resolver that returns the serialized node
validResolver := func(path []byte, hash common.Hash) ([]byte, error) {
return serialized, nil
}
var newValues [StemNodeWidth][]byte
s3 := newNodeStore()
hashedRef3 := s3.newHashedRef(common.HexToHash("0x1234"))
s3.root = hashedRef3
newValues := make([][]byte, StemNodeWidth)
newValues[2] = common.HexToHash("0x3333333333333333333333333333333333333333333333333333333333333333").Bytes()
resolvedNode, err := node.InsertValuesAtStem(stem, newValues[:], validResolver, 0)
err = s3.InsertValuesAtStem(stem, newValues, validResolver)
if err != nil {
t.Fatalf("Expected successful resolution and insertion, got error: %v", err)
}
resultStem, ok := resolvedNode.(*StemNode)
if !ok {
t.Fatalf("Expected resolved node to be *StemNode, got %T", resolvedNode)
// Verify original values are preserved
retrieved, err := s3.GetValuesAtStem(stem, nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(resultStem.Stem, stem) {
t.Errorf("Stem mismatch: expected %x, got %x", stem, resultStem.Stem)
if !bytes.Equal(retrieved[0], originalValues[0]) {
t.Errorf("Original value at index 0 not preserved")
}
// Verify the original values are preserved
if !bytes.Equal(resultStem.Values[0], originalValues[0]) {
t.Errorf("Original value at index 0 not preserved: expected %x, got %x", originalValues[0], resultStem.Values[0])
if !bytes.Equal(retrieved[1], originalValues[1]) {
t.Errorf("Original value at index 1 not preserved")
}
if !bytes.Equal(resultStem.Values[1], originalValues[1]) {
t.Errorf("Original value at index 1 not preserved: expected %x, got %x", originalValues[1], resultStem.Values[1])
}
// Verify the new value was inserted
if !bytes.Equal(resultStem.Values[2], newValues[2]) {
t.Errorf("New value at index 2 not inserted correctly: expected %x, got %x", newValues[2], resultStem.Values[2])
if !bytes.Equal(retrieved[2], newValues[2]) {
t.Errorf("New value at index 2 not inserted correctly")
}
}
// TestHashedNodeToDot tests the toDot method for visualization
func TestHashedNodeToDot(t *testing.T) {
hash := common.HexToHash("0x1234")
node := HashedNode(hash)
// TestHashedNodeGetError tests that getting through an unresolved HashedNode root returns error.
func TestHashedNodeGetError(t *testing.T) {
s := newNodeStore()
// Create root as hashed, then try to resolve through InternalNode parent
rootRef := s.newInternalRef(0)
rootNode := s.getInternal(rootRef.Index())
hashedLeft := s.newHashedRef(common.HexToHash("0x1234"))
rootNode.left = hashedLeft
rootNode.right = emptyRef
s.root = rootRef
dot := node.toDot("parent", "010")
key := make([]byte, 32) // goes left
key[31] = 5
// Should contain the hash value and parent connection
expectedHash := "hash010"
if !contains(dot, expectedHash) {
t.Errorf("Expected dot output to contain %s", expectedHash)
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
return nil, errors.New("node not found")
}
if !contains(dot, "parent -> hash010") {
t.Error("Expected dot output to contain parent connection")
_, err := s.Get(key, resolver)
if err == nil {
t.Fatal("Expected error when resolver fails")
}
}
// Helper function
func contains(s, substr string) bool {
return len(s) >= len(substr) && s != "" && len(substr) > 0
}

View file

@ -17,35 +17,13 @@
package bintrie
import (
"crypto/sha256"
"errors"
"fmt"
"math/bits"
"runtime"
"sync"
"github.com/ethereum/go-ethereum/common"
)
// parallelDepth returns the tree depth below which Hash() spawns goroutines.
func parallelDepth() int {
return min(bits.Len(uint(runtime.NumCPU())), 8)
}
// isDirty reports whether a BinaryNode child needs rehashing.
func isDirty(n BinaryNode) bool {
switch v := n.(type) {
case *InternalNode:
return v.mustRecompute
case *StemNode:
return v.mustRecompute
default:
return false
}
}
func keyToPath(depth int, key []byte) ([]byte, error) {
if depth > 31*8 {
if depth >= 31*8 {
return nil, errors.New("node too deep")
}
path := make([]byte, 0, depth+1)
@ -56,252 +34,12 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
return path, nil
}
// InternalNode is a binary trie internal node.
// Invariant: dirty=false implies mustRecompute=false. Every mutation that
// invalidates the cached hash MUST also mark the blob for re-flush.
type InternalNode struct {
left, right BinaryNode
depth int
mustRecompute bool // true if the hash needs to be recomputed
dirty bool // true if the node's on-disk blob is stale (needs flush)
hash common.Hash // cached hash when mustRecompute == false
}
// GetValuesAtStem retrieves the group of values located at the given stem key.
func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) {
if bt.depth > 31*8 {
return nil, errors.New("node too deep")
}
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 {
if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
return bt.left.GetValuesAtStem(stem, resolver)
}
if hn, ok := bt.right.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.right = node
}
return bt.right.GetValuesAtStem(stem, resolver)
}
// Get retrieves the value for the given key.
func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error) {
values, err := bt.GetValuesAtStem(key[:31], resolver)
if err != nil {
return nil, fmt.Errorf("get error: %w", err)
}
if values == nil {
return nil, nil
}
return values[key[31]], nil
}
// Insert inserts a new key-value pair into the trie.
func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
var values [256][]byte
values[key[31]] = value
return bt.InsertValuesAtStem(key[:31], values[:], resolver, depth)
}
// Copy creates a deep copy of the node.
func (bt *InternalNode) Copy() BinaryNode {
return &InternalNode{
left: bt.left.Copy(),
right: bt.right.Copy(),
depth: bt.depth,
mustRecompute: bt.mustRecompute,
dirty: bt.dirty,
hash: bt.hash,
}
}
// Hash returns the hash of the node.
func (bt *InternalNode) Hash() common.Hash {
if !bt.mustRecompute {
return bt.hash
}
// At shallow depths, parallelize when both children need rehashing:
// hash left subtree in a goroutine, right subtree inline, then combine.
// Skip goroutine overhead when only one child is dirty (common case
// for narrow state updates that touch a single path through the trie).
if bt.depth < parallelDepth() && isDirty(bt.left) && isDirty(bt.right) {
var input [64]byte
var lh common.Hash
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
lh = bt.left.Hash()
}()
rh := bt.right.Hash()
copy(input[32:], rh[:])
wg.Wait()
copy(input[:32], lh[:])
bt.hash = sha256.Sum256(input[:])
bt.mustRecompute = false
return bt.hash
}
// Deeper nodes: sequential using pooled hasher (goroutine overhead > hash cost)
h := newSha256()
defer returnSha256(h)
if bt.left != nil {
h.Write(bt.left.Hash().Bytes())
} else {
h.Write(zero[:])
}
if bt.right != nil {
h.Write(bt.right.Hash().Bytes())
} else {
h.Write(zero[:])
}
bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false
return bt.hash
}
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
// Already-existing values will be overwritten.
func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
var err error
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 {
if bt.left == nil {
bt.left = Empty{}
}
if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
bt.dirty = true
return bt, err
}
if bt.right == nil {
bt.right = Empty{}
}
if hn, ok := bt.right.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.right = node
}
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
bt.dirty = true
return bt, err
}
// CollectNodes collects all child nodes at a given path, and flushes it
// into the provided node collector. Clean subtrees (dirty == false) are
// skipped.
func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
if !bt.dirty {
return nil
}
if bt.left != nil {
var p [256]byte
copy(p[:], path)
childpath := p[:len(path)]
childpath = append(childpath, 0)
if err := bt.left.CollectNodes(childpath, flushfn); err != nil {
return err
}
}
if bt.right != nil {
var p [256]byte
copy(p[:], path)
childpath := p[:len(path)]
childpath = append(childpath, 1)
if err := bt.right.CollectNodes(childpath, flushfn); err != nil {
return err
}
}
flushfn(path, bt)
bt.dirty = false
return nil
}
// GetHeight returns the height of the node.
func (bt *InternalNode) GetHeight() int {
var (
leftHeight int
rightHeight int
)
if bt.left != nil {
leftHeight = bt.left.GetHeight()
}
if bt.right != nil {
rightHeight = bt.right.GetHeight()
}
return 1 + max(leftHeight, rightHeight)
}
func (bt *InternalNode) toDot(parent, path string) string {
me := fmt.Sprintf("internal%s", path)
ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, bt.Hash())
if len(parent) > 0 {
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
}
if bt.left != nil {
ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0)))
}
if bt.right != nil {
ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1)))
}
return ret
left, right nodeRef
depth uint8
mustRecompute bool // hash is stale (cleared by Hash)
dirty bool // on-disk blob is stale (cleared by CollectNodes)
hash common.Hash
}

View file

@ -24,35 +24,33 @@ import (
"github.com/ethereum/go-ethereum/common"
)
// TestInternalNodeGet tests the Get method
// TestInternalNodeGet tests the Get method via nodeStore.
func TestInternalNodeGet(t *testing.T) {
// Create a simple tree structure
s := newNodeStore()
leftStem := make([]byte, 31)
rightStem := make([]byte, 31)
rightStem[0] = 0x80 // First bit is 1
rightStem[0] = 0x80
var leftValues, rightValues [256][]byte
leftValues := make([][]byte, 256)
leftValues[0] = common.HexToHash("0x0101").Bytes()
rightValues := make([][]byte, 256)
rightValues[0] = common.HexToHash("0x0202").Bytes()
node := &InternalNode{
depth: 0,
left: &StemNode{
Stem: leftStem,
Values: leftValues[:],
depth: 1,
},
right: &StemNode{
Stem: rightStem,
Values: rightValues[:],
depth: 1,
},
// Build tree: root -> left stem, right stem
// Insert left stem values
s.root = emptyRef
if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
t.Fatal(err)
}
if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
t.Fatal(err)
}
// Get value from left subtree
leftKey := make([]byte, 32)
leftKey[31] = 0
value, err := node.Get(leftKey, nil)
value, err := s.Get(leftKey, nil)
if err != nil {
t.Fatalf("Failed to get left value: %v", err)
}
@ -64,7 +62,7 @@ func TestInternalNodeGet(t *testing.T) {
rightKey := make([]byte, 32)
rightKey[0] = 0x80
rightKey[31] = 0
value, err = node.Get(rightKey, nil)
value, err = s.Get(rightKey, nil)
if err != nil {
t.Fatalf("Failed to get right value: %v", err)
}
@ -73,29 +71,26 @@ func TestInternalNodeGet(t *testing.T) {
}
}
// TestInternalNodeGetWithResolver tests Get with HashedNode resolution
// TestInternalNodeGetWithResolver tests Get with HashedNode resolution via nodeStore.
func TestInternalNodeGetWithResolver(t *testing.T) {
// Create an internal node with a hashed child
hashedChild := HashedNode(common.HexToHash("0x1234"))
node := &InternalNode{
depth: 0,
left: hashedChild,
right: Empty{},
}
// Create a store with an internal node containing a hashed child
s := newNodeStore()
hashedChild := s.newHashedRef(common.HexToHash("0x1234"))
rootRef := s.newInternalRef(0)
rootNode := s.getInternal(rootRef.Index())
rootNode.left = hashedChild
rootNode.right = emptyRef
s.root = rootRef
// Mock resolver that returns a stem node
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
if hash == common.Hash(hashedChild) {
if hash == common.HexToHash("0x1234") {
rs := newNodeStore()
stem := make([]byte, 31)
var values [256][]byte
values[5] = common.HexToHash("0xabcd").Bytes()
stemNode := &StemNode{
Stem: stem,
Values: values[:],
depth: 1,
}
return SerializeNode(stemNode), nil
ref := rs.newStemRef(stem, 1)
sn := rs.getStem(ref.Index())
sn.setValue(5, common.HexToHash("0xabcd").Bytes())
return rs.serializeNode(ref), nil
}
return nil, errors.New("node not found")
}
@ -103,7 +98,7 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
// Get value through the hashed node
key := make([]byte, 32)
key[31] = 5
value, err := node.Get(key, resolver)
value, err := s.Get(key, resolver)
if err != nil {
t.Fatalf("Failed to get value: %v", err)
}
@ -114,179 +109,113 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
}
}
// TestInternalNodeInsert tests the Insert method
// TestInternalNodeInsert tests the Insert method via nodeStore.
func TestInternalNodeInsert(t *testing.T) {
// Start with an internal node with empty children
node := &InternalNode{
depth: 0,
left: Empty{},
right: Empty{},
}
s := newNodeStore()
// Insert a value into the left subtree
leftKey := make([]byte, 32)
leftKey[31] = 10
leftValue := common.HexToHash("0x0101").Bytes()
newNode, err := node.Insert(leftKey, leftValue, nil, 0)
if err != nil {
if err := s.Insert(leftKey, leftValue, nil); err != nil {
t.Fatalf("Failed to insert: %v", err)
}
internalNode, ok := newNode.(*InternalNode)
if !ok {
t.Fatalf("Expected InternalNode, got %T", newNode)
// Verify the value was stored
value, err := s.Get(leftKey, nil)
if err != nil {
t.Fatalf("Failed to get: %v", err)
}
// Check that left child is now a StemNode
leftStem, ok := internalNode.left.(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
}
// Check the inserted value
if !bytes.Equal(leftStem.Values[10], leftValue) {
t.Errorf("Value mismatch: expected %x, got %x", leftValue, leftStem.Values[10])
}
// Right child should still be Empty
_, ok = internalNode.right.(Empty)
if !ok {
t.Errorf("Expected right child to remain Empty, got %T", internalNode.right)
if !bytes.Equal(value, leftValue) {
t.Errorf("Value mismatch: expected %x, got %x", leftValue, value)
}
}
// TestInternalNodeCopy tests the Copy method
// TestInternalNodeCopy tests the Copy method via nodeStore.
func TestInternalNodeCopy(t *testing.T) {
// Create an internal node with stem children
leftStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
}
leftStem.Values[0] = common.HexToHash("0x0101").Bytes()
s := newNodeStore()
rightStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
}
rightStem.Stem[0] = 0x80
rightStem.Values[0] = common.HexToHash("0x0202").Bytes()
leftKey := make([]byte, 32)
leftKey[31] = 0
leftValue := common.HexToHash("0x0101").Bytes()
node := &InternalNode{
depth: 0,
left: leftStem,
right: rightStem,
rightKey := make([]byte, 32)
rightKey[0] = 0x80
rightKey[31] = 0
rightValue := common.HexToHash("0x0202").Bytes()
if err := s.Insert(leftKey, leftValue, nil); err != nil {
t.Fatal(err)
}
if err := s.Insert(rightKey, rightValue, nil); err != nil {
t.Fatal(err)
}
// Create a copy
copied := node.Copy()
copiedInternal, ok := copied.(*InternalNode)
if !ok {
t.Fatalf("Expected InternalNode, got %T", copied)
}
ns := s.Copy()
// Check depth
if copiedInternal.depth != node.depth {
t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedInternal.depth)
}
// Check that children are copied
copiedLeft, ok := copiedInternal.left.(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.left)
}
copiedRight, ok := copiedInternal.right.(*StemNode)
if !ok {
t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.right)
}
// Verify deep copy (children should be different objects)
if copiedLeft == leftStem {
t.Error("Left child not properly copied")
}
if copiedRight == rightStem {
t.Error("Right child not properly copied")
}
// But values should be equal
if !bytes.Equal(copiedLeft.Values[0], leftStem.Values[0]) {
// Values should be equal
v1, _ := ns.Get(leftKey, nil)
if !bytes.Equal(v1, leftValue) {
t.Error("Left child value mismatch after copy")
}
if !bytes.Equal(copiedRight.Values[0], rightStem.Values[0]) {
v2, _ := ns.Get(rightKey, nil)
if !bytes.Equal(v2, rightValue) {
t.Error("Right child value mismatch after copy")
}
}
// TestInternalNodeHash tests the Hash method
// TestInternalNodeHash tests the Hash method via nodeStore.
func TestInternalNodeHash(t *testing.T) {
// Create an internal node
node := &InternalNode{
depth: 0,
left: HashedNode(common.HexToHash("0x1111")),
right: HashedNode(common.HexToHash("0x2222")),
}
s := newNodeStore()
leftRef := s.newHashedRef(common.HexToHash("0x1111"))
rightRef := s.newHashedRef(common.HexToHash("0x2222"))
rootRef := s.newInternalRef(0)
rootNode := s.getInternal(rootRef.Index())
rootNode.left = leftRef
rootNode.right = rightRef
s.root = rootRef
hash1 := node.Hash()
hash1 := s.computeHash(rootRef)
// Hash should be deterministic
hash2 := node.Hash()
hash2 := s.computeHash(rootRef)
if hash1 != hash2 {
t.Errorf("Hash not deterministic: %x != %x", hash1, hash2)
}
// Changing a child should change the hash
node.left = HashedNode(common.HexToHash("0x3333"))
node.mustRecompute = true
hash3 := node.Hash()
rootNode.left = s.newHashedRef(common.HexToHash("0x3333"))
rootNode.mustRecompute = true
hash3 := s.computeHash(rootRef)
if hash1 == hash3 {
t.Error("Hash didn't change after modifying left child")
}
// Test with nil children (should use zero hash)
nodeWithNil := &InternalNode{
depth: 0,
left: nil,
right: HashedNode(common.HexToHash("0x4444")),
mustRecompute: true,
}
hashWithNil := nodeWithNil.Hash()
if hashWithNil == (common.Hash{}) {
t.Error("Hash shouldn't be zero even with nil child")
}
}
// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method
// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method via nodeStore.
func TestInternalNodeGetValuesAtStem(t *testing.T) {
// Create a tree with values at different stems
s := newNodeStore()
leftStem := make([]byte, 31)
rightStem := make([]byte, 31)
rightStem[0] = 0x80
var leftValues, rightValues [256][]byte
leftValues := make([][]byte, 256)
leftValues[0] = common.HexToHash("0x0101").Bytes()
leftValues[10] = common.HexToHash("0x0102").Bytes()
rightValues := make([][]byte, 256)
rightValues[0] = common.HexToHash("0x0201").Bytes()
rightValues[20] = common.HexToHash("0x0202").Bytes()
node := &InternalNode{
depth: 0,
left: &StemNode{
Stem: leftStem,
Values: leftValues[:],
depth: 1,
},
right: &StemNode{
Stem: rightStem,
Values: rightValues[:],
depth: 1,
},
if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
t.Fatal(err)
}
if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
t.Fatal(err)
}
// Get values from left stem
values, err := node.GetValuesAtStem(leftStem, nil)
values, err := s.GetValuesAtStem(leftStem, nil)
if err != nil {
t.Fatalf("Failed to get left values: %v", err)
}
@ -298,7 +227,7 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
}
// Get values from right stem
values, err = node.GetValuesAtStem(rightStem, nil)
values, err = s.GetValuesAtStem(rightStem, nil)
if err != nil {
t.Fatalf("Failed to get right values: %v", err)
}
@ -310,201 +239,103 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
}
}
// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method
// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method via nodeStore.
func TestInternalNodeInsertValuesAtStem(t *testing.T) {
// Start with an internal node with empty children
node := &InternalNode{
depth: 0,
left: Empty{},
right: Empty{},
}
s := newNodeStore()
// Insert values at a stem in the left subtree
stem := make([]byte, 31)
var values [256][]byte
values := make([][]byte, 256)
values[5] = common.HexToHash("0x0505").Bytes()
values[10] = common.HexToHash("0x1010").Bytes()
newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 0)
if err != nil {
if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
t.Fatalf("Failed to insert values: %v", err)
}
internalNode, ok := newNode.(*InternalNode)
if !ok {
t.Fatalf("Expected InternalNode, got %T", newNode)
// Check that the values are stored
retrieved, err := s.GetValuesAtStem(stem, nil)
if err != nil {
t.Fatalf("Failed to get values: %v", err)
}
// Check that left child is now a StemNode with the values
leftStem, ok := internalNode.left.(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
}
if !bytes.Equal(leftStem.Values[5], values[5]) {
if !bytes.Equal(retrieved[5], values[5]) {
t.Error("Value at index 5 mismatch")
}
if !bytes.Equal(leftStem.Values[10], values[10]) {
if !bytes.Equal(retrieved[10], values[10]) {
t.Error("Value at index 10 mismatch")
}
}
// TestInternalNodeCollectNodes tests CollectNodes method
// TestInternalNodeCollectNodes tests CollectNodes method via nodeStore.
func TestInternalNodeCollectNodes(t *testing.T) {
// Create an internal node with two stem children. All three are
// marked dirty to mirror production semantics — see CollectNodes.
leftStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
dirty: true,
}
s := newNodeStore()
rightStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
dirty: true,
}
rightStem.Stem[0] = 0x80
leftStem := make([]byte, 31)
rightStem := make([]byte, 31)
rightStem[0] = 0x80
node := &InternalNode{
depth: 0,
left: leftStem,
right: rightStem,
dirty: true,
leftValues := make([][]byte, 256)
rightValues := make([][]byte, 256)
if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
t.Fatal(err)
}
if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
t.Fatal(err)
}
var collectedPaths [][]byte
var collectedNodes []BinaryNode
flushFn := func(path []byte, n BinaryNode) {
flushFn := func(path []byte, hash common.Hash, serialized []byte) {
pathCopy := make([]byte, len(path))
copy(pathCopy, path)
collectedPaths = append(collectedPaths, pathCopy)
collectedNodes = append(collectedNodes, n)
}
err := node.CollectNodes([]byte{1}, flushFn)
err := s.collectNodes(s.root, []byte{1}, flushFn)
if err != nil {
t.Fatalf("Failed to collect nodes: %v", err)
}
// Should have collected 3 nodes: left stem, right stem, and the internal node itself
if len(collectedNodes) != 3 {
t.Errorf("Expected 3 collected nodes, got %d", len(collectedNodes))
}
// Check paths
expectedPaths := [][]byte{
{1, 0}, // left child
{1, 1}, // right child
{1}, // internal node itself
}
for i, expectedPath := range expectedPaths {
if !bytes.Equal(collectedPaths[i], expectedPath) {
t.Errorf("Path %d mismatch: expected %v, got %v", i, expectedPath, collectedPaths[i])
}
if len(collectedPaths) != 3 {
t.Errorf("Expected 3 collected nodes, got %d", len(collectedPaths))
}
}
// TestInternalNodeCollectNodesSkipsClean verifies clean subtrees are not
// flushed. A dirty internal node over clean children only flushes itself;
// a fully clean tree flushes nothing.
func TestInternalNodeCollectNodesSkipsClean(t *testing.T) {
leftStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
}
rightStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
}
rightStem.Stem[0] = 0x80
dirtyParent := &InternalNode{
depth: 0,
left: leftStem,
right: rightStem,
dirty: true,
}
var collected []BinaryNode
flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) }
if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil {
t.Fatalf("CollectNodes: %v", err)
}
if len(collected) != 1 || collected[0] != dirtyParent {
t.Fatalf("expected only the dirty parent to be flushed, got %d nodes", len(collected))
}
if dirtyParent.dirty {
t.Errorf("parent dirty flag should be cleared after flush")
}
// Second call on the same tree should be a no-op: everything is clean.
collected = nil
if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil {
t.Fatalf("CollectNodes (second call): %v", err)
}
if len(collected) != 0 {
t.Errorf("expected no nodes to be flushed on clean tree, got %d", len(collected))
}
}
// TestInternalNodeGetHeight tests GetHeight method
// TestInternalNodeGetHeight tests GetHeight method via nodeStore.
func TestInternalNodeGetHeight(t *testing.T) {
// Create a tree with different heights
// Left subtree: depth 2 (internal -> stem)
// Right subtree: depth 1 (stem)
leftInternal := &InternalNode{
depth: 1,
left: &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 2,
},
right: Empty{},
s := newNodeStore()
// Insert values that create a deeper tree
stem1 := make([]byte, 31) // left
stem2 := make([]byte, 31)
stem2[0] = 0x40 // 01... -> goes left at depth 0, right at depth 1
values1 := make([][]byte, 256)
values1[0] = common.HexToHash("0x01").Bytes()
values2 := make([][]byte, 256)
values2[0] = common.HexToHash("0x02").Bytes()
if err := s.InsertValuesAtStem(stem1, values1, nil); err != nil {
t.Fatal(err)
}
if err := s.InsertValuesAtStem(stem2, values2, nil); err != nil {
t.Fatal(err)
}
rightStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
}
node := &InternalNode{
depth: 0,
left: leftInternal,
right: rightStem,
}
height := node.GetHeight()
// Height should be max(left height, right height) + 1
// Left height: 2, Right height: 1, so total: 3
if height != 3 {
t.Errorf("Expected height 3, got %d", height)
height := s.getHeight(s.root)
if height < 2 {
t.Errorf("Expected height >= 2, got %d", height)
}
}
// TestInternalNodeDepthTooLarge tests handling of excessive depth
// TestInternalNodeDepthTooLarge tests handling of excessive depth via nodeStore.
func TestInternalNodeDepthTooLarge(t *testing.T) {
// Create an internal node at max depth
node := &InternalNode{
depth: 31*8 + 1,
left: Empty{},
right: Empty{},
}
stem := make([]byte, 31)
_, err := node.GetValuesAtStem(stem, nil)
if err == nil {
t.Fatal("Expected error for excessive depth")
}
if err.Error() != "node too deep" {
t.Errorf("Expected 'node too deep' error, got: %v", err)
}
s := newNodeStore()
// Creating an internal node beyond max depth should panic
defer func() {
if r := recover(); r == nil {
t.Fatal("Expected panic for excessive depth")
}
}()
s.newInternalRef(31*8 + 1)
}

View file

@ -26,13 +26,14 @@ import (
var errIteratorEnd = errors.New("end of iteration")
type binaryNodeIteratorState struct {
Node BinaryNode
Node nodeRef
Index int
}
type binaryNodeIterator struct {
trie *BinaryTrie
current BinaryNode
store *nodeStore
current nodeRef
lastErr error
stack []binaryNodeIteratorState
@ -40,56 +41,63 @@ type binaryNodeIterator struct {
func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) {
if t.Hash() == zero {
return &binaryNodeIterator{trie: t, lastErr: errIteratorEnd}, nil
return &binaryNodeIterator{trie: t, store: t.store, lastErr: errIteratorEnd}, nil
}
it := &binaryNodeIterator{trie: t, current: t.root}
// it.err = it.seek(start)
it := &binaryNodeIterator{trie: t, store: t.store, current: t.store.root}
return it, nil
}
// Next moves the iterator to the next node. If the parameter is false, any child
// nodes will be skipped.
// Next moves the iterator to the next node. If descend is false, children of
// the current node are skipped.
func (it *binaryNodeIterator) Next(descend bool) bool {
if it.lastErr == errIteratorEnd {
it.lastErr = errIteratorEnd
return false
}
if len(it.stack) == 0 {
it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.root})
it.current = it.trie.root
it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.store.root})
it.current = it.trie.store.root
return true
}
switch node := it.current.(type) {
case *InternalNode:
// index: 0 = nothing visited, 1=left visited, 2=right visited
switch it.current.Kind() {
case kindInternal:
// index: 0 = nothing visited, 1 = left visited, 2 = right visited.
node := it.store.getInternal(it.current.Index())
context := &it.stack[len(it.stack)-1]
// recurse into both children
if !descend {
// Skip children: pop this node and advance parent.
if len(it.stack) == 1 {
it.lastErr = errIteratorEnd
return false
}
it.stack = it.stack[:len(it.stack)-1]
it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++
return it.Next(true)
}
// Recurse into both children.
if context.Index == 0 {
if _, isempty := node.left.(Empty); node.left != nil && !isempty {
if !node.left.IsEmpty() {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left})
it.current = node.left
return it.Next(descend)
}
context.Index++
}
if context.Index == 1 {
if _, isempty := node.right.(Empty); node.right != nil && !isempty {
if !node.right.IsEmpty() {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right})
it.current = node.right
return it.Next(descend)
}
context.Index++
}
// Reached the end of this node, go back to the parent, if
// this isn't root.
// Reached the end of this node; go back to the parent unless we're at the root.
if len(it.stack) == 1 {
it.lastErr = errIteratorEnd
return false
@ -98,17 +106,18 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++
return it.Next(descend)
case *StemNode:
// Look for the next non-empty value
case kindStem:
// Look for the next non-empty value in this stem.
sn := it.store.getStem(it.current.Index())
for i := it.stack[len(it.stack)-1].Index; i < 256; i++ {
if node.Values[i] != nil {
if sn.hasValue(byte(i)) {
it.stack[len(it.stack)-1].Index = i + 1
return true
}
}
// go back to parent to get the next leaf
// Check if we're at the root before popping
// No more values in this stem; go back to parent to get the next leaf.
if len(it.stack) == 1 {
it.lastErr = errIteratorEnd
return false
@ -117,51 +126,47 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++
return it.Next(descend)
case HashedNode:
// resolve the node
resolverPath := it.Path()
data, err := it.trie.nodeResolver(resolverPath, common.Hash(node))
if err != nil {
panic(err)
}
if data == nil {
// Empty/nil node — treat as Empty, backtrack
it.current = Empty{}
it.stack[len(it.stack)-1].Node = it.current
return it.Next(descend)
}
it.current, err = DeserializeNodeWithHash(data, len(it.stack)-1, common.Hash(node))
if err != nil {
panic(err)
}
// update the stack and parent with the resolved node
it.stack[len(it.stack)-1].Node = it.current
if len(it.stack) >= 2 {
parent := &it.stack[len(it.stack)-2]
if parent.Index == 0 {
parent.Node.(*InternalNode).left = it.current
} else {
parent.Node.(*InternalNode).right = it.current
}
}
return it.Next(descend)
case Empty:
// Empty node - go back to parent and continue
if len(it.stack) <= 1 {
it.lastErr = errIteratorEnd
case kindHashed:
// Resolve the hashed node from disk, then rewire the parent to point at the
// resolved node in place.
if len(it.stack) < 2 {
it.lastErr = errors.New("cannot resolve hashed root during iteration")
return false
}
it.stack = it.stack[:len(it.stack)-1]
it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++
hn := it.store.getHashed(it.current.Index())
data, err := it.trie.nodeResolver(it.Path(), hn.Hash())
if err != nil {
it.lastErr = err
return false
}
resolved, err := it.store.deserializeNodeWithHash(data, len(it.stack)-1, hn.Hash())
if err != nil {
it.lastErr = err
return false
}
oldHashedIdx := it.current.Index()
it.current = resolved
it.stack[len(it.stack)-1].Node = resolved
parent := &it.stack[len(it.stack)-2]
parentNode := it.store.getInternal(parent.Node.Index())
if parent.Index == 0 {
parentNode.left = resolved
} else {
parentNode.right = resolved
}
it.store.freeHashedNode(oldHashedIdx)
return it.Next(descend)
case kindEmpty:
return false
default:
panic("invalid node type")
}
}
// Error returns the error status of the iterator.
func (it *binaryNodeIterator) Error() error {
if it.lastErr == errIteratorEnd {
return nil
@ -169,27 +174,28 @@ func (it *binaryNodeIterator) Error() error {
return it.lastErr
}
// Hash returns the hash of the current node.
func (it *binaryNodeIterator) Hash() common.Hash {
return it.current.Hash()
return it.store.computeHash(it.current)
}
// Parent returns the hash of the parent of the current node. The hash may be the one
// grandparent if the immediate parent is an internal node with no hash.
// Parent returns the hash of the current node's parent. When the immediate
// parent is an internal node whose hash has not been materialised, the
// returned hash may be the one of a grandparent instead.
func (it *binaryNodeIterator) Parent() common.Hash {
return it.stack[len(it.stack)-1].Node.Hash()
if len(it.stack) < 2 {
return common.Hash{}
}
return it.store.computeHash(it.stack[len(it.stack)-2].Node)
}
// Path returns the hex-encoded path to the current node.
// Callers must not retain references to the return value after calling Next.
// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10.
// Path returns the bit-path to the current node.
// Callers must not retain references to the returned slice after calling Next.
func (it *binaryNodeIterator) Path() []byte {
if it.Leaf() {
return it.LeafKey()
}
var path []byte
for i, state := range it.stack {
// skip the last byte
if i >= len(it.stack)-1 {
break
}
@ -198,107 +204,94 @@ func (it *binaryNodeIterator) Path() []byte {
return path
}
// NodeBlob returns the serialized bytes of the current node.
func (it *binaryNodeIterator) NodeBlob() []byte {
return SerializeNode(it.current)
return it.store.serializeNode(it.current)
}
// Leaf returns true iff the current node is a leaf node.
// In a Binary Trie, a StemNode contains up to 256 leaf values.
// The iterator is only considered to be "at a leaf" when it's positioned
// at a specific non-nil value within the StemNode, not just at the StemNode itself.
// Leaf reports whether the iterator is currently positioned at a leaf value.
// A StemNode holds up to 256 values; the iterator is only "at a leaf" when
// positioned at a specific non-nil value inside the stem, not merely at the
// StemNode itself. The stack Index points to the NEXT position after the
// current value, so Index == 0 means we haven't yielded anything yet.
func (it *binaryNodeIterator) Leaf() bool {
sn, ok := it.current.(*StemNode)
if !ok {
if it.current.Kind() != kindStem {
return false
}
// Check if we have a valid stack position
if len(it.stack) == 0 {
return false
}
// The Index in the stack state points to the NEXT position after the current value.
// So if Index is 0, we haven't started iterating through the values yet.
// If Index is 5, we're currently at value[4] (the 5th value, 0-indexed).
idx := it.stack[len(it.stack)-1].Index
if idx == 0 || idx > 256 {
return false
}
// Check if there's actually a value at the current position
sn := it.store.getStem(it.current.Index())
currentValueIndex := idx - 1
return sn.Values[currentValueIndex] != nil
return sn.hasValue(byte(currentValueIndex))
}
// LeafKey returns the key of the leaf. The method panics if the iterator is not
// positioned at a leaf. Callers must not retain references to the value after
// calling Next.
// LeafKey returns the key of the leaf. Panics if the iterator is not
// positioned at a leaf. Callers must not retain references to the returned
// slice after calling Next.
func (it *binaryNodeIterator) LeafKey() []byte {
leaf, ok := it.current.(*StemNode)
if !ok {
if it.current.Kind() != kindStem {
panic("Leaf() called on an binary node iterator not at a leaf location")
}
return leaf.Key(it.stack[len(it.stack)-1].Index - 1)
sn := it.store.getStem(it.current.Index())
return sn.Key(it.stack[len(it.stack)-1].Index - 1)
}
// LeafBlob returns the content of the leaf. The method panics if the iterator
// is not positioned at a leaf. Callers must not retain references to the value
// after calling Next.
// LeafBlob returns the leaf value. Panics if the iterator is not positioned
// at a leaf. Callers must not retain references to the returned slice after
// calling Next.
func (it *binaryNodeIterator) LeafBlob() []byte {
leaf, ok := it.current.(*StemNode)
if !ok {
if it.current.Kind() != kindStem {
panic("LeafBlob() called on an binary node iterator not at a leaf location")
}
return leaf.Values[it.stack[len(it.stack)-1].Index-1]
sn := it.store.getStem(it.current.Index())
return sn.getValue(byte(it.stack[len(it.stack)-1].Index - 1))
}
// LeafProof returns the Merkle proof of the leaf. The method panics if the
// iterator is not positioned at a leaf. Callers must not retain references
// to the value after calling Next.
// LeafProof returns the Merkle proof of the leaf. Panics if the iterator is
// not positioned at a leaf. Callers must not retain references to the
// returned slices after calling Next.
func (it *binaryNodeIterator) LeafProof() [][]byte {
sn, ok := it.current.(*StemNode)
if !ok {
if it.current.Kind() != kindStem {
panic("LeafProof() called on an binary node iterator not at a leaf location")
}
sn := it.store.getStem(it.current.Index())
proof := make([][]byte, 0, len(it.stack)+StemNodeWidth)
// Build proof by walking up the stack and collecting sibling hashes
if len(it.stack) < 2 {
proof = append(proof, sn.Stem[:])
proof = append(proof, sn.allValues()...)
return proof
}
for i := range it.stack[:len(it.stack)-2] {
state := it.stack[i]
internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode
internalNode := it.store.getInternal(state.Node.Index())
// Add the sibling hash to the proof
if state.Index == 0 {
// We came from left, so include right sibling
proof = append(proof, internalNode.right.Hash().Bytes())
rh := it.store.computeHash(internalNode.right)
proof = append(proof, rh.Bytes())
} else {
// We came from right, so include left sibling
proof = append(proof, internalNode.left.Hash().Bytes())
lh := it.store.computeHash(internalNode.left)
proof = append(proof, lh.Bytes())
}
}
// Add the stem and siblings
proof = append(proof, sn.Stem)
for _, v := range sn.Values {
proof = append(proof, v)
}
proof = append(proof, sn.Stem[:])
proof = append(proof, sn.allValues()...)
return proof
}
// AddResolver sets an intermediate database to use for looking up trie nodes
// before reaching into the real persistent layer.
//
// This is not required for normal operation, rather is an optimization for
// cases where trie nodes can be recovered from some external mechanism without
// reading from disk. In those cases, this resolver allows short circuiting
// accesses and returning them from memory.
//
// Before adding a similar mechanism to any other place in Geth, consider
// making trie.Database an interface and wrapping at that level. It's a huge
// refactor, but it could be worth it if another occurrence arises.
// AddResolver is a no-op (satisfies the NodeIterator interface).
func (it *binaryNodeIterator) AddResolver(trie.NodeResolver) {
// Not implemented, but should not panic
}

View file

@ -27,14 +27,13 @@ import (
// makeTrie creates a BinaryTrie populated with the given key-value pairs.
func makeTrie(t *testing.T, entries [][2]common.Hash) *BinaryTrie {
t.Helper()
store := newNodeStore()
tr := &BinaryTrie{
root: NewBinaryNode(),
store: store,
tracer: trie.NewPrevalueTracer(),
}
for _, kv := range entries {
var err error
tr.root, err = tr.root.Insert(kv[0][:], kv[1][:], nil, 0)
if err != nil {
if err := store.Insert(kv[0][:], kv[1][:], nil); err != nil {
t.Fatal(err)
}
}
@ -64,7 +63,7 @@ func countLeaves(t *testing.T, tr *BinaryTrie) int {
// no nodes and reports no error.
func TestIteratorEmptyTrie(t *testing.T) {
tr := &BinaryTrie{
root: Empty{},
store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
it, err := newBinaryNodeIterator(tr, nil)
@ -145,8 +144,8 @@ func TestIteratorEmptyNodeBacktrack(t *testing.T) {
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
})
if _, ok := tr.root.(*InternalNode); !ok {
t.Fatalf("expected InternalNode root, got %T", tr.root)
if tr.store.root.Kind() != kindInternal {
t.Fatalf("expected InternalNode root, got kind %d", tr.store.root.Kind())
}
if leaves := countLeaves(t, tr); leaves != 2 {
t.Fatalf("expected 2 leaves, got %d (Empty backtrack bug?)", leaves)
@ -162,18 +161,31 @@ func TestIteratorHashedNodeNilData(t *testing.T) {
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
})
root, ok := tr.root.(*InternalNode)
if !ok {
t.Fatalf("expected InternalNode root, got %T", tr.root)
root := tr.store.root
if root.Kind() != kindInternal {
t.Fatalf("expected InternalNode root, got kind %d", root.Kind())
}
rootNode := tr.store.getInternal(root.Index())
// Replace right child with a zero-hash HashedNode. nodeResolver
// short-circuits on common.Hash{} and returns (nil, nil), which
// triggers the nil-data guard in the iterator.
root.right = HashedNode(common.Hash{})
rootNode.right = tr.store.newHashedRef(common.Hash{})
// Should not panic; the zero-hash right child should be treated as Empty.
if leaves := countLeaves(t, tr); leaves != 1 {
// Since the hashed node can't be resolved (nil data -> empty deserialization),
// only the left leaf should be counted.
it, err := newBinaryNodeIterator(tr, nil)
if err != nil {
t.Fatal(err)
}
leaves := 0
for it.Next(true) {
if it.Leaf() {
leaves++
}
}
if leaves != 1 {
t.Fatalf("expected 1 leaf (zero-hash right node skipped), got %d", leaves)
}
}

56
trie/bintrie/node_ref.go Normal file
View file

@ -0,0 +1,56 @@
// Copyright 2026 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
// nodeKind identifies the type of a trie node stored in a nodeRef.
type nodeKind uint8
const (
kindEmpty nodeKind = iota
kindInternal
kindStem // up to 256 values per stem
kindHashed
)
// nodeRef is a compact, GC-invisible reference to a node in a nodeStore.
// It packs a 2-bit type tag (bits 31-30) and a 30-bit index (bits 29-0)
// into a single uint32. Because nodeRef contains no Go pointers, slices
// of structs containing nodeRef fields are allocated in noscan spans —
// the garbage collector never examines them.
type nodeRef uint32
const (
kindShift uint32 = 30
indexMask uint32 = (1 << kindShift) - 1
// emptyRef represents an empty node.
emptyRef nodeRef = 0
)
func makeRef(kind nodeKind, idx uint32) nodeRef {
if idx > indexMask {
panic("nodeRef index overflow")
}
return nodeRef(uint32(kind)<<kindShift | idx)
}
func (r nodeRef) Kind() nodeKind { return nodeKind(uint32(r) >> kindShift) }
// Index within the typed pool.
func (r nodeRef) Index() uint32 { return uint32(r) & indexMask }
func (r nodeRef) IsEmpty() bool { return r.Kind() == kindEmpty }

184
trie/bintrie/node_store.go Normal file
View file

@ -0,0 +1,184 @@
// Copyright 2026 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import "github.com/ethereum/go-ethereum/common"
// storeChunkSize is the number of nodes per chunk in each typed pool.
const storeChunkSize = 4096
// nodeStore is a GC-friendly arena for binary trie nodes. Nodes are packed
// into typed chunked pools so pointer-free types (InternalNode, HashedNode)
// land in noscan spans the GC skips entirely.
type nodeStore struct {
internalChunks []*[storeChunkSize]InternalNode
internalCount uint32
stemChunks []*[storeChunkSize]StemNode
stemCount uint32
hashedChunks []*[storeChunkSize]HashedNode
hashedCount uint32
root nodeRef
// Free list for recycling hashed-node slots after resolve. Internal and
// stem nodes are never freed under current semantics (no delete path,
// stem-split keeps the old stem at a deeper position), so they don't
// have free lists.
freeHashed []uint32
}
func newNodeStore() *nodeStore {
return &nodeStore{root: emptyRef}
}
func (s *nodeStore) allocInternal() uint32 {
idx := s.internalCount
chunkIdx := idx / storeChunkSize
if uint32(len(s.internalChunks)) <= chunkIdx {
s.internalChunks = append(s.internalChunks, new([storeChunkSize]InternalNode))
}
s.internalCount++
if s.internalCount > indexMask {
panic("internal node pool overflow")
}
return idx
}
func (s *nodeStore) getInternal(idx uint32) *InternalNode {
return &s.internalChunks[idx/storeChunkSize][idx%storeChunkSize]
}
func (s *nodeStore) newInternalRef(depth int) nodeRef {
if depth > 248 {
panic("node depth exceeds maximum binary trie depth")
}
idx := s.allocInternal()
n := s.getInternal(idx)
n.depth = uint8(depth)
n.mustRecompute = true
n.dirty = true
return makeRef(kindInternal, idx)
}
func (s *nodeStore) allocStem() uint32 {
idx := s.stemCount
chunkIdx := idx / storeChunkSize
if uint32(len(s.stemChunks)) <= chunkIdx {
s.stemChunks = append(s.stemChunks, new([storeChunkSize]StemNode))
}
s.stemCount++
if s.stemCount > indexMask {
panic("stem node pool overflow")
}
return idx
}
func (s *nodeStore) getStem(idx uint32) *StemNode {
return &s.stemChunks[idx/storeChunkSize][idx%storeChunkSize]
}
func (s *nodeStore) newStemRef(stem []byte, depth int) nodeRef {
if depth > 248 {
panic("node depth exceeds maximum binary trie depth")
}
idx := s.allocStem()
sn := s.getStem(idx)
copy(sn.Stem[:], stem[:StemSize])
sn.depth = uint8(depth)
sn.mustRecompute = true
sn.dirty = true
return makeRef(kindStem, idx)
}
func (s *nodeStore) allocHashed() uint32 {
if n := len(s.freeHashed); n > 0 {
idx := s.freeHashed[n-1]
s.freeHashed = s.freeHashed[:n-1]
*s.getHashed(idx) = HashedNode{}
return idx
}
idx := s.hashedCount
chunkIdx := idx / storeChunkSize
if uint32(len(s.hashedChunks)) <= chunkIdx {
s.hashedChunks = append(s.hashedChunks, new([storeChunkSize]HashedNode))
}
s.hashedCount++
if s.hashedCount > indexMask {
panic("hashed node pool overflow")
}
return idx
}
func (s *nodeStore) getHashed(idx uint32) *HashedNode {
return &s.hashedChunks[idx/storeChunkSize][idx%storeChunkSize]
}
func (s *nodeStore) freeHashedNode(idx uint32) {
s.freeHashed = append(s.freeHashed, idx)
}
func (s *nodeStore) newHashedRef(hash common.Hash) nodeRef {
idx := s.allocHashed()
*s.getHashed(idx) = HashedNode(hash)
return makeRef(kindHashed, idx)
}
func (s *nodeStore) Copy() *nodeStore {
ns := &nodeStore{
root: s.root,
internalCount: s.internalCount,
stemCount: s.stemCount,
hashedCount: s.hashedCount,
}
ns.internalChunks = make([]*[storeChunkSize]InternalNode, len(s.internalChunks))
for i, chunk := range s.internalChunks {
cp := *chunk
ns.internalChunks[i] = &cp
}
ns.stemChunks = make([]*[storeChunkSize]StemNode, len(s.stemChunks))
for i, chunk := range s.stemChunks {
cp := *chunk
ns.stemChunks[i] = &cp
}
// Deep-copy each stem's value slots — they may alias serialized buffers,
// so we can't rely on the chunk-wise struct copy above.
for i := uint32(0); i < s.stemCount; i++ {
src := s.getStem(i)
dst := ns.getStem(i)
for j, v := range src.values {
if v == nil {
continue
}
cp := make([]byte, len(v))
copy(cp, v)
dst.values[j] = cp
}
}
ns.hashedChunks = make([]*[storeChunkSize]HashedNode, len(s.hashedChunks))
for i, chunk := range s.hashedChunks {
cp := *chunk
ns.hashedChunks[i] = &cp
}
if len(s.freeHashed) > 0 {
ns.freeHashed = make([]uint32, len(s.freeHashed))
copy(ns.freeHashed, s.freeHashed)
}
return ns
}

View file

@ -17,236 +17,93 @@
package bintrie
import (
"bytes"
"errors"
"fmt"
"slices"
"crypto/sha256"
"github.com/ethereum/go-ethereum/common"
)
// StemNode represents a group of `NodeWith` values sharing the same stem.
// StemNode holds up to 256 values sharing a 31-byte stem.
//
// Invariant: dirty=false implies mustRecompute=false. Every mutation that
// invalidates the cached hash MUST also mark the blob for re-flush.
type StemNode struct {
Stem []byte // Stem path to get to StemNodeWidth values
Values [][]byte // All values, indexed by the last byte of the key.
depth int // Depth of the node
Stem [StemSize]byte
values [StemNodeWidth][]byte // nil == slot absent
mustRecompute bool // true if the hash needs to be recomputed
dirty bool // true if the node's on-disk blob is stale (needs flush)
depth uint8
mustRecompute bool // hash is stale (cleared by Hash)
dirty bool // on-disk blob is stale (cleared by CollectNodes)
hash common.Hash // cached hash when mustRecompute == false
}
// Get retrieves the value for the given key.
func (bt *StemNode) Get(key []byte, _ NodeResolverFn) ([]byte, error) {
if !bytes.Equal(bt.Stem, key[:StemSize]) {
return nil, nil
}
return bt.Values[key[StemSize]], nil
func (sn *StemNode) getValue(suffix byte) []byte {
return sn.values[suffix]
}
// Insert inserts a new key-value pair into the node.
func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
if !bytes.Equal(bt.Stem, key[:StemSize]) {
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true}
bt.depth++
// bt is re-parented under n and sits at a new path — rewrite its blob.
bt.mustRecompute = true
bt.dirty = true
var child, other *BinaryNode
if bitStem == 0 {
n.left = bt
child = &n.left
other = &n.right
} else {
n.right = bt
child = &n.right
other = &n.left
}
bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
if bitKey == bitStem {
var err error
*child, err = (*child).Insert(key, value, nil, depth+1)
if err != nil {
return n, fmt.Errorf("insert error: %w", err)
}
*other = Empty{}
} else {
var values [StemNodeWidth][]byte
values[key[StemSize]] = value
*other = &StemNode{
Stem: slices.Clone(key[:StemSize]),
Values: values[:],
depth: depth + 1,
mustRecompute: true,
dirty: true,
}
}
return n, nil
}
if len(value) != HashSize {
return bt, errors.New("invalid insertion: value length")
}
bt.Values[key[StemSize]] = value
bt.mustRecompute = true
bt.dirty = true
return bt, nil
func (sn *StemNode) hasValue(suffix byte) bool {
return sn.values[suffix] != nil
}
// Copy creates a deep copy of the node.
func (bt *StemNode) Copy() BinaryNode {
var values [StemNodeWidth][]byte
for i, v := range bt.Values {
values[i] = slices.Clone(v)
}
return &StemNode{
Stem: slices.Clone(bt.Stem),
Values: values[:],
depth: bt.depth,
hash: bt.hash,
mustRecompute: bt.mustRecompute,
dirty: bt.dirty,
}
// allValues returns the underlying slot array as a slice. nil entries mean
// absent. Callers must treat it as read-only.
func (sn *StemNode) allValues() [][]byte {
return sn.values[:]
}
// GetHeight returns the height of the node.
func (bt *StemNode) GetHeight() int {
return 1
// setValue mutates a value slot and marks the stem for re-hash and
// re-flush. This is the only API for post-load value mutation; direct
// values[...] writes are reserved for the on-disk load path in
// decodeNode, which must leave mustRecompute/dirty at their loaded
// state.
func (sn *StemNode) setValue(suffix byte, value []byte) {
sn.values[suffix] = value
sn.mustRecompute = true
sn.dirty = true
}
// Hash returns the hash of the node.
func (bt *StemNode) Hash() common.Hash {
if !bt.mustRecompute {
return bt.hash
func (sn *StemNode) Hash() common.Hash {
if !sn.mustRecompute {
return sn.hash
}
// Use sha256.Sum256 (returns [32]byte by value) instead of a pooled
// hash.Hash: feeding data[i][:0] into the interface method Sum forces
// data to heap (escape analysis is conservative through interfaces).
// Sum256 takes []byte and returns by value, so data stays on stack.
var data [StemNodeWidth]common.Hash
h := newSha256()
defer returnSha256(h)
for i, v := range bt.Values {
for i, v := range sn.values {
if v != nil {
h.Reset()
h.Write(v)
h.Sum(data[i][:0])
data[i] = sha256.Sum256(v)
}
}
h.Reset()
var pair [2 * HashSize]byte
for level := 1; level <= 8; level++ {
for i := range StemNodeWidth / (1 << level) {
h.Reset()
if data[i*2] == (common.Hash{}) && data[i*2+1] == (common.Hash{}) {
data[i] = common.Hash{}
continue
}
h.Write(data[i*2][:])
h.Write(data[i*2+1][:])
data[i] = common.Hash(h.Sum(nil))
copy(pair[:HashSize], data[i*2][:])
copy(pair[HashSize:], data[i*2+1][:])
data[i] = sha256.Sum256(pair[:])
}
}
h.Reset()
h.Write(bt.Stem)
h.Write([]byte{0})
h.Write(data[0][:])
bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false
return bt.hash
var final [StemSize + 1 + HashSize]byte
copy(final[:StemSize], sn.Stem[:])
final[StemSize] = 0
copy(final[StemSize+1:], data[0][:])
sn.hash = sha256.Sum256(final[:])
sn.mustRecompute = false
return sn.hash
}
// CollectNodes flushes the stem via the collector when dirty; clean stems
// are skipped.
func (bt *StemNode) CollectNodes(path []byte, flush NodeFlushFn) error {
if !bt.dirty {
return nil
}
flush(path, bt)
bt.dirty = false
return nil
}
// GetValuesAtStem retrieves the group of values located at the given stem key.
func (bt *StemNode) GetValuesAtStem(stem []byte, _ NodeResolverFn) ([][]byte, error) {
if !bytes.Equal(bt.Stem, stem) {
return nil, nil
}
return bt.Values[:], nil
}
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
// Already-existing values will be overwritten.
func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
if !bytes.Equal(bt.Stem, key[:StemSize]) {
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true}
bt.depth++
// bt is re-parented under n and sits at a new path — rewrite its blob.
bt.mustRecompute = true
bt.dirty = true
var child, other *BinaryNode
if bitStem == 0 {
n.left = bt
child = &n.left
other = &n.right
} else {
n.right = bt
child = &n.right
other = &n.left
}
bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
if bitKey == bitStem {
var err error
*child, err = (*child).InsertValuesAtStem(key, values, nil, depth+1)
if err != nil {
return n, fmt.Errorf("insert error: %w", err)
}
*other = Empty{}
} else {
*other = &StemNode{
Stem: slices.Clone(key[:StemSize]),
Values: values,
depth: n.depth + 1,
mustRecompute: true,
dirty: true,
}
}
return n, nil
}
// same stem, just merge the two value lists
for i, v := range values {
if v != nil {
bt.Values[i] = v
bt.mustRecompute = true
bt.dirty = true
}
}
return bt, nil
}
func (bt *StemNode) toDot(parent, path string) string {
me := fmt.Sprintf("stem%s", path)
ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, bt.Stem, bt.Hash())
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
for i, v := range bt.Values {
if v != nil {
ret = fmt.Sprintf("%s%s%x [label=\"%x\"]\n", ret, me, i, v)
ret = fmt.Sprintf("%s%s -> %s%x\n", ret, me, me, i)
}
}
return ret
}
// Key returns the full key for the given index.
func (bt *StemNode) Key(i int) []byte {
func (sn *StemNode) Key(i int) []byte {
var ret [HashSize]byte
copy(ret[:], bt.Stem)
copy(ret[:], sn.Stem[:])
ret[StemSize] = byte(i)
return ret[:]
}

View file

@ -23,165 +23,99 @@ import (
"github.com/ethereum/go-ethereum/common"
)
// TestStemNodeGet tests the Get method for matching stem, non-matching stem,
// and nil-value suffix scenarios.
func TestStemNodeGet(t *testing.T) {
stem := make([]byte, StemSize)
stem[0] = 0xAB
var values [StemNodeWidth][]byte
values[5] = common.HexToHash("0xdeadbeef").Bytes()
node := &StemNode{Stem: stem, Values: values[:], depth: 0}
// Matching stem, populated suffix → returns value.
key := make([]byte, HashSize)
copy(key[:StemSize], stem)
key[StemSize] = 5
got, err := node.Get(key, nil)
if err != nil {
t.Fatalf("Get error: %v", err)
}
if !bytes.Equal(got, values[5]) {
t.Fatalf("Get = %x, want %x", got, values[5])
}
// Matching stem, empty suffix → returns nil (slot not set).
key[StemSize] = 99
got, err = node.Get(key, nil)
if err != nil {
t.Fatalf("Get error: %v", err)
}
if got != nil {
t.Fatalf("Get(empty suffix) = %x, want nil", got)
}
// Non-matching stem → returns nil, nil.
otherKey := make([]byte, HashSize)
otherKey[0] = 0xFF
got, err = node.Get(otherKey, nil)
if err != nil {
t.Fatalf("Get error: %v", err)
}
if got != nil {
t.Fatalf("Get(wrong stem) = %x, want nil", got)
}
}
// TestStemNodeInsertSameStem tests inserting values with the same stem
// TestStemNodeInsertSameStem tests inserting values with the same stem via nodeStore.
func TestStemNodeInsertSameStem(t *testing.T) {
s := newNodeStore()
stem := make([]byte, 31)
for i := range stem {
stem[i] = byte(i)
}
var values [256][]byte
values[0] = common.HexToHash("0x0101").Bytes()
node := &StemNode{
Stem: stem,
Values: values[:],
depth: 0,
// Insert first value
key1 := make([]byte, 32)
copy(key1[:31], stem)
key1[31] = 0
value1 := common.HexToHash("0x0101").Bytes()
if err := s.Insert(key1, value1, nil); err != nil {
t.Fatal(err)
}
// Insert another value with the same stem but different last byte
key := make([]byte, 32)
copy(key[:31], stem)
key[31] = 10
value := common.HexToHash("0x0202").Bytes()
newNode, err := node.Insert(key, value, nil, 0)
if err != nil {
t.Fatalf("Failed to insert: %v", err)
key2 := make([]byte, 32)
copy(key2[:31], stem)
key2[31] = 10
value2 := common.HexToHash("0x0202").Bytes()
if err := s.Insert(key2, value2, nil); err != nil {
t.Fatal(err)
}
// Should still be a StemNode
stemNode, ok := newNode.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", newNode)
// Root should still be a StemNode
if s.root.Kind() != kindStem {
t.Fatalf("Expected kindStem root, got kind %d", s.root.Kind())
}
// Check that both values are present
if !bytes.Equal(stemNode.Values[0], values[0]) {
v1, _ := s.Get(key1, nil)
if !bytes.Equal(v1, value1) {
t.Errorf("Value at index 0 mismatch")
}
if !bytes.Equal(stemNode.Values[10], value) {
v2, _ := s.Get(key2, nil)
if !bytes.Equal(v2, value2) {
t.Errorf("Value at index 10 mismatch")
}
}
// TestStemNodeInsertDifferentStem tests inserting values with different stems
// TestStemNodeInsertDifferentStem tests inserting values with different stems via nodeStore.
func TestStemNodeInsertDifferentStem(t *testing.T) {
stem1 := make([]byte, 31)
for i := range stem1 {
stem1[i] = 0x00
}
s := newNodeStore()
var values [256][]byte
values[0] = common.HexToHash("0x0101").Bytes()
node := &StemNode{
Stem: stem1,
Values: values[:],
depth: 0,
// Insert first value with stem of all zeros
key1 := make([]byte, 32)
key1[31] = 0
value1 := common.HexToHash("0x0101").Bytes()
if err := s.Insert(key1, value1, nil); err != nil {
t.Fatal(err)
}
// Insert with a different stem (first bit different)
key := make([]byte, 32)
key[0] = 0x80 // First bit is 1 instead of 0
value := common.HexToHash("0x0202").Bytes()
newNode, err := node.Insert(key, value, nil, 0)
if err != nil {
t.Fatalf("Failed to insert: %v", err)
key2 := make([]byte, 32)
key2[0] = 0x80 // First bit is 1 instead of 0
value2 := common.HexToHash("0x0202").Bytes()
if err := s.Insert(key2, value2, nil); err != nil {
t.Fatal(err)
}
// Should now be an InternalNode
internalNode, ok := newNode.(*InternalNode)
if !ok {
t.Fatalf("Expected InternalNode, got %T", newNode)
if s.root.Kind() != kindInternal {
t.Fatalf("Expected kindInternal root, got kind %d", s.root.Kind())
}
// Check depth
if internalNode.depth != 0 {
t.Errorf("Expected depth 0, got %d", internalNode.depth)
rootNode := s.getInternal(s.root.Index())
if rootNode.depth != 0 {
t.Errorf("Expected depth 0, got %d", rootNode.depth)
}
// Original stem should be on the left (bit 0)
leftStem, ok := internalNode.left.(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
// Verify both values are retrievable
v1, _ := s.Get(key1, nil)
if !bytes.Equal(v1, value1) {
t.Error("Value 1 mismatch")
}
if !bytes.Equal(leftStem.Stem, stem1) {
t.Errorf("Left stem mismatch")
}
// New stem should be on the right (bit 1)
rightStem, ok := internalNode.right.(*StemNode)
if !ok {
t.Fatalf("Expected right child to be StemNode, got %T", internalNode.right)
}
if !bytes.Equal(rightStem.Stem, key[:31]) {
t.Errorf("Right stem mismatch")
v2, _ := s.Get(key2, nil)
if !bytes.Equal(v2, value2) {
t.Error("Value 2 mismatch")
}
}
// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length
// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length via nodeStore.
func TestStemNodeInsertInvalidValueLength(t *testing.T) {
stem := make([]byte, 31)
var values [256][]byte
s := newNodeStore()
node := &StemNode{
Stem: stem,
Values: values[:],
depth: 0,
}
// Try to insert value with wrong length
key := make([]byte, 32)
copy(key[:31], stem)
invalidValue := []byte{1, 2, 3} // Not 32 bytes
_, err := node.Insert(key, invalidValue, nil, 0)
err := s.Insert(key, invalidValue, nil)
if err == nil {
t.Fatal("Expected error for invalid value length")
}
@ -191,221 +125,209 @@ func TestStemNodeInsertInvalidValueLength(t *testing.T) {
}
}
// TestStemNodeCopy tests the Copy method
// TestStemNodeCopy tests the Copy method via nodeStore.
func TestStemNodeCopy(t *testing.T) {
stem := make([]byte, 31)
for i := range stem {
stem[i] = byte(i)
s := newNodeStore()
key1 := make([]byte, 32)
for i := range 31 {
key1[i] = byte(i)
}
key1[31] = 0
value1 := common.HexToHash("0x0101").Bytes()
key2 := make([]byte, 32)
copy(key2[:31], key1[:31])
key2[31] = 255
value2 := common.HexToHash("0x0202").Bytes()
if err := s.Insert(key1, value1, nil); err != nil {
t.Fatal(err)
}
if err := s.Insert(key2, value2, nil); err != nil {
t.Fatal(err)
}
var values [256][]byte
values[0] = common.HexToHash("0x0101").Bytes()
values[255] = common.HexToHash("0x0202").Bytes()
ns := s.Copy()
node := &StemNode{
Stem: stem,
Values: values[:],
depth: 10,
}
// Create a copy
copied := node.Copy()
copiedStem, ok := copied.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", copied)
}
// Check that values are equal but not the same slice
if !bytes.Equal(copiedStem.Stem, node.Stem) {
t.Errorf("Stem mismatch after copy")
}
if &copiedStem.Stem[0] == &node.Stem[0] {
t.Error("Stem slice not properly cloned")
}
// Check values
if !bytes.Equal(copiedStem.Values[0], node.Values[0]) {
// Check that values are equal
v1, _ := ns.Get(key1, nil)
if !bytes.Equal(v1, value1) {
t.Errorf("Value at index 0 mismatch after copy")
}
if !bytes.Equal(copiedStem.Values[255], node.Values[255]) {
v2, _ := ns.Get(key2, nil)
if !bytes.Equal(v2, value2) {
t.Errorf("Value at index 255 mismatch after copy")
}
// Check that value slices are cloned
if copiedStem.Values[0] != nil && &copiedStem.Values[0][0] == &node.Values[0][0] {
t.Error("Value slice not properly cloned")
}
// Check depth
if copiedStem.depth != node.depth {
t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedStem.depth)
}
}
// TestStemNodeHash tests the Hash method
// TestStemNodeHash tests the Hash method.
func TestStemNodeHash(t *testing.T) {
stem := make([]byte, 31)
var values [256][]byte
values[0] = common.HexToHash("0x0101").Bytes()
s := newNodeStore()
node := &StemNode{
Stem: stem,
Values: values[:],
depth: 0,
key := make([]byte, 32)
key[31] = 0
value := common.HexToHash("0x0101").Bytes()
if err := s.Insert(key, value, nil); err != nil {
t.Fatal(err)
}
hash1 := node.Hash()
hash1 := s.computeHash(s.root)
// Hash should be deterministic
hash2 := node.Hash()
hash2 := s.computeHash(s.root)
if hash1 != hash2 {
t.Errorf("Hash not deterministic: %x != %x", hash1, hash2)
}
// Changing a value should change the hash
node.Values[1] = common.HexToHash("0x0202").Bytes()
node.mustRecompute = true
hash3 := node.Hash()
key2 := make([]byte, 32)
key2[31] = 1
value2 := common.HexToHash("0x0202").Bytes()
if err := s.Insert(key2, value2, nil); err != nil {
t.Fatal(err)
}
hash3 := s.computeHash(s.root)
if hash1 == hash3 {
t.Error("Hash didn't change after modifying values")
}
}
// TestStemNodeGetValuesAtStem tests GetValuesAtStem method
// TestStemNodeGetValuesAtStem tests GetValuesAtStem method via nodeStore.
func TestStemNodeGetValuesAtStem(t *testing.T) {
s := newNodeStore()
stem := make([]byte, 31)
for i := range stem {
stem[i] = byte(i)
}
var values [256][]byte
values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes()
values[10] = common.HexToHash("0x0202").Bytes()
values[255] = common.HexToHash("0x0303").Bytes()
node := &StemNode{
Stem: stem,
Values: values[:],
depth: 0,
if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
t.Fatal(err)
}
// GetValuesAtStem with matching stem
retrievedValues, err := node.GetValuesAtStem(stem, nil)
retrievedValues, err := s.GetValuesAtStem(stem, nil)
if err != nil {
t.Fatalf("Failed to get values: %v", err)
}
// Check that all values match
for i := range 256 {
if !bytes.Equal(retrievedValues[i], values[i]) {
t.Errorf("Value mismatch at index %d", i)
}
if !bytes.Equal(retrievedValues[0], values[0]) {
t.Error("Value at index 0 mismatch")
}
if !bytes.Equal(retrievedValues[10], values[10]) {
t.Error("Value at index 10 mismatch")
}
if !bytes.Equal(retrievedValues[255], values[255]) {
t.Error("Value at index 255 mismatch")
}
// GetValuesAtStem with different stem should return nil
// GetValuesAtStem with different stem should return nil values
differentStem := make([]byte, 31)
differentStem[0] = 0xFF
shouldBeNil, err := node.GetValuesAtStem(differentStem, nil)
shouldBeEmpty, err := s.GetValuesAtStem(differentStem, nil)
if err != nil {
t.Fatalf("Failed to get values with different stem: %v", err)
}
if shouldBeNil != nil {
t.Error("Expected nil for different stem, got non-nil")
allNil := true
for _, v := range shouldBeEmpty {
if v != nil {
allNil = false
break
}
}
if !allNil {
t.Error("Expected all nil values for different stem")
}
}
// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method
// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method via nodeStore.
func TestStemNodeInsertValuesAtStem(t *testing.T) {
s := newNodeStore()
stem := make([]byte, 31)
var values [256][]byte
values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes()
node := &StemNode{
Stem: stem,
Values: values[:],
depth: 0,
if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
t.Fatal(err)
}
// Insert new values at the same stem
var newValues [256][]byte
newValues := make([][]byte, 256)
newValues[1] = common.HexToHash("0x0202").Bytes()
newValues[2] = common.HexToHash("0x0303").Bytes()
newNode, err := node.InsertValuesAtStem(stem, newValues[:], nil, 0)
if err != nil {
t.Fatalf("Failed to insert values: %v", err)
}
stemNode, ok := newNode.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", newNode)
if err := s.InsertValuesAtStem(stem, newValues, nil); err != nil {
t.Fatal(err)
}
// Check that all values are present
if !bytes.Equal(stemNode.Values[0], values[0]) {
retrieved, err := s.GetValuesAtStem(stem, nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(retrieved[0], values[0]) {
t.Error("Original value at index 0 missing")
}
if !bytes.Equal(stemNode.Values[1], newValues[1]) {
if !bytes.Equal(retrieved[1], newValues[1]) {
t.Error("New value at index 1 missing")
}
if !bytes.Equal(stemNode.Values[2], newValues[2]) {
if !bytes.Equal(retrieved[2], newValues[2]) {
t.Error("New value at index 2 missing")
}
}
// TestStemNodeGetHeight tests GetHeight method
// TestStemNodeGetHeight tests GetHeight method via nodeStore.
func TestStemNodeGetHeight(t *testing.T) {
node := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 0,
s := newNodeStore()
key := make([]byte, 32)
value := common.HexToHash("0x01").Bytes()
if err := s.Insert(key, value, nil); err != nil {
t.Fatal(err)
}
height := node.GetHeight()
height := s.getHeight(s.root)
if height != 1 {
t.Errorf("Expected height 1, got %d", height)
}
}
// TestStemNodeCollectNodes tests CollectNodes method
// TestStemNodeCollectNodes tests CollectNodes method via nodeStore.
func TestStemNodeCollectNodes(t *testing.T) {
s := newNodeStore()
stem := make([]byte, 31)
var values [256][]byte
values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes()
node := &StemNode{
Stem: stem,
Values: values[:],
depth: 0,
dirty: true,
if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
t.Fatal(err)
}
var collectedPaths [][]byte
var collectedNodes []BinaryNode
flushFn := func(path []byte, n BinaryNode) {
// Make a copy of the path
flushFn := func(path []byte, hash common.Hash, serialized []byte) {
pathCopy := make([]byte, len(path))
copy(pathCopy, path)
collectedPaths = append(collectedPaths, pathCopy)
collectedNodes = append(collectedNodes, n)
}
err := node.CollectNodes([]byte{0, 1, 0}, flushFn)
err := s.collectNodes(s.root, []byte{0, 1, 0}, flushFn)
if err != nil {
t.Fatalf("Failed to collect nodes: %v", err)
}
// Should have collected one node (itself)
if len(collectedNodes) != 1 {
t.Errorf("Expected 1 collected node, got %d", len(collectedNodes))
}
// Check that the collected node is the same
if collectedNodes[0] != node {
t.Error("Collected node doesn't match original")
if len(collectedPaths) != 1 {
t.Errorf("Expected 1 collected node, got %d", len(collectedPaths))
}
// Check the path
@ -413,44 +335,3 @@ func TestStemNodeCollectNodes(t *testing.T) {
t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0])
}
}
// TestStemNodeCollectNodesSkipsClean verifies that a clean stem is not
// flushed, and that flushing a dirty stem clears its dirty flag so that
// a subsequent CollectNodes on the same node is a no-op.
func TestStemNodeCollectNodesSkipsClean(t *testing.T) {
stem := make([]byte, 31)
node := &StemNode{
Stem: stem,
Values: make([][]byte, 256),
depth: 0,
}
var collected []BinaryNode
flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) }
if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
t.Fatalf("CollectNodes on clean stem: %v", err)
}
if len(collected) != 0 {
t.Fatalf("expected clean stem not to be flushed, got %d", len(collected))
}
node.dirty = true
if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
t.Fatalf("CollectNodes on dirty stem: %v", err)
}
if len(collected) != 1 {
t.Fatalf("expected dirty stem to be flushed once, got %d", len(collected))
}
if node.dirty {
t.Errorf("stem dirty flag should be cleared after flush")
}
collected = nil
if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
t.Fatalf("CollectNodes after flush: %v", err)
}
if len(collected) != 0 {
t.Errorf("expected no flush on clean stem, got %d", len(collected))
}
}

View file

@ -0,0 +1,310 @@
// Copyright 2026 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"crypto/sha256"
"errors"
"fmt"
"math/bits"
"runtime"
"sync"
"github.com/ethereum/go-ethereum/common"
)
type nodeFlushFn func(path []byte, hash common.Hash, serialized []byte)
func (s *nodeStore) Hash() common.Hash {
return s.computeHash(s.root)
}
func (s *nodeStore) computeHash(ref nodeRef) common.Hash {
switch ref.Kind() {
case kindInternal:
return s.hashInternal(ref.Index())
case kindStem:
return s.getStem(ref.Index()).Hash()
case kindHashed:
return s.getHashed(ref.Index()).Hash()
case kindEmpty:
return common.Hash{}
default:
return common.Hash{}
}
}
// parallelHashDepth is the tree depth below which hashInternal spawns
// goroutines for shallow-depth parallelism. Computed once at init because
// NumCPU() never changes after startup.
var parallelHashDepth = min(bits.Len(uint(runtime.NumCPU())), 8)
// hashInternal hashes an InternalNode and caches the result.
//
// At shallow depths (< parallelHashDepth) the left subtree is hashed in a
// goroutine while the right subtree is hashed inline, then the two digests
// are combined. Below that threshold the goroutine spawn cost outweighs the
// hashing work, so deeper nodes hash both children sequentially.
func (s *nodeStore) hashInternal(idx uint32) common.Hash {
node := s.getInternal(idx)
if !node.mustRecompute {
return node.hash
}
if int(node.depth) < parallelHashDepth {
var input [64]byte
var lh common.Hash
var wg sync.WaitGroup
if !node.left.IsEmpty() {
wg.Add(1)
go func() {
// defer wg.Done() so a panic in computeHash still releases
// the waiter; without this, a recover() higher in the call
// stack would leave the parent stuck in wg.Wait forever.
defer wg.Done()
lh = s.computeHash(node.left)
}()
}
if !node.right.IsEmpty() {
rh := s.computeHash(node.right)
copy(input[32:], rh[:])
}
wg.Wait()
copy(input[:32], lh[:])
node.hash = sha256.Sum256(input[:])
node.mustRecompute = false
return node.hash
}
// Deep sequential branch — mirrors the shallow branch's shape to keep
// input on the stack. Writing lh/rh through hash.Hash (interface)
// forces escape; copy into a local [64]byte and hash it in one shot.
var input [64]byte
if !node.left.IsEmpty() {
lh := s.computeHash(node.left)
copy(input[:HashSize], lh[:])
}
if !node.right.IsEmpty() {
rh := s.computeHash(node.right)
copy(input[HashSize:], rh[:])
}
node.hash = sha256.Sum256(input[:])
node.mustRecompute = false
return node.hash
}
// SerializeNode serializes a node into the flat on-disk format.
func (s *nodeStore) serializeNode(ref nodeRef) []byte {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
var serialized [NodeTypeBytes + HashSize + HashSize]byte
serialized[0] = nodeTypeInternal
lh := s.computeHash(node.left)
rh := s.computeHash(node.right)
copy(serialized[NodeTypeBytes:NodeTypeBytes+HashSize], lh[:])
copy(serialized[NodeTypeBytes+HashSize:], rh[:])
return serialized[:]
case kindStem:
sn := s.getStem(ref.Index())
// Count present slots to size the blob.
var count int
for _, v := range sn.values {
if v != nil {
count++
}
}
serializedLen := NodeTypeBytes + StemSize + StemBitmapSize + count*HashSize
serialized := make([]byte, serializedLen)
serialized[0] = nodeTypeStem
copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], sn.Stem[:])
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+StemBitmapSize]
offset := NodeTypeBytes + StemSize + StemBitmapSize
for i, v := range sn.values {
if v != nil {
bitmap[i/8] |= 1 << (7 - (i % 8))
copy(serialized[offset:offset+HashSize], v)
offset += HashSize
}
}
return serialized
default:
panic(fmt.Sprintf("SerializeNode: unexpected node kind %d", ref.Kind()))
}
}
var errInvalidSerializedLength = errors.New("invalid serialized node length")
// DeserializeNode deserializes a node from bytes, recomputing its hash. The
// returned node is marked dirty (provenance unknown, safe re-flush default).
func (s *nodeStore) deserializeNode(serialized []byte, depth int) (nodeRef, error) {
return s.decodeNode(serialized, depth, common.Hash{}, true, true)
}
// DeserializeNodeWithHash deserializes a node whose hash is already known and
// whose blob is already on disk (mustRecompute=false, dirty=false).
func (s *nodeStore) deserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (nodeRef, error) {
return s.decodeNode(serialized, depth, hn, false, false)
}
func (s *nodeStore) decodeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (nodeRef, error) {
if len(serialized) == 0 {
return emptyRef, nil
}
switch serialized[0] {
case nodeTypeInternal:
if len(serialized) != NodeTypeBytes+2*HashSize {
return emptyRef, errInvalidSerializedLength
}
var leftHash, rightHash common.Hash
copy(leftHash[:], serialized[NodeTypeBytes:NodeTypeBytes+HashSize])
copy(rightHash[:], serialized[NodeTypeBytes+HashSize:])
var leftRef, rightRef nodeRef
if leftHash != (common.Hash{}) {
leftRef = s.newHashedRef(leftHash)
}
if rightHash != (common.Hash{}) {
rightRef = s.newHashedRef(rightHash)
}
ref := s.newInternalRef(depth)
node := s.getInternal(ref.Index())
node.left = leftRef
node.right = rightRef
if !mustRecompute {
node.hash = hn
node.mustRecompute = false
}
node.dirty = dirty
return ref, nil
case nodeTypeStem:
if len(serialized) < NodeTypeBytes+StemSize+StemBitmapSize {
return emptyRef, errInvalidSerializedLength
}
stemIdx := s.allocStem()
sn := s.getStem(stemIdx)
copy(sn.Stem[:], serialized[NodeTypeBytes:NodeTypeBytes+StemSize])
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+StemBitmapSize]
offset := NodeTypeBytes + StemSize + StemBitmapSize
for i := range StemNodeWidth {
if bitmap[i/8]>>(7-(i%8))&1 != 1 {
continue
}
if len(serialized) < offset+HashSize {
return emptyRef, errInvalidSerializedLength
}
// Zero-copy: each slot aliases the serialized input buffer.
sn.values[i] = serialized[offset : offset+HashSize]
offset += HashSize
}
sn.depth = uint8(depth)
sn.hash = hn
sn.mustRecompute = mustRecompute
sn.dirty = dirty
return makeRef(kindStem, stemIdx), nil
default:
return emptyRef, errors.New("invalid node type")
}
}
// CollectNodes flushes every node that needs flushing via flushfn in post-order.
// Invariant: any ancestor of a node that needs flushing is itself marked, so a
// clean root means the whole subtree is clean.
func (s *nodeStore) collectNodes(ref nodeRef, path []byte, flushfn nodeFlushFn) error {
switch ref.Kind() {
case kindEmpty:
return nil
case kindInternal:
node := s.getInternal(ref.Index())
if !node.dirty {
return nil
}
// Reuse path buffer across children: flushfn consumers
// (NodeSet.AddNode, tracer.Get) clone via string(path), so in-place
// mutation is safe.
path = append(path, 0)
if err := s.collectNodes(node.left, path, flushfn); err != nil {
return err
}
path[len(path)-1] = 1
if err := s.collectNodes(node.right, path, flushfn); err != nil {
return err
}
path = path[:len(path)-1]
flushfn(path, s.computeHash(ref), s.serializeNode(ref))
node.dirty = false
return nil
case kindStem:
sn := s.getStem(ref.Index())
if !sn.dirty {
return nil
}
flushfn(path, s.computeHash(ref), s.serializeNode(ref))
sn.dirty = false
return nil
case kindHashed:
return nil // Already committed
default:
return fmt.Errorf("CollectNodes: unexpected kind %d", ref.Kind())
}
}
func (s *nodeStore) toDot(ref nodeRef, parent, path string) string {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
me := fmt.Sprintf("internal%s", path)
ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, s.computeHash(ref))
if len(parent) > 0 {
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
}
if !node.left.IsEmpty() {
ret += s.toDot(node.left, me, fmt.Sprintf("%s%02x", path, 0))
}
if !node.right.IsEmpty() {
ret += s.toDot(node.right, me, fmt.Sprintf("%s%02x", path, 1))
}
return ret
case kindStem:
sn := s.getStem(ref.Index())
me := fmt.Sprintf("stem%s", path)
ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, sn.Stem, sn.Hash())
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
for i, v := range sn.values {
if v == nil {
continue
}
ret += fmt.Sprintf("%s%x [label=\"%x\"]\n", me, i, v)
ret += fmt.Sprintf("%s -> %s%x\n", me, me, i)
}
return ret
case kindHashed:
hn := s.getHashed(ref.Index())
me := fmt.Sprintf("hash%s", path)
ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, hn.Hash())
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
return ret
default:
return ""
}
}

345
trie/bintrie/store_ops.go Normal file
View file

@ -0,0 +1,345 @@
// Copyright 2026 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
)
// nodeResolverFn resolves a hashed node from the database.
type nodeResolverFn func([]byte, common.Hash) ([]byte, error)
// GetValue returns the value at (stem, suffix) or nil if absent. Thin
// wrapper over GetValuesAtStem — the underlying StemNode returns its
// 256-slot array as a slice header (no allocation), so the per-call cost
// is the tree walk plus one index.
func (s *nodeStore) GetValue(stem []byte, suffix byte, resolver nodeResolverFn) ([]byte, error) {
values, err := s.GetValuesAtStem(stem, resolver)
if err != nil || values == nil {
return nil, err
}
return values[suffix], nil
}
// GetValuesAtStem returns the 256 value slots at stem, or nil if the stem
// is not in the trie. The returned slice is a view over the in-place
// StemNode values array (no allocation) and must be treated read-only.
func (s *nodeStore) GetValuesAtStem(stem []byte, resolver nodeResolverFn) ([][]byte, error) {
cur := s.root
var parentIdx uint32
var parentIsLeft bool
for {
switch cur.Kind() {
case kindInternal:
node := s.getInternal(cur.Index())
if node.depth >= 31*8 {
return nil, errors.New("node too deep")
}
bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
parentIdx = cur.Index()
if bit == 0 {
parentIsLeft = true
cur = node.left
} else {
parentIsLeft = false
cur = node.right
}
case kindStem:
sn := s.getStem(cur.Index())
if sn.Stem != [StemSize]byte(stem[:StemSize]) {
return nil, nil
}
return sn.allValues(), nil
case kindHashed:
// HashedNode at root is impossible: NewBinaryTrie resolves the
// root eagerly before any query. Any HashedNode we encounter here
// is necessarily a child of a previously-visited internal node.
if resolver == nil {
return nil, errors.New("getValuesAtStem: cannot resolve hashed node without resolver")
}
hn := s.getHashed(cur.Index())
parentNode := s.getInternal(parentIdx)
path, err := keyToPath(int(parentNode.depth), stem)
if err != nil {
return nil, fmt.Errorf("getValuesAtStem path error: %w", err)
}
data, err := resolver(path, hn.Hash())
if err != nil {
return nil, fmt.Errorf("getValuesAtStem resolve error: %w", err)
}
resolved, err := s.deserializeNodeWithHash(data, int(parentNode.depth)+1, hn.Hash())
if err != nil {
return nil, fmt.Errorf("getValuesAtStem deserialization error: %w", err)
}
s.freeHashedNode(cur.Index())
if parentIsLeft {
parentNode.left = resolved
} else {
parentNode.right = resolved
}
cur = resolved
case kindEmpty:
var values [StemNodeWidth][]byte
return values[:], nil
default:
return nil, fmt.Errorf("getValuesAtStem: unexpected node kind %d", cur.Kind())
}
}
}
// InsertSingle writes a single value slot at (stem, suffix). Thin wrapper
// over InsertValuesAtStem — builds a stack-allocated 256-slot array with
// only the target slot set and delegates. Matches the original design
// gballet referenced (comment 3101751325): one primary insert path; the
// single-slot variant dispatches through it so the split / resolve logic
// lives in one place.
func (s *nodeStore) InsertSingle(stem []byte, suffix byte, value []byte, resolver nodeResolverFn) error {
if len(value) != HashSize {
return errors.New("invalid insertion: value length")
}
var values [StemNodeWidth][]byte
values[suffix] = value
return s.InsertValuesAtStem(stem, values[:], resolver)
}
// InsertValuesAtStem writes the supplied value slots at stem. values may be
// sparse (nil entries are ignored). The recursive implementation dispatches
// through the same body, so a single code path handles internal descent,
// HashedNode resolution, stem merge, and stem split.
func (s *nodeStore) InsertValuesAtStem(stem []byte, values [][]byte, resolver nodeResolverFn) error {
var err error
s.root, err = s.insertValuesAtStem(s.root, stem, values, resolver, 0)
return err
}
func (s *nodeStore) insertValuesAtStem(ref nodeRef, stem []byte, values [][]byte, resolver nodeResolverFn, depth int) (nodeRef, error) {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
if bit == 0 {
if node.left.Kind() == kindHashed {
if resolver == nil {
return ref, errors.New("insertValuesAtStem: cannot resolve hashed node without resolver")
}
hn := s.getHashed(node.left.Index())
path, err := keyToPath(int(node.depth), stem)
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
}
data, err := resolver(path, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
resolved, err := s.deserializeNodeWithHash(data, int(node.depth)+1, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
}
s.freeHashedNode(node.left.Index())
node.left = resolved
}
newChild, err := s.insertValuesAtStem(node.left, stem, values, resolver, depth+1)
if err != nil {
return ref, err
}
node.left = newChild
} else {
if node.right.Kind() == kindHashed {
if resolver == nil {
return ref, errors.New("insertValuesAtStem: cannot resolve hashed node without resolver")
}
hn := s.getHashed(node.right.Index())
path, err := keyToPath(int(node.depth), stem)
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
}
data, err := resolver(path, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
resolved, err := s.deserializeNodeWithHash(data, int(node.depth)+1, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
}
s.freeHashedNode(node.right.Index())
node.right = resolved
}
newChild, err := s.insertValuesAtStem(node.right, stem, values, resolver, depth+1)
if err != nil {
return ref, err
}
node.right = newChild
}
node.mustRecompute = true
node.dirty = true
return ref, nil
case kindStem:
sn := s.getStem(ref.Index())
if sn.Stem == [StemSize]byte(stem[:StemSize]) {
// Same stem — merge values (setValue marks dirty+mustRecompute)
for i, v := range values {
if v != nil {
sn.setValue(byte(i), v)
}
}
return ref, nil
}
// Different stem — split
return s.splitStemValuesInsert(ref, stem, values, resolver, depth)
case kindHashed:
hn := s.getHashed(ref.Index())
path, err := keyToPath(depth, stem)
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
}
if resolver == nil {
return ref, errors.New("InsertValuesAtStem: resolver is nil")
}
data, err := resolver(path, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
resolved, err := s.deserializeNodeWithHash(data, depth, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
}
s.freeHashedNode(ref.Index())
return s.insertValuesAtStem(resolved, stem, values, resolver, depth)
case kindEmpty:
// Create new StemNode. Flag flips before the value loop so an
// all-nil values input still marks the newly-created stem dirty.
stemIdx := s.allocStem()
sn := s.getStem(stemIdx)
copy(sn.Stem[:], stem[:StemSize])
sn.depth = uint8(depth)
sn.mustRecompute = true
sn.dirty = true
for i, v := range values {
if v != nil {
sn.setValue(byte(i), v)
}
}
return makeRef(kindStem, stemIdx), nil
default:
return ref, fmt.Errorf("insertValuesAtStem: unexpected kind %d", ref.Kind())
}
}
// splitStemValuesInsert splits a StemNode when the new stem diverges.
func (s *nodeStore) splitStemValuesInsert(existingRef nodeRef, newStem []byte, values [][]byte, resolver nodeResolverFn, depth int) (nodeRef, error) {
existing := s.getStem(existingRef.Index())
if int(existing.depth) >= StemSize*8 {
panic("splitStemValuesInsert: identical stems")
}
bitStem := existing.Stem[existing.depth/8] >> (7 - (existing.depth % 8)) & 1
nRef := s.newInternalRef(int(existing.depth))
nNode := s.getInternal(nRef.Index())
existing.depth++
bitKey := newStem[nNode.depth/8] >> (7 - (nNode.depth % 8)) & 1
if bitKey == bitStem {
// Same direction — need deeper split
var child nodeRef
if bitStem == 0 {
nNode.left = existingRef
child = nNode.left
} else {
nNode.right = existingRef
child = nNode.right
}
newChild, err := s.insertValuesAtStem(child, newStem, values, resolver, depth+1)
if err != nil {
// Roll back the depth increment so a retry sees the same
// existing state and extracts bitStem at the correct offset.
// nRef itself leaks (no internal free-list), but the slot is
// unreachable from the tree and harmless.
existing.depth--
return nRef, err
}
if bitStem == 0 {
nNode.left = newChild
nNode.right = emptyRef
} else {
nNode.right = newChild
nNode.left = emptyRef
}
} else {
// Divergence — create new StemNode for the new values
newStemIdx := s.allocStem()
newSn := s.getStem(newStemIdx)
copy(newSn.Stem[:], newStem[:StemSize])
newSn.depth = nNode.depth + 1
newSn.mustRecompute = true
newSn.dirty = true
for i, v := range values {
if v != nil {
newSn.setValue(byte(i), v)
}
}
newStemRef := makeRef(kindStem, newStemIdx)
if bitStem == 0 {
nNode.left = existingRef
nNode.right = newStemRef
} else {
nNode.left = newStemRef
nNode.right = existingRef
}
}
return nRef, nil
}
func (s *nodeStore) Insert(key []byte, value []byte, resolver nodeResolverFn) error {
return s.InsertSingle(key[:StemSize], key[StemSize], value, resolver)
}
func (s *nodeStore) Get(key []byte, resolver nodeResolverFn) ([]byte, error) {
return s.GetValue(key[:StemSize], key[StemSize], resolver)
}
func (s *nodeStore) getHeight(ref nodeRef) int {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
lh := s.getHeight(node.left)
rh := s.getHeight(node.right)
if lh > rh {
return 1 + lh
}
return 1 + rh
case kindStem:
return 1
case kindEmpty:
return 0
default:
return 0
}
}

View file

@ -19,7 +19,6 @@ package bintrie
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
@ -31,8 +30,6 @@ import (
"github.com/holiman/uint256"
)
var errInvalidRootType = errors.New("invalid root type")
// ChunkedCode represents a sequence of HashSize-byte chunks of code (StemSize bytes of which
// are actual code, and NodeTypeBytes byte is the pushdata offset).
type ChunkedCode []byte
@ -108,22 +105,17 @@ func ChunkifyCode(code []byte) ChunkedCode {
return chunks
}
// NewBinaryNode creates a new empty binary trie
func NewBinaryNode() BinaryNode {
return Empty{}
}
// BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864.
type BinaryTrie struct {
root BinaryNode
store *nodeStore
reader *trie.Reader
tracer *trie.PrevalueTracer
}
// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
func (t *BinaryTrie) ToDot() string {
t.root.Hash()
return ToDot(t.root)
t.store.computeHash(t.store.root)
return t.store.toDot(t.store.root, "", "")
}
// NewBinaryTrie creates a new binary trie.
@ -133,7 +125,7 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
return nil, err
}
t := &BinaryTrie{
root: NewBinaryNode(),
store: newNodeStore(),
reader: reader,
tracer: trie.NewPrevalueTracer(),
}
@ -143,11 +135,11 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
if err != nil {
return nil, err
}
node, err := DeserializeNodeWithHash(blob, 0, root)
ref, err := t.store.deserializeNodeWithHash(blob, 0, root)
if err != nil {
return nil, err
}
t.root = node
t.store.root = ref
}
return t, nil
}
@ -176,29 +168,18 @@ func (t *BinaryTrie) GetKey(key []byte) []byte {
// GetWithHashedKey returns the value, assuming that the key has already
// been hashed.
func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) {
return t.root.Get(key, t.nodeResolver)
return t.store.Get(key, t.nodeResolver)
}
// GetAccount returns the account information for the given address.
func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) {
var (
values [][]byte
err error
acc = &types.StateAccount{}
key = GetBinaryTreeKey(addr, zero[:])
err error
acc = &types.StateAccount{}
key = GetBinaryTreeKey(addr, zero[:])
)
switch r := t.root.(type) {
case *InternalNode:
values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver)
case *StemNode:
values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver)
case Empty:
return nil, nil
default:
// This will cover HashedNode but that should be fine since the
// root node should always be resolved.
return nil, errInvalidRootType
}
values, err := t.store.GetValuesAtStem(key[:StemSize], t.nodeResolver)
if err != nil {
return nil, fmt.Errorf("GetAccount (%x) error: %v", addr, err)
}
@ -219,7 +200,7 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
// If the account has been deleted, BasicData and CodeHash will both be
// 32-byte zero blobs (not nil). If the account is recreated afterwards,
// UpdateAccount overwrites BasicData and CodeHash with non-zero values,
// so this branch won't activate..
// so this branch won't activate.
if bytes.Equal(values[BasicDataLeafKey], zero[:]) &&
bytes.Equal(values[CodeHashLeafKey], zero[:]) {
return nil, nil
@ -238,13 +219,12 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
// not be modified by the caller. If a node was not found in the database, a
// trie.MissingNodeError is returned.
func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) {
return t.root.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
return t.store.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
}
// UpdateAccount updates the account information for the given address.
func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error {
var (
err error
basicData [HashSize]byte
values = make([][]byte, StemNodeWidth)
stem = GetBinaryTreeKey(addr, zero[:])
@ -265,15 +245,12 @@ func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount,
values[BasicDataLeafKey] = basicData[:]
values[CodeHashLeafKey] = acc.CodeHash[:]
t.root, err = t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0)
return err
return t.store.InsertValuesAtStem(stem, values, t.nodeResolver)
}
// UpdateStem updates the values for the given stem key.
func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error {
var err error
t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0)
return err
return t.store.InsertValuesAtStem(key, values, t.nodeResolver)
}
// UpdateStorage associates key with value in the trie. If value has length zero, any
@ -288,11 +265,10 @@ func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) er
} else {
copy(v[HashSize-len(value):], value[:])
}
root, err := t.root.Insert(k, v[:], t.nodeResolver, 0)
err := t.store.Insert(k, v[:], t.nodeResolver)
if err != nil {
return fmt.Errorf("UpdateStorage (%x) error: %v", address, err)
}
t.root = root
return nil
}
@ -307,12 +283,7 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
values[BasicDataLeafKey] = zero[:]
values[CodeHashLeafKey] = zero[:]
root, err := t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0)
if err != nil {
return fmt.Errorf("DeleteAccount (%x) error: %v", addr, err)
}
t.root = root
return nil
return t.store.InsertValuesAtStem(stem, values, t.nodeResolver)
}
// DeleteStorage removes any existing value for key from the trie. If a node was not
@ -320,18 +291,17 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error {
k := GetBinaryTreeKeyStorageSlot(addr, key)
var zero [HashSize]byte
root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0)
err := t.store.Insert(k, zero[:], t.nodeResolver)
if err != nil {
return fmt.Errorf("DeleteStorage (%x) error: %v", addr, err)
}
t.root = root
return nil
}
// Hash returns the root hash of the trie. It does not write to the database and
// can be used even if the trie doesn't have one.
func (t *BinaryTrie) Hash() common.Hash {
return t.root.Hash()
return t.store.computeHash(t.store.root)
}
// Commit writes all nodes to the trie's memory database, tracking the internal
@ -339,15 +309,15 @@ func (t *BinaryTrie) Hash() common.Hash {
func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
nodeset := trienode.NewNodeSet(common.Hash{})
// The root can be any type of BinaryNode (InternalNode, StemNode, etc.)
err := t.root.CollectNodes(nil, func(path []byte, node BinaryNode) {
serialized := SerializeNode(node)
nodeset.AddNode(path, trienode.NewNodeWithPrev(node.Hash(), serialized, t.tracer.Get(path)))
// Pre-size the path buffer: collectNodes reuses it in-place via
// append/truncate; 32 covers typical binary-trie depth without regrowth.
pathBuf := make([]byte, 0, 32)
err := t.store.collectNodes(t.store.root, pathBuf, func(path []byte, hash common.Hash, serialized []byte) {
nodeset.AddNode(path, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(path)))
})
if err != nil {
panic(fmt.Errorf("CollectNodes failed: %v", err))
}
// Serialize root commitment form
return t.Hash(), nodeset
}
@ -371,7 +341,7 @@ func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
// Copy creates a deep copy of the trie.
func (t *BinaryTrie) Copy() *BinaryTrie {
return &BinaryTrie{
root: t.root.Copy(),
store: t.store.Copy(),
reader: t.reader,
tracer: t.tracer.Copy(),
}
@ -407,7 +377,6 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has
if groupOffset == StemNodeWidth-1 || len(chunks)-i <= HashSize {
err = t.UpdateStem(key[:StemSize], values)
if err != nil {
return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err)
}

View file

@ -37,147 +37,130 @@ var (
)
func TestSingleEntry(t *testing.T) {
tree := NewBinaryNode()
tree, err := tree.Insert(zeroKey[:], oneKey[:], nil, 0)
if err != nil {
s := newNodeStore()
if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil {
t.Fatal(err)
}
if tree.GetHeight() != 1 {
if s.getHeight(s.root) != 1 {
t.Fatal("invalid depth")
}
expected := common.HexToHash("aab1060e04cb4f5dc6f697ae93156a95714debbf77d54238766adc5709282b6f")
got := tree.Hash()
got := s.Hash()
if got != expected {
t.Fatalf("invalid tree root, got %x, want %x", got, expected)
}
}
func TestTwoEntriesDiffFirstBit(t *testing.T) {
var err error
tree := NewBinaryNode()
tree, err = tree.Insert(zeroKey[:], oneKey[:], nil, 0)
if err != nil {
s := newNodeStore()
if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil {
t.Fatal(err)
}
tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil, 0)
if err != nil {
if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
if tree.GetHeight() != 2 {
if s.getHeight(s.root) != 2 {
t.Fatal("invalid height")
}
if tree.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") {
if s.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") {
t.Fatal("invalid tree root")
}
}
func TestOneStemColocatedValues(t *testing.T) {
var err error
tree := NewBinaryNode()
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
if err != nil {
s := newNodeStore()
if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
t.Fatal(err)
}
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
if err != nil {
if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil, 0)
if err != nil {
if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil); err != nil {
t.Fatal(err)
}
tree, err = tree.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil, 0)
if err != nil {
if err := s.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil); err != nil {
t.Fatal(err)
}
if tree.GetHeight() != 1 {
if s.getHeight(s.root) != 1 {
t.Fatal("invalid height")
}
}
func TestTwoStemColocatedValues(t *testing.T) {
var err error
tree := NewBinaryNode()
s := newNodeStore()
// stem: 0...0
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
if err != nil {
if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
t.Fatal(err)
}
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
if err != nil {
if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
// stem: 10...0
tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
if err != nil {
if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
t.Fatal(err)
}
tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
if err != nil {
if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
t.Fatal(err)
}
if tree.GetHeight() != 2 {
if s.getHeight(s.root) != 2 {
t.Fatal("invalid height")
}
}
func TestTwoKeysMatchFirst42Bits(t *testing.T) {
var err error
tree := NewBinaryNode()
s := newNodeStore()
// key1 and key 2 have the same prefix of 42 bits (b0*42+b1+b1) and differ after.
key1 := common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0").Bytes()
key2 := common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000").Bytes()
tree, err = tree.Insert(key1, oneKey[:], nil, 0)
if err != nil {
if err := s.Insert(key1, oneKey[:], nil); err != nil {
t.Fatal(err)
}
tree, err = tree.Insert(key2, twoKey[:], nil, 0)
if err != nil {
if err := s.Insert(key2, twoKey[:], nil); err != nil {
t.Fatal(err)
}
if tree.GetHeight() != 1+42+1 {
if s.getHeight(s.root) != 1+42+1 {
t.Fatal("invalid height")
}
}
func TestInsertDuplicateKey(t *testing.T) {
var err error
tree := NewBinaryNode()
tree, err = tree.Insert(oneKey[:], oneKey[:], nil, 0)
if err != nil {
s := newNodeStore()
if err := s.Insert(oneKey[:], oneKey[:], nil); err != nil {
t.Fatal(err)
}
tree, err = tree.Insert(oneKey[:], twoKey[:], nil, 0)
if err != nil {
if err := s.Insert(oneKey[:], twoKey[:], nil); err != nil {
t.Fatal(err)
}
if tree.GetHeight() != 1 {
if s.getHeight(s.root) != 1 {
t.Fatal("invalid height")
}
// Verify that the value is updated
if !bytes.Equal(tree.(*StemNode).Values[1], twoKey[:]) {
t.Fatal("invalid height")
v, err := s.Get(oneKey[:], nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(v, twoKey[:]) {
t.Fatal("value not updated")
}
}
func TestLargeNumberOfEntries(t *testing.T) {
var err error
tree := NewBinaryNode()
s := newNodeStore()
for i := range StemNodeWidth {
var key [HashSize]byte
key[0] = byte(i)
tree, err = tree.Insert(key[:], ffKey[:], nil, 0)
if err != nil {
if err := s.Insert(key[:], ffKey[:], nil); err != nil {
t.Fatal(err)
}
}
height := tree.GetHeight()
height := s.getHeight(s.root)
if height != 1+8 {
t.Fatalf("invalid height, wanted %d, got %d", 1+8, height)
}
}
func TestMerkleizeMultipleEntries(t *testing.T) {
var err error
tree := NewBinaryNode()
s := newNodeStore()
keys := [][]byte{
zeroKey[:],
common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(),
@ -187,12 +170,11 @@ func TestMerkleizeMultipleEntries(t *testing.T) {
for i, key := range keys {
var v [HashSize]byte
binary.LittleEndian.PutUint64(v[:8], uint64(i))
tree, err = tree.Insert(key, v[:], nil, 0)
if err != nil {
if err := s.Insert(key, v[:], nil); err != nil {
t.Fatal(err)
}
}
got := tree.Hash()
got := s.Hash()
expected := common.HexToHash("9317155862f7a3867660ddd0966ff799a3d16aa4df1e70a7516eaa4a675191b5")
if got != expected {
t.Fatalf("invalid root, expected=%x, got = %x", expected, got)
@ -206,7 +188,7 @@ func TestMerkleizeMultipleEntries(t *testing.T) {
func TestStorageRoundTrip(t *testing.T) {
tracer := trie.NewPrevalueTracer()
tr := &BinaryTrie{
root: NewBinaryNode(),
store: newNodeStore(),
tracer: tracer,
}
addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678")
@ -274,7 +256,7 @@ func TestStorageRoundTrip(t *testing.T) {
func newEmptyTestTrie(t *testing.T) *BinaryTrie {
t.Helper()
return &BinaryTrie{
root: NewBinaryNode(),
store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
}
@ -599,7 +581,7 @@ func TestBinaryTrieWitness(t *testing.T) {
tracer := trie.NewPrevalueTracer()
tr := &BinaryTrie{
root: NewBinaryNode(),
store: newNodeStore(),
tracer: tracer,
}
if w := tr.Witness(); len(w) != 0 {
@ -626,7 +608,7 @@ func TestBinaryTrieWitness(t *testing.T) {
func testAccount(t *testing.T, addr common.Address, nonce uint64, balance uint64) *BinaryTrie {
t.Helper()
tr := &BinaryTrie{
root: NewBinaryNode(),
store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
acc := &types.StateAccount{
@ -649,8 +631,8 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) {
tr := testAccount(t, addr, 42, 100)
// Verify root is a StemNode (single stem inserted).
if _, ok := tr.root.(*StemNode); !ok {
t.Fatalf("expected StemNode root, got %T", tr.root)
if tr.store.root.Kind() != kindStem {
t.Fatalf("expected StemNode root, got kind %d", tr.store.root.Kind())
}
// Query a completely different address — must return nil.
@ -680,7 +662,7 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) {
// address returns nil when the trie root is an InternalNode (multi-account trie).
func TestGetAccountNonMembershipInternalRoot(t *testing.T) {
tr := &BinaryTrie{
root: NewBinaryNode(),
store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
@ -700,8 +682,8 @@ func TestGetAccountNonMembershipInternalRoot(t *testing.T) {
}
// Verify root is an InternalNode.
if _, ok := tr.root.(*InternalNode); !ok {
t.Fatalf("expected InternalNode root, got %T", tr.root)
if tr.store.root.Kind() != kindInternal {
t.Fatalf("expected InternalNode root, got kind %d", tr.store.root.Kind())
}
// Query a non-existent address — must return nil.
@ -723,8 +705,8 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) {
tr := testAccount(t, addr, 1, 100)
// Verify root is a StemNode.
if _, ok := tr.root.(*StemNode); !ok {
t.Fatalf("expected StemNode root, got %T", tr.root)
if tr.store.root.Kind() != kindStem {
t.Fatalf("expected StemNode root, got kind %d", tr.store.root.Kind())
}
// Query storage for a different address — must return nil, not panic.
@ -743,7 +725,7 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) {
// non-existent address returns nil when the root is an InternalNode.
func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
tr := &BinaryTrie{
root: NewBinaryNode(),
store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
@ -765,8 +747,8 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
t.Fatalf("UpdateStorage error: %v", err)
}
if _, ok := tr.root.(*InternalNode); !ok {
t.Fatalf("expected InternalNode root, got %T", tr.root)
if tr.store.root.Kind() != kindInternal {
t.Fatalf("expected InternalNode root, got kind %d", tr.store.root.Kind())
}
// Query storage for a non-existent address — must return nil.
@ -780,36 +762,29 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
}
}
// commitKeyN derives a distinct 32-byte key from a seed integer. Used by
// TestBinaryTrieCommitIncremental and BenchmarkCollectNodes_SparseWrite to
// populate a trie with many disjoint stems.
func commitKeyN(i int) [HashSize]byte {
var k [HashSize]byte
binary.BigEndian.PutUint64(k[:8], uint64(i)*0x9e3779b97f4a7c15)
binary.BigEndian.PutUint64(k[8:16], uint64(i)*0xc2b2ae3d27d4eb4f)
binary.BigEndian.PutUint64(k[16:24], uint64(i)*0x165667b19e3779f9)
binary.BigEndian.PutUint64(k[24:32], uint64(i)*0x85ebca77c2b2ae63)
return k
}
// TestBinaryTrieCommitIncremental verifies that a second Commit with only a
// single modified leaf flushes only the path from that leaf to the root,
// not the entire tree.
func TestBinaryTrieCommitIncremental(t *testing.T) {
// TestCommitSkipCleanSubtrees verifies that CollectNodes short-circuits on
// clean subtrees. First Commit flushes every resolved node; a follow-up
// Commit with no modifications flushes nothing; a single-leaf modification
// flushes only the root-to-leaf path.
func TestCommitSkipCleanSubtrees(t *testing.T) {
tr := &BinaryTrie{
root: NewBinaryNode(),
store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
const n = 512
keys := make([][HashSize]byte, n)
const n = 200
key := func(i int) [HashSize]byte {
var k [HashSize]byte
binary.BigEndian.PutUint64(k[:8], uint64(i+1)*0x9e3779b97f4a7c15)
binary.BigEndian.PutUint64(k[8:16], uint64(i+1)*0xc2b2ae3d27d4eb4f)
binary.BigEndian.PutUint64(k[16:24], uint64(i+1)*0x165667b19e3779f9)
binary.BigEndian.PutUint64(k[24:32], uint64(i+1)*0x85ebca77c2b2ae63)
return k
}
for i := range n {
keys[i] = commitKeyN(i + 1)
k := key(i)
var v [HashSize]byte
binary.BigEndian.PutUint64(v[24:], uint64(i+1))
var err error
tr.root, err = tr.root.Insert(keys[i][:], v[:], nil, 0)
if err != nil {
if err := tr.store.Insert(k[:], v[:], nil); err != nil {
t.Fatalf("Insert %d: %v", i, err)
}
}
@ -818,69 +793,55 @@ func TestBinaryTrieCommitIncremental(t *testing.T) {
if len(ns1.Nodes) == 0 {
t.Fatal("first Commit produced empty NodeSet")
}
if len(ns1.Nodes) < n {
t.Fatalf("first Commit: expected at least %d nodes, got %d", n, len(ns1.Nodes))
}
// Second Commit on the same trie with no modifications: NodeSet must
// be empty because every subtree is clean.
_, nsNoop := tr.Commit(false)
if len(nsNoop.Nodes) != 0 {
t.Fatalf("no-op Commit: expected empty NodeSet, got %d nodes", len(nsNoop.Nodes))
t.Fatalf("no-op Commit: expected empty NodeSet, got %d", len(nsNoop.Nodes))
}
// Modify a single leaf's value. Only the path from that leaf to the
// root should appear in the next Commit's NodeSet.
// Modify a single leaf — only the root-to-leaf path should flush.
k := key(n / 2)
var newVal [HashSize]byte
newVal[0] = 0xff
var err error
tr.root, err = tr.root.Insert(keys[n/2][:], newVal[:], nil, 0)
if err != nil {
if err := tr.store.Insert(k[:], newVal[:], nil); err != nil {
t.Fatalf("Insert (modify): %v", err)
}
_, ns2 := tr.Commit(false)
// Path length for a binary trie of n=512 stems is bounded by the
// internal depth at which the modified stem sits. Allow generous
// slack: up to 64 nodes is fine, anywhere near n (512) is a regression.
if len(ns2.Nodes) == 0 {
t.Fatal("modified Commit produced empty NodeSet")
}
if len(ns2.Nodes) > 64 {
t.Fatalf("modified Commit: expected small NodeSet, got %d nodes (first Commit had %d)", len(ns2.Nodes), len(ns1.Nodes))
if len(ns2.Nodes) > 32 {
t.Fatalf("modified Commit: expected ≤32 nodes (path+stem), got %d", len(ns2.Nodes))
}
if len(ns2.Nodes) >= len(ns1.Nodes) {
t.Fatalf("expected second NodeSet (%d) to be smaller than first (%d)", len(ns2.Nodes), len(ns1.Nodes))
}
}
// BenchmarkCollectNodes_SparseWrite measures Commit cost when only one leaf
// changes between blocks — the common case for state updates. After warm-up
// BenchmarkCollectNodesSparseWrite measures Commit cost when one leaf
// changes per block — the common case for state updates. After warm-up
// (populate + initial Commit), each iteration modifies a single leaf and
// re-Commits. Under the skip-clean optimization, each iteration flushes
// only the root-to-leaf path; pre-fix behavior would re-flush the entire
// tree every iteration.
func BenchmarkCollectNodes_SparseWrite(b *testing.B) {
// re-Commits. Matches the shape of the same-named benchmark on master so
// the two trees can be benchstat'd directly.
func BenchmarkCollectNodesSparseWrite(b *testing.B) {
const n = 10_000
tr := &BinaryTrie{
root: NewBinaryNode(),
store: newNodeStore(),
tracer: trie.NewPrevalueTracer(),
}
keys := make([][HashSize]byte, n)
for i := range n {
keys[i] = commitKeyN(i + 1)
binary.BigEndian.PutUint64(keys[i][:8], uint64(i+1)*0x9e3779b97f4a7c15)
binary.BigEndian.PutUint64(keys[i][8:16], uint64(i+1)*0xc2b2ae3d27d4eb4f)
binary.BigEndian.PutUint64(keys[i][16:24], uint64(i+1)*0x165667b19e3779f9)
binary.BigEndian.PutUint64(keys[i][24:32], uint64(i+1)*0x85ebca77c2b2ae63)
var v [HashSize]byte
binary.BigEndian.PutUint64(v[24:], uint64(i+1))
var err error
tr.root, err = tr.root.Insert(keys[i][:], v[:], nil, 0)
if err != nil {
b.Fatalf("Insert %d: %v", i, err)
if err := tr.store.Insert(keys[i][:], v[:], nil); err != nil {
b.Fatalf("warmup Insert %d: %v", i, err)
}
}
// Flush the initial tree so subsequent Commits reflect the
// single-modification workload we want to measure.
_, _ = tr.Commit(false)
_, _ = tr.Commit(false) // warmup flush
var newVal [HashSize]byte
b.ReportAllocs()
@ -888,10 +849,8 @@ func BenchmarkCollectNodes_SparseWrite(b *testing.B) {
for i := 0; i < b.N; i++ {
idx := i % n
binary.BigEndian.PutUint64(newVal[24:], uint64(i+1))
var err error
tr.root, err = tr.root.Insert(keys[idx][:], newVal[:], nil, 0)
if err != nil {
b.Fatalf("Insert at iter %d: %v", i, err)
if err := tr.store.Insert(keys[idx][:], newVal[:], nil); err != nil {
b.Fatalf("iter %d Insert: %v", i, err)
}
_, ns := tr.Commit(false)
if len(ns.Nodes) == 0 {

View file

@ -102,11 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) {
if len(blob) == 0 {
return types.EmptyBinaryHash, nil
}
n, err := bintrie.DeserializeNode(blob, 0)
if err != nil {
return common.Hash{}, err
}
return n.Hash(), nil
return bintrie.DeserializeAndHash(blob, 0)
}
// Database is a multiple-layered structure for maintaining in-memory states