diff --git a/trie/bintrie/binary_node.go b/trie/bintrie/binary_node.go index 8905a82285..fbfd4a6ca7 100644 --- a/trie/bintrie/binary_node.go +++ b/trie/bintrie/binary_node.go @@ -16,140 +16,43 @@ package bintrie -import ( - "errors" - - "github.com/ethereum/go-ethereum/common" -) - -type ( - NodeFlushFn func([]byte, BinaryNode) - NodeResolverFn func([]byte, common.Hash) ([]byte, error) -) +import "github.com/ethereum/go-ethereum/common" // zero is the zero value for a 32-byte array. var zero [32]byte const ( - StemNodeWidth = 256 // Number of child per leaf node - StemSize = 31 // Number of bytes to travel before reaching a group of leaves - NodeTypeBytes = 1 // Size of node type prefix in serialization - HashSize = 32 // Size of a hash in bytes - BitmapSize = 32 // Size of the bitmap in a stem node + StemNodeWidth = 256 // Number of children per leaf node + StemSize = 31 // Number of bytes to travel before reaching a group of leaves + NodeTypeBytes = 1 // Size of node type prefix in serialization + HashSize = 32 // Size of a hash in bytes + StemBitmapSize = 32 // Size of the bitmap in a stem node (256 values = 32 bytes) + + // MaxGroupDepth is the maximum allowed group depth for InternalNode serialization. + MaxGroupDepth = 8 ) +// BitmapSizeForDepth returns the bitmap size in bytes for a given group depth. +func BitmapSizeForDepth(groupDepth int) int { + if groupDepth <= 3 { + return 1 + } + return 1 << (groupDepth - 3) +} + const ( - nodeTypeStem = iota + 1 // Stem node, contains a stem and a bitmap of values + nodeTypeStem = iota + 1 nodeTypeInternal ) -// BinaryNode is an interface for a binary trie node. -type BinaryNode interface { - Get([]byte, NodeResolverFn) ([]byte, error) - Insert([]byte, []byte, NodeResolverFn, int) (BinaryNode, error) - Copy() BinaryNode - Hash() common.Hash - GetValuesAtStem([]byte, NodeResolverFn) ([][]byte, error) - InsertValuesAtStem([]byte, [][]byte, NodeResolverFn, int) (BinaryNode, error) - CollectNodes([]byte, NodeFlushFn) error - - toDot(parent, path string) string - GetHeight() int -} - -// SerializeNode serializes a binary trie node into a byte slice. -func SerializeNode(node BinaryNode) []byte { - switch n := (node).(type) { - case *InternalNode: - // InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash - var serialized [NodeTypeBytes + HashSize + HashSize]byte - serialized[0] = nodeTypeInternal - copy(serialized[1:33], n.left.Hash().Bytes()) - copy(serialized[33:65], n.right.Hash().Bytes()) - return serialized[:] - case *StemNode: - // StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values - var serialized [NodeTypeBytes + StemSize + BitmapSize + StemNodeWidth*HashSize]byte - serialized[0] = nodeTypeStem - copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], n.Stem) - bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize] - offset := NodeTypeBytes + StemSize + BitmapSize - for i, v := range n.Values { - if v != nil { - bitmap[i/8] |= 1 << (7 - (i % 8)) - copy(serialized[offset:offset+HashSize], v) - offset += HashSize - } - } - // Only return the actual data, not the entire array - return serialized[:offset] - default: - panic("invalid node type") +// DeserializeAndHash deserializes a node from bytes and returns its hash. +// This is a convenience function for external callers that need to compute +// the hash of a serialized node without maintaining a NodeStore. +func DeserializeAndHash(blob []byte, depth int) (common.Hash, error) { + s := NewNodeStore() + ref, err := s.DeserializeNode(blob, depth) + if err != nil { + return common.Hash{}, err } -} - -var invalidSerializedLength = errors.New("invalid serialized node length") - -// DeserializeNode deserializes a binary trie node from a byte slice. The -// hash will be recomputed from the deserialized data. -func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) { - return deserializeNode(serialized, depth, common.Hash{}, true, true) -} - -// DeserializeNodeWithHash deserializes a binary trie node from a byte slice, using the provided hash. -func DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) { - return deserializeNode(serialized, depth, hn, false, false) -} - -func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (BinaryNode, error) { - if len(serialized) == 0 { - return Empty{}, nil - } - - switch serialized[0] { - case nodeTypeInternal: - if len(serialized) != 65 { - return nil, invalidSerializedLength - } - return &InternalNode{ - depth: depth, - left: HashedNode(common.BytesToHash(serialized[1:33])), - right: HashedNode(common.BytesToHash(serialized[33:65])), - hash: hn, - mustRecompute: mustRecompute, - dirty: dirty, - }, nil - case nodeTypeStem: - if len(serialized) < 64 { - return nil, invalidSerializedLength - } - var values [StemNodeWidth][]byte - bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+BitmapSize] - offset := NodeTypeBytes + StemSize + BitmapSize - - for i := range StemNodeWidth { - if bitmap[i/8]>>(7-(i%8))&1 == 1 { - if len(serialized) < offset+HashSize { - return nil, invalidSerializedLength - } - values[i] = serialized[offset : offset+HashSize] - offset += HashSize - } - } - return &StemNode{ - Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize], - Values: values[:], - depth: depth, - hash: hn, - mustRecompute: mustRecompute, - dirty: dirty, - }, nil - default: - return nil, errors.New("invalid node type") - } -} - -// ToDot converts the binary trie to a DOT language representation. Useful for debugging. -func ToDot(root BinaryNode) string { - return root.toDot("", "") + return s.ComputeHash(ref), nil } diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go index 242743ba53..f8a00bc9f1 100644 --- a/trie/bintrie/binary_node_test.go +++ b/trie/bintrie/binary_node_test.go @@ -24,78 +24,118 @@ import ( ) // TestSerializeDeserializeInternalNode tests serialization and deserialization of InternalNode +// with the grouped subtree format through NodeStore. func TestSerializeDeserializeInternalNode(t *testing.T) { - // Create an internal node with two hashed children leftHash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321") - node := &InternalNode{ - depth: 5, - left: HashedNode(leftHash), - right: HashedNode(rightHash), - } + s := NewNodeStore() + leftRef := s.newHashedRef(leftHash) + rightRef := s.newHashedRef(rightHash) - // Serialize the node - serialized := SerializeNode(node) + rootRef := s.newInternalRef(0) + rootNode := s.getInternal(rootRef.Index()) + rootNode.left = leftRef + rootNode.right = rightRef + s.SetRoot(rootRef) + + // Serialize the node with default group depth of 8 + serialized := s.SerializeNode(rootRef, MaxGroupDepth) // Check the serialized format if serialized[0] != nodeTypeInternal { t.Errorf("Expected type byte to be %d, got %d", nodeTypeInternal, serialized[0]) } - - if len(serialized) != 65 { - t.Errorf("Expected serialized length to be 65, got %d", len(serialized)) + if serialized[1] != MaxGroupDepth { + t.Errorf("Expected group depth to be %d, got %d", MaxGroupDepth, serialized[1]) } - // Deserialize the node - deserialized, err := DeserializeNode(serialized, 5) + bitmapSize := BitmapSizeForDepth(MaxGroupDepth) + expectedLen := 1 + 1 + bitmapSize + 2*HashSize + if len(serialized) != expectedLen { + t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized)) + } + + // Check bitmap bits + bitmap := serialized[2 : 2+bitmapSize] + if bitmap[0]&0x80 == 0 { + t.Error("Expected bit 0 to be set in bitmap (left child)") + } + if bitmap[16]&0x80 == 0 { + t.Error("Expected bit 128 to be set in bitmap (right child)") + } + + // Deserialize into a new store + ds := NewNodeStore() + deserialized, err := ds.DeserializeNode(serialized, 0) if err != nil { t.Fatalf("Failed to deserialize node: %v", err) } - // Check that it's an internal node - internalNode, ok := deserialized.(*InternalNode) - if !ok { - t.Fatalf("Expected InternalNode, got %T", deserialized) + // Root should be an InternalNode + if deserialized.Kind() != KindInternal { + t.Fatalf("Expected KindInternal, got kind %d", deserialized.Kind()) } - // Check the depth - if internalNode.depth != 5 { - t.Errorf("Expected depth 5, got %d", internalNode.depth) + internalNode := ds.getInternal(deserialized.Index()) + if internalNode.depth != 0 { + t.Errorf("Expected depth 0, got %d", internalNode.depth) } - // Check the left and right hashes - if internalNode.left.Hash() != leftHash { - t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash()) + // Navigate to position 0 (8 left turns) to find the left hash + node0 := navigateToLeafRef(ds, deserialized, 0, 8) + if ds.ComputeHash(node0) != leftHash { + t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, ds.ComputeHash(node0)) } - if internalNode.right.Hash() != rightHash { - t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash()) + // Navigate to position 128 (right, then 7 lefts) to find the right hash + node128 := navigateToLeafRef(ds, deserialized, 128, 8) + if ds.ComputeHash(node128) != rightHash { + t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, ds.ComputeHash(node128)) } } -// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode +// navigateToLeafRef navigates to a specific position in the tree using NodeStore. +func navigateToLeafRef(s *NodeStore, ref NodeRef, position, depth int) NodeRef { + cur := ref + for d := 0; d < depth; d++ { + if cur.Kind() != KindInternal { + return cur + } + in := s.getInternal(cur.Index()) + bit := (position >> (depth - 1 - d)) & 1 + if bit == 0 { + cur = in.left + } else { + cur = in.right + } + } + return cur +} + +// TestSerializeDeserializeStemNode tests serialization and deserialization of StemNode through NodeStore. func TestSerializeDeserializeStemNode(t *testing.T) { - // Create a stem node with some values stem := make([]byte, StemSize) for i := range stem { stem[i] = byte(i) } var values [StemNodeWidth][]byte - // Add some values at different indices values[0] = common.HexToHash("0x0101010101010101010101010101010101010101010101010101010101010101").Bytes() values[10] = common.HexToHash("0x0202020202020202020202020202020202020202020202020202020202020202").Bytes() values[255] = common.HexToHash("0x0303030303030303030303030303030303030303030303030303030303030303").Bytes() - node := &StemNode{ - Stem: stem, - Values: values[:], - depth: 10, + s := NewNodeStore() + ref := s.newStemRef(stem, 10) + sn := s.getStem(ref.Index()) + for i, v := range values { + if v != nil { + sn.setValue(byte(i), v) + } } // Serialize the node - serialized := SerializeNode(node) + serialized := s.SerializeNode(ref, MaxGroupDepth) // Check the serialized format if serialized[0] != nodeTypeStem { @@ -107,31 +147,32 @@ func TestSerializeDeserializeStemNode(t *testing.T) { t.Errorf("Stem mismatch in serialized data") } - // Deserialize the node - deserialized, err := DeserializeNode(serialized, 10) + // Deserialize into a new store + ds := NewNodeStore() + deserializedRef, err := ds.DeserializeNode(serialized, 10) if err != nil { t.Fatalf("Failed to deserialize node: %v", err) } - // Check that it's a stem node - stemNode, ok := deserialized.(*StemNode) - if !ok { - t.Fatalf("Expected StemNode, got %T", deserialized) + if deserializedRef.Kind() != KindStem { + t.Fatalf("Expected KindStem, got kind %d", deserializedRef.Kind()) } + stemNode := ds.getStem(deserializedRef.Index()) + // Check the stem - if !bytes.Equal(stemNode.Stem, stem) { + if !bytes.Equal(stemNode.Stem[:], stem) { t.Errorf("Stem mismatch after deserialization") } // Check the values - if !bytes.Equal(stemNode.Values[0], values[0]) { + if !bytes.Equal(stemNode.getValue(0), values[0]) { t.Errorf("Value at index 0 mismatch") } - if !bytes.Equal(stemNode.Values[10], values[10]) { + if !bytes.Equal(stemNode.getValue(10), values[10]) { t.Errorf("Value at index 10 mismatch") } - if !bytes.Equal(stemNode.Values[255], values[255]) { + if !bytes.Equal(stemNode.getValue(255), values[255]) { t.Errorf("Value at index 255 mismatch") } @@ -140,43 +181,43 @@ func TestSerializeDeserializeStemNode(t *testing.T) { if i == 0 || i == 10 || i == 255 { continue } - if stemNode.Values[i] != nil { - t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i]) + if stemNode.hasValue(byte(i)) { + t.Errorf("Expected no value at index %d, got %x", i, stemNode.getValue(byte(i))) } } } -// TestDeserializeEmptyNode tests deserialization of empty node +// TestDeserializeEmptyNode tests deserialization of empty node. func TestDeserializeEmptyNode(t *testing.T) { - // Empty byte slice should deserialize to Empty node - deserialized, err := DeserializeNode([]byte{}, 0) + s := NewNodeStore() + deserialized, err := s.DeserializeNode([]byte{}, 0) if err != nil { t.Fatalf("Failed to deserialize empty node: %v", err) } - _, ok := deserialized.(Empty) - if !ok { - t.Fatalf("Expected Empty node, got %T", deserialized) + if !deserialized.IsEmpty() { + t.Fatalf("Expected EmptyRef, got kind %d", deserialized.Kind()) } } -// TestDeserializeInvalidType tests deserialization with invalid type byte +// TestDeserializeInvalidType tests deserialization with invalid type byte. func TestDeserializeInvalidType(t *testing.T) { - // Create invalid serialized data with unknown type byte + s := NewNodeStore() invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid - _, err := DeserializeNode(invalidData, 0) + _, err := s.DeserializeNode(invalidData, 0) if err == nil { t.Fatal("Expected error for invalid type byte, got nil") } } -// TestDeserializeInvalidLength tests deserialization with invalid data length +// TestDeserializeInvalidLength tests deserialization with invalid data length. func TestDeserializeInvalidLength(t *testing.T) { - // InternalNode with type byte 1 but wrong length - invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node + s := NewNodeStore() + // InternalNode with valid type byte and group depth but too short for bitmap + invalidData := []byte{nodeTypeInternal, 8, 0, 0} // Too short for bitmap (needs 32 bytes) - _, err := DeserializeNode(invalidData, 0) + _, err := s.DeserializeNode(invalidData, 0) if err == nil { t.Fatal("Expected error for invalid data length, got nil") } @@ -186,7 +227,7 @@ func TestDeserializeInvalidLength(t *testing.T) { } } -// TestKeyToPath tests the keyToPath function +// TestKeyToPath tests the keyToPath function. func TestKeyToPath(t *testing.T) { tests := []struct { name string @@ -218,14 +259,14 @@ func TestKeyToPath(t *testing.T) { }, { name: "max valid depth", - depth: StemSize * 8, + depth: StemSize*8 - 1, key: make([]byte, HashSize), - expected: make([]byte, StemSize*8+1), + expected: make([]byte, StemSize*8), wantErr: false, }, { name: "depth too large", - depth: StemSize*8 + 1, + depth: StemSize * 8, key: make([]byte, HashSize), wantErr: true, }, diff --git a/trie/bintrie/empty.go b/trie/bintrie/empty.go deleted file mode 100644 index c47e284dac..0000000000 --- a/trie/bintrie/empty.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2025 go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package bintrie - -import ( - "slices" - - "github.com/ethereum/go-ethereum/common" -) - -type Empty struct{} - -func (e Empty) Get(_ []byte, _ NodeResolverFn) ([]byte, error) { - return nil, nil -} - -func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) { - var values [256][]byte - values[key[31]] = value - return &StemNode{ - Stem: slices.Clone(key[:31]), - Values: values[:], - depth: depth, - mustRecompute: true, - dirty: true, - }, nil -} - -func (e Empty) Copy() BinaryNode { - return Empty{} -} - -func (e Empty) Hash() common.Hash { - return common.Hash{} -} - -func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) { - var values [256][]byte - return values[:], nil -} - -func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) { - return &StemNode{ - Stem: slices.Clone(key[:31]), - Values: values, - depth: depth, - mustRecompute: true, - dirty: true, - }, nil -} - -func (e Empty) CollectNodes(_ []byte, _ NodeFlushFn) error { - return nil -} - -func (e Empty) toDot(parent string, path string) string { - return "" -} - -func (e Empty) GetHeight() int { - return 0 -} diff --git a/trie/bintrie/empty_test.go b/trie/bintrie/empty_test.go deleted file mode 100644 index 4da1ed15a0..0000000000 --- a/trie/bintrie/empty_test.go +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright 2025 go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package bintrie - -import ( - "bytes" - "testing" - - "github.com/ethereum/go-ethereum/common" -) - -// TestEmptyGet tests the Get method -func TestEmptyGet(t *testing.T) { - node := Empty{} - - key := make([]byte, 32) - value, err := node.Get(key, nil) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if value != nil { - t.Errorf("Expected nil value from empty node, got %x", value) - } -} - -// TestEmptyInsert tests the Insert method -func TestEmptyInsert(t *testing.T) { - node := Empty{} - - key := make([]byte, 32) - key[0] = 0x12 - key[31] = 0x34 - value := common.HexToHash("0xabcd").Bytes() - - newNode, err := node.Insert(key, value, nil, 0) - if err != nil { - t.Fatalf("Failed to insert: %v", err) - } - - // Should create a StemNode - stemNode, ok := newNode.(*StemNode) - if !ok { - t.Fatalf("Expected StemNode, got %T", newNode) - } - - // Check the stem (first 31 bytes of key) - if !bytes.Equal(stemNode.Stem, key[:31]) { - t.Errorf("Stem mismatch: expected %x, got %x", key[:31], stemNode.Stem) - } - - // Check the value at the correct index (last byte of key) - if !bytes.Equal(stemNode.Values[key[31]], value) { - t.Errorf("Value mismatch at index %d: expected %x, got %x", key[31], value, stemNode.Values[key[31]]) - } - - // Check that other values are nil - for i := 0; i < 256; i++ { - if i != int(key[31]) && stemNode.Values[i] != nil { - t.Errorf("Expected nil value at index %d, got %x", i, stemNode.Values[i]) - } - } -} - -// TestEmptyCopy tests the Copy method -func TestEmptyCopy(t *testing.T) { - node := Empty{} - - copied := node.Copy() - copiedEmpty, ok := copied.(Empty) - if !ok { - t.Fatalf("Expected Empty, got %T", copied) - } - - // Both should be empty - if node != copiedEmpty { - // Empty is a zero-value struct, so copies should be equal - t.Errorf("Empty nodes should be equal") - } -} - -// TestEmptyHash tests the Hash method -func TestEmptyHash(t *testing.T) { - node := Empty{} - - hash := node.Hash() - - // Empty node should have zero hash - if hash != (common.Hash{}) { - t.Errorf("Expected zero hash for empty node, got %x", hash) - } -} - -// TestEmptyGetValuesAtStem tests the GetValuesAtStem method -func TestEmptyGetValuesAtStem(t *testing.T) { - node := Empty{} - - stem := make([]byte, 31) - values, err := node.GetValuesAtStem(stem, nil) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - // Should return an array of 256 nil values - if len(values) != 256 { - t.Errorf("Expected 256 values, got %d", len(values)) - } - - for i, v := range values { - if v != nil { - t.Errorf("Expected nil value at index %d, got %x", i, v) - } - } -} - -// TestEmptyInsertValuesAtStem tests the InsertValuesAtStem method -func TestEmptyInsertValuesAtStem(t *testing.T) { - node := Empty{} - - stem := make([]byte, 31) - stem[0] = 0x42 - - var values [256][]byte - values[0] = common.HexToHash("0x0101").Bytes() - values[10] = common.HexToHash("0x0202").Bytes() - values[255] = common.HexToHash("0x0303").Bytes() - - newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 5) - if err != nil { - t.Fatalf("Failed to insert values: %v", err) - } - - // Should create a StemNode - stemNode, ok := newNode.(*StemNode) - if !ok { - t.Fatalf("Expected StemNode, got %T", newNode) - } - - // Check the stem - if !bytes.Equal(stemNode.Stem, stem) { - t.Errorf("Stem mismatch: expected %x, got %x", stem, stemNode.Stem) - } - - // Check the depth - if stemNode.depth != 5 { - t.Errorf("Depth mismatch: expected 5, got %d", stemNode.depth) - } - - // Check the values - if !bytes.Equal(stemNode.Values[0], values[0]) { - t.Error("Value at index 0 mismatch") - } - if !bytes.Equal(stemNode.Values[10], values[10]) { - t.Error("Value at index 10 mismatch") - } - if !bytes.Equal(stemNode.Values[255], values[255]) { - t.Error("Value at index 255 mismatch") - } - - // Check that values is the same slice (not a copy) - if &stemNode.Values[0] != &values[0] { - t.Error("Expected values to be the same slice reference") - } -} - -// TestEmptyCollectNodes tests the CollectNodes method -func TestEmptyCollectNodes(t *testing.T) { - node := Empty{} - - var collected []BinaryNode - flushFn := func(path []byte, n BinaryNode) { - collected = append(collected, n) - } - - err := node.CollectNodes([]byte{0, 1, 0}, flushFn) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - // Should not collect anything for empty node - if len(collected) != 0 { - t.Errorf("Expected no collected nodes for empty, got %d", len(collected)) - } -} - -// TestEmptyToDot tests the toDot method -func TestEmptyToDot(t *testing.T) { - node := Empty{} - - dot := node.toDot("parent", "010") - - // Should return empty string for empty node - if dot != "" { - t.Errorf("Expected empty string for empty node toDot, got %s", dot) - } -} - -// TestEmptyGetHeight tests the GetHeight method -func TestEmptyGetHeight(t *testing.T) { - node := Empty{} - - height := node.GetHeight() - - // Empty node should have height 0 - if height != 0 { - t.Errorf("Expected height 0 for empty node, got %d", height) - } -} - -// TestEmptyInsertMarksDirty verifies that a StemNode produced by Empty.Insert -// is marked dirty. Without this, CollectNodes would skip the freshly created -// stem and its blob would never reach disk, producing "missing trie node" -// errors on subsequent reads. -func TestEmptyInsertMarksDirty(t *testing.T) { - key := make([]byte, 32) - key[0] = 0xaa - val := make([]byte, 32) - val[0] = 0xbb - n, err := Empty{}.Insert(key, val, nil, 0) - if err != nil { - t.Fatalf("Insert: %v", err) - } - sn, ok := n.(*StemNode) - if !ok { - t.Fatalf("expected *StemNode, got %T", n) - } - if !sn.dirty { - t.Fatalf("stem produced by Empty.Insert must have dirty=true") - } -} - -// TestEmptyInsertValuesAtStemMarksDirty is the analogous guard for the -// bulk-insert entry point. Fresh stems created here must be dirty. -func TestEmptyInsertValuesAtStemMarksDirty(t *testing.T) { - key := make([]byte, 32) - key[0] = 0xcc - values := make([][]byte, 256) - values[0] = make([]byte, 32) - n, err := Empty{}.InsertValuesAtStem(key, values, nil, 3) - if err != nil { - t.Fatalf("InsertValuesAtStem: %v", err) - } - sn, ok := n.(*StemNode) - if !ok { - t.Fatalf("expected *StemNode, got %T", n) - } - if !sn.dirty { - t.Fatalf("stem produced by Empty.InsertValuesAtStem must have dirty=true") - } -} diff --git a/trie/bintrie/hashed_node.go b/trie/bintrie/hashed_node.go index e44c6d1e8a..9d349beaf8 100644 --- a/trie/bintrie/hashed_node.go +++ b/trie/bintrie/hashed_node.go @@ -16,75 +16,9 @@ package bintrie -import ( - "errors" - "fmt" +import "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common" -) - -type HashedNode common.Hash - -func (h HashedNode) Get(_ []byte, _ NodeResolverFn) ([]byte, error) { - panic("not implemented") // TODO: Implement -} - -func (h HashedNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { - return nil, errors.New("insert not implemented for hashed node") -} - -func (h HashedNode) Copy() BinaryNode { - nh := common.Hash(h) - return HashedNode(nh) -} - -func (h HashedNode) Hash() common.Hash { - return common.Hash(h) -} - -func (h HashedNode) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) { - return nil, errors.New("attempted to get values from an unresolved node") -} - -func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { - // Step 1: Generate the path for this node's position in the tree - path, err := keyToPath(depth, stem) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem path generation error: %w", err) - } - - if resolver == nil { - return nil, errors.New("InsertValuesAtStem resolve error: resolver is nil") - } - - // Step 2: Resolve the hashed node to get the actual node data - data, err := resolver(path, common.Hash(h)) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) - } - - // Step 3: Deserialize the resolved data into a concrete node - node, err := DeserializeNodeWithHash(data, depth, common.Hash(h)) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) - } - - // Step 4: Call InsertValuesAtStem on the resolved concrete node - return node.InsertValuesAtStem(stem, values, resolver, depth) -} - -func (h HashedNode) toDot(parent string, path string) string { - me := fmt.Sprintf("hash%s", path) - ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, h) - ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) - return ret -} - -func (h HashedNode) CollectNodes([]byte, NodeFlushFn) error { - // HashedNodes are already persisted in the database and don't need to be collected. - return nil -} - -func (h HashedNode) GetHeight() int { - panic("tried to get the height of a hashed node, this is a bug") +// HashedNode represents an unresolved node that only stores its hash. +type HashedNode struct { + hash common.Hash } diff --git a/trie/bintrie/hashed_node_test.go b/trie/bintrie/hashed_node_test.go index f9e6984888..b723f7f256 100644 --- a/trie/bintrie/hashed_node_test.go +++ b/trie/bintrie/hashed_node_test.go @@ -18,180 +18,137 @@ package bintrie import ( "bytes" + "errors" "testing" "github.com/ethereum/go-ethereum/common" ) -// TestHashedNodeHash tests the Hash method +// TestHashedNodeHash tests the Hash method via NodeStore. func TestHashedNodeHash(t *testing.T) { hash := common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") - node := HashedNode(hash) + s := NewNodeStore() + ref := s.newHashedRef(hash) - // Hash should return the stored hash - if node.Hash() != hash { - t.Errorf("Hash mismatch: expected %x, got %x", hash, node.Hash()) + if s.ComputeHash(ref) != hash { + t.Errorf("Hash mismatch: expected %x, got %x", hash, s.ComputeHash(ref)) } } -// TestHashedNodeCopy tests the Copy method +// TestHashedNodeCopy tests the Copy method via NodeStore. func TestHashedNodeCopy(t *testing.T) { hash := common.HexToHash("0xabcdef") - node := HashedNode(hash) + s := NewNodeStore() + ref := s.newHashedRef(hash) + s.SetRoot(ref) - copied := node.Copy() - copiedHash, ok := copied.(HashedNode) - if !ok { - t.Fatalf("Expected HashedNode, got %T", copied) - } + ns := s.Copy() + copiedHash := ns.ComputeHash(ns.Root()) - // Hash should be the same - if common.Hash(copiedHash) != hash { + if copiedHash != hash { t.Errorf("Hash mismatch after copy: expected %x, got %x", hash, copiedHash) } - - // But should be a different object - if &node == &copiedHash { - t.Error("Copy returned same object reference") - } } -// TestHashedNodeInsert tests that Insert returns an error -func TestHashedNodeInsert(t *testing.T) { - node := HashedNode(common.HexToHash("0x1234")) - - key := make([]byte, HashSize) - value := make([]byte, HashSize) - - _, err := node.Insert(key, value, nil, 0) - if err == nil { - t.Fatal("Expected error for Insert on HashedNode") - } - - if err.Error() != "insert not implemented for hashed node" { - t.Errorf("Unexpected error message: %v", err) - } -} - -// TestHashedNodeGetValuesAtStem tests that GetValuesAtStem returns an error -func TestHashedNodeGetValuesAtStem(t *testing.T) { - node := HashedNode(common.HexToHash("0x1234")) - - stem := make([]byte, StemSize) - _, err := node.GetValuesAtStem(stem, nil) - if err == nil { - t.Fatal("Expected error for GetValuesAtStem on HashedNode") - } - - if err.Error() != "attempted to get values from an unresolved node" { - t.Errorf("Unexpected error message: %v", err) - } -} - -// TestHashedNodeInsertValuesAtStem tests that InsertValuesAtStem returns an error +// TestHashedNodeInsertValuesAtStem tests InsertValuesAtStem resolution via NodeStore. func TestHashedNodeInsertValuesAtStem(t *testing.T) { - node := HashedNode(common.HexToHash("0x1234")) + // Test 1: nil resolver should return an error + s := NewNodeStore() + hashedRef := s.newHashedRef(common.HexToHash("0x1234")) + s.SetRoot(hashedRef) stem := make([]byte, StemSize) values := make([][]byte, StemNodeWidth) - // Test 1: nil resolver should return an error - _, err := node.InsertValuesAtStem(stem, values, nil, 0) + err := s.InsertValuesAtStem(stem, values, nil) if err == nil { - t.Fatal("Expected error for InsertValuesAtStem on HashedNode with nil resolver") - } - - if err.Error() != "InsertValuesAtStem resolve error: resolver is nil" { - t.Errorf("Unexpected error message: %v", err) + t.Fatal("Expected error for InsertValuesAtStem with nil resolver") } // Test 2: mock resolver returning invalid data should return deserialization error mockResolver := func(path []byte, hash common.Hash) ([]byte, error) { - // Return invalid/nonsense data that cannot be deserialized return []byte{0xff, 0xff, 0xff}, nil } - _, err = node.InsertValuesAtStem(stem, values, mockResolver, 0) - if err == nil { - t.Fatal("Expected error for InsertValuesAtStem on HashedNode with invalid resolver data") - } + s2 := NewNodeStore() + hashedRef2 := s2.newHashedRef(common.HexToHash("0x1234")) + s2.SetRoot(hashedRef2) - expectedPrefix := "InsertValuesAtStem node deserialization error:" - if len(err.Error()) < len(expectedPrefix) || err.Error()[:len(expectedPrefix)] != expectedPrefix { - t.Errorf("Expected deserialization error, got: %v", err) + err = s2.InsertValuesAtStem(stem, values, mockResolver) + if err == nil { + t.Fatal("Expected error for InsertValuesAtStem with invalid resolver data") } // Test 3: mock resolver returning valid serialized node should succeed stem = make([]byte, StemSize) stem[0] = 0xaa - var originalValues [StemNodeWidth][]byte + originalValues := make([][]byte, StemNodeWidth) originalValues[0] = common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111").Bytes() originalValues[1] = common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222").Bytes() - originalNode := &StemNode{ - Stem: stem, - Values: originalValues[:], - depth: 0, + // Build the serialized node + rs := NewNodeStore() + ref := rs.newStemRef(stem, 0) + sn := rs.getStem(ref.Index()) + for i, v := range originalValues { + if v != nil { + sn.setValue(byte(i), v) + } } + serialized := rs.SerializeNode(ref, MaxGroupDepth) - // Serialize the node - serialized := SerializeNode(originalNode) - - // Create a mock resolver that returns the serialized node validResolver := func(path []byte, hash common.Hash) ([]byte, error) { return serialized, nil } - var newValues [StemNodeWidth][]byte + s3 := NewNodeStore() + hashedRef3 := s3.newHashedRef(common.HexToHash("0x1234")) + s3.SetRoot(hashedRef3) + + newValues := make([][]byte, StemNodeWidth) newValues[2] = common.HexToHash("0x3333333333333333333333333333333333333333333333333333333333333333").Bytes() - resolvedNode, err := node.InsertValuesAtStem(stem, newValues[:], validResolver, 0) + err = s3.InsertValuesAtStem(stem, newValues, validResolver) if err != nil { t.Fatalf("Expected successful resolution and insertion, got error: %v", err) } - resultStem, ok := resolvedNode.(*StemNode) - if !ok { - t.Fatalf("Expected resolved node to be *StemNode, got %T", resolvedNode) + // Verify original values are preserved + retrieved, err := s3.GetValuesAtStem(stem, nil) + if err != nil { + t.Fatal(err) } - - if !bytes.Equal(resultStem.Stem, stem) { - t.Errorf("Stem mismatch: expected %x, got %x", stem, resultStem.Stem) + if !bytes.Equal(retrieved[0], originalValues[0]) { + t.Errorf("Original value at index 0 not preserved") } - - // Verify the original values are preserved - if !bytes.Equal(resultStem.Values[0], originalValues[0]) { - t.Errorf("Original value at index 0 not preserved: expected %x, got %x", originalValues[0], resultStem.Values[0]) + if !bytes.Equal(retrieved[1], originalValues[1]) { + t.Errorf("Original value at index 1 not preserved") } - if !bytes.Equal(resultStem.Values[1], originalValues[1]) { - t.Errorf("Original value at index 1 not preserved: expected %x, got %x", originalValues[1], resultStem.Values[1]) - } - - // Verify the new value was inserted - if !bytes.Equal(resultStem.Values[2], newValues[2]) { - t.Errorf("New value at index 2 not inserted correctly: expected %x, got %x", newValues[2], resultStem.Values[2]) + if !bytes.Equal(retrieved[2], newValues[2]) { + t.Errorf("New value at index 2 not inserted correctly") } } -// TestHashedNodeToDot tests the toDot method for visualization -func TestHashedNodeToDot(t *testing.T) { - hash := common.HexToHash("0x1234") - node := HashedNode(hash) +// TestHashedNodeGetError tests that getting through an unresolved HashedNode root returns error. +func TestHashedNodeGetError(t *testing.T) { + s := NewNodeStore() + // Create root as hashed, then try to resolve through InternalNode parent + rootRef := s.newInternalRef(0) + rootNode := s.getInternal(rootRef.Index()) + hashedLeft := s.newHashedRef(common.HexToHash("0x1234")) + rootNode.left = hashedLeft + rootNode.right = EmptyRef + s.SetRoot(rootRef) - dot := node.toDot("parent", "010") + key := make([]byte, 32) // goes left + key[31] = 5 - // Should contain the hash value and parent connection - expectedHash := "hash010" - if !contains(dot, expectedHash) { - t.Errorf("Expected dot output to contain %s", expectedHash) + resolver := func(path []byte, hash common.Hash) ([]byte, error) { + return nil, errors.New("node not found") } - if !contains(dot, "parent -> hash010") { - t.Error("Expected dot output to contain parent connection") + _, err := s.Get(key, resolver) + if err == nil { + t.Fatal("Expected error when resolver fails") } } - -// Helper function -func contains(s, substr string) bool { - return len(s) >= len(substr) && s != "" && len(substr) > 0 -} diff --git a/trie/bintrie/hasher.go b/trie/bintrie/hasher.go index b81c145723..d56cba96c5 100644 --- a/trie/bintrie/hasher.go +++ b/trie/bintrie/hasher.go @@ -37,3 +37,11 @@ func newSha256() hash.Hash { func returnSha256(h hash.Hash) { sha256Pool.Put(h) } + +// sha256Sum256 computes a sha256 digest and returns it as a common.Hash. +func sha256Sum256(data []byte) [32]byte { + return sha256.Sum256(data) +} + +// parallelHashDepth controls below which depth hashing is parallelised. +const parallelHashDepth = 4 diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go index 811f65bcd8..51beecefd5 100644 --- a/trie/bintrie/internal_node.go +++ b/trie/bintrie/internal_node.go @@ -17,35 +17,13 @@ package bintrie import ( - "crypto/sha256" "errors" - "fmt" - "math/bits" - "runtime" - "sync" "github.com/ethereum/go-ethereum/common" ) -// parallelDepth returns the tree depth below which Hash() spawns goroutines. -func parallelDepth() int { - return min(bits.Len(uint(runtime.NumCPU())), 8) -} - -// isDirty reports whether a BinaryNode child needs rehashing. -func isDirty(n BinaryNode) bool { - switch v := n.(type) { - case *InternalNode: - return v.mustRecompute - case *StemNode: - return v.mustRecompute - default: - return false - } -} - func keyToPath(depth int, key []byte) ([]byte, error) { - if depth > 31*8 { + if depth >= 31*8 { return nil, errors.New("node too deep") } path := make([]byte, 0, depth+1) @@ -56,252 +34,20 @@ func keyToPath(depth int, key []byte) ([]byte, error) { return path, nil } +// makeKeyPath is a simplified version of keyToPath that doesn't return an error. +func makeKeyPath(depth int, key []byte) []byte { + path := make([]byte, 0, depth+1) + for i := range depth + 1 { + bit := key[i/8] >> (7 - (i % 8)) & 1 + path = append(path, bit) + } + return path +} + // InternalNode is a binary trie internal node. type InternalNode struct { - left, right BinaryNode - depth int - + left, right NodeRef + depth uint8 mustRecompute bool // true if the hash needs to be recomputed - dirty bool // true if the node's on-disk blob is stale (needs flush) hash common.Hash // cached hash when mustRecompute == false } - -// GetValuesAtStem retrieves the group of values located at the given stem key. -func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) { - if bt.depth > 31*8 { - return nil, errors.New("node too deep") - } - - bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 - if bit == 0 { - if hn, ok := bt.left.(HashedNode); ok { - path, err := keyToPath(bt.depth, stem) - if err != nil { - return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) - } - data, err := resolver(path, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) - } - node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) - } - bt.left = node - } - return bt.left.GetValuesAtStem(stem, resolver) - } - - if hn, ok := bt.right.(HashedNode); ok { - path, err := keyToPath(bt.depth, stem) - if err != nil { - return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) - } - data, err := resolver(path, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) - } - node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) - } - bt.right = node - } - return bt.right.GetValuesAtStem(stem, resolver) -} - -// Get retrieves the value for the given key. -func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error) { - values, err := bt.GetValuesAtStem(key[:31], resolver) - if err != nil { - return nil, fmt.Errorf("get error: %w", err) - } - if values == nil { - return nil, nil - } - return values[key[31]], nil -} - -// Insert inserts a new key-value pair into the trie. -func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { - var values [256][]byte - values[key[31]] = value - return bt.InsertValuesAtStem(key[:31], values[:], resolver, depth) -} - -// Copy creates a deep copy of the node. -func (bt *InternalNode) Copy() BinaryNode { - return &InternalNode{ - left: bt.left.Copy(), - right: bt.right.Copy(), - depth: bt.depth, - mustRecompute: bt.mustRecompute, - dirty: bt.dirty, - hash: bt.hash, - } -} - -// Hash returns the hash of the node. -func (bt *InternalNode) Hash() common.Hash { - if !bt.mustRecompute { - return bt.hash - } - - // At shallow depths, parallelize when both children need rehashing: - // hash left subtree in a goroutine, right subtree inline, then combine. - // Skip goroutine overhead when only one child is dirty (common case - // for narrow state updates that touch a single path through the trie). - if bt.depth < parallelDepth() && isDirty(bt.left) && isDirty(bt.right) { - var input [64]byte - var lh common.Hash - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - lh = bt.left.Hash() - }() - rh := bt.right.Hash() - copy(input[32:], rh[:]) - wg.Wait() - copy(input[:32], lh[:]) - bt.hash = sha256.Sum256(input[:]) - bt.mustRecompute = false - return bt.hash - } - - // Deeper nodes: sequential using pooled hasher (goroutine overhead > hash cost) - h := newSha256() - defer returnSha256(h) - if bt.left != nil { - h.Write(bt.left.Hash().Bytes()) - } else { - h.Write(zero[:]) - } - if bt.right != nil { - h.Write(bt.right.Hash().Bytes()) - } else { - h.Write(zero[:]) - } - bt.hash = common.BytesToHash(h.Sum(nil)) - bt.mustRecompute = false - return bt.hash -} - -// InsertValuesAtStem inserts a full value group at the given stem in the internal node. -// Already-existing values will be overwritten. -func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { - var err error - bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 - if bit == 0 { - if bt.left == nil { - bt.left = Empty{} - } - - if hn, ok := bt.left.(HashedNode); ok { - path, err := keyToPath(bt.depth, stem) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) - } - data, err := resolver(path, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) - } - node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) - } - bt.left = node - } - - bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1) - bt.mustRecompute = true - bt.dirty = true - return bt, err - } - - if bt.right == nil { - bt.right = Empty{} - } - - if hn, ok := bt.right.(HashedNode); ok { - path, err := keyToPath(bt.depth, stem) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) - } - data, err := resolver(path, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) - } - node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn)) - if err != nil { - return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) - } - bt.right = node - } - - bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1) - bt.mustRecompute = true - bt.dirty = true - return bt, err -} - -// CollectNodes collects all child nodes at a given path, and flushes it -// into the provided node collector. Clean subtrees (dirty == false) are -// skipped. -func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error { - if !bt.dirty { - return nil - } - if bt.left != nil { - var p [256]byte - copy(p[:], path) - childpath := p[:len(path)] - childpath = append(childpath, 0) - if err := bt.left.CollectNodes(childpath, flushfn); err != nil { - return err - } - } - if bt.right != nil { - var p [256]byte - copy(p[:], path) - childpath := p[:len(path)] - childpath = append(childpath, 1) - if err := bt.right.CollectNodes(childpath, flushfn); err != nil { - return err - } - } - flushfn(path, bt) - bt.dirty = false - return nil -} - -// GetHeight returns the height of the node. -func (bt *InternalNode) GetHeight() int { - var ( - leftHeight int - rightHeight int - ) - if bt.left != nil { - leftHeight = bt.left.GetHeight() - } - if bt.right != nil { - rightHeight = bt.right.GetHeight() - } - return 1 + max(leftHeight, rightHeight) -} - -func (bt *InternalNode) toDot(parent, path string) string { - me := fmt.Sprintf("internal%s", path) - ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, bt.Hash()) - if len(parent) > 0 { - ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) - } - - if bt.left != nil { - ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0))) - } - if bt.right != nil { - ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1))) - } - return ret -} diff --git a/trie/bintrie/internal_node_test.go b/trie/bintrie/internal_node_test.go index ddcec8085d..7ae94b1b8c 100644 --- a/trie/bintrie/internal_node_test.go +++ b/trie/bintrie/internal_node_test.go @@ -24,35 +24,33 @@ import ( "github.com/ethereum/go-ethereum/common" ) -// TestInternalNodeGet tests the Get method +// TestInternalNodeGet tests the Get method via NodeStore. func TestInternalNodeGet(t *testing.T) { - // Create a simple tree structure + s := NewNodeStore() + leftStem := make([]byte, 31) rightStem := make([]byte, 31) - rightStem[0] = 0x80 // First bit is 1 + rightStem[0] = 0x80 - var leftValues, rightValues [256][]byte + leftValues := make([][]byte, 256) leftValues[0] = common.HexToHash("0x0101").Bytes() + rightValues := make([][]byte, 256) rightValues[0] = common.HexToHash("0x0202").Bytes() - node := &InternalNode{ - depth: 0, - left: &StemNode{ - Stem: leftStem, - Values: leftValues[:], - depth: 1, - }, - right: &StemNode{ - Stem: rightStem, - Values: rightValues[:], - depth: 1, - }, + // Build tree: root -> left stem, right stem + // Insert left stem values + s.SetRoot(EmptyRef) + if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil { + t.Fatal(err) + } + if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil { + t.Fatal(err) } // Get value from left subtree leftKey := make([]byte, 32) leftKey[31] = 0 - value, err := node.Get(leftKey, nil) + value, err := s.Get(leftKey, nil) if err != nil { t.Fatalf("Failed to get left value: %v", err) } @@ -64,7 +62,7 @@ func TestInternalNodeGet(t *testing.T) { rightKey := make([]byte, 32) rightKey[0] = 0x80 rightKey[31] = 0 - value, err = node.Get(rightKey, nil) + value, err = s.Get(rightKey, nil) if err != nil { t.Fatalf("Failed to get right value: %v", err) } @@ -73,29 +71,26 @@ func TestInternalNodeGet(t *testing.T) { } } -// TestInternalNodeGetWithResolver tests Get with HashedNode resolution +// TestInternalNodeGetWithResolver tests Get with HashedNode resolution via NodeStore. func TestInternalNodeGetWithResolver(t *testing.T) { - // Create an internal node with a hashed child - hashedChild := HashedNode(common.HexToHash("0x1234")) - - node := &InternalNode{ - depth: 0, - left: hashedChild, - right: Empty{}, - } + // Create a store with an internal node containing a hashed child + s := NewNodeStore() + hashedChild := s.newHashedRef(common.HexToHash("0x1234")) + rootRef := s.newInternalRef(0) + rootNode := s.getInternal(rootRef.Index()) + rootNode.left = hashedChild + rootNode.right = EmptyRef + s.SetRoot(rootRef) // Mock resolver that returns a stem node resolver := func(path []byte, hash common.Hash) ([]byte, error) { - if hash == common.Hash(hashedChild) { + if hash == common.HexToHash("0x1234") { + rs := NewNodeStore() stem := make([]byte, 31) - var values [256][]byte - values[5] = common.HexToHash("0xabcd").Bytes() - stemNode := &StemNode{ - Stem: stem, - Values: values[:], - depth: 1, - } - return SerializeNode(stemNode), nil + ref := rs.newStemRef(stem, 1) + sn := rs.getStem(ref.Index()) + sn.setValue(5, common.HexToHash("0xabcd").Bytes()) + return rs.SerializeNode(ref, MaxGroupDepth), nil } return nil, errors.New("node not found") } @@ -103,7 +98,7 @@ func TestInternalNodeGetWithResolver(t *testing.T) { // Get value through the hashed node key := make([]byte, 32) key[31] = 5 - value, err := node.Get(key, resolver) + value, err := s.Get(key, resolver) if err != nil { t.Fatalf("Failed to get value: %v", err) } @@ -114,179 +109,113 @@ func TestInternalNodeGetWithResolver(t *testing.T) { } } -// TestInternalNodeInsert tests the Insert method +// TestInternalNodeInsert tests the Insert method via NodeStore. func TestInternalNodeInsert(t *testing.T) { - // Start with an internal node with empty children - node := &InternalNode{ - depth: 0, - left: Empty{}, - right: Empty{}, - } + s := NewNodeStore() - // Insert a value into the left subtree leftKey := make([]byte, 32) leftKey[31] = 10 leftValue := common.HexToHash("0x0101").Bytes() - newNode, err := node.Insert(leftKey, leftValue, nil, 0) - if err != nil { + if err := s.Insert(leftKey, leftValue, nil); err != nil { t.Fatalf("Failed to insert: %v", err) } - internalNode, ok := newNode.(*InternalNode) - if !ok { - t.Fatalf("Expected InternalNode, got %T", newNode) + // Verify the value was stored + value, err := s.Get(leftKey, nil) + if err != nil { + t.Fatalf("Failed to get: %v", err) } - - // Check that left child is now a StemNode - leftStem, ok := internalNode.left.(*StemNode) - if !ok { - t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) - } - - // Check the inserted value - if !bytes.Equal(leftStem.Values[10], leftValue) { - t.Errorf("Value mismatch: expected %x, got %x", leftValue, leftStem.Values[10]) - } - - // Right child should still be Empty - _, ok = internalNode.right.(Empty) - if !ok { - t.Errorf("Expected right child to remain Empty, got %T", internalNode.right) + if !bytes.Equal(value, leftValue) { + t.Errorf("Value mismatch: expected %x, got %x", leftValue, value) } } -// TestInternalNodeCopy tests the Copy method +// TestInternalNodeCopy tests the Copy method via NodeStore. func TestInternalNodeCopy(t *testing.T) { - // Create an internal node with stem children - leftStem := &StemNode{ - Stem: make([]byte, 31), - Values: make([][]byte, 256), - depth: 1, - } - leftStem.Values[0] = common.HexToHash("0x0101").Bytes() + s := NewNodeStore() - rightStem := &StemNode{ - Stem: make([]byte, 31), - Values: make([][]byte, 256), - depth: 1, - } - rightStem.Stem[0] = 0x80 - rightStem.Values[0] = common.HexToHash("0x0202").Bytes() + leftKey := make([]byte, 32) + leftKey[31] = 0 + leftValue := common.HexToHash("0x0101").Bytes() - node := &InternalNode{ - depth: 0, - left: leftStem, - right: rightStem, + rightKey := make([]byte, 32) + rightKey[0] = 0x80 + rightKey[31] = 0 + rightValue := common.HexToHash("0x0202").Bytes() + + if err := s.Insert(leftKey, leftValue, nil); err != nil { + t.Fatal(err) + } + if err := s.Insert(rightKey, rightValue, nil); err != nil { + t.Fatal(err) } - // Create a copy - copied := node.Copy() - copiedInternal, ok := copied.(*InternalNode) - if !ok { - t.Fatalf("Expected InternalNode, got %T", copied) - } + ns := s.Copy() - // Check depth - if copiedInternal.depth != node.depth { - t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedInternal.depth) - } - - // Check that children are copied - copiedLeft, ok := copiedInternal.left.(*StemNode) - if !ok { - t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.left) - } - - copiedRight, ok := copiedInternal.right.(*StemNode) - if !ok { - t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.right) - } - - // Verify deep copy (children should be different objects) - if copiedLeft == leftStem { - t.Error("Left child not properly copied") - } - if copiedRight == rightStem { - t.Error("Right child not properly copied") - } - - // But values should be equal - if !bytes.Equal(copiedLeft.Values[0], leftStem.Values[0]) { + // Values should be equal + v1, _ := ns.Get(leftKey, nil) + if !bytes.Equal(v1, leftValue) { t.Error("Left child value mismatch after copy") } - if !bytes.Equal(copiedRight.Values[0], rightStem.Values[0]) { + v2, _ := ns.Get(rightKey, nil) + if !bytes.Equal(v2, rightValue) { t.Error("Right child value mismatch after copy") } } -// TestInternalNodeHash tests the Hash method +// TestInternalNodeHash tests the Hash method via NodeStore. func TestInternalNodeHash(t *testing.T) { - // Create an internal node - node := &InternalNode{ - depth: 0, - left: HashedNode(common.HexToHash("0x1111")), - right: HashedNode(common.HexToHash("0x2222")), - } + s := NewNodeStore() + leftRef := s.newHashedRef(common.HexToHash("0x1111")) + rightRef := s.newHashedRef(common.HexToHash("0x2222")) + rootRef := s.newInternalRef(0) + rootNode := s.getInternal(rootRef.Index()) + rootNode.left = leftRef + rootNode.right = rightRef + s.SetRoot(rootRef) - hash1 := node.Hash() + hash1 := s.ComputeHash(rootRef) // Hash should be deterministic - hash2 := node.Hash() + hash2 := s.ComputeHash(rootRef) if hash1 != hash2 { t.Errorf("Hash not deterministic: %x != %x", hash1, hash2) } // Changing a child should change the hash - node.left = HashedNode(common.HexToHash("0x3333")) - node.mustRecompute = true - hash3 := node.Hash() + rootNode.left = s.newHashedRef(common.HexToHash("0x3333")) + rootNode.mustRecompute = true + hash3 := s.ComputeHash(rootRef) if hash1 == hash3 { t.Error("Hash didn't change after modifying left child") } - - // Test with nil children (should use zero hash) - nodeWithNil := &InternalNode{ - depth: 0, - left: nil, - right: HashedNode(common.HexToHash("0x4444")), - mustRecompute: true, - } - hashWithNil := nodeWithNil.Hash() - if hashWithNil == (common.Hash{}) { - t.Error("Hash shouldn't be zero even with nil child") - } } -// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method +// TestInternalNodeGetValuesAtStem tests GetValuesAtStem method via NodeStore. func TestInternalNodeGetValuesAtStem(t *testing.T) { - // Create a tree with values at different stems + s := NewNodeStore() + leftStem := make([]byte, 31) rightStem := make([]byte, 31) rightStem[0] = 0x80 - var leftValues, rightValues [256][]byte + leftValues := make([][]byte, 256) leftValues[0] = common.HexToHash("0x0101").Bytes() leftValues[10] = common.HexToHash("0x0102").Bytes() + rightValues := make([][]byte, 256) rightValues[0] = common.HexToHash("0x0201").Bytes() rightValues[20] = common.HexToHash("0x0202").Bytes() - node := &InternalNode{ - depth: 0, - left: &StemNode{ - Stem: leftStem, - Values: leftValues[:], - depth: 1, - }, - right: &StemNode{ - Stem: rightStem, - Values: rightValues[:], - depth: 1, - }, + if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil { + t.Fatal(err) + } + if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil { + t.Fatal(err) } // Get values from left stem - values, err := node.GetValuesAtStem(leftStem, nil) + values, err := s.GetValuesAtStem(leftStem, nil) if err != nil { t.Fatalf("Failed to get left values: %v", err) } @@ -298,7 +227,7 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) { } // Get values from right stem - values, err = node.GetValuesAtStem(rightStem, nil) + values, err = s.GetValuesAtStem(rightStem, nil) if err != nil { t.Fatalf("Failed to get right values: %v", err) } @@ -310,201 +239,103 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) { } } -// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method +// TestInternalNodeInsertValuesAtStem tests InsertValuesAtStem method via NodeStore. func TestInternalNodeInsertValuesAtStem(t *testing.T) { - // Start with an internal node with empty children - node := &InternalNode{ - depth: 0, - left: Empty{}, - right: Empty{}, - } + s := NewNodeStore() - // Insert values at a stem in the left subtree stem := make([]byte, 31) - var values [256][]byte + values := make([][]byte, 256) values[5] = common.HexToHash("0x0505").Bytes() values[10] = common.HexToHash("0x1010").Bytes() - newNode, err := node.InsertValuesAtStem(stem, values[:], nil, 0) - if err != nil { + if err := s.InsertValuesAtStem(stem, values, nil); err != nil { t.Fatalf("Failed to insert values: %v", err) } - internalNode, ok := newNode.(*InternalNode) - if !ok { - t.Fatalf("Expected InternalNode, got %T", newNode) + // Check that the values are stored + retrieved, err := s.GetValuesAtStem(stem, nil) + if err != nil { + t.Fatalf("Failed to get values: %v", err) } - - // Check that left child is now a StemNode with the values - leftStem, ok := internalNode.left.(*StemNode) - if !ok { - t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) - } - - if !bytes.Equal(leftStem.Values[5], values[5]) { + if !bytes.Equal(retrieved[5], values[5]) { t.Error("Value at index 5 mismatch") } - if !bytes.Equal(leftStem.Values[10], values[10]) { + if !bytes.Equal(retrieved[10], values[10]) { t.Error("Value at index 10 mismatch") } } -// TestInternalNodeCollectNodes tests CollectNodes method +// TestInternalNodeCollectNodes tests CollectNodes method via NodeStore. func TestInternalNodeCollectNodes(t *testing.T) { - // Create an internal node with two stem children. All three are - // marked dirty to mirror production semantics — see CollectNodes. - leftStem := &StemNode{ - Stem: make([]byte, 31), - Values: make([][]byte, 256), - depth: 1, - dirty: true, - } + s := NewNodeStore() - rightStem := &StemNode{ - Stem: make([]byte, 31), - Values: make([][]byte, 256), - depth: 1, - dirty: true, - } - rightStem.Stem[0] = 0x80 + leftStem := make([]byte, 31) + rightStem := make([]byte, 31) + rightStem[0] = 0x80 - node := &InternalNode{ - depth: 0, - left: leftStem, - right: rightStem, - dirty: true, + leftValues := make([][]byte, 256) + rightValues := make([][]byte, 256) + + if err := s.InsertValuesAtStem(leftStem, leftValues, nil); err != nil { + t.Fatal(err) + } + if err := s.InsertValuesAtStem(rightStem, rightValues, nil); err != nil { + t.Fatal(err) } var collectedPaths [][]byte - var collectedNodes []BinaryNode - - flushFn := func(path []byte, n BinaryNode) { + flushFn := func(path []byte, hash common.Hash, serialized []byte) { pathCopy := make([]byte, len(path)) copy(pathCopy, path) collectedPaths = append(collectedPaths, pathCopy) - collectedNodes = append(collectedNodes, n) } - err := node.CollectNodes([]byte{1}, flushFn) + err := s.CollectNodes(s.Root(), []byte{1}, flushFn, MaxGroupDepth) if err != nil { t.Fatalf("Failed to collect nodes: %v", err) } // Should have collected 3 nodes: left stem, right stem, and the internal node itself - if len(collectedNodes) != 3 { - t.Errorf("Expected 3 collected nodes, got %d", len(collectedNodes)) - } - - // Check paths - expectedPaths := [][]byte{ - {1, 0}, // left child - {1, 1}, // right child - {1}, // internal node itself - } - - for i, expectedPath := range expectedPaths { - if !bytes.Equal(collectedPaths[i], expectedPath) { - t.Errorf("Path %d mismatch: expected %v, got %v", i, expectedPath, collectedPaths[i]) - } + if len(collectedPaths) != 3 { + t.Errorf("Expected 3 collected nodes, got %d", len(collectedPaths)) } } -// TestInternalNodeCollectNodesSkipsClean verifies clean subtrees are not -// flushed. A dirty internal node over clean children only flushes itself; -// a fully clean tree flushes nothing. -func TestInternalNodeCollectNodesSkipsClean(t *testing.T) { - leftStem := &StemNode{ - Stem: make([]byte, 31), - Values: make([][]byte, 256), - depth: 1, - } - rightStem := &StemNode{ - Stem: make([]byte, 31), - Values: make([][]byte, 256), - depth: 1, - } - rightStem.Stem[0] = 0x80 - - dirtyParent := &InternalNode{ - depth: 0, - left: leftStem, - right: rightStem, - dirty: true, - } - - var collected []BinaryNode - flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) } - - if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil { - t.Fatalf("CollectNodes: %v", err) - } - if len(collected) != 1 || collected[0] != dirtyParent { - t.Fatalf("expected only the dirty parent to be flushed, got %d nodes", len(collected)) - } - if dirtyParent.dirty { - t.Errorf("parent dirty flag should be cleared after flush") - } - - // Second call on the same tree should be a no-op: everything is clean. - collected = nil - if err := dirtyParent.CollectNodes([]byte{1}, flushFn); err != nil { - t.Fatalf("CollectNodes (second call): %v", err) - } - if len(collected) != 0 { - t.Errorf("expected no nodes to be flushed on clean tree, got %d", len(collected)) - } -} - -// TestInternalNodeGetHeight tests GetHeight method +// TestInternalNodeGetHeight tests GetHeight method via NodeStore. func TestInternalNodeGetHeight(t *testing.T) { - // Create a tree with different heights - // Left subtree: depth 2 (internal -> stem) - // Right subtree: depth 1 (stem) - leftInternal := &InternalNode{ - depth: 1, - left: &StemNode{ - Stem: make([]byte, 31), - Values: make([][]byte, 256), - depth: 2, - }, - right: Empty{}, + s := NewNodeStore() + + // Insert values that create a deeper tree + stem1 := make([]byte, 31) // left + stem2 := make([]byte, 31) + stem2[0] = 0x40 // 01... -> goes left at depth 0, right at depth 1 + + values1 := make([][]byte, 256) + values1[0] = common.HexToHash("0x01").Bytes() + values2 := make([][]byte, 256) + values2[0] = common.HexToHash("0x02").Bytes() + + if err := s.InsertValuesAtStem(stem1, values1, nil); err != nil { + t.Fatal(err) + } + if err := s.InsertValuesAtStem(stem2, values2, nil); err != nil { + t.Fatal(err) } - rightStem := &StemNode{ - Stem: make([]byte, 31), - Values: make([][]byte, 256), - depth: 1, - } - - node := &InternalNode{ - depth: 0, - left: leftInternal, - right: rightStem, - } - - height := node.GetHeight() - // Height should be max(left height, right height) + 1 - // Left height: 2, Right height: 1, so total: 3 - if height != 3 { - t.Errorf("Expected height 3, got %d", height) + height := s.GetHeight(s.Root()) + if height < 2 { + t.Errorf("Expected height >= 2, got %d", height) } } -// TestInternalNodeDepthTooLarge tests handling of excessive depth +// TestInternalNodeDepthTooLarge tests handling of excessive depth via NodeStore. func TestInternalNodeDepthTooLarge(t *testing.T) { - // Create an internal node at max depth - node := &InternalNode{ - depth: 31*8 + 1, - left: Empty{}, - right: Empty{}, - } - - stem := make([]byte, 31) - _, err := node.GetValuesAtStem(stem, nil) - if err == nil { - t.Fatal("Expected error for excessive depth") - } - if err.Error() != "node too deep" { - t.Errorf("Expected 'node too deep' error, got: %v", err) - } + s := NewNodeStore() + // Creating an internal node beyond max depth should panic + defer func() { + if r := recover(); r == nil { + t.Fatal("Expected panic for excessive depth") + } + }() + s.newInternalRef(31*8 + 1) } diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go index 048d37f766..4fe927376f 100644 --- a/trie/bintrie/iterator.go +++ b/trie/bintrie/iterator.go @@ -26,13 +26,14 @@ import ( var errIteratorEnd = errors.New("end of iteration") type binaryNodeIteratorState struct { - Node BinaryNode + Node NodeRef Index int } type binaryNodeIterator struct { trie *BinaryTrie - current BinaryNode + store *NodeStore + current NodeRef lastErr error stack []binaryNodeIteratorState @@ -40,56 +41,59 @@ type binaryNodeIterator struct { func newBinaryNodeIterator(t *BinaryTrie, _ []byte) (trie.NodeIterator, error) { if t.Hash() == zero { - return &binaryNodeIterator{trie: t, lastErr: errIteratorEnd}, nil + return &binaryNodeIterator{trie: t, store: t.store, lastErr: errIteratorEnd}, nil } - it := &binaryNodeIterator{trie: t, current: t.root} - // it.err = it.seek(start) + it := &binaryNodeIterator{trie: t, store: t.store, current: t.store.Root()} return it, nil } -// Next moves the iterator to the next node. If the parameter is false, any child -// nodes will be skipped. +// Next moves the iterator to the next node. func (it *binaryNodeIterator) Next(descend bool) bool { if it.lastErr == errIteratorEnd { - it.lastErr = errIteratorEnd return false } if len(it.stack) == 0 { - it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.root}) - it.current = it.trie.root - + it.stack = append(it.stack, binaryNodeIteratorState{Node: it.trie.store.Root()}) + it.current = it.trie.store.Root() return true } - switch node := it.current.(type) { - case *InternalNode: - // index: 0 = nothing visited, 1=left visited, 2=right visited + switch it.current.Kind() { + case KindInternal: + node := it.store.getInternal(it.current.Index()) context := &it.stack[len(it.stack)-1] - // recurse into both children + if !descend { + // Skip children: pop this node and advance parent + if len(it.stack) == 1 { + it.lastErr = errIteratorEnd + return false + } + it.stack = it.stack[:len(it.stack)-1] + it.current = it.stack[len(it.stack)-1].Node + it.stack[len(it.stack)-1].Index++ + return it.Next(true) + } + if context.Index == 0 { - if _, isempty := node.left.(Empty); node.left != nil && !isempty { + if !node.left.IsEmpty() { it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left}) it.current = node.left return it.Next(descend) } - context.Index++ } if context.Index == 1 { - if _, isempty := node.right.(Empty); node.right != nil && !isempty { + if !node.right.IsEmpty() { it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right}) it.current = node.right return it.Next(descend) } - context.Index++ } - // Reached the end of this node, go back to the parent, if - // this isn't root. if len(it.stack) == 1 { it.lastErr = errIteratorEnd return false @@ -98,17 +102,16 @@ func (it *binaryNodeIterator) Next(descend bool) bool { it.current = it.stack[len(it.stack)-1].Node it.stack[len(it.stack)-1].Index++ return it.Next(descend) - case *StemNode: - // Look for the next non-empty value + + case KindStem: + sn := it.store.getStem(it.current.Index()) for i := it.stack[len(it.stack)-1].Index; i < 256; i++ { - if node.Values[i] != nil { + if sn.hasValue(byte(i)) { it.stack[len(it.stack)-1].Index = i + 1 return true } } - // go back to parent to get the next leaf - // Check if we're at the root before popping if len(it.stack) == 1 { it.lastErr = errIteratorEnd return false @@ -117,45 +120,39 @@ func (it *binaryNodeIterator) Next(descend bool) bool { it.current = it.stack[len(it.stack)-1].Node it.stack[len(it.stack)-1].Index++ return it.Next(descend) - case HashedNode: - // resolve the node - resolverPath := it.Path() - data, err := it.trie.nodeResolver(resolverPath, common.Hash(node)) - if err != nil { - panic(err) - } - if data == nil { - // Empty/nil node — treat as Empty, backtrack - it.current = Empty{} - it.stack[len(it.stack)-1].Node = it.current - return it.Next(descend) - } - it.current, err = DeserializeNodeWithHash(data, len(it.stack)-1, common.Hash(node)) - if err != nil { - panic(err) - } - // update the stack and parent with the resolved node - it.stack[len(it.stack)-1].Node = it.current - if len(it.stack) >= 2 { - parent := &it.stack[len(it.stack)-2] - if parent.Index == 0 { - parent.Node.(*InternalNode).left = it.current - } else { - parent.Node.(*InternalNode).right = it.current - } - } - return it.Next(descend) - case Empty: - // Empty node - go back to parent and continue - if len(it.stack) <= 1 { - it.lastErr = errIteratorEnd + case KindHashed: + if len(it.stack) < 2 { + it.lastErr = errors.New("cannot resolve hashed root during iteration") return false } - it.stack = it.stack[:len(it.stack)-1] - it.current = it.stack[len(it.stack)-1].Node - it.stack[len(it.stack)-1].Index++ + hn := it.store.getHashed(it.current.Index()) + data, err := it.trie.nodeResolver(it.Path(), hn.hash) + if err != nil { + it.lastErr = err + return false + } + resolved, err := it.store.DeserializeNodeWithHash(data, len(it.stack)-1, hn.hash) + if err != nil { + it.lastErr = err + return false + } + + // Update the stack and parent with the resolved node + it.current = resolved + it.stack[len(it.stack)-1].Node = resolved + parent := &it.stack[len(it.stack)-2] + parentNode := it.store.getInternal(parent.Node.Index()) + if parent.Index == 0 { + parentNode.left = resolved + } else { + parentNode.right = resolved + } return it.Next(descend) + + case KindEmpty: + return false + default: panic("invalid node type") } @@ -171,25 +168,21 @@ func (it *binaryNodeIterator) Error() error { // Hash returns the hash of the current node. func (it *binaryNodeIterator) Hash() common.Hash { - return it.current.Hash() + return it.store.ComputeHash(it.current) } -// Parent returns the hash of the parent of the current node. The hash may be the one -// grandparent if the immediate parent is an internal node with no hash. +// Parent returns the hash of the parent of the current node. func (it *binaryNodeIterator) Parent() common.Hash { - return it.stack[len(it.stack)-1].Node.Hash() + return it.store.ComputeHash(it.stack[len(it.stack)-1].Node) } // Path returns the hex-encoded path to the current node. -// Callers must not retain references to the return value after calling Next. -// For leaf nodes, the last element of the path is the 'terminator symbol' 0x10. func (it *binaryNodeIterator) Path() []byte { if it.Leaf() { return it.LeafKey() } var path []byte for i, state := range it.stack { - // skip the last byte if i >= len(it.stack)-1 { break } @@ -200,105 +193,78 @@ func (it *binaryNodeIterator) Path() []byte { // NodeBlob returns the serialized bytes of the current node. func (it *binaryNodeIterator) NodeBlob() []byte { - return SerializeNode(it.current) + return it.store.SerializeNode(it.current, MaxGroupDepth) } // Leaf returns true iff the current node is a leaf node. -// In a Binary Trie, a StemNode contains up to 256 leaf values. -// The iterator is only considered to be "at a leaf" when it's positioned -// at a specific non-nil value within the StemNode, not just at the StemNode itself. func (it *binaryNodeIterator) Leaf() bool { - sn, ok := it.current.(*StemNode) - if !ok { + if it.current.Kind() != KindStem { return false } - // Check if we have a valid stack position if len(it.stack) == 0 { return false } - // The Index in the stack state points to the NEXT position after the current value. - // So if Index is 0, we haven't started iterating through the values yet. - // If Index is 5, we're currently at value[4] (the 5th value, 0-indexed). idx := it.stack[len(it.stack)-1].Index if idx == 0 || idx > 256 { return false } - // Check if there's actually a value at the current position + sn := it.store.getStem(it.current.Index()) currentValueIndex := idx - 1 - return sn.Values[currentValueIndex] != nil + return sn.hasValue(byte(currentValueIndex)) } -// LeafKey returns the key of the leaf. The method panics if the iterator is not -// positioned at a leaf. Callers must not retain references to the value after -// calling Next. +// LeafKey returns the key of the leaf. func (it *binaryNodeIterator) LeafKey() []byte { - leaf, ok := it.current.(*StemNode) - if !ok { + if it.current.Kind() != KindStem { panic("Leaf() called on an binary node iterator not at a leaf location") } - return leaf.Key(it.stack[len(it.stack)-1].Index - 1) + sn := it.store.getStem(it.current.Index()) + return sn.Key(it.stack[len(it.stack)-1].Index - 1) } -// LeafBlob returns the content of the leaf. The method panics if the iterator -// is not positioned at a leaf. Callers must not retain references to the value -// after calling Next. +// LeafBlob returns the content of the leaf. func (it *binaryNodeIterator) LeafBlob() []byte { - leaf, ok := it.current.(*StemNode) - if !ok { + if it.current.Kind() != KindStem { panic("LeafBlob() called on an binary node iterator not at a leaf location") } - return leaf.Values[it.stack[len(it.stack)-1].Index-1] + sn := it.store.getStem(it.current.Index()) + return sn.getValue(byte(it.stack[len(it.stack)-1].Index - 1)) } -// LeafProof returns the Merkle proof of the leaf. The method panics if the -// iterator is not positioned at a leaf. Callers must not retain references -// to the value after calling Next. +// LeafProof returns the Merkle proof of the leaf. func (it *binaryNodeIterator) LeafProof() [][]byte { - sn, ok := it.current.(*StemNode) - if !ok { + if it.current.Kind() != KindStem { panic("LeafProof() called on an binary node iterator not at a leaf location") } + sn := it.store.getStem(it.current.Index()) proof := make([][]byte, 0, len(it.stack)+StemNodeWidth) - // Build proof by walking up the stack and collecting sibling hashes for i := range it.stack[:len(it.stack)-2] { state := it.stack[i] - internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode + internalNode := it.store.getInternal(state.Node.Index()) - // Add the sibling hash to the proof if state.Index == 0 { - // We came from left, so include right sibling - proof = append(proof, internalNode.right.Hash().Bytes()) + rh := it.store.ComputeHash(internalNode.right) + proof = append(proof, rh.Bytes()) } else { - // We came from right, so include left sibling - proof = append(proof, internalNode.left.Hash().Bytes()) + lh := it.store.ComputeHash(internalNode.left) + proof = append(proof, lh.Bytes()) } } // Add the stem and siblings - proof = append(proof, sn.Stem) - for _, v := range sn.Values { - proof = append(proof, v) - } + proof = append(proof, sn.Stem[:]) + proof = append(proof, sn.allValues()...) return proof } // AddResolver sets an intermediate database to use for looking up trie nodes // before reaching into the real persistent layer. -// -// This is not required for normal operation, rather is an optimization for -// cases where trie nodes can be recovered from some external mechanism without -// reading from disk. In those cases, this resolver allows short circuiting -// accesses and returning them from memory. -// -// Before adding a similar mechanism to any other place in Geth, consider -// making trie.Database an interface and wrapping at that level. It's a huge -// refactor, but it could be worth it if another occurrence arises. func (it *binaryNodeIterator) AddResolver(trie.NodeResolver) { // Not implemented, but should not panic } diff --git a/trie/bintrie/iterator_test.go b/trie/bintrie/iterator_test.go index 3e717c07ba..954e3b8400 100644 --- a/trie/bintrie/iterator_test.go +++ b/trie/bintrie/iterator_test.go @@ -27,14 +27,14 @@ import ( // makeTrie creates a BinaryTrie populated with the given key-value pairs. func makeTrie(t *testing.T, entries [][2]common.Hash) *BinaryTrie { t.Helper() + store := NewNodeStore() tr := &BinaryTrie{ - root: NewBinaryNode(), - tracer: trie.NewPrevalueTracer(), + store: store, + tracer: trie.NewPrevalueTracer(), + groupDepth: MaxGroupDepth, } for _, kv := range entries { - var err error - tr.root, err = tr.root.Insert(kv[0][:], kv[1][:], nil, 0) - if err != nil { + if err := store.Insert(kv[0][:], kv[1][:], nil); err != nil { t.Fatal(err) } } @@ -64,8 +64,9 @@ func countLeaves(t *testing.T, tr *BinaryTrie) int { // no nodes and reports no error. func TestIteratorEmptyTrie(t *testing.T) { tr := &BinaryTrie{ - root: Empty{}, - tracer: trie.NewPrevalueTracer(), + store: NewNodeStore(), + tracer: trie.NewPrevalueTracer(), + groupDepth: MaxGroupDepth, } it, err := newBinaryNodeIterator(tr, nil) if err != nil { @@ -145,8 +146,8 @@ func TestIteratorEmptyNodeBacktrack(t *testing.T) { {common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey}, }) - if _, ok := tr.root.(*InternalNode); !ok { - t.Fatalf("expected InternalNode root, got %T", tr.root) + if tr.store.Root().Kind() != KindInternal { + t.Fatalf("expected InternalNode root, got kind %d", tr.store.Root().Kind()) } if leaves := countLeaves(t, tr); leaves != 2 { t.Fatalf("expected 2 leaves, got %d (Empty backtrack bug?)", leaves) @@ -162,18 +163,31 @@ func TestIteratorHashedNodeNilData(t *testing.T) { {common.HexToHash("8000000000000000000000000000000000000000000000000000000000000001"), oneKey}, }) - root, ok := tr.root.(*InternalNode) - if !ok { - t.Fatalf("expected InternalNode root, got %T", tr.root) + root := tr.store.Root() + if root.Kind() != KindInternal { + t.Fatalf("expected InternalNode root, got kind %d", root.Kind()) } + rootNode := tr.store.getInternal(root.Index()) // Replace right child with a zero-hash HashedNode. nodeResolver // short-circuits on common.Hash{} and returns (nil, nil), which // triggers the nil-data guard in the iterator. - root.right = HashedNode(common.Hash{}) + rootNode.right = tr.store.newHashedRef(common.Hash{}) // Should not panic; the zero-hash right child should be treated as Empty. - if leaves := countLeaves(t, tr); leaves != 1 { + // Since the hashed node can't be resolved (nil data -> empty deserialization), + // only the left leaf should be counted. + it, err := newBinaryNodeIterator(tr, nil) + if err != nil { + t.Fatal(err) + } + leaves := 0 + for it.Next(true) { + if it.Leaf() { + leaves++ + } + } + if leaves != 1 { t.Fatalf("expected 1 leaf (zero-hash right node skipped), got %d", leaves) } } diff --git a/trie/bintrie/node_ref.go b/trie/bintrie/node_ref.go new file mode 100644 index 0000000000..da6d24a711 --- /dev/null +++ b/trie/bintrie/node_ref.go @@ -0,0 +1,57 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +// NodeKind identifies the type of a trie node stored in a NodeRef. +type NodeKind uint8 + +const ( + KindEmpty NodeKind = iota // no node + KindInternal // internal binary branching node + KindStem // leaf group containing up to 256 values + KindHashed // unresolved node (hash only) + KindInvalid // sentinel for validation +) + +// NodeRef is a compact, GC-invisible reference to a node in a NodeStore. +// It packs a 2-bit type tag (bits 31-30) and a 30-bit index (bits 29-0) +// into a single uint32. Because NodeRef contains no Go pointers, slices +// of structs containing NodeRef fields are allocated in noscan spans — +// the garbage collector never examines them. +type NodeRef uint32 + +const ( + kindShift uint32 = 30 + indexMask uint32 = (1 << kindShift) - 1 + + // EmptyRef is the zero-value NodeRef, representing an empty node. + EmptyRef NodeRef = 0 +) + +// MakeRef creates a NodeRef from a kind and pool index. +func MakeRef(kind NodeKind, idx uint32) NodeRef { + return NodeRef(uint32(kind)<> kindShift) } + +// Index returns the pool index within the node's typed pool. +func (r NodeRef) Index() uint32 { return uint32(r) & indexMask } + +// IsEmpty returns true if this ref represents an empty node. +func (r NodeRef) IsEmpty() bool { return r == EmptyRef } diff --git a/trie/bintrie/node_store.go b/trie/bintrie/node_store.go new file mode 100644 index 0000000000..79670a4331 --- /dev/null +++ b/trie/bintrie/node_store.go @@ -0,0 +1,239 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import "github.com/ethereum/go-ethereum/common" + +// storeChunkSize is the number of nodes per chunk in each typed pool. +// Using fixed-size array chunks ensures that pointers to nodes within +// existing chunks remain valid when new chunks are added (no reallocation +// of the backing data, only the outer pointer slice grows). +const storeChunkSize = 4096 + +// NodeStore is a GC-friendly arena for binary trie nodes. +// +// Instead of allocating each node as a separate heap object with interface +// pointers (which the GC must scan), NodeStore packs nodes into typed chunked +// pools. InternalNode and HashedNode contain ZERO Go pointers, so their pool +// backing arrays are allocated in noscan spans — the GC skips them entirely. +// StemNode has one pointer (valueData []byte) per node. +// +// For a trie with 25K InternalNodes, this reduces GC-scanned pointer-words +// from ~125K (with interface-based nodes) to ~25K (just StemNode valueData), +// an ~80% reduction. At mainnet scale (millions of nodes), this prevents +// multi-second GC pauses. +type NodeStore struct { + // InternalNode pool — NOSCAN: InternalNode contains zero Go pointers. + // Children are NodeRef (uint32), hash is [32]byte. + internalChunks []*[storeChunkSize]InternalNode + internalCount uint32 + + // StemNode pool — each StemNode has one pointer (valueData []byte). + // Still much better than the old design where each InternalNode had + // two BinaryNode interface pointers (4 pointer-words each). + stemChunks []*[storeChunkSize]StemNode + stemCount uint32 + + // HashedNode pool — NOSCAN: HashedNode is just [32]byte. + hashedChunks []*[storeChunkSize]HashedNode + hashedCount uint32 + + root NodeRef + + // Free lists for recycling deleted node slots. + freeInternals []uint32 + freeStems []uint32 + freeHashed []uint32 +} + +// NewNodeStore creates a new empty NodeStore. +func NewNodeStore() *NodeStore { + return &NodeStore{root: EmptyRef} +} + +// Root returns the root NodeRef. +func (s *NodeStore) Root() NodeRef { return s.root } + +// SetRoot sets the root NodeRef. +func (s *NodeStore) SetRoot(ref NodeRef) { s.root = ref } + +// --- InternalNode allocation --- + +func (s *NodeStore) allocInternal() uint32 { + if n := len(s.freeInternals); n > 0 { + idx := s.freeInternals[n-1] + s.freeInternals = s.freeInternals[:n-1] + *s.getInternal(idx) = InternalNode{} + return idx + } + idx := s.internalCount + chunkIdx := idx / storeChunkSize + if uint32(len(s.internalChunks)) <= chunkIdx { + s.internalChunks = append(s.internalChunks, new([storeChunkSize]InternalNode)) + } + s.internalCount++ + if s.internalCount > indexMask { + panic("internal node pool overflow") + } + return idx +} + +func (s *NodeStore) getInternal(idx uint32) *InternalNode { + return &s.internalChunks[idx/storeChunkSize][idx%storeChunkSize] +} + +// newInternalRef allocates an InternalNode and returns its NodeRef. +func (s *NodeStore) newInternalRef(depth int) NodeRef { + if depth > 248 { + panic("node depth exceeds maximum binary trie depth") + } + idx := s.allocInternal() + n := s.getInternal(idx) + n.depth = uint8(depth) + n.mustRecompute = true + return MakeRef(KindInternal, idx) +} + +// --- StemNode allocation --- + +func (s *NodeStore) allocStem() uint32 { + if n := len(s.freeStems); n > 0 { + idx := s.freeStems[n-1] + s.freeStems = s.freeStems[:n-1] + *s.getStem(idx) = StemNode{} + return idx + } + idx := s.stemCount + chunkIdx := idx / storeChunkSize + if uint32(len(s.stemChunks)) <= chunkIdx { + s.stemChunks = append(s.stemChunks, new([storeChunkSize]StemNode)) + } + s.stemCount++ + if s.stemCount > indexMask { + panic("internal node pool overflow") + } + return idx +} + +func (s *NodeStore) getStem(idx uint32) *StemNode { + return &s.stemChunks[idx/storeChunkSize][idx%storeChunkSize] +} + +// newStemRef allocates a StemNode with the given stem/depth and returns its NodeRef. +func (s *NodeStore) newStemRef(stem []byte, depth int) NodeRef { + if depth > 248 { + panic("node depth exceeds maximum binary trie depth") + } + idx := s.allocStem() + sn := s.getStem(idx) + copy(sn.Stem[:], stem[:StemSize]) + sn.depth = uint8(depth) + sn.mustRecompute = true + return MakeRef(KindStem, idx) +} + +// --- HashedNode allocation --- + +func (s *NodeStore) allocHashed() uint32 { + if n := len(s.freeHashed); n > 0 { + idx := s.freeHashed[n-1] + s.freeHashed = s.freeHashed[:n-1] + *s.getHashed(idx) = HashedNode{} + return idx + } + idx := s.hashedCount + chunkIdx := idx / storeChunkSize + if uint32(len(s.hashedChunks)) <= chunkIdx { + s.hashedChunks = append(s.hashedChunks, new([storeChunkSize]HashedNode)) + } + s.hashedCount++ + if s.hashedCount > indexMask { + panic("internal node pool overflow") + } + return idx +} + +func (s *NodeStore) getHashed(idx uint32) *HashedNode { + return &s.hashedChunks[idx/storeChunkSize][idx%storeChunkSize] +} + +func (s *NodeStore) freeHashedNode(idx uint32) { + s.freeHashed = append(s.freeHashed, idx) +} + +// newHashedRef allocates a HashedNode and returns its NodeRef. +func (s *NodeStore) newHashedRef(hash common.Hash) NodeRef { + idx := s.allocHashed() + *s.getHashed(idx) = HashedNode{hash: hash} + return MakeRef(KindHashed, idx) +} + +// Copy creates a deep copy of the NodeStore and all its nodes. +func (s *NodeStore) Copy() *NodeStore { + ns := &NodeStore{ + root: s.root, + internalCount: s.internalCount, + stemCount: s.stemCount, + hashedCount: s.hashedCount, + } + + // Deep copy internal chunks + ns.internalChunks = make([]*[storeChunkSize]InternalNode, len(s.internalChunks)) + for i, chunk := range s.internalChunks { + cp := *chunk + ns.internalChunks[i] = &cp + } + + // Deep copy stem chunks (need to deep copy valueData) + ns.stemChunks = make([]*[storeChunkSize]StemNode, len(s.stemChunks)) + for i, chunk := range s.stemChunks { + cp := *chunk + ns.stemChunks[i] = &cp + } + // Deep copy pointer fields for each active stem + for i := uint32(0); i < s.stemCount; i++ { + src := s.getStem(i) + dst := ns.getStem(i) + if len(src.valueData) > 0 { + dst.valueData = make([]byte, len(src.valueData)) + copy(dst.valueData, src.valueData) + } + } + + // Deep copy hashed chunks + ns.hashedChunks = make([]*[storeChunkSize]HashedNode, len(s.hashedChunks)) + for i, chunk := range s.hashedChunks { + cp := *chunk + ns.hashedChunks[i] = &cp + } + + // Copy free lists + if len(s.freeInternals) > 0 { + ns.freeInternals = make([]uint32, len(s.freeInternals)) + copy(ns.freeInternals, s.freeInternals) + } + if len(s.freeStems) > 0 { + ns.freeStems = make([]uint32, len(s.freeStems)) + copy(ns.freeStems, s.freeStems) + } + if len(s.freeHashed) > 0 { + ns.freeHashed = make([]uint32, len(s.freeHashed)) + copy(ns.freeHashed, s.freeHashed) + } + + return ns +} diff --git a/trie/bintrie/stem_node.go b/trie/bintrie/stem_node.go index 0ceae6b062..9f1ac6d2e0 100644 --- a/trie/bintrie/stem_node.go +++ b/trie/bintrie/stem_node.go @@ -17,119 +17,135 @@ package bintrie import ( - "bytes" - "errors" - "fmt" - "slices" + "math/bits" "github.com/ethereum/go-ethereum/common" ) -// StemNode represents a group of `NodeWith` values sharing the same stem. +// StemNode represents a group of `StemNodeWidth` values sharing the same stem. +// It uses a packed representation: bitmap indicates which of the 256 positions +// have values, and valueData stores the values contiguously in bitmap order. type StemNode struct { - Stem []byte // Stem path to get to StemNodeWidth values - Values [][]byte // All values, indexed by the last byte of the key. - depth int // Depth of the node + Stem [StemSize]byte // Stem path to get to StemNodeWidth values + bitmap [StemBitmapSize]byte // bitmap indicating which positions have values + valueData []byte // packed value data (count * HashSize bytes) + count uint16 // number of values present + depth uint8 // Depth of the node + shared bool // true if valueData is shared with serialized input mustRecompute bool // true if the hash needs to be recomputed - dirty bool // true if the node's on-disk blob is stale (needs flush) hash common.Hash // cached hash when mustRecompute == false } -// Get retrieves the value for the given key. -func (bt *StemNode) Get(key []byte, _ NodeResolverFn) ([]byte, error) { - if !bytes.Equal(bt.Stem, key[:StemSize]) { - return nil, nil +// posInData returns the index within valueData for the given suffix. +// Returns -1 if the suffix is not present. +func (sn *StemNode) posInData(suffix byte) int { + idx := int(suffix) + if sn.bitmap[idx/8]>>(7-(idx%8))&1 == 0 { + return -1 } - return bt.Values[key[StemSize]], nil + // Count the bits set before this position to determine the offset + pos := 0 + byteIdx := idx / 8 + for i := 0; i < byteIdx; i++ { + pos += bits.OnesCount8(sn.bitmap[i]) + } + // Count bits in the partial byte + mask := byte(0xFF) << (8 - (idx % 8)) + pos += bits.OnesCount8(sn.bitmap[byteIdx] & mask) + return pos } -// Insert inserts a new key-value pair into the node. -func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (BinaryNode, error) { - if !bytes.Equal(bt.Stem, key[:StemSize]) { - bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 +// getValue returns the value at the given suffix, or nil if not present. +func (sn *StemNode) getValue(suffix byte) []byte { + pos := sn.posInData(suffix) + if pos < 0 { + return nil + } + start := pos * HashSize + return sn.valueData[start : start+HashSize] +} - n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true} - bt.depth++ - // bt is re-parented under n and sits at a new path — rewrite its blob. - bt.mustRecompute = true - bt.dirty = true - var child, other *BinaryNode - if bitStem == 0 { - n.left = bt - child = &n.left - other = &n.right - } else { - n.right = bt - child = &n.right - other = &n.left +// hasValue returns true if the given suffix has a value. +func (sn *StemNode) hasValue(suffix byte) bool { + idx := int(suffix) + return sn.bitmap[idx/8]>>(7-(idx%8))&1 == 1 +} + +// allValues returns all 256 values (nil for absent positions). +func (sn *StemNode) allValues() [][]byte { + values := make([][]byte, StemNodeWidth) + dataIdx := 0 + for i := range StemNodeWidth { + if sn.bitmap[i/8]>>(7-(i%8))&1 == 1 { + values[i] = sn.valueData[dataIdx*HashSize : (dataIdx+1)*HashSize] + dataIdx++ } - - bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1 - if bitKey == bitStem { - var err error - *child, err = (*child).Insert(key, value, nil, depth+1) - if err != nil { - return n, fmt.Errorf("insert error: %w", err) - } - *other = Empty{} - } else { - var values [StemNodeWidth][]byte - values[key[StemSize]] = value - *other = &StemNode{ - Stem: slices.Clone(key[:StemSize]), - Values: values[:], - depth: depth + 1, - mustRecompute: true, - dirty: true, - } - } - return n, nil } - if len(value) != HashSize { - return bt, errors.New("invalid insertion: value length") - } - bt.Values[key[StemSize]] = value - bt.mustRecompute = true - bt.dirty = true - return bt, nil + return values } -// Copy creates a deep copy of the node. -func (bt *StemNode) Copy() BinaryNode { - var values [StemNodeWidth][]byte - for i, v := range bt.Values { - values[i] = slices.Clone(v) - } - return &StemNode{ - Stem: slices.Clone(bt.Stem), - Values: values[:], - depth: bt.depth, - hash: bt.hash, - mustRecompute: bt.mustRecompute, - dirty: bt.dirty, +// ensureWritable makes the valueData writable (copies if shared with serialized input). +func (sn *StemNode) ensureWritable() { + if sn.shared || cap(sn.valueData)-len(sn.valueData) < HashSize { + newData := make([]byte, len(sn.valueData), len(sn.valueData)+HashSize*4) + copy(newData, sn.valueData) + sn.valueData = newData + sn.shared = false } } -// GetHeight returns the height of the node. -func (bt *StemNode) GetHeight() int { - return 1 +// setValue sets or inserts a value at the given suffix. +func (sn *StemNode) setValue(suffix byte, value []byte) { + idx := int(suffix) + pos := sn.posInData(suffix) + if pos >= 0 { + // Overwrite existing value + copy(sn.valueData[pos*HashSize:], value[:HashSize]) + return + } + // New value: insert into bitmap and valueData at the correct position. + sn.bitmap[idx/8] |= 1 << (7 - (idx % 8)) + sn.count++ + + // Find the correct position in valueData (count bits before this position). + insertPos := 0 + byteIdx := idx / 8 + for i := 0; i < byteIdx; i++ { + insertPos += bits.OnesCount8(sn.bitmap[i]) + } + mask := byte(0xFF) << (8 - (idx % 8)) + insertPos += bits.OnesCount8(sn.bitmap[byteIdx] & mask) + + // Insert value at the correct position in valueData. + insertOffset := insertPos * HashSize + // Grow the slice + sn.valueData = append(sn.valueData, make([]byte, HashSize)...) + // Shift data after insertion point + copy(sn.valueData[insertOffset+HashSize:], sn.valueData[insertOffset:len(sn.valueData)-HashSize]) + // Copy the new value + copy(sn.valueData[insertOffset:], value[:HashSize]) } // Hash returns the hash of the node. -func (bt *StemNode) Hash() common.Hash { - if !bt.mustRecompute { - return bt.hash +func (sn *StemNode) Hash() common.Hash { + if !sn.mustRecompute { + return sn.hash } var data [StemNodeWidth]common.Hash h := newSha256() defer returnSha256(h) - for i, v := range bt.Values { - if v != nil { + + // Hash each present value + dataIdx := 0 + for i := range StemNodeWidth { + if sn.bitmap[i/8]>>(7-(i%8))&1 == 1 { + v := sn.valueData[dataIdx*HashSize : (dataIdx+1)*HashSize] h.Reset() h.Write(v) h.Sum(data[i][:0]) + dataIdx++ } } h.Reset() @@ -150,103 +166,18 @@ func (bt *StemNode) Hash() common.Hash { } h.Reset() - h.Write(bt.Stem) + h.Write(sn.Stem[:]) h.Write([]byte{0}) h.Write(data[0][:]) - bt.hash = common.BytesToHash(h.Sum(nil)) - bt.mustRecompute = false - return bt.hash -} - -// CollectNodes flushes the stem via the collector when dirty; clean stems -// are skipped. -func (bt *StemNode) CollectNodes(path []byte, flush NodeFlushFn) error { - if !bt.dirty { - return nil - } - flush(path, bt) - bt.dirty = false - return nil -} - -// GetValuesAtStem retrieves the group of values located at the given stem key. -func (bt *StemNode) GetValuesAtStem(stem []byte, _ NodeResolverFn) ([][]byte, error) { - if !bytes.Equal(bt.Stem, stem) { - return nil, nil - } - return bt.Values[:], nil -} - -// InsertValuesAtStem inserts a full value group at the given stem in the internal node. -// Already-existing values will be overwritten. -func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) { - if !bytes.Equal(bt.Stem, key[:StemSize]) { - bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 - - n := &InternalNode{depth: bt.depth, mustRecompute: true, dirty: true} - bt.depth++ - // bt is re-parented under n and sits at a new path — rewrite its blob. - bt.mustRecompute = true - bt.dirty = true - var child, other *BinaryNode - if bitStem == 0 { - n.left = bt - child = &n.left - other = &n.right - } else { - n.right = bt - child = &n.right - other = &n.left - } - - bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1 - if bitKey == bitStem { - var err error - *child, err = (*child).InsertValuesAtStem(key, values, nil, depth+1) - if err != nil { - return n, fmt.Errorf("insert error: %w", err) - } - *other = Empty{} - } else { - *other = &StemNode{ - Stem: slices.Clone(key[:StemSize]), - Values: values, - depth: n.depth + 1, - mustRecompute: true, - dirty: true, - } - } - return n, nil - } - - // same stem, just merge the two value lists - for i, v := range values { - if v != nil { - bt.Values[i] = v - bt.mustRecompute = true - bt.dirty = true - } - } - return bt, nil -} - -func (bt *StemNode) toDot(parent, path string) string { - me := fmt.Sprintf("stem%s", path) - ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, bt.Stem, bt.Hash()) - ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) - for i, v := range bt.Values { - if v != nil { - ret = fmt.Sprintf("%s%s%x [label=\"%x\"]\n", ret, me, i, v) - ret = fmt.Sprintf("%s%s -> %s%x\n", ret, me, me, i) - } - } - return ret + sn.hash = common.BytesToHash(h.Sum(nil)) + sn.mustRecompute = false + return sn.hash } // Key returns the full key for the given index. -func (bt *StemNode) Key(i int) []byte { +func (sn *StemNode) Key(i int) []byte { var ret [HashSize]byte - copy(ret[:], bt.Stem) + copy(ret[:], sn.Stem[:]) ret[StemSize] = byte(i) return ret[:] } diff --git a/trie/bintrie/stem_node_test.go b/trie/bintrie/stem_node_test.go index 2743e7ce9b..9ad77ff438 100644 --- a/trie/bintrie/stem_node_test.go +++ b/trie/bintrie/stem_node_test.go @@ -23,165 +23,99 @@ import ( "github.com/ethereum/go-ethereum/common" ) -// TestStemNodeGet tests the Get method for matching stem, non-matching stem, -// and nil-value suffix scenarios. -func TestStemNodeGet(t *testing.T) { - stem := make([]byte, StemSize) - stem[0] = 0xAB - var values [StemNodeWidth][]byte - values[5] = common.HexToHash("0xdeadbeef").Bytes() - - node := &StemNode{Stem: stem, Values: values[:], depth: 0} - - // Matching stem, populated suffix → returns value. - key := make([]byte, HashSize) - copy(key[:StemSize], stem) - key[StemSize] = 5 - got, err := node.Get(key, nil) - if err != nil { - t.Fatalf("Get error: %v", err) - } - if !bytes.Equal(got, values[5]) { - t.Fatalf("Get = %x, want %x", got, values[5]) - } - - // Matching stem, empty suffix → returns nil (slot not set). - key[StemSize] = 99 - got, err = node.Get(key, nil) - if err != nil { - t.Fatalf("Get error: %v", err) - } - if got != nil { - t.Fatalf("Get(empty suffix) = %x, want nil", got) - } - - // Non-matching stem → returns nil, nil. - otherKey := make([]byte, HashSize) - otherKey[0] = 0xFF - got, err = node.Get(otherKey, nil) - if err != nil { - t.Fatalf("Get error: %v", err) - } - if got != nil { - t.Fatalf("Get(wrong stem) = %x, want nil", got) - } -} - -// TestStemNodeInsertSameStem tests inserting values with the same stem +// TestStemNodeInsertSameStem tests inserting values with the same stem via NodeStore. func TestStemNodeInsertSameStem(t *testing.T) { + s := NewNodeStore() + stem := make([]byte, 31) for i := range stem { stem[i] = byte(i) } - var values [256][]byte - values[0] = common.HexToHash("0x0101").Bytes() - - node := &StemNode{ - Stem: stem, - Values: values[:], - depth: 0, + // Insert first value + key1 := make([]byte, 32) + copy(key1[:31], stem) + key1[31] = 0 + value1 := common.HexToHash("0x0101").Bytes() + if err := s.Insert(key1, value1, nil); err != nil { + t.Fatal(err) } // Insert another value with the same stem but different last byte - key := make([]byte, 32) - copy(key[:31], stem) - key[31] = 10 - value := common.HexToHash("0x0202").Bytes() - - newNode, err := node.Insert(key, value, nil, 0) - if err != nil { - t.Fatalf("Failed to insert: %v", err) + key2 := make([]byte, 32) + copy(key2[:31], stem) + key2[31] = 10 + value2 := common.HexToHash("0x0202").Bytes() + if err := s.Insert(key2, value2, nil); err != nil { + t.Fatal(err) } - // Should still be a StemNode - stemNode, ok := newNode.(*StemNode) - if !ok { - t.Fatalf("Expected StemNode, got %T", newNode) + // Root should still be a StemNode + if s.Root().Kind() != KindStem { + t.Fatalf("Expected KindStem root, got kind %d", s.Root().Kind()) } // Check that both values are present - if !bytes.Equal(stemNode.Values[0], values[0]) { + v1, _ := s.Get(key1, nil) + if !bytes.Equal(v1, value1) { t.Errorf("Value at index 0 mismatch") } - if !bytes.Equal(stemNode.Values[10], value) { + v2, _ := s.Get(key2, nil) + if !bytes.Equal(v2, value2) { t.Errorf("Value at index 10 mismatch") } } -// TestStemNodeInsertDifferentStem tests inserting values with different stems +// TestStemNodeInsertDifferentStem tests inserting values with different stems via NodeStore. func TestStemNodeInsertDifferentStem(t *testing.T) { - stem1 := make([]byte, 31) - for i := range stem1 { - stem1[i] = 0x00 - } + s := NewNodeStore() - var values [256][]byte - values[0] = common.HexToHash("0x0101").Bytes() - - node := &StemNode{ - Stem: stem1, - Values: values[:], - depth: 0, + // Insert first value with stem of all zeros + key1 := make([]byte, 32) + key1[31] = 0 + value1 := common.HexToHash("0x0101").Bytes() + if err := s.Insert(key1, value1, nil); err != nil { + t.Fatal(err) } // Insert with a different stem (first bit different) - key := make([]byte, 32) - key[0] = 0x80 // First bit is 1 instead of 0 - value := common.HexToHash("0x0202").Bytes() - - newNode, err := node.Insert(key, value, nil, 0) - if err != nil { - t.Fatalf("Failed to insert: %v", err) + key2 := make([]byte, 32) + key2[0] = 0x80 // First bit is 1 instead of 0 + value2 := common.HexToHash("0x0202").Bytes() + if err := s.Insert(key2, value2, nil); err != nil { + t.Fatal(err) } // Should now be an InternalNode - internalNode, ok := newNode.(*InternalNode) - if !ok { - t.Fatalf("Expected InternalNode, got %T", newNode) + if s.Root().Kind() != KindInternal { + t.Fatalf("Expected KindInternal root, got kind %d", s.Root().Kind()) } // Check depth - if internalNode.depth != 0 { - t.Errorf("Expected depth 0, got %d", internalNode.depth) + rootNode := s.getInternal(s.Root().Index()) + if rootNode.depth != 0 { + t.Errorf("Expected depth 0, got %d", rootNode.depth) } - // Original stem should be on the left (bit 0) - leftStem, ok := internalNode.left.(*StemNode) - if !ok { - t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) + // Verify both values are retrievable + v1, _ := s.Get(key1, nil) + if !bytes.Equal(v1, value1) { + t.Error("Value 1 mismatch") } - if !bytes.Equal(leftStem.Stem, stem1) { - t.Errorf("Left stem mismatch") - } - - // New stem should be on the right (bit 1) - rightStem, ok := internalNode.right.(*StemNode) - if !ok { - t.Fatalf("Expected right child to be StemNode, got %T", internalNode.right) - } - if !bytes.Equal(rightStem.Stem, key[:31]) { - t.Errorf("Right stem mismatch") + v2, _ := s.Get(key2, nil) + if !bytes.Equal(v2, value2) { + t.Error("Value 2 mismatch") } } -// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length +// TestStemNodeInsertInvalidValueLength tests inserting value with invalid length via NodeStore. func TestStemNodeInsertInvalidValueLength(t *testing.T) { - stem := make([]byte, 31) - var values [256][]byte + s := NewNodeStore() - node := &StemNode{ - Stem: stem, - Values: values[:], - depth: 0, - } - - // Try to insert value with wrong length key := make([]byte, 32) - copy(key[:31], stem) invalidValue := []byte{1, 2, 3} // Not 32 bytes - _, err := node.Insert(key, invalidValue, nil, 0) + err := s.Insert(key, invalidValue, nil) if err == nil { t.Fatal("Expected error for invalid value length") } @@ -191,221 +125,209 @@ func TestStemNodeInsertInvalidValueLength(t *testing.T) { } } -// TestStemNodeCopy tests the Copy method +// TestStemNodeCopy tests the Copy method via NodeStore. func TestStemNodeCopy(t *testing.T) { - stem := make([]byte, 31) - for i := range stem { - stem[i] = byte(i) + s := NewNodeStore() + + key1 := make([]byte, 32) + for i := range 31 { + key1[i] = byte(i) + } + key1[31] = 0 + value1 := common.HexToHash("0x0101").Bytes() + + key2 := make([]byte, 32) + copy(key2[:31], key1[:31]) + key2[31] = 255 + value2 := common.HexToHash("0x0202").Bytes() + + if err := s.Insert(key1, value1, nil); err != nil { + t.Fatal(err) + } + if err := s.Insert(key2, value2, nil); err != nil { + t.Fatal(err) } - var values [256][]byte - values[0] = common.HexToHash("0x0101").Bytes() - values[255] = common.HexToHash("0x0202").Bytes() + ns := s.Copy() - node := &StemNode{ - Stem: stem, - Values: values[:], - depth: 10, - } - - // Create a copy - copied := node.Copy() - copiedStem, ok := copied.(*StemNode) - if !ok { - t.Fatalf("Expected StemNode, got %T", copied) - } - - // Check that values are equal but not the same slice - if !bytes.Equal(copiedStem.Stem, node.Stem) { - t.Errorf("Stem mismatch after copy") - } - if &copiedStem.Stem[0] == &node.Stem[0] { - t.Error("Stem slice not properly cloned") - } - - // Check values - if !bytes.Equal(copiedStem.Values[0], node.Values[0]) { + // Check that values are equal + v1, _ := ns.Get(key1, nil) + if !bytes.Equal(v1, value1) { t.Errorf("Value at index 0 mismatch after copy") } - if !bytes.Equal(copiedStem.Values[255], node.Values[255]) { + v2, _ := ns.Get(key2, nil) + if !bytes.Equal(v2, value2) { t.Errorf("Value at index 255 mismatch after copy") } - - // Check that value slices are cloned - if copiedStem.Values[0] != nil && &copiedStem.Values[0][0] == &node.Values[0][0] { - t.Error("Value slice not properly cloned") - } - - // Check depth - if copiedStem.depth != node.depth { - t.Errorf("Depth mismatch: expected %d, got %d", node.depth, copiedStem.depth) - } } -// TestStemNodeHash tests the Hash method +// TestStemNodeHash tests the Hash method. func TestStemNodeHash(t *testing.T) { - stem := make([]byte, 31) - var values [256][]byte - values[0] = common.HexToHash("0x0101").Bytes() + s := NewNodeStore() - node := &StemNode{ - Stem: stem, - Values: values[:], - depth: 0, + key := make([]byte, 32) + key[31] = 0 + value := common.HexToHash("0x0101").Bytes() + if err := s.Insert(key, value, nil); err != nil { + t.Fatal(err) } - hash1 := node.Hash() + hash1 := s.ComputeHash(s.Root()) // Hash should be deterministic - hash2 := node.Hash() + hash2 := s.ComputeHash(s.Root()) if hash1 != hash2 { t.Errorf("Hash not deterministic: %x != %x", hash1, hash2) } // Changing a value should change the hash - node.Values[1] = common.HexToHash("0x0202").Bytes() - node.mustRecompute = true - hash3 := node.Hash() + key2 := make([]byte, 32) + key2[31] = 1 + value2 := common.HexToHash("0x0202").Bytes() + if err := s.Insert(key2, value2, nil); err != nil { + t.Fatal(err) + } + hash3 := s.ComputeHash(s.Root()) if hash1 == hash3 { t.Error("Hash didn't change after modifying values") } } -// TestStemNodeGetValuesAtStem tests GetValuesAtStem method +// TestStemNodeGetValuesAtStem tests GetValuesAtStem method via NodeStore. func TestStemNodeGetValuesAtStem(t *testing.T) { + s := NewNodeStore() + stem := make([]byte, 31) for i := range stem { stem[i] = byte(i) } - var values [256][]byte + values := make([][]byte, 256) values[0] = common.HexToHash("0x0101").Bytes() values[10] = common.HexToHash("0x0202").Bytes() values[255] = common.HexToHash("0x0303").Bytes() - node := &StemNode{ - Stem: stem, - Values: values[:], - depth: 0, + if err := s.InsertValuesAtStem(stem, values, nil); err != nil { + t.Fatal(err) } // GetValuesAtStem with matching stem - retrievedValues, err := node.GetValuesAtStem(stem, nil) + retrievedValues, err := s.GetValuesAtStem(stem, nil) if err != nil { t.Fatalf("Failed to get values: %v", err) } - // Check that all values match - for i := range 256 { - if !bytes.Equal(retrievedValues[i], values[i]) { - t.Errorf("Value mismatch at index %d", i) - } + if !bytes.Equal(retrievedValues[0], values[0]) { + t.Error("Value at index 0 mismatch") + } + if !bytes.Equal(retrievedValues[10], values[10]) { + t.Error("Value at index 10 mismatch") + } + if !bytes.Equal(retrievedValues[255], values[255]) { + t.Error("Value at index 255 mismatch") } - // GetValuesAtStem with different stem should return nil + // GetValuesAtStem with different stem should return nil values differentStem := make([]byte, 31) differentStem[0] = 0xFF - shouldBeNil, err := node.GetValuesAtStem(differentStem, nil) + shouldBeEmpty, err := s.GetValuesAtStem(differentStem, nil) if err != nil { t.Fatalf("Failed to get values with different stem: %v", err) } - if shouldBeNil != nil { - t.Error("Expected nil for different stem, got non-nil") + allNil := true + for _, v := range shouldBeEmpty { + if v != nil { + allNil = false + break + } + } + if !allNil { + t.Error("Expected all nil values for different stem") } } -// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method +// TestStemNodeInsertValuesAtStem tests InsertValuesAtStem method via NodeStore. func TestStemNodeInsertValuesAtStem(t *testing.T) { + s := NewNodeStore() + stem := make([]byte, 31) - var values [256][]byte + values := make([][]byte, 256) values[0] = common.HexToHash("0x0101").Bytes() - node := &StemNode{ - Stem: stem, - Values: values[:], - depth: 0, + if err := s.InsertValuesAtStem(stem, values, nil); err != nil { + t.Fatal(err) } // Insert new values at the same stem - var newValues [256][]byte + newValues := make([][]byte, 256) newValues[1] = common.HexToHash("0x0202").Bytes() newValues[2] = common.HexToHash("0x0303").Bytes() - newNode, err := node.InsertValuesAtStem(stem, newValues[:], nil, 0) - if err != nil { - t.Fatalf("Failed to insert values: %v", err) - } - - stemNode, ok := newNode.(*StemNode) - if !ok { - t.Fatalf("Expected StemNode, got %T", newNode) + if err := s.InsertValuesAtStem(stem, newValues, nil); err != nil { + t.Fatal(err) } // Check that all values are present - if !bytes.Equal(stemNode.Values[0], values[0]) { + retrieved, err := s.GetValuesAtStem(stem, nil) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(retrieved[0], values[0]) { t.Error("Original value at index 0 missing") } - if !bytes.Equal(stemNode.Values[1], newValues[1]) { + if !bytes.Equal(retrieved[1], newValues[1]) { t.Error("New value at index 1 missing") } - if !bytes.Equal(stemNode.Values[2], newValues[2]) { + if !bytes.Equal(retrieved[2], newValues[2]) { t.Error("New value at index 2 missing") } } -// TestStemNodeGetHeight tests GetHeight method +// TestStemNodeGetHeight tests GetHeight method via NodeStore. func TestStemNodeGetHeight(t *testing.T) { - node := &StemNode{ - Stem: make([]byte, 31), - Values: make([][]byte, 256), - depth: 0, + s := NewNodeStore() + + key := make([]byte, 32) + value := common.HexToHash("0x01").Bytes() + if err := s.Insert(key, value, nil); err != nil { + t.Fatal(err) } - height := node.GetHeight() + height := s.GetHeight(s.Root()) if height != 1 { t.Errorf("Expected height 1, got %d", height) } } -// TestStemNodeCollectNodes tests CollectNodes method +// TestStemNodeCollectNodes tests CollectNodes method via NodeStore. func TestStemNodeCollectNodes(t *testing.T) { + s := NewNodeStore() + stem := make([]byte, 31) - var values [256][]byte + values := make([][]byte, 256) values[0] = common.HexToHash("0x0101").Bytes() - node := &StemNode{ - Stem: stem, - Values: values[:], - depth: 0, - dirty: true, + if err := s.InsertValuesAtStem(stem, values, nil); err != nil { + t.Fatal(err) } var collectedPaths [][]byte - var collectedNodes []BinaryNode - - flushFn := func(path []byte, n BinaryNode) { - // Make a copy of the path + flushFn := func(path []byte, hash common.Hash, serialized []byte) { pathCopy := make([]byte, len(path)) copy(pathCopy, path) collectedPaths = append(collectedPaths, pathCopy) - collectedNodes = append(collectedNodes, n) } - err := node.CollectNodes([]byte{0, 1, 0}, flushFn) + err := s.CollectNodes(s.Root(), []byte{0, 1, 0}, flushFn, MaxGroupDepth) if err != nil { t.Fatalf("Failed to collect nodes: %v", err) } // Should have collected one node (itself) - if len(collectedNodes) != 1 { - t.Errorf("Expected 1 collected node, got %d", len(collectedNodes)) - } - - // Check that the collected node is the same - if collectedNodes[0] != node { - t.Error("Collected node doesn't match original") + if len(collectedPaths) != 1 { + t.Errorf("Expected 1 collected node, got %d", len(collectedPaths)) } // Check the path @@ -413,44 +335,3 @@ func TestStemNodeCollectNodes(t *testing.T) { t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0]) } } - -// TestStemNodeCollectNodesSkipsClean verifies that a clean stem is not -// flushed, and that flushing a dirty stem clears its dirty flag so that -// a subsequent CollectNodes on the same node is a no-op. -func TestStemNodeCollectNodesSkipsClean(t *testing.T) { - stem := make([]byte, 31) - node := &StemNode{ - Stem: stem, - Values: make([][]byte, 256), - depth: 0, - } - - var collected []BinaryNode - flushFn := func(_ []byte, n BinaryNode) { collected = append(collected, n) } - - if err := node.CollectNodes([]byte{0}, flushFn); err != nil { - t.Fatalf("CollectNodes on clean stem: %v", err) - } - if len(collected) != 0 { - t.Fatalf("expected clean stem not to be flushed, got %d", len(collected)) - } - - node.dirty = true - if err := node.CollectNodes([]byte{0}, flushFn); err != nil { - t.Fatalf("CollectNodes on dirty stem: %v", err) - } - if len(collected) != 1 { - t.Fatalf("expected dirty stem to be flushed once, got %d", len(collected)) - } - if node.dirty { - t.Errorf("stem dirty flag should be cleared after flush") - } - - collected = nil - if err := node.CollectNodes([]byte{0}, flushFn); err != nil { - t.Fatalf("CollectNodes after flush: %v", err) - } - if len(collected) != 0 { - t.Errorf("expected no flush on clean stem, got %d", len(collected)) - } -} diff --git a/trie/bintrie/store_commit.go b/trie/bintrie/store_commit.go new file mode 100644 index 0000000000..3f71a71e38 --- /dev/null +++ b/trie/bintrie/store_commit.go @@ -0,0 +1,477 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "errors" + "fmt" + "math/bits" + "sync" + + "github.com/ethereum/go-ethereum/common" +) + +// NodeFlushFn is called during commit to flush serialized nodes. +type NodeFlushFn func(path []byte, hash common.Hash, serialized []byte) + +// Hash computes and returns the root hash. +func (s *NodeStore) Hash() common.Hash { + return s.ComputeHash(s.root) +} + +// ComputeHash computes the hash of the node referenced by ref. +func (s *NodeStore) ComputeHash(ref NodeRef) common.Hash { + switch ref.Kind() { + case KindInternal: + return s.hashInternal(ref.Index()) + case KindStem: + return s.getStem(ref.Index()).Hash() + case KindHashed: + return s.getHashed(ref.Index()).hash + case KindEmpty: + return common.Hash{} + default: + return common.Hash{} + } +} + +// hashInternal computes the hash of an InternalNode. At shallow depths +// (< parallelHashDepth), the left subtree is hashed in a goroutine while +// the right subtree is hashed inline. This is safe because left and right +// subtrees are disjoint in a well-formed tree — no node appears in both. +// ComputeHash must not be called concurrently with mutations to the NodeStore. +func (s *NodeStore) hashInternal(idx uint32) common.Hash { + node := s.getInternal(idx) + if !node.mustRecompute { + return node.hash + } + + if node.depth < parallelHashDepth { + var input [64]byte + var lh common.Hash + var wg sync.WaitGroup + if !node.left.IsEmpty() { + wg.Add(1) + go func() { + lh = s.ComputeHash(node.left) + wg.Done() + }() + } + if !node.right.IsEmpty() { + rh := s.ComputeHash(node.right) + copy(input[32:], rh[:]) + } + wg.Wait() + copy(input[:32], lh[:]) + node.hash = sha256Sum256(input[:]) + node.mustRecompute = false + return node.hash + } + + var input [64]byte + if !node.left.IsEmpty() { + lh := s.ComputeHash(node.left) + copy(input[:32], lh[:]) + } + if !node.right.IsEmpty() { + rh := s.ComputeHash(node.right) + copy(input[32:], rh[:]) + } + node.hash = sha256Sum256(input[:]) + node.mustRecompute = false + return node.hash +} + +// --- Serialization --- + +// countSubtreeChildren counts non-empty children at the bottom layer of a subtree. +func (s *NodeStore) countSubtreeChildren(ref NodeRef, remainingDepth int) int { + if remainingDepth == 0 { + if ref.IsEmpty() { + return 0 + } + return 1 + } + if ref.Kind() == KindInternal { + node := s.getInternal(ref.Index()) + return s.countSubtreeChildren(node.left, remainingDepth-1) + s.countSubtreeChildren(node.right, remainingDepth-1) + } + if ref.IsEmpty() { + return 0 + } + return 1 +} + +// serializeSubtreeDirect writes child hashes directly into the serialized buffer. +func (s *NodeStore) serializeSubtreeDirect(ref NodeRef, remainingDepth int, position int, absoluteDepth int, bitmap []byte, out []byte, hashOffset *int) { + if remainingDepth == 0 { + if ref.IsEmpty() { + return + } + bitmap[position/8] |= 1 << (7 - (position % 8)) + h := s.ComputeHash(ref) + copy(out[*hashOffset:*hashOffset+HashSize], h[:]) + *hashOffset += HashSize + return + } + + if ref.Kind() == KindInternal { + node := s.getInternal(ref.Index()) + leftPos := position * 2 + rightPos := position*2 + 1 + s.serializeSubtreeDirect(node.left, remainingDepth-1, leftPos, absoluteDepth+1, bitmap, out, hashOffset) + s.serializeSubtreeDirect(node.right, remainingDepth-1, rightPos, absoluteDepth+1, bitmap, out, hashOffset) + return + } + + if ref.IsEmpty() { + return + } + + // Leaf (StemNode or HashedNode) at a non-zero remaining depth: compute its position. + leafPos := position + if ref.Kind() == KindStem { + sn := s.getStem(ref.Index()) + for d := 0; d < remainingDepth; d++ { + bit := sn.Stem[(absoluteDepth+d)/8] >> (7 - ((absoluteDepth + d) % 8)) & 1 + leafPos = leafPos*2 + int(bit) + } + } else { + leafPos = position << remainingDepth + } + bitmap[leafPos/8] |= 1 << (7 - (leafPos % 8)) + h := s.ComputeHash(ref) + copy(out[*hashOffset:*hashOffset+HashSize], h[:]) + *hashOffset += HashSize +} + +// SerializeNode serializes a node referenced by ref. +func (s *NodeStore) SerializeNode(ref NodeRef, groupDepth int) []byte { + if groupDepth < 1 || groupDepth > MaxGroupDepth { + panic("groupDepth must be between 1 and 8") + } + + switch ref.Kind() { + case KindInternal: + node := s.getInternal(ref.Index()) + bitmapSize := BitmapSizeForDepth(groupDepth) + hashCount := s.countSubtreeChildren(ref, groupDepth) + serializedLen := NodeTypeBytes + 1 + bitmapSize + hashCount*HashSize + serialized := make([]byte, serializedLen) + serialized[0] = nodeTypeInternal + serialized[1] = byte(groupDepth) + bitmap := serialized[2 : 2+bitmapSize] + hashOffset := 2 + bitmapSize + s.serializeSubtreeDirect(ref, groupDepth, 0, int(node.depth), bitmap, serialized, &hashOffset) + return serialized + + case KindStem: + sn := s.getStem(ref.Index()) + serializedLen := NodeTypeBytes + StemSize + StemBitmapSize + len(sn.valueData) + serialized := make([]byte, serializedLen) + serialized[0] = nodeTypeStem + copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], sn.Stem[:]) + copy(serialized[NodeTypeBytes+StemSize:NodeTypeBytes+StemSize+StemBitmapSize], sn.bitmap[:]) + copy(serialized[NodeTypeBytes+StemSize+StemBitmapSize:], sn.valueData) + return serialized + + default: + panic(fmt.Sprintf("SerializeNode: unexpected node kind %d", ref.Kind())) + } +} + +// --- Deserialization --- + +var errInvalidSerializedLength = errors.New("invalid serialized node length") + +// DeserializeNode deserializes a node from bytes, recomputing its hash. +func (s *NodeStore) DeserializeNode(serialized []byte, depth int) (NodeRef, error) { + return s.deserializeNode(serialized, depth, common.Hash{}, true) +} + +// DeserializeNodeWithHash deserializes a node, using the provided hash. +func (s *NodeStore) DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (NodeRef, error) { + return s.deserializeNode(serialized, depth, hn, false) +} + +func (s *NodeStore) deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute bool) (NodeRef, error) { + if len(serialized) == 0 { + return EmptyRef, nil + } + + switch serialized[0] { + case nodeTypeInternal: + if len(serialized) < NodeTypeBytes+1 { + return EmptyRef, errInvalidSerializedLength + } + groupDepth := int(serialized[1]) + if groupDepth < 1 || groupDepth > MaxGroupDepth { + return EmptyRef, errors.New("invalid group depth") + } + bitmapSize := BitmapSizeForDepth(groupDepth) + if len(serialized) < NodeTypeBytes+1+bitmapSize { + return EmptyRef, errInvalidSerializedLength + } + bitmap := serialized[2 : 2+bitmapSize] + hashData := serialized[2+bitmapSize:] + + hashIdx := 0 + ref, err := s.deserializeSubtree(groupDepth, 0, depth, bitmap, hashData, &hashIdx, mustRecompute) + if err != nil { + return EmptyRef, err + } + if ref.Kind() == KindInternal && !mustRecompute { + s.getInternal(ref.Index()).hash = hn + } + return ref, nil + + case nodeTypeStem: + if len(serialized) < 64 { + return EmptyRef, errInvalidSerializedLength + } + stemIdx := s.allocStem() + sn := s.getStem(stemIdx) + copy(sn.Stem[:], serialized[NodeTypeBytes:NodeTypeBytes+StemSize]) + copy(sn.bitmap[:], serialized[NodeTypeBytes+StemSize:NodeTypeBytes+StemSize+StemBitmapSize]) + + var count uint16 + for i := range StemBitmapSize { + count += uint16(bits.OnesCount8(sn.bitmap[i])) + } + sn.count = count + dataStart := NodeTypeBytes + StemSize + StemBitmapSize + dataEnd := dataStart + int(count)*HashSize + if len(serialized) < dataEnd { + return EmptyRef, errInvalidSerializedLength + } + // Zero-copy sub-slice of serialized data + sn.valueData = serialized[dataStart:dataEnd] + sn.shared = true + sn.depth = uint8(depth) + sn.hash = hn + sn.mustRecompute = mustRecompute + return MakeRef(KindStem, stemIdx), nil + + default: + return EmptyRef, errors.New("invalid node type") + } +} + +func (s *NodeStore) deserializeSubtree(remainingDepth int, position int, nodeDepth int, bitmap []byte, hashData []byte, hashIdx *int, mustRecompute bool) (NodeRef, error) { + if remainingDepth == 0 { + if bitmap[position/8]>>(7-(position%8))&1 == 1 { + if len(hashData) < (*hashIdx+1)*HashSize { + return EmptyRef, errInvalidSerializedLength + } + hash := common.BytesToHash(hashData[*hashIdx*HashSize : (*hashIdx+1)*HashSize]) + *hashIdx++ + return s.newHashedRef(hash), nil + } + return EmptyRef, nil + } + + leftPos := position * 2 + rightPos := position*2 + 1 + + left, err := s.deserializeSubtree(remainingDepth-1, leftPos, nodeDepth+1, bitmap, hashData, hashIdx, true) + if err != nil { + return EmptyRef, err + } + right, err := s.deserializeSubtree(remainingDepth-1, rightPos, nodeDepth+1, bitmap, hashData, hashIdx, true) + if err != nil { + return EmptyRef, err + } + + if left.IsEmpty() && right.IsEmpty() { + return EmptyRef, nil + } + + if nodeDepth > 248 { + panic("node depth exceeds maximum binary trie depth") + } + idx := s.allocInternal() + n := s.getInternal(idx) + n.depth = uint8(nodeDepth) + n.left = left + n.right = right + n.mustRecompute = mustRecompute + return MakeRef(KindInternal, idx), nil +} + +// --- CollectNodes (Commit) --- + +// CollectNodes traverses the trie, flushing nodes at group boundaries. +func (s *NodeStore) CollectNodes(ref NodeRef, path []byte, flushfn NodeFlushFn, groupDepth int) error { + if groupDepth < 1 || groupDepth > MaxGroupDepth { + return errors.New("groupDepth must be between 1 and 8") + } + buf := make([]byte, len(path), len(path)+MaxGroupDepth+1) + copy(buf, path) + return s.collectNodesBuf(ref, buf, flushfn, groupDepth) +} + +func (s *NodeStore) collectNodesBuf(ref NodeRef, buf []byte, flushfn NodeFlushFn, groupDepth int) error { + switch ref.Kind() { + case KindInternal: + node := s.getInternal(ref.Index()) + if int(node.depth)%groupDepth == 0 { + if err := s.collectChildGroupsBuf(ref, buf, flushfn, groupDepth, groupDepth-1); err != nil { + return err + } + serialized := s.SerializeNode(ref, groupDepth) + flushfn(buf, s.ComputeHash(ref), serialized) + return nil + } + return s.collectChildGroupsBuf(ref, buf, flushfn, groupDepth, groupDepth-(int(node.depth)%groupDepth)-1) + + case KindStem: + serialized := s.SerializeNode(ref, groupDepth) + flushfn(buf, s.ComputeHash(ref), serialized) + return nil + + case KindHashed, KindEmpty: + return nil + + default: + return fmt.Errorf("collectNodesBuf: unexpected kind %d", ref.Kind()) + } +} + +func (s *NodeStore) collectChildGroupsBuf(ref NodeRef, buf []byte, flushfn NodeFlushFn, groupDepth int, remainingLevels int) error { + if ref.Kind() != KindInternal { + return nil + } + node := s.getInternal(ref.Index()) + saved := len(buf) + childDepth := int(node.depth) + 1 + + if remainingLevels == 0 { + if !node.left.IsEmpty() { + buf = append(buf, 0) + if err := s.collectNodesBuf(node.left, buf, flushfn, groupDepth); err != nil { + return err + } + buf = buf[:saved] + } + if !node.right.IsEmpty() { + buf = append(buf, 1) + if err := s.collectNodesBuf(node.right, buf, flushfn, groupDepth); err != nil { + return err + } + } + return nil + } + + // Left child + if !node.left.IsEmpty() { + if node.left.Kind() == KindInternal { + buf = append(buf, 0) + if err := s.collectChildGroupsBuf(node.left, buf, flushfn, groupDepth, remainingLevels-1); err != nil { + return err + } + buf = buf[:saved] + } else { + buf = append(buf, 0) + buf = s.extendPathBuf(buf, node.left, remainingLevels, childDepth) + if err := s.collectNodesBuf(node.left, buf, flushfn, groupDepth); err != nil { + return err + } + buf = buf[:saved] + } + } + + // Right child + if !node.right.IsEmpty() { + if node.right.Kind() == KindInternal { + buf = append(buf, 1) + if err := s.collectChildGroupsBuf(node.right, buf, flushfn, groupDepth, remainingLevels-1); err != nil { + return err + } + } else { + buf = append(buf, 1) + buf = s.extendPathBuf(buf, node.right, remainingLevels, childDepth) + if err := s.collectNodesBuf(node.right, buf, flushfn, groupDepth); err != nil { + return err + } + } + } + + return nil +} + +// extendPathBuf extends the path buffer to the group's leaf boundary. +func (s *NodeStore) extendPathBuf(buf []byte, ref NodeRef, remainingLevels int, absoluteDepth int) []byte { + if remainingLevels <= 0 { + return buf + } + if ref.Kind() == KindStem { + sn := s.getStem(ref.Index()) + for d := 0; d < remainingLevels; d++ { + bit := sn.Stem[(absoluteDepth+d)/8] >> (7 - ((absoluteDepth + d) % 8)) & 1 + buf = append(buf, bit) + } + } else { + for d := 0; d < remainingLevels; d++ { + buf = append(buf, 0) + } + } + return buf +} + +// ToDot generates a DOT representation for debugging. +func (s *NodeStore) ToDot(ref NodeRef, parent, path string) string { + switch ref.Kind() { + case KindInternal: + node := s.getInternal(ref.Index()) + me := fmt.Sprintf("internal%s", path) + ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, s.ComputeHash(ref)) + if len(parent) > 0 { + ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) + } + if !node.left.IsEmpty() { + ret += s.ToDot(node.left, me, fmt.Sprintf("%s%02x", path, 0)) + } + if !node.right.IsEmpty() { + ret += s.ToDot(node.right, me, fmt.Sprintf("%s%02x", path, 1)) + } + return ret + case KindStem: + sn := s.getStem(ref.Index()) + me := fmt.Sprintf("stem%s", path) + ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, sn.Stem, sn.Hash()) + ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) + idx := 0 + for i := range StemNodeWidth { + if sn.bitmap[i/8]>>(7-i%8)&1 != 1 { + continue + } + v := sn.valueData[idx*HashSize : (idx+1)*HashSize] + idx++ + ret += fmt.Sprintf("%s%x [label=\"%x\"]\n", me, i, v) + ret += fmt.Sprintf("%s -> %s%x\n", me, me, i) + } + return ret + case KindHashed: + hn := s.getHashed(ref.Index()) + me := fmt.Sprintf("hash%s", path) + ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, hn.hash) + ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) + return ret + default: + return "" + } +} diff --git a/trie/bintrie/store_ops.go b/trie/bintrie/store_ops.go new file mode 100644 index 0000000000..72b2c2389b --- /dev/null +++ b/trie/bintrie/store_ops.go @@ -0,0 +1,556 @@ +// Copyright 2025 go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bintrie + +import ( + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +// NodeResolverFn resolves a hashed node from the database. +type NodeResolverFn func([]byte, common.Hash) ([]byte, error) + +// GetSingle retrieves a single value at stem[suffix] from the trie root. +func (s *NodeStore) GetSingle(stem []byte, suffix byte, resolver NodeResolverFn) ([]byte, error) { + return s.getSingle(s.root, stem, suffix, resolver) +} + +// getSingle retrieves a single value using iterative traversal. +func (s *NodeStore) getSingle(ref NodeRef, stem []byte, suffix byte, resolver NodeResolverFn) ([]byte, error) { + cur := ref + // Track parent for HashedNode resolution (update parent's child ref). + var parentIdx uint32 + var parentIsLeft bool + hasParent := false + + for { + switch cur.Kind() { + case KindInternal: + node := s.getInternal(cur.Index()) + if node.depth >= 31*8 { + return nil, errors.New("node too deep") + } + bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1 + parentIdx = cur.Index() + hasParent = true + if bit == 0 { + parentIsLeft = true + cur = node.left + } else { + parentIsLeft = false + cur = node.right + } + + case KindStem: + sn := s.getStem(cur.Index()) + if sn.Stem != [StemSize]byte(stem[:StemSize]) { + return nil, nil + } + return sn.getValue(suffix), nil + + case KindHashed: + if !hasParent { + return nil, errors.New("getSingle: hashed node at root") + } + hn := s.getHashed(cur.Index()) + parentNode := s.getInternal(parentIdx) + path := makeKeyPath(int(parentNode.depth), stem) + data, err := resolver(path, hn.hash) + if err != nil { + return nil, fmt.Errorf("getSingle resolve error: %w", err) + } + resolved, err := s.DeserializeNodeWithHash(data, int(parentNode.depth)+1, hn.hash) + if err != nil { + return nil, fmt.Errorf("getSingle deserialization error: %w", err) + } + // Update parent's child ref + s.freeHashedNode(cur.Index()) + if parentIsLeft { + parentNode.left = resolved + } else { + parentNode.right = resolved + } + cur = resolved + + case KindEmpty: + return nil, nil + + default: + return nil, fmt.Errorf("getSingle: unexpected node kind %d", cur.Kind()) + } + } +} + +// GetValuesAtStem retrieves all 256 values at a stem. +func (s *NodeStore) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) { + return s.getValuesAtStem(s.root, stem, resolver) +} + +// getValuesAtStem uses iterative traversal to find the StemNode. +func (s *NodeStore) getValuesAtStem(ref NodeRef, stem []byte, resolver NodeResolverFn) ([][]byte, error) { + cur := ref + var parentIdx uint32 + var parentIsLeft bool + hasParent := false + + for { + switch cur.Kind() { + case KindInternal: + node := s.getInternal(cur.Index()) + if node.depth >= 31*8 { + return nil, errors.New("node too deep") + } + bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1 + parentIdx = cur.Index() + hasParent = true + if bit == 0 { + parentIsLeft = true + cur = node.left + } else { + parentIsLeft = false + cur = node.right + } + + case KindStem: + sn := s.getStem(cur.Index()) + if sn.Stem != [StemSize]byte(stem[:StemSize]) { + return nil, nil + } + return sn.allValues(), nil + + case KindHashed: + if !hasParent { + return nil, errors.New("getValuesAtStem: hashed node at root") + } + hn := s.getHashed(cur.Index()) + parentNode := s.getInternal(parentIdx) + path := makeKeyPath(int(parentNode.depth), stem) + data, err := resolver(path, hn.hash) + if err != nil { + return nil, fmt.Errorf("getValuesAtStem resolve error: %w", err) + } + resolved, err := s.DeserializeNodeWithHash(data, int(parentNode.depth)+1, hn.hash) + if err != nil { + return nil, fmt.Errorf("getValuesAtStem deserialization error: %w", err) + } + s.freeHashedNode(cur.Index()) + if parentIsLeft { + parentNode.left = resolved + } else { + parentNode.right = resolved + } + cur = resolved + + case KindEmpty: + var values [StemNodeWidth][]byte + return values[:], nil + + default: + return nil, fmt.Errorf("getValuesAtStem: unexpected node kind %d", cur.Kind()) + } + } +} + +// InsertSingle inserts a single value at stem[suffix] into the trie. +func (s *NodeStore) InsertSingle(stem []byte, suffix byte, value []byte, resolver NodeResolverFn) error { + if len(value) != HashSize { + return errors.New("invalid insertion: value length") + } + + // Handle root-is-empty case + if s.root.IsEmpty() { + ref := s.newStemRef(stem, 0) + sn := s.getStem(ref.Index()) + sn.setValue(suffix, value) + s.root = ref + return nil + } + + // Handle root-is-stem case + if s.root.Kind() == KindStem { + sn := s.getStem(s.root.Index()) + if sn.Stem == [StemSize]byte(stem[:StemSize]) { + sn.ensureWritable() + sn.setValue(suffix, value) + sn.mustRecompute = true + return nil + } + // Different stem — promote root to internal node via split + newRoot := s.splitStemInsert(s.root, stem, suffix, value, int(sn.depth)) + s.root = newRoot + return nil + } + + // Root is an InternalNode — iterative descent + return s.insertSingleInternal(stem, suffix, value, resolver) +} + +// insertSingleInternal handles insertion when root is an InternalNode. +func (s *NodeStore) insertSingleInternal(stem []byte, suffix byte, value []byte, resolver NodeResolverFn) error { + type pathEntry struct { + internalIdx uint32 + isLeft bool + } + var pathStack [256]pathEntry // stack-allocated, max depth 248 + pathLen := 0 + + cur := s.root + + for { + switch cur.Kind() { + case KindInternal: + node := s.getInternal(cur.Index()) + node.mustRecompute = true + bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1 + pathStack[pathLen] = pathEntry{internalIdx: cur.Index(), isLeft: bit == 0} + pathLen++ + if bit == 0 { + cur = node.left + } else { + cur = node.right + } + + case KindStem: + sn := s.getStem(cur.Index()) + if sn.Stem == [StemSize]byte(stem[:StemSize]) { + sn.ensureWritable() + sn.setValue(suffix, value) + sn.mustRecompute = true + return nil + } + // Different stem — split + parentDepth := int(s.getInternal(pathStack[pathLen-1].internalIdx).depth) + 1 + newRef := s.splitStemInsert(cur, stem, suffix, value, parentDepth) + p := pathStack[pathLen-1] + parent := s.getInternal(p.internalIdx) + if p.isLeft { + parent.left = newRef + } else { + parent.right = newRef + } + return nil + + case KindHashed: + if pathLen == 0 { + return errors.New("insertSingle: hashed node at root") + } + p := pathStack[pathLen-1] + parentNode := s.getInternal(p.internalIdx) + hn := s.getHashed(cur.Index()) + path := makeKeyPath(int(parentNode.depth), stem) + data, err := resolver(path, hn.hash) + if err != nil { + return fmt.Errorf("insertSingle resolve error: %w", err) + } + resolved, err := s.DeserializeNodeWithHash(data, int(parentNode.depth)+1, hn.hash) + if err != nil { + return fmt.Errorf("insertSingle deserialization error: %w", err) + } + s.freeHashedNode(cur.Index()) + if p.isLeft { + parentNode.left = resolved + } else { + parentNode.right = resolved + } + cur = resolved + + case KindEmpty: + parentDepth := int(s.getInternal(pathStack[pathLen-1].internalIdx).depth) + 1 + ref := s.newStemRef(stem, parentDepth) + sn := s.getStem(ref.Index()) + sn.setValue(suffix, value) + p := pathStack[pathLen-1] + parent := s.getInternal(p.internalIdx) + if p.isLeft { + parent.left = ref + } else { + parent.right = ref + } + return nil + + default: + return fmt.Errorf("insertSingle: unexpected node kind %d", cur.Kind()) + } + } +} + +// splitStemInsert handles the case where we need to split a StemNode +// into a chain of InternalNodes because the new key has a different stem. +func (s *NodeStore) splitStemInsert(existingRef NodeRef, newStem []byte, suffix byte, value []byte, depth int) NodeRef { + existing := s.getStem(existingRef.Index()) + existingDepth := depth + + var firstRef NodeRef + var lastInternalIdx uint32 + var lastIsLeft bool + first := true + + for { + bitExisting := existing.Stem[existingDepth/8] >> (7 - (existingDepth % 8)) & 1 + bitNew := newStem[existingDepth/8] >> (7 - (existingDepth % 8)) & 1 + + newInternalIdx := s.allocInternal() + newInternal := s.getInternal(newInternalIdx) + newInternal.depth = uint8(existingDepth) + newInternal.mustRecompute = true + newRef := MakeRef(KindInternal, newInternalIdx) + + if first { + firstRef = newRef + first = false + } else { + parent := s.getInternal(lastInternalIdx) + if lastIsLeft { + parent.left = newRef + } else { + parent.right = newRef + } + } + + if bitExisting != bitNew { + // Divergence point + existing.depth = uint8(existingDepth + 1) + + newStemIdx := s.allocStem() + newSn := s.getStem(newStemIdx) + copy(newSn.Stem[:], newStem[:StemSize]) + newSn.depth = uint8(existingDepth + 1) + newSn.mustRecompute = true + newSn.setValue(suffix, value) + newStemRef := MakeRef(KindStem, newStemIdx) + + if bitExisting == 0 { + newInternal.left = existingRef + newInternal.right = newStemRef + } else { + newInternal.left = newStemRef + newInternal.right = existingRef + } + return firstRef + } + + // Same bit — continue splitting + lastInternalIdx = newInternalIdx + lastIsLeft = (bitExisting == 0) + existingDepth++ + } +} + +// InsertValuesAtStem inserts a full group of values at the given stem. +func (s *NodeStore) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn) error { + newRoot, err := s.insertValuesAtStem(s.root, stem, values, resolver, 0) + if err != nil { + return err + } + s.root = newRoot + return nil +} + +// insertValuesAtStem recursively inserts values at a stem. +func (s *NodeStore) insertValuesAtStem(ref NodeRef, stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (NodeRef, error) { + switch ref.Kind() { + case KindInternal: + node := s.getInternal(ref.Index()) + bit := stem[node.depth/8] >> (7 - (node.depth % 8)) & 1 + if bit == 0 { + if node.left.Kind() == KindHashed { + hn := s.getHashed(node.left.Index()) + path := makeKeyPath(int(node.depth), stem) + data, err := resolver(path, hn.hash) + if err != nil { + return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) + } + resolved, err := s.DeserializeNodeWithHash(data, int(node.depth)+1, hn.hash) + if err != nil { + return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err) + } + s.freeHashedNode(node.left.Index()) + node.left = resolved + } + newChild, err := s.insertValuesAtStem(node.left, stem, values, resolver, depth+1) + if err != nil { + return ref, err + } + node.left = newChild + } else { + if node.right.Kind() == KindHashed { + hn := s.getHashed(node.right.Index()) + path := makeKeyPath(int(node.depth), stem) + data, err := resolver(path, hn.hash) + if err != nil { + return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) + } + resolved, err := s.DeserializeNodeWithHash(data, int(node.depth)+1, hn.hash) + if err != nil { + return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err) + } + s.freeHashedNode(node.right.Index()) + node.right = resolved + } + newChild, err := s.insertValuesAtStem(node.right, stem, values, resolver, depth+1) + if err != nil { + return ref, err + } + node.right = newChild + } + node.mustRecompute = true + return ref, nil + + case KindStem: + sn := s.getStem(ref.Index()) + if sn.Stem == [StemSize]byte(stem[:StemSize]) { + // Same stem — merge values + sn.ensureWritable() + for i, v := range values { + if v != nil { + sn.setValue(byte(i), v) + sn.mustRecompute = true + } + } + return ref, nil + } + // Different stem — split + return s.splitStemValuesInsert(ref, stem, values, resolver, depth) + + case KindHashed: + hn := s.getHashed(ref.Index()) + path, err := keyToPath(depth, stem) + if err != nil { + return ref, fmt.Errorf("InsertValuesAtStem path error: %w", err) + } + if resolver == nil { + return ref, errors.New("InsertValuesAtStem: resolver is nil") + } + data, err := resolver(path, hn.hash) + if err != nil { + return ref, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) + } + resolved, err := s.DeserializeNodeWithHash(data, depth, hn.hash) + if err != nil { + return ref, fmt.Errorf("InsertValuesAtStem deserialization error: %w", err) + } + s.freeHashedNode(ref.Index()) + return s.insertValuesAtStem(resolved, stem, values, resolver, depth) + + case KindEmpty: + // Create new StemNode + stemIdx := s.allocStem() + sn := s.getStem(stemIdx) + copy(sn.Stem[:], stem[:StemSize]) + sn.depth = uint8(depth) + sn.mustRecompute = true + for i, v := range values { + if v != nil { + sn.count++ + sn.bitmap[i/8] |= 1 << (7 - (i % 8)) + sn.valueData = append(sn.valueData, v[:HashSize]...) + } + } + return MakeRef(KindStem, stemIdx), nil + + default: + return ref, fmt.Errorf("insertValuesAtStem: unexpected kind %d", ref.Kind()) + } +} + +// splitStemValuesInsert handles splitting a StemNode when inserting values with a different stem. +func (s *NodeStore) splitStemValuesInsert(existingRef NodeRef, newStem []byte, values [][]byte, resolver NodeResolverFn, depth int) (NodeRef, error) { + existing := s.getStem(existingRef.Index()) + + bitStem := existing.Stem[existing.depth/8] >> (7 - (existing.depth % 8)) & 1 + nRef := s.newInternalRef(int(existing.depth)) + nNode := s.getInternal(nRef.Index()) + existing.depth++ + + bitKey := newStem[nNode.depth/8] >> (7 - (nNode.depth % 8)) & 1 + if bitKey == bitStem { + // Same direction — need deeper split + var child NodeRef + if bitStem == 0 { + nNode.left = existingRef + child = nNode.left + } else { + nNode.right = existingRef + child = nNode.right + } + newChild, err := s.insertValuesAtStem(child, newStem, values, resolver, depth+1) + if err != nil { + return nRef, err + } + if bitStem == 0 { + nNode.left = newChild + nNode.right = EmptyRef + } else { + nNode.right = newChild + nNode.left = EmptyRef + } + } else { + // Divergence — create new StemNode for the new values + newStemIdx := s.allocStem() + newSn := s.getStem(newStemIdx) + copy(newSn.Stem[:], newStem[:StemSize]) + newSn.depth = nNode.depth + 1 + newSn.mustRecompute = true + for i, v := range values { + if v != nil { + newSn.setValue(byte(i), v) + } + } + newStemRef := MakeRef(KindStem, newStemIdx) + + if bitStem == 0 { + nNode.left = existingRef + nNode.right = newStemRef + } else { + nNode.left = newStemRef + nNode.right = existingRef + } + } + return nRef, nil +} + +// Insert inserts a key-value pair using the full 32-byte key. +func (s *NodeStore) Insert(key []byte, value []byte, resolver NodeResolverFn) error { + return s.InsertSingle(key[:StemSize], key[StemSize], value, resolver) +} + +// Get retrieves the value for the given 32-byte key. +func (s *NodeStore) Get(key []byte, resolver NodeResolverFn) ([]byte, error) { + return s.GetSingle(key[:StemSize], key[StemSize], resolver) +} + +// GetHeight returns the height of the trie rooted at ref. +func (s *NodeStore) GetHeight(ref NodeRef) int { + switch ref.Kind() { + case KindInternal: + node := s.getInternal(ref.Index()) + lh := s.GetHeight(node.left) + rh := s.GetHeight(node.right) + if lh > rh { + return 1 + lh + } + return 1 + rh + case KindStem: + return 1 + case KindEmpty: + return 0 + default: + return 0 + } +} diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index 23d014eb33..39fa4dac30 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -19,7 +19,6 @@ package bintrie import ( "bytes" "encoding/binary" - "errors" "fmt" "github.com/ethereum/go-ethereum/common" @@ -31,8 +30,6 @@ import ( "github.com/holiman/uint256" ) -var errInvalidRootType = errors.New("invalid root type") - // ChunkedCode represents a sequence of HashSize-byte chunks of code (StemSize bytes of which // are actual code, and NodeTypeBytes byte is the pushdata offset). type ChunkedCode []byte @@ -108,22 +105,18 @@ func ChunkifyCode(code []byte) ChunkedCode { return chunks } -// NewBinaryNode creates a new empty binary trie -func NewBinaryNode() BinaryNode { - return Empty{} -} - // BinaryTrie is the implementation of https://eips.ethereum.org/EIPS/eip-7864. type BinaryTrie struct { - root BinaryNode - reader *trie.Reader - tracer *trie.PrevalueTracer + store *NodeStore + reader *trie.Reader + tracer *trie.PrevalueTracer + groupDepth int } // ToDot converts the binary trie to a DOT language representation. Useful for debugging. func (t *BinaryTrie) ToDot() string { - t.root.Hash() - return ToDot(t.root) + t.store.ComputeHash(t.store.Root()) + return t.store.ToDot(t.store.Root(), "", "") } // NewBinaryTrie creates a new binary trie. @@ -133,9 +126,10 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err return nil, err } t := &BinaryTrie{ - root: NewBinaryNode(), - reader: reader, - tracer: trie.NewPrevalueTracer(), + store: NewNodeStore(), + reader: reader, + tracer: trie.NewPrevalueTracer(), + groupDepth: MaxGroupDepth, } // Parse the root node if it's not empty if root != types.EmptyBinaryHash && root != types.EmptyRootHash { @@ -143,11 +137,11 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err if err != nil { return nil, err } - node, err := DeserializeNodeWithHash(blob, 0, root) + ref, err := t.store.DeserializeNodeWithHash(blob, 0, root) if err != nil { return nil, err } - t.root = node + t.store.SetRoot(ref) } return t, nil } @@ -176,29 +170,18 @@ func (t *BinaryTrie) GetKey(key []byte) []byte { // GetWithHashedKey returns the value, assuming that the key has already // been hashed. func (t *BinaryTrie) GetWithHashedKey(key []byte) ([]byte, error) { - return t.root.Get(key, t.nodeResolver) + return t.store.Get(key, t.nodeResolver) } // GetAccount returns the account information for the given address. func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error) { var ( - values [][]byte err error acc = &types.StateAccount{} key = GetBinaryTreeKey(addr, zero[:]) ) - switch r := t.root.(type) { - case *InternalNode: - values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver) - case *StemNode: - values, err = r.GetValuesAtStem(key[:StemSize], t.nodeResolver) - case Empty: - return nil, nil - default: - // This will cover HashedNode but that should be fine since the - // root node should always be resolved. - return nil, errInvalidRootType - } + + values, err := t.store.GetValuesAtStem(key[:StemSize], t.nodeResolver) if err != nil { return nil, fmt.Errorf("GetAccount (%x) error: %v", addr, err) } @@ -219,7 +202,7 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error // If the account has been deleted, BasicData and CodeHash will both be // 32-byte zero blobs (not nil). If the account is recreated afterwards, // UpdateAccount overwrites BasicData and CodeHash with non-zero values, - // so this branch won't activate.. + // so this branch won't activate. if bytes.Equal(values[BasicDataLeafKey], zero[:]) && bytes.Equal(values[CodeHashLeafKey], zero[:]) { return nil, nil @@ -238,13 +221,12 @@ func (t *BinaryTrie) GetAccount(addr common.Address) (*types.StateAccount, error // not be modified by the caller. If a node was not found in the database, a // trie.MissingNodeError is returned. func (t *BinaryTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) { - return t.root.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver) + return t.store.Get(GetBinaryTreeKeyStorageSlot(addr, key), t.nodeResolver) } // UpdateAccount updates the account information for the given address. func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, codeLen int) error { var ( - err error basicData [HashSize]byte values = make([][]byte, StemNodeWidth) stem = GetBinaryTreeKey(addr, zero[:]) @@ -265,15 +247,12 @@ func (t *BinaryTrie) UpdateAccount(addr common.Address, acc *types.StateAccount, values[BasicDataLeafKey] = basicData[:] values[CodeHashLeafKey] = acc.CodeHash[:] - t.root, err = t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0) - return err + return t.store.InsertValuesAtStem(stem, values, t.nodeResolver) } // UpdateStem updates the values for the given stem key. func (t *BinaryTrie) UpdateStem(key []byte, values [][]byte) error { - var err error - t.root, err = t.root.InsertValuesAtStem(key, values, t.nodeResolver, 0) - return err + return t.store.InsertValuesAtStem(key, values, t.nodeResolver) } // UpdateStorage associates key with value in the trie. If value has length zero, any @@ -288,11 +267,10 @@ func (t *BinaryTrie) UpdateStorage(address common.Address, key, value []byte) er } else { copy(v[HashSize-len(value):], value[:]) } - root, err := t.root.Insert(k, v[:], t.nodeResolver, 0) + err := t.store.Insert(k, v[:], t.nodeResolver) if err != nil { return fmt.Errorf("UpdateStorage (%x) error: %v", address, err) } - t.root = root return nil } @@ -307,12 +285,7 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error { values[BasicDataLeafKey] = zero[:] values[CodeHashLeafKey] = zero[:] - root, err := t.root.InsertValuesAtStem(stem, values, t.nodeResolver, 0) - if err != nil { - return fmt.Errorf("DeleteAccount (%x) error: %v", addr, err) - } - t.root = root - return nil + return t.store.InsertValuesAtStem(stem, values, t.nodeResolver) } // DeleteStorage removes any existing value for key from the trie. If a node was not @@ -320,18 +293,17 @@ func (t *BinaryTrie) DeleteAccount(addr common.Address) error { func (t *BinaryTrie) DeleteStorage(addr common.Address, key []byte) error { k := GetBinaryTreeKeyStorageSlot(addr, key) var zero [HashSize]byte - root, err := t.root.Insert(k, zero[:], t.nodeResolver, 0) + err := t.store.Insert(k, zero[:], t.nodeResolver) if err != nil { return fmt.Errorf("DeleteStorage (%x) error: %v", addr, err) } - t.root = root return nil } // Hash returns the root hash of the trie. It does not write to the database and // can be used even if the trie doesn't have one. func (t *BinaryTrie) Hash() common.Hash { - return t.root.Hash() + return t.store.ComputeHash(t.store.Root()) } // Commit writes all nodes to the trie's memory database, tracking the internal @@ -339,15 +311,17 @@ func (t *BinaryTrie) Hash() common.Hash { func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { nodeset := trienode.NewNodeSet(common.Hash{}) - // The root can be any type of BinaryNode (InternalNode, StemNode, etc.) - err := t.root.CollectNodes(nil, func(path []byte, node BinaryNode) { - serialized := SerializeNode(node) - nodeset.AddNode(path, trienode.NewNodeWithPrev(node.Hash(), serialized, t.tracer.Get(path))) - }) + groupDepth := t.groupDepth + if groupDepth == 0 { + groupDepth = MaxGroupDepth + } + + err := t.store.CollectNodes(t.store.Root(), nil, func(path []byte, hash common.Hash, serialized []byte) { + nodeset.AddNode(path, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(path))) + }, groupDepth) if err != nil { panic(fmt.Errorf("CollectNodes failed: %v", err)) } - // Serialize root commitment form return t.Hash(), nodeset } @@ -371,9 +345,10 @@ func (t *BinaryTrie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { // Copy creates a deep copy of the trie. func (t *BinaryTrie) Copy() *BinaryTrie { return &BinaryTrie{ - root: t.root.Copy(), - reader: t.reader, - tracer: t.tracer.Copy(), + store: t.store.Copy(), + reader: t.reader, + tracer: t.tracer.Copy(), + groupDepth: t.groupDepth, } } @@ -407,7 +382,6 @@ func (t *BinaryTrie) UpdateContractCode(addr common.Address, codeHash common.Has if groupOffset == StemNodeWidth-1 || len(chunks)-i <= HashSize { err = t.UpdateStem(key[:StemSize], values) - if err != nil { return fmt.Errorf("UpdateContractCode (addr=%x) error: %w", addr[:], err) } diff --git a/trie/bintrie/trie_test.go b/trie/bintrie/trie_test.go index c436501464..b06bb4841f 100644 --- a/trie/bintrie/trie_test.go +++ b/trie/bintrie/trie_test.go @@ -37,147 +37,130 @@ var ( ) func TestSingleEntry(t *testing.T) { - tree := NewBinaryNode() - tree, err := tree.Insert(zeroKey[:], oneKey[:], nil, 0) - if err != nil { + s := NewNodeStore() + if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil { t.Fatal(err) } - if tree.GetHeight() != 1 { + if s.GetHeight(s.Root()) != 1 { t.Fatal("invalid depth") } expected := common.HexToHash("aab1060e04cb4f5dc6f697ae93156a95714debbf77d54238766adc5709282b6f") - got := tree.Hash() + got := s.Hash() if got != expected { t.Fatalf("invalid tree root, got %x, want %x", got, expected) } } func TestTwoEntriesDiffFirstBit(t *testing.T) { - var err error - tree := NewBinaryNode() - tree, err = tree.Insert(zeroKey[:], oneKey[:], nil, 0) - if err != nil { + s := NewNodeStore() + if err := s.Insert(zeroKey[:], oneKey[:], nil); err != nil { t.Fatal(err) } - tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil, 0) - if err != nil { + if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), twoKey[:], nil); err != nil { t.Fatal(err) } - if tree.GetHeight() != 2 { + if s.GetHeight(s.Root()) != 2 { t.Fatal("invalid height") } - if tree.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") { + if s.Hash() != common.HexToHash("dfc69c94013a8b3c65395625a719a87534a7cfd38719251ad8c8ea7fe79f065e") { t.Fatal("invalid tree root") } } func TestOneStemColocatedValues(t *testing.T) { - var err error - tree := NewBinaryNode() - tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) - if err != nil { + s := NewNodeStore() + if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil { t.Fatal(err) } - tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) - if err != nil { + if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil { t.Fatal(err) } - tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil, 0) - if err != nil { + if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000009").Bytes(), threeKey[:], nil); err != nil { t.Fatal(err) } - tree, err = tree.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil, 0) - if err != nil { + if err := s.Insert(common.HexToHash("00000000000000000000000000000000000000000000000000000000000000FF").Bytes(), fourKey[:], nil); err != nil { t.Fatal(err) } - if tree.GetHeight() != 1 { + if s.GetHeight(s.Root()) != 1 { t.Fatal("invalid height") } } func TestTwoStemColocatedValues(t *testing.T) { - var err error - tree := NewBinaryNode() + s := NewNodeStore() // stem: 0...0 - tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) - if err != nil { + if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil { t.Fatal(err) } - tree, err = tree.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) - if err != nil { + if err := s.Insert(common.HexToHash("0000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil { t.Fatal(err) } // stem: 10...0 - tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil, 0) - if err != nil { + if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000003").Bytes(), oneKey[:], nil); err != nil { t.Fatal(err) } - tree, err = tree.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil, 0) - if err != nil { + if err := s.Insert(common.HexToHash("8000000000000000000000000000000000000000000000000000000000000004").Bytes(), twoKey[:], nil); err != nil { t.Fatal(err) } - if tree.GetHeight() != 2 { + if s.GetHeight(s.Root()) != 2 { t.Fatal("invalid height") } } func TestTwoKeysMatchFirst42Bits(t *testing.T) { - var err error - tree := NewBinaryNode() + s := NewNodeStore() // key1 and key 2 have the same prefix of 42 bits (b0*42+b1+b1) and differ after. key1 := common.HexToHash("0000000000C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0C0").Bytes() key2 := common.HexToHash("0000000000E00000000000000000000000000000000000000000000000000000").Bytes() - tree, err = tree.Insert(key1, oneKey[:], nil, 0) - if err != nil { + if err := s.Insert(key1, oneKey[:], nil); err != nil { t.Fatal(err) } - tree, err = tree.Insert(key2, twoKey[:], nil, 0) - if err != nil { + if err := s.Insert(key2, twoKey[:], nil); err != nil { t.Fatal(err) } - if tree.GetHeight() != 1+42+1 { + if s.GetHeight(s.Root()) != 1+42+1 { t.Fatal("invalid height") } } + func TestInsertDuplicateKey(t *testing.T) { - var err error - tree := NewBinaryNode() - tree, err = tree.Insert(oneKey[:], oneKey[:], nil, 0) - if err != nil { + s := NewNodeStore() + if err := s.Insert(oneKey[:], oneKey[:], nil); err != nil { t.Fatal(err) } - tree, err = tree.Insert(oneKey[:], twoKey[:], nil, 0) - if err != nil { + if err := s.Insert(oneKey[:], twoKey[:], nil); err != nil { t.Fatal(err) } - if tree.GetHeight() != 1 { + if s.GetHeight(s.Root()) != 1 { t.Fatal("invalid height") } // Verify that the value is updated - if !bytes.Equal(tree.(*StemNode).Values[1], twoKey[:]) { - t.Fatal("invalid height") + v, err := s.Get(oneKey[:], nil) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(v, twoKey[:]) { + t.Fatal("value not updated") } } + func TestLargeNumberOfEntries(t *testing.T) { - var err error - tree := NewBinaryNode() + s := NewNodeStore() for i := range StemNodeWidth { var key [HashSize]byte key[0] = byte(i) - tree, err = tree.Insert(key[:], ffKey[:], nil, 0) - if err != nil { + if err := s.Insert(key[:], ffKey[:], nil); err != nil { t.Fatal(err) } } - height := tree.GetHeight() + height := s.GetHeight(s.Root()) if height != 1+8 { t.Fatalf("invalid height, wanted %d, got %d", 1+8, height) } } func TestMerkleizeMultipleEntries(t *testing.T) { - var err error - tree := NewBinaryNode() + s := NewNodeStore() keys := [][]byte{ zeroKey[:], common.HexToHash("8000000000000000000000000000000000000000000000000000000000000000").Bytes(), @@ -187,12 +170,11 @@ func TestMerkleizeMultipleEntries(t *testing.T) { for i, key := range keys { var v [HashSize]byte binary.LittleEndian.PutUint64(v[:8], uint64(i)) - tree, err = tree.Insert(key, v[:], nil, 0) - if err != nil { + if err := s.Insert(key, v[:], nil); err != nil { t.Fatal(err) } } - got := tree.Hash() + got := s.Hash() expected := common.HexToHash("9317155862f7a3867660ddd0966ff799a3d16aa4df1e70a7516eaa4a675191b5") if got != expected { t.Fatalf("invalid root, expected=%x, got = %x", expected, got) @@ -206,8 +188,9 @@ func TestMerkleizeMultipleEntries(t *testing.T) { func TestStorageRoundTrip(t *testing.T) { tracer := trie.NewPrevalueTracer() tr := &BinaryTrie{ - root: NewBinaryNode(), - tracer: tracer, + store: NewNodeStore(), + tracer: tracer, + groupDepth: MaxGroupDepth, } addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") @@ -274,8 +257,9 @@ func TestStorageRoundTrip(t *testing.T) { func newEmptyTestTrie(t *testing.T) *BinaryTrie { t.Helper() return &BinaryTrie{ - root: NewBinaryNode(), - tracer: trie.NewPrevalueTracer(), + store: NewNodeStore(), + tracer: trie.NewPrevalueTracer(), + groupDepth: MaxGroupDepth, } } @@ -599,8 +583,9 @@ func TestBinaryTrieWitness(t *testing.T) { tracer := trie.NewPrevalueTracer() tr := &BinaryTrie{ - root: NewBinaryNode(), - tracer: tracer, + store: NewNodeStore(), + tracer: tracer, + groupDepth: MaxGroupDepth, } if w := tr.Witness(); len(w) != 0 { t.Fatal("expected empty witness for fresh trie") @@ -626,8 +611,9 @@ func TestBinaryTrieWitness(t *testing.T) { func testAccount(t *testing.T, addr common.Address, nonce uint64, balance uint64) *BinaryTrie { t.Helper() tr := &BinaryTrie{ - root: NewBinaryNode(), - tracer: trie.NewPrevalueTracer(), + store: NewNodeStore(), + tracer: trie.NewPrevalueTracer(), + groupDepth: MaxGroupDepth, } acc := &types.StateAccount{ Nonce: nonce, @@ -649,8 +635,8 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) { tr := testAccount(t, addr, 42, 100) // Verify root is a StemNode (single stem inserted). - if _, ok := tr.root.(*StemNode); !ok { - t.Fatalf("expected StemNode root, got %T", tr.root) + if tr.store.Root().Kind() != KindStem { + t.Fatalf("expected StemNode root, got kind %d", tr.store.Root().Kind()) } // Query a completely different address — must return nil. @@ -680,8 +666,9 @@ func TestGetAccountNonMembershipStemRoot(t *testing.T) { // address returns nil when the trie root is an InternalNode (multi-account trie). func TestGetAccountNonMembershipInternalRoot(t *testing.T) { tr := &BinaryTrie{ - root: NewBinaryNode(), - tracer: trie.NewPrevalueTracer(), + store: NewNodeStore(), + tracer: trie.NewPrevalueTracer(), + groupDepth: MaxGroupDepth, } // Insert two accounts whose binary tree keys have different first bits @@ -700,8 +687,8 @@ func TestGetAccountNonMembershipInternalRoot(t *testing.T) { } // Verify root is an InternalNode. - if _, ok := tr.root.(*InternalNode); !ok { - t.Fatalf("expected InternalNode root, got %T", tr.root) + if tr.store.Root().Kind() != KindInternal { + t.Fatalf("expected InternalNode root, got kind %d", tr.store.Root().Kind()) } // Query a non-existent address — must return nil. @@ -723,8 +710,8 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) { tr := testAccount(t, addr, 1, 100) // Verify root is a StemNode. - if _, ok := tr.root.(*StemNode); !ok { - t.Fatalf("expected StemNode root, got %T", tr.root) + if tr.store.Root().Kind() != KindStem { + t.Fatalf("expected StemNode root, got kind %d", tr.store.Root().Kind()) } // Query storage for a different address — must return nil, not panic. @@ -743,8 +730,9 @@ func TestGetStorageNonMembershipStemRoot(t *testing.T) { // non-existent address returns nil when the root is an InternalNode. func TestGetStorageNonMembershipInternalRoot(t *testing.T) { tr := &BinaryTrie{ - root: NewBinaryNode(), - tracer: trie.NewPrevalueTracer(), + store: NewNodeStore(), + tracer: trie.NewPrevalueTracer(), + groupDepth: MaxGroupDepth, } addr := common.HexToAddress("0x1234567890abcdef1234567890abcdef12345678") @@ -765,8 +753,8 @@ func TestGetStorageNonMembershipInternalRoot(t *testing.T) { t.Fatalf("UpdateStorage error: %v", err) } - if _, ok := tr.root.(*InternalNode); !ok { - t.Fatalf("expected InternalNode root, got %T", tr.root) + if tr.store.Root().Kind() != KindInternal { + t.Fatalf("expected InternalNode root, got kind %d", tr.store.Root().Kind()) } // Query storage for a non-existent address — must return nil. diff --git a/triedb/pathdb/database.go b/triedb/pathdb/database.go index 04c76cfd53..98975d7fa5 100644 --- a/triedb/pathdb/database.go +++ b/triedb/pathdb/database.go @@ -102,11 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) { if len(blob) == 0 { return types.EmptyBinaryHash, nil } - n, err := bintrie.DeserializeNode(blob, 0) - if err != nil { - return common.Hash{}, err - } - return n.Hash(), nil + return bintrie.DeserializeAndHash(blob, 0) } // Database is a multiple-layered structure for maintaining in-memory states