mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-06-12 01:41:36 +00:00
trie/bintrie: replace BinaryNode interface with GC-free NodeRef arena
Replace the BinaryNode interface (which uses Go interface pointers that the GC must scan) with NodeRef uint32 indices into typed arena pools. NodeRef packs a 2-bit kind tag and 30-bit pool index into a single uint32, making it invisible to the garbage collector. NodeStore manages chunked typed pools per node kind: - InternalNode pool: ZERO Go pointers (children are NodeRef, hash is [32]byte) → allocated in noscan spans, GC skips entirely - HashedNode pool: ZERO Go pointers → noscan spans - StemNode pool: ONE pointer per node (valueData []byte) → minimal GC For a trie with 25K InternalNodes, this reduces GC-scanned pointer-words from ~125K to ~10K (85% reduction). CPU profiling showed 44% of time in GC; this refactor directly addresses that bottleneck. Serialization format is unchanged — the on-disk representation is fully compatible. All existing tests pass.
This commit is contained in:
parent
61bfacc52f
commit
8a5e777fde
20 changed files with 2170 additions and 2011 deletions
|
|
@ -16,140 +16,43 @@
|
|||
|
||||
package bintrie
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
type (
|
||||
NodeFlushFn func([]byte, BinaryNode)
|
||||
NodeResolverFn func([]byte, common.Hash) ([]byte, error)
|
||||
)
|
||||
import "github.com/ethereum/go-ethereum/common"
|
||||
|
||||
// zero is the zero value for a 32-byte array.
|
||||
var zero [32]byte
|
||||
|
||||
const (
|
||||
StemNodeWidth = 256 // Number of child per leaf node
|
||||
StemSize = 31 // Number of bytes to travel before reaching a group of leaves
|
||||
NodeTypeBytes = 1 // Size of node type prefix in serialization
|
||||
HashSize = 32 // Size of a hash in bytes
|
||||
BitmapSize = 32 // Size of the bitmap in a stem node
|
||||
StemNodeWidth = 256 // Number of children per leaf node
|
||||
StemSize = 31 // Number of bytes to travel before reaching a group of leaves
|
||||
NodeTypeBytes = 1 // Size of node type prefix in serialization
|
||||
HashSize = 32 // Size of a hash in bytes
|
||||
StemBitmapSize = 32 // Size of the bitmap in a stem node (256 values = 32 bytes)
|
||||
|
||||
// MaxGroupDepth is the maximum allowed group depth for InternalNode serialization.
|
||||
MaxGroupDepth = 8
|
||||
)
|
||||
|
||||
// BitmapSizeForDepth returns the bitmap size in bytes for a given group depth.
|
||||
func BitmapSizeForDepth(groupDepth int) int {
|
||||
if groupDepth <= 3 {
|
||||
return 1
|
||||
}
|
||||
return 1 << (groupDepth - 3)
|
||||
}
|
||||
|
||||
const (
|
||||
nodeTypeStem = iota + 1 // Stem node, contains a stem and a bitmap of values
|
||||
nodeTypeStem = iota + 1
|
||||
nodeTypeInternal
|
||||
)
|
||||
|
||||
// BinaryNode is an interface for a binary trie node.
|
||||
type BinaryNode interface {
|
||||
Get([]byte, NodeResolverFn) ([]byte, error)
|
||||
Insert([]byte, []byte, NodeResolverFn, int) (BinaryNode, error)
|
||||
Copy() BinaryNode
|
||||
Hash() common.Hash
|
||||
GetValuesAtStem([]byte, NodeResolverFn) ([][]byte, error)
|
||||
InsertValuesAtStem([]byte, [][]byte, NodeResolverFn, int) (BinaryNode, error)
|
||||
CollectNodes([]byte, NodeFlushFn) error
|
||||
|
||||
toDot(parent, path string) string
|
||||
GetHeight() int
|
||||
}
|
||||
|
||||
// SerializeNode serializes a binary trie node into a byte slice.
|
||||
func SerializeNode(node BinaryNode) []byte {
|
||||
switch n := (node).(type) {
|
||||
case *InternalNode:
|
||||
// InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash
|
||||
var serialized [NodeTypeBytes + HashSize + HashSize]byte
|
||||
serialized[0] = nodeTypeInternal
|
||||
copy(serialized[1:33], n.left.Hash().Bytes())
|
||||
copy(serialized[33:65], n.right.Hash().Bytes())
|
||||
return serialized[:]
|
||||
case *StemNode:
|
||||
// StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values
|
||||
var serialized [NodeTypeBytes + StemSize + BitmapSize + StemNodeWidth*HashSize]byte
|
||||
serialized[0] = nodeTypeStem
|
||||
copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], n.Stem)
|
||||
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize]
|
||||
offset := NodeTypeBytes + StemSize + BitmapSize
|
||||
for i, v := range n.Values {
|
||||
if v != nil {
|
||||
bitmap[i/8] |= 1 << (7 - (i % 8))
|
||||
copy(serialized[offset:offset+HashSize], v)
|
||||
offset += HashSize
|
||||
}
|
||||
}
|
||||
// Only return the actual data, not the entire array
|
||||
return serialized[:offset]
|
||||
default:
|
||||
panic("invalid node type")
|
||||
// DeserializeAndHash deserializes a node from bytes and returns its hash.
|
||||
// This is a convenience function for external callers that need to compute
|
||||
// the hash of a serialized node without maintaining a NodeStore.
|
||||
func DeserializeAndHash(blob []byte, depth int) (common.Hash, error) {
|
||||
s := NewNodeStore()
|
||||
ref, err := s.DeserializeNode(blob, depth)
|
||||
if err != nil {
|
||||
return common.Hash{}, err
|
||||
}
|
||||
}
|
||||
|
||||
var invalidSerializedLength = errors.New("invalid serialized node length")
|
||||
|
||||
// DeserializeNode deserializes a binary trie node from a byte slice. The
|
||||
// hash will be recomputed from the deserialized data.
|
||||
func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
|
||||
return deserializeNode(serialized, depth, common.Hash{}, true, true)
|
||||
}
|
||||
|
||||
// DeserializeNodeWithHash deserializes a binary trie node from a byte slice, using the provided hash.
|
||||
func DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
|
||||
return deserializeNode(serialized, depth, hn, false, false)
|
||||
}
|
||||
|
||||
func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (BinaryNode, error) {
|
||||
if len(serialized) == 0 {
|
||||
return Empty{}, nil
|
||||
}
|
||||
|
||||
switch serialized[0] {
|
||||
case nodeTypeInternal:
|
||||
if len(serialized) != 65 {
|
||||
return nil, invalidSerializedLength
|
||||
}
|
||||
return &InternalNode{
|
||||
depth: depth,
|
||||
left: HashedNode(common.BytesToHash(serialized[1:33])),
|
||||
right: HashedNode(common.BytesToHash(serialized[33:65])),
|
||||
hash: hn,
|
||||
mustRecompute: mustRecompute,
|
||||
dirty: dirty,
|
||||
}, nil
|
||||
case nodeTypeStem:
|
||||
if len(serialized) < 64 {
|
||||
return nil, invalidSerializedLength
|
||||
}
|
||||
var values [StemNodeWidth][]byte
|
||||
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize]
|
||||
offset := NodeTypeBytes + StemSize + BitmapSize
|
||||
|
||||
for i := range StemNodeWidth {
|
||||
if bitmap[i/8]>>(7-(i%8))&1 == 1 {
|
||||
if len(serialized) < offset+HashSize {
|
||||
return nil, invalidSerializedLength
|
||||
}
|
||||
values[i] = serialized[offset : offset+HashSize]
|
||||
offset += HashSize
|
||||
}
|
||||
}
|
||||
return &StemNode{
|
||||
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
|
||||
Values: values[:],
|
||||
depth: depth,
|
||||
hash: hn,
|
||||
mustRecompute: mustRecompute,
|
||||
dirty: dirty,
|
||||
}, nil
|
||||
default:
|
||||
return nil, errors.New("invalid node type")
|
||||
}
|
||||
}
|
||||
|
||||
// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
|
||||
func ToDot(root BinaryNode) string {
|
||||
return root.toDot("", "")
|
||||
return s.ComputeHash(ref), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,78 +24,118 @@ import (
|
|||
)
|
||||
|
||||
// TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode
|
||||
// with the grouped subtree format through NodeStore.
|
||||
func TestSerializeDeserializeInternalNode(t *testing.T) {
|
||||
// Create an internal node with two hashed children
|
||||
leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
|
||||
rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321")
|
||||
|
||||
node := &InternalNode{
|
||||
depth: 5,
|
||||
left: HashedNode(leftHash),
|
||||
right: HashedNode(rightHash),
|
||||
}
|
||||
s := NewNodeStore()
|
||||
leftRef := s.newHashedRef(leftHash)
|
||||
rightRef := s.newHashedRef(rightHash)
|
||||
|
||||
// Serialize the node
|
||||
serialized := SerializeNode(node)
|
||||
rootRef := s.newInternalRef(0)
|
||||
rootNode := s.getInternal(rootRef.Index())
|
||||
rootNode.left = leftRef
|
||||
rootNode.right = rightRef
|
||||
s.SetRoot(rootRef)
|
||||
|
||||
// Serialize the node with default group depth of 8
|
||||
serialized := s.SerializeNode(rootRef, MaxGroupDepth)
|
||||
|
||||
// Check the serialized format
|
||||
if serialized[0] != nodeTypeInternal {
|
||||
t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0])
|
||||
}
|
||||
|
||||
if len(serialized) != 65 {
|
||||
t.Errorf("Expected serialized length to be 65, got %d", len(serialized))
|
||||
if serialized[1] != MaxGroupDepth {
|
||||
t.Errorf("Expected group depth to be %d, got %d", MaxGroupDepth, serialized[1])
|
||||
}
|
||||
|
||||
// Deserialize the node
|
||||
deserialized, err := DeserializeNode(serialized, 5)
|
||||
bitmapSize := BitmapSizeForDepth(MaxGroupDepth)
|
||||
expectedLen := 1 + 1 + bitmapSize + 2*HashSize
|
||||
if len(serialized) != expectedLen {
|
||||
t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized))
|
||||
}
|
||||
|
||||
// Check bitmap bits
|
||||
bitmap := serialized[2 : 2+bitmapSize]
|
||||
if bitmap[0]&0x80 == 0 {
|
||||
t.Error("Expected bit 0 to be set in bitmap (left child)")
|
||||
}
|
||||
if bitmap[16]&0x80 == 0 {
|
||||
t.Error("Expected bit 128 to be set in bitmap (right child)")
|
||||
}
|
||||
|
||||
// Deserialize into a new store
|
||||
ds := NewNodeStore()
|
||||
deserialized, err := ds.DeserializeNode(serialized, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to deserialize node: %v", err)
|
||||
}
|
||||
|
||||
// Check that it's an internal node
|
||||
internalNode, ok := deserialized.(*InternalNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected InternalNode, got %T", deserialized)
|
||||
// Root should be an InternalNode
|
||||
if deserialized.Kind() != KindInternal {
|
||||
t.Fatalf("Expected KindInternal, got kind %d", deserialized.Kind())
|
||||
}
|
||||
|
||||
// Check the depth
|
||||
if internalNode.depth != 5 {
|
||||
t.Errorf("Expected depth 5, got %d", internalNode.depth)
|
||||
internalNode := ds.getInternal(deserialized.Index())
|
||||
if internalNode.depth != 0 {
|
||||
t.Errorf("Expected depth 0, got %d", internalNode.depth)
|
||||
}
|
||||
|
||||
// Check the left and right hashes
|
||||
if internalNode.left.Hash() != leftHash {
|
||||
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash())
|
||||
// Navigate to position 0 (8 left turns) to find the left hash
|
||||
node0 := navigateToLeafRef(ds, deserialized, 0, 8)
|
||||
if ds.ComputeHash(node0) != leftHash {
|
||||
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, ds.ComputeHash(node0))
|
||||
}
|
||||
|
||||
if internalNode.right.Hash() != rightHash {
|
||||
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash())
|
||||
// Navigate to position 128 (right, then 7 lefts) to find the right hash
|
||||
node128 := navigateToLeafRef(ds, deserialized, 128, 8)
|
||||
if ds.ComputeHash(node128) != rightHash {
|
||||
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, ds.ComputeHash(node128))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode
|
||||
// navigateToLeafRef navigates to a specific position in the tree using NodeStore.
|
||||
func navigateToLeafRef(s *NodeStore, ref NodeRef, position, depth int) NodeRef {
|
||||
cur := ref
|
||||
for d := 0; d < depth; d++ {
|
||||
if cur.Kind() != KindInternal {
|
||||
return cur
|
||||
}
|
||||
in := s.getInternal(cur.Index())
|
||||
bit := (position >> (depth - 1 - d)) & 1
|
||||
if bit == 0 {
|
||||
cur = in.left
|
||||
} else {
|
||||
cur = in.right
|
||||
}
|
||||
}
|
||||
return cur
|
||||
}
|
||||
|
||||
// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode through NodeStore.
|
||||
func TestSerializeDeserializeStemNode(t *testing.T) {
|
||||
// Create a stem node with some values
|
||||
stem := make([]byte, StemSize)
|
||||
for i := range stem {
|
||||
stem[i] = byte(i)
|
||||
}
|
||||
|
||||
var values [StemNodeWidth][]byte
|
||||
// Add some values at different indices
|
||||
values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes()
|
||||
values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes()
|
||||
values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes()
|
||||
|
||||
node := &StemNode{
|
||||
Stem: stem,
|
||||
Values: values[:],
|
||||
depth: 10,
|
||||
s := NewNodeStore()
|
||||
ref := s.newStemRef(stem, 10)
|
||||
sn := s.getStem(ref.Index())
|
||||
for i, v := range values {
|
||||
if v != nil {
|
||||
sn.setValue(byte(i), v)
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize the node
|
||||
serialized := SerializeNode(node)
|
||||
serialized := s.SerializeNode(ref, MaxGroupDepth)
|
||||
|
||||
// Check the serialized format
|
||||
if serialized[0] != nodeTypeStem {
|
||||
|
|
@ -107,31 +147,32 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
|
|||
t.Errorf("Stem mismatch in serialized data")
|
||||
}
|
||||
|
||||
// Deserialize the node
|
||||
deserialized, err := DeserializeNode(serialized, 10)
|
||||
// Deserialize into a new store
|
||||
ds := NewNodeStore()
|
||||
deserializedRef, err := ds.DeserializeNode(serialized, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to deserialize node: %v", err)
|
||||
}
|
||||
|
||||
// Check that it's a stem node
|
||||
stemNode, ok := deserialized.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected StemNode, got %T", deserialized)
|
||||
if deserializedRef.Kind() != KindStem {
|
||||
t.Fatalf("Expected KindStem, got kind %d", deserializedRef.Kind())
|
||||
}
|
||||
|
||||
stemNode := ds.getStem(deserializedRef.Index())
|
||||
|
||||
// Check the stem
|
||||
if !bytes.Equal(stemNode.Stem, stem) {
|
||||
if !bytes.Equal(stemNode.Stem[:], stem) {
|
||||
t.Errorf("Stem mismatch after deserialization")
|
||||
}
|
||||
|
||||
// Check the values
|
||||
if !bytes.Equal(stemNode.Values[0], values[0]) {
|
||||
if !bytes.Equal(stemNode.getValue(0), values[0]) {
|
||||
t.Errorf("Value at index 0 mismatch")
|
||||
}
|
||||
if !bytes.Equal(stemNode.Values[10], values[10]) {
|
||||
if !bytes.Equal(stemNode.getValue(10), values[10]) {
|
||||
t.Errorf("Value at index 10 mismatch")
|
||||
}
|
||||
if !bytes.Equal(stemNode.Values[255], values[255]) {
|
||||
if !bytes.Equal(stemNode.getValue(255), values[255]) {
|
||||
t.Errorf("Value at index 255 mismatch")
|
||||
}
|
||||
|
||||
|
|
@ -140,43 +181,43 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
|
|||
if i == 0 || i == 10 || i == 255 {
|
||||
continue
|
||||
}
|
||||
if stemNode.Values[i] != nil {
|
||||
t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i])
|
||||
if stemNode.hasValue(byte(i)) {
|
||||
t.Errorf("Expected no value at index %d, got %x", i, stemNode.getValue(byte(i)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeserializeEmptyNode tests deserialization of empty node
|
||||
// TestDeserializeEmptyNode tests deserialization of empty node.
|
||||
func TestDeserializeEmptyNode(t *testing.T) {
|
||||
// Empty byte slice should deserialize to Empty node
|
||||
deserialized, err := DeserializeNode([]byte{}, 0)
|
||||
s := NewNodeStore()
|
||||
deserialized, err := s.DeserializeNode([]byte{}, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to deserialize empty node: %v", err)
|
||||
}
|
||||
|
||||
_, ok := deserialized.(Empty)
|
||||
if !ok {
|
||||
t.Fatalf("Expected Empty node, got %T", deserialized)
|
||||
if !deserialized.IsEmpty() {
|
||||
t.Fatalf("Expected EmptyRef, got kind %d", deserialized.Kind())
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeserializeInvalidType tests deserialization with invalid type byte
|
||||
// TestDeserializeInvalidType tests deserialization with invalid type byte.
|
||||
func TestDeserializeInvalidType(t *testing.T) {
|
||||
// Create invalid serialized data with unknown type byte
|
||||
s := NewNodeStore()
|
||||
invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid
|
||||
|
||||
_, err := DeserializeNode(invalidData, 0)
|
||||
_, err := s.DeserializeNode(invalidData, 0)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid type byte, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeserializeInvalidLength tests deserialization with invalid data length
|
||||
// TestDeserializeInvalidLength tests deserialization with invalid data length.
|
||||
func TestDeserializeInvalidLength(t *testing.T) {
|
||||
// InternalNode with type byte 1 but wrong length
|
||||
invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node
|
||||
s := NewNodeStore()
|
||||
// InternalNode with valid type byte and group depth but too short for bitmap
|
||||
invalidData := []byte{nodeTypeInternal, 8, 0, 0} // Too short for bitmap (needs 32 bytes)
|
||||
|
||||
_, err := DeserializeNode(invalidData, 0)
|
||||
_, err := s.DeserializeNode(invalidData, 0)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid data length, got nil")
|
||||
}
|
||||
|
|
@ -186,7 +227,7 @@ func TestDeserializeInvalidLength(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestKeyToPath tests the keyToPath function
|
||||
// TestKeyToPath tests the keyToPath function.
|
||||
func TestKeyToPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -218,14 +259,14 @@ func TestKeyToPath(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "max valid depth",
|
||||
depth: StemSize * 8,
|
||||
depth: StemSize*8 - 1,
|
||||
key: make([]byte, HashSize),
|
||||
expected: make([]byte, StemSize*8+1),
|
||||
expected: make([]byte, StemSize*8),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "depth too large",
|
||||
depth: StemSize*8 + 1,
|
||||
depth: StemSize * 8,
|
||||
key: make([]byte, HashSize),
|
||||
wantErr: true,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,76 +0,0 @@
|
|||
// Copyright 2025 go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bintrie
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
type Empty struct{}
|
||||
|
||||
func (e Empty) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
|
||||
var values [256][]byte
|
||||
values[key[31]] = value
|
||||
return &StemNode{
|
||||
Stem: slices.Clone(key[:31]),
|
||||
Values: values[:],
|
||||
depth: depth,
|
||||
mustRecompute: true,
|
||||
dirty: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e Empty) Copy() BinaryNode {
|
||||
return Empty{}
|
||||
}
|
||||
|
||||
func (e Empty) Hash() common.Hash {
|
||||
return common.Hash{}
|
||||
}
|
||||
|
||||
func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
|
||||
var values [256][]byte
|
||||
return values[:], nil
|
||||
}
|
||||
|
||||
func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
|
||||
return &StemNode{
|
||||
Stem: slices.Clone(key[:31]),
|
||||
Values: values,
|
||||
depth: depth,
|
||||
mustRecompute: true,
|
||||
dirty: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e Empty) CollectNodes(_ []byte, _ NodeFlushFn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e Empty) toDot(parent string, path string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e Empty) GetHeight() int {
|
||||
return 0
|
||||
}
|
||||
|
|
@ -1,264 +0,0 @@
|
|||
// Copyright 2025 go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bintrie
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// TestEmptyGet tests the Get method
|
||||
func TestEmptyGet(t *testing.T) {
|
||||
node := Empty{}
|
||||
|
||||
key := make([]byte, 32)
|
||||
value, err := node.Get(key, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if value != nil {
|
||||
t.Errorf("Expected nil value from empty node, got %x", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyInsert tests the Insert method
|
||||
func TestEmptyInsert(t *testing.T) {
|
||||
node := Empty{}
|
||||
|
||||
key := make([]byte, 32)
|
||||
key[0] = 0x12
|
||||
key[31] = 0x34
|
||||
value := common.HexToHash("0xabcd").Bytes()
|
||||
|
||||
newNode, err := node.Insert(key, value, nil, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert: %v", err)
|
||||
}
|
||||
|
||||
// Should create a StemNode
|
||||
stemNode, ok := newNode.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected StemNode, got %T", newNode)
|
||||
}
|
||||
|
||||
// Check the stem (first 31 bytes of key)
|
||||
if !bytes.Equal(stemNode.Stem, key[:31]) {
|
||||
t.Errorf("Stem mismatch: expected %x, got %x", key[:31], stemNode.Stem)
|
||||
}
|
||||
|
||||
// Check the value at the correct index (last byte of key)
|
||||
if !bytes.Equal(stemNode.Values[key[31]], value) {
|
||||
t.Errorf("Value mismatch at index %d: expected %x, got %x", key[31], value, stemNode.Values[key[31]])
|
||||
}
|
||||
|
||||
// Check that other values are nil
|
||||
for i := 0; i < 256; i++ {
|
||||
if i != int(key[31]) && stemNode.Values[i] != nil {
|
||||
t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyCopy tests the Copy method
|
||||
func TestEmptyCopy(t *testing.T) {
|
||||
node := Empty{}
|
||||
|
||||
copied := node.Copy()
|
||||
copiedEmpty, ok := copied.(Empty)
|
||||
if !ok {
|
||||
t.Fatalf("Expected Empty, got %T", copied)
|
||||
}
|
||||
|
||||
// Both should be empty
|
||||
if node != copiedEmpty {
|
||||
// Empty is a zero-value struct, so copies should be equal
|
||||
t.Errorf("Empty nodes should be equal")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyHash tests the Hash method
|
||||
func TestEmptyHash(t *testing.T) {
|
||||
node := Empty{}
|
||||
|
||||
hash := node.Hash()
|
||||
|
||||
// Empty node should have zero hash
|
||||
if hash != (common.Hash{}) {
|
||||
t.Errorf("Expected zero hash for empty node, got %x", hash)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyGetValuesAtStem tests the GetValuesAtStem method
|
||||
func TestEmptyGetValuesAtStem(t *testing.T) {
|
||||
node := Empty{}
|
||||
|
||||
stem := make([]byte, 31)
|
||||
values, err := node.GetValuesAtStem(stem, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should return an array of 256 nil values
|
||||
if len(values) != 256 {
|
||||
t.Errorf("Expected 256 values, got %d", len(values))
|
||||
}
|
||||
|
||||
for i, v := range values {
|
||||
if v != nil {
|
||||
t.Errorf("Expected nil value at index %d, got %x", i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyInsertValuesAtStem tests the InsertValuesAtStem method
|
||||
func TestEmptyInsertValuesAtStem(t *testing.T) {
|
||||
node := Empty{}
|
||||
|
||||
stem := make([]byte, 31)
|
||||
stem[0] = 0x42
|
||||
|
||||
var values [256][]byte
|
||||
values[0] = common.HexToHash("0x0101").Bytes()
|
||||
values[10] = common.HexToHash("0x0202").Bytes()
|
||||
values[255] = common.HexToHash("0x0303").Bytes()
|
||||
|
||||
newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert values: %v", err)
|
||||
}
|
||||
|
||||
// Should create a StemNode
|
||||
stemNode, ok := newNode.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected StemNode, got %T", newNode)
|
||||
}
|
||||
|
||||
// Check the stem
|
||||
if !bytes.Equal(stemNode.Stem, stem) {
|
||||
t.Errorf("Stem mismatch: expected %x, got %x", stem, stemNode.Stem)
|
||||
}
|
||||
|
||||
// Check the depth
|
||||
if stemNode.depth != 5 {
|
||||
t.Errorf("Depth mismatch: expected 5, got %d", stemNode.depth)
|
||||
}
|
||||
|
||||
// Check the values
|
||||
if !bytes.Equal(stemNode.Values[0], values[0]) {
|
||||
t.Error("Value at index 0 mismatch")
|
||||
}
|
||||
if !bytes.Equal(stemNode.Values[10], values[10]) {
|
||||
t.Error("Value at index 10 mismatch")
|
||||
}
|
||||
if !bytes.Equal(stemNode.Values[255], values[255]) {
|
||||
t.Error("Value at index 255 mismatch")
|
||||
}
|
||||
|
||||
// Check that values is the same slice (not a copy)
|
||||
if &stemNode.Values[0] != &values[0] {
|
||||
t.Error("Expected values to be the same slice reference")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyCollectNodes tests the CollectNodes method
|
||||
func TestEmptyCollectNodes(t *testing.T) {
|
||||
node := Empty{}
|
||||
|
||||
var collected []BinaryNode
|
||||
flushFn := func(path []byte, n BinaryNode) {
|
||||
collected = append(collected, n)
|
||||
}
|
||||
|
||||
err := node.CollectNodes([]byte{0, 1, 0}, flushFn)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should not collect anything for empty node
|
||||
if len(collected) != 0 {
|
||||
t.Errorf("Expected no collected nodes for empty, got %d", len(collected))
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyToDot tests the toDot method
|
||||
func TestEmptyToDot(t *testing.T) {
|
||||
node := Empty{}
|
||||
|
||||
dot := node.toDot("parent", "010")
|
||||
|
||||
// Should return empty string for empty node
|
||||
if dot != "" {
|
||||
t.Errorf("Expected empty string for empty node toDot, got %s", dot)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyGetHeight tests the GetHeight method
|
||||
func TestEmptyGetHeight(t *testing.T) {
|
||||
node := Empty{}
|
||||
|
||||
height := node.GetHeight()
|
||||
|
||||
// Empty node should have height 0
|
||||
if height != 0 {
|
||||
t.Errorf("Expected height 0 for empty node, got %d", height)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyInsertMarksDirty verifies that a StemNode produced by Empty.Insert
|
||||
// is marked dirty. Without this, CollectNodes would skip the freshly created
|
||||
// stem and its blob would never reach disk, producing "missing trie node"
|
||||
// errors on subsequent reads.
|
||||
func TestEmptyInsertMarksDirty(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
key[0] = 0xaa
|
||||
val := make([]byte, 32)
|
||||
val[0] = 0xbb
|
||||
n, err := Empty{}.Insert(key, val, nil, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Insert: %v", err)
|
||||
}
|
||||
sn, ok := n.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("expected *StemNode, got %T", n)
|
||||
}
|
||||
if !sn.dirty {
|
||||
t.Fatalf("stem produced by Empty.Insert must have dirty=true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptyInsertValuesAtStemMarksDirty is the analogous guard for the
|
||||
// bulk-insert entry point. Fresh stems created here must be dirty.
|
||||
func TestEmptyInsertValuesAtStemMarksDirty(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
key[0] = 0xcc
|
||||
values := make([][]byte, 256)
|
||||
values[0] = make([]byte, 32)
|
||||
n, err := Empty{}.InsertValuesAtStem(key, values, nil, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("InsertValuesAtStem: %v", err)
|
||||
}
|
||||
sn, ok := n.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("expected *StemNode, got %T", n)
|
||||
}
|
||||
if !sn.dirty {
|
||||
t.Fatalf("stem produced by Empty.InsertValuesAtStem must have dirty=true")
|
||||
}
|
||||
}
|
||||
|
|
@ -16,75 +16,9 @@
|
|||
|
||||
package bintrie
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
import "github.com/ethereum/go-ethereum/common"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
type HashedNode common.Hash
|
||||
|
||||
func (h HashedNode) Get(_ []byte, _ NodeResolverFn) ([]byte, error) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (h HashedNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
|
||||
return nil, errors.New("insert not implemented for hashed node")
|
||||
}
|
||||
|
||||
func (h HashedNode) Copy() BinaryNode {
|
||||
nh := common.Hash(h)
|
||||
return HashedNode(nh)
|
||||
}
|
||||
|
||||
func (h HashedNode) Hash() common.Hash {
|
||||
return common.Hash(h)
|
||||
}
|
||||
|
||||
func (h HashedNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
|
||||
return nil, errors.New("attempted to get values from an unresolved node")
|
||||
}
|
||||
|
||||
func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
|
||||
// Step 1: Generate the path for this node's position in the tree
|
||||
path, err := keyToPath(depth, stem)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem path generation error: %w", err)
|
||||
}
|
||||
|
||||
if resolver == nil {
|
||||
return nil, errors.New("InsertValuesAtStem resolve error: resolver is nil")
|
||||
}
|
||||
|
||||
// Step 2: Resolve the hashed node to get the actual node data
|
||||
data, err := resolver(path, common.Hash(h))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Deserialize the resolved data into a concrete node
|
||||
node, err := DeserializeNodeWithHash(data, depth, common.Hash(h))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Call InsertValuesAtStem on the resolved concrete node
|
||||
return node.InsertValuesAtStem(stem, values, resolver, depth)
|
||||
}
|
||||
|
||||
func (h HashedNode) toDot(parent string, path string) string {
|
||||
me := fmt.Sprintf("hash%s", path)
|
||||
ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, h)
|
||||
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
|
||||
return ret
|
||||
}
|
||||
|
||||
func (h HashedNode) CollectNodes([]byte, NodeFlushFn) error {
|
||||
// HashedNodes are already persisted in the database and don't need to be collected.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h HashedNode) GetHeight() int {
|
||||
panic("tried to get the height of a hashed node, this is a bug")
|
||||
// HashedNode represents an unresolved node that only stores its hash.
|
||||
type HashedNode struct {
|
||||
hash common.Hash
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,180 +18,137 @@ package bintrie
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// TestHashedNodeHash tests the Hash method
|
||||
// TestHashedNodeHash tests the Hash method via NodeStore.
|
||||
func TestHashedNodeHash(t *testing.T) {
|
||||
hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef")
|
||||
node := HashedNode(hash)
|
||||
s := NewNodeStore()
|
||||
ref := s.newHashedRef(hash)
|
||||
|
||||
// Hash should return the stored hash
|
||||
if node.Hash() != hash {
|
||||
t.Errorf("Hash mismatch: expected %x, got %x", hash, node.Hash())
|
||||
if s.ComputeHash(ref) != hash {
|
||||
t.Errorf("Hash mismatch: expected %x, got %x", hash, s.ComputeHash(ref))
|
||||
}
|
||||
}
|
||||
|
||||
// TestHashedNodeCopy tests the Copy method
|
||||
// TestHashedNodeCopy tests the Copy method via NodeStore.
|
||||
func TestHashedNodeCopy(t *testing.T) {
|
||||
hash := common.HexToHash("0xabcdef")
|
||||
node := HashedNode(hash)
|
||||
s := NewNodeStore()
|
||||
ref := s.newHashedRef(hash)
|
||||
s.SetRoot(ref)
|
||||
|
||||
copied := node.Copy()
|
||||
copiedHash, ok := copied.(HashedNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected HashedNode, got %T", copied)
|
||||
}
|
||||
ns := s.Copy()
|
||||
copiedHash := ns.ComputeHash(ns.Root())
|
||||
|
||||
// Hash should be the same
|
||||
if common.Hash(copiedHash) != hash {
|
||||
if copiedHash != hash {
|
||||
t.Errorf("Hash mismatch after copy: expected %x, got %x", hash, copiedHash)
|
||||
}
|
||||
|
||||
// But should be a different object
|
||||
if &node == &copiedHash {
|
||||
t.Error("Copy returned same object reference")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHashedNodeInsert tests that Insert returns an error
|
||||
func TestHashedNodeInsert(t *testing.T) {
|
||||
node := HashedNode(common.HexToHash("0x1234"))
|
||||
|
||||
key := make([]byte, HashSize)
|
||||
value := make([]byte, HashSize)
|
||||
|
||||
_, err := node.Insert(key, value, nil, 0)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for Insert on HashedNode")
|
||||
}
|
||||
|
||||
if err.Error() != "insert not implemented for hashed node" {
|
||||
t.Errorf("Unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHashedNodeGetValuesAtStem tests that GetValuesAtStem returns an error
|
||||
func TestHashedNodeGetValuesAtStem(t *testing.T) {
|
||||
node := HashedNode(common.HexToHash("0x1234"))
|
||||
|
||||
stem := make([]byte, StemSize)
|
||||
_, err := node.GetValuesAtStem(stem, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for GetValuesAtStem on HashedNode")
|
||||
}
|
||||
|
||||
if err.Error() != "attempted to get values from an unresolved node" {
|
||||
t.Errorf("Unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHashedNodeInsertValuesAtStem tests that InsertValuesAtStem returns an error
|
||||
// TestHashedNodeInsertValuesAtStem tests InsertValuesAtStem resolution via NodeStore.
|
||||
func TestHashedNodeInsertValuesAtStem(t *testing.T) {
|
||||
node := HashedNode(common.HexToHash("0x1234"))
|
||||
// Test 1: nil resolver should return an error
|
||||
s := NewNodeStore()
|
||||
hashedRef := s.newHashedRef(common.HexToHash("0x1234"))
|
||||
s.SetRoot(hashedRef)
|
||||
|
||||
stem := make([]byte, StemSize)
|
||||
values := make([][]byte, StemNodeWidth)
|
||||
|
||||
// Test 1: nil resolver should return an error
|
||||
_, err := node.InsertValuesAtStem(stem, values, nil, 0)
|
||||
err := s.InsertValuesAtStem(stem, values, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for InsertValuesAtStem on HashedNode with nil resolver")
|
||||
}
|
||||
|
||||
if err.Error() != "InsertValuesAtStem resolve error: resolver is nil" {
|
||||
t.Errorf("Unexpected error message: %v", err)
|
||||
t.Fatal("Expected error for InsertValuesAtStem with nil resolver")
|
||||
}
|
||||
|
||||
// Test 2: mock resolver returning invalid data should return deserialization error
|
||||
mockResolver := func(path []byte, hash common.Hash) ([]byte, error) {
|
||||
// Return invalid/nonsense data that cannot be deserialized
|
||||
return []byte{0xff, 0xff, 0xff}, nil
|
||||
}
|
||||
|
||||
_, err = node.InsertValuesAtStem(stem, values, mockResolver, 0)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for InsertValuesAtStem on HashedNode with invalid resolver data")
|
||||
}
|
||||
s2 := NewNodeStore()
|
||||
hashedRef2 := s2.newHashedRef(common.HexToHash("0x1234"))
|
||||
s2.SetRoot(hashedRef2)
|
||||
|
||||
expectedPrefix := "InsertValuesAtStem node deserialization error:"
|
||||
if len(err.Error()) < len(expectedPrefix) || err.Error()[:len(expectedPrefix)] != expectedPrefix {
|
||||
t.Errorf("Expected deserialization error, got: %v", err)
|
||||
err = s2.InsertValuesAtStem(stem, values, mockResolver)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for InsertValuesAtStem with invalid resolver data")
|
||||
}
|
||||
|
||||
// Test 3: mock resolver returning valid serialized node should succeed
|
||||
stem = make([]byte, StemSize)
|
||||
stem[0] = 0xaa
|
||||
var originalValues [StemNodeWidth][]byte
|
||||
originalValues := make([][]byte, StemNodeWidth)
|
||||
originalValues[0] = common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111").Bytes()
|
||||
originalValues[1] = common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222").Bytes()
|
||||
|
||||
originalNode := &StemNode{
|
||||
Stem: stem,
|
||||
Values: originalValues[:],
|
||||
depth: 0,
|
||||
// Build the serialized node
|
||||
rs := NewNodeStore()
|
||||
ref := rs.newStemRef(stem, 0)
|
||||
sn := rs.getStem(ref.Index())
|
||||
for i, v := range originalValues {
|
||||
if v != nil {
|
||||
sn.setValue(byte(i), v)
|
||||
}
|
||||
}
|
||||
serialized := rs.SerializeNode(ref, MaxGroupDepth)
|
||||
|
||||
// Serialize the node
|
||||
serialized := SerializeNode(originalNode)
|
||||
|
||||
// Create a mock resolver that returns the serialized node
|
||||
validResolver := func(path []byte, hash common.Hash) ([]byte, error) {
|
||||
return serialized, nil
|
||||
}
|
||||
|
||||
var newValues [StemNodeWidth][]byte
|
||||
s3 := NewNodeStore()
|
||||
hashedRef3 := s3.newHashedRef(common.HexToHash("0x1234"))
|
||||
s3.SetRoot(hashedRef3)
|
||||
|
||||
newValues := make([][]byte, StemNodeWidth)
|
||||
newValues[2] = common.HexToHash("0x3333333333333333333333333333333333333333333333333333333333333333").Bytes()
|
||||
|
||||
resolvedNode, err := node.InsertValuesAtStem(stem, newValues[:], validResolver, 0)
|
||||
err = s3.InsertValuesAtStem(stem, newValues, validResolver)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected successful resolution and insertion, got error: %v", err)
|
||||
}
|
||||
|
||||
resultStem, ok := resolvedNode.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected resolved node to be *StemNode, got %T", resolvedNode)
|
||||
// Verify original values are preserved
|
||||
retrieved, err := s3.GetValuesAtStem(stem, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(resultStem.Stem, stem) {
|
||||
t.Errorf("Stem mismatch: expected %x, got %x", stem, resultStem.Stem)
|
||||
if !bytes.Equal(retrieved[0], originalValues[0]) {
|
||||
t.Errorf("Original value at index 0 not preserved")
|
||||
}
|
||||
|
||||
// Verify the original values are preserved
|
||||
if !bytes.Equal(resultStem.Values[0], originalValues[0]) {
|
||||
t.Errorf("Original value at index 0 not preserved: expected %x, got %x", originalValues[0], resultStem.Values[0])
|
||||
if !bytes.Equal(retrieved[1], originalValues[1]) {
|
||||
t.Errorf("Original value at index 1 not preserved")
|
||||
}
|
||||
if !bytes.Equal(resultStem.Values[1], originalValues[1]) {
|
||||
t.Errorf("Original value at index 1 not preserved: expected %x, got %x", originalValues[1], resultStem.Values[1])
|
||||
}
|
||||
|
||||
// Verify the new value was inserted
|
||||
if !bytes.Equal(resultStem.Values[2], newValues[2]) {
|
||||
t.Errorf("New value at index 2 not inserted correctly: expected %x, got %x", newValues[2], resultStem.Values[2])
|
||||
if !bytes.Equal(retrieved[2], newValues[2]) {
|
||||
t.Errorf("New value at index 2 not inserted correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHashedNodeToDot tests the toDot method for visualization
|
||||
func TestHashedNodeToDot(t *testing.T) {
|
||||
hash := common.HexToHash("0x1234")
|
||||
node := HashedNode(hash)
|
||||
// TestHashedNodeGetError tests that getting through an unresolved HashedNode root returns error.
|
||||
func TestHashedNodeGetError(t *testing.T) {
|
||||
s := NewNodeStore()
|
||||
// Create root as hashed, then try to resolve through InternalNode parent
|
||||
rootRef := s.newInternalRef(0)
|
||||
rootNode := s.getInternal(rootRef.Index())
|
||||
hashedLeft := s.newHashedRef(common.HexToHash("0x1234"))
|
||||
rootNode.left = hashedLeft
|
||||
rootNode.right = EmptyRef
|
||||
s.SetRoot(rootRef)
|
||||
|
||||
dot := node.toDot("parent", "010")
|
||||
key := make([]byte, 32) // goes left
|
||||
key[31] = 5
|
||||
|
||||
// Should contain the hash value and parent connection
|
||||
expectedHash := "hash010"
|
||||
if !contains(dot, expectedHash) {
|
||||
t.Errorf("Expected dot output to contain %s", expectedHash)
|
||||
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
|
||||
return nil, errors.New("node not found")
|
||||
}
|
||||
|
||||
if !contains(dot, "parent -> hash010") {
|
||||
t.Error("Expected dot output to contain parent connection")
|
||||
_, err := s.Get(key, resolver)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when resolver fails")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && s != "" && len(substr) > 0
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,3 +37,11 @@ func newSha256() hash.Hash {
|
|||
func returnSha256(h hash.Hash) {
|
||||
sha256Pool.Put(h)
|
||||
}
|
||||
|
||||
// sha256Sum256 computes a sha256 digest and returns it as a common.Hash.
|
||||
func sha256Sum256(data []byte) [32]byte {
|
||||
return sha256.Sum256(data)
|
||||
}
|
||||
|
||||
// parallelHashDepth controls below which depth hashing is parallelised.
|
||||
const parallelHashDepth = 4
|
||||
|
|
|
|||
|
|
@ -17,35 +17,13 @@
|
|||
package bintrie
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// parallelDepth returns the tree depth below which Hash() spawns goroutines.
|
||||
func parallelDepth() int {
|
||||
return min(bits.Len(uint(runtime.NumCPU())), 8)
|
||||
}
|
||||
|
||||
// isDirty reports whether a BinaryNode child needs rehashing.
|
||||
func isDirty(n BinaryNode) bool {
|
||||
switch v := n.(type) {
|
||||
case *InternalNode:
|
||||
return v.mustRecompute
|
||||
case *StemNode:
|
||||
return v.mustRecompute
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func keyToPath(depth int, key []byte) ([]byte, error) {
|
||||
if depth > 31*8 {
|
||||
if depth >= 31*8 {
|
||||
return nil, errors.New("node too deep")
|
||||
}
|
||||
path := make([]byte, 0, depth+1)
|
||||
|
|
@ -56,252 +34,20 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
|
|||
return path, nil
|
||||
}
|
||||
|
||||
// makeKeyPath is a simplified version of keyToPath that doesn't return an error.
|
||||
func makeKeyPath(depth int, key []byte) []byte {
|
||||
path := make([]byte, 0, depth+1)
|
||||
for i := range depth + 1 {
|
||||
bit := key[i/8] >> (7 - (i % 8)) & 1
|
||||
path = append(path, bit)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
// InternalNode is a binary trie internal node.
|
||||
type InternalNode struct {
|
||||
left, right BinaryNode
|
||||
depth int
|
||||
|
||||
left, right NodeRef
|
||||
depth uint8
|
||||
mustRecompute bool // true if the hash needs to be recomputed
|
||||
dirty bool // true if the node's on-disk blob is stale (needs flush)
|
||||
hash common.Hash // cached hash when mustRecompute == false
|
||||
}
|
||||
|
||||
// GetValuesAtStem retrieves the group of values located at the given stem key.
|
||||
func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) {
|
||||
if bt.depth > 31*8 {
|
||||
return nil, errors.New("node too deep")
|
||||
}
|
||||
|
||||
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
|
||||
if bit == 0 {
|
||||
if hn, ok := bt.left.(HashedNode); ok {
|
||||
path, err := keyToPath(bt.depth, stem)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
data, err := resolver(path, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
|
||||
}
|
||||
bt.left = node
|
||||
}
|
||||
return bt.left.GetValuesAtStem(stem, resolver)
|
||||
}
|
||||
|
||||
if hn, ok := bt.right.(HashedNode); ok {
|
||||
path, err := keyToPath(bt.depth, stem)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
data, err := resolver(path, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
|
||||
}
|
||||
bt.right = node
|
||||
}
|
||||
return bt.right.GetValuesAtStem(stem, resolver)
|
||||
}
|
||||
|
||||
// Get retrieves the value for the given key.
|
||||
func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error) {
|
||||
values, err := bt.GetValuesAtStem(key[:31], resolver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get error: %w", err)
|
||||
}
|
||||
if values == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return values[key[31]], nil
|
||||
}
|
||||
|
||||
// Insert inserts a new key-value pair into the trie.
|
||||
func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
|
||||
var values [256][]byte
|
||||
values[key[31]] = value
|
||||
return bt.InsertValuesAtStem(key[:31], values[:], resolver, depth)
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the node.
|
||||
func (bt *InternalNode) Copy() BinaryNode {
|
||||
return &InternalNode{
|
||||
left: bt.left.Copy(),
|
||||
right: bt.right.Copy(),
|
||||
depth: bt.depth,
|
||||
mustRecompute: bt.mustRecompute,
|
||||
dirty: bt.dirty,
|
||||
hash: bt.hash,
|
||||
}
|
||||
}
|
||||
|
||||
// Hash returns the hash of the node.
|
||||
func (bt *InternalNode) Hash() common.Hash {
|
||||
if !bt.mustRecompute {
|
||||
return bt.hash
|
||||
}
|
||||
|
||||
// At shallow depths, parallelize when both children need rehashing:
|
||||
// hash left subtree in a goroutine, right subtree inline, then combine.
|
||||
// Skip goroutine overhead when only one child is dirty (common case
|
||||
// for narrow state updates that touch a single path through the trie).
|
||||
if bt.depth < parallelDepth() && isDirty(bt.left) && isDirty(bt.right) {
|
||||
var input [64]byte
|
||||
var lh common.Hash
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
lh = bt.left.Hash()
|
||||
}()
|
||||
rh := bt.right.Hash()
|
||||
copy(input[32:], rh[:])
|
||||
wg.Wait()
|
||||
copy(input[:32], lh[:])
|
||||
bt.hash = sha256.Sum256(input[:])
|
||||
bt.mustRecompute = false
|
||||
return bt.hash
|
||||
}
|
||||
|
||||
// Deeper nodes: sequential using pooled hasher (goroutine overhead > hash cost)
|
||||
h := newSha256()
|
||||
defer returnSha256(h)
|
||||
if bt.left != nil {
|
||||
h.Write(bt.left.Hash().Bytes())
|
||||
} else {
|
||||
h.Write(zero[:])
|
||||
}
|
||||
if bt.right != nil {
|
||||
h.Write(bt.right.Hash().Bytes())
|
||||
} else {
|
||||
h.Write(zero[:])
|
||||
}
|
||||
bt.hash = common.BytesToHash(h.Sum(nil))
|
||||
bt.mustRecompute = false
|
||||
return bt.hash
|
||||
}
|
||||
|
||||
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
|
||||
// Already-existing values will be overwritten.
|
||||
func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
|
||||
var err error
|
||||
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
|
||||
if bit == 0 {
|
||||
if bt.left == nil {
|
||||
bt.left = Empty{}
|
||||
}
|
||||
|
||||
if hn, ok := bt.left.(HashedNode); ok {
|
||||
path, err := keyToPath(bt.depth, stem)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
data, err := resolver(path, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
|
||||
}
|
||||
bt.left = node
|
||||
}
|
||||
|
||||
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
|
||||
bt.mustRecompute = true
|
||||
bt.dirty = true
|
||||
return bt, err
|
||||
}
|
||||
|
||||
if bt.right == nil {
|
||||
bt.right = Empty{}
|
||||
}
|
||||
|
||||
if hn, ok := bt.right.(HashedNode); ok {
|
||||
path, err := keyToPath(bt.depth, stem)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
data, err := resolver(path, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
|
||||
}
|
||||
bt.right = node
|
||||
}
|
||||
|
||||
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
|
||||
bt.mustRecompute = true
|
||||
bt.dirty = true
|
||||
return bt, err
|
||||
}
|
||||
|
||||
// CollectNodes collects all child nodes at a given path, and flushes it
|
||||
// into the provided node collector. Clean subtrees (dirty == false) are
|
||||
// skipped.
|
||||
func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
|
||||
if !bt.dirty {
|
||||
return nil
|
||||
}
|
||||
if bt.left != nil {
|
||||
var p [256]byte
|
||||
copy(p[:], path)
|
||||
childpath := p[:len(path)]
|
||||
childpath = append(childpath, 0)
|
||||
if err := bt.left.CollectNodes(childpath, flushfn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if bt.right != nil {
|
||||
var p [256]byte
|
||||
copy(p[:], path)
|
||||
childpath := p[:len(path)]
|
||||
childpath = append(childpath, 1)
|
||||
if err := bt.right.CollectNodes(childpath, flushfn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
flushfn(path, bt)
|
||||
bt.dirty = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHeight returns the height of the node.
|
||||
func (bt *InternalNode) GetHeight() int {
|
||||
var (
|
||||
leftHeight int
|
||||
rightHeight int
|
||||
)
|
||||
if bt.left != nil {
|
||||
leftHeight = bt.left.GetHeight()
|
||||
}
|
||||
if bt.right != nil {
|
||||
rightHeight = bt.right.GetHeight()
|
||||
}
|
||||
return 1 + max(leftHeight, rightHeight)
|
||||
}
|
||||
|
||||
func (bt *InternalNode) toDot(parent, path string) string {
|
||||
me := fmt.Sprintf("internal%s", path)
|
||||
ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, bt.Hash())
|
||||
if len(parent) > 0 {
|
||||
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
|
||||
}
|
||||
|
||||
if bt.left != nil {
|
||||
ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0)))
|
||||
}
|
||||
if bt.right != nil {
|
||||
ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1)))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,35 +24,33 @@ import (
|
|||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// TestInternalNodeGet tests the Get method
|
||||
// TestInternalNodeGet tests the Get method via NodeStore.
|
||||
func TestInternalNodeGet(t *testing.T) {
|
||||
// Create a simple tree structure
|
||||
s := NewNodeStore()
|
||||
|
||||
leftStem := make([]byte, 31)
|
||||
rightStem := make([]byte, 31)
|
||||
rightStem[0] = 0x80 // First bit is 1
|
||||
rightStem[0] = 0x80
|
||||
|
||||
var leftValues, rightValues [256][]byte
|
||||
leftValues := make([][]byte, 256)
|
||||
leftValues[0] = common.HexToHash("0x0101").Bytes()
|
||||
rightValues := make([][]byte, 256)
|
||||
rightValues[0] = common.HexToHash("0x0202").Bytes()
|
||||
|
||||
node := &InternalNode{
|
||||
depth: 0,
|
||||
left: &StemNode{
|
||||
Stem: leftStem,
|
||||
Values: leftValues[:],
|
||||
depth: 1,
|
||||
},
|
||||
right: &StemNode{
|
||||
Stem: rightStem,
|
||||
Values: rightValues[:],
|
||||
depth: 1,
|
||||
},
|
||||
// Build tree: root -> left stem, right stem
|
||||
// Insert left stem values
|
||||
s.SetRoot(EmptyRef)
|
||||
if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Get value from left subtree
|
||||
leftKey := make([]byte, 32)
|
||||
leftKey[31] = 0
|
||||
value, err := node.Get(leftKey, nil)
|
||||
value, err := s.Get(leftKey, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get left value: %v", err)
|
||||
}
|
||||
|
|
@ -64,7 +62,7 @@ func TestInternalNodeGet(t *testing.T) {
|
|||
rightKey := make([]byte, 32)
|
||||
rightKey[0] = 0x80
|
||||
rightKey[31] = 0
|
||||
value, err = node.Get(rightKey, nil)
|
||||
value, err = s.Get(rightKey, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get right value: %v", err)
|
||||
}
|
||||
|
|
@ -73,29 +71,26 @@ func TestInternalNodeGet(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestInternalNodeGetWithResolver tests Get with HashedNode resolution
|
||||
// TestInternalNodeGetWithResolver tests Get with HashedNode resolution via NodeStore.
|
||||
func TestInternalNodeGetWithResolver(t *testing.T) {
|
||||
// Create an internal node with a hashed child
|
||||
hashedChild := HashedNode(common.HexToHash("0x1234"))
|
||||
|
||||
node := &InternalNode{
|
||||
depth: 0,
|
||||
left: hashedChild,
|
||||
right: Empty{},
|
||||
}
|
||||
// Create a store with an internal node containing a hashed child
|
||||
s := NewNodeStore()
|
||||
hashedChild := s.newHashedRef(common.HexToHash("0x1234"))
|
||||
rootRef := s.newInternalRef(0)
|
||||
rootNode := s.getInternal(rootRef.Index())
|
||||
rootNode.left = hashedChild
|
||||
rootNode.right = EmptyRef
|
||||
s.SetRoot(rootRef)
|
||||
|
||||
// Mock resolver that returns a stem node
|
||||
resolver := func(path []byte, hash common.Hash) ([]byte, error) {
|
||||
if hash == common.Hash(hashedChild) {
|
||||
if hash == common.HexToHash("0x1234") {
|
||||
rs := NewNodeStore()
|
||||
stem := make([]byte, 31)
|
||||
var values [256][]byte
|
||||
values[5] = common.HexToHash("0xabcd").Bytes()
|
||||
stemNode := &StemNode{
|
||||
Stem: stem,
|
||||
Values: values[:],
|
||||
depth: 1,
|
||||
}
|
||||
return SerializeNode(stemNode), nil
|
||||
ref := rs.newStemRef(stem, 1)
|
||||
sn := rs.getStem(ref.Index())
|
||||
sn.setValue(5, common.HexToHash("0xabcd").Bytes())
|
||||
return rs.SerializeNode(ref, MaxGroupDepth), nil
|
||||
}
|
||||
return nil, errors.New("node not found")
|
||||
}
|
||||
|
|
@ -103,7 +98,7 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
|
|||
// Get value through the hashed node
|
||||
key := make([]byte, 32)
|
||||
key[31] = 5
|
||||
value, err := node.Get(key, resolver)
|
||||
value, err := s.Get(key, resolver)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get value: %v", err)
|
||||
}
|
||||
|
|
@ -114,179 +109,113 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestInternalNodeInsert tests the Insert method
|
||||
// TestInternalNodeInsert tests the Insert method via NodeStore.
|
||||
func TestInternalNodeInsert(t *testing.T) {
|
||||
// Start with an internal node with empty children
|
||||
node := &InternalNode{
|
||||
depth: 0,
|
||||
left: Empty{},
|
||||
right: Empty{},
|
||||
}
|
||||
s := NewNodeStore()
|
||||
|
||||
// Insert a value into the left subtree
|
||||
leftKey := make([]byte, 32)
|
||||
leftKey[31] = 10
|
||||
leftValue := common.HexToHash("0x0101").Bytes()
|
||||
|
||||
newNode, err := node.Insert(leftKey, leftValue, nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(leftKey, leftValue, nil); err != nil {
|
||||
t.Fatalf("Failed to insert: %v", err)
|
||||
}
|
||||
|
||||
internalNode, ok := newNode.(*InternalNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected InternalNode, got %T", newNode)
|
||||
// Verify the value was stored
|
||||
value, err := s.Get(leftKey, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get: %v", err)
|
||||
}
|
||||
|
||||
// Check that left child is now a StemNode
|
||||
leftStem, ok := internalNode.left.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
|
||||
}
|
||||
|
||||
// Check the inserted value
|
||||
if !bytes.Equal(leftStem.Values[10], leftValue) {
|
||||
t.Errorf("Value mismatch: expected %x, got %x", leftValue, leftStem.Values[10])
|
||||
}
|
||||
|
||||
// Right child should still be Empty
|
||||
_, ok = internalNode.right.(Empty)
|
||||
if !ok {
|
||||
t.Errorf("Expected right child to remain Empty, got %T", internalNode.right)
|
||||
if !bytes.Equal(value, leftValue) {
|
||||
t.Errorf("Value mismatch: expected %x, got %x", leftValue, value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInternalNodeCopy tests the Copy method
|
||||
// TestInternalNodeCopy tests the Copy method via NodeStore.
|
||||
func TestInternalNodeCopy(t *testing.T) {
|
||||
// Create an internal node with stem children
|
||||
leftStem := &StemNode{
|
||||
Stem: make([]byte, 31),
|
||||
Values: make([][]byte, 256),
|
||||
depth: 1,
|
||||
}
|
||||
leftStem.Values[0] = common.HexToHash("0x0101").Bytes()
|
||||
s := NewNodeStore()
|
||||
|
||||
rightStem := &StemNode{
|
||||
Stem: make([]byte, 31),
|
||||
Values: make([][]byte, 256),
|
||||
depth: 1,
|
||||
}
|
||||
rightStem.Stem[0] = 0x80
|
||||
rightStem.Values[0] = common.HexToHash("0x0202").Bytes()
|
||||
leftKey := make([]byte, 32)
|
||||
leftKey[31] = 0
|
||||
leftValue := common.HexToHash("0x0101").Bytes()
|
||||
|
||||
node := &InternalNode{
|
||||
depth: 0,
|
||||
left: leftStem,
|
||||
right: rightStem,
|
||||
rightKey := make([]byte, 32)
|
||||
rightKey[0] = 0x80
|
||||
rightKey[31] = 0
|
||||
rightValue := common.HexToHash("0x0202").Bytes()
|
||||
|
||||
if err := s.Insert(leftKey, leftValue, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.Insert(rightKey, rightValue, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a copy
|
||||
copied := node.Copy()
|
||||
copiedInternal, ok := copied.(*InternalNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected InternalNode, got %T", copied)
|
||||
}
|
||||
ns := s.Copy()
|
||||
|
||||
// Check depth
|
||||
if copiedInternal.depth != node.depth {
|
||||
t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedInternal.depth)
|
||||
}
|
||||
|
||||
// Check that children are copied
|
||||
copiedLeft, ok := copiedInternal.left.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.left)
|
||||
}
|
||||
|
||||
copiedRight, ok := copiedInternal.right.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.right)
|
||||
}
|
||||
|
||||
// Verify deep copy (children should be different objects)
|
||||
if copiedLeft == leftStem {
|
||||
t.Error("Left child not properly copied")
|
||||
}
|
||||
if copiedRight == rightStem {
|
||||
t.Error("Right child not properly copied")
|
||||
}
|
||||
|
||||
// But values should be equal
|
||||
if !bytes.Equal(copiedLeft.Values[0], leftStem.Values[0]) {
|
||||
// Values should be equal
|
||||
v1, _ := ns.Get(leftKey, nil)
|
||||
if !bytes.Equal(v1, leftValue) {
|
||||
t.Error("Left child value mismatch after copy")
|
||||
}
|
||||
if !bytes.Equal(copiedRight.Values[0], rightStem.Values[0]) {
|
||||
v2, _ := ns.Get(rightKey, nil)
|
||||
if !bytes.Equal(v2, rightValue) {
|
||||
t.Error("Right child value mismatch after copy")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInternalNodeHash tests the Hash method
|
||||
// TestInternalNodeHash tests the Hash method via NodeStore.
|
||||
func TestInternalNodeHash(t *testing.T) {
|
||||
// Create an internal node
|
||||
node := &InternalNode{
|
||||
depth: 0,
|
||||
left: HashedNode(common.HexToHash("0x1111")),
|
||||
right: HashedNode(common.HexToHash("0x2222")),
|
||||
}
|
||||
s := NewNodeStore()
|
||||
leftRef := s.newHashedRef(common.HexToHash("0x1111"))
|
||||
rightRef := s.newHashedRef(common.HexToHash("0x2222"))
|
||||
rootRef := s.newInternalRef(0)
|
||||
rootNode := s.getInternal(rootRef.Index())
|
||||
rootNode.left = leftRef
|
||||
rootNode.right = rightRef
|
||||
s.SetRoot(rootRef)
|
||||
|
||||
hash1 := node.Hash()
|
||||
hash1 := s.ComputeHash(rootRef)
|
||||
|
||||
// Hash should be deterministic
|
||||
hash2 := node.Hash()
|
||||
hash2 := s.ComputeHash(rootRef)
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("Hash not deterministic: %x != %x", hash1, hash2)
|
||||
}
|
||||
|
||||
// Changing a child should change the hash
|
||||
node.left = HashedNode(common.HexToHash("0x3333"))
|
||||
node.mustRecompute = true
|
||||
hash3 := node.Hash()
|
||||
rootNode.left = s.newHashedRef(common.HexToHash("0x3333"))
|
||||
rootNode.mustRecompute = true
|
||||
hash3 := s.ComputeHash(rootRef)
|
||||
if hash1 == hash3 {
|
||||
t.Error("Hash didn't change after modifying left child")
|
||||
}
|
||||
|
||||
// Test with nil children (should use zero hash)
|
||||
nodeWithNil := &InternalNode{
|
||||
depth: 0,
|
||||
left: nil,
|
||||
right: HashedNode(common.HexToHash("0x4444")),
|
||||
mustRecompute: true,
|
||||
}
|
||||
hashWithNil := nodeWithNil.Hash()
|
||||
if hashWithNil == (common.Hash{}) {
|
||||
t.Error("Hash shouldn't be zero even with nil child")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method
|
||||
// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method via NodeStore.
|
||||
func TestInternalNodeGetValuesAtStem(t *testing.T) {
|
||||
// Create a tree with values at different stems
|
||||
s := NewNodeStore()
|
||||
|
||||
leftStem := make([]byte, 31)
|
||||
rightStem := make([]byte, 31)
|
||||
rightStem[0] = 0x80
|
||||
|
||||
var leftValues, rightValues [256][]byte
|
||||
leftValues := make([][]byte, 256)
|
||||
leftValues[0] = common.HexToHash("0x0101").Bytes()
|
||||
leftValues[10] = common.HexToHash("0x0102").Bytes()
|
||||
rightValues := make([][]byte, 256)
|
||||
rightValues[0] = common.HexToHash("0x0201").Bytes()
|
||||
rightValues[20] = common.HexToHash("0x0202").Bytes()
|
||||
|
||||
node := &InternalNode{
|
||||
depth: 0,
|
||||
left: &StemNode{
|
||||
Stem: leftStem,
|
||||
Values: leftValues[:],
|
||||
depth: 1,
|
||||
},
|
||||
right: &StemNode{
|
||||
Stem: rightStem,
|
||||
Values: rightValues[:],
|
||||
depth: 1,
|
||||
},
|
||||
if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Get values from left stem
|
||||
values, err := node.GetValuesAtStem(leftStem, nil)
|
||||
values, err := s.GetValuesAtStem(leftStem, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get left values: %v", err)
|
||||
}
|
||||
|
|
@ -298,7 +227,7 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
|
|||
}
|
||||
|
||||
// Get values from right stem
|
||||
values, err = node.GetValuesAtStem(rightStem, nil)
|
||||
values, err = s.GetValuesAtStem(rightStem, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get right values: %v", err)
|
||||
}
|
||||
|
|
@ -310,201 +239,103 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method
|
||||
// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method via NodeStore.
|
||||
func TestInternalNodeInsertValuesAtStem(t *testing.T) {
|
||||
// Start with an internal node with empty children
|
||||
node := &InternalNode{
|
||||
depth: 0,
|
||||
left: Empty{},
|
||||
right: Empty{},
|
||||
}
|
||||
s := NewNodeStore()
|
||||
|
||||
// Insert values at a stem in the left subtree
|
||||
stem := make([]byte, 31)
|
||||
var values [256][]byte
|
||||
values := make([][]byte, 256)
|
||||
values[5] = common.HexToHash("0x0505").Bytes()
|
||||
values[10] = common.HexToHash("0x1010").Bytes()
|
||||
|
||||
newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
|
||||
t.Fatalf("Failed to insert values: %v", err)
|
||||
}
|
||||
|
||||
internalNode, ok := newNode.(*InternalNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected InternalNode, got %T", newNode)
|
||||
// Check that the values are stored
|
||||
retrieved, err := s.GetValuesAtStem(stem, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get values: %v", err)
|
||||
}
|
||||
|
||||
// Check that left child is now a StemNode with the values
|
||||
leftStem, ok := internalNode.left.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
|
||||
}
|
||||
|
||||
if !bytes.Equal(leftStem.Values[5], values[5]) {
|
||||
if !bytes.Equal(retrieved[5], values[5]) {
|
||||
t.Error("Value at index 5 mismatch")
|
||||
}
|
||||
if !bytes.Equal(leftStem.Values[10], values[10]) {
|
||||
if !bytes.Equal(retrieved[10], values[10]) {
|
||||
t.Error("Value at index 10 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInternalNodeCollectNodes tests CollectNodes method
|
||||
// TestInternalNodeCollectNodes tests CollectNodes method via NodeStore.
|
||||
func TestInternalNodeCollectNodes(t *testing.T) {
|
||||
// Create an internal node with two stem children. All three are
|
||||
// marked dirty to mirror production semantics — see CollectNodes.
|
||||
leftStem := &StemNode{
|
||||
Stem: make([]byte, 31),
|
||||
Values: make([][]byte, 256),
|
||||
depth: 1,
|
||||
dirty: true,
|
||||
}
|
||||
s := NewNodeStore()
|
||||
|
||||
rightStem := &StemNode{
|
||||
Stem: make([]byte, 31),
|
||||
Values: make([][]byte, 256),
|
||||
depth: 1,
|
||||
dirty: true,
|
||||
}
|
||||
rightStem.Stem[0] = 0x80
|
||||
leftStem := make([]byte, 31)
|
||||
rightStem := make([]byte, 31)
|
||||
rightStem[0] = 0x80
|
||||
|
||||
node := &InternalNode{
|
||||
depth: 0,
|
||||
left: leftStem,
|
||||
right: rightStem,
|
||||
dirty: true,
|
||||
leftValues := make([][]byte, 256)
|
||||
rightValues := make([][]byte, 256)
|
||||
|
||||
if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var collectedPaths [][]byte
|
||||
var collectedNodes []BinaryNode
|
||||
|
||||
flushFn := func(path []byte, n BinaryNode) {
|
||||
flushFn := func(path []byte, hash common.Hash, serialized []byte) {
|
||||
pathCopy := make([]byte, len(path))
|
||||
copy(pathCopy, path)
|
||||
collectedPaths = append(collectedPaths, pathCopy)
|
||||
collectedNodes = append(collectedNodes, n)
|
||||
}
|
||||
|
||||
err := node.CollectNodes([]byte{1}, flushFn)
|
||||
err := s.CollectNodes(s.Root(), []byte{1}, flushFn, MaxGroupDepth)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to collect nodes: %v", err)
|
||||
}
|
||||
|
||||
// Should have collected 3 nodes: left stem, right stem, and the internal node itself
|
||||
if len(collectedNodes) != 3 {
|
||||
t.Errorf("Expected 3 collected nodes, got %d", len(collectedNodes))
|
||||
}
|
||||
|
||||
// Check paths
|
||||
expectedPaths := [][]byte{
|
||||
{1, 0}, // left child
|
||||
{1, 1}, // right child
|
||||
{1}, // internal node itself
|
||||
}
|
||||
|
||||
for i, expectedPath := range expectedPaths {
|
||||
if !bytes.Equal(collectedPaths[i], expectedPath) {
|
||||
t.Errorf("Path %d mismatch: expected %v, got %v", i, expectedPath, collectedPaths[i])
|
||||
}
|
||||
if len(collectedPaths) != 3 {
|
||||
t.Errorf("Expected 3 collected nodes, got %d", len(collectedPaths))
|
||||
}
|
||||
}
|
||||
|
||||
// TestInternalNodeCollectNodesSkipsClean verifies clean subtrees are not
|
||||
// flushed. A dirty internal node over clean children only flushes itself;
|
||||
// a fully clean tree flushes nothing.
|
||||
func TestInternalNodeCollectNodesSkipsClean(t *testing.T) {
|
||||
leftStem := &StemNode{
|
||||
Stem: make([]byte, 31),
|
||||
Values: make([][]byte, 256),
|
||||
depth: 1,
|
||||
}
|
||||
rightStem := &StemNode{
|
||||
Stem: make([]byte, 31),
|
||||
Values: make([][]byte, 256),
|
||||
depth: 1,
|
||||
}
|
||||
rightStem.Stem[0] = 0x80
|
||||
|
||||
dirtyParent := &InternalNode{
|
||||
depth: 0,
|
||||
left: leftStem,
|
||||
right: rightStem,
|
||||
dirty: true,
|
||||
}
|
||||
|
||||
var collected []BinaryNode
|
||||
flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) }
|
||||
|
||||
if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil {
|
||||
t.Fatalf("CollectNodes: %v", err)
|
||||
}
|
||||
if len(collected) != 1 || collected[0] != dirtyParent {
|
||||
t.Fatalf("expected only the dirty parent to be flushed, got %d nodes", len(collected))
|
||||
}
|
||||
if dirtyParent.dirty {
|
||||
t.Errorf("parent dirty flag should be cleared after flush")
|
||||
}
|
||||
|
||||
// Second call on the same tree should be a no-op: everything is clean.
|
||||
collected = nil
|
||||
if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil {
|
||||
t.Fatalf("CollectNodes (second call): %v", err)
|
||||
}
|
||||
if len(collected) != 0 {
|
||||
t.Errorf("expected no nodes to be flushed on clean tree, got %d", len(collected))
|
||||
}
|
||||
}
|
||||
|
||||
// TestInternalNodeGetHeight tests GetHeight method
|
||||
// TestInternalNodeGetHeight tests GetHeight method via NodeStore.
|
||||
func TestInternalNodeGetHeight(t *testing.T) {
|
||||
// Create a tree with different heights
|
||||
// Left subtree: depth 2 (internal -> stem)
|
||||
// Right subtree: depth 1 (stem)
|
||||
leftInternal := &InternalNode{
|
||||
depth: 1,
|
||||
left: &StemNode{
|
||||
Stem: make([]byte, 31),
|
||||
Values: make([][]byte, 256),
|
||||
depth: 2,
|
||||
},
|
||||
right: Empty{},
|
||||
s := NewNodeStore()
|
||||
|
||||
// Insert values that create a deeper tree
|
||||
stem1 := make([]byte, 31) // left
|
||||
stem2 := make([]byte, 31)
|
||||
stem2[0] = 0x40 // 01... -> goes left at depth 0, right at depth 1
|
||||
|
||||
values1 := make([][]byte, 256)
|
||||
values1[0] = common.HexToHash("0x01").Bytes()
|
||||
values2 := make([][]byte, 256)
|
||||
values2[0] = common.HexToHash("0x02").Bytes()
|
||||
|
||||
if err := s.InsertValuesAtStem(stem1, values1, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.InsertValuesAtStem(stem2, values2, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rightStem := &StemNode{
|
||||
Stem: make([]byte, 31),
|
||||
Values: make([][]byte, 256),
|
||||
depth: 1,
|
||||
}
|
||||
|
||||
node := &InternalNode{
|
||||
depth: 0,
|
||||
left: leftInternal,
|
||||
right: rightStem,
|
||||
}
|
||||
|
||||
height := node.GetHeight()
|
||||
// Height should be max(left height, right height) + 1
|
||||
// Left height: 2, Right height: 1, so total: 3
|
||||
if height != 3 {
|
||||
t.Errorf("Expected height 3, got %d", height)
|
||||
height := s.GetHeight(s.Root())
|
||||
if height < 2 {
|
||||
t.Errorf("Expected height >= 2, got %d", height)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInternalNodeDepthTooLarge tests handling of excessive depth
|
||||
// TestInternalNodeDepthTooLarge tests handling of excessive depth via NodeStore.
|
||||
func TestInternalNodeDepthTooLarge(t *testing.T) {
|
||||
// Create an internal node at max depth
|
||||
node := &InternalNode{
|
||||
depth: 31*8 + 1,
|
||||
left: Empty{},
|
||||
right: Empty{},
|
||||
}
|
||||
|
||||
stem := make([]byte, 31)
|
||||
_, err := node.GetValuesAtStem(stem, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for excessive depth")
|
||||
}
|
||||
if err.Error() != "node too deep" {
|
||||
t.Errorf("Expected 'node too deep' error, got: %v", err)
|
||||
}
|
||||
s := NewNodeStore()
|
||||
// Creating an internal node beyond max depth should panic
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("Expected panic for excessive depth")
|
||||
}
|
||||
}()
|
||||
s.newInternalRef(31*8 + 1)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,13 +26,14 @@ import (
|
|||
var errIteratorEnd = errors.New("end of iteration")
|
||||
|
||||
type binaryNodeIteratorState struct {
|
||||
Node BinaryNode
|
||||
Node NodeRef
|
||||
Index int
|
||||
}
|
||||
|
||||
type binaryNodeIterator struct {
|
||||
trie *BinaryTrie
|
||||
current BinaryNode
|
||||
store *NodeStore
|
||||
current NodeRef
|
||||
lastErr error
|
||||
|
||||
stack []binaryNodeIteratorState
|
||||
|
|
@ -40,56 +41,59 @@ type binaryNodeIterator struct {
|
|||
|
||||
func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) {
|
||||
if t.Hash() == zero {
|
||||
return &binaryNodeIterator{trie: t, lastErr: errIteratorEnd}, nil
|
||||
return &binaryNodeIterator{trie: t, store: t.store, lastErr: errIteratorEnd}, nil
|
||||
}
|
||||
it := &binaryNodeIterator{trie: t, current: t.root}
|
||||
// it.err = it.seek(start)
|
||||
it := &binaryNodeIterator{trie: t, store: t.store, current: t.store.Root()}
|
||||
return it, nil
|
||||
}
|
||||
|
||||
// Next moves the iterator to the next node. If the parameter is false, any child
|
||||
// nodes will be skipped.
|
||||
// Next moves the iterator to the next node.
|
||||
func (it *binaryNodeIterator) Next(descend bool) bool {
|
||||
if it.lastErr == errIteratorEnd {
|
||||
it.lastErr = errIteratorEnd
|
||||
return false
|
||||
}
|
||||
|
||||
if len(it.stack) == 0 {
|
||||
it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.root})
|
||||
it.current = it.trie.root
|
||||
|
||||
it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.store.Root()})
|
||||
it.current = it.trie.store.Root()
|
||||
return true
|
||||
}
|
||||
|
||||
switch node := it.current.(type) {
|
||||
case *InternalNode:
|
||||
// index: 0 = nothing visited, 1=left visited, 2=right visited
|
||||
switch it.current.Kind() {
|
||||
case KindInternal:
|
||||
node := it.store.getInternal(it.current.Index())
|
||||
context := &it.stack[len(it.stack)-1]
|
||||
|
||||
// recurse into both children
|
||||
if !descend {
|
||||
// Skip children: pop this node and advance parent
|
||||
if len(it.stack) == 1 {
|
||||
it.lastErr = errIteratorEnd
|
||||
return false
|
||||
}
|
||||
it.stack = it.stack[:len(it.stack)-1]
|
||||
it.current = it.stack[len(it.stack)-1].Node
|
||||
it.stack[len(it.stack)-1].Index++
|
||||
return it.Next(true)
|
||||
}
|
||||
|
||||
if context.Index == 0 {
|
||||
if _, isempty := node.left.(Empty); node.left != nil && !isempty {
|
||||
if !node.left.IsEmpty() {
|
||||
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left})
|
||||
it.current = node.left
|
||||
return it.Next(descend)
|
||||
}
|
||||
|
||||
context.Index++
|
||||
}
|
||||
|
||||
if context.Index == 1 {
|
||||
if _, isempty := node.right.(Empty); node.right != nil && !isempty {
|
||||
if !node.right.IsEmpty() {
|
||||
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right})
|
||||
it.current = node.right
|
||||
return it.Next(descend)
|
||||
}
|
||||
|
||||
context.Index++
|
||||
}
|
||||
|
||||
// Reached the end of this node, go back to the parent, if
|
||||
// this isn't root.
|
||||
if len(it.stack) == 1 {
|
||||
it.lastErr = errIteratorEnd
|
||||
return false
|
||||
|
|
@ -98,17 +102,16 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
|
|||
it.current = it.stack[len(it.stack)-1].Node
|
||||
it.stack[len(it.stack)-1].Index++
|
||||
return it.Next(descend)
|
||||
case *StemNode:
|
||||
// Look for the next non-empty value
|
||||
|
||||
case KindStem:
|
||||
sn := it.store.getStem(it.current.Index())
|
||||
for i := it.stack[len(it.stack)-1].Index; i < 256; i++ {
|
||||
if node.Values[i] != nil {
|
||||
if sn.hasValue(byte(i)) {
|
||||
it.stack[len(it.stack)-1].Index = i + 1
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// go back to parent to get the next leaf
|
||||
// Check if we're at the root before popping
|
||||
if len(it.stack) == 1 {
|
||||
it.lastErr = errIteratorEnd
|
||||
return false
|
||||
|
|
@ -117,45 +120,39 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
|
|||
it.current = it.stack[len(it.stack)-1].Node
|
||||
it.stack[len(it.stack)-1].Index++
|
||||
return it.Next(descend)
|
||||
case HashedNode:
|
||||
// resolve the node
|
||||
resolverPath := it.Path()
|
||||
data, err := it.trie.nodeResolver(resolverPath, common.Hash(node))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if data == nil {
|
||||
// Empty/nil node — treat as Empty, backtrack
|
||||
it.current = Empty{}
|
||||
it.stack[len(it.stack)-1].Node = it.current
|
||||
return it.Next(descend)
|
||||
}
|
||||
it.current, err = DeserializeNodeWithHash(data, len(it.stack)-1, common.Hash(node))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// update the stack and parent with the resolved node
|
||||
it.stack[len(it.stack)-1].Node = it.current
|
||||
if len(it.stack) >= 2 {
|
||||
parent := &it.stack[len(it.stack)-2]
|
||||
if parent.Index == 0 {
|
||||
parent.Node.(*InternalNode).left = it.current
|
||||
} else {
|
||||
parent.Node.(*InternalNode).right = it.current
|
||||
}
|
||||
}
|
||||
return it.Next(descend)
|
||||
case Empty:
|
||||
// Empty node - go back to parent and continue
|
||||
if len(it.stack) <= 1 {
|
||||
it.lastErr = errIteratorEnd
|
||||
case KindHashed:
|
||||
if len(it.stack) < 2 {
|
||||
it.lastErr = errors.New("cannot resolve hashed root during iteration")
|
||||
return false
|
||||
}
|
||||
it.stack = it.stack[:len(it.stack)-1]
|
||||
it.current = it.stack[len(it.stack)-1].Node
|
||||
it.stack[len(it.stack)-1].Index++
|
||||
hn := it.store.getHashed(it.current.Index())
|
||||
data, err := it.trie.nodeResolver(it.Path(), hn.hash)
|
||||
if err != nil {
|
||||
it.lastErr = err
|
||||
return false
|
||||
}
|
||||
resolved, err := it.store.DeserializeNodeWithHash(data, len(it.stack)-1, hn.hash)
|
||||
if err != nil {
|
||||
it.lastErr = err
|
||||
return false
|
||||
}
|
||||
|
||||
// Update the stack and parent with the resolved node
|
||||
it.current = resolved
|
||||
it.stack[len(it.stack)-1].Node = resolved
|
||||
parent := &it.stack[len(it.stack)-2]
|
||||
parentNode := it.store.getInternal(parent.Node.Index())
|
||||
if parent.Index == 0 {
|
||||
parentNode.left = resolved
|
||||
} else {
|
||||
parentNode.right = resolved
|
||||
}
|
||||
return it.Next(descend)
|
||||
|
||||
case KindEmpty:
|
||||
return false
|
||||
|
||||
default:
|
||||
panic("invalid node type")
|
||||
}
|
||||
|
|
@ -171,25 +168,21 @@ func (it *binaryNodeIterator) Error() error {
|
|||
|
||||
// Hash returns the hash of the current node.
|
||||
func (it *binaryNodeIterator) Hash() common.Hash {
|
||||
return it.current.Hash()
|
||||
return it.store.ComputeHash(it.current)
|
||||
}
|
||||
|
||||
// Parent returns the hash of the parent of the current node. The hash may be the one
|
||||
// grandparent if the immediate parent is an internal node with no hash.
|
||||
// Parent returns the hash of the parent of the current node.
|
||||
func (it *binaryNodeIterator) Parent() common.Hash {
|
||||
return it.stack[len(it.stack)-1].Node.Hash()
|
||||
return it.store.ComputeHash(it.stack[len(it.stack)-1].Node)
|
||||
}
|
||||
|
||||
// Path returns the hex-encoded path to the current node.
|
||||
// Callers must not retain references to the return value after calling Next.
|
||||
// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10.
|
||||
func (it *binaryNodeIterator) Path() []byte {
|
||||
if it.Leaf() {
|
||||
return it.LeafKey()
|
||||
}
|
||||
var path []byte
|
||||
for i, state := range it.stack {
|
||||
// skip the last byte
|
||||
if i >= len(it.stack)-1 {
|
||||
break
|
||||
}
|
||||
|
|
@ -200,105 +193,78 @@ func (it *binaryNodeIterator) Path() []byte {
|
|||
|
||||
// NodeBlob returns the serialized bytes of the current node.
|
||||
func (it *binaryNodeIterator) NodeBlob() []byte {
|
||||
return SerializeNode(it.current)
|
||||
return it.store.SerializeNode(it.current, MaxGroupDepth)
|
||||
}
|
||||
|
||||
// Leaf returns true iff the current node is a leaf node.
|
||||
// In a Binary Trie, a StemNode contains up to 256 leaf values.
|
||||
// The iterator is only considered to be "at a leaf" when it's positioned
|
||||
// at a specific non-nil value within the StemNode, not just at the StemNode itself.
|
||||
func (it *binaryNodeIterator) Leaf() bool {
|
||||
sn, ok := it.current.(*StemNode)
|
||||
if !ok {
|
||||
if it.current.Kind() != KindStem {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if we have a valid stack position
|
||||
if len(it.stack) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// The Index in the stack state points to the NEXT position after the current value.
|
||||
// So if Index is 0, we haven't started iterating through the values yet.
|
||||
// If Index is 5, we're currently at value[4] (the 5th value, 0-indexed).
|
||||
idx := it.stack[len(it.stack)-1].Index
|
||||
if idx == 0 || idx > 256 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if there's actually a value at the current position
|
||||
sn := it.store.getStem(it.current.Index())
|
||||
currentValueIndex := idx - 1
|
||||
return sn.Values[currentValueIndex] != nil
|
||||
return sn.hasValue(byte(currentValueIndex))
|
||||
}
|
||||
|
||||
// LeafKey returns the key of the leaf. The method panics if the iterator is not
|
||||
// positioned at a leaf. Callers must not retain references to the value after
|
||||
// calling Next.
|
||||
// LeafKey returns the key of the leaf.
|
||||
func (it *binaryNodeIterator) LeafKey() []byte {
|
||||
leaf, ok := it.current.(*StemNode)
|
||||
if !ok {
|
||||
if it.current.Kind() != KindStem {
|
||||
panic("Leaf() called on an binary node iterator not at a leaf location")
|
||||
}
|
||||
return leaf.Key(it.stack[len(it.stack)-1].Index - 1)
|
||||
sn := it.store.getStem(it.current.Index())
|
||||
return sn.Key(it.stack[len(it.stack)-1].Index - 1)
|
||||
}
|
||||
|
||||
// LeafBlob returns the content of the leaf. The method panics if the iterator
|
||||
// is not positioned at a leaf. Callers must not retain references to the value
|
||||
// after calling Next.
|
||||
// LeafBlob returns the content of the leaf.
|
||||
func (it *binaryNodeIterator) LeafBlob() []byte {
|
||||
leaf, ok := it.current.(*StemNode)
|
||||
if !ok {
|
||||
if it.current.Kind() != KindStem {
|
||||
panic("LeafBlob() called on an binary node iterator not at a leaf location")
|
||||
}
|
||||
return leaf.Values[it.stack[len(it.stack)-1].Index-1]
|
||||
sn := it.store.getStem(it.current.Index())
|
||||
return sn.getValue(byte(it.stack[len(it.stack)-1].Index - 1))
|
||||
}
|
||||
|
||||
// LeafProof returns the Merkle proof of the leaf. The method panics if the
|
||||
// iterator is not positioned at a leaf. Callers must not retain references
|
||||
// to the value after calling Next.
|
||||
// LeafProof returns the Merkle proof of the leaf.
|
||||
func (it *binaryNodeIterator) LeafProof() [][]byte {
|
||||
sn, ok := it.current.(*StemNode)
|
||||
if !ok {
|
||||
if it.current.Kind() != KindStem {
|
||||
panic("LeafProof() called on an binary node iterator not at a leaf location")
|
||||
}
|
||||
sn := it.store.getStem(it.current.Index())
|
||||
|
||||
proof := make([][]byte, 0, len(it.stack)+StemNodeWidth)
|
||||
|
||||
// Build proof by walking up the stack and collecting sibling hashes
|
||||
for i := range it.stack[:len(it.stack)-2] {
|
||||
state := it.stack[i]
|
||||
internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode
|
||||
internalNode := it.store.getInternal(state.Node.Index())
|
||||
|
||||
// Add the sibling hash to the proof
|
||||
if state.Index == 0 {
|
||||
// We came from left, so include right sibling
|
||||
proof = append(proof, internalNode.right.Hash().Bytes())
|
||||
rh := it.store.ComputeHash(internalNode.right)
|
||||
proof = append(proof, rh.Bytes())
|
||||
} else {
|
||||
// We came from right, so include left sibling
|
||||
proof = append(proof, internalNode.left.Hash().Bytes())
|
||||
lh := it.store.ComputeHash(internalNode.left)
|
||||
proof = append(proof, lh.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
// Add the stem and siblings
|
||||
proof = append(proof, sn.Stem)
|
||||
for _, v := range sn.Values {
|
||||
proof = append(proof, v)
|
||||
}
|
||||
proof = append(proof, sn.Stem[:])
|
||||
proof = append(proof, sn.allValues()...)
|
||||
|
||||
return proof
|
||||
}
|
||||
|
||||
// AddResolver sets an intermediate database to use for looking up trie nodes
|
||||
// before reaching into the real persistent layer.
|
||||
//
|
||||
// This is not required for normal operation, rather is an optimization for
|
||||
// cases where trie nodes can be recovered from some external mechanism without
|
||||
// reading from disk. In those cases, this resolver allows short circuiting
|
||||
// accesses and returning them from memory.
|
||||
//
|
||||
// Before adding a similar mechanism to any other place in Geth, consider
|
||||
// making trie.Database an interface and wrapping at that level. It's a huge
|
||||
// refactor, but it could be worth it if another occurrence arises.
|
||||
func (it *binaryNodeIterator) AddResolver(trie.NodeResolver) {
|
||||
// Not implemented, but should not panic
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,14 +27,14 @@ import (
|
|||
// makeTrie creates a BinaryTrie populated with the given key-value pairs.
|
||||
func makeTrie(t *testing.T, entries [][2]common.Hash) *BinaryTrie {
|
||||
t.Helper()
|
||||
store := NewNodeStore()
|
||||
tr := &BinaryTrie{
|
||||
root: NewBinaryNode(),
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
store: store,
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
groupDepth: MaxGroupDepth,
|
||||
}
|
||||
for _, kv := range entries {
|
||||
var err error
|
||||
tr.root, err = tr.root.Insert(kv[0][:], kv[1][:], nil, 0)
|
||||
if err != nil {
|
||||
if err := store.Insert(kv[0][:], kv[1][:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
@ -64,8 +64,9 @@ func countLeaves(t *testing.T, tr *BinaryTrie) int {
|
|||
// no nodes and reports no error.
|
||||
func TestIteratorEmptyTrie(t *testing.T) {
|
||||
tr := &BinaryTrie{
|
||||
root: Empty{},
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
store: NewNodeStore(),
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
groupDepth: MaxGroupDepth,
|
||||
}
|
||||
it, err := newBinaryNodeIterator(tr, nil)
|
||||
if err != nil {
|
||||
|
|
@ -145,8 +146,8 @@ func TestIteratorEmptyNodeBacktrack(t *testing.T) {
|
|||
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
|
||||
})
|
||||
|
||||
if _, ok := tr.root.(*InternalNode); !ok {
|
||||
t.Fatalf("expected InternalNode root, got %T", tr.root)
|
||||
if tr.store.Root().Kind() != KindInternal {
|
||||
t.Fatalf("expected InternalNode root, got kind %d", tr.store.Root().Kind())
|
||||
}
|
||||
if leaves := countLeaves(t, tr); leaves != 2 {
|
||||
t.Fatalf("expected 2 leaves, got %d (Empty backtrack bug?)", leaves)
|
||||
|
|
@ -162,18 +163,31 @@ func TestIteratorHashedNodeNilData(t *testing.T) {
|
|||
{common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey},
|
||||
})
|
||||
|
||||
root, ok := tr.root.(*InternalNode)
|
||||
if !ok {
|
||||
t.Fatalf("expected InternalNode root, got %T", tr.root)
|
||||
root := tr.store.Root()
|
||||
if root.Kind() != KindInternal {
|
||||
t.Fatalf("expected InternalNode root, got kind %d", root.Kind())
|
||||
}
|
||||
rootNode := tr.store.getInternal(root.Index())
|
||||
|
||||
// Replace right child with a zero-hash HashedNode. nodeResolver
|
||||
// short-circuits on common.Hash{} and returns (nil, nil), which
|
||||
// triggers the nil-data guard in the iterator.
|
||||
root.right = HashedNode(common.Hash{})
|
||||
rootNode.right = tr.store.newHashedRef(common.Hash{})
|
||||
|
||||
// Should not panic; the zero-hash right child should be treated as Empty.
|
||||
if leaves := countLeaves(t, tr); leaves != 1 {
|
||||
// Since the hashed node can't be resolved (nil data -> empty deserialization),
|
||||
// only the left leaf should be counted.
|
||||
it, err := newBinaryNodeIterator(tr, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
leaves := 0
|
||||
for it.Next(true) {
|
||||
if it.Leaf() {
|
||||
leaves++
|
||||
}
|
||||
}
|
||||
if leaves != 1 {
|
||||
t.Fatalf("expected 1 leaf (zero-hash right node skipped), got %d", leaves)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
57
trie/bintrie/node_ref.go
Normal file
57
trie/bintrie/node_ref.go
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright 2025 go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bintrie
|
||||
|
||||
// NodeKind identifies the type of a trie node stored in a NodeRef.
|
||||
type NodeKind uint8
|
||||
|
||||
const (
|
||||
KindEmpty NodeKind = iota // no node
|
||||
KindInternal // internal binary branching node
|
||||
KindStem // leaf group containing up to 256 values
|
||||
KindHashed // unresolved node (hash only)
|
||||
KindInvalid // sentinel for validation
|
||||
)
|
||||
|
||||
// NodeRef is a compact, GC-invisible reference to a node in a NodeStore.
|
||||
// It packs a 2-bit type tag (bits 31-30) and a 30-bit index (bits 29-0)
|
||||
// into a single uint32. Because NodeRef contains no Go pointers, slices
|
||||
// of structs containing NodeRef fields are allocated in noscan spans —
|
||||
// the garbage collector never examines them.
|
||||
type NodeRef uint32
|
||||
|
||||
const (
|
||||
kindShift uint32 = 30
|
||||
indexMask uint32 = (1 << kindShift) - 1
|
||||
|
||||
// EmptyRef is the zero-value NodeRef, representing an empty node.
|
||||
EmptyRef NodeRef = 0
|
||||
)
|
||||
|
||||
// MakeRef creates a NodeRef from a kind and pool index.
|
||||
func MakeRef(kind NodeKind, idx uint32) NodeRef {
|
||||
return NodeRef(uint32(kind)<<kindShift | (idx & indexMask))
|
||||
}
|
||||
|
||||
// Kind returns the node type tag.
|
||||
func (r NodeRef) Kind() NodeKind { return NodeKind(uint32(r) >> kindShift) }
|
||||
|
||||
// Index returns the pool index within the node's typed pool.
|
||||
func (r NodeRef) Index() uint32 { return uint32(r) & indexMask }
|
||||
|
||||
// IsEmpty returns true if this ref represents an empty node.
|
||||
func (r NodeRef) IsEmpty() bool { return r == EmptyRef }
|
||||
239
trie/bintrie/node_store.go
Normal file
239
trie/bintrie/node_store.go
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
// Copyright 2025 go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bintrie
|
||||
|
||||
import "github.com/ethereum/go-ethereum/common"
|
||||
|
||||
// storeChunkSize is the number of nodes per chunk in each typed pool.
|
||||
// Using fixed-size array chunks ensures that pointers to nodes within
|
||||
// existing chunks remain valid when new chunks are added (no reallocation
|
||||
// of the backing data, only the outer pointer slice grows).
|
||||
const storeChunkSize = 4096
|
||||
|
||||
// NodeStore is a GC-friendly arena for binary trie nodes.
|
||||
//
|
||||
// Instead of allocating each node as a separate heap object with interface
|
||||
// pointers (which the GC must scan), NodeStore packs nodes into typed chunked
|
||||
// pools. InternalNode and HashedNode contain ZERO Go pointers, so their pool
|
||||
// backing arrays are allocated in noscan spans — the GC skips them entirely.
|
||||
// StemNode has one pointer (valueData []byte) per node.
|
||||
//
|
||||
// For a trie with 25K InternalNodes, this reduces GC-scanned pointer-words
|
||||
// from ~125K (with interface-based nodes) to ~25K (just StemNode valueData),
|
||||
// an ~80% reduction. At mainnet scale (millions of nodes), this prevents
|
||||
// multi-second GC pauses.
|
||||
type NodeStore struct {
|
||||
// InternalNode pool — NOSCAN: InternalNode contains zero Go pointers.
|
||||
// Children are NodeRef (uint32), hash is [32]byte.
|
||||
internalChunks []*[storeChunkSize]InternalNode
|
||||
internalCount uint32
|
||||
|
||||
// StemNode pool — each StemNode has one pointer (valueData []byte).
|
||||
// Still much better than the old design where each InternalNode had
|
||||
// two BinaryNode interface pointers (4 pointer-words each).
|
||||
stemChunks []*[storeChunkSize]StemNode
|
||||
stemCount uint32
|
||||
|
||||
// HashedNode pool — NOSCAN: HashedNode is just [32]byte.
|
||||
hashedChunks []*[storeChunkSize]HashedNode
|
||||
hashedCount uint32
|
||||
|
||||
root NodeRef
|
||||
|
||||
// Free lists for recycling deleted node slots.
|
||||
freeInternals []uint32
|
||||
freeStems []uint32
|
||||
freeHashed []uint32
|
||||
}
|
||||
|
||||
// NewNodeStore creates a new empty NodeStore.
|
||||
func NewNodeStore() *NodeStore {
|
||||
return &NodeStore{root: EmptyRef}
|
||||
}
|
||||
|
||||
// Root returns the root NodeRef.
|
||||
func (s *NodeStore) Root() NodeRef { return s.root }
|
||||
|
||||
// SetRoot sets the root NodeRef.
|
||||
func (s *NodeStore) SetRoot(ref NodeRef) { s.root = ref }
|
||||
|
||||
// --- InternalNode allocation ---
|
||||
|
||||
func (s *NodeStore) allocInternal() uint32 {
|
||||
if n := len(s.freeInternals); n > 0 {
|
||||
idx := s.freeInternals[n-1]
|
||||
s.freeInternals = s.freeInternals[:n-1]
|
||||
*s.getInternal(idx) = InternalNode{}
|
||||
return idx
|
||||
}
|
||||
idx := s.internalCount
|
||||
chunkIdx := idx / storeChunkSize
|
||||
if uint32(len(s.internalChunks)) <= chunkIdx {
|
||||
s.internalChunks = append(s.internalChunks, new([storeChunkSize]InternalNode))
|
||||
}
|
||||
s.internalCount++
|
||||
if s.internalCount > indexMask {
|
||||
panic("internal node pool overflow")
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *NodeStore) getInternal(idx uint32) *InternalNode {
|
||||
return &s.internalChunks[idx/storeChunkSize][idx%storeChunkSize]
|
||||
}
|
||||
|
||||
// newInternalRef allocates an InternalNode and returns its NodeRef.
|
||||
func (s *NodeStore) newInternalRef(depth int) NodeRef {
|
||||
if depth > 248 {
|
||||
panic("node depth exceeds maximum binary trie depth")
|
||||
}
|
||||
idx := s.allocInternal()
|
||||
n := s.getInternal(idx)
|
||||
n.depth = uint8(depth)
|
||||
n.mustRecompute = true
|
||||
return MakeRef(KindInternal, idx)
|
||||
}
|
||||
|
||||
// --- StemNode allocation ---
|
||||
|
||||
func (s *NodeStore) allocStem() uint32 {
|
||||
if n := len(s.freeStems); n > 0 {
|
||||
idx := s.freeStems[n-1]
|
||||
s.freeStems = s.freeStems[:n-1]
|
||||
*s.getStem(idx) = StemNode{}
|
||||
return idx
|
||||
}
|
||||
idx := s.stemCount
|
||||
chunkIdx := idx / storeChunkSize
|
||||
if uint32(len(s.stemChunks)) <= chunkIdx {
|
||||
s.stemChunks = append(s.stemChunks, new([storeChunkSize]StemNode))
|
||||
}
|
||||
s.stemCount++
|
||||
if s.stemCount > indexMask {
|
||||
panic("internal node pool overflow")
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *NodeStore) getStem(idx uint32) *StemNode {
|
||||
return &s.stemChunks[idx/storeChunkSize][idx%storeChunkSize]
|
||||
}
|
||||
|
||||
// newStemRef allocates a StemNode with the given stem/depth and returns its NodeRef.
|
||||
func (s *NodeStore) newStemRef(stem []byte, depth int) NodeRef {
|
||||
if depth > 248 {
|
||||
panic("node depth exceeds maximum binary trie depth")
|
||||
}
|
||||
idx := s.allocStem()
|
||||
sn := s.getStem(idx)
|
||||
copy(sn.Stem[:], stem[:StemSize])
|
||||
sn.depth = uint8(depth)
|
||||
sn.mustRecompute = true
|
||||
return MakeRef(KindStem, idx)
|
||||
}
|
||||
|
||||
// --- HashedNode allocation ---
|
||||
|
||||
func (s *NodeStore) allocHashed() uint32 {
|
||||
if n := len(s.freeHashed); n > 0 {
|
||||
idx := s.freeHashed[n-1]
|
||||
s.freeHashed = s.freeHashed[:n-1]
|
||||
*s.getHashed(idx) = HashedNode{}
|
||||
return idx
|
||||
}
|
||||
idx := s.hashedCount
|
||||
chunkIdx := idx / storeChunkSize
|
||||
if uint32(len(s.hashedChunks)) <= chunkIdx {
|
||||
s.hashedChunks = append(s.hashedChunks, new([storeChunkSize]HashedNode))
|
||||
}
|
||||
s.hashedCount++
|
||||
if s.hashedCount > indexMask {
|
||||
panic("internal node pool overflow")
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *NodeStore) getHashed(idx uint32) *HashedNode {
|
||||
return &s.hashedChunks[idx/storeChunkSize][idx%storeChunkSize]
|
||||
}
|
||||
|
||||
func (s *NodeStore) freeHashedNode(idx uint32) {
|
||||
s.freeHashed = append(s.freeHashed, idx)
|
||||
}
|
||||
|
||||
// newHashedRef allocates a HashedNode and returns its NodeRef.
|
||||
func (s *NodeStore) newHashedRef(hash common.Hash) NodeRef {
|
||||
idx := s.allocHashed()
|
||||
*s.getHashed(idx) = HashedNode{hash: hash}
|
||||
return MakeRef(KindHashed, idx)
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the NodeStore and all its nodes.
|
||||
func (s *NodeStore) Copy() *NodeStore {
|
||||
ns := &NodeStore{
|
||||
root: s.root,
|
||||
internalCount: s.internalCount,
|
||||
stemCount: s.stemCount,
|
||||
hashedCount: s.hashedCount,
|
||||
}
|
||||
|
||||
// Deep copy internal chunks
|
||||
ns.internalChunks = make([]*[storeChunkSize]InternalNode, len(s.internalChunks))
|
||||
for i, chunk := range s.internalChunks {
|
||||
cp := *chunk
|
||||
ns.internalChunks[i] = &cp
|
||||
}
|
||||
|
||||
// Deep copy stem chunks (need to deep copy valueData)
|
||||
ns.stemChunks = make([]*[storeChunkSize]StemNode, len(s.stemChunks))
|
||||
for i, chunk := range s.stemChunks {
|
||||
cp := *chunk
|
||||
ns.stemChunks[i] = &cp
|
||||
}
|
||||
// Deep copy pointer fields for each active stem
|
||||
for i := uint32(0); i < s.stemCount; i++ {
|
||||
src := s.getStem(i)
|
||||
dst := ns.getStem(i)
|
||||
if len(src.valueData) > 0 {
|
||||
dst.valueData = make([]byte, len(src.valueData))
|
||||
copy(dst.valueData, src.valueData)
|
||||
}
|
||||
}
|
||||
|
||||
// Deep copy hashed chunks
|
||||
ns.hashedChunks = make([]*[storeChunkSize]HashedNode, len(s.hashedChunks))
|
||||
for i, chunk := range s.hashedChunks {
|
||||
cp := *chunk
|
||||
ns.hashedChunks[i] = &cp
|
||||
}
|
||||
|
||||
// Copy free lists
|
||||
if len(s.freeInternals) > 0 {
|
||||
ns.freeInternals = make([]uint32, len(s.freeInternals))
|
||||
copy(ns.freeInternals, s.freeInternals)
|
||||
}
|
||||
if len(s.freeStems) > 0 {
|
||||
ns.freeStems = make([]uint32, len(s.freeStems))
|
||||
copy(ns.freeStems, s.freeStems)
|
||||
}
|
||||
if len(s.freeHashed) > 0 {
|
||||
ns.freeHashed = make([]uint32, len(s.freeHashed))
|
||||
copy(ns.freeHashed, s.freeHashed)
|
||||
}
|
||||
|
||||
return ns
|
||||
}
|
||||
|
|
@ -17,119 +17,135 @@
|
|||
package bintrie
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"math/bits"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// StemNode represents a group of `NodeWith` values sharing the same stem.
|
||||
// StemNode represents a group of `StemNodeWidth` values sharing the same stem.
|
||||
// It uses a packed representation: bitmap indicates which of the 256 positions
|
||||
// have values, and valueData stores the values contiguously in bitmap order.
|
||||
type StemNode struct {
|
||||
Stem []byte // Stem path to get to StemNodeWidth values
|
||||
Values [][]byte // All values, indexed by the last byte of the key.
|
||||
depth int // Depth of the node
|
||||
Stem [StemSize]byte // Stem path to get to StemNodeWidth values
|
||||
bitmap [StemBitmapSize]byte // bitmap indicating which positions have values
|
||||
valueData []byte // packed value data (count * HashSize bytes)
|
||||
count uint16 // number of values present
|
||||
depth uint8 // Depth of the node
|
||||
shared bool // true if valueData is shared with serialized input
|
||||
|
||||
mustRecompute bool // true if the hash needs to be recomputed
|
||||
dirty bool // true if the node's on-disk blob is stale (needs flush)
|
||||
hash common.Hash // cached hash when mustRecompute == false
|
||||
}
|
||||
|
||||
// Get retrieves the value for the given key.
|
||||
func (bt *StemNode) Get(key []byte, _ NodeResolverFn) ([]byte, error) {
|
||||
if !bytes.Equal(bt.Stem, key[:StemSize]) {
|
||||
return nil, nil
|
||||
// posInData returns the index within valueData for the given suffix.
|
||||
// Returns -1 if the suffix is not present.
|
||||
func (sn *StemNode) posInData(suffix byte) int {
|
||||
idx := int(suffix)
|
||||
if sn.bitmap[idx/8]>>(7-(idx%8))&1 == 0 {
|
||||
return -1
|
||||
}
|
||||
return bt.Values[key[StemSize]], nil
|
||||
// Count the bits set before this position to determine the offset
|
||||
pos := 0
|
||||
byteIdx := idx / 8
|
||||
for i := 0; i < byteIdx; i++ {
|
||||
pos += bits.OnesCount8(sn.bitmap[i])
|
||||
}
|
||||
// Count bits in the partial byte
|
||||
mask := byte(0xFF) << (8 - (idx % 8))
|
||||
pos += bits.OnesCount8(sn.bitmap[byteIdx] & mask)
|
||||
return pos
|
||||
}
|
||||
|
||||
// Insert inserts a new key-value pair into the node.
|
||||
func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
|
||||
if !bytes.Equal(bt.Stem, key[:StemSize]) {
|
||||
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
|
||||
// getValue returns the value at the given suffix, or nil if not present.
|
||||
func (sn *StemNode) getValue(suffix byte) []byte {
|
||||
pos := sn.posInData(suffix)
|
||||
if pos < 0 {
|
||||
return nil
|
||||
}
|
||||
start := pos * HashSize
|
||||
return sn.valueData[start : start+HashSize]
|
||||
}
|
||||
|
||||
n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true}
|
||||
bt.depth++
|
||||
// bt is re-parented under n and sits at a new path — rewrite its blob.
|
||||
bt.mustRecompute = true
|
||||
bt.dirty = true
|
||||
var child, other *BinaryNode
|
||||
if bitStem == 0 {
|
||||
n.left = bt
|
||||
child = &n.left
|
||||
other = &n.right
|
||||
} else {
|
||||
n.right = bt
|
||||
child = &n.right
|
||||
other = &n.left
|
||||
// hasValue returns true if the given suffix has a value.
|
||||
func (sn *StemNode) hasValue(suffix byte) bool {
|
||||
idx := int(suffix)
|
||||
return sn.bitmap[idx/8]>>(7-(idx%8))&1 == 1
|
||||
}
|
||||
|
||||
// allValues returns all 256 values (nil for absent positions).
|
||||
func (sn *StemNode) allValues() [][]byte {
|
||||
values := make([][]byte, StemNodeWidth)
|
||||
dataIdx := 0
|
||||
for i := range StemNodeWidth {
|
||||
if sn.bitmap[i/8]>>(7-(i%8))&1 == 1 {
|
||||
values[i] = sn.valueData[dataIdx*HashSize : (dataIdx+1)*HashSize]
|
||||
dataIdx++
|
||||
}
|
||||
|
||||
bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
|
||||
if bitKey == bitStem {
|
||||
var err error
|
||||
*child, err = (*child).Insert(key, value, nil, depth+1)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("insert error: %w", err)
|
||||
}
|
||||
*other = Empty{}
|
||||
} else {
|
||||
var values [StemNodeWidth][]byte
|
||||
values[key[StemSize]] = value
|
||||
*other = &StemNode{
|
||||
Stem: slices.Clone(key[:StemSize]),
|
||||
Values: values[:],
|
||||
depth: depth + 1,
|
||||
mustRecompute: true,
|
||||
dirty: true,
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
if len(value) != HashSize {
|
||||
return bt, errors.New("invalid insertion: value length")
|
||||
}
|
||||
bt.Values[key[StemSize]] = value
|
||||
bt.mustRecompute = true
|
||||
bt.dirty = true
|
||||
return bt, nil
|
||||
return values
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the node.
|
||||
func (bt *StemNode) Copy() BinaryNode {
|
||||
var values [StemNodeWidth][]byte
|
||||
for i, v := range bt.Values {
|
||||
values[i] = slices.Clone(v)
|
||||
}
|
||||
return &StemNode{
|
||||
Stem: slices.Clone(bt.Stem),
|
||||
Values: values[:],
|
||||
depth: bt.depth,
|
||||
hash: bt.hash,
|
||||
mustRecompute: bt.mustRecompute,
|
||||
dirty: bt.dirty,
|
||||
// ensureWritable makes the valueData writable (copies if shared with serialized input).
|
||||
func (sn *StemNode) ensureWritable() {
|
||||
if sn.shared || cap(sn.valueData)-len(sn.valueData) < HashSize {
|
||||
newData := make([]byte, len(sn.valueData), len(sn.valueData)+HashSize*4)
|
||||
copy(newData, sn.valueData)
|
||||
sn.valueData = newData
|
||||
sn.shared = false
|
||||
}
|
||||
}
|
||||
|
||||
// GetHeight returns the height of the node.
|
||||
func (bt *StemNode) GetHeight() int {
|
||||
return 1
|
||||
// setValue sets or inserts a value at the given suffix.
|
||||
func (sn *StemNode) setValue(suffix byte, value []byte) {
|
||||
idx := int(suffix)
|
||||
pos := sn.posInData(suffix)
|
||||
if pos >= 0 {
|
||||
// Overwrite existing value
|
||||
copy(sn.valueData[pos*HashSize:], value[:HashSize])
|
||||
return
|
||||
}
|
||||
// New value: insert into bitmap and valueData at the correct position.
|
||||
sn.bitmap[idx/8] |= 1 << (7 - (idx % 8))
|
||||
sn.count++
|
||||
|
||||
// Find the correct position in valueData (count bits before this position).
|
||||
insertPos := 0
|
||||
byteIdx := idx / 8
|
||||
for i := 0; i < byteIdx; i++ {
|
||||
insertPos += bits.OnesCount8(sn.bitmap[i])
|
||||
}
|
||||
mask := byte(0xFF) << (8 - (idx % 8))
|
||||
insertPos += bits.OnesCount8(sn.bitmap[byteIdx] & mask)
|
||||
|
||||
// Insert value at the correct position in valueData.
|
||||
insertOffset := insertPos * HashSize
|
||||
// Grow the slice
|
||||
sn.valueData = append(sn.valueData, make([]byte, HashSize)...)
|
||||
// Shift data after insertion point
|
||||
copy(sn.valueData[insertOffset+HashSize:], sn.valueData[insertOffset:len(sn.valueData)-HashSize])
|
||||
// Copy the new value
|
||||
copy(sn.valueData[insertOffset:], value[:HashSize])
|
||||
}
|
||||
|
||||
// Hash returns the hash of the node.
|
||||
func (bt *StemNode) Hash() common.Hash {
|
||||
if !bt.mustRecompute {
|
||||
return bt.hash
|
||||
func (sn *StemNode) Hash() common.Hash {
|
||||
if !sn.mustRecompute {
|
||||
return sn.hash
|
||||
}
|
||||
|
||||
var data [StemNodeWidth]common.Hash
|
||||
h := newSha256()
|
||||
defer returnSha256(h)
|
||||
for i, v := range bt.Values {
|
||||
if v != nil {
|
||||
|
||||
// Hash each present value
|
||||
dataIdx := 0
|
||||
for i := range StemNodeWidth {
|
||||
if sn.bitmap[i/8]>>(7-(i%8))&1 == 1 {
|
||||
v := sn.valueData[dataIdx*HashSize : (dataIdx+1)*HashSize]
|
||||
h.Reset()
|
||||
h.Write(v)
|
||||
h.Sum(data[i][:0])
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
h.Reset()
|
||||
|
|
@ -150,103 +166,18 @@ func (bt *StemNode) Hash() common.Hash {
|
|||
}
|
||||
|
||||
h.Reset()
|
||||
h.Write(bt.Stem)
|
||||
h.Write(sn.Stem[:])
|
||||
h.Write([]byte{0})
|
||||
h.Write(data[0][:])
|
||||
bt.hash = common.BytesToHash(h.Sum(nil))
|
||||
bt.mustRecompute = false
|
||||
return bt.hash
|
||||
}
|
||||
|
||||
// CollectNodes flushes the stem via the collector when dirty; clean stems
|
||||
// are skipped.
|
||||
func (bt *StemNode) CollectNodes(path []byte, flush NodeFlushFn) error {
|
||||
if !bt.dirty {
|
||||
return nil
|
||||
}
|
||||
flush(path, bt)
|
||||
bt.dirty = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetValuesAtStem retrieves the group of values located at the given stem key.
|
||||
func (bt *StemNode) GetValuesAtStem(stem []byte, _ NodeResolverFn) ([][]byte, error) {
|
||||
if !bytes.Equal(bt.Stem, stem) {
|
||||
return nil, nil
|
||||
}
|
||||
return bt.Values[:], nil
|
||||
}
|
||||
|
||||
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
|
||||
// Already-existing values will be overwritten.
|
||||
func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
|
||||
if !bytes.Equal(bt.Stem, key[:StemSize]) {
|
||||
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
|
||||
|
||||
n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true}
|
||||
bt.depth++
|
||||
// bt is re-parented under n and sits at a new path — rewrite its blob.
|
||||
bt.mustRecompute = true
|
||||
bt.dirty = true
|
||||
var child, other *BinaryNode
|
||||
if bitStem == 0 {
|
||||
n.left = bt
|
||||
child = &n.left
|
||||
other = &n.right
|
||||
} else {
|
||||
n.right = bt
|
||||
child = &n.right
|
||||
other = &n.left
|
||||
}
|
||||
|
||||
bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
|
||||
if bitKey == bitStem {
|
||||
var err error
|
||||
*child, err = (*child).InsertValuesAtStem(key, values, nil, depth+1)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("insert error: %w", err)
|
||||
}
|
||||
*other = Empty{}
|
||||
} else {
|
||||
*other = &StemNode{
|
||||
Stem: slices.Clone(key[:StemSize]),
|
||||
Values: values,
|
||||
depth: n.depth + 1,
|
||||
mustRecompute: true,
|
||||
dirty: true,
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// same stem, just merge the two value lists
|
||||
for i, v := range values {
|
||||
if v != nil {
|
||||
bt.Values[i] = v
|
||||
bt.mustRecompute = true
|
||||
bt.dirty = true
|
||||
}
|
||||
}
|
||||
return bt, nil
|
||||
}
|
||||
|
||||
func (bt *StemNode) toDot(parent, path string) string {
|
||||
me := fmt.Sprintf("stem%s", path)
|
||||
ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, bt.Stem, bt.Hash())
|
||||
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
|
||||
for i, v := range bt.Values {
|
||||
if v != nil {
|
||||
ret = fmt.Sprintf("%s%s%x [label=\"%x\"]\n", ret, me, i, v)
|
||||
ret = fmt.Sprintf("%s%s -> %s%x\n", ret, me, me, i)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
sn.hash = common.BytesToHash(h.Sum(nil))
|
||||
sn.mustRecompute = false
|
||||
return sn.hash
|
||||
}
|
||||
|
||||
// Key returns the full key for the given index.
|
||||
func (bt *StemNode) Key(i int) []byte {
|
||||
func (sn *StemNode) Key(i int) []byte {
|
||||
var ret [HashSize]byte
|
||||
copy(ret[:], bt.Stem)
|
||||
copy(ret[:], sn.Stem[:])
|
||||
ret[StemSize] = byte(i)
|
||||
return ret[:]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,165 +23,99 @@ import (
|
|||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// TestStemNodeGet tests the Get method for matching stem, non-matching stem,
|
||||
// and nil-value suffix scenarios.
|
||||
func TestStemNodeGet(t *testing.T) {
|
||||
stem := make([]byte, StemSize)
|
||||
stem[0] = 0xAB
|
||||
var values [StemNodeWidth][]byte
|
||||
values[5] = common.HexToHash("0xdeadbeef").Bytes()
|
||||
|
||||
node := &StemNode{Stem: stem, Values: values[:], depth: 0}
|
||||
|
||||
// Matching stem, populated suffix → returns value.
|
||||
key := make([]byte, HashSize)
|
||||
copy(key[:StemSize], stem)
|
||||
key[StemSize] = 5
|
||||
got, err := node.Get(key, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Get error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, values[5]) {
|
||||
t.Fatalf("Get = %x, want %x", got, values[5])
|
||||
}
|
||||
|
||||
// Matching stem, empty suffix → returns nil (slot not set).
|
||||
key[StemSize] = 99
|
||||
got, err = node.Get(key, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Get error: %v", err)
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("Get(empty suffix) = %x, want nil", got)
|
||||
}
|
||||
|
||||
// Non-matching stem → returns nil, nil.
|
||||
otherKey := make([]byte, HashSize)
|
||||
otherKey[0] = 0xFF
|
||||
got, err = node.Get(otherKey, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Get error: %v", err)
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("Get(wrong stem) = %x, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStemNodeInsertSameStem tests inserting values with the same stem
|
||||
// TestStemNodeInsertSameStem tests inserting values with the same stem via NodeStore.
|
||||
func TestStemNodeInsertSameStem(t *testing.T) {
|
||||
s := NewNodeStore()
|
||||
|
||||
stem := make([]byte, 31)
|
||||
for i := range stem {
|
||||
stem[i] = byte(i)
|
||||
}
|
||||
|
||||
var values [256][]byte
|
||||
values[0] = common.HexToHash("0x0101").Bytes()
|
||||
|
||||
node := &StemNode{
|
||||
Stem: stem,
|
||||
Values: values[:],
|
||||
depth: 0,
|
||||
// Insert first value
|
||||
key1 := make([]byte, 32)
|
||||
copy(key1[:31], stem)
|
||||
key1[31] = 0
|
||||
value1 := common.HexToHash("0x0101").Bytes()
|
||||
if err := s.Insert(key1, value1, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Insert another value with the same stem but different last byte
|
||||
key := make([]byte, 32)
|
||||
copy(key[:31], stem)
|
||||
key[31] = 10
|
||||
value := common.HexToHash("0x0202").Bytes()
|
||||
|
||||
newNode, err := node.Insert(key, value, nil, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert: %v", err)
|
||||
key2 := make([]byte, 32)
|
||||
copy(key2[:31], stem)
|
||||
key2[31] = 10
|
||||
value2 := common.HexToHash("0x0202").Bytes()
|
||||
if err := s.Insert(key2, value2, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should still be a StemNode
|
||||
stemNode, ok := newNode.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected StemNode, got %T", newNode)
|
||||
// Root should still be a StemNode
|
||||
if s.Root().Kind() != KindStem {
|
||||
t.Fatalf("Expected KindStem root, got kind %d", s.Root().Kind())
|
||||
}
|
||||
|
||||
// Check that both values are present
|
||||
if !bytes.Equal(stemNode.Values[0], values[0]) {
|
||||
v1, _ := s.Get(key1, nil)
|
||||
if !bytes.Equal(v1, value1) {
|
||||
t.Errorf("Value at index 0 mismatch")
|
||||
}
|
||||
if !bytes.Equal(stemNode.Values[10], value) {
|
||||
v2, _ := s.Get(key2, nil)
|
||||
if !bytes.Equal(v2, value2) {
|
||||
t.Errorf("Value at index 10 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStemNodeInsertDifferentStem tests inserting values with different stems
|
||||
// TestStemNodeInsertDifferentStem tests inserting values with different stems via NodeStore.
|
||||
func TestStemNodeInsertDifferentStem(t *testing.T) {
|
||||
stem1 := make([]byte, 31)
|
||||
for i := range stem1 {
|
||||
stem1[i] = 0x00
|
||||
}
|
||||
s := NewNodeStore()
|
||||
|
||||
var values [256][]byte
|
||||
values[0] = common.HexToHash("0x0101").Bytes()
|
||||
|
||||
node := &StemNode{
|
||||
Stem: stem1,
|
||||
Values: values[:],
|
||||
depth: 0,
|
||||
// Insert first value with stem of all zeros
|
||||
key1 := make([]byte, 32)
|
||||
key1[31] = 0
|
||||
value1 := common.HexToHash("0x0101").Bytes()
|
||||
if err := s.Insert(key1, value1, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Insert with a different stem (first bit different)
|
||||
key := make([]byte, 32)
|
||||
key[0] = 0x80 // First bit is 1 instead of 0
|
||||
value := common.HexToHash("0x0202").Bytes()
|
||||
|
||||
newNode, err := node.Insert(key, value, nil, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert: %v", err)
|
||||
key2 := make([]byte, 32)
|
||||
key2[0] = 0x80 // First bit is 1 instead of 0
|
||||
value2 := common.HexToHash("0x0202").Bytes()
|
||||
if err := s.Insert(key2, value2, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Should now be an InternalNode
|
||||
internalNode, ok := newNode.(*InternalNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected InternalNode, got %T", newNode)
|
||||
if s.Root().Kind() != KindInternal {
|
||||
t.Fatalf("Expected KindInternal root, got kind %d", s.Root().Kind())
|
||||
}
|
||||
|
||||
// Check depth
|
||||
if internalNode.depth != 0 {
|
||||
t.Errorf("Expected depth 0, got %d", internalNode.depth)
|
||||
rootNode := s.getInternal(s.Root().Index())
|
||||
if rootNode.depth != 0 {
|
||||
t.Errorf("Expected depth 0, got %d", rootNode.depth)
|
||||
}
|
||||
|
||||
// Original stem should be on the left (bit 0)
|
||||
leftStem, ok := internalNode.left.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
|
||||
// Verify both values are retrievable
|
||||
v1, _ := s.Get(key1, nil)
|
||||
if !bytes.Equal(v1, value1) {
|
||||
t.Error("Value 1 mismatch")
|
||||
}
|
||||
if !bytes.Equal(leftStem.Stem, stem1) {
|
||||
t.Errorf("Left stem mismatch")
|
||||
}
|
||||
|
||||
// New stem should be on the right (bit 1)
|
||||
rightStem, ok := internalNode.right.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected right child to be StemNode, got %T", internalNode.right)
|
||||
}
|
||||
if !bytes.Equal(rightStem.Stem, key[:31]) {
|
||||
t.Errorf("Right stem mismatch")
|
||||
v2, _ := s.Get(key2, nil)
|
||||
if !bytes.Equal(v2, value2) {
|
||||
t.Error("Value 2 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length
|
||||
// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length via NodeStore.
|
||||
func TestStemNodeInsertInvalidValueLength(t *testing.T) {
|
||||
stem := make([]byte, 31)
|
||||
var values [256][]byte
|
||||
s := NewNodeStore()
|
||||
|
||||
node := &StemNode{
|
||||
Stem: stem,
|
||||
Values: values[:],
|
||||
depth: 0,
|
||||
}
|
||||
|
||||
// Try to insert value with wrong length
|
||||
key := make([]byte, 32)
|
||||
copy(key[:31], stem)
|
||||
invalidValue := []byte{1, 2, 3} // Not 32 bytes
|
||||
|
||||
_, err := node.Insert(key, invalidValue, nil, 0)
|
||||
err := s.Insert(key, invalidValue, nil)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid value length")
|
||||
}
|
||||
|
|
@ -191,221 +125,209 @@ func TestStemNodeInsertInvalidValueLength(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestStemNodeCopy tests the Copy method
|
||||
// TestStemNodeCopy tests the Copy method via NodeStore.
|
||||
func TestStemNodeCopy(t *testing.T) {
|
||||
stem := make([]byte, 31)
|
||||
for i := range stem {
|
||||
stem[i] = byte(i)
|
||||
s := NewNodeStore()
|
||||
|
||||
key1 := make([]byte, 32)
|
||||
for i := range 31 {
|
||||
key1[i] = byte(i)
|
||||
}
|
||||
key1[31] = 0
|
||||
value1 := common.HexToHash("0x0101").Bytes()
|
||||
|
||||
key2 := make([]byte, 32)
|
||||
copy(key2[:31], key1[:31])
|
||||
key2[31] = 255
|
||||
value2 := common.HexToHash("0x0202").Bytes()
|
||||
|
||||
if err := s.Insert(key1, value1, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.Insert(key2, value2, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var values [256][]byte
|
||||
values[0] = common.HexToHash("0x0101").Bytes()
|
||||
values[255] = common.HexToHash("0x0202").Bytes()
|
||||
ns := s.Copy()
|
||||
|
||||
node := &StemNode{
|
||||
Stem: stem,
|
||||
Values: values[:],
|
||||
depth: 10,
|
||||
}
|
||||
|
||||
// Create a copy
|
||||
copied := node.Copy()
|
||||
copiedStem, ok := copied.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected StemNode, got %T", copied)
|
||||
}
|
||||
|
||||
// Check that values are equal but not the same slice
|
||||
if !bytes.Equal(copiedStem.Stem, node.Stem) {
|
||||
t.Errorf("Stem mismatch after copy")
|
||||
}
|
||||
if &copiedStem.Stem[0] == &node.Stem[0] {
|
||||
t.Error("Stem slice not properly cloned")
|
||||
}
|
||||
|
||||
// Check values
|
||||
if !bytes.Equal(copiedStem.Values[0], node.Values[0]) {
|
||||
// Check that values are equal
|
||||
v1, _ := ns.Get(key1, nil)
|
||||
if !bytes.Equal(v1, value1) {
|
||||
t.Errorf("Value at index 0 mismatch after copy")
|
||||
}
|
||||
if !bytes.Equal(copiedStem.Values[255], node.Values[255]) {
|
||||
v2, _ := ns.Get(key2, nil)
|
||||
if !bytes.Equal(v2, value2) {
|
||||
t.Errorf("Value at index 255 mismatch after copy")
|
||||
}
|
||||
|
||||
// Check that value slices are cloned
|
||||
if copiedStem.Values[0] != nil && &copiedStem.Values[0][0] == &node.Values[0][0] {
|
||||
t.Error("Value slice not properly cloned")
|
||||
}
|
||||
|
||||
// Check depth
|
||||
if copiedStem.depth != node.depth {
|
||||
t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedStem.depth)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStemNodeHash tests the Hash method
|
||||
// TestStemNodeHash tests the Hash method.
|
||||
func TestStemNodeHash(t *testing.T) {
|
||||
stem := make([]byte, 31)
|
||||
var values [256][]byte
|
||||
values[0] = common.HexToHash("0x0101").Bytes()
|
||||
s := NewNodeStore()
|
||||
|
||||
node := &StemNode{
|
||||
Stem: stem,
|
||||
Values: values[:],
|
||||
depth: 0,
|
||||
key := make([]byte, 32)
|
||||
key[31] = 0
|
||||
value := common.HexToHash("0x0101").Bytes()
|
||||
if err := s.Insert(key, value, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hash1 := node.Hash()
|
||||
hash1 := s.ComputeHash(s.Root())
|
||||
|
||||
// Hash should be deterministic
|
||||
hash2 := node.Hash()
|
||||
hash2 := s.ComputeHash(s.Root())
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("Hash not deterministic: %x != %x", hash1, hash2)
|
||||
}
|
||||
|
||||
// Changing a value should change the hash
|
||||
node.Values[1] = common.HexToHash("0x0202").Bytes()
|
||||
node.mustRecompute = true
|
||||
hash3 := node.Hash()
|
||||
key2 := make([]byte, 32)
|
||||
key2[31] = 1
|
||||
value2 := common.HexToHash("0x0202").Bytes()
|
||||
if err := s.Insert(key2, value2, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hash3 := s.ComputeHash(s.Root())
|
||||
if hash1 == hash3 {
|
||||
t.Error("Hash didn't change after modifying values")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStemNodeGetValuesAtStem tests GetValuesAtStem method
|
||||
// TestStemNodeGetValuesAtStem tests GetValuesAtStem method via NodeStore.
|
||||
func TestStemNodeGetValuesAtStem(t *testing.T) {
|
||||
s := NewNodeStore()
|
||||
|
||||
stem := make([]byte, 31)
|
||||
for i := range stem {
|
||||
stem[i] = byte(i)
|
||||
}
|
||||
|
||||
var values [256][]byte
|
||||
values := make([][]byte, 256)
|
||||
values[0] = common.HexToHash("0x0101").Bytes()
|
||||
values[10] = common.HexToHash("0x0202").Bytes()
|
||||
values[255] = common.HexToHash("0x0303").Bytes()
|
||||
|
||||
node := &StemNode{
|
||||
Stem: stem,
|
||||
Values: values[:],
|
||||
depth: 0,
|
||||
if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// GetValuesAtStem with matching stem
|
||||
retrievedValues, err := node.GetValuesAtStem(stem, nil)
|
||||
retrievedValues, err := s.GetValuesAtStem(stem, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get values: %v", err)
|
||||
}
|
||||
|
||||
// Check that all values match
|
||||
for i := range 256 {
|
||||
if !bytes.Equal(retrievedValues[i], values[i]) {
|
||||
t.Errorf("Value mismatch at index %d", i)
|
||||
}
|
||||
if !bytes.Equal(retrievedValues[0], values[0]) {
|
||||
t.Error("Value at index 0 mismatch")
|
||||
}
|
||||
if !bytes.Equal(retrievedValues[10], values[10]) {
|
||||
t.Error("Value at index 10 mismatch")
|
||||
}
|
||||
if !bytes.Equal(retrievedValues[255], values[255]) {
|
||||
t.Error("Value at index 255 mismatch")
|
||||
}
|
||||
|
||||
// GetValuesAtStem with different stem should return nil
|
||||
// GetValuesAtStem with different stem should return nil values
|
||||
differentStem := make([]byte, 31)
|
||||
differentStem[0] = 0xFF
|
||||
|
||||
shouldBeNil, err := node.GetValuesAtStem(differentStem, nil)
|
||||
shouldBeEmpty, err := s.GetValuesAtStem(differentStem, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get values with different stem: %v", err)
|
||||
}
|
||||
|
||||
if shouldBeNil != nil {
|
||||
t.Error("Expected nil for different stem, got non-nil")
|
||||
allNil := true
|
||||
for _, v := range shouldBeEmpty {
|
||||
if v != nil {
|
||||
allNil = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allNil {
|
||||
t.Error("Expected all nil values for different stem")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method
|
||||
// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method via NodeStore.
|
||||
func TestStemNodeInsertValuesAtStem(t *testing.T) {
|
||||
s := NewNodeStore()
|
||||
|
||||
stem := make([]byte, 31)
|
||||
var values [256][]byte
|
||||
values := make([][]byte, 256)
|
||||
values[0] = common.HexToHash("0x0101").Bytes()
|
||||
|
||||
node := &StemNode{
|
||||
Stem: stem,
|
||||
Values: values[:],
|
||||
depth: 0,
|
||||
if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Insert new values at the same stem
|
||||
var newValues [256][]byte
|
||||
newValues := make([][]byte, 256)
|
||||
newValues[1] = common.HexToHash("0x0202").Bytes()
|
||||
newValues[2] = common.HexToHash("0x0303").Bytes()
|
||||
|
||||
newNode, err := node.InsertValuesAtStem(stem, newValues[:], nil, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert values: %v", err)
|
||||
}
|
||||
|
||||
stemNode, ok := newNode.(*StemNode)
|
||||
if !ok {
|
||||
t.Fatalf("Expected StemNode, got %T", newNode)
|
||||
if err := s.InsertValuesAtStem(stem, newValues, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check that all values are present
|
||||
if !bytes.Equal(stemNode.Values[0], values[0]) {
|
||||
retrieved, err := s.GetValuesAtStem(stem, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(retrieved[0], values[0]) {
|
||||
t.Error("Original value at index 0 missing")
|
||||
}
|
||||
if !bytes.Equal(stemNode.Values[1], newValues[1]) {
|
||||
if !bytes.Equal(retrieved[1], newValues[1]) {
|
||||
t.Error("New value at index 1 missing")
|
||||
}
|
||||
if !bytes.Equal(stemNode.Values[2], newValues[2]) {
|
||||
if !bytes.Equal(retrieved[2], newValues[2]) {
|
||||
t.Error("New value at index 2 missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStemNodeGetHeight tests GetHeight method
|
||||
// TestStemNodeGetHeight tests GetHeight method via NodeStore.
|
||||
func TestStemNodeGetHeight(t *testing.T) {
|
||||
node := &StemNode{
|
||||
Stem: make([]byte, 31),
|
||||
Values: make([][]byte, 256),
|
||||
depth: 0,
|
||||
s := NewNodeStore()
|
||||
|
||||
key := make([]byte, 32)
|
||||
value := common.HexToHash("0x01").Bytes()
|
||||
if err := s.Insert(key, value, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
height := node.GetHeight()
|
||||
height := s.GetHeight(s.Root())
|
||||
if height != 1 {
|
||||
t.Errorf("Expected height 1, got %d", height)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStemNodeCollectNodes tests CollectNodes method
|
||||
// TestStemNodeCollectNodes tests CollectNodes method via NodeStore.
|
||||
func TestStemNodeCollectNodes(t *testing.T) {
|
||||
s := NewNodeStore()
|
||||
|
||||
stem := make([]byte, 31)
|
||||
var values [256][]byte
|
||||
values := make([][]byte, 256)
|
||||
values[0] = common.HexToHash("0x0101").Bytes()
|
||||
|
||||
node := &StemNode{
|
||||
Stem: stem,
|
||||
Values: values[:],
|
||||
depth: 0,
|
||||
dirty: true,
|
||||
if err := s.InsertValuesAtStem(stem, values, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var collectedPaths [][]byte
|
||||
var collectedNodes []BinaryNode
|
||||
|
||||
flushFn := func(path []byte, n BinaryNode) {
|
||||
// Make a copy of the path
|
||||
flushFn := func(path []byte, hash common.Hash, serialized []byte) {
|
||||
pathCopy := make([]byte, len(path))
|
||||
copy(pathCopy, path)
|
||||
collectedPaths = append(collectedPaths, pathCopy)
|
||||
collectedNodes = append(collectedNodes, n)
|
||||
}
|
||||
|
||||
err := node.CollectNodes([]byte{0, 1, 0}, flushFn)
|
||||
err := s.CollectNodes(s.Root(), []byte{0, 1, 0}, flushFn, MaxGroupDepth)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to collect nodes: %v", err)
|
||||
}
|
||||
|
||||
// Should have collected one node (itself)
|
||||
if len(collectedNodes) != 1 {
|
||||
t.Errorf("Expected 1 collected node, got %d", len(collectedNodes))
|
||||
}
|
||||
|
||||
// Check that the collected node is the same
|
||||
if collectedNodes[0] != node {
|
||||
t.Error("Collected node doesn't match original")
|
||||
if len(collectedPaths) != 1 {
|
||||
t.Errorf("Expected 1 collected node, got %d", len(collectedPaths))
|
||||
}
|
||||
|
||||
// Check the path
|
||||
|
|
@ -413,44 +335,3 @@ func TestStemNodeCollectNodes(t *testing.T) {
|
|||
t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestStemNodeCollectNodesSkipsClean verifies that a clean stem is not
|
||||
// flushed, and that flushing a dirty stem clears its dirty flag so that
|
||||
// a subsequent CollectNodes on the same node is a no-op.
|
||||
func TestStemNodeCollectNodesSkipsClean(t *testing.T) {
|
||||
stem := make([]byte, 31)
|
||||
node := &StemNode{
|
||||
Stem: stem,
|
||||
Values: make([][]byte, 256),
|
||||
depth: 0,
|
||||
}
|
||||
|
||||
var collected []BinaryNode
|
||||
flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) }
|
||||
|
||||
if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
|
||||
t.Fatalf("CollectNodes on clean stem: %v", err)
|
||||
}
|
||||
if len(collected) != 0 {
|
||||
t.Fatalf("expected clean stem not to be flushed, got %d", len(collected))
|
||||
}
|
||||
|
||||
node.dirty = true
|
||||
if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
|
||||
t.Fatalf("CollectNodes on dirty stem: %v", err)
|
||||
}
|
||||
if len(collected) != 1 {
|
||||
t.Fatalf("expected dirty stem to be flushed once, got %d", len(collected))
|
||||
}
|
||||
if node.dirty {
|
||||
t.Errorf("stem dirty flag should be cleared after flush")
|
||||
}
|
||||
|
||||
collected = nil
|
||||
if err := node.CollectNodes([]byte{0}, flushFn); err != nil {
|
||||
t.Fatalf("CollectNodes after flush: %v", err)
|
||||
}
|
||||
if len(collected) != 0 {
|
||||
t.Errorf("expected no flush on clean stem, got %d", len(collected))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
477
trie/bintrie/store_commit.go
Normal file
477
trie/bintrie/store_commit.go
Normal file
|
|
@ -0,0 +1,477 @@
|
|||
// Copyright 2025 go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bintrie
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"sync"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// NodeFlushFn is called during commit to flush serialized nodes.
|
||||
type NodeFlushFn func(path []byte, hash common.Hash, serialized []byte)
|
||||
|
||||
// Hash computes and returns the root hash.
|
||||
func (s *NodeStore) Hash() common.Hash {
|
||||
return s.ComputeHash(s.root)
|
||||
}
|
||||
|
||||
// ComputeHash computes the hash of the node referenced by ref.
|
||||
func (s *NodeStore) ComputeHash(ref NodeRef) common.Hash {
|
||||
switch ref.Kind() {
|
||||
case KindInternal:
|
||||
return s.hashInternal(ref.Index())
|
||||
case KindStem:
|
||||
return s.getStem(ref.Index()).Hash()
|
||||
case KindHashed:
|
||||
return s.getHashed(ref.Index()).hash
|
||||
case KindEmpty:
|
||||
return common.Hash{}
|
||||
default:
|
||||
return common.Hash{}
|
||||
}
|
||||
}
|
||||
|
||||
// hashInternal computes the hash of an InternalNode. At shallow depths
|
||||
// (< parallelHashDepth), the left subtree is hashed in a goroutine while
|
||||
// the right subtree is hashed inline. This is safe because left and right
|
||||
// subtrees are disjoint in a well-formed tree — no node appears in both.
|
||||
// ComputeHash must not be called concurrently with mutations to the NodeStore.
|
||||
func (s *NodeStore) hashInternal(idx uint32) common.Hash {
|
||||
node := s.getInternal(idx)
|
||||
if !node.mustRecompute {
|
||||
return node.hash
|
||||
}
|
||||
|
||||
if node.depth < parallelHashDepth {
|
||||
var input [64]byte
|
||||
var lh common.Hash
|
||||
var wg sync.WaitGroup
|
||||
if !node.left.IsEmpty() {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
lh = s.ComputeHash(node.left)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
if !node.right.IsEmpty() {
|
||||
rh := s.ComputeHash(node.right)
|
||||
copy(input[32:], rh[:])
|
||||
}
|
||||
wg.Wait()
|
||||
copy(input[:32], lh[:])
|
||||
node.hash = sha256Sum256(input[:])
|
||||
node.mustRecompute = false
|
||||
return node.hash
|
||||
}
|
||||
|
||||
var input [64]byte
|
||||
if !node.left.IsEmpty() {
|
||||
lh := s.ComputeHash(node.left)
|
||||
copy(input[:32], lh[:])
|
||||
}
|
||||
if !node.right.IsEmpty() {
|
||||
rh := s.ComputeHash(node.right)
|
||||
copy(input[32:], rh[:])
|
||||
}
|
||||
node.hash = sha256Sum256(input[:])
|
||||
node.mustRecompute = false
|
||||
return node.hash
|
||||
}
|
||||
|
||||
// --- Serialization ---
|
||||
|
||||
// countSubtreeChildren counts non-empty children at the bottom layer of a subtree.
|
||||
func (s *NodeStore) countSubtreeChildren(ref NodeRef, remainingDepth int) int {
|
||||
if remainingDepth == 0 {
|
||||
if ref.IsEmpty() {
|
||||
return 0
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if ref.Kind() == KindInternal {
|
||||
node := s.getInternal(ref.Index())
|
||||
return s.countSubtreeChildren(node.left, remainingDepth-1) + s.countSubtreeChildren(node.right, remainingDepth-1)
|
||||
}
|
||||
if ref.IsEmpty() {
|
||||
return 0
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// serializeSubtreeDirect writes child hashes directly into the serialized buffer.
|
||||
func (s *NodeStore) serializeSubtreeDirect(ref NodeRef, remainingDepth int, position int, absoluteDepth int, bitmap []byte, out []byte, hashOffset *int) {
|
||||
if remainingDepth == 0 {
|
||||
if ref.IsEmpty() {
|
||||
return
|
||||
}
|
||||
bitmap[position/8] |= 1 << (7 - (position % 8))
|
||||
h := s.ComputeHash(ref)
|
||||
copy(out[*hashOffset:*hashOffset+HashSize], h[:])
|
||||
*hashOffset += HashSize
|
||||
return
|
||||
}
|
||||
|
||||
if ref.Kind() == KindInternal {
|
||||
node := s.getInternal(ref.Index())
|
||||
leftPos := position * 2
|
||||
rightPos := position*2 + 1
|
||||
s.serializeSubtreeDirect(node.left, remainingDepth-1, leftPos, absoluteDepth+1, bitmap, out, hashOffset)
|
||||
s.serializeSubtreeDirect(node.right, remainingDepth-1, rightPos, absoluteDepth+1, bitmap, out, hashOffset)
|
||||
return
|
||||
}
|
||||
|
||||
if ref.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
||||
// Leaf (StemNode or HashedNode) at a non-zero remaining depth: compute its position.
|
||||
leafPos := position
|
||||
if ref.Kind() == KindStem {
|
||||
sn := s.getStem(ref.Index())
|
||||
for d := 0; d < remainingDepth; d++ {
|
||||
bit := sn.Stem[(absoluteDepth+d)/8] >> (7 - ((absoluteDepth + d) % 8)) & 1
|
||||
leafPos = leafPos*2 + int(bit)
|
||||
}
|
||||
} else {
|
||||
leafPos = position << remainingDepth
|
||||
}
|
||||
bitmap[leafPos/8] |= 1 << (7 - (leafPos % 8))
|
||||
h := s.ComputeHash(ref)
|
||||
copy(out[*hashOffset:*hashOffset+HashSize], h[:])
|
||||
*hashOffset += HashSize
|
||||
}
|
||||
|
||||
// SerializeNode serializes a node referenced by ref.
|
||||
func (s *NodeStore) SerializeNode(ref NodeRef, groupDepth int) []byte {
|
||||
if groupDepth < 1 || groupDepth > MaxGroupDepth {
|
||||
panic("groupDepth must be between 1 and 8")
|
||||
}
|
||||
|
||||
switch ref.Kind() {
|
||||
case KindInternal:
|
||||
node := s.getInternal(ref.Index())
|
||||
bitmapSize := BitmapSizeForDepth(groupDepth)
|
||||
hashCount := s.countSubtreeChildren(ref, groupDepth)
|
||||
serializedLen := NodeTypeBytes + 1 + bitmapSize + hashCount*HashSize
|
||||
serialized := make([]byte, serializedLen)
|
||||
serialized[0] = nodeTypeInternal
|
||||
serialized[1] = byte(groupDepth)
|
||||
bitmap := serialized[2 : 2+bitmapSize]
|
||||
hashOffset := 2 + bitmapSize
|
||||
s.serializeSubtreeDirect(ref, groupDepth, 0, int(node.depth), bitmap, serialized, &hashOffset)
|
||||
return serialized
|
||||
|
||||
case KindStem:
|
||||
sn := s.getStem(ref.Index())
|
||||
serializedLen := NodeTypeBytes + StemSize + StemBitmapSize + len(sn.valueData)
|
||||
serialized := make([]byte, serializedLen)
|
||||
serialized[0] = nodeTypeStem
|
||||
copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], sn.Stem[:])
|
||||
copy(serialized[NodeTypeBytes+StemSize:NodeTypeBytes+StemSize+StemBitmapSize], sn.bitmap[:])
|
||||
copy(serialized[NodeTypeBytes+StemSize+StemBitmapSize:], sn.valueData)
|
||||
return serialized
|
||||
|
||||
default:
|
||||
panic(fmt.Sprintf("SerializeNode: unexpected node kind %d", ref.Kind()))
|
||||
}
|
||||
}
|
||||
|
||||
// --- Deserialization ---
|
||||
|
||||
var errInvalidSerializedLength = errors.New("invalid serialized node length")
|
||||
|
||||
// DeserializeNode deserializes a node from bytes, recomputing its hash.
|
||||
func (s *NodeStore) DeserializeNode(serialized []byte, depth int) (NodeRef, error) {
|
||||
return s.deserializeNode(serialized, depth, common.Hash{}, true)
|
||||
}
|
||||
|
||||
// DeserializeNodeWithHash deserializes a node, using the provided hash.
|
||||
func (s *NodeStore) DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (NodeRef, error) {
|
||||
return s.deserializeNode(serialized, depth, hn, false)
|
||||
}
|
||||
|
||||
func (s *NodeStore) deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute bool) (NodeRef, error) {
|
||||
if len(serialized) == 0 {
|
||||
return EmptyRef, nil
|
||||
}
|
||||
|
||||
switch serialized[0] {
|
||||
case nodeTypeInternal:
|
||||
if len(serialized) < NodeTypeBytes+1 {
|
||||
return EmptyRef, errInvalidSerializedLength
|
||||
}
|
||||
groupDepth := int(serialized[1])
|
||||
if groupDepth < 1 || groupDepth > MaxGroupDepth {
|
||||
return EmptyRef, errors.New("invalid group depth")
|
||||
}
|
||||
bitmapSize := BitmapSizeForDepth(groupDepth)
|
||||
if len(serialized) < NodeTypeBytes+1+bitmapSize {
|
||||
return EmptyRef, errInvalidSerializedLength
|
||||
}
|
||||
bitmap := serialized[2 : 2+bitmapSize]
|
||||
hashData := serialized[2+bitmapSize:]
|
||||
|
||||
hashIdx := 0
|
||||
ref, err := s.deserializeSubtree(groupDepth, 0, depth, bitmap, hashData, &hashIdx, mustRecompute)
|
||||
if err != nil {
|
||||
return EmptyRef, err
|
||||
}
|
||||
if ref.Kind() == KindInternal && !mustRecompute {
|
||||
s.getInternal(ref.Index()).hash = hn
|
||||
}
|
||||
return ref, nil
|
||||
|
||||
case nodeTypeStem:
|
||||
if len(serialized) < 64 {
|
||||
return EmptyRef, errInvalidSerializedLength
|
||||
}
|
||||
stemIdx := s.allocStem()
|
||||
sn := s.getStem(stemIdx)
|
||||
copy(sn.Stem[:], serialized[NodeTypeBytes:NodeTypeBytes+StemSize])
|
||||
copy(sn.bitmap[:], serialized[NodeTypeBytes+StemSize:NodeTypeBytes+StemSize+StemBitmapSize])
|
||||
|
||||
var count uint16
|
||||
for i := range StemBitmapSize {
|
||||
count += uint16(bits.OnesCount8(sn.bitmap[i]))
|
||||
}
|
||||
sn.count = count
|
||||
dataStart := NodeTypeBytes + StemSize + StemBitmapSize
|
||||
dataEnd := dataStart + int(count)*HashSize
|
||||
if len(serialized) < dataEnd {
|
||||
return EmptyRef, errInvalidSerializedLength
|
||||
}
|
||||
// Zero-copy sub-slice of serialized data
|
||||
sn.valueData = serialized[dataStart:dataEnd]
|
||||
sn.shared = true
|
||||
sn.depth = uint8(depth)
|
||||
sn.hash = hn
|
||||
sn.mustRecompute = mustRecompute
|
||||
return MakeRef(KindStem, stemIdx), nil
|
||||
|
||||
default:
|
||||
return EmptyRef, errors.New("invalid node type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeStore) deserializeSubtree(remainingDepth int, position int, nodeDepth int, bitmap []byte, hashData []byte, hashIdx *int, mustRecompute bool) (NodeRef, error) {
|
||||
if remainingDepth == 0 {
|
||||
if bitmap[position/8]>>(7-(position%8))&1 == 1 {
|
||||
if len(hashData) < (*hashIdx+1)*HashSize {
|
||||
return EmptyRef, errInvalidSerializedLength
|
||||
}
|
||||
hash := common.BytesToHash(hashData[*hashIdx*HashSize : (*hashIdx+1)*HashSize])
|
||||
*hashIdx++
|
||||
return s.newHashedRef(hash), nil
|
||||
}
|
||||
return EmptyRef, nil
|
||||
}
|
||||
|
||||
leftPos := position * 2
|
||||
rightPos := position*2 + 1
|
||||
|
||||
left, err := s.deserializeSubtree(remainingDepth-1, leftPos, nodeDepth+1, bitmap, hashData, hashIdx, true)
|
||||
if err != nil {
|
||||
return EmptyRef, err
|
||||
}
|
||||
right, err := s.deserializeSubtree(remainingDepth-1, rightPos, nodeDepth+1, bitmap, hashData, hashIdx, true)
|
||||
if err != nil {
|
||||
return EmptyRef, err
|
||||
}
|
||||
|
||||
if left.IsEmpty() && right.IsEmpty() {
|
||||
return EmptyRef, nil
|
||||
}
|
||||
|
||||
if nodeDepth > 248 {
|
||||
panic("node depth exceeds maximum binary trie depth")
|
||||
}
|
||||
idx := s.allocInternal()
|
||||
n := s.getInternal(idx)
|
||||
n.depth = uint8(nodeDepth)
|
||||
n.left = left
|
||||
n.right = right
|
||||
n.mustRecompute = mustRecompute
|
||||
return MakeRef(KindInternal, idx), nil
|
||||
}
|
||||
|
||||
// --- CollectNodes (Commit) ---
|
||||
|
||||
// CollectNodes traverses the trie, flushing nodes at group boundaries.
|
||||
func (s *NodeStore) CollectNodes(ref NodeRef, path []byte, flushfn NodeFlushFn, groupDepth int) error {
|
||||
if groupDepth < 1 || groupDepth > MaxGroupDepth {
|
||||
return errors.New("groupDepth must be between 1 and 8")
|
||||
}
|
||||
buf := make([]byte, len(path), len(path)+MaxGroupDepth+1)
|
||||
copy(buf, path)
|
||||
return s.collectNodesBuf(ref, buf, flushfn, groupDepth)
|
||||
}
|
||||
|
||||
func (s *NodeStore) collectNodesBuf(ref NodeRef, buf []byte, flushfn NodeFlushFn, groupDepth int) error {
|
||||
switch ref.Kind() {
|
||||
case KindInternal:
|
||||
node := s.getInternal(ref.Index())
|
||||
if int(node.depth)%groupDepth == 0 {
|
||||
if err := s.collectChildGroupsBuf(ref, buf, flushfn, groupDepth, groupDepth-1); err != nil {
|
||||
return err
|
||||
}
|
||||
serialized := s.SerializeNode(ref, groupDepth)
|
||||
flushfn(buf, s.ComputeHash(ref), serialized)
|
||||
return nil
|
||||
}
|
||||
return s.collectChildGroupsBuf(ref, buf, flushfn, groupDepth, groupDepth-(int(node.depth)%groupDepth)-1)
|
||||
|
||||
case KindStem:
|
||||
serialized := s.SerializeNode(ref, groupDepth)
|
||||
flushfn(buf, s.ComputeHash(ref), serialized)
|
||||
return nil
|
||||
|
||||
case KindHashed, KindEmpty:
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("collectNodesBuf: unexpected kind %d", ref.Kind())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NodeStore) collectChildGroupsBuf(ref NodeRef, buf []byte, flushfn NodeFlushFn, groupDepth int, remainingLevels int) error {
|
||||
if ref.Kind() != KindInternal {
|
||||
return nil
|
||||
}
|
||||
node := s.getInternal(ref.Index())
|
||||
saved := len(buf)
|
||||
childDepth := int(node.depth) + 1
|
||||
|
||||
if remainingLevels == 0 {
|
||||
if !node.left.IsEmpty() {
|
||||
buf = append(buf, 0)
|
||||
if err := s.collectNodesBuf(node.left, buf, flushfn, groupDepth); err != nil {
|
||||
return err
|
||||
}
|
||||
buf = buf[:saved]
|
||||
}
|
||||
if !node.right.IsEmpty() {
|
||||
buf = append(buf, 1)
|
||||
if err := s.collectNodesBuf(node.right, buf, flushfn, groupDepth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Left child
|
||||
if !node.left.IsEmpty() {
|
||||
if node.left.Kind() == KindInternal {
|
||||
buf = append(buf, 0)
|
||||
if err := s.collectChildGroupsBuf(node.left, buf, flushfn, groupDepth, remainingLevels-1); err != nil {
|
||||
return err
|
||||
}
|
||||
buf = buf[:saved]
|
||||
} else {
|
||||
buf = append(buf, 0)
|
||||
buf = s.extendPathBuf(buf, node.left, remainingLevels, childDepth)
|
||||
if err := s.collectNodesBuf(node.left, buf, flushfn, groupDepth); err != nil {
|
||||
return err
|
||||
}
|
||||
buf = buf[:saved]
|
||||
}
|
||||
}
|
||||
|
||||
// Right child
|
||||
if !node.right.IsEmpty() {
|
||||
if node.right.Kind() == KindInternal {
|
||||
buf = append(buf, 1)
|
||||
if err := s.collectChildGroupsBuf(node.right, buf, flushfn, groupDepth, remainingLevels-1); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
buf = append(buf, 1)
|
||||
buf = s.extendPathBuf(buf, node.right, remainingLevels, childDepth)
|
||||
if err := s.collectNodesBuf(node.right, buf, flushfn, groupDepth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extendPathBuf extends the path buffer to the group's leaf boundary.
|
||||
func (s *NodeStore) extendPathBuf(buf []byte, ref NodeRef, remainingLevels int, absoluteDepth int) []byte {
|
||||
if remainingLevels <= 0 {
|
||||
return buf
|
||||
}
|
||||
if ref.Kind() == KindStem {
|
||||
sn := s.getStem(ref.Index())
|
||||
for d := 0; d < remainingLevels; d++ {
|
||||
bit := sn.Stem[(absoluteDepth+d)/8] >> (7 - ((absoluteDepth + d) % 8)) & 1
|
||||
buf = append(buf, bit)
|
||||
}
|
||||
} else {
|
||||
for d := 0; d < remainingLevels; d++ {
|
||||
buf = append(buf, 0)
|
||||
}
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
// ToDot generates a DOT representation for debugging.
|
||||
func (s *NodeStore) ToDot(ref NodeRef, parent, path string) string {
|
||||
switch ref.Kind() {
|
||||
case KindInternal:
|
||||
node := s.getInternal(ref.Index())
|
||||
me := fmt.Sprintf("internal%s", path)
|
||||
ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, s.ComputeHash(ref))
|
||||
if len(parent) > 0 {
|
||||
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
|
||||
}
|
||||
if !node.left.IsEmpty() {
|
||||
ret += s.ToDot(node.left, me, fmt.Sprintf("%s%02x", path, 0))
|
||||
}
|
||||
if !node.right.IsEmpty() {
|
||||
ret += s.ToDot(node.right, me, fmt.Sprintf("%s%02x", path, 1))
|
||||
}
|
||||
return ret
|
||||
case KindStem:
|
||||
sn := s.getStem(ref.Index())
|
||||
me := fmt.Sprintf("stem%s", path)
|
||||
ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, sn.Stem, sn.Hash())
|
||||
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
|
||||
idx := 0
|
||||
for i := range StemNodeWidth {
|
||||
if sn.bitmap[i/8]>>(7-i%8)&1 != 1 {
|
||||
continue
|
||||
}
|
||||
v := sn.valueData[idx*HashSize : (idx+1)*HashSize]
|
||||
idx++
|
||||
ret += fmt.Sprintf("%s%x [label=\"%x\"]\n", me, i, v)
|
||||
ret += fmt.Sprintf("%s -> %s%x\n", me, me, i)
|
||||
}
|
||||
return ret
|
||||
case KindHashed:
|
||||
hn := s.getHashed(ref.Index())
|
||||
me := fmt.Sprintf("hash%s", path)
|
||||
ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, hn.hash)
|
||||
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
|
||||
return ret
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
556
trie/bintrie/store_ops.go
Normal file
556
trie/bintrie/store_ops.go
Normal file
|
|
@ -0,0 +1,556 @@
|
|||
// Copyright 2025 go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package bintrie
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
// NodeResolverFn resolves a hashed node from the database.
|
||||
type NodeResolverFn func([]byte, common.Hash) ([]byte, error)
|
||||
|
||||
// GetSingle retrieves a single value at stem[suffix] from the trie root.
|
||||
func (s *NodeStore) GetSingle(stem []byte, suffix byte, resolver NodeResolverFn) ([]byte, error) {
|
||||
return s.getSingle(s.root, stem, suffix, resolver)
|
||||
}
|
||||
|
||||
// getSingle retrieves a single value using iterative traversal.
|
||||
func (s *NodeStore) getSingle(ref NodeRef, stem []byte, suffix byte, resolver NodeResolverFn) ([]byte, error) {
|
||||
cur := ref
|
||||
// Track parent for HashedNode resolution (update parent's child ref).
|
||||
var parentIdx uint32
|
||||
var parentIsLeft bool
|
||||
hasParent := false
|
||||
|
||||
for {
|
||||
switch cur.Kind() {
|
||||
case KindInternal:
|
||||
node := s.getInternal(cur.Index())
|
||||
if node.depth >= 31*8 {
|
||||
return nil, errors.New("node too deep")
|
||||
}
|
||||
bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
|
||||
parentIdx = cur.Index()
|
||||
hasParent = true
|
||||
if bit == 0 {
|
||||
parentIsLeft = true
|
||||
cur = node.left
|
||||
} else {
|
||||
parentIsLeft = false
|
||||
cur = node.right
|
||||
}
|
||||
|
||||
case KindStem:
|
||||
sn := s.getStem(cur.Index())
|
||||
if sn.Stem != [StemSize]byte(stem[:StemSize]) {
|
||||
return nil, nil
|
||||
}
|
||||
return sn.getValue(suffix), nil
|
||||
|
||||
case KindHashed:
|
||||
if !hasParent {
|
||||
return nil, errors.New("getSingle: hashed node at root")
|
||||
}
|
||||
hn := s.getHashed(cur.Index())
|
||||
parentNode := s.getInternal(parentIdx)
|
||||
path := makeKeyPath(int(parentNode.depth), stem)
|
||||
data, err := resolver(path, hn.hash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getSingle resolve error: %w", err)
|
||||
}
|
||||
resolved, err := s.DeserializeNodeWithHash(data, int(parentNode.depth)+1, hn.hash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getSingle deserialization error: %w", err)
|
||||
}
|
||||
// Update parent's child ref
|
||||
s.freeHashedNode(cur.Index())
|
||||
if parentIsLeft {
|
||||
parentNode.left = resolved
|
||||
} else {
|
||||
parentNode.right = resolved
|
||||
}
|
||||
cur = resolved
|
||||
|
||||
case KindEmpty:
|
||||
return nil, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("getSingle: unexpected node kind %d", cur.Kind())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetValuesAtStem retrieves all 256 values at a stem.
|
||||
func (s *NodeStore) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) {
|
||||
return s.getValuesAtStem(s.root, stem, resolver)
|
||||
}
|
||||
|
||||
// getValuesAtStem uses iterative traversal to find the StemNode.
|
||||
func (s *NodeStore) getValuesAtStem(ref NodeRef, stem []byte, resolver NodeResolverFn) ([][]byte, error) {
|
||||
cur := ref
|
||||
var parentIdx uint32
|
||||
var parentIsLeft bool
|
||||
hasParent := false
|
||||
|
||||
for {
|
||||
switch cur.Kind() {
|
||||
case KindInternal:
|
||||
node := s.getInternal(cur.Index())
|
||||
if node.depth >= 31*8 {
|
||||
return nil, errors.New("node too deep")
|
||||
}
|
||||
bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
|
||||
parentIdx = cur.Index()
|
||||
hasParent = true
|
||||
if bit == 0 {
|
||||
parentIsLeft = true
|
||||
cur = node.left
|
||||
} else {
|
||||
parentIsLeft = false
|
||||
cur = node.right
|
||||
}
|
||||
|
||||
case KindStem:
|
||||
sn := s.getStem(cur.Index())
|
||||
if sn.Stem != [StemSize]byte(stem[:StemSize]) {
|
||||
return nil, nil
|
||||
}
|
||||
return sn.allValues(), nil
|
||||
|
||||
case KindHashed:
|
||||
if !hasParent {
|
||||
return nil, errors.New("getValuesAtStem: hashed node at root")
|
||||
}
|
||||
hn := s.getHashed(cur.Index())
|
||||
parentNode := s.getInternal(parentIdx)
|
||||
path := makeKeyPath(int(parentNode.depth), stem)
|
||||
data, err := resolver(path, hn.hash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
resolved, err := s.DeserializeNodeWithHash(data, int(parentNode.depth)+1, hn.hash)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getValuesAtStem deserialization error: %w", err)
|
||||
}
|
||||
s.freeHashedNode(cur.Index())
|
||||
if parentIsLeft {
|
||||
parentNode.left = resolved
|
||||
} else {
|
||||
parentNode.right = resolved
|
||||
}
|
||||
cur = resolved
|
||||
|
||||
case KindEmpty:
|
||||
var values [StemNodeWidth][]byte
|
||||
return values[:], nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("getValuesAtStem: unexpected node kind %d", cur.Kind())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InsertSingle inserts a single value at stem[suffix] into the trie.
|
||||
func (s *NodeStore) InsertSingle(stem []byte, suffix byte, value []byte, resolver NodeResolverFn) error {
|
||||
if len(value) != HashSize {
|
||||
return errors.New("invalid insertion: value length")
|
||||
}
|
||||
|
||||
// Handle root-is-empty case
|
||||
if s.root.IsEmpty() {
|
||||
ref := s.newStemRef(stem, 0)
|
||||
sn := s.getStem(ref.Index())
|
||||
sn.setValue(suffix, value)
|
||||
s.root = ref
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle root-is-stem case
|
||||
if s.root.Kind() == KindStem {
|
||||
sn := s.getStem(s.root.Index())
|
||||
if sn.Stem == [StemSize]byte(stem[:StemSize]) {
|
||||
sn.ensureWritable()
|
||||
sn.setValue(suffix, value)
|
||||
sn.mustRecompute = true
|
||||
return nil
|
||||
}
|
||||
// Different stem — promote root to internal node via split
|
||||
newRoot := s.splitStemInsert(s.root, stem, suffix, value, int(sn.depth))
|
||||
s.root = newRoot
|
||||
return nil
|
||||
}
|
||||
|
||||
// Root is an InternalNode — iterative descent
|
||||
return s.insertSingleInternal(stem, suffix, value, resolver)
|
||||
}
|
||||
|
||||
// insertSingleInternal handles insertion when root is an InternalNode.
|
||||
func (s *NodeStore) insertSingleInternal(stem []byte, suffix byte, value []byte, resolver NodeResolverFn) error {
|
||||
type pathEntry struct {
|
||||
internalIdx uint32
|
||||
isLeft bool
|
||||
}
|
||||
var pathStack [256]pathEntry // stack-allocated, max depth 248
|
||||
pathLen := 0
|
||||
|
||||
cur := s.root
|
||||
|
||||
for {
|
||||
switch cur.Kind() {
|
||||
case KindInternal:
|
||||
node := s.getInternal(cur.Index())
|
||||
node.mustRecompute = true
|
||||
bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
|
||||
pathStack[pathLen] = pathEntry{internalIdx: cur.Index(), isLeft: bit == 0}
|
||||
pathLen++
|
||||
if bit == 0 {
|
||||
cur = node.left
|
||||
} else {
|
||||
cur = node.right
|
||||
}
|
||||
|
||||
case KindStem:
|
||||
sn := s.getStem(cur.Index())
|
||||
if sn.Stem == [StemSize]byte(stem[:StemSize]) {
|
||||
sn.ensureWritable()
|
||||
sn.setValue(suffix, value)
|
||||
sn.mustRecompute = true
|
||||
return nil
|
||||
}
|
||||
// Different stem — split
|
||||
parentDepth := int(s.getInternal(pathStack[pathLen-1].internalIdx).depth) + 1
|
||||
newRef := s.splitStemInsert(cur, stem, suffix, value, parentDepth)
|
||||
p := pathStack[pathLen-1]
|
||||
parent := s.getInternal(p.internalIdx)
|
||||
if p.isLeft {
|
||||
parent.left = newRef
|
||||
} else {
|
||||
parent.right = newRef
|
||||
}
|
||||
return nil
|
||||
|
||||
case KindHashed:
|
||||
if pathLen == 0 {
|
||||
return errors.New("insertSingle: hashed node at root")
|
||||
}
|
||||
p := pathStack[pathLen-1]
|
||||
parentNode := s.getInternal(p.internalIdx)
|
||||
hn := s.getHashed(cur.Index())
|
||||
path := makeKeyPath(int(parentNode.depth), stem)
|
||||
data, err := resolver(path, hn.hash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insertSingle resolve error: %w", err)
|
||||
}
|
||||
resolved, err := s.DeserializeNodeWithHash(data, int(parentNode.depth)+1, hn.hash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insertSingle deserialization error: %w", err)
|
||||
}
|
||||
s.freeHashedNode(cur.Index())
|
||||
if p.isLeft {
|
||||
parentNode.left = resolved
|
||||
} else {
|
||||
parentNode.right = resolved
|
||||
}
|
||||
cur = resolved
|
||||
|
||||
case KindEmpty:
|
||||
parentDepth := int(s.getInternal(pathStack[pathLen-1].internalIdx).depth) + 1
|
||||
ref := s.newStemRef(stem, parentDepth)
|
||||
sn := s.getStem(ref.Index())
|
||||
sn.setValue(suffix, value)
|
||||
p := pathStack[pathLen-1]
|
||||
parent := s.getInternal(p.internalIdx)
|
||||
if p.isLeft {
|
||||
parent.left = ref
|
||||
} else {
|
||||
parent.right = ref
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("insertSingle: unexpected node kind %d", cur.Kind())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitStemInsert handles the case where we need to split a StemNode
|
||||
// into a chain of InternalNodes because the new key has a different stem.
|
||||
func (s *NodeStore) splitStemInsert(existingRef NodeRef, newStem []byte, suffix byte, value []byte, depth int) NodeRef {
|
||||
existing := s.getStem(existingRef.Index())
|
||||
existingDepth := depth
|
||||
|
||||
var firstRef NodeRef
|
||||
var lastInternalIdx uint32
|
||||
var lastIsLeft bool
|
||||
first := true
|
||||
|
||||
for {
|
||||
bitExisting := existing.Stem[existingDepth/8] >> (7 - (existingDepth % 8)) & 1
|
||||
bitNew := newStem[existingDepth/8] >> (7 - (existingDepth % 8)) & 1
|
||||
|
||||
newInternalIdx := s.allocInternal()
|
||||
newInternal := s.getInternal(newInternalIdx)
|
||||
newInternal.depth = uint8(existingDepth)
|
||||
newInternal.mustRecompute = true
|
||||
newRef := MakeRef(KindInternal, newInternalIdx)
|
||||
|
||||
if first {
|
||||
firstRef = newRef
|
||||
first = false
|
||||
} else {
|
||||
parent := s.getInternal(lastInternalIdx)
|
||||
if lastIsLeft {
|
||||
parent.left = newRef
|
||||
} else {
|
||||
parent.right = newRef
|
||||
}
|
||||
}
|
||||
|
||||
if bitExisting != bitNew {
|
||||
// Divergence point
|
||||
existing.depth = uint8(existingDepth + 1)
|
||||
|
||||
newStemIdx := s.allocStem()
|
||||
newSn := s.getStem(newStemIdx)
|
||||
copy(newSn.Stem[:], newStem[:StemSize])
|
||||
newSn.depth = uint8(existingDepth + 1)
|
||||
newSn.mustRecompute = true
|
||||
newSn.setValue(suffix, value)
|
||||
newStemRef := MakeRef(KindStem, newStemIdx)
|
||||
|
||||
if bitExisting == 0 {
|
||||
newInternal.left = existingRef
|
||||
newInternal.right = newStemRef
|
||||
} else {
|
||||
newInternal.left = newStemRef
|
||||
newInternal.right = existingRef
|
||||
}
|
||||
return firstRef
|
||||
}
|
||||
|
||||
// Same bit — continue splitting
|
||||
lastInternalIdx = newInternalIdx
|
||||
lastIsLeft = (bitExisting == 0)
|
||||
existingDepth++
|
||||
}
|
||||
}
|
||||
|
||||
// InsertValuesAtStem inserts a full group of values at the given stem.
|
||||
func (s *NodeStore) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn) error {
|
||||
newRoot, err := s.insertValuesAtStem(s.root, stem, values, resolver, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.root = newRoot
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertValuesAtStem recursively inserts values at a stem.
|
||||
func (s *NodeStore) insertValuesAtStem(ref NodeRef, stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (NodeRef, error) {
|
||||
switch ref.Kind() {
|
||||
case KindInternal:
|
||||
node := s.getInternal(ref.Index())
|
||||
bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1
|
||||
if bit == 0 {
|
||||
if node.left.Kind() == KindHashed {
|
||||
hn := s.getHashed(node.left.Index())
|
||||
path := makeKeyPath(int(node.depth), stem)
|
||||
data, err := resolver(path, hn.hash)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
resolved, err := s.DeserializeNodeWithHash(data, int(node.depth)+1, hn.hash)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
|
||||
}
|
||||
s.freeHashedNode(node.left.Index())
|
||||
node.left = resolved
|
||||
}
|
||||
newChild, err := s.insertValuesAtStem(node.left, stem, values, resolver, depth+1)
|
||||
if err != nil {
|
||||
return ref, err
|
||||
}
|
||||
node.left = newChild
|
||||
} else {
|
||||
if node.right.Kind() == KindHashed {
|
||||
hn := s.getHashed(node.right.Index())
|
||||
path := makeKeyPath(int(node.depth), stem)
|
||||
data, err := resolver(path, hn.hash)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
resolved, err := s.DeserializeNodeWithHash(data, int(node.depth)+1, hn.hash)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
|
||||
}
|
||||
s.freeHashedNode(node.right.Index())
|
||||
node.right = resolved
|
||||
}
|
||||
newChild, err := s.insertValuesAtStem(node.right, stem, values, resolver, depth+1)
|
||||
if err != nil {
|
||||
return ref, err
|
||||
}
|
||||
node.right = newChild
|
||||
}
|
||||
node.mustRecompute = true
|
||||
return ref, nil
|
||||
|
||||
case KindStem:
|
||||
sn := s.getStem(ref.Index())
|
||||
if sn.Stem == [StemSize]byte(stem[:StemSize]) {
|
||||
// Same stem — merge values
|
||||
sn.ensureWritable()
|
||||
for i, v := range values {
|
||||
if v != nil {
|
||||
sn.setValue(byte(i), v)
|
||||
sn.mustRecompute = true
|
||||
}
|
||||
}
|
||||
return ref, nil
|
||||
}
|
||||
// Different stem — split
|
||||
return s.splitStemValuesInsert(ref, stem, values, resolver, depth)
|
||||
|
||||
case KindHashed:
|
||||
hn := s.getHashed(ref.Index())
|
||||
path, err := keyToPath(depth, stem)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err)
|
||||
}
|
||||
if resolver == nil {
|
||||
return ref, errors.New("InsertValuesAtStem: resolver is nil")
|
||||
}
|
||||
data, err := resolver(path, hn.hash)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
resolved, err := s.DeserializeNodeWithHash(data, depth, hn.hash)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err)
|
||||
}
|
||||
s.freeHashedNode(ref.Index())
|
||||
return s.insertValuesAtStem(resolved, stem, values, resolver, depth)
|
||||
|
||||
case KindEmpty:
|
||||
// Create new StemNode
|
||||
stemIdx := s.allocStem()
|
||||
sn := s.getStem(stemIdx)
|
||||
copy(sn.Stem[:], stem[:StemSize])
|
||||
sn.depth = uint8(depth)
|
||||
sn.mustRecompute = true
|
||||
for i, v := range values {
|
||||
if v != nil {
|
||||
sn.count++
|
||||
sn.bitmap[i/8] |= 1 << (7 - (i % 8))
|
||||
sn.valueData = append(sn.valueData, v[:HashSize]...)
|
||||
}
|
||||
}
|
||||
return MakeRef(KindStem, stemIdx), nil
|
||||
|
||||
default:
|
||||
return ref, fmt.Errorf("insertValuesAtStem: unexpected kind %d", ref.Kind())
|
||||
}
|
||||
}
|
||||
|
||||
// splitStemValuesInsert handles splitting a StemNode when inserting values with a different stem.
|
||||
func (s *NodeStore) splitStemValuesInsert(existingRef NodeRef, newStem []byte, values [][]byte, resolver NodeResolverFn, depth int) (NodeRef, error) {
|
||||
existing := s.getStem(existingRef.Index())
|
||||
|
||||
bitStem := existing.Stem[existing.depth/8] >> (7 - (existing.depth % 8)) & 1
|
||||
nRef := s.newInternalRef(int(existing.depth))
|
||||
nNode := s.getInternal(nRef.Index())
|
||||
existing.depth++
|
||||
|
||||
bitKey := newStem[nNode.depth/8] >> (7 - (nNode.depth % 8)) & 1
|
||||
if bitKey == bitStem {
|
||||
// Same direction — need deeper split
|
||||
var child NodeRef
|
||||
if bitStem == 0 {
|
||||
nNode.left = existingRef
|
||||
child = nNode.left
|
||||
} else {
|
||||
nNode.right = existingRef
|
||||
child = nNode.right
|
||||
}
|
||||
newChild, err := s.insertValuesAtStem(child, newStem, values, resolver, depth+1)
|
||||
if err != nil {
|
||||
return nRef, err
|
||||
}
|
||||
if bitStem == 0 {
|
||||
nNode.left = newChild
|
||||
nNode.right = EmptyRef
|
||||
} else {
|
||||
nNode.right = newChild
|
||||
nNode.left = EmptyRef
|
||||
}
|
||||
} else {
|
||||
// Divergence — create new StemNode for the new values
|
||||
newStemIdx := s.allocStem()
|
||||
newSn := s.getStem(newStemIdx)
|
||||
copy(newSn.Stem[:], newStem[:StemSize])
|
||||
newSn.depth = nNode.depth + 1
|
||||
newSn.mustRecompute = true
|
||||
for i, v := range values {
|
||||
if v != nil {
|
||||
newSn.setValue(byte(i), v)
|
||||
}
|
||||
}
|
||||
newStemRef := MakeRef(KindStem, newStemIdx)
|
||||
|
||||
if bitStem == 0 {
|
||||
nNode.left = existingRef
|
||||
nNode.right = newStemRef
|
||||
} else {
|
||||
nNode.left = newStemRef
|
||||
nNode.right = existingRef
|
||||
}
|
||||
}
|
||||
return nRef, nil
|
||||
}
|
||||
|
||||
// Insert inserts a key-value pair using the full 32-byte key.
|
||||
func (s *NodeStore) Insert(key []byte, value []byte, resolver NodeResolverFn) error {
|
||||
return s.InsertSingle(key[:StemSize], key[StemSize], value, resolver)
|
||||
}
|
||||
|
||||
// Get retrieves the value for the given 32-byte key.
|
||||
func (s *NodeStore) Get(key []byte, resolver NodeResolverFn) ([]byte, error) {
|
||||
return s.GetSingle(key[:StemSize], key[StemSize], resolver)
|
||||
}
|
||||
|
||||
// GetHeight returns the height of the trie rooted at ref.
|
||||
func (s *NodeStore) GetHeight(ref NodeRef) int {
|
||||
switch ref.Kind() {
|
||||
case KindInternal:
|
||||
node := s.getInternal(ref.Index())
|
||||
lh := s.GetHeight(node.left)
|
||||
rh := s.GetHeight(node.right)
|
||||
if lh > rh {
|
||||
return 1 + lh
|
||||
}
|
||||
return 1 + rh
|
||||
case KindStem:
|
||||
return 1
|
||||
case KindEmpty:
|
||||
return 0
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
|
@ -19,7 +19,6 @@ package bintrie
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
|
|
@ -31,8 +30,6 @@ import (
|
|||
"github.com/holiman/uint256"
|
||||
)
|
||||
|
||||
var errInvalidRootType = errors.New("invalid root type")
|
||||
|
||||
// ChunkedCode represents a sequence of HashSize-byte chunks of code (StemSize bytes of which
|
||||
// are actual code, and NodeTypeBytes byte is the pushdata offset).
|
||||
type ChunkedCode []byte
|
||||
|
|
@ -108,22 +105,18 @@ func ChunkifyCode(code []byte) ChunkedCode {
|
|||
return chunks
|
||||
}
|
||||
|
||||
// NewBinaryNode creates a new empty binary trie
|
||||
func NewBinaryNode() BinaryNode {
|
||||
return Empty{}
|
||||
}
|
||||
|
||||
// BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864.
|
||||
type BinaryTrie struct {
|
||||
root BinaryNode
|
||||
reader *trie.Reader
|
||||
tracer *trie.PrevalueTracer
|
||||
store *NodeStore
|
||||
reader *trie.Reader
|
||||
tracer *trie.PrevalueTracer
|
||||
groupDepth int
|
||||
}
|
||||
|
||||
// ToDot converts the binary trie to a DOT language representation. Useful for debugging.
|
||||
func (t *BinaryTrie) ToDot() string {
|
||||
t.root.Hash()
|
||||
return ToDot(t.root)
|
||||
t.store.ComputeHash(t.store.Root())
|
||||
return t.store.ToDot(t.store.Root(), "", "")
|
||||
}
|
||||
|
||||
// NewBinaryTrie creates a new binary trie.
|
||||
|
|
@ -133,9 +126,10 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
|
|||
return nil, err
|
||||
}
|
||||
t := &BinaryTrie{
|
||||
root: NewBinaryNode(),
|
||||
reader: reader,
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
store: NewNodeStore(),
|
||||
reader: reader,
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
groupDepth: MaxGroupDepth,
|
||||
}
|
||||
// Parse the root node if it's not empty
|
||||
if root != types.EmptyBinaryHash && root != types.EmptyRootHash {
|
||||
|
|
@ -143,11 +137,11 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node, err := DeserializeNodeWithHash(blob, 0, root)
|
||||
ref, err := t.store.DeserializeNodeWithHash(blob, 0, root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.root = node
|
||||
t.store.SetRoot(ref)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
|
@ -176,29 +170,18 @@ func (t *BinaryTrie) GetKey(key []byte) []byte {
|
|||
// GetWithHashedKey returns the value, assuming that the key has already
|
||||
// been hashed.
|
||||
func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) {
|
||||
return t.root.Get(key, t.nodeResolver)
|
||||
return t.store.Get(key, t.nodeResolver)
|
||||
}
|
||||
|
||||
// GetAccount returns the account information for the given address.
|
||||
func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) {
|
||||
var (
|
||||
values [][]byte
|
||||
err error
|
||||
acc = &types.StateAccount{}
|
||||
key = GetBinaryTreeKey(addr, zero[:])
|
||||
)
|
||||
switch r := t.root.(type) {
|
||||
case *InternalNode:
|
||||
values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver)
|
||||
case *StemNode:
|
||||
values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver)
|
||||
case Empty:
|
||||
return nil, nil
|
||||
default:
|
||||
// This will cover HashedNode but that should be fine since the
|
||||
// root node should always be resolved.
|
||||
return nil, errInvalidRootType
|
||||
}
|
||||
|
||||
values, err := t.store.GetValuesAtStem(key[:StemSize], t.nodeResolver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetAccount (%x) error: %v", addr, err)
|
||||
}
|
||||
|
|
@ -219,7 +202,7 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
|
|||
// If the account has been deleted, BasicData and CodeHash will both be
|
||||
// 32-byte zero blobs (not nil). If the account is recreated afterwards,
|
||||
// UpdateAccount overwrites BasicData and CodeHash with non-zero values,
|
||||
// so this branch won't activate..
|
||||
// so this branch won't activate.
|
||||
if bytes.Equal(values[BasicDataLeafKey], zero[:]) &&
|
||||
bytes.Equal(values[CodeHashLeafKey], zero[:]) {
|
||||
return nil, nil
|
||||
|
|
@ -238,13 +221,12 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error
|
|||
// not be modified by the caller. If a node was not found in the database, a
|
||||
// trie.MissingNodeError is returned.
|
||||
func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) {
|
||||
return t.root.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
|
||||
return t.store.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver)
|
||||
}
|
||||
|
||||
// UpdateAccount updates the account information for the given address.
|
||||
func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error {
|
||||
var (
|
||||
err error
|
||||
basicData [HashSize]byte
|
||||
values = make([][]byte, StemNodeWidth)
|
||||
stem = GetBinaryTreeKey(addr, zero[:])
|
||||
|
|
@ -265,15 +247,12 @@ func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount,
|
|||
values[BasicDataLeafKey] = basicData[:]
|
||||
values[CodeHashLeafKey] = acc.CodeHash[:]
|
||||
|
||||
t.root, err = t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0)
|
||||
return err
|
||||
return t.store.InsertValuesAtStem(stem, values, t.nodeResolver)
|
||||
}
|
||||
|
||||
// UpdateStem updates the values for the given stem key.
|
||||
func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error {
|
||||
var err error
|
||||
t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0)
|
||||
return err
|
||||
return t.store.InsertValuesAtStem(key, values, t.nodeResolver)
|
||||
}
|
||||
|
||||
// UpdateStorage associates key with value in the trie. If value has length zero, any
|
||||
|
|
@ -288,11 +267,10 @@ func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) er
|
|||
} else {
|
||||
copy(v[HashSize-len(value):], value[:])
|
||||
}
|
||||
root, err := t.root.Insert(k, v[:], t.nodeResolver, 0)
|
||||
err := t.store.Insert(k, v[:], t.nodeResolver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("UpdateStorage (%x) error: %v", address, err)
|
||||
}
|
||||
t.root = root
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -307,12 +285,7 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
|
|||
values[BasicDataLeafKey] = zero[:]
|
||||
values[CodeHashLeafKey] = zero[:]
|
||||
|
||||
root, err := t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DeleteAccount (%x) error: %v", addr, err)
|
||||
}
|
||||
t.root = root
|
||||
return nil
|
||||
return t.store.InsertValuesAtStem(stem, values, t.nodeResolver)
|
||||
}
|
||||
|
||||
// DeleteStorage removes any existing value for key from the trie. If a node was not
|
||||
|
|
@ -320,18 +293,17 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error {
|
|||
func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error {
|
||||
k := GetBinaryTreeKeyStorageSlot(addr, key)
|
||||
var zero [HashSize]byte
|
||||
root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0)
|
||||
err := t.store.Insert(k, zero[:], t.nodeResolver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DeleteStorage (%x) error: %v", addr, err)
|
||||
}
|
||||
t.root = root
|
||||
return nil
|
||||
}
|
||||
|
||||
// Hash returns the root hash of the trie. It does not write to the database and
|
||||
// can be used even if the trie doesn't have one.
|
||||
func (t *BinaryTrie) Hash() common.Hash {
|
||||
return t.root.Hash()
|
||||
return t.store.ComputeHash(t.store.Root())
|
||||
}
|
||||
|
||||
// Commit writes all nodes to the trie's memory database, tracking the internal
|
||||
|
|
@ -339,15 +311,17 @@ func (t *BinaryTrie) Hash() common.Hash {
|
|||
func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
|
||||
nodeset := trienode.NewNodeSet(common.Hash{})
|
||||
|
||||
// The root can be any type of BinaryNode (InternalNode, StemNode, etc.)
|
||||
err := t.root.CollectNodes(nil, func(path []byte, node BinaryNode) {
|
||||
serialized := SerializeNode(node)
|
||||
nodeset.AddNode(path, trienode.NewNodeWithPrev(node.Hash(), serialized, t.tracer.Get(path)))
|
||||
})
|
||||
groupDepth := t.groupDepth
|
||||
if groupDepth == 0 {
|
||||
groupDepth = MaxGroupDepth
|
||||
}
|
||||
|
||||
err := t.store.CollectNodes(t.store.Root(), nil, func(path []byte, hash common.Hash, serialized []byte) {
|
||||
nodeset.AddNode(path, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(path)))
|
||||
}, groupDepth)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("CollectNodes failed: %v", err))
|
||||
}
|
||||
// Serialize root commitment form
|
||||
return t.Hash(), nodeset
|
||||
}
|
||||
|
||||
|
|
@ -371,9 +345,10 @@ func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error {
|
|||
// Copy creates a deep copy of the trie.
|
||||
func (t *BinaryTrie) Copy() *BinaryTrie {
|
||||
return &BinaryTrie{
|
||||
root: t.root.Copy(),
|
||||
reader: t.reader,
|
||||
tracer: t.tracer.Copy(),
|
||||
store: t.store.Copy(),
|
||||
reader: t.reader,
|
||||
tracer: t.tracer.Copy(),
|
||||
groupDepth: t.groupDepth,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -407,7 +382,6 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has
|
|||
|
||||
if groupOffset == StemNodeWidth-1 || len(chunks)-i <= HashSize {
|
||||
err = t.UpdateStem(key[:StemSize], values)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,147 +37,130 @@ var (
|
|||
)
|
||||
|
||||
func TestSingleEntry(t *testing.T) {
|
||||
tree := NewBinaryNode()
|
||||
tree, err := tree.Insert(zeroKey[:], oneKey[:], nil, 0)
|
||||
if err != nil {
|
||||
s := NewNodeStore()
|
||||
if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tree.GetHeight() != 1 {
|
||||
if s.GetHeight(s.Root()) != 1 {
|
||||
t.Fatal("invalid depth")
|
||||
}
|
||||
expected := common.HexToHash("aab1060e04cb4f5dc6f697ae93156a95714debbf77d54238766adc5709282b6f")
|
||||
got := tree.Hash()
|
||||
got := s.Hash()
|
||||
if got != expected {
|
||||
t.Fatalf("invalid tree root, got %x, want %x", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoEntriesDiffFirstBit(t *testing.T) {
|
||||
var err error
|
||||
tree := NewBinaryNode()
|
||||
tree, err = tree.Insert(zeroKey[:], oneKey[:], nil, 0)
|
||||
if err != nil {
|
||||
s := NewNodeStore()
|
||||
if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tree.GetHeight() != 2 {
|
||||
if s.GetHeight(s.Root()) != 2 {
|
||||
t.Fatal("invalid height")
|
||||
}
|
||||
if tree.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") {
|
||||
if s.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") {
|
||||
t.Fatal("invalid tree root")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOneStemColocatedValues(t *testing.T) {
|
||||
var err error
|
||||
tree := NewBinaryNode()
|
||||
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
|
||||
if err != nil {
|
||||
s := NewNodeStore()
|
||||
if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tree, err = tree.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tree.GetHeight() != 1 {
|
||||
if s.GetHeight(s.Root()) != 1 {
|
||||
t.Fatal("invalid height")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoStemColocatedValues(t *testing.T) {
|
||||
var err error
|
||||
tree := NewBinaryNode()
|
||||
s := NewNodeStore()
|
||||
// stem: 0...0
|
||||
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// stem: 10...0
|
||||
tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tree.GetHeight() != 2 {
|
||||
if s.GetHeight(s.Root()) != 2 {
|
||||
t.Fatal("invalid height")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTwoKeysMatchFirst42Bits(t *testing.T) {
|
||||
var err error
|
||||
tree := NewBinaryNode()
|
||||
s := NewNodeStore()
|
||||
// key1 and key 2 have the same prefix of 42 bits (b0*42+b1+b1) and differ after.
|
||||
key1 := common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0").Bytes()
|
||||
key2 := common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000").Bytes()
|
||||
tree, err = tree.Insert(key1, oneKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(key1, oneKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tree, err = tree.Insert(key2, twoKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(key2, twoKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tree.GetHeight() != 1+42+1 {
|
||||
if s.GetHeight(s.Root()) != 1+42+1 {
|
||||
t.Fatal("invalid height")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertDuplicateKey(t *testing.T) {
|
||||
var err error
|
||||
tree := NewBinaryNode()
|
||||
tree, err = tree.Insert(oneKey[:], oneKey[:], nil, 0)
|
||||
if err != nil {
|
||||
s := NewNodeStore()
|
||||
if err := s.Insert(oneKey[:], oneKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tree, err = tree.Insert(oneKey[:], twoKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(oneKey[:], twoKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tree.GetHeight() != 1 {
|
||||
if s.GetHeight(s.Root()) != 1 {
|
||||
t.Fatal("invalid height")
|
||||
}
|
||||
// Verify that the value is updated
|
||||
if !bytes.Equal(tree.(*StemNode).Values[1], twoKey[:]) {
|
||||
t.Fatal("invalid height")
|
||||
v, err := s.Get(oneKey[:], nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(v, twoKey[:]) {
|
||||
t.Fatal("value not updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLargeNumberOfEntries(t *testing.T) {
|
||||
var err error
|
||||
tree := NewBinaryNode()
|
||||
s := NewNodeStore()
|
||||
for i := range StemNodeWidth {
|
||||
var key [HashSize]byte
|
||||
key[0] = byte(i)
|
||||
tree, err = tree.Insert(key[:], ffKey[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(key[:], ffKey[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
height := tree.GetHeight()
|
||||
height := s.GetHeight(s.Root())
|
||||
if height != 1+8 {
|
||||
t.Fatalf("invalid height, wanted %d, got %d", 1+8, height)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMerkleizeMultipleEntries(t *testing.T) {
|
||||
var err error
|
||||
tree := NewBinaryNode()
|
||||
s := NewNodeStore()
|
||||
keys := [][]byte{
|
||||
zeroKey[:],
|
||||
common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(),
|
||||
|
|
@ -187,12 +170,11 @@ func TestMerkleizeMultipleEntries(t *testing.T) {
|
|||
for i, key := range keys {
|
||||
var v [HashSize]byte
|
||||
binary.LittleEndian.PutUint64(v[:8], uint64(i))
|
||||
tree, err = tree.Insert(key, v[:], nil, 0)
|
||||
if err != nil {
|
||||
if err := s.Insert(key, v[:], nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
got := tree.Hash()
|
||||
got := s.Hash()
|
||||
expected := common.HexToHash("9317155862f7a3867660ddd0966ff799a3d16aa4df1e70a7516eaa4a675191b5")
|
||||
if got != expected {
|
||||
t.Fatalf("invalid root, expected=%x, got = %x", expected, got)
|
||||
|
|
@ -206,8 +188,9 @@ func TestMerkleizeMultipleEntries(t *testing.T) {
|
|||
func TestStorageRoundTrip(t *testing.T) {
|
||||
tracer := trie.NewPrevalueTracer()
|
||||
tr := &BinaryTrie{
|
||||
root: NewBinaryNode(),
|
||||
tracer: tracer,
|
||||
store: NewNodeStore(),
|
||||
tracer: tracer,
|
||||
groupDepth: MaxGroupDepth,
|
||||
}
|
||||
addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678")
|
||||
|
||||
|
|
@ -274,8 +257,9 @@ func TestStorageRoundTrip(t *testing.T) {
|
|||
func newEmptyTestTrie(t *testing.T) *BinaryTrie {
|
||||
t.Helper()
|
||||
return &BinaryTrie{
|
||||
root: NewBinaryNode(),
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
store: NewNodeStore(),
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
groupDepth: MaxGroupDepth,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -599,8 +583,9 @@ func TestBinaryTrieWitness(t *testing.T) {
|
|||
tracer := trie.NewPrevalueTracer()
|
||||
|
||||
tr := &BinaryTrie{
|
||||
root: NewBinaryNode(),
|
||||
tracer: tracer,
|
||||
store: NewNodeStore(),
|
||||
tracer: tracer,
|
||||
groupDepth: MaxGroupDepth,
|
||||
}
|
||||
if w := tr.Witness(); len(w) != 0 {
|
||||
t.Fatal("expected empty witness for fresh trie")
|
||||
|
|
@ -626,8 +611,9 @@ func TestBinaryTrieWitness(t *testing.T) {
|
|||
func testAccount(t *testing.T, addr common.Address, nonce uint64, balance uint64) *BinaryTrie {
|
||||
t.Helper()
|
||||
tr := &BinaryTrie{
|
||||
root: NewBinaryNode(),
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
store: NewNodeStore(),
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
groupDepth: MaxGroupDepth,
|
||||
}
|
||||
acc := &types.StateAccount{
|
||||
Nonce: nonce,
|
||||
|
|
@ -649,8 +635,8 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) {
|
|||
tr := testAccount(t, addr, 42, 100)
|
||||
|
||||
// Verify root is a StemNode (single stem inserted).
|
||||
if _, ok := tr.root.(*StemNode); !ok {
|
||||
t.Fatalf("expected StemNode root, got %T", tr.root)
|
||||
if tr.store.Root().Kind() != KindStem {
|
||||
t.Fatalf("expected StemNode root, got kind %d", tr.store.Root().Kind())
|
||||
}
|
||||
|
||||
// Query a completely different address — must return nil.
|
||||
|
|
@ -680,8 +666,9 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) {
|
|||
// address returns nil when the trie root is an InternalNode (multi-account trie).
|
||||
func TestGetAccountNonMembershipInternalRoot(t *testing.T) {
|
||||
tr := &BinaryTrie{
|
||||
root: NewBinaryNode(),
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
store: NewNodeStore(),
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
groupDepth: MaxGroupDepth,
|
||||
}
|
||||
|
||||
// Insert two accounts whose binary tree keys have different first bits
|
||||
|
|
@ -700,8 +687,8 @@ func TestGetAccountNonMembershipInternalRoot(t *testing.T) {
|
|||
}
|
||||
|
||||
// Verify root is an InternalNode.
|
||||
if _, ok := tr.root.(*InternalNode); !ok {
|
||||
t.Fatalf("expected InternalNode root, got %T", tr.root)
|
||||
if tr.store.Root().Kind() != KindInternal {
|
||||
t.Fatalf("expected InternalNode root, got kind %d", tr.store.Root().Kind())
|
||||
}
|
||||
|
||||
// Query a non-existent address — must return nil.
|
||||
|
|
@ -723,8 +710,8 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) {
|
|||
tr := testAccount(t, addr, 1, 100)
|
||||
|
||||
// Verify root is a StemNode.
|
||||
if _, ok := tr.root.(*StemNode); !ok {
|
||||
t.Fatalf("expected StemNode root, got %T", tr.root)
|
||||
if tr.store.Root().Kind() != KindStem {
|
||||
t.Fatalf("expected StemNode root, got kind %d", tr.store.Root().Kind())
|
||||
}
|
||||
|
||||
// Query storage for a different address — must return nil, not panic.
|
||||
|
|
@ -743,8 +730,9 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) {
|
|||
// non-existent address returns nil when the root is an InternalNode.
|
||||
func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
|
||||
tr := &BinaryTrie{
|
||||
root: NewBinaryNode(),
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
store: NewNodeStore(),
|
||||
tracer: trie.NewPrevalueTracer(),
|
||||
groupDepth: MaxGroupDepth,
|
||||
}
|
||||
|
||||
addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678")
|
||||
|
|
@ -765,8 +753,8 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) {
|
|||
t.Fatalf("UpdateStorage error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := tr.root.(*InternalNode); !ok {
|
||||
t.Fatalf("expected InternalNode root, got %T", tr.root)
|
||||
if tr.store.Root().Kind() != KindInternal {
|
||||
t.Fatalf("expected InternalNode root, got kind %d", tr.store.Root().Kind())
|
||||
}
|
||||
|
||||
// Query storage for a non-existent address — must return nil.
|
||||
|
|
|
|||
|
|
@ -102,11 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) {
|
|||
if len(blob) == 0 {
|
||||
return types.EmptyBinaryHash, nil
|
||||
}
|
||||
n, err := bintrie.DeserializeNode(blob, 0)
|
||||
if err != nil {
|
||||
return common.Hash{}, err
|
||||
}
|
||||
return n.Hash(), nil
|
||||
return bintrie.DeserializeAndHash(blob, 0)
|
||||
}
|
||||
|
||||
// Database is a multiple-layered structure for maintaining in-memory states
|
||||
|
|
|
|||
Loading…
Reference in a new issue