trie/bintrie: replace BinaryNode interface with GC-free NodeRef arena (#34055)

## Summary

Replace the `BinaryNode` interface with `NodeRef uint32` indices into
typed arena pools, eliminating GC-scanned pointers from binary trie
nodes.

Inspired by [fjl's
observation](https://github.com/ethereum/go-ethereum/pull/34034#issuecomment-4075176446):
> *"if the binary trie produces such a large graph, it should probably
be changed so that the trie node type does not contain pointers. The
runtime does not scan objects that do not contain pointers, so it can
really help with the performance to build it this way."*

### The problem

CPU profiling of the binary trie (EIP-7864) showed **44% of CPU time in
garbage collection**. Each `InternalNode` held two `BinaryNode`
interface values (2 pointer-words each), and the GC scanned every one.
With ~25K `InternalNode`s in memory during block processing, this
created enormous GC pressure.

### The solution

`NodeRef` is a compact `uint32` (2-bit kind tag + 30-bit pool index).
`NodeStore` manages chunked typed pools per node kind:
- **InternalNode pool**: ZERO Go pointers (children are `NodeRef`, hash
is `[32]byte`) → noscan spans
- **HashedNode pool**: ZERO Go pointers → noscan spans
- **StemNode pool**: retains `Values [][]byte` (matching existing
format)

The serialization format is unchanged — flat InternalNode
`[type][leftHash][rightHash]` = 65 bytes.

## Benchmark: Apple M4 Pro (`--benchtime=10s --count=3`, on top of
#34021)

| Metric | Baseline | Arena | Delta |
|--------|----------|-------|-------|
| Approve (Mgas/s) | 374 | 382 | **+2.1%** |
| BalanceOf (Mgas/s) | 885 | 901 | **+1.8%** |
| Approve allocs/op | 775K | **607K** | **-21.7%** |
| BalanceOf allocs/op | 265K | **228K** | **-14.0%** |

## Benchmark: AMD EPYC 48-core (50GB state, execution-specs ERC-20, on
top of #34021 + #34032)

| Benchmark | Baseline | Arena | Delta |
|-----------|----------|-------|-------|
| erc20_approve (write) | 22.4 Mgas/s | **27.0 Mgas/s** | **+20.5%** |
| mixed_sload_sstore | 62.9 Mgas/s | **97.3 Mgas/s** | **+54.7%** |
| erc20_balanceof (read) | 180.8 Mgas/s | 167.6 Mgas/s | -7.3% (cold
cache variance) |

The arena benefit scales with heap size — the EPYC (larger heap, more GC
pressure) shows much larger gains than the M4 Pro (efficient unified
memory). The mixed workload baseline was unstable (62.9 vs 16.3 Mgas/s
between runs due to GC-induced throughput collapse); the arena
eliminates this entirely (95-97 Mgas/s, stable).

## Dependencies

Benchmarked with #34021 (H01 N+1 fix) + #34032 (R14 parallel hashing).
No code dependency — applies independently to master.

All test suites pass (`trie/bintrie` with `-race`, `core/state`,
`triedb/pathdb`, `cmd/geth`).

---------

Co-authored-by: Guillaume Ballet <3272758+gballet@users.noreply.github.com>
This commit is contained in:
CPerezz 2026-04-20 14:08:30 +02:00 committed by GitHub
parent 29e0a6f404
commit b6d415c88d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 1689 additions and 2092 deletions

View file

@ -16,140 +16,32 @@
package bintrie package bintrie
import ( import "github.com/ethereum/go-ethereum/common"
"errors"
"github.com/ethereum/go-ethereum/common"
)
type (
NodeFlushFn func([]byte, BinaryNode)
NodeResolverFn func([]byte, common.Hash) ([]byte, error)
)
// zero is the zero value for a 32-byte array. // zero is the zero value for a 32-byte array.
var zero [32]byte var zero [32]byte
const ( const (
StemNodeWidth = 256 // Number of child per leaf node StemNodeWidth = 256 // Number of children per leaf node
StemSize = 31 // Number of bytes to travel before reaching a group of leaves StemSize = 31 // Number of bytes to travel before reaching a group of leaves
NodeTypeBytes = 1 // Size of node type prefix in serialization NodeTypeBytes = 1 // Size of node type prefix in serialization
HashSize = 32 // Size of a hash in bytes HashSize = 32 // Size of a hash in bytes
BitmapSize = 32 // Size of the bitmap in a stem node StemBitmapSize = 32 // Size of the bitmap in a stem node (256 values = 32 bytes)
) )
const ( const (
nodeTypeStem = iota + 1 // Stem node, contains a stem and a bitmap of values nodeTypeStem = iota + 1
nodeTypeInternal nodeTypeInternal
) )
// BinaryNode is an interface for a binary trie node. // DeserializeAndHash deserializes a node from bytes and returns its hash.
type BinaryNode interface { // This is a convenience function for external callers that need to compute
Get([]byte, NodeResolverFn) ([]byte, error) // the hash of a serialized node without maintaining a nodeStore.
Insert([]byte, []byte, NodeResolverFn, int) (BinaryNode, error) func DeserializeAndHash(blob []byte, depth int) (common.Hash, error) {
Copy() BinaryNode s := newNodeStore()
Hash() common.Hash ref, err := s.deserializeNode(blob, depth)
GetValuesAtStem([]byte, NodeResolverFn) ([][]byte, error) if err != nil {
InsertValuesAtStem([]byte, [][]byte, NodeResolverFn, int) (BinaryNode, error) return common.Hash{}, err
CollectNodes([]byte, NodeFlushFn) error
toDot(parent, path string) string
GetHeight() int
}
// SerializeNode serializes a binary trie node into a byte slice.
func SerializeNode(node BinaryNode) []byte {
switch n := (node).(type) {
case *InternalNode:
// InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash
var serialized [NodeTypeBytes + HashSize + HashSize]byte
serialized[0] = nodeTypeInternal
copy(serialized[1:33], n.left.Hash().Bytes())
copy(serialized[33:65], n.right.Hash().Bytes())
return serialized[:]
case *StemNode:
// StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values
var serialized [NodeTypeBytes + StemSize + BitmapSize + StemNodeWidth*HashSize]byte
serialized[0] = nodeTypeStem
copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], n.Stem)
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize]
offset := NodeTypeBytes + StemSize + BitmapSize
for i, v := range n.Values {
if v != nil {
bitmap[i/8] |= 1 << (7 - (i % 8))
copy(serialized[offset:offset+HashSize], v)
offset += HashSize
}
}
// Only return the actual data, not the entire array
return serialized[:offset]
default:
panic("invalid node type")
} }
} return s.computeHash(ref), nil
var invalidSerializedLength = errors.New("invalid serialized node length")
// DeserializeNode deserializes a binary trie node from a byte slice. The
// hash will be recomputed from the deserialized data.
func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
return deserializeNode(serialized, depth, common.Hash{}, true, true)
}
// DeserializeNodeWithHash deserializes a binary trie node from a byte slice, using the provided hash.
func DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
return deserializeNode(serialized, depth, hn, false, false)
}
func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (BinaryNode, error) {
if len(serialized) == 0 {
return Empty{}, nil
}
switch serialized[0] {
case nodeTypeInternal:
if len(serialized) != 65 {
return nil, invalidSerializedLength
}
return &InternalNode{
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
hash: hn,
mustRecompute: mustRecompute,
dirty: dirty,
}, nil
case nodeTypeStem:
if len(serialized) < 64 {
return nil, invalidSerializedLength
}
var values [StemNodeWidth][]byte
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize]
offset := NodeTypeBytes + StemSize + BitmapSize
for i := range StemNodeWidth {
if bitmap[i/8]>>(7-(i%8))&1 == 1 {
if len(serialized) < offset+HashSize {
return nil, invalidSerializedLength
}
values[i] = serialized[offset : offset+HashSize]
offset += HashSize
}
}
return &StemNode{
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
Values: values[:],
depth: depth,
hash: hn,
mustRecompute: mustRecompute,
dirty: dirty,
}, nil
default:
return nil, errors.New("invalid node type")
}
}
// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
func ToDot(root BinaryNode) string {
return root.toDot("", "")
} }

View file

@ -23,79 +23,100 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
// TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode // TestSerializeDeserializeInternalNode tests flat 65-byte serialization and
// deserialization of InternalNode through nodeStore.
func TestSerializeDeserializeInternalNode(t *testing.T) { func TestSerializeDeserializeInternalNode(t *testing.T) {
// Create an internal node with two hashed children
leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321") rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321")
node := &InternalNode{ s := newNodeStore()
depth: 5, leftRef := s.newHashedRef(leftHash)
left: HashedNode(leftHash), rightRef := s.newHashedRef(rightHash)
right: HashedNode(rightHash),
}
// Serialize the node rootRef := s.newInternalRef(0)
serialized := SerializeNode(node) rootNode := s.getInternal(rootRef.Index())
rootNode.left = leftRef
rootNode.right = rightRef
s.root = rootRef
// Check the serialized format // Serialize the node — flat 65-byte format
serialized := s.serializeNode(rootRef)
// Check the serialized format: [type(1)][leftHash(32)][rightHash(32)]
if serialized[0] != nodeTypeInternal { if serialized[0] != nodeTypeInternal {
t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0]) t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0])
} }
if len(serialized) != 65 { expectedLen := NodeTypeBytes + 2*HashSize // 1 + 64 = 65
t.Errorf("Expected serialized length to be 65, got %d", len(serialized)) if len(serialized) != expectedLen {
t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized))
} }
// Deserialize the node // Check that left and right hashes are embedded directly
deserialized, err := DeserializeNode(serialized, 5) if !bytes.Equal(serialized[NodeTypeBytes:NodeTypeBytes+HashSize], leftHash[:]) {
t.Error("Left hash not found at expected position")
}
if !bytes.Equal(serialized[NodeTypeBytes+HashSize:], rightHash[:]) {
t.Error("Right hash not found at expected position")
}
// Deserialize into a new store
ds := newNodeStore()
deserialized, err := ds.deserializeNode(serialized, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to deserialize node: %v", err) t.Fatalf("Failed to deserialize node: %v", err)
} }
// Check that it's an internal node // Root should be an InternalNode
internalNode, ok := deserialized.(*InternalNode) if deserialized.Kind() != kindInternal {
if !ok { t.Fatalf("Expected kindInternal, got kind %d", deserialized.Kind())
t.Fatalf("Expected InternalNode, got %T", deserialized)
} }
// Check the depth internalNode := ds.getInternal(deserialized.Index())
if internalNode.depth != 5 { if internalNode.depth != 0 {
t.Errorf("Expected depth 5, got %d", internalNode.depth) t.Errorf("Expected depth 0, got %d", internalNode.depth)
} }
// Check the left and right hashes // Left child should be a HashedNode with the correct hash
if internalNode.left.Hash() != leftHash { if internalNode.left.Kind() != kindHashed {
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash()) t.Fatalf("Expected left child to be kindHashed, got %d", internalNode.left.Kind())
}
if ds.computeHash(internalNode.left) != leftHash {
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, ds.computeHash(internalNode.left))
} }
if internalNode.right.Hash() != rightHash { // Right child should be a HashedNode with the correct hash
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash()) if internalNode.right.Kind() != kindHashed {
t.Fatalf("Expected right child to be kindHashed, got %d", internalNode.right.Kind())
}
if ds.computeHash(internalNode.right) != rightHash {
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, ds.computeHash(internalNode.right))
} }
} }
// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode // TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode through nodeStore.
func TestSerializeDeserializeStemNode(t *testing.T) { func TestSerializeDeserializeStemNode(t *testing.T) {
// Create a stem node with some values
stem := make([]byte, StemSize) stem := make([]byte, StemSize)
for i := range stem { for i := range stem {
stem[i] = byte(i) stem[i] = byte(i)
} }
var values [StemNodeWidth][]byte var values [StemNodeWidth][]byte
// Add some values at different indices
values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes() values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes()
values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes() values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes()
values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes() values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes()
node := &StemNode{ s := newNodeStore()
Stem: stem, ref := s.newStemRef(stem, 10)
Values: values[:], sn := s.getStem(ref.Index())
depth: 10, for i, v := range values {
if v != nil {
sn.setValue(byte(i), v)
}
} }
// Serialize the node // Serialize the node
serialized := SerializeNode(node) serialized := s.serializeNode(ref)
// Check the serialized format // Check the serialized format
if serialized[0] != nodeTypeStem { if serialized[0] != nodeTypeStem {
@ -107,31 +128,32 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
t.Errorf("Stem mismatch in serialized data") t.Errorf("Stem mismatch in serialized data")
} }
// Deserialize the node // Deserialize into a new store
deserialized, err := DeserializeNode(serialized, 10) ds := newNodeStore()
deserializedRef, err := ds.deserializeNode(serialized, 10)
if err != nil { if err != nil {
t.Fatalf("Failed to deserialize node: %v", err) t.Fatalf("Failed to deserialize node: %v", err)
} }
// Check that it's a stem node if deserializedRef.Kind() != kindStem {
stemNode, ok := deserialized.(*StemNode) t.Fatalf("Expected kindStem, got kind %d", deserializedRef.Kind())
if !ok {
t.Fatalf("Expected StemNode, got %T", deserialized)
} }
stemNode := ds.getStem(deserializedRef.Index())
// Check the stem // Check the stem
if !bytes.Equal(stemNode.Stem, stem) { if !bytes.Equal(stemNode.Stem[:], stem) {
t.Errorf("Stem mismatch after deserialization") t.Errorf("Stem mismatch after deserialization")
} }
// Check the values // Check the values
if !bytes.Equal(stemNode.Values[0], values[0]) { if !bytes.Equal(stemNode.getValue(0), values[0]) {
t.Errorf("Value at index 0 mismatch") t.Errorf("Value at index 0 mismatch")
} }
if !bytes.Equal(stemNode.Values[10], values[10]) { if !bytes.Equal(stemNode.getValue(10), values[10]) {
t.Errorf("Value at index 10 mismatch") t.Errorf("Value at index 10 mismatch")
} }
if !bytes.Equal(stemNode.Values[255], values[255]) { if !bytes.Equal(stemNode.getValue(255), values[255]) {
t.Errorf("Value at index 255 mismatch") t.Errorf("Value at index 255 mismatch")
} }
@ -140,43 +162,43 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
if i == 0 || i == 10 || i == 255 { if i == 0 || i == 10 || i == 255 {
continue continue
} }
if stemNode.Values[i] != nil { if stemNode.hasValue(byte(i)) {
t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i]) t.Errorf("Expected no value at index %d, got %x", i, stemNode.getValue(byte(i)))
} }
} }
} }
// TestDeserializeEmptyNode tests deserialization of empty node // TestDeserializeEmptyNode tests deserialization of empty node.
func TestDeserializeEmptyNode(t *testing.T) { func TestDeserializeEmptyNode(t *testing.T) {
// Empty byte slice should deserialize to Empty node s := newNodeStore()
deserialized, err := DeserializeNode([]byte{}, 0) deserialized, err := s.deserializeNode([]byte{}, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to deserialize empty node: %v", err) t.Fatalf("Failed to deserialize empty node: %v", err)
} }
_, ok := deserialized.(Empty) if !deserialized.IsEmpty() {
if !ok { t.Fatalf("Expected emptyRef, got kind %d", deserialized.Kind())
t.Fatalf("Expected Empty node, got %T", deserialized)
} }
} }
// TestDeserializeInvalidType tests deserialization with invalid type byte // TestDeserializeInvalidType tests deserialization with invalid type byte.
func TestDeserializeInvalidType(t *testing.T) { func TestDeserializeInvalidType(t *testing.T) {
// Create invalid serialized data with unknown type byte s := newNodeStore()
invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid
_, err := DeserializeNode(invalidData, 0) _, err := s.deserializeNode(invalidData, 0)
if err == nil { if err == nil {
t.Fatal("Expected error for invalid type byte, got nil") t.Fatal("Expected error for invalid type byte, got nil")
} }
} }
// TestDeserializeInvalidLength tests deserialization with invalid data length // TestDeserializeInvalidLength tests deserialization with invalid data length.
func TestDeserializeInvalidLength(t *testing.T) { func TestDeserializeInvalidLength(t *testing.T) {
// InternalNode with type byte 1 but wrong length s := newNodeStore()
invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node // InternalNode with valid type byte but wrong length (needs exactly 65 bytes)
invalidData := []byte{nodeTypeInternal, 0, 0, 0}
_, err := DeserializeNode(invalidData, 0) _, err := s.deserializeNode(invalidData, 0)
if err == nil { if err == nil {
t.Fatal("Expected error for invalid data length, got nil") t.Fatal("Expected error for invalid data length, got nil")
} }
@ -186,7 +208,7 @@ func TestDeserializeInvalidLength(t *testing.T) {
} }
} }
// TestKeyToPath tests the keyToPath function // TestKeyToPath tests the keyToPath function.
func TestKeyToPath(t *testing.T) { func TestKeyToPath(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -218,14 +240,14 @@ func TestKeyToPath(t *testing.T) {
}, },
{ {
name: "max valid depth", name: "max valid depth",
depth: StemSize * 8, depth: StemSize*8 - 1,
key: make([]byte, HashSize), key: make([]byte, HashSize),
expected: make([]byte, StemSize*8+1), expected: make([]byte, StemSize*8),
wantErr: false, wantErr: false,
}, },
{ {
name: "depth too large", name: "depth too large",
depth: StemSize*8 + 1, depth: StemSize * 8,
key: make([]byte, HashSize), key: make([]byte, HashSize),
wantErr: true, wantErr: true,
}, },

View file

@ -1,76 +0,0 @@
// Copyright 2025 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"slices"
"github.com/ethereum/go-ethereum/common"
)
type Empty struct{}
func (e Empty) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
return nil, nil
}
func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
var values [256][]byte
values[key[31]] = value
return &StemNode{
Stem: slices.Clone(key[:31]),
Values: values[:],
depth: depth,
mustRecompute: true,
dirty: true,
}, nil
}
func (e Empty) Copy() BinaryNode {
return Empty{}
}
func (e Empty) Hash() common.Hash {
return common.Hash{}
}
func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
var values [256][]byte
return values[:], nil
}
func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
return &StemNode{
Stem: slices.Clone(key[:31]),
Values: values,
depth: depth,
mustRecompute: true,
dirty: true,
}, nil
}
func (e Empty) CollectNodes(_ []byte, _ NodeFlushFn) error {
return nil
}
func (e Empty) toDot(parent string, path string) string {
return ""
}
func (e Empty) GetHeight() int {
return 0
}

View file

@ -1,264 +0,0 @@
// Copyright 2025 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"bytes"
"testing"
"github.com/ethereum/go-ethereum/common"
)
// TestEmptyGet tests the Get method
func TestEmptyGet(t *testing.T) {
node := Empty{}
key := make([]byte, 32)
value, err := node.Get(key, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if value != nil {
t.Errorf("Expected nil value from empty node, got %x", value)
}
}
// TestEmptyInsert tests the Insert method
func TestEmptyInsert(t *testing.T) {
node := Empty{}
key := make([]byte, 32)
key[0] = 0x12
key[31] = 0x34
value := common.HexToHash("0xabcd").Bytes()
newNode, err := node.Insert(key, value, nil, 0)
if err != nil {
t.Fatalf("Failed to insert: %v", err)
}
// Should create a StemNode
stemNode, ok := newNode.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", newNode)
}
// Check the stem (first 31 bytes of key)
if !bytes.Equal(stemNode.Stem, key[:31]) {
t.Errorf("Stem mismatch: expected %x, got %x", key[:31], stemNode.Stem)
}
// Check the value at the correct index (last byte of key)
if !bytes.Equal(stemNode.Values[key[31]], value) {
t.Errorf("Value mismatch at index %d: expected %x, got %x", key[31], value, stemNode.Values[key[31]])
}
// Check that other values are nil
for i := 0; i < 256; i++ {
if i != int(key[31]) && stemNode.Values[i] != nil {
t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i])
}
}
}
// TestEmptyCopy tests the Copy method
func TestEmptyCopy(t *testing.T) {
node := Empty{}
copied := node.Copy()
copiedEmpty, ok := copied.(Empty)
if !ok {
t.Fatalf("Expected Empty, got %T", copied)
}
// Both should be empty
if node != copiedEmpty {
// Empty is a zero-value struct, so copies should be equal
t.Errorf("Empty nodes should be equal")
}
}
// TestEmptyHash tests the Hash method
func TestEmptyHash(t *testing.T) {
node := Empty{}
hash := node.Hash()
// Empty node should have zero hash
if hash != (common.Hash{}) {
t.Errorf("Expected zero hash for empty node, got %x", hash)
}
}
// TestEmptyGetValuesAtStem tests the GetValuesAtStem method
func TestEmptyGetValuesAtStem(t *testing.T) {
node := Empty{}
stem := make([]byte, 31)
values, err := node.GetValuesAtStem(stem, nil)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Should return an array of 256 nil values
if len(values) != 256 {
t.Errorf("Expected 256 values, got %d", len(values))
}
for i, v := range values {
if v != nil {
t.Errorf("Expected nil value at index %d, got %x", i, v)
}
}
}
// TestEmptyInsertValuesAtStem tests the InsertValuesAtStem method
func TestEmptyInsertValuesAtStem(t *testing.T) {
node := Empty{}
stem := make([]byte, 31)
stem[0] = 0x42
var values [256][]byte
values[0] = common.HexToHash("0x0101").Bytes()
values[10] = common.HexToHash("0x0202").Bytes()
values[255] = common.HexToHash("0x0303").Bytes()
newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 5)
if err != nil {
t.Fatalf("Failed to insert values: %v", err)
}
// Should create a StemNode
stemNode, ok := newNode.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", newNode)
}
// Check the stem
if !bytes.Equal(stemNode.Stem, stem) {
t.Errorf("Stem mismatch: expected %x, got %x", stem, stemNode.Stem)
}
// Check the depth
if stemNode.depth != 5 {
t.Errorf("Depth mismatch: expected 5, got %d", stemNode.depth)
}
// Check the values
if !bytes.Equal(stemNode.Values[0], values[0]) {
t.Error("Value at index 0 mismatch")
}
if !bytes.Equal(stemNode.Values[10], values[10]) {
t.Error("Value at index 10 mismatch")
}
if !bytes.Equal(stemNode.Values[255], values[255]) {
t.Error("Value at index 255 mismatch")
}
// Check that values is the same slice (not a copy)
if &stemNode.Values[0] != &values[0] {
t.Error("Expected values to be the same slice reference")
}
}
// TestEmptyCollectNodes tests the CollectNodes method
func TestEmptyCollectNodes(t *testing.T) {
node := Empty{}
var collected []BinaryNode
flushFn := func(path []byte, n BinaryNode) {
collected = append(collected, n)
}
err := node.CollectNodes([]byte{0, 1, 0}, flushFn)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Should not collect anything for empty node
if len(collected) != 0 {
t.Errorf("Expected no collected nodes for empty, got %d", len(collected))
}
}
// TestEmptyToDot tests the toDot method
func TestEmptyToDot(t *testing.T) {
node := Empty{}
dot := node.toDot("parent", "010")
// Should return empty string for empty node
if dot != "" {
t.Errorf("Expected empty string for empty node toDot, got %s", dot)
}
}
// TestEmptyGetHeight tests the GetHeight method
func TestEmptyGetHeight(t *testing.T) {
node := Empty{}
height := node.GetHeight()
// Empty node should have height 0
if height != 0 {
t.Errorf("Expected height 0 for empty node, got %d", height)
}
}
// TestEmptyInsertMarksDirty verifies that a StemNode produced by Empty.Insert
// is marked dirty. Without this, CollectNodes would skip the freshly created
// stem and its blob would never reach disk, producing "missing trie node"
// errors on subsequent reads.
func TestEmptyInsertMarksDirty(t *testing.T) {
key := make([]byte, 32)
key[0] = 0xaa
val := make([]byte, 32)
val[0] = 0xbb
n, err := Empty{}.Insert(key, val, nil, 0)
if err != nil {
t.Fatalf("Insert: %v", err)
}
sn, ok := n.(*StemNode)
if !ok {
t.Fatalf("expected *StemNode, got %T", n)
}
if !sn.dirty {
t.Fatalf("stem produced by Empty.Insert must have dirty=true")
}
}
// TestEmptyInsertValuesAtStemMarksDirty is the analogous guard for the
// bulk-insert entry point. Fresh stems created here must be dirty.
func TestEmptyInsertValuesAtStemMarksDirty(t *testing.T) {
key := make([]byte, 32)
key[0] = 0xcc
values := make([][]byte, 256)
values[0] = make([]byte, 32)
n, err := Empty{}.InsertValuesAtStem(key, values, nil, 3)
if err != nil {
t.Fatalf("InsertValuesAtStem: %v", err)
}
sn, ok := n.(*StemNode)
if !ok {
t.Fatalf("expected *StemNode, got %T", n)
}
if !sn.dirty {
t.Fatalf("stem produced by Empty.InsertValuesAtStem must have dirty=true")
}
}

View file

@ -16,75 +16,10 @@
package bintrie package bintrie
import ( import "github.com/ethereum/go-ethereum/common"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
)
// HashedNode is an unresolved node — only its hash is known.
type HashedNode common.Hash type HashedNode common.Hash
func (h HashedNode) Get(_ []byte, _ NodeResolverFn) ([]byte, error) { // Hash returns the node's hash.
panic("not implemented") // TODO: Implement func (h HashedNode) Hash() common.Hash { return common.Hash(h) }
}
func (h HashedNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
return nil, errors.New("insert not implemented for hashed node")
}
func (h HashedNode) Copy() BinaryNode {
nh := common.Hash(h)
return HashedNode(nh)
}
func (h HashedNode) Hash() common.Hash {
return common.Hash(h)
}
func (h HashedNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
return nil, errors.New("attempted to get values from an unresolved node")
}
func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
// Step 1: Generate the path for this node's position in the tree
path, err := keyToPath(depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem path generation error: %w", err)
}
if resolver == nil {
return nil, errors.New("InsertValuesAtStem resolve error: resolver is nil")
}
// Step 2: Resolve the hashed node to get the actual node data
data, err := resolver(path, common.Hash(h))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
// Step 3: Deserialize the resolved data into a concrete node
node, err := DeserializeNodeWithHash(data, depth, common.Hash(h))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
// Step 4: Call InsertValuesAtStem on the resolved concrete node
return node.InsertValuesAtStem(stem, values, resolver, depth)
}
func (h HashedNode) toDot(parent string, path string) string {
me := fmt.Sprintf("hash%s", path)
ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, h)
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
return ret
}
func (h HashedNode) CollectNodes([]byte, NodeFlushFn) error {
// HashedNodes are already persisted in the database and don't need to be collected.
return nil
}
func (h HashedNode) GetHeight() int {
panic("tried to get the height of a hashed node, this is a bug")
}

View file

@ -18,180 +18,137 @@ package bintrie
import ( import (
"bytes" "bytes"
"errors"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
// TestHashedNodeHash tests the Hash method // TestHashedNodeHash tests the Hash method via nodeStore.
func TestHashedNodeHash(t *testing.T) { func TestHashedNodeHash(t *testing.T) {
hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
node := HashedNode(hash) s := newNodeStore()
ref := s.newHashedRef(hash)
// Hash should return the stored hash if s.computeHash(ref) != hash {
if node.Hash() != hash { t.Errorf("Hash mismatch: expected %x, got %x", hash, s.computeHash(ref))
t.Errorf("Hash mismatch: expected %x, got %x", hash, node.Hash())
} }
} }
// TestHashedNodeCopy tests the Copy method // TestHashedNodeCopy tests the Copy method via nodeStore.
func TestHashedNodeCopy(t *testing.T) { func TestHashedNodeCopy(t *testing.T) {
hash := common.HexToHash("0xabcdef") hash := common.HexToHash("0xabcdef")
node := HashedNode(hash) s := newNodeStore()
ref := s.newHashedRef(hash)
s.root = ref
copied := node.Copy() ns := s.Copy()
copiedHash, ok := copied.(HashedNode) copiedHash := ns.computeHash(ns.root)
if !ok {
t.Fatalf("Expected HashedNode, got %T", copied)
}
// Hash should be the same if copiedHash != hash {
if common.Hash(copiedHash) != hash {
t.Errorf("Hash mismatch after copy: expected %x, got %x", hash, copiedHash) t.Errorf("Hash mismatch after copy: expected %x, got %x", hash, copiedHash)
} }
// But should be a different object
if &node == &copiedHash {
t.Error("Copy returned same object reference")
}
} }
// TestHashedNodeInsert tests that Insert returns an error // TestHashedNodeInsertValuesAtStem tests InsertValuesAtStem resolution via nodeStore.
func TestHashedNodeInsert(t *testing.T) {
node := HashedNode(common.HexToHash("0x1234"))
key := make([]byte, HashSize)
value := make([]byte, HashSize)
_, err := node.Insert(key, value, nil, 0)
if err == nil {
t.Fatal("Expected error for Insert on HashedNode")
}
if err.Error() != "insert not implemented for hashed node" {
t.Errorf("Unexpected error message: %v", err)
}
}
// TestHashedNodeGetValuesAtStem tests that GetValuesAtStem returns an error
func TestHashedNodeGetValuesAtStem(t *testing.T) {
node := HashedNode(common.HexToHash("0x1234"))
stem := make([]byte, StemSize)
_, err := node.GetValuesAtStem(stem, nil)
if err == nil {
t.Fatal("Expected error for GetValuesAtStem on HashedNode")
}
if err.Error() != "attempted to get values from an unresolved node" {
t.Errorf("Unexpected error message: %v", err)
}
}
// TestHashedNodeInsertValuesAtStem tests that InsertValuesAtStem returns an error
func TestHashedNodeInsertValuesAtStem(t *testing.T) { func TestHashedNodeInsertValuesAtStem(t *testing.T) {
node := HashedNode(common.HexToHash("0x1234")) // Test 1: nil resolver should return an error
s := newNodeStore()
hashedRef := s.newHashedRef(common.HexToHash("0x1234"))
s.root = hashedRef
stem := make([]byte, StemSize) stem := make([]byte, StemSize)
values := make([][]byte, StemNodeWidth) values := make([][]byte, StemNodeWidth)
// Test 1: nil resolver should return an error err := s.InsertValuesAtStem(stem, values, nil)
_, err := node.InsertValuesAtStem(stem, values, nil, 0)
if err == nil { if err == nil {
t.Fatal("Expected error for InsertValuesAtStem on HashedNode with nil resolver") t.Fatal("Expected error for InsertValuesAtStem with nil resolver")
}
if err.Error() != "InsertValuesAtStem resolve error: resolver is nil" {
t.Errorf("Unexpected error message: %v", err)
} }
// Test 2: mock resolver returning invalid data should return deserialization error // Test 2: mock resolver returning invalid data should return deserialization error
mockResolver := func(path []byte, hash common.Hash) ([]byte, error) { mockResolver := func(path []byte, hash common.Hash) ([]byte, error) {
// Return invalid/nonsense data that cannot be deserialized
return []byte{0xff, 0xff, 0xff}, nil return []byte{0xff, 0xff, 0xff}, nil
} }
_, err = node.InsertValuesAtStem(stem, values, mockResolver, 0) s2 := newNodeStore()
if err == nil { hashedRef2 := s2.newHashedRef(common.HexToHash("0x1234"))
t.Fatal("Expected error for InsertValuesAtStem on HashedNode with invalid resolver data") s2.root = hashedRef2
}
expectedPrefix := "InsertValuesAtStem node deserialization error:" err = s2.InsertValuesAtStem(stem, values, mockResolver)
if len(err.Error()) < len(expectedPrefix) || err.Error()[:len(expectedPrefix)] != expectedPrefix { if err == nil {
t.Errorf("Expected deserialization error, got: %v", err) t.Fatal("Expected error for InsertValuesAtStem with invalid resolver data")
} }
// Test 3: mock resolver returning valid serialized node should succeed // Test 3: mock resolver returning valid serialized node should succeed
stem = make([]byte, StemSize) stem = make([]byte, StemSize)
stem[0] = 0xaa stem[0] = 0xaa
var originalValues [StemNodeWidth][]byte originalValues := make([][]byte, StemNodeWidth)
originalValues[0] = common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111").Bytes() originalValues[0] = common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111").Bytes()
originalValues[1] = common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222").Bytes() originalValues[1] = common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222").Bytes()
originalNode := &StemNode{ // Build the serialized node
Stem: stem, rs := newNodeStore()
Values: originalValues[:], ref := rs.newStemRef(stem, 0)
depth: 0, sn := rs.getStem(ref.Index())
for i, v := range originalValues {
if v != nil {
sn.setValue(byte(i), v)
}
} }
serialized := rs.serializeNode(ref)
// Serialize the node
serialized := SerializeNode(originalNode)
// Create a mock resolver that returns the serialized node
validResolver := func(path []byte, hash common.Hash) ([]byte, error) { validResolver := func(path []byte, hash common.Hash) ([]byte, error) {
return serialized, nil return serialized, nil
} }
var newValues [StemNodeWidth][]byte s3 := newNodeStore()
hashedRef3 := s3.newHashedRef(common.HexToHash("0x1234"))
s3.root = hashedRef3
newValues := make([][]byte, StemNodeWidth)
newValues[2] = common.HexToHash("0x3333333333333333333333333333333333333333333333333333333333333333").Bytes() newValues[2] = common.HexToHash("0x3333333333333333333333333333333333333333333333333333333333333333").Bytes()
resolvedNode, err := node.InsertValuesAtStem(stem, newValues[:], validResolver, 0) err = s3.InsertValuesAtStem(stem, newValues, validResolver)
if err != nil { if err != nil {
t.Fatalf("Expected successful resolution and insertion, got error: %v", err) t.Fatalf("Expected successful resolution and insertion, got error: %v", err)
} }
resultStem, ok := resolvedNode.(*StemNode) // Verify original values are preserved
if !ok { retrieved, err := s3.GetValuesAtStem(stem, nil)
t.Fatalf("Expected resolved node to be *StemNode, got %T", resolvedNode) if err != nil {
t.Fatal(err)
} }
if !bytes.Equal(retrieved[0], originalValues[0]) {
if !bytes.Equal(resultStem.Stem, stem) { t.Errorf("Original value at index 0 not preserved")
t.Errorf("Stem mismatch: expected %x, got %x", stem, resultStem.Stem)
} }
if !bytes.Equal(retrieved[1], originalValues[1]) {
// Verify the original values are preserved t.Errorf("Original value at index 1 not preserved")
if !bytes.Equal(resultStem.Values[0], originalValues[0]) {
t.Errorf("Original value at index 0 not preserved: expected %x, got %x", originalValues[0], resultStem.Values[0])
} }
if !bytes.Equal(resultStem.Values[1], originalValues[1]) { if !bytes.Equal(retrieved[2], newValues[2]) {
t.Errorf("Original value at index 1 not preserved: expected %x, got %x", originalValues[1], resultStem.Values[1]) t.Errorf("New value at index 2 not inserted correctly")
}
// Verify the new value was inserted
if !bytes.Equal(resultStem.Values[2], newValues[2]) {
t.Errorf("New value at index 2 not inserted correctly: expected %x, got %x", newValues[2], resultStem.Values[2])
} }
} }
// TestHashedNodeToDot tests the toDot method for visualization // TestHashedNodeGetError tests that getting through an unresolved HashedNode root returns error.
func TestHashedNodeToDot(t *testing.T) { func TestHashedNodeGetError(t *testing.T) {
hash := common.HexToHash("0x1234") s := newNodeStore()
node := HashedNode(hash) // Create root as hashed, then try to resolve through InternalNode parent
rootRef := s.newInternalRef(0)
rootNode := s.getInternal(rootRef.Index())
hashedLeft := s.newHashedRef(common.HexToHash("0x1234"))
rootNode.left = hashedLeft
rootNode.right = emptyRef
s.root = rootRef
dot := node.toDot("parent", "010") key := make([]byte, 32) // goes left
key[31] = 5
// Should contain the hash value and parent connection resolver := func(path []byte, hash common.Hash) ([]byte, error) {
expectedHash := "hash010" return nil, errors.New("node not found")
if !contains(dot, expectedHash) {
t.Errorf("Expected dot output to contain %s", expectedHash)
} }
if !contains(dot, "parent -> hash010") { _, err := s.Get(key, resolver)
t.Error("Expected dot output to contain parent connection") if err == nil {
t.Fatal("Expected error when resolver fails")
} }
} }
// Helper function
func contains(s, substr string) bool {
return len(s) >= len(substr) && s != "" && len(substr) > 0
}

View file

@ -17,35 +17,13 @@
package bintrie package bintrie
import ( import (
"crypto/sha256"
"errors" "errors"
"fmt"
"math/bits"
"runtime"
"sync"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
// parallelDepth returns the tree depth below which Hash() spawns goroutines.
func parallelDepth() int {
return min(bits.Len(uint(runtime.NumCPU())), 8)
}
// isDirty reports whether a BinaryNode child needs rehashing.
func isDirty(n BinaryNode) bool {
switch v := n.(type) {
case *InternalNode:
return v.mustRecompute
case *StemNode:
return v.mustRecompute
default:
return false
}
}
func keyToPath(depth int, key []byte) ([]byte, error) { func keyToPath(depth int, key []byte) ([]byte, error) {
if depth > 31*8 { if depth >= 31*8 {
return nil, errors.New("node too deep") return nil, errors.New("node too deep")
} }
path := make([]byte, 0, depth+1) path := make([]byte, 0, depth+1)
@ -56,252 +34,12 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
return path, nil return path, nil
} }
// InternalNode is a binary trie internal node. // Invariant: dirty=false implies mustRecompute=false. Every mutation that
// invalidates the cached hash MUST also mark the blob for re-flush.
type InternalNode struct { type InternalNode struct {
left, right BinaryNode left, right nodeRef
depth int depth uint8
mustRecompute bool // hash is stale (cleared by Hash)
mustRecompute bool // true if the hash needs to be recomputed dirty bool // on-disk blob is stale (cleared by CollectNodes)
dirty bool // true if the node's on-disk blob is stale (needs flush) hash common.Hash
hash common.Hash // cached hash when mustRecompute == false
}
// GetValuesAtStem retrieves the group of values located at the given stem key.
func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) {
if bt.depth > 31*8 {
return nil, errors.New("node too deep")
}
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 {
if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
return bt.left.GetValuesAtStem(stem, resolver)
}
if hn, ok := bt.right.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.right = node
}
return bt.right.GetValuesAtStem(stem, resolver)
}
// Get retrieves the value for the given key.
func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error) {
values, err := bt.GetValuesAtStem(key[:31], resolver)
if err != nil {
return nil, fmt.Errorf("get error: %w", err)
}
if values == nil {
return nil, nil
}
return values[key[31]], nil
}
// Insert inserts a new key-value pair into the trie.
func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
var values [256][]byte
values[key[31]] = value
return bt.InsertValuesAtStem(key[:31], values[:], resolver, depth)
}
// Copy creates a deep copy of the node.
func (bt *InternalNode) Copy() BinaryNode {
return &InternalNode{
left: bt.left.Copy(),
right: bt.right.Copy(),
depth: bt.depth,
mustRecompute: bt.mustRecompute,
dirty: bt.dirty,
hash: bt.hash,
}
}
// Hash returns the hash of the node.
func (bt *InternalNode) Hash() common.Hash {
if !bt.mustRecompute {
return bt.hash
}
// At shallow depths, parallelize when both children need rehashing:
// hash left subtree in a goroutine, right subtree inline, then combine.
// Skip goroutine overhead when only one child is dirty (common case
// for narrow state updates that touch a single path through the trie).
if bt.depth < parallelDepth() && isDirty(bt.left) && isDirty(bt.right) {
var input [64]byte
var lh common.Hash
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
lh = bt.left.Hash()
}()
rh := bt.right.Hash()
copy(input[32:], rh[:])
wg.Wait()
copy(input[:32], lh[:])
bt.hash = sha256.Sum256(input[:])
bt.mustRecompute = false
return bt.hash
}
// Deeper nodes: sequential using pooled hasher (goroutine overhead > hash cost)
h := newSha256()
defer returnSha256(h)
if bt.left != nil {
h.Write(bt.left.Hash().Bytes())
} else {
h.Write(zero[:])
}
if bt.right != nil {
h.Write(bt.right.Hash().Bytes())
} else {
h.Write(zero[:])
}
bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false
return bt.hash
}
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
// Already-existing values will be overwritten.
func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
var err error
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 {
if bt.left == nil {
bt.left = Empty{}
}
if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
bt.dirty = true
return bt, err
}
if bt.right == nil {
bt.right = Empty{}
}
if hn, ok := bt.right.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.right = node
}
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
bt.dirty = true
return bt, err
}
// CollectNodes collects all child nodes at a given path, and flushes it
// into the provided node collector. Clean subtrees (dirty == false) are
// skipped.
func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
if !bt.dirty {
return nil
}
if bt.left != nil {
var p [256]byte
copy(p[:], path)
childpath := p[:len(path)]
childpath = append(childpath, 0)
if err := bt.left.CollectNodes(childpath, flushfn); err != nil {
return err
}
}
if bt.right != nil {
var p [256]byte
copy(p[:], path)
childpath := p[:len(path)]
childpath = append(childpath, 1)
if err := bt.right.CollectNodes(childpath, flushfn); err != nil {
return err
}
}
flushfn(path, bt)
bt.dirty = false
return nil
}
// GetHeight returns the height of the node.
func (bt *InternalNode) GetHeight() int {
var (
leftHeight int
rightHeight int
)
if bt.left != nil {
leftHeight = bt.left.GetHeight()
}
if bt.right != nil {
rightHeight = bt.right.GetHeight()
}
return 1 + max(leftHeight, rightHeight)
}
func (bt *InternalNode) toDot(parent, path string) string {
me := fmt.Sprintf("internal%s", path)
ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, bt.Hash())
if len(parent) > 0 {
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
}
if bt.left != nil {
ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0)))
}
if bt.right != nil {
ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1)))
}
return ret
} }

View file

@ -24,35 +24,33 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
// TestInternalNodeGet tests the Get method // TestInternalNodeGet tests the Get method via nodeStore.
func TestInternalNodeGet(t *testing.T) { func TestInternalNodeGet(t *testing.T) {
// Create a simple tree structure s := newNodeStore()
leftStem := make([]byte, 31) leftStem := make([]byte, 31)
rightStem := make([]byte, 31) rightStem := make([]byte, 31)
rightStem[0] = 0x80 // First bit is 1 rightStem[0] = 0x80
var leftValues, rightValues [256][]byte leftValues := make([][]byte, 256)
leftValues[0] = common.HexToHash("0x0101").Bytes() leftValues[0] = common.HexToHash("0x0101").Bytes()
rightValues := make([][]byte, 256)
rightValues[0] = common.HexToHash("0x0202").Bytes() rightValues[0] = common.HexToHash("0x0202").Bytes()
node := &InternalNode{ // Build tree: root -> left stem, right stem
depth: 0, // Insert left stem values
left: &StemNode{ s.root = emptyRef
Stem: leftStem, if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
Values: leftValues[:], t.Fatal(err)
depth: 1, }
}, if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
right: &StemNode{ t.Fatal(err)
Stem: rightStem,
Values: rightValues[:],
depth: 1,
},
} }
// Get value from left subtree // Get value from left subtree
leftKey := make([]byte, 32) leftKey := make([]byte, 32)
leftKey[31] = 0 leftKey[31] = 0
value, err := node.Get(leftKey, nil) value, err := s.Get(leftKey, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to get left value: %v", err) t.Fatalf("Failed to get left value: %v", err)
} }
@ -64,7 +62,7 @@ func TestInternalNodeGet(t *testing.T) {
rightKey := make([]byte, 32) rightKey := make([]byte, 32)
rightKey[0] = 0x80 rightKey[0] = 0x80
rightKey[31] = 0 rightKey[31] = 0
value, err = node.Get(rightKey, nil) value, err = s.Get(rightKey, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to get right value: %v", err) t.Fatalf("Failed to get right value: %v", err)
} }
@ -73,29 +71,26 @@ func TestInternalNodeGet(t *testing.T) {
} }
} }
// TestInternalNodeGetWithResolver tests Get with HashedNode resolution // TestInternalNodeGetWithResolver tests Get with HashedNode resolution via nodeStore.
func TestInternalNodeGetWithResolver(t *testing.T) { func TestInternalNodeGetWithResolver(t *testing.T) {
// Create an internal node with a hashed child // Create a store with an internal node containing a hashed child
hashedChild := HashedNode(common.HexToHash("0x1234")) s := newNodeStore()
hashedChild := s.newHashedRef(common.HexToHash("0x1234"))
node := &InternalNode{ rootRef := s.newInternalRef(0)
depth: 0, rootNode := s.getInternal(rootRef.Index())
left: hashedChild, rootNode.left = hashedChild
right: Empty{}, rootNode.right = emptyRef
} s.root = rootRef
// Mock resolver that returns a stem node // Mock resolver that returns a stem node
resolver := func(path []byte, hash common.Hash) ([]byte, error) { resolver := func(path []byte, hash common.Hash) ([]byte, error) {
if hash == common.Hash(hashedChild) { if hash == common.HexToHash("0x1234") {
rs := newNodeStore()
stem := make([]byte, 31) stem := make([]byte, 31)
var values [256][]byte ref := rs.newStemRef(stem, 1)
values[5] = common.HexToHash("0xabcd").Bytes() sn := rs.getStem(ref.Index())
stemNode := &StemNode{ sn.setValue(5, common.HexToHash("0xabcd").Bytes())
Stem: stem, return rs.serializeNode(ref), nil
Values: values[:],
depth: 1,
}
return SerializeNode(stemNode), nil
} }
return nil, errors.New("node not found") return nil, errors.New("node not found")
} }
@ -103,7 +98,7 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
// Get value through the hashed node // Get value through the hashed node
key := make([]byte, 32) key := make([]byte, 32)
key[31] = 5 key[31] = 5
value, err := node.Get(key, resolver) value, err := s.Get(key, resolver)
if err != nil { if err != nil {
t.Fatalf("Failed to get value: %v", err) t.Fatalf("Failed to get value: %v", err)
} }
@ -114,179 +109,113 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
} }
} }
// TestInternalNodeInsert tests the Insert method // TestInternalNodeInsert tests the Insert method via nodeStore.
func TestInternalNodeInsert(t *testing.T) { func TestInternalNodeInsert(t *testing.T) {
// Start with an internal node with empty children s := newNodeStore()
node := &InternalNode{
depth: 0,
left: Empty{},
right: Empty{},
}
// Insert a value into the left subtree
leftKey := make([]byte, 32) leftKey := make([]byte, 32)
leftKey[31] = 10 leftKey[31] = 10
leftValue := common.HexToHash("0x0101").Bytes() leftValue := common.HexToHash("0x0101").Bytes()
newNode, err := node.Insert(leftKey, leftValue, nil, 0) if err := s.Insert(leftKey, leftValue, nil); err != nil {
if err != nil {
t.Fatalf("Failed to insert: %v", err) t.Fatalf("Failed to insert: %v", err)
} }
internalNode, ok := newNode.(*InternalNode) // Verify the value was stored
if !ok { value, err := s.Get(leftKey, nil)
t.Fatalf("Expected InternalNode, got %T", newNode) if err != nil {
t.Fatalf("Failed to get: %v", err)
} }
if !bytes.Equal(value, leftValue) {
// Check that left child is now a StemNode t.Errorf("Value mismatch: expected %x, got %x", leftValue, value)
leftStem, ok := internalNode.left.(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
}
// Check the inserted value
if !bytes.Equal(leftStem.Values[10], leftValue) {
t.Errorf("Value mismatch: expected %x, got %x", leftValue, leftStem.Values[10])
}
// Right child should still be Empty
_, ok = internalNode.right.(Empty)
if !ok {
t.Errorf("Expected right child to remain Empty, got %T", internalNode.right)
} }
} }
// TestInternalNodeCopy tests the Copy method // TestInternalNodeCopy tests the Copy method via nodeStore.
func TestInternalNodeCopy(t *testing.T) { func TestInternalNodeCopy(t *testing.T) {
// Create an internal node with stem children s := newNodeStore()
leftStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
}
leftStem.Values[0] = common.HexToHash("0x0101").Bytes()
rightStem := &StemNode{ leftKey := make([]byte, 32)
Stem: make([]byte, 31), leftKey[31] = 0
Values: make([][]byte, 256), leftValue := common.HexToHash("0x0101").Bytes()
depth: 1,
}
rightStem.Stem[0] = 0x80
rightStem.Values[0] = common.HexToHash("0x0202").Bytes()
node := &InternalNode{ rightKey := make([]byte, 32)
depth: 0, rightKey[0] = 0x80
left: leftStem, rightKey[31] = 0
right: rightStem, rightValue := common.HexToHash("0x0202").Bytes()
if err := s.Insert(leftKey, leftValue, nil); err != nil {
t.Fatal(err)
}
if err := s.Insert(rightKey, rightValue, nil); err != nil {
t.Fatal(err)
} }
// Create a copy ns := s.Copy()
copied := node.Copy()
copiedInternal, ok := copied.(*InternalNode)
if !ok {
t.Fatalf("Expected InternalNode, got %T", copied)
}
// Check depth // Values should be equal
if copiedInternal.depth != node.depth { v1, _ := ns.Get(leftKey, nil)
t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedInternal.depth) if !bytes.Equal(v1, leftValue) {
}
// Check that children are copied
copiedLeft, ok := copiedInternal.left.(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.left)
}
copiedRight, ok := copiedInternal.right.(*StemNode)
if !ok {
t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.right)
}
// Verify deep copy (children should be different objects)
if copiedLeft == leftStem {
t.Error("Left child not properly copied")
}
if copiedRight == rightStem {
t.Error("Right child not properly copied")
}
// But values should be equal
if !bytes.Equal(copiedLeft.Values[0], leftStem.Values[0]) {
t.Error("Left child value mismatch after copy") t.Error("Left child value mismatch after copy")
} }
if !bytes.Equal(copiedRight.Values[0], rightStem.Values[0]) { v2, _ := ns.Get(rightKey, nil)
if !bytes.Equal(v2, rightValue) {
t.Error("Right child value mismatch after copy") t.Error("Right child value mismatch after copy")
} }
} }
// TestInternalNodeHash tests the Hash method // TestInternalNodeHash tests the Hash method via nodeStore.
func TestInternalNodeHash(t *testing.T) { func TestInternalNodeHash(t *testing.T) {
// Create an internal node s := newNodeStore()
node := &InternalNode{ leftRef := s.newHashedRef(common.HexToHash("0x1111"))
depth: 0, rightRef := s.newHashedRef(common.HexToHash("0x2222"))
left: HashedNode(common.HexToHash("0x1111")), rootRef := s.newInternalRef(0)
right: HashedNode(common.HexToHash("0x2222")), rootNode := s.getInternal(rootRef.Index())
} rootNode.left = leftRef
rootNode.right = rightRef
s.root = rootRef
hash1 := node.Hash() hash1 := s.computeHash(rootRef)
// Hash should be deterministic // Hash should be deterministic
hash2 := node.Hash() hash2 := s.computeHash(rootRef)
if hash1 != hash2 { if hash1 != hash2 {
t.Errorf("Hash not deterministic: %x != %x", hash1, hash2) t.Errorf("Hash not deterministic: %x != %x", hash1, hash2)
} }
// Changing a child should change the hash // Changing a child should change the hash
node.left = HashedNode(common.HexToHash("0x3333")) rootNode.left = s.newHashedRef(common.HexToHash("0x3333"))
node.mustRecompute = true rootNode.mustRecompute = true
hash3 := node.Hash() hash3 := s.computeHash(rootRef)
if hash1 == hash3 { if hash1 == hash3 {
t.Error("Hash didn't change after modifying left child") t.Error("Hash didn't change after modifying left child")
} }
// Test with nil children (should use zero hash)
nodeWithNil := &InternalNode{
depth: 0,
left: nil,
right: HashedNode(common.HexToHash("0x4444")),
mustRecompute: true,
}
hashWithNil := nodeWithNil.Hash()
if hashWithNil == (common.Hash{}) {
t.Error("Hash shouldn't be zero even with nil child")
}
} }
// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method // TestInternalNodeGetValuesAtStem tests GetValuesAtStem method via nodeStore.
func TestInternalNodeGetValuesAtStem(t *testing.T) { func TestInternalNodeGetValuesAtStem(t *testing.T) {
// Create a tree with values at different stems s := newNodeStore()
leftStem := make([]byte, 31) leftStem := make([]byte, 31)
rightStem := make([]byte, 31) rightStem := make([]byte, 31)
rightStem[0] = 0x80 rightStem[0] = 0x80
var leftValues, rightValues [256][]byte leftValues := make([][]byte, 256)
leftValues[0] = common.HexToHash("0x0101").Bytes() leftValues[0] = common.HexToHash("0x0101").Bytes()
leftValues[10] = common.HexToHash("0x0102").Bytes() leftValues[10] = common.HexToHash("0x0102").Bytes()
rightValues := make([][]byte, 256)
rightValues[0] = common.HexToHash("0x0201").Bytes() rightValues[0] = common.HexToHash("0x0201").Bytes()
rightValues[20] = common.HexToHash("0x0202").Bytes() rightValues[20] = common.HexToHash("0x0202").Bytes()
node := &InternalNode{ if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
depth: 0, t.Fatal(err)
left: &StemNode{ }
Stem: leftStem, if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
Values: leftValues[:], t.Fatal(err)
depth: 1,
},
right: &StemNode{
Stem: rightStem,
Values: rightValues[:],
depth: 1,
},
} }
// Get values from left stem // Get values from left stem
values, err := node.GetValuesAtStem(leftStem, nil) values, err := s.GetValuesAtStem(leftStem, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to get left values: %v", err) t.Fatalf("Failed to get left values: %v", err)
} }
@ -298,7 +227,7 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
} }
// Get values from right stem // Get values from right stem
values, err = node.GetValuesAtStem(rightStem, nil) values, err = s.GetValuesAtStem(rightStem, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to get right values: %v", err) t.Fatalf("Failed to get right values: %v", err)
} }
@ -310,201 +239,103 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
} }
} }
// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method // TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method via nodeStore.
func TestInternalNodeInsertValuesAtStem(t *testing.T) { func TestInternalNodeInsertValuesAtStem(t *testing.T) {
// Start with an internal node with empty children s := newNodeStore()
node := &InternalNode{
depth: 0,
left: Empty{},
right: Empty{},
}
// Insert values at a stem in the left subtree
stem := make([]byte, 31) stem := make([]byte, 31)
var values [256][]byte values := make([][]byte, 256)
values[5] = common.HexToHash("0x0505").Bytes() values[5] = common.HexToHash("0x0505").Bytes()
values[10] = common.HexToHash("0x1010").Bytes() values[10] = common.HexToHash("0x1010").Bytes()
newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 0) if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
if err != nil {
t.Fatalf("Failed to insert values: %v", err) t.Fatalf("Failed to insert values: %v", err)
} }
internalNode, ok := newNode.(*InternalNode) // Check that the values are stored
if !ok { retrieved, err := s.GetValuesAtStem(stem, nil)
t.Fatalf("Expected InternalNode, got %T", newNode) if err != nil {
t.Fatalf("Failed to get values: %v", err)
} }
if !bytes.Equal(retrieved[5], values[5]) {
// Check that left child is now a StemNode with the values
leftStem, ok := internalNode.left.(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
}
if !bytes.Equal(leftStem.Values[5], values[5]) {
t.Error("Value at index 5 mismatch") t.Error("Value at index 5 mismatch")
} }
if !bytes.Equal(leftStem.Values[10], values[10]) { if !bytes.Equal(retrieved[10], values[10]) {
t.Error("Value at index 10 mismatch") t.Error("Value at index 10 mismatch")
} }
} }
// TestInternalNodeCollectNodes tests CollectNodes method // TestInternalNodeCollectNodes tests CollectNodes method via nodeStore.
func TestInternalNodeCollectNodes(t *testing.T) { func TestInternalNodeCollectNodes(t *testing.T) {
// Create an internal node with two stem children. All three are s := newNodeStore()
// marked dirty to mirror production semantics — see CollectNodes.
leftStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
dirty: true,
}
rightStem := &StemNode{ leftStem := make([]byte, 31)
Stem: make([]byte, 31), rightStem := make([]byte, 31)
Values: make([][]byte, 256), rightStem[0] = 0x80
depth: 1,
dirty: true,
}
rightStem.Stem[0] = 0x80
node := &InternalNode{ leftValues := make([][]byte, 256)
depth: 0, rightValues := make([][]byte, 256)
left: leftStem,
right: rightStem, if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
dirty: true, t.Fatal(err)
}
if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
t.Fatal(err)
} }
var collectedPaths [][]byte var collectedPaths [][]byte
var collectedNodes []BinaryNode flushFn := func(path []byte, hash common.Hash, serialized []byte) {
flushFn := func(path []byte, n BinaryNode) {
pathCopy := make([]byte, len(path)) pathCopy := make([]byte, len(path))
copy(pathCopy, path) copy(pathCopy, path)
collectedPaths = append(collectedPaths, pathCopy) collectedPaths = append(collectedPaths, pathCopy)
collectedNodes = append(collectedNodes, n)
} }
err := node.CollectNodes([]byte{1}, flushFn) err := s.collectNodes(s.root, []byte{1}, flushFn)
if err != nil { if err != nil {
t.Fatalf("Failed to collect nodes: %v", err) t.Fatalf("Failed to collect nodes: %v", err)
} }
// Should have collected 3 nodes: left stem, right stem, and the internal node itself // Should have collected 3 nodes: left stem, right stem, and the internal node itself
if len(collectedNodes) != 3 { if len(collectedPaths) != 3 {
t.Errorf("Expected 3 collected nodes, got %d", len(collectedNodes)) t.Errorf("Expected 3 collected nodes, got %d", len(collectedPaths))
}
// Check paths
expectedPaths := [][]byte{
{1, 0}, // left child
{1, 1}, // right child
{1}, // internal node itself
}
for i, expectedPath := range expectedPaths {
if !bytes.Equal(collectedPaths[i], expectedPath) {
t.Errorf("Path %d mismatch: expected %v, got %v", i, expectedPath, collectedPaths[i])
}
} }
} }
// TestInternalNodeCollectNodesSkipsClean verifies clean subtrees are not // TestInternalNodeGetHeight tests GetHeight method via nodeStore.
// flushed. A dirty internal node over clean children only flushes itself;
// a fully clean tree flushes nothing.
func TestInternalNodeCollectNodesSkipsClean(t *testing.T) {
leftStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
}
rightStem := &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 1,
}
rightStem.Stem[0] = 0x80
dirtyParent := &InternalNode{
depth: 0,
left: leftStem,
right: rightStem,
dirty: true,
}
var collected []BinaryNode
flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) }
if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil {
t.Fatalf("CollectNodes: %v", err)
}
if len(collected) != 1 || collected[0] != dirtyParent {
t.Fatalf("expected only the dirty parent to be flushed, got %d nodes", len(collected))
}
if dirtyParent.dirty {
t.Errorf("parent dirty flag should be cleared after flush")
}
// Second call on the same tree should be a no-op: everything is clean.
collected = nil
if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil {
t.Fatalf("CollectNodes (second call): %v", err)
}
if len(collected) != 0 {
t.Errorf("expected no nodes to be flushed on clean tree, got %d", len(collected))
}
}
// TestInternalNodeGetHeight tests GetHeight method
func TestInternalNodeGetHeight(t *testing.T) { func TestInternalNodeGetHeight(t *testing.T) {
// Create a tree with different heights s := newNodeStore()
// Left subtree: depth 2 (internal -> stem)
// Right subtree: depth 1 (stem) // Insert values that create a deeper tree
leftInternal := &InternalNode{ stem1 := make([]byte, 31) // left
depth: 1, stem2 := make([]byte, 31)
left: &StemNode{ stem2[0] = 0x40 // 01... -> goes left at depth 0, right at depth 1
Stem: make([]byte, 31),
Values: make([][]byte, 256), values1 := make([][]byte, 256)
depth: 2, values1[0] = common.HexToHash("0x01").Bytes()
}, values2 := make([][]byte, 256)
right: Empty{}, values2[0] = common.HexToHash("0x02").Bytes()
if err := s.InsertValuesAtStem(stem1, values1, nil); err != nil {
t.Fatal(err)
}
if err := s.InsertValuesAtStem(stem2, values2, nil); err != nil {
t.Fatal(err)
} }
rightStem := &StemNode{ height := s.getHeight(s.root)
Stem: make([]byte, 31), if height < 2 {
Values: make([][]byte, 256), t.Errorf("Expected height >= 2, got %d", height)
depth: 1,
}
node := &InternalNode{
depth: 0,
left: leftInternal,
right: rightStem,
}
height := node.GetHeight()
// Height should be max(left height, right height) + 1
// Left height: 2, Right height: 1, so total: 3
if height != 3 {
t.Errorf("Expected height 3, got %d", height)
} }
} }
// TestInternalNodeDepthTooLarge tests handling of excessive depth // TestInternalNodeDepthTooLarge tests handling of excessive depth via nodeStore.
func TestInternalNodeDepthTooLarge(t *testing.T) { func TestInternalNodeDepthTooLarge(t *testing.T) {
// Create an internal node at max depth s := newNodeStore()
node := &InternalNode{ // Creating an internal node beyond max depth should panic
depth: 31*8 + 1, defer func() {
left: Empty{}, if r := recover(); r == nil {
right: Empty{}, t.Fatal("Expected panic for excessive depth")
} }
}()
stem := make([]byte, 31) s.newInternalRef(31*8 + 1)
_, err := node.GetValuesAtStem(stem, nil)
if err == nil {
t.Fatal("Expected error for excessive depth")
}
if err.Error() != "node too deep" {
t.Errorf("Expected 'node too deep' error, got: %v", err)
}
} }

View file

@ -26,13 +26,14 @@ import (
var errIteratorEnd = errors.New("end of iteration") var errIteratorEnd = errors.New("end of iteration")
type binaryNodeIteratorState struct { type binaryNodeIteratorState struct {
Node BinaryNode Node nodeRef
Index int Index int
} }
type binaryNodeIterator struct { type binaryNodeIterator struct {
trie *BinaryTrie trie *BinaryTrie
current BinaryNode store *nodeStore
current nodeRef
lastErr error lastErr error
stack []binaryNodeIteratorState stack []binaryNodeIteratorState
@ -40,56 +41,63 @@ type binaryNodeIterator struct {
func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) { func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) {
if t.Hash() == zero { if t.Hash() == zero {
return &binaryNodeIterator{trie: t, lastErr: errIteratorEnd}, nil return &binaryNodeIterator{trie: t, store: t.store, lastErr: errIteratorEnd}, nil
} }
it := &binaryNodeIterator{trie: t, current: t.root} it := &binaryNodeIterator{trie: t, store: t.store, current: t.store.root}
// it.err = it.seek(start)
return it, nil return it, nil
} }
// Next moves the iterator to the next node. If the parameter is false, any child // Next moves the iterator to the next node. If descend is false, children of
// nodes will be skipped. // the current node are skipped.
func (it *binaryNodeIterator) Next(descend bool) bool { func (it *binaryNodeIterator) Next(descend bool) bool {
if it.lastErr == errIteratorEnd { if it.lastErr == errIteratorEnd {
it.lastErr = errIteratorEnd
return false return false
} }
if len(it.stack) == 0 { if len(it.stack) == 0 {
it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.root}) it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.store.root})
it.current = it.trie.root it.current = it.trie.store.root
return true return true
} }
switch node := it.current.(type) { switch it.current.Kind() {
case *InternalNode: case kindInternal:
// index: 0 = nothing visited, 1=left visited, 2=right visited // index: 0 = nothing visited, 1 = left visited, 2 = right visited.
node := it.store.getInternal(it.current.Index())
context := &it.stack[len(it.stack)-1] context := &it.stack[len(it.stack)-1]
// recurse into both children if !descend {
// Skip children: pop this node and advance parent.
if len(it.stack) == 1 {
it.lastErr = errIteratorEnd
return false
}
it.stack = it.stack[:len(it.stack)-1]
it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++
return it.Next(true)
}
// Recurse into both children.
if context.Index == 0 { if context.Index == 0 {
if _, isempty := node.left.(Empty); node.left != nil && !isempty { if !node.left.IsEmpty() {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left}) it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left})
it.current = node.left it.current = node.left
return it.Next(descend) return it.Next(descend)
} }
context.Index++ context.Index++
} }
if context.Index == 1 { if context.Index == 1 {
if _, isempty := node.right.(Empty); node.right != nil && !isempty { if !node.right.IsEmpty() {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right}) it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right})
it.current = node.right it.current = node.right
return it.Next(descend) return it.Next(descend)
} }
context.Index++ context.Index++
} }
// Reached the end of this node, go back to the parent, if // Reached the end of this node; go back to the parent unless we're at the root.
// this isn't root.
if len(it.stack) == 1 { if len(it.stack) == 1 {
it.lastErr = errIteratorEnd it.lastErr = errIteratorEnd
return false return false
@ -98,17 +106,18 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
it.current = it.stack[len(it.stack)-1].Node it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++ it.stack[len(it.stack)-1].Index++
return it.Next(descend) return it.Next(descend)
case *StemNode:
// Look for the next non-empty value case kindStem:
// Look for the next non-empty value in this stem.
sn := it.store.getStem(it.current.Index())
for i := it.stack[len(it.stack)-1].Index; i < 256; i++ { for i := it.stack[len(it.stack)-1].Index; i < 256; i++ {
if node.Values[i] != nil { if sn.hasValue(byte(i)) {
it.stack[len(it.stack)-1].Index = i + 1 it.stack[len(it.stack)-1].Index = i + 1
return true return true
} }
} }
// go back to parent to get the next leaf // No more values in this stem; go back to parent to get the next leaf.
// Check if we're at the root before popping
if len(it.stack) == 1 { if len(it.stack) == 1 {
it.lastErr = errIteratorEnd it.lastErr = errIteratorEnd
return false return false
@ -117,51 +126,47 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
it.current = it.stack[len(it.stack)-1].Node it.current = it.stack[len(it.stack)-1].Node
it.stack[len(it.stack)-1].Index++ it.stack[len(it.stack)-1].Index++
return it.Next(descend) return it.Next(descend)
case HashedNode:
// resolve the node
resolverPath := it.Path()
data, err := it.trie.nodeResolver(resolverPath, common.Hash(node))
if err != nil {
panic(err)
}
if data == nil {
// Empty/nil node — treat as Empty, backtrack
it.current = Empty{}
it.stack[len(it.stack)-1].Node = it.current
return it.Next(descend)
}
it.current, err = DeserializeNodeWithHash(data, len(it.stack)-1, common.Hash(node))
if err != nil {
panic(err)
}
// update the stack and parent with the resolved node case kindHashed:
it.stack[len(it.stack)-1].Node = it.current // Resolve the hashed node from disk, then rewire the parent to point at the
if len(it.stack) >= 2 { // resolved node in place.
parent := &it.stack[len(it.stack)-2] if len(it.stack) < 2 {
if parent.Index == 0 { it.lastErr = errors.New("cannot resolve hashed root during iteration")
parent.Node.(*InternalNode).left = it.current
} else {
parent.Node.(*InternalNode).right = it.current
}
}
return it.Next(descend)
case Empty:
// Empty node - go back to parent and continue
if len(it.stack) <= 1 {
it.lastErr = errIteratorEnd
return false return false
} }
it.stack = it.stack[:len(it.stack)-1] hn := it.store.getHashed(it.current.Index())
it.current = it.stack[len(it.stack)-1].Node data, err := it.trie.nodeResolver(it.Path(), hn.Hash())
it.stack[len(it.stack)-1].Index++ if err != nil {
it.lastErr = err
return false
}
resolved, err := it.store.deserializeNodeWithHash(data, len(it.stack)-1, hn.Hash())
if err != nil {
it.lastErr = err
return false
}
oldHashedIdx := it.current.Index()
it.current = resolved
it.stack[len(it.stack)-1].Node = resolved
parent := &it.stack[len(it.stack)-2]
parentNode := it.store.getInternal(parent.Node.Index())
if parent.Index == 0 {
parentNode.left = resolved
} else {
parentNode.right = resolved
}
it.store.freeHashedNode(oldHashedIdx)
return it.Next(descend) return it.Next(descend)
case kindEmpty:
return false
default: default:
panic("invalid node type") panic("invalid node type")
} }
} }
// Error returns the error status of the iterator.
func (it *binaryNodeIterator) Error() error { func (it *binaryNodeIterator) Error() error {
if it.lastErr == errIteratorEnd { if it.lastErr == errIteratorEnd {
return nil return nil
@ -169,27 +174,28 @@ func (it *binaryNodeIterator) Error() error {
return it.lastErr return it.lastErr
} }
// Hash returns the hash of the current node.
func (it *binaryNodeIterator) Hash() common.Hash { func (it *binaryNodeIterator) Hash() common.Hash {
return it.current.Hash() return it.store.computeHash(it.current)
} }
// Parent returns the hash of the parent of the current node. The hash may be the one // Parent returns the hash of the current node's parent. When the immediate
// grandparent if the immediate parent is an internal node with no hash. // parent is an internal node whose hash has not been materialised, the
// returned hash may be the one of a grandparent instead.
func (it *binaryNodeIterator) Parent() common.Hash { func (it *binaryNodeIterator) Parent() common.Hash {
return it.stack[len(it.stack)-1].Node.Hash() if len(it.stack) < 2 {
return common.Hash{}
}
return it.store.computeHash(it.stack[len(it.stack)-2].Node)
} }
// Path returns the hex-encoded path to the current node. // Path returns the bit-path to the current node.
// Callers must not retain references to the return value after calling Next. // Callers must not retain references to the returned slice after calling Next.
// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10.
func (it *binaryNodeIterator) Path() []byte { func (it *binaryNodeIterator) Path() []byte {
if it.Leaf() { if it.Leaf() {
return it.LeafKey() return it.LeafKey()
} }
var path []byte var path []byte
for i, state := range it.stack { for i, state := range it.stack {
// skip the last byte
if i >= len(it.stack)-1 { if i >= len(it.stack)-1 {
break break
} }
@ -198,107 +204,94 @@ func (it *binaryNodeIterator) Path() []byte {
return path return path
} }
// NodeBlob returns the serialized bytes of the current node.
func (it *binaryNodeIterator) NodeBlob() []byte { func (it *binaryNodeIterator) NodeBlob() []byte {
return SerializeNode(it.current) return it.store.serializeNode(it.current)
} }
// Leaf returns true iff the current node is a leaf node. // Leaf reports whether the iterator is currently positioned at a leaf value.
// In a Binary Trie, a StemNode contains up to 256 leaf values. // A StemNode holds up to 256 values; the iterator is only "at a leaf" when
// The iterator is only considered to be "at a leaf" when it's positioned // positioned at a specific non-nil value inside the stem, not merely at the
// at a specific non-nil value within the StemNode, not just at the StemNode itself. // StemNode itself. The stack Index points to the NEXT position after the
// current value, so Index == 0 means we haven't yielded anything yet.
func (it *binaryNodeIterator) Leaf() bool { func (it *binaryNodeIterator) Leaf() bool {
sn, ok := it.current.(*StemNode) if it.current.Kind() != kindStem {
if !ok {
return false return false
} }
// Check if we have a valid stack position
if len(it.stack) == 0 { if len(it.stack) == 0 {
return false return false
} }
// The Index in the stack state points to the NEXT position after the current value.
// So if Index is 0, we haven't started iterating through the values yet.
// If Index is 5, we're currently at value[4] (the 5th value, 0-indexed).
idx := it.stack[len(it.stack)-1].Index idx := it.stack[len(it.stack)-1].Index
if idx == 0 || idx > 256 { if idx == 0 || idx > 256 {
return false return false
} }
// Check if there's actually a value at the current position sn := it.store.getStem(it.current.Index())
currentValueIndex := idx - 1 currentValueIndex := idx - 1
return sn.Values[currentValueIndex] != nil return sn.hasValue(byte(currentValueIndex))
} }
// LeafKey returns the key of the leaf. The method panics if the iterator is not // LeafKey returns the key of the leaf. Panics if the iterator is not
// positioned at a leaf. Callers must not retain references to the value after // positioned at a leaf. Callers must not retain references to the returned
// calling Next. // slice after calling Next.
func (it *binaryNodeIterator) LeafKey() []byte { func (it *binaryNodeIterator) LeafKey() []byte {
leaf, ok := it.current.(*StemNode) if it.current.Kind() != kindStem {
if !ok {
panic("Leaf() called on an binary node iterator not at a leaf location") panic("Leaf() called on an binary node iterator not at a leaf location")
} }
return leaf.Key(it.stack[len(it.stack)-1].Index - 1) sn := it.store.getStem(it.current.Index())
return sn.Key(it.stack[len(it.stack)-1].Index - 1)
} }
// LeafBlob returns the content of the leaf. The method panics if the iterator // LeafBlob returns the leaf value. Panics if the iterator is not positioned
// is not positioned at a leaf. Callers must not retain references to the value // at a leaf. Callers must not retain references to the returned slice after
// after calling Next. // calling Next.
func (it *binaryNodeIterator) LeafBlob() []byte { func (it *binaryNodeIterator) LeafBlob() []byte {
leaf, ok := it.current.(*StemNode) if it.current.Kind() != kindStem {
if !ok {
panic("LeafBlob() called on an binary node iterator not at a leaf location") panic("LeafBlob() called on an binary node iterator not at a leaf location")
} }
return leaf.Values[it.stack[len(it.stack)-1].Index-1] sn := it.store.getStem(it.current.Index())
return sn.getValue(byte(it.stack[len(it.stack)-1].Index - 1))
} }
// LeafProof returns the Merkle proof of the leaf. The method panics if the // LeafProof returns the Merkle proof of the leaf. Panics if the iterator is
// iterator is not positioned at a leaf. Callers must not retain references // not positioned at a leaf. Callers must not retain references to the
// to the value after calling Next. // returned slices after calling Next.
func (it *binaryNodeIterator) LeafProof() [][]byte { func (it *binaryNodeIterator) LeafProof() [][]byte {
sn, ok := it.current.(*StemNode) if it.current.Kind() != kindStem {
if !ok {
panic("LeafProof() called on an binary node iterator not at a leaf location") panic("LeafProof() called on an binary node iterator not at a leaf location")
} }
sn := it.store.getStem(it.current.Index())
proof := make([][]byte, 0, len(it.stack)+StemNodeWidth) proof := make([][]byte, 0, len(it.stack)+StemNodeWidth)
// Build proof by walking up the stack and collecting sibling hashes if len(it.stack) < 2 {
proof = append(proof, sn.Stem[:])
proof = append(proof, sn.allValues()...)
return proof
}
for i := range it.stack[:len(it.stack)-2] { for i := range it.stack[:len(it.stack)-2] {
state := it.stack[i] state := it.stack[i]
internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode internalNode := it.store.getInternal(state.Node.Index())
// Add the sibling hash to the proof
if state.Index == 0 { if state.Index == 0 {
// We came from left, so include right sibling rh := it.store.computeHash(internalNode.right)
proof = append(proof, internalNode.right.Hash().Bytes()) proof = append(proof, rh.Bytes())
} else { } else {
// We came from right, so include left sibling lh := it.store.computeHash(internalNode.left)
proof = append(proof, internalNode.left.Hash().Bytes()) proof = append(proof, lh.Bytes())
} }
} }
// Add the stem and siblings // Add the stem and siblings
proof = append(proof, sn.Stem) proof = append(proof, sn.Stem[:])
for _, v := range sn.Values { proof = append(proof, sn.allValues()...)
proof = append(proof, v)
}
return proof return proof
} }
// AddResolver sets an intermediate database to use for looking up trie nodes // AddResolver is a no-op (satisfies the NodeIterator interface).
// before reaching into the real persistent layer.
//
// This is not required for normal operation, rather is an optimization for
// cases where trie nodes can be recovered from some external mechanism without
// reading from disk. In those cases, this resolver allows short circuiting
// accesses and returning them from memory.
//
// Before adding a similar mechanism to any other place in Geth, consider
// making trie.Database an interface and wrapping at that level. It's a huge
// refactor, but it could be worth it if another occurrence arises.
func (it *binaryNodeIterator) AddResolver(trie.NodeResolver) { func (it *binaryNodeIterator) AddResolver(trie.NodeResolver) {
// Not implemented, but should not panic // Not implemented, but should not panic
} }

View file

@ -27,14 +27,13 @@ import (
// makeTrie creates a BinaryTrie populated with the given key-value pairs. // makeTrie creates a BinaryTrie populated with the given key-value pairs.
func makeTrie(t *testing.T, entries [][2]common.Hash) *BinaryTrie { func makeTrie(t *testing.T, entries [][2]common.Hash) *BinaryTrie {
t.Helper() t.Helper()
store := newNodeStore()
tr := &BinaryTrie{ tr := &BinaryTrie{
root: NewBinaryNode(), store: store,
tracer: trie.NewPrevalueTracer(), tracer: trie.NewPrevalueTracer(),
} }
for _, kv := range entries { for _, kv := range entries {
var err error if err := store.Insert(kv[0][:], kv[1][:], nil); err != nil {
tr.root, err = tr.root.Insert(kv[0][:], kv[1][:], nil, 0)
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
@ -64,7 +63,7 @@ func countLeaves(t *testing.T, tr *BinaryTrie) int {
// no nodes and reports no error. // no nodes and reports no error.
func TestIteratorEmptyTrie(t *testing.T) { func TestIteratorEmptyTrie(t *testing.T) {
tr := &BinaryTrie{ tr := &BinaryTrie{
root: Empty{}, store: newNodeStore(),
tracer: trie.NewPrevalueTracer(), tracer: trie.NewPrevalueTracer(),
} }
it, err := newBinaryNodeIterator(tr, nil) it, err := newBinaryNodeIterator(tr, nil)
@ -145,8 +144,8 @@ func TestIteratorEmptyNodeBacktrack(t *testing.T) {
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey}, {common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
}) })
if _, ok := tr.root.(*InternalNode); !ok { if tr.store.root.Kind() != kindInternal {
t.Fatalf("expected InternalNode root, got %T", tr.root) t.Fatalf("expected InternalNode root, got kind %d", tr.store.root.Kind())
} }
if leaves := countLeaves(t, tr); leaves != 2 { if leaves := countLeaves(t, tr); leaves != 2 {
t.Fatalf("expected 2 leaves, got %d (Empty backtrack bug?)", leaves) t.Fatalf("expected 2 leaves, got %d (Empty backtrack bug?)", leaves)
@ -162,18 +161,31 @@ func TestIteratorHashedNodeNilData(t *testing.T) {
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey}, {common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
}) })
root, ok := tr.root.(*InternalNode) root := tr.store.root
if !ok { if root.Kind() != kindInternal {
t.Fatalf("expected InternalNode root, got %T", tr.root) t.Fatalf("expected InternalNode root, got kind %d", root.Kind())
} }
rootNode := tr.store.getInternal(root.Index())
// Replace right child with a zero-hash HashedNode. nodeResolver // Replace right child with a zero-hash HashedNode. nodeResolver
// short-circuits on common.Hash{} and returns (nil, nil), which // short-circuits on common.Hash{} and returns (nil, nil), which
// triggers the nil-data guard in the iterator. // triggers the nil-data guard in the iterator.
root.right = HashedNode(common.Hash{}) rootNode.right = tr.store.newHashedRef(common.Hash{})
// Should not panic; the zero-hash right child should be treated as Empty. // Should not panic; the zero-hash right child should be treated as Empty.
if leaves := countLeaves(t, tr); leaves != 1 { // Since the hashed node can't be resolved (nil data -> empty deserialization),
// only the left leaf should be counted.
it, err := newBinaryNodeIterator(tr, nil)
if err != nil {
t.Fatal(err)
}
leaves := 0
for it.Next(true) {
if it.Leaf() {
leaves++
}
}
if leaves != 1 {
t.Fatalf("expected 1 leaf (zero-hash right node skipped), got %d", leaves) t.Fatalf("expected 1 leaf (zero-hash right node skipped), got %d", leaves)
} }
} }

56
trie/bintrie/node_ref.go Normal file
View file

@ -0,0 +1,56 @@
// Copyright 2026 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
// nodeKind identifies the type of a trie node stored in a nodeRef.
type nodeKind uint8
const (
kindEmpty nodeKind = iota
kindInternal
kindStem // up to 256 values per stem
kindHashed
)
// nodeRef is a compact, GC-invisible reference to a node in a nodeStore.
// It packs a 2-bit type tag (bits 31-30) and a 30-bit index (bits 29-0)
// into a single uint32. Because nodeRef contains no Go pointers, slices
// of structs containing nodeRef fields are allocated in noscan spans —
// the garbage collector never examines them.
type nodeRef uint32
const (
kindShift uint32 = 30
indexMask uint32 = (1 << kindShift) - 1
// emptyRef represents an empty node.
emptyRef nodeRef = 0
)
func makeRef(kind nodeKind, idx uint32) nodeRef {
if idx > indexMask {
panic("nodeRef index overflow")
}
return nodeRef(uint32(kind)<<kindShift | idx)
}
func (r nodeRef) Kind() nodeKind { return nodeKind(uint32(r) >> kindShift) }
// Index within the typed pool.
func (r nodeRef) Index() uint32 { return uint32(r) & indexMask }
func (r nodeRef) IsEmpty() bool { return r.Kind() == kindEmpty }

184
trie/bintrie/node_store.go Normal file
View file

@ -0,0 +1,184 @@
// Copyright 2026 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import "github.com/ethereum/go-ethereum/common"
// storeChunkSize is the number of nodes per chunk in each typed pool.
const storeChunkSize = 4096
// nodeStore is a GC-friendly arena for binary trie nodes. Nodes are packed
// into typed chunked pools so pointer-free types (InternalNode, HashedNode)
// land in noscan spans the GC skips entirely.
type nodeStore struct {
internalChunks []*[storeChunkSize]InternalNode
internalCount uint32
stemChunks []*[storeChunkSize]StemNode
stemCount uint32
hashedChunks []*[storeChunkSize]HashedNode
hashedCount uint32
root nodeRef
// Free list for recycling hashed-node slots after resolve. Internal and
// stem nodes are never freed under current semantics (no delete path,
// stem-split keeps the old stem at a deeper position), so they don't
// have free lists.
freeHashed []uint32
}
func newNodeStore() *nodeStore {
return &nodeStore{root: emptyRef}
}
func (s *nodeStore) allocInternal() uint32 {
idx := s.internalCount
chunkIdx := idx / storeChunkSize
if uint32(len(s.internalChunks)) <= chunkIdx {
s.internalChunks = append(s.internalChunks, new([storeChunkSize]InternalNode))
}
s.internalCount++
if s.internalCount > indexMask {
panic("internal node pool overflow")
}
return idx
}
func (s *nodeStore) getInternal(idx uint32) *InternalNode {
return &s.internalChunks[idx/storeChunkSize][idx%storeChunkSize]
}
func (s *nodeStore) newInternalRef(depth int) nodeRef {
if depth > 248 {
panic("node depth exceeds maximum binary trie depth")
}
idx := s.allocInternal()
n := s.getInternal(idx)
n.depth = uint8(depth)
n.mustRecompute = true
n.dirty = true
return makeRef(kindInternal, idx)
}
func (s *nodeStore) allocStem() uint32 {
idx := s.stemCount
chunkIdx := idx / storeChunkSize
if uint32(len(s.stemChunks)) <= chunkIdx {
s.stemChunks = append(s.stemChunks, new([storeChunkSize]StemNode))
}
s.stemCount++
if s.stemCount > indexMask {
panic("stem node pool overflow")
}
return idx
}
func (s *nodeStore) getStem(idx uint32) *StemNode {
return &s.stemChunks[idx/storeChunkSize][idx%storeChunkSize]
}
func (s *nodeStore) newStemRef(stem []byte, depth int) nodeRef {
if depth > 248 {
panic("node depth exceeds maximum binary trie depth")
}
idx := s.allocStem()
sn := s.getStem(idx)
copy(sn.Stem[:], stem[:StemSize])
sn.depth = uint8(depth)
sn.mustRecompute = true
sn.dirty = true
return makeRef(kindStem, idx)
}
func (s *nodeStore) allocHashed() uint32 {
if n := len(s.freeHashed); n > 0 {
idx := s.freeHashed[n-1]
s.freeHashed = s.freeHashed[:n-1]
*s.getHashed(idx) = HashedNode{}
return idx
}
idx := s.hashedCount
chunkIdx := idx / storeChunkSize
if uint32(len(s.hashedChunks)) <= chunkIdx {
s.hashedChunks = append(s.hashedChunks, new([storeChunkSize]HashedNode))
}
s.hashedCount++
if s.hashedCount > indexMask {
panic("hashed node pool overflow")
}
return idx
}
func (s *nodeStore) getHashed(idx uint32) *HashedNode {
return &s.hashedChunks[idx/storeChunkSize][idx%storeChunkSize]
}
func (s *nodeStore) freeHashedNode(idx uint32) {
s.freeHashed = append(s.freeHashed, idx)
}
func (s *nodeStore) newHashedRef(hash common.Hash) nodeRef {
idx := s.allocHashed()
*s.getHashed(idx) = HashedNode(hash)
return makeRef(kindHashed, idx)
}
func (s *nodeStore) Copy() *nodeStore {
ns := &nodeStore{
root: s.root,
internalCount: s.internalCount,
stemCount: s.stemCount,
hashedCount: s.hashedCount,
}
ns.internalChunks = make([]*[storeChunkSize]InternalNode, len(s.internalChunks))
for i, chunk := range s.internalChunks {
cp := *chunk
ns.internalChunks[i] = &cp
}
ns.stemChunks = make([]*[storeChunkSize]StemNode, len(s.stemChunks))
for i, chunk := range s.stemChunks {
cp := *chunk
ns.stemChunks[i] = &cp
}
// Deep-copy each stem's value slots — they may alias serialized buffers,
// so we can't rely on the chunk-wise struct copy above.
for i := uint32(0); i < s.stemCount; i++ {
src := s.getStem(i)
dst := ns.getStem(i)
for j, v := range src.values {
if v == nil {
continue
}
cp := make([]byte, len(v))
copy(cp, v)
dst.values[j] = cp
}
}
ns.hashedChunks = make([]*[storeChunkSize]HashedNode, len(s.hashedChunks))
for i, chunk := range s.hashedChunks {
cp := *chunk
ns.hashedChunks[i] = &cp
}
if len(s.freeHashed) > 0 {
ns.freeHashed = make([]uint32, len(s.freeHashed))
copy(ns.freeHashed, s.freeHashed)
}
return ns
}

View file

@ -17,236 +17,93 @@
package bintrie package bintrie
import ( import (
"bytes" "crypto/sha256"
"errors"
"fmt"
"slices"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
// StemNode represents a group of `NodeWith` values sharing the same stem. // StemNode holds up to 256 values sharing a 31-byte stem.
//
// Invariant: dirty=false implies mustRecompute=false. Every mutation that
// invalidates the cached hash MUST also mark the blob for re-flush.
type StemNode struct { type StemNode struct {
Stem []byte // Stem path to get to StemNodeWidth values Stem [StemSize]byte
Values [][]byte // All values, indexed by the last byte of the key. values [StemNodeWidth][]byte // nil == slot absent
depth int // Depth of the node
mustRecompute bool // true if the hash needs to be recomputed depth uint8
dirty bool // true if the node's on-disk blob is stale (needs flush)
mustRecompute bool // hash is stale (cleared by Hash)
dirty bool // on-disk blob is stale (cleared by CollectNodes)
hash common.Hash // cached hash when mustRecompute == false hash common.Hash // cached hash when mustRecompute == false
} }
// Get retrieves the value for the given key. func (sn *StemNode) getValue(suffix byte) []byte {
func (bt *StemNode) Get(key []byte, _ NodeResolverFn) ([]byte, error) { return sn.values[suffix]
if !bytes.Equal(bt.Stem, key[:StemSize]) {
return nil, nil
}
return bt.Values[key[StemSize]], nil
} }
// Insert inserts a new key-value pair into the node. func (sn *StemNode) hasValue(suffix byte) bool {
func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) { return sn.values[suffix] != nil
if !bytes.Equal(bt.Stem, key[:StemSize]) {
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true}
bt.depth++
// bt is re-parented under n and sits at a new path — rewrite its blob.
bt.mustRecompute = true
bt.dirty = true
var child, other *BinaryNode
if bitStem == 0 {
n.left = bt
child = &n.left
other = &n.right
} else {
n.right = bt
child = &n.right
other = &n.left
}
bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
if bitKey == bitStem {
var err error
*child, err = (*child).Insert(key, value, nil, depth+1)
if err != nil {
return n, fmt.Errorf("insert error: %w", err)
}
*other = Empty{}
} else {
var values [StemNodeWidth][]byte
values[key[StemSize]] = value
*other = &StemNode{
Stem: slices.Clone(key[:StemSize]),
Values: values[:],
depth: depth + 1,
mustRecompute: true,
dirty: true,
}
}
return n, nil
}
if len(value) != HashSize {
return bt, errors.New("invalid insertion: value length")
}
bt.Values[key[StemSize]] = value
bt.mustRecompute = true
bt.dirty = true
return bt, nil
} }
// Copy creates a deep copy of the node. // allValues returns the underlying slot array as a slice. nil entries mean
func (bt *StemNode) Copy() BinaryNode { // absent. Callers must treat it as read-only.
var values [StemNodeWidth][]byte func (sn *StemNode) allValues() [][]byte {
for i, v := range bt.Values { return sn.values[:]
values[i] = slices.Clone(v)
}
return &StemNode{
Stem: slices.Clone(bt.Stem),
Values: values[:],
depth: bt.depth,
hash: bt.hash,
mustRecompute: bt.mustRecompute,
dirty: bt.dirty,
}
} }
// GetHeight returns the height of the node. // setValue mutates a value slot and marks the stem for re-hash and
func (bt *StemNode) GetHeight() int { // re-flush. This is the only API for post-load value mutation; direct
return 1 // values[...] writes are reserved for the on-disk load path in
// decodeNode, which must leave mustRecompute/dirty at their loaded
// state.
func (sn *StemNode) setValue(suffix byte, value []byte) {
sn.values[suffix] = value
sn.mustRecompute = true
sn.dirty = true
} }
// Hash returns the hash of the node. func (sn *StemNode) Hash() common.Hash {
func (bt *StemNode) Hash() common.Hash { if !sn.mustRecompute {
if !bt.mustRecompute { return sn.hash
return bt.hash
} }
// Use sha256.Sum256 (returns [32]byte by value) instead of a pooled
// hash.Hash: feeding data[i][:0] into the interface method Sum forces
// data to heap (escape analysis is conservative through interfaces).
// Sum256 takes []byte and returns by value, so data stays on stack.
var data [StemNodeWidth]common.Hash var data [StemNodeWidth]common.Hash
h := newSha256()
defer returnSha256(h) for i, v := range sn.values {
for i, v := range bt.Values {
if v != nil { if v != nil {
h.Reset() data[i] = sha256.Sum256(v)
h.Write(v)
h.Sum(data[i][:0])
} }
} }
h.Reset()
var pair [2 * HashSize]byte
for level := 1; level <= 8; level++ { for level := 1; level <= 8; level++ {
for i := range StemNodeWidth / (1 << level) { for i := range StemNodeWidth / (1 << level) {
h.Reset()
if data[i*2] == (common.Hash{}) && data[i*2+1] == (common.Hash{}) { if data[i*2] == (common.Hash{}) && data[i*2+1] == (common.Hash{}) {
data[i] = common.Hash{} data[i] = common.Hash{}
continue continue
} }
copy(pair[:HashSize], data[i*2][:])
h.Write(data[i*2][:]) copy(pair[HashSize:], data[i*2+1][:])
h.Write(data[i*2+1][:]) data[i] = sha256.Sum256(pair[:])
data[i] = common.Hash(h.Sum(nil))
} }
} }
h.Reset() var final [StemSize + 1 + HashSize]byte
h.Write(bt.Stem) copy(final[:StemSize], sn.Stem[:])
h.Write([]byte{0}) final[StemSize] = 0
h.Write(data[0][:]) copy(final[StemSize+1:], data[0][:])
bt.hash = common.BytesToHash(h.Sum(nil)) sn.hash = sha256.Sum256(final[:])
bt.mustRecompute = false sn.mustRecompute = false
return bt.hash return sn.hash
} }
// CollectNodes flushes the stem via the collector when dirty; clean stems func (sn *StemNode) Key(i int) []byte {
// are skipped.
func (bt *StemNode) CollectNodes(path []byte, flush NodeFlushFn) error {
if !bt.dirty {
return nil
}
flush(path, bt)
bt.dirty = false
return nil
}
// GetValuesAtStem retrieves the group of values located at the given stem key.
func (bt *StemNode) GetValuesAtStem(stem []byte, _ NodeResolverFn) ([][]byte, error) {
if !bytes.Equal(bt.Stem, stem) {
return nil, nil
}
return bt.Values[:], nil
}
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
// Already-existing values will be overwritten.
func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
if !bytes.Equal(bt.Stem, key[:StemSize]) {
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true}
bt.depth++
// bt is re-parented under n and sits at a new path — rewrite its blob.
bt.mustRecompute = true
bt.dirty = true
var child, other *BinaryNode
if bitStem == 0 {
n.left = bt
child = &n.left
other = &n.right
} else {
n.right = bt
child = &n.right
other = &n.left
}
bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
if bitKey == bitStem {
var err error
*child, err = (*child).InsertValuesAtStem(key, values, nil, depth+1)
if err != nil {
return n, fmt.Errorf("insert error: %w", err)
}
*other = Empty{}
} else {
*other = &StemNode{
Stem: slices.Clone(key[:StemSize]),
Values: values,
depth: n.depth + 1,
mustRecompute: true,
dirty: true,
}
}
return n, nil
}
// same stem, just merge the two value lists
for i, v := range values {
if v != nil {
bt.Values[i] = v
bt.mustRecompute = true
bt.dirty = true
}
}
return bt, nil
}
func (bt *StemNode) toDot(parent, path string) string {
me := fmt.Sprintf("stem%s", path)
ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, bt.Stem, bt.Hash())
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
for i, v := range bt.Values {
if v != nil {
ret = fmt.Sprintf("%s%s%x [label=\"%x\"]\n", ret, me, i, v)
ret = fmt.Sprintf("%s%s -> %s%x\n", ret, me, me, i)
}
}
return ret
}
// Key returns the full key for the given index.
func (bt *StemNode) Key(i int) []byte {
var ret [HashSize]byte var ret [HashSize]byte
copy(ret[:], bt.Stem) copy(ret[:], sn.Stem[:])
ret[StemSize] = byte(i) ret[StemSize] = byte(i)
return ret[:] return ret[:]
} }

View file

@ -23,165 +23,99 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
// TestStemNodeGet tests the Get method for matching stem, non-matching stem, // TestStemNodeInsertSameStem tests inserting values with the same stem via nodeStore.
// and nil-value suffix scenarios.
func TestStemNodeGet(t *testing.T) {
stem := make([]byte, StemSize)
stem[0] = 0xAB
var values [StemNodeWidth][]byte
values[5] = common.HexToHash("0xdeadbeef").Bytes()
node := &StemNode{Stem: stem, Values: values[:], depth: 0}
// Matching stem, populated suffix → returns value.
key := make([]byte, HashSize)
copy(key[:StemSize], stem)
key[StemSize] = 5
got, err := node.Get(key, nil)
if err != nil {
t.Fatalf("Get error: %v", err)
}
if !bytes.Equal(got, values[5]) {
t.Fatalf("Get = %x, want %x", got, values[5])
}
// Matching stem, empty suffix → returns nil (slot not set).
key[StemSize] = 99
got, err = node.Get(key, nil)
if err != nil {
t.Fatalf("Get error: %v", err)
}
if got != nil {
t.Fatalf("Get(empty suffix) = %x, want nil", got)
}
// Non-matching stem → returns nil, nil.
otherKey := make([]byte, HashSize)
otherKey[0] = 0xFF
got, err = node.Get(otherKey, nil)
if err != nil {
t.Fatalf("Get error: %v", err)
}
if got != nil {
t.Fatalf("Get(wrong stem) = %x, want nil", got)
}
}
// TestStemNodeInsertSameStem tests inserting values with the same stem
func TestStemNodeInsertSameStem(t *testing.T) { func TestStemNodeInsertSameStem(t *testing.T) {
s := newNodeStore()
stem := make([]byte, 31) stem := make([]byte, 31)
for i := range stem { for i := range stem {
stem[i] = byte(i) stem[i] = byte(i)
} }
var values [256][]byte // Insert first value
values[0] = common.HexToHash("0x0101").Bytes() key1 := make([]byte, 32)
copy(key1[:31], stem)
node := &StemNode{ key1[31] = 0
Stem: stem, value1 := common.HexToHash("0x0101").Bytes()
Values: values[:], if err := s.Insert(key1, value1, nil); err != nil {
depth: 0, t.Fatal(err)
} }
// Insert another value with the same stem but different last byte // Insert another value with the same stem but different last byte
key := make([]byte, 32) key2 := make([]byte, 32)
copy(key[:31], stem) copy(key2[:31], stem)
key[31] = 10 key2[31] = 10
value := common.HexToHash("0x0202").Bytes() value2 := common.HexToHash("0x0202").Bytes()
if err := s.Insert(key2, value2, nil); err != nil {
newNode, err := node.Insert(key, value, nil, 0) t.Fatal(err)
if err != nil {
t.Fatalf("Failed to insert: %v", err)
} }
// Should still be a StemNode // Root should still be a StemNode
stemNode, ok := newNode.(*StemNode) if s.root.Kind() != kindStem {
if !ok { t.Fatalf("Expected kindStem root, got kind %d", s.root.Kind())
t.Fatalf("Expected StemNode, got %T", newNode)
} }
// Check that both values are present // Check that both values are present
if !bytes.Equal(stemNode.Values[0], values[0]) { v1, _ := s.Get(key1, nil)
if !bytes.Equal(v1, value1) {
t.Errorf("Value at index 0 mismatch") t.Errorf("Value at index 0 mismatch")
} }
if !bytes.Equal(stemNode.Values[10], value) { v2, _ := s.Get(key2, nil)
if !bytes.Equal(v2, value2) {
t.Errorf("Value at index 10 mismatch") t.Errorf("Value at index 10 mismatch")
} }
} }
// TestStemNodeInsertDifferentStem tests inserting values with different stems // TestStemNodeInsertDifferentStem tests inserting values with different stems via nodeStore.
func TestStemNodeInsertDifferentStem(t *testing.T) { func TestStemNodeInsertDifferentStem(t *testing.T) {
stem1 := make([]byte, 31) s := newNodeStore()
for i := range stem1 {
stem1[i] = 0x00
}
var values [256][]byte // Insert first value with stem of all zeros
values[0] = common.HexToHash("0x0101").Bytes() key1 := make([]byte, 32)
key1[31] = 0
node := &StemNode{ value1 := common.HexToHash("0x0101").Bytes()
Stem: stem1, if err := s.Insert(key1, value1, nil); err != nil {
Values: values[:], t.Fatal(err)
depth: 0,
} }
// Insert with a different stem (first bit different) // Insert with a different stem (first bit different)
key := make([]byte, 32) key2 := make([]byte, 32)
key[0] = 0x80 // First bit is 1 instead of 0 key2[0] = 0x80 // First bit is 1 instead of 0
value := common.HexToHash("0x0202").Bytes() value2 := common.HexToHash("0x0202").Bytes()
if err := s.Insert(key2, value2, nil); err != nil {
newNode, err := node.Insert(key, value, nil, 0) t.Fatal(err)
if err != nil {
t.Fatalf("Failed to insert: %v", err)
} }
// Should now be an InternalNode // Should now be an InternalNode
internalNode, ok := newNode.(*InternalNode) if s.root.Kind() != kindInternal {
if !ok { t.Fatalf("Expected kindInternal root, got kind %d", s.root.Kind())
t.Fatalf("Expected InternalNode, got %T", newNode)
} }
// Check depth // Check depth
if internalNode.depth != 0 { rootNode := s.getInternal(s.root.Index())
t.Errorf("Expected depth 0, got %d", internalNode.depth) if rootNode.depth != 0 {
t.Errorf("Expected depth 0, got %d", rootNode.depth)
} }
// Original stem should be on the left (bit 0) // Verify both values are retrievable
leftStem, ok := internalNode.left.(*StemNode) v1, _ := s.Get(key1, nil)
if !ok { if !bytes.Equal(v1, value1) {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) t.Error("Value 1 mismatch")
} }
if !bytes.Equal(leftStem.Stem, stem1) { v2, _ := s.Get(key2, nil)
t.Errorf("Left stem mismatch") if !bytes.Equal(v2, value2) {
} t.Error("Value 2 mismatch")
// New stem should be on the right (bit 1)
rightStem, ok := internalNode.right.(*StemNode)
if !ok {
t.Fatalf("Expected right child to be StemNode, got %T", internalNode.right)
}
if !bytes.Equal(rightStem.Stem, key[:31]) {
t.Errorf("Right stem mismatch")
} }
} }
// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length // TestStemNodeInsertInvalidValueLength tests inserting value with invalid length via nodeStore.
func TestStemNodeInsertInvalidValueLength(t *testing.T) { func TestStemNodeInsertInvalidValueLength(t *testing.T) {
stem := make([]byte, 31) s := newNodeStore()
var values [256][]byte
node := &StemNode{
Stem: stem,
Values: values[:],
depth: 0,
}
// Try to insert value with wrong length
key := make([]byte, 32) key := make([]byte, 32)
copy(key[:31], stem)
invalidValue := []byte{1, 2, 3} // Not 32 bytes invalidValue := []byte{1, 2, 3} // Not 32 bytes
_, err := node.Insert(key, invalidValue, nil, 0) err := s.Insert(key, invalidValue, nil)
if err == nil { if err == nil {
t.Fatal("Expected error for invalid value length") t.Fatal("Expected error for invalid value length")
} }
@ -191,221 +125,209 @@ func TestStemNodeInsertInvalidValueLength(t *testing.T) {
} }
} }
// TestStemNodeCopy tests the Copy method // TestStemNodeCopy tests the Copy method via nodeStore.
func TestStemNodeCopy(t *testing.T) { func TestStemNodeCopy(t *testing.T) {
stem := make([]byte, 31) s := newNodeStore()
for i := range stem {
stem[i] = byte(i) key1 := make([]byte, 32)
for i := range 31 {
key1[i] = byte(i)
}
key1[31] = 0
value1 := common.HexToHash("0x0101").Bytes()
key2 := make([]byte, 32)
copy(key2[:31], key1[:31])
key2[31] = 255
value2 := common.HexToHash("0x0202").Bytes()
if err := s.Insert(key1, value1, nil); err != nil {
t.Fatal(err)
}
if err := s.Insert(key2, value2, nil); err != nil {
t.Fatal(err)
} }
var values [256][]byte ns := s.Copy()
values[0] = common.HexToHash("0x0101").Bytes()
values[255] = common.HexToHash("0x0202").Bytes()
node := &StemNode{ // Check that values are equal
Stem: stem, v1, _ := ns.Get(key1, nil)
Values: values[:], if !bytes.Equal(v1, value1) {
depth: 10,
}
// Create a copy
copied := node.Copy()
copiedStem, ok := copied.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", copied)
}
// Check that values are equal but not the same slice
if !bytes.Equal(copiedStem.Stem, node.Stem) {
t.Errorf("Stem mismatch after copy")
}
if &copiedStem.Stem[0] == &node.Stem[0] {
t.Error("Stem slice not properly cloned")
}
// Check values
if !bytes.Equal(copiedStem.Values[0], node.Values[0]) {
t.Errorf("Value at index 0 mismatch after copy") t.Errorf("Value at index 0 mismatch after copy")
} }
if !bytes.Equal(copiedStem.Values[255], node.Values[255]) { v2, _ := ns.Get(key2, nil)
if !bytes.Equal(v2, value2) {
t.Errorf("Value at index 255 mismatch after copy") t.Errorf("Value at index 255 mismatch after copy")
} }
// Check that value slices are cloned
if copiedStem.Values[0] != nil && &copiedStem.Values[0][0] == &node.Values[0][0] {
t.Error("Value slice not properly cloned")
}
// Check depth
if copiedStem.depth != node.depth {
t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedStem.depth)
}
} }
// TestStemNodeHash tests the Hash method // TestStemNodeHash tests the Hash method.
func TestStemNodeHash(t *testing.T) { func TestStemNodeHash(t *testing.T) {
stem := make([]byte, 31) s := newNodeStore()
var values [256][]byte
values[0] = common.HexToHash("0x0101").Bytes()
node := &StemNode{ key := make([]byte, 32)
Stem: stem, key[31] = 0
Values: values[:], value := common.HexToHash("0x0101").Bytes()
depth: 0, if err := s.Insert(key, value, nil); err != nil {
t.Fatal(err)
} }
hash1 := node.Hash() hash1 := s.computeHash(s.root)
// Hash should be deterministic // Hash should be deterministic
hash2 := node.Hash() hash2 := s.computeHash(s.root)
if hash1 != hash2 { if hash1 != hash2 {
t.Errorf("Hash not deterministic: %x != %x", hash1, hash2) t.Errorf("Hash not deterministic: %x != %x", hash1, hash2)
} }
// Changing a value should change the hash // Changing a value should change the hash
node.Values[1] = common.HexToHash("0x0202").Bytes() key2 := make([]byte, 32)
node.mustRecompute = true key2[31] = 1
hash3 := node.Hash() value2 := common.HexToHash("0x0202").Bytes()
if err := s.Insert(key2, value2, nil); err != nil {
t.Fatal(err)
}
hash3 := s.computeHash(s.root)
if hash1 == hash3 { if hash1 == hash3 {
t.Error("Hash didn't change after modifying values") t.Error("Hash didn't change after modifying values")
} }
} }
// TestStemNodeGetValuesAtStem tests GetValuesAtStem method // TestStemNodeGetValuesAtStem tests GetValuesAtStem method via nodeStore.
func TestStemNodeGetValuesAtStem(t *testing.T) { func TestStemNodeGetValuesAtStem(t *testing.T) {
s := newNodeStore()
stem := make([]byte, 31) stem := make([]byte, 31)
for i := range stem { for i := range stem {
stem[i] = byte(i) stem[i] = byte(i)
} }
var values [256][]byte values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes() values[0] = common.HexToHash("0x0101").Bytes()
values[10] = common.HexToHash("0x0202").Bytes() values[10] = common.HexToHash("0x0202").Bytes()
values[255] = common.HexToHash("0x0303").Bytes() values[255] = common.HexToHash("0x0303").Bytes()
node := &StemNode{ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
Stem: stem, t.Fatal(err)
Values: values[:],
depth: 0,
} }
// GetValuesAtStem with matching stem // GetValuesAtStem with matching stem
retrievedValues, err := node.GetValuesAtStem(stem, nil) retrievedValues, err := s.GetValuesAtStem(stem, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to get values: %v", err) t.Fatalf("Failed to get values: %v", err)
} }
// Check that all values match if !bytes.Equal(retrievedValues[0], values[0]) {
for i := range 256 { t.Error("Value at index 0 mismatch")
if !bytes.Equal(retrievedValues[i], values[i]) { }
t.Errorf("Value mismatch at index %d", i) if !bytes.Equal(retrievedValues[10], values[10]) {
} t.Error("Value at index 10 mismatch")
}
if !bytes.Equal(retrievedValues[255], values[255]) {
t.Error("Value at index 255 mismatch")
} }
// GetValuesAtStem with different stem should return nil // GetValuesAtStem with different stem should return nil values
differentStem := make([]byte, 31) differentStem := make([]byte, 31)
differentStem[0] = 0xFF differentStem[0] = 0xFF
shouldBeNil, err := node.GetValuesAtStem(differentStem, nil) shouldBeEmpty, err := s.GetValuesAtStem(differentStem, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to get values with different stem: %v", err) t.Fatalf("Failed to get values with different stem: %v", err)
} }
if shouldBeNil != nil { allNil := true
t.Error("Expected nil for different stem, got non-nil") for _, v := range shouldBeEmpty {
if v != nil {
allNil = false
break
}
}
if !allNil {
t.Error("Expected all nil values for different stem")
} }
} }
// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method // TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method via nodeStore.
func TestStemNodeInsertValuesAtStem(t *testing.T) { func TestStemNodeInsertValuesAtStem(t *testing.T) {
s := newNodeStore()
stem := make([]byte, 31) stem := make([]byte, 31)
var values [256][]byte values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes() values[0] = common.HexToHash("0x0101").Bytes()
node := &StemNode{ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
Stem: stem, t.Fatal(err)
Values: values[:],
depth: 0,
} }
// Insert new values at the same stem // Insert new values at the same stem
var newValues [256][]byte newValues := make([][]byte, 256)
newValues[1] = common.HexToHash("0x0202").Bytes() newValues[1] = common.HexToHash("0x0202").Bytes()
newValues[2] = common.HexToHash("0x0303").Bytes() newValues[2] = common.HexToHash("0x0303").Bytes()
newNode, err := node.InsertValuesAtStem(stem, newValues[:], nil, 0) if err := s.InsertValuesAtStem(stem, newValues, nil); err != nil {
if err != nil { t.Fatal(err)
t.Fatalf("Failed to insert values: %v", err)
}
stemNode, ok := newNode.(*StemNode)
if !ok {
t.Fatalf("Expected StemNode, got %T", newNode)
} }
// Check that all values are present // Check that all values are present
if !bytes.Equal(stemNode.Values[0], values[0]) { retrieved, err := s.GetValuesAtStem(stem, nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(retrieved[0], values[0]) {
t.Error("Original value at index 0 missing") t.Error("Original value at index 0 missing")
} }
if !bytes.Equal(stemNode.Values[1], newValues[1]) { if !bytes.Equal(retrieved[1], newValues[1]) {
t.Error("New value at index 1 missing") t.Error("New value at index 1 missing")
} }
if !bytes.Equal(stemNode.Values[2], newValues[2]) { if !bytes.Equal(retrieved[2], newValues[2]) {
t.Error("New value at index 2 missing") t.Error("New value at index 2 missing")
} }
} }
// TestStemNodeGetHeight tests GetHeight method // TestStemNodeGetHeight tests GetHeight method via nodeStore.
func TestStemNodeGetHeight(t *testing.T) { func TestStemNodeGetHeight(t *testing.T) {
node := &StemNode{ s := newNodeStore()
Stem: make([]byte, 31),
Values: make([][]byte, 256), key := make([]byte, 32)
depth: 0, value := common.HexToHash("0x01").Bytes()
if err := s.Insert(key, value, nil); err != nil {
t.Fatal(err)
} }
height := node.GetHeight() height := s.getHeight(s.root)
if height != 1 { if height != 1 {
t.Errorf("Expected height 1, got %d", height) t.Errorf("Expected height 1, got %d", height)
} }
} }
// TestStemNodeCollectNodes tests CollectNodes method // TestStemNodeCollectNodes tests CollectNodes method via nodeStore.
func TestStemNodeCollectNodes(t *testing.T) { func TestStemNodeCollectNodes(t *testing.T) {
s := newNodeStore()
stem := make([]byte, 31) stem := make([]byte, 31)
var values [256][]byte values := make([][]byte, 256)
values[0] = common.HexToHash("0x0101").Bytes() values[0] = common.HexToHash("0x0101").Bytes()
node := &StemNode{ if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
Stem: stem, t.Fatal(err)
Values: values[:],
depth: 0,
dirty: true,
} }
var collectedPaths [][]byte var collectedPaths [][]byte
var collectedNodes []BinaryNode flushFn := func(path []byte, hash common.Hash, serialized []byte) {
flushFn := func(path []byte, n BinaryNode) {
// Make a copy of the path
pathCopy := make([]byte, len(path)) pathCopy := make([]byte, len(path))
copy(pathCopy, path) copy(pathCopy, path)
collectedPaths = append(collectedPaths, pathCopy) collectedPaths = append(collectedPaths, pathCopy)
collectedNodes = append(collectedNodes, n)
} }
err := node.CollectNodes([]byte{0, 1, 0}, flushFn) err := s.collectNodes(s.root, []byte{0, 1, 0}, flushFn)
if err != nil { if err != nil {
t.Fatalf("Failed to collect nodes: %v", err) t.Fatalf("Failed to collect nodes: %v", err)
} }
// Should have collected one node (itself) // Should have collected one node (itself)
if len(collectedNodes) != 1 { if len(collectedPaths) != 1 {
t.Errorf("Expected 1 collected node, got %d", len(collectedNodes)) t.Errorf("Expected 1 collected node, got %d", len(collectedPaths))
}
// Check that the collected node is the same
if collectedNodes[0] != node {
t.Error("Collected node doesn't match original")
} }
// Check the path // Check the path
@ -413,44 +335,3 @@ func TestStemNodeCollectNodes(t *testing.T) {
t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0]) t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0])
} }
} }
// TestStemNodeCollectNodesSkipsClean verifies that a clean stem is not
// flushed, and that flushing a dirty stem clears its dirty flag so that
// a subsequent CollectNodes on the same node is a no-op.
func TestStemNodeCollectNodesSkipsClean(t *testing.T) {
stem := make([]byte, 31)
node := &StemNode{
Stem: stem,
Values: make([][]byte, 256),
depth: 0,
}
var collected []BinaryNode
flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) }
if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
t.Fatalf("CollectNodes on clean stem: %v", err)
}
if len(collected) != 0 {
t.Fatalf("expected clean stem not to be flushed, got %d", len(collected))
}
node.dirty = true
if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
t.Fatalf("CollectNodes on dirty stem: %v", err)
}
if len(collected) != 1 {
t.Fatalf("expected dirty stem to be flushed once, got %d", len(collected))
}
if node.dirty {
t.Errorf("stem dirty flag should be cleared after flush")
}
collected = nil
if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
t.Fatalf("CollectNodes after flush: %v", err)
}
if len(collected) != 0 {
t.Errorf("expected no flush on clean stem, got %d", len(collected))
}
}

View file

@ -0,0 +1,310 @@
// Copyright 2026 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"crypto/sha256"
"errors"
"fmt"
"math/bits"
"runtime"
"sync"
"github.com/ethereum/go-ethereum/common"
)
type nodeFlushFn func(path []byte, hash common.Hash, serialized []byte)
func (s *nodeStore) Hash() common.Hash {
return s.computeHash(s.root)
}
func (s *nodeStore) computeHash(ref nodeRef) common.Hash {
switch ref.Kind() {
case kindInternal:
return s.hashInternal(ref.Index())
case kindStem:
return s.getStem(ref.Index()).Hash()
case kindHashed:
return s.getHashed(ref.Index()).Hash()
case kindEmpty:
return common.Hash{}
default:
return common.Hash{}
}
}
// parallelHashDepth is the tree depth below which hashInternal spawns
// goroutines for shallow-depth parallelism. Computed once at init because
// NumCPU() never changes after startup.
var parallelHashDepth = min(bits.Len(uint(runtime.NumCPU())), 8)
// hashInternal hashes an InternalNode and caches the result.
//
// At shallow depths (< parallelHashDepth) the left subtree is hashed in a
// goroutine while the right subtree is hashed inline, then the two digests
// are combined. Below that threshold the goroutine spawn cost outweighs the
// hashing work, so deeper nodes hash both children sequentially.
func (s *nodeStore) hashInternal(idx uint32) common.Hash {
node := s.getInternal(idx)
if !node.mustRecompute {
return node.hash
}
if int(node.depth) < parallelHashDepth {
var input [64]byte
var lh common.Hash
var wg sync.WaitGroup
if !node.left.IsEmpty() {
wg.Add(1)
go func() {
// defer wg.Done() so a panic in computeHash still releases
// the waiter; without this, a recover() higher in the call
// stack would leave the parent stuck in wg.Wait forever.
defer wg.Done()
lh = s.computeHash(node.left)
}()
}
if !node.right.IsEmpty() {
rh := s.computeHash(node.right)
copy(input[32:], rh[:])
}
wg.Wait()
copy(input[:32], lh[:])
node.hash = sha256.Sum256(input[:])
node.mustRecompute = false
return node.hash
}
// Deep sequential branch — mirrors the shallow branch's shape to keep
// input on the stack. Writing lh/rh through hash.Hash (interface)
// forces escape; copy into a local [64]byte and hash it in one shot.
var input [64]byte
if !node.left.IsEmpty() {
lh := s.computeHash(node.left)
copy(input[:HashSize], lh[:])
}
if !node.right.IsEmpty() {
rh := s.computeHash(node.right)
copy(input[HashSize:], rh[:])
}
node.hash = sha256.Sum256(input[:])
node.mustRecompute = false
return node.hash
}
// SerializeNode serializes a node into the flat on-disk format.
func (s *nodeStore) serializeNode(ref nodeRef) []byte {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
var serialized [NodeTypeBytes + HashSize + HashSize]byte
serialized[0] = nodeTypeInternal
lh := s.computeHash(node.left)
rh := s.computeHash(node.right)
copy(serialized[NodeTypeBytes:NodeTypeBytes+HashSize], lh[:])
copy(serialized[NodeTypeBytes+HashSize:], rh[:])
return serialized[:]
case kindStem:
sn := s.getStem(ref.Index())
// Count present slots to size the blob.
var count int
for _, v := range sn.values {
if v != nil {
count++
}
}
serializedLen := NodeTypeBytes + StemSize + StemBitmapSize + count*HashSize
serialized := make([]byte, serializedLen)
serialized[0] = nodeTypeStem
copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], sn.Stem[:])
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+StemBitmapSize]
offset := NodeTypeBytes + StemSize + StemBitmapSize
for i, v := range sn.values {
if v != nil {
bitmap[i/8] |= 1 << (7 - (i % 8))
copy(serialized[offset:offset+HashSize], v)
offset += HashSize
}
}
return serialized
default:
panic(fmt.Sprintf("SerializeNode: unexpected node kind %d", ref.Kind()))
}
}
var errInvalidSerializedLength = errors.New("invalid serialized node length")
// DeserializeNode deserializes a node from bytes, recomputing its hash. The
// returned node is marked dirty (provenance unknown, safe re-flush default).
func (s *nodeStore) deserializeNode(serialized []byte, depth int) (nodeRef, error) {
return s.decodeNode(serialized, depth, common.Hash{}, true, true)
}
// DeserializeNodeWithHash deserializes a node whose hash is already known and
// whose blob is already on disk (mustRecompute=false, dirty=false).
func (s *nodeStore) deserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (nodeRef, error) {
return s.decodeNode(serialized, depth, hn, false, false)
}
func (s *nodeStore) decodeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (nodeRef, error) {
if len(serialized) == 0 {
return emptyRef, nil
}
switch serialized[0] {
case nodeTypeInternal:
if len(serialized) != NodeTypeBytes+2*HashSize {
return emptyRef, errInvalidSerializedLength
}
var leftHash, rightHash common.Hash
copy(leftHash[:], serialized[NodeTypeBytes:NodeTypeBytes+HashSize])
copy(rightHash[:], serialized[NodeTypeBytes+HashSize:])
var leftRef, rightRef nodeRef
if leftHash != (common.Hash{}) {
leftRef = s.newHashedRef(leftHash)
}
if rightHash != (common.Hash{}) {
rightRef = s.newHashedRef(rightHash)
}
ref := s.newInternalRef(depth)
node := s.getInternal(ref.Index())
node.left = leftRef
node.right = rightRef
if !mustRecompute {
node.hash = hn
node.mustRecompute = false
}
node.dirty = dirty
return ref, nil
case nodeTypeStem:
if len(serialized) < NodeTypeBytes+StemSize+StemBitmapSize {
return emptyRef, errInvalidSerializedLength
}
stemIdx := s.allocStem()
sn := s.getStem(stemIdx)
copy(sn.Stem[:], serialized[NodeTypeBytes:NodeTypeBytes+StemSize])
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+StemBitmapSize]
offset := NodeTypeBytes + StemSize + StemBitmapSize
for i := range StemNodeWidth {
if bitmap[i/8]>>(7-(i%8))&1 != 1 {
continue
}
if len(serialized) < offset+HashSize {
return emptyRef, errInvalidSerializedLength
}
// Zero-copy: each slot aliases the serialized input buffer.
sn.values[i] = serialized[offset : offset+HashSize]
offset += HashSize
}
sn.depth = uint8(depth)
sn.hash = hn
sn.mustRecompute = mustRecompute
sn.dirty = dirty
return makeRef(kindStem, stemIdx), nil
default:
return emptyRef, errors.New("invalid node type")
}
}
// CollectNodes flushes every node that needs flushing via flushfn in post-order.
// Invariant: any ancestor of a node that needs flushing is itself marked, so a
// clean root means the whole subtree is clean.
func (s *nodeStore) collectNodes(ref nodeRef, path []byte, flushfn nodeFlushFn) error {
switch ref.Kind() {
case kindEmpty:
return nil
case kindInternal:
node := s.getInternal(ref.Index())
if !node.dirty {
return nil
}
// Reuse path buffer across children: flushfn consumers
// (NodeSet.AddNode, tracer.Get) clone via string(path), so in-place
// mutation is safe.
path = append(path, 0)
if err := s.collectNodes(node.left, path, flushfn); err != nil {
return err
}
path[len(path)-1] = 1
if err := s.collectNodes(node.right, path, flushfn); err != nil {
return err
}
path = path[:len(path)-1]
flushfn(path, s.computeHash(ref), s.serializeNode(ref))
node.dirty = false
return nil
case kindStem:
sn := s.getStem(ref.Index())
if !sn.dirty {
return nil
}
flushfn(path, s.computeHash(ref), s.serializeNode(ref))
sn.dirty = false
return nil
case kindHashed:
return nil // Already committed
default:
return fmt.Errorf("CollectNodes: unexpected kind %d", ref.Kind())
}
}
func (s *nodeStore) toDot(ref nodeRef, parent, path string) string {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
me := fmt.Sprintf("internal%s", path)
ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, s.computeHash(ref))
if len(parent) > 0 {
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
}
if !node.left.IsEmpty() {
ret += s.toDot(node.left, me, fmt.Sprintf("%s%02x", path, 0))
}
if !node.right.IsEmpty() {
ret += s.toDot(node.right, me, fmt.Sprintf("%s%02x", path, 1))
}
return ret
case kindStem:
sn := s.getStem(ref.Index())
me := fmt.Sprintf("stem%s", path)
ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, sn.Stem, sn.Hash())
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
for i, v := range sn.values {
if v == nil {
continue
}
ret += fmt.Sprintf("%s%x [label=\"%x\"]\n", me, i, v)
ret += fmt.Sprintf("%s -> %s%x\n", me, me, i)
}
return ret
case kindHashed:
hn := s.getHashed(ref.Index())
me := fmt.Sprintf("hash%s", path)
ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, hn.Hash())
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
return ret
default:
return ""
}
}

345
trie/bintrie/store_ops.go Normal file
View file

@ -0,0 +1,345 @@
// Copyright 2026 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
)
// nodeResolverFn resolves a hashed node from the database.
type nodeResolverFn func([]byte, common.Hash) ([]byte, error)
// GetValue returns the value at (stem, suffix) or nil if absent. Thin
// wrapper over GetValuesAtStem — the underlying StemNode returns its
// 256-slot array as a slice header (no allocation), so the per-call cost
// is the tree walk plus one index.
func (s *nodeStore) GetValue(stem []byte, suffix byte, resolver nodeResolverFn) ([]byte, error) {
values, err := s.GetValuesAtStem(stem, resolver)
if err != nil || values == nil {
return nil, err
}
return values[suffix], nil
}
// GetValuesAtStem returns the 256 value slots at stem, or nil if the stem
// is not in the trie. The returned slice is a view over the in-place
// StemNode values array (no allocation) and must be treated read-only.
func (s *nodeStore) GetValuesAtStem(stem []byte, resolver nodeResolverFn) ([][]byte, error) {
cur := s.root
var parentIdx uint32
var parentIsLeft bool
for {
switch cur.Kind() {
case kindInternal:
node := s.getInternal(cur.Index())
if node.depth >= 31*8 {
return nil, errors.New("node too deep")
}
bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
parentIdx = cur.Index()
if bit == 0 {
parentIsLeft = true
cur = node.left
} else {
parentIsLeft = false
cur = node.right
}
case kindStem:
sn := s.getStem(cur.Index())
if sn.Stem != [StemSize]byte(stem[:StemSize]) {
return nil, nil
}
return sn.allValues(), nil
case kindHashed:
// HashedNode at root is impossible: NewBinaryTrie resolves the
// root eagerly before any query. Any HashedNode we encounter here
// is necessarily a child of a previously-visited internal node.
if resolver == nil {
return nil, errors.New("getValuesAtStem: cannot resolve hashed node without resolver")
}
hn := s.getHashed(cur.Index())
parentNode := s.getInternal(parentIdx)
path, err := keyToPath(int(parentNode.depth), stem)
if err != nil {
return nil, fmt.Errorf("getValuesAtStem path error: %w", err)
}
data, err := resolver(path, hn.Hash())
if err != nil {
return nil, fmt.Errorf("getValuesAtStem resolve error: %w", err)
}
resolved, err := s.deserializeNodeWithHash(data, int(parentNode.depth)+1, hn.Hash())
if err != nil {
return nil, fmt.Errorf("getValuesAtStem deserialization error: %w", err)
}
s.freeHashedNode(cur.Index())
if parentIsLeft {
parentNode.left = resolved
} else {
parentNode.right = resolved
}
cur = resolved
case kindEmpty:
var values [StemNodeWidth][]byte
return values[:], nil
default:
return nil, fmt.Errorf("getValuesAtStem: unexpected node kind %d", cur.Kind())
}
}
}
// InsertSingle writes a single value slot at (stem, suffix). Thin wrapper
// over InsertValuesAtStem — builds a stack-allocated 256-slot array with
// only the target slot set and delegates. Matches the original design
// gballet referenced (comment 3101751325): one primary insert path; the
// single-slot variant dispatches through it so the split / resolve logic
// lives in one place.
func (s *nodeStore) InsertSingle(stem []byte, suffix byte, value []byte, resolver nodeResolverFn) error {
if len(value) != HashSize {
return errors.New("invalid insertion: value length")
}
var values [StemNodeWidth][]byte
values[suffix] = value
return s.InsertValuesAtStem(stem, values[:], resolver)
}
// InsertValuesAtStem writes the supplied value slots at stem. values may be
// sparse (nil entries are ignored). The recursive implementation dispatches
// through the same body, so a single code path handles internal descent,
// HashedNode resolution, stem merge, and stem split.
func (s *nodeStore) InsertValuesAtStem(stem []byte, values [][]byte, resolver nodeResolverFn) error {
var err error
s.root, err = s.insertValuesAtStem(s.root, stem, values, resolver, 0)
return err
}
func (s *nodeStore) insertValuesAtStem(ref nodeRef, stem []byte, values [][]byte, resolver nodeResolverFn, depth int) (nodeRef, error) {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
if bit == 0 {
if node.left.Kind() == kindHashed {
if resolver == nil {
return ref, errors.New("insertValuesAtStem: cannot resolve hashed node without resolver")
}
hn := s.getHashed(node.left.Index())
path, err := keyToPath(int(node.depth), stem)
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
}
data, err := resolver(path, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
resolved, err := s.deserializeNodeWithHash(data, int(node.depth)+1, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
}
s.freeHashedNode(node.left.Index())
node.left = resolved
}
newChild, err := s.insertValuesAtStem(node.left, stem, values, resolver, depth+1)
if err != nil {
return ref, err
}
node.left = newChild
} else {
if node.right.Kind() == kindHashed {
if resolver == nil {
return ref, errors.New("insertValuesAtStem: cannot resolve hashed node without resolver")
}
hn := s.getHashed(node.right.Index())
path, err := keyToPath(int(node.depth), stem)
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
}
data, err := resolver(path, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
resolved, err := s.deserializeNodeWithHash(data, int(node.depth)+1, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
}
s.freeHashedNode(node.right.Index())
node.right = resolved
}
newChild, err := s.insertValuesAtStem(node.right, stem, values, resolver, depth+1)
if err != nil {
return ref, err
}
node.right = newChild
}
node.mustRecompute = true
node.dirty = true
return ref, nil
case kindStem:
sn := s.getStem(ref.Index())
if sn.Stem == [StemSize]byte(stem[:StemSize]) {
// Same stem — merge values (setValue marks dirty+mustRecompute)
for i, v := range values {
if v != nil {
sn.setValue(byte(i), v)
}
}
return ref, nil
}
// Different stem — split
return s.splitStemValuesInsert(ref, stem, values, resolver, depth)
case kindHashed:
hn := s.getHashed(ref.Index())
path, err := keyToPath(depth, stem)
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
}
if resolver == nil {
return ref, errors.New("InsertValuesAtStem: resolver is nil")
}
data, err := resolver(path, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
resolved, err := s.deserializeNodeWithHash(data, depth, hn.Hash())
if err != nil {
return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
}
s.freeHashedNode(ref.Index())
return s.insertValuesAtStem(resolved, stem, values, resolver, depth)
case kindEmpty:
// Create new StemNode. Flag flips before the value loop so an
// all-nil values input still marks the newly-created stem dirty.
stemIdx := s.allocStem()
sn := s.getStem(stemIdx)
copy(sn.Stem[:], stem[:StemSize])
sn.depth = uint8(depth)
sn.mustRecompute = true
sn.dirty = true
for i, v := range values {
if v != nil {
sn.setValue(byte(i), v)
}
}
return makeRef(kindStem, stemIdx), nil
default:
return ref, fmt.Errorf("insertValuesAtStem: unexpected kind %d", ref.Kind())
}
}
// splitStemValuesInsert splits a StemNode when the new stem diverges.
func (s *nodeStore) splitStemValuesInsert(existingRef nodeRef, newStem []byte, values [][]byte, resolver nodeResolverFn, depth int) (nodeRef, error) {
existing := s.getStem(existingRef.Index())
if int(existing.depth) >= StemSize*8 {
panic("splitStemValuesInsert: identical stems")
}
bitStem := existing.Stem[existing.depth/8] >> (7 - (existing.depth % 8)) & 1
nRef := s.newInternalRef(int(existing.depth))
nNode := s.getInternal(nRef.Index())
existing.depth++
bitKey := newStem[nNode.depth/8] >> (7 - (nNode.depth % 8)) & 1
if bitKey == bitStem {
// Same direction — need deeper split
var child nodeRef
if bitStem == 0 {
nNode.left = existingRef
child = nNode.left
} else {
nNode.right = existingRef
child = nNode.right
}
newChild, err := s.insertValuesAtStem(child, newStem, values, resolver, depth+1)
if err != nil {
// Roll back the depth increment so a retry sees the same
// existing state and extracts bitStem at the correct offset.
// nRef itself leaks (no internal free-list), but the slot is
// unreachable from the tree and harmless.
existing.depth--
return nRef, err
}
if bitStem == 0 {
nNode.left = newChild
nNode.right = emptyRef
} else {
nNode.right = newChild
nNode.left = emptyRef
}
} else {
// Divergence — create new StemNode for the new values
newStemIdx := s.allocStem()
newSn := s.getStem(newStemIdx)
copy(newSn.Stem[:], newStem[:StemSize])
newSn.depth = nNode.depth + 1
newSn.mustRecompute = true
newSn.dirty = true
for i, v := range values {
if v != nil {
newSn.setValue(byte(i), v)
}
}
newStemRef := makeRef(kindStem, newStemIdx)
if bitStem == 0 {
nNode.left = existingRef
nNode.right = newStemRef
} else {
nNode.left = newStemRef
nNode.right = existingRef
}
}
return nRef, nil
}
func (s *nodeStore) Insert(key []byte, value []byte, resolver nodeResolverFn) error {
return s.InsertSingle(key[:StemSize], key[StemSize], value, resolver)
}
func (s *nodeStore) Get(key []byte, resolver nodeResolverFn) ([]byte, error) {
return s.GetValue(key[:StemSize], key[StemSize], resolver)
}
func (s *nodeStore) getHeight(ref nodeRef) int {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
lh := s.getHeight(node.left)
rh := s.getHeight(node.right)
if lh > rh {
return 1 + lh
}
return 1 + rh
case kindStem:
return 1
case kindEmpty:
return 0
default:
return 0
}
}

View file

@ -19,7 +19,6 @@ package bintrie
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
@ -31,8 +30,6 @@ import (
"github.com/holiman/uint256" "github.com/holiman/uint256"
) )
var errInvalidRootType = errors.New("invalid root type")
// ChunkedCode represents a sequence of HashSize-byte chunks of code (StemSize bytes of which // ChunkedCode represents a sequence of HashSize-byte chunks of code (StemSize bytes of which
// are actual code, and NodeTypeBytes byte is the pushdata offset). // are actual code, and NodeTypeBytes byte is the pushdata offset).
type ChunkedCode []byte type ChunkedCode []byte
@ -108,22 +105,17 @@ func ChunkifyCode(code []byte) ChunkedCode {
return chunks return chunks
} }
// NewBinaryNode creates a new empty binary trie
func NewBinaryNode() BinaryNode {
return Empty{}
}
// BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864. // BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864.
type BinaryTrie struct { type BinaryTrie struct {
root BinaryNode store *nodeStore
reader *trie.Reader reader *trie.Reader
tracer *trie.PrevalueTracer tracer *trie.PrevalueTracer
} }
// ToDot converts the binary trie to a DOT language representation. Useful for debugging. // ToDot converts the binary trie to a DOT language representation. Useful for debugging.
func (t *BinaryTrie) ToDot() string { func (t *BinaryTrie) ToDot() string {
t.root.Hash() t.store.computeHash(t.store.root)
return ToDot(t.root) return t.store.toDot(t.store.root, "", "")
} }
// NewBinaryTrie creates a new binary trie. // NewBinaryTrie creates a new binary trie.
@ -133,7 +125,7 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
return nil, err return nil, err
} }
t := &BinaryTrie{ t := &BinaryTrie{
root: NewBinaryNode(), store: newNodeStore(),
reader: reader, reader: reader,
tracer: trie.NewPrevalueTracer(), tracer: trie.NewPrevalueTracer(),
} }
@ -143,11 +135,11 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
if err != nil { if err != nil {
return nil, err return nil, err
} }
node, err := DeserializeNodeWithHash(blob, 0, root) ref, err := t.store.deserializeNodeWithHash(blob, 0, root)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.root = node t.store.root = ref
} }
return t, nil return t, nil
} }
@ -176,29 +168,18 @@ func (t *BinaryTrie) GetKey(key []byte) []byte {
// GetWithHashedKey returns the value, assuming that the key has already // GetWithHashedKey returns the value, assuming that the key has already
// been hashed. // been hashed.
func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) { func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) {
return t.root.Get(key, t.nodeResolver) return t.store.Get(key, t.nodeResolver)
} }
// GetAccount returns the account information for the given address. // GetAccount returns the account information for the given address.
func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) { func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) {
var ( var (
values [][]byte err error
err error acc = &types.StateAccount{}
acc = &types.StateAccount{} key = GetBinaryTreeKey(addr, zero[:])
key = GetBinaryTreeKey(addr, zero[:])
) )
switch r := t.root.(type) {
case *InternalNode: values, err := t.store.GetValuesAtStem(key[:StemSize], t.nodeResolver)
values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver)
case *StemNode:
values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver)
case Empty:
return nil, nil
default:
// This will cover HashedNode but that should be fine since the
// root node should always be resolved.
return nil, errInvalidRootType
}
if err != nil { if err != nil {
return nil, fmt.Errorf("GetAccount (%x) error: %v", addr, err) return nil, fmt.Errorf("GetAccount (%x) error: %v", addr, err)
} }
@ -219,7 +200,7 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
// If the account has been deleted, BasicData and CodeHash will both be // If the account has been deleted, BasicData and CodeHash will both be
// 32-byte zero blobs (not nil). If the account is recreated afterwards, // 32-byte zero blobs (not nil). If the account is recreated afterwards,
// UpdateAccount overwrites BasicData and CodeHash with non-zero values, // UpdateAccount overwrites BasicData and CodeHash with non-zero values,
// so this branch won't activate.. // so this branch won't activate.
if bytes.Equal(values[BasicDataLeafKey], zero[:]) && if bytes.Equal(values[BasicDataLeafKey], zero[:]) &&
bytes.Equal(values[CodeHashLeafKey], zero[:]) { bytes.Equal(values[CodeHashLeafKey], zero[:]) {
return nil, nil return nil, nil
@ -238,13 +219,12 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
// not be modified by the caller. If a node was not found in the database, a // not be modified by the caller. If a node was not found in the database, a
// trie.MissingNodeError is returned. // trie.MissingNodeError is returned.
func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) { func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) {
return t.root.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver) return t.store.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
} }
// UpdateAccount updates the account information for the given address. // UpdateAccount updates the account information for the given address.
func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error { func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error {
var ( var (
err error
basicData [HashSize]byte basicData [HashSize]byte
values = make([][]byte, StemNodeWidth) values = make([][]byte, StemNodeWidth)
stem = GetBinaryTreeKey(addr, zero[:]) stem = GetBinaryTreeKey(addr, zero[:])
@ -265,15 +245,12 @@ func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount,
values[BasicDataLeafKey] = basicData[:] values[BasicDataLeafKey] = basicData[:]
values[CodeHashLeafKey] = acc.CodeHash[:] values[CodeHashLeafKey] = acc.CodeHash[:]
t.root, err = t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0) return t.store.InsertValuesAtStem(stem, values, t.nodeResolver)
return err
} }
// UpdateStem updates the values for the given stem key. // UpdateStem updates the values for the given stem key.
func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error { func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error {
var err error return t.store.InsertValuesAtStem(key, values, t.nodeResolver)
t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0)
return err
} }
// UpdateStorage associates key with value in the trie. If value has length zero, any // UpdateStorage associates key with value in the trie. If value has length zero, any
@ -288,11 +265,10 @@ func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) er
} else { } else {
copy(v[HashSize-len(value):], value[:]) copy(v[HashSize-len(value):], value[:])
} }
root, err := t.root.Insert(k, v[:], t.nodeResolver, 0) err := t.store.Insert(k, v[:], t.nodeResolver)
if err != nil { if err != nil {
return fmt.Errorf("UpdateStorage (%x) error: %v", address, err) return fmt.Errorf("UpdateStorage (%x) error: %v", address, err)
} }
t.root = root
return nil return nil
} }
@ -307,12 +283,7 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
values[BasicDataLeafKey] = zero[:] values[BasicDataLeafKey] = zero[:]
values[CodeHashLeafKey] = zero[:] values[CodeHashLeafKey] = zero[:]
root, err := t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0) return t.store.InsertValuesAtStem(stem, values, t.nodeResolver)
if err != nil {
return fmt.Errorf("DeleteAccount (%x) error: %v", addr, err)
}
t.root = root
return nil
} }
// DeleteStorage removes any existing value for key from the trie. If a node was not // DeleteStorage removes any existing value for key from the trie. If a node was not
@ -320,18 +291,17 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error { func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error {
k := GetBinaryTreeKeyStorageSlot(addr, key) k := GetBinaryTreeKeyStorageSlot(addr, key)
var zero [HashSize]byte var zero [HashSize]byte
root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0) err := t.store.Insert(k, zero[:], t.nodeResolver)
if err != nil { if err != nil {
return fmt.Errorf("DeleteStorage (%x) error: %v", addr, err) return fmt.Errorf("DeleteStorage (%x) error: %v", addr, err)
} }
t.root = root
return nil return nil
} }
// Hash returns the root hash of the trie. It does not write to the database and // Hash returns the root hash of the trie. It does not write to the database and
// can be used even if the trie doesn't have one. // can be used even if the trie doesn't have one.
func (t *BinaryTrie) Hash() common.Hash { func (t *BinaryTrie) Hash() common.Hash {
return t.root.Hash() return t.store.computeHash(t.store.root)
} }
// Commit writes all nodes to the trie's memory database, tracking the internal // Commit writes all nodes to the trie's memory database, tracking the internal
@ -339,15 +309,15 @@ func (t *BinaryTrie) Hash() common.Hash {
func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
nodeset := trienode.NewNodeSet(common.Hash{}) nodeset := trienode.NewNodeSet(common.Hash{})
// The root can be any type of BinaryNode (InternalNode, StemNode, etc.) // Pre-size the path buffer: collectNodes reuses it in-place via
err := t.root.CollectNodes(nil, func(path []byte, node BinaryNode) { // append/truncate; 32 covers typical binary-trie depth without regrowth.
serialized := SerializeNode(node) pathBuf := make([]byte, 0, 32)
nodeset.AddNode(path, trienode.NewNodeWithPrev(node.Hash(), serialized, t.tracer.Get(path))) err := t.store.collectNodes(t.store.root, pathBuf, func(path []byte, hash common.Hash, serialized []byte) {
nodeset.AddNode(path, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(path)))
}) })
if err != nil { if err != nil {
panic(fmt.Errorf("CollectNodes failed: %v", err)) panic(fmt.Errorf("CollectNodes failed: %v", err))
} }
// Serialize root commitment form
return t.Hash(), nodeset return t.Hash(), nodeset
} }
@ -371,7 +341,7 @@ func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
// Copy creates a deep copy of the trie. // Copy creates a deep copy of the trie.
func (t *BinaryTrie) Copy() *BinaryTrie { func (t *BinaryTrie) Copy() *BinaryTrie {
return &BinaryTrie{ return &BinaryTrie{
root: t.root.Copy(), store: t.store.Copy(),
reader: t.reader, reader: t.reader,
tracer: t.tracer.Copy(), tracer: t.tracer.Copy(),
} }
@ -407,7 +377,6 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has
if groupOffset == StemNodeWidth-1 || len(chunks)-i <= HashSize { if groupOffset == StemNodeWidth-1 || len(chunks)-i <= HashSize {
err = t.UpdateStem(key[:StemSize], values) err = t.UpdateStem(key[:StemSize], values)
if err != nil { if err != nil {
return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err) return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err)
} }

View file

@ -37,147 +37,130 @@ var (
) )
func TestSingleEntry(t *testing.T) { func TestSingleEntry(t *testing.T) {
tree := NewBinaryNode() s := newNodeStore()
tree, err := tree.Insert(zeroKey[:], oneKey[:], nil, 0) if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tree.GetHeight() != 1 { if s.getHeight(s.root) != 1 {
t.Fatal("invalid depth") t.Fatal("invalid depth")
} }
expected := common.HexToHash("aab1060e04cb4f5dc6f697ae93156a95714debbf77d54238766adc5709282b6f") expected := common.HexToHash("aab1060e04cb4f5dc6f697ae93156a95714debbf77d54238766adc5709282b6f")
got := tree.Hash() got := s.Hash()
if got != expected { if got != expected {
t.Fatalf("invalid tree root, got %x, want %x", got, expected) t.Fatalf("invalid tree root, got %x, want %x", got, expected)
} }
} }
func TestTwoEntriesDiffFirstBit(t *testing.T) { func TestTwoEntriesDiffFirstBit(t *testing.T) {
var err error s := newNodeStore()
tree := NewBinaryNode() if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil {
tree, err = tree.Insert(zeroKey[:], oneKey[:], nil, 0)
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil, 0) if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tree.GetHeight() != 2 { if s.getHeight(s.root) != 2 {
t.Fatal("invalid height") t.Fatal("invalid height")
} }
if tree.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") { if s.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") {
t.Fatal("invalid tree root") t.Fatal("invalid tree root")
} }
} }
func TestOneStemColocatedValues(t *testing.T) { func TestOneStemColocatedValues(t *testing.T) {
var err error s := newNodeStore()
tree := NewBinaryNode() if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil, 0) if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tree, err = tree.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil, 0) if err := s.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tree.GetHeight() != 1 { if s.getHeight(s.root) != 1 {
t.Fatal("invalid height") t.Fatal("invalid height")
} }
} }
func TestTwoStemColocatedValues(t *testing.T) { func TestTwoStemColocatedValues(t *testing.T) {
var err error s := newNodeStore()
tree := NewBinaryNode()
// stem: 0...0 // stem: 0...0
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// stem: 10...0 // stem: 10...0
tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tree.GetHeight() != 2 { if s.getHeight(s.root) != 2 {
t.Fatal("invalid height") t.Fatal("invalid height")
} }
} }
func TestTwoKeysMatchFirst42Bits(t *testing.T) { func TestTwoKeysMatchFirst42Bits(t *testing.T) {
var err error s := newNodeStore()
tree := NewBinaryNode()
// key1 and key 2 have the same prefix of 42 bits (b0*42+b1+b1) and differ after. // key1 and key 2 have the same prefix of 42 bits (b0*42+b1+b1) and differ after.
key1 := common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0").Bytes() key1 := common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0").Bytes()
key2 := common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000").Bytes() key2 := common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000").Bytes()
tree, err = tree.Insert(key1, oneKey[:], nil, 0) if err := s.Insert(key1, oneKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tree, err = tree.Insert(key2, twoKey[:], nil, 0) if err := s.Insert(key2, twoKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tree.GetHeight() != 1+42+1 { if s.getHeight(s.root) != 1+42+1 {
t.Fatal("invalid height") t.Fatal("invalid height")
} }
} }
func TestInsertDuplicateKey(t *testing.T) { func TestInsertDuplicateKey(t *testing.T) {
var err error s := newNodeStore()
tree := NewBinaryNode() if err := s.Insert(oneKey[:], oneKey[:], nil); err != nil {
tree, err = tree.Insert(oneKey[:], oneKey[:], nil, 0)
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tree, err = tree.Insert(oneKey[:], twoKey[:], nil, 0) if err := s.Insert(oneKey[:], twoKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tree.GetHeight() != 1 { if s.getHeight(s.root) != 1 {
t.Fatal("invalid height") t.Fatal("invalid height")
} }
// Verify that the value is updated // Verify that the value is updated
if !bytes.Equal(tree.(*StemNode).Values[1], twoKey[:]) { v, err := s.Get(oneKey[:], nil)
t.Fatal("invalid height") if err != nil {
t.Fatal(err)
}
if !bytes.Equal(v, twoKey[:]) {
t.Fatal("value not updated")
} }
} }
func TestLargeNumberOfEntries(t *testing.T) { func TestLargeNumberOfEntries(t *testing.T) {
var err error s := newNodeStore()
tree := NewBinaryNode()
for i := range StemNodeWidth { for i := range StemNodeWidth {
var key [HashSize]byte var key [HashSize]byte
key[0] = byte(i) key[0] = byte(i)
tree, err = tree.Insert(key[:], ffKey[:], nil, 0) if err := s.Insert(key[:], ffKey[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
height := tree.GetHeight() height := s.getHeight(s.root)
if height != 1+8 { if height != 1+8 {
t.Fatalf("invalid height, wanted %d, got %d", 1+8, height) t.Fatalf("invalid height, wanted %d, got %d", 1+8, height)
} }
} }
func TestMerkleizeMultipleEntries(t *testing.T) { func TestMerkleizeMultipleEntries(t *testing.T) {
var err error s := newNodeStore()
tree := NewBinaryNode()
keys := [][]byte{ keys := [][]byte{
zeroKey[:], zeroKey[:],
common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(),
@ -187,12 +170,11 @@ func TestMerkleizeMultipleEntries(t *testing.T) {
for i, key := range keys { for i, key := range keys {
var v [HashSize]byte var v [HashSize]byte
binary.LittleEndian.PutUint64(v[:8], uint64(i)) binary.LittleEndian.PutUint64(v[:8], uint64(i))
tree, err = tree.Insert(key, v[:], nil, 0) if err := s.Insert(key, v[:], nil); err != nil {
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
got := tree.Hash() got := s.Hash()
expected := common.HexToHash("9317155862f7a3867660ddd0966ff799a3d16aa4df1e70a7516eaa4a675191b5") expected := common.HexToHash("9317155862f7a3867660ddd0966ff799a3d16aa4df1e70a7516eaa4a675191b5")
if got != expected { if got != expected {
t.Fatalf("invalid root, expected=%x, got = %x", expected, got) t.Fatalf("invalid root, expected=%x, got = %x", expected, got)
@ -206,7 +188,7 @@ func TestMerkleizeMultipleEntries(t *testing.T) {
func TestStorageRoundTrip(t *testing.T) { func TestStorageRoundTrip(t *testing.T) {
tracer := trie.NewPrevalueTracer() tracer := trie.NewPrevalueTracer()
tr := &BinaryTrie{ tr := &BinaryTrie{
root: NewBinaryNode(), store: newNodeStore(),
tracer: tracer, tracer: tracer,
} }
addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678")
@ -274,7 +256,7 @@ func TestStorageRoundTrip(t *testing.T) {
func newEmptyTestTrie(t *testing.T) *BinaryTrie { func newEmptyTestTrie(t *testing.T) *BinaryTrie {
t.Helper() t.Helper()
return &BinaryTrie{ return &BinaryTrie{
root: NewBinaryNode(), store: newNodeStore(),
tracer: trie.NewPrevalueTracer(), tracer: trie.NewPrevalueTracer(),
} }
} }
@ -599,7 +581,7 @@ func TestBinaryTrieWitness(t *testing.T) {
tracer := trie.NewPrevalueTracer() tracer := trie.NewPrevalueTracer()
tr := &BinaryTrie{ tr := &BinaryTrie{
root: NewBinaryNode(), store: newNodeStore(),
tracer: tracer, tracer: tracer,
} }
if w := tr.Witness(); len(w) != 0 { if w := tr.Witness(); len(w) != 0 {
@ -626,7 +608,7 @@ func TestBinaryTrieWitness(t *testing.T) {
func testAccount(t *testing.T, addr common.Address, nonce uint64, balance uint64) *BinaryTrie { func testAccount(t *testing.T, addr common.Address, nonce uint64, balance uint64) *BinaryTrie {
t.Helper() t.Helper()
tr := &BinaryTrie{ tr := &BinaryTrie{
root: NewBinaryNode(), store: newNodeStore(),
tracer: trie.NewPrevalueTracer(), tracer: trie.NewPrevalueTracer(),
} }
acc := &types.StateAccount{ acc := &types.StateAccount{
@ -649,8 +631,8 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) {
tr := testAccount(t, addr, 42, 100) tr := testAccount(t, addr, 42, 100)
// Verify root is a StemNode (single stem inserted). // Verify root is a StemNode (single stem inserted).
if _, ok := tr.root.(*StemNode); !ok { if tr.store.root.Kind() != kindStem {
t.Fatalf("expected StemNode root, got %T", tr.root) t.Fatalf("expected StemNode root, got kind %d", tr.store.root.Kind())
} }
// Query a completely different address — must return nil. // Query a completely different address — must return nil.
@ -680,7 +662,7 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) {
// address returns nil when the trie root is an InternalNode (multi-account trie). // address returns nil when the trie root is an InternalNode (multi-account trie).
func TestGetAccountNonMembershipInternalRoot(t *testing.T) { func TestGetAccountNonMembershipInternalRoot(t *testing.T) {
tr := &BinaryTrie{ tr := &BinaryTrie{
root: NewBinaryNode(), store: newNodeStore(),
tracer: trie.NewPrevalueTracer(), tracer: trie.NewPrevalueTracer(),
} }
@ -700,8 +682,8 @@ func TestGetAccountNonMembershipInternalRoot(t *testing.T) {
} }
// Verify root is an InternalNode. // Verify root is an InternalNode.
if _, ok := tr.root.(*InternalNode); !ok { if tr.store.root.Kind() != kindInternal {
t.Fatalf("expected InternalNode root, got %T", tr.root) t.Fatalf("expected InternalNode root, got kind %d", tr.store.root.Kind())
} }
// Query a non-existent address — must return nil. // Query a non-existent address — must return nil.
@ -723,8 +705,8 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) {
tr := testAccount(t, addr, 1, 100) tr := testAccount(t, addr, 1, 100)
// Verify root is a StemNode. // Verify root is a StemNode.
if _, ok := tr.root.(*StemNode); !ok { if tr.store.root.Kind() != kindStem {
t.Fatalf("expected StemNode root, got %T", tr.root) t.Fatalf("expected StemNode root, got kind %d", tr.store.root.Kind())
} }
// Query storage for a different address — must return nil, not panic. // Query storage for a different address — must return nil, not panic.
@ -743,7 +725,7 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) {
// non-existent address returns nil when the root is an InternalNode. // non-existent address returns nil when the root is an InternalNode.
func TestGetStorageNonMembershipInternalRoot(t *testing.T) { func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
tr := &BinaryTrie{ tr := &BinaryTrie{
root: NewBinaryNode(), store: newNodeStore(),
tracer: trie.NewPrevalueTracer(), tracer: trie.NewPrevalueTracer(),
} }
@ -765,8 +747,8 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
t.Fatalf("UpdateStorage error: %v", err) t.Fatalf("UpdateStorage error: %v", err)
} }
if _, ok := tr.root.(*InternalNode); !ok { if tr.store.root.Kind() != kindInternal {
t.Fatalf("expected InternalNode root, got %T", tr.root) t.Fatalf("expected InternalNode root, got kind %d", tr.store.root.Kind())
} }
// Query storage for a non-existent address — must return nil. // Query storage for a non-existent address — must return nil.
@ -780,36 +762,29 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
} }
} }
// commitKeyN derives a distinct 32-byte key from a seed integer. Used by // TestCommitSkipCleanSubtrees verifies that CollectNodes short-circuits on
// TestBinaryTrieCommitIncremental and BenchmarkCollectNodes_SparseWrite to // clean subtrees. First Commit flushes every resolved node; a follow-up
// populate a trie with many disjoint stems. // Commit with no modifications flushes nothing; a single-leaf modification
func commitKeyN(i int) [HashSize]byte { // flushes only the root-to-leaf path.
var k [HashSize]byte func TestCommitSkipCleanSubtrees(t *testing.T) {
binary.BigEndian.PutUint64(k[:8], uint64(i)*0x9e3779b97f4a7c15)
binary.BigEndian.PutUint64(k[8:16], uint64(i)*0xc2b2ae3d27d4eb4f)
binary.BigEndian.PutUint64(k[16:24], uint64(i)*0x165667b19e3779f9)
binary.BigEndian.PutUint64(k[24:32], uint64(i)*0x85ebca77c2b2ae63)
return k
}
// TestBinaryTrieCommitIncremental verifies that a second Commit with only a
// single modified leaf flushes only the path from that leaf to the root,
// not the entire tree.
func TestBinaryTrieCommitIncremental(t *testing.T) {
tr := &BinaryTrie{ tr := &BinaryTrie{
root: NewBinaryNode(), store: newNodeStore(),
tracer: trie.NewPrevalueTracer(), tracer: trie.NewPrevalueTracer(),
} }
const n = 200
const n = 512 key := func(i int) [HashSize]byte {
keys := make([][HashSize]byte, n) var k [HashSize]byte
binary.BigEndian.PutUint64(k[:8], uint64(i+1)*0x9e3779b97f4a7c15)
binary.BigEndian.PutUint64(k[8:16], uint64(i+1)*0xc2b2ae3d27d4eb4f)
binary.BigEndian.PutUint64(k[16:24], uint64(i+1)*0x165667b19e3779f9)
binary.BigEndian.PutUint64(k[24:32], uint64(i+1)*0x85ebca77c2b2ae63)
return k
}
for i := range n { for i := range n {
keys[i] = commitKeyN(i + 1) k := key(i)
var v [HashSize]byte var v [HashSize]byte
binary.BigEndian.PutUint64(v[24:], uint64(i+1)) binary.BigEndian.PutUint64(v[24:], uint64(i+1))
var err error if err := tr.store.Insert(k[:], v[:], nil); err != nil {
tr.root, err = tr.root.Insert(keys[i][:], v[:], nil, 0)
if err != nil {
t.Fatalf("Insert %d: %v", i, err) t.Fatalf("Insert %d: %v", i, err)
} }
} }
@ -818,69 +793,55 @@ func TestBinaryTrieCommitIncremental(t *testing.T) {
if len(ns1.Nodes) == 0 { if len(ns1.Nodes) == 0 {
t.Fatal("first Commit produced empty NodeSet") t.Fatal("first Commit produced empty NodeSet")
} }
if len(ns1.Nodes) < n {
t.Fatalf("first Commit: expected at least %d nodes, got %d", n, len(ns1.Nodes))
}
// Second Commit on the same trie with no modifications: NodeSet must
// be empty because every subtree is clean.
_, nsNoop := tr.Commit(false) _, nsNoop := tr.Commit(false)
if len(nsNoop.Nodes) != 0 { if len(nsNoop.Nodes) != 0 {
t.Fatalf("no-op Commit: expected empty NodeSet, got %d nodes", len(nsNoop.Nodes)) t.Fatalf("no-op Commit: expected empty NodeSet, got %d", len(nsNoop.Nodes))
} }
// Modify a single leaf's value. Only the path from that leaf to the // Modify a single leaf — only the root-to-leaf path should flush.
// root should appear in the next Commit's NodeSet. k := key(n / 2)
var newVal [HashSize]byte var newVal [HashSize]byte
newVal[0] = 0xff newVal[0] = 0xff
var err error if err := tr.store.Insert(k[:], newVal[:], nil); err != nil {
tr.root, err = tr.root.Insert(keys[n/2][:], newVal[:], nil, 0)
if err != nil {
t.Fatalf("Insert (modify): %v", err) t.Fatalf("Insert (modify): %v", err)
} }
_, ns2 := tr.Commit(false) _, ns2 := tr.Commit(false)
// Path length for a binary trie of n=512 stems is bounded by the
// internal depth at which the modified stem sits. Allow generous
// slack: up to 64 nodes is fine, anywhere near n (512) is a regression.
if len(ns2.Nodes) == 0 { if len(ns2.Nodes) == 0 {
t.Fatal("modified Commit produced empty NodeSet") t.Fatal("modified Commit produced empty NodeSet")
} }
if len(ns2.Nodes) > 64 { if len(ns2.Nodes) > 32 {
t.Fatalf("modified Commit: expected small NodeSet, got %d nodes (first Commit had %d)", len(ns2.Nodes), len(ns1.Nodes)) t.Fatalf("modified Commit: expected ≤32 nodes (path+stem), got %d", len(ns2.Nodes))
} }
if len(ns2.Nodes) >= len(ns1.Nodes) { if len(ns2.Nodes) >= len(ns1.Nodes) {
t.Fatalf("expected second NodeSet (%d) to be smaller than first (%d)", len(ns2.Nodes), len(ns1.Nodes)) t.Fatalf("expected second NodeSet (%d) to be smaller than first (%d)", len(ns2.Nodes), len(ns1.Nodes))
} }
} }
// BenchmarkCollectNodes_SparseWrite measures Commit cost when only one leaf // BenchmarkCollectNodesSparseWrite measures Commit cost when one leaf
// changes between blocks — the common case for state updates. After warm-up // changes per block — the common case for state updates. After warm-up
// (populate + initial Commit), each iteration modifies a single leaf and // (populate + initial Commit), each iteration modifies a single leaf and
// re-Commits. Under the skip-clean optimization, each iteration flushes // re-Commits. Matches the shape of the same-named benchmark on master so
// only the root-to-leaf path; pre-fix behavior would re-flush the entire // the two trees can be benchstat'd directly.
// tree every iteration. func BenchmarkCollectNodesSparseWrite(b *testing.B) {
func BenchmarkCollectNodes_SparseWrite(b *testing.B) {
const n = 10_000 const n = 10_000
tr := &BinaryTrie{ tr := &BinaryTrie{
root: NewBinaryNode(), store: newNodeStore(),
tracer: trie.NewPrevalueTracer(), tracer: trie.NewPrevalueTracer(),
} }
keys := make([][HashSize]byte, n) keys := make([][HashSize]byte, n)
for i := range n { for i := range n {
keys[i] = commitKeyN(i + 1) binary.BigEndian.PutUint64(keys[i][:8], uint64(i+1)*0x9e3779b97f4a7c15)
binary.BigEndian.PutUint64(keys[i][8:16], uint64(i+1)*0xc2b2ae3d27d4eb4f)
binary.BigEndian.PutUint64(keys[i][16:24], uint64(i+1)*0x165667b19e3779f9)
binary.BigEndian.PutUint64(keys[i][24:32], uint64(i+1)*0x85ebca77c2b2ae63)
var v [HashSize]byte var v [HashSize]byte
binary.BigEndian.PutUint64(v[24:], uint64(i+1)) binary.BigEndian.PutUint64(v[24:], uint64(i+1))
var err error if err := tr.store.Insert(keys[i][:], v[:], nil); err != nil {
tr.root, err = tr.root.Insert(keys[i][:], v[:], nil, 0) b.Fatalf("warmup Insert %d: %v", i, err)
if err != nil {
b.Fatalf("Insert %d: %v", i, err)
} }
} }
// Flush the initial tree so subsequent Commits reflect the _, _ = tr.Commit(false) // warmup flush
// single-modification workload we want to measure.
_, _ = tr.Commit(false)
var newVal [HashSize]byte var newVal [HashSize]byte
b.ReportAllocs() b.ReportAllocs()
@ -888,10 +849,8 @@ func BenchmarkCollectNodes_SparseWrite(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
idx := i % n idx := i % n
binary.BigEndian.PutUint64(newVal[24:], uint64(i+1)) binary.BigEndian.PutUint64(newVal[24:], uint64(i+1))
var err error if err := tr.store.Insert(keys[idx][:], newVal[:], nil); err != nil {
tr.root, err = tr.root.Insert(keys[idx][:], newVal[:], nil, 0) b.Fatalf("iter %d Insert: %v", i, err)
if err != nil {
b.Fatalf("Insert at iter %d: %v", i, err)
} }
_, ns := tr.Commit(false) _, ns := tr.Commit(false)
if len(ns.Nodes) == 0 { if len(ns.Nodes) == 0 {

View file

@ -102,11 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) {
if len(blob) == 0 { if len(blob) == 0 {
return types.EmptyBinaryHash, nil return types.EmptyBinaryHash, nil
} }
n, err := bintrie.DeserializeNode(blob, 0) return bintrie.DeserializeAndHash(blob, 0)
if err != nil {
return common.Hash{}, err
}
return n.Hash(), nil
} }
// Database is a multiple-layered structure for maintaining in-memory states // Database is a multiple-layered structure for maintaining in-memory states