trie/bintrie: cache hashes of clean nodes so as not to rehash the whole tree

This commit is contained in:
Guillaume Ballet 2026-03-05 18:30:20 +01:00
parent a0fb8102fe
commit ad2125ecb2
11 changed files with 73 additions and 39 deletions

View file

@ -91,7 +91,7 @@ func SerializeNode(node BinaryNode) []byte {
var invalidSerializedLength = errors.New("invalid serialized node length")
// DeserializeNode deserializes a binary trie node from a byte slice.
func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
func DeserializeNode(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
if len(serialized) == 0 {
return Empty{}, nil
}
@ -105,6 +105,7 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
hash: hn,
}, nil
case nodeTypeStem:
if len(serialized) < 64 {
@ -127,6 +128,7 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
Values: values[:],
depth: depth,
hash: hn,
}, nil
default:
return nil, errors.New("invalid node type")

View file

@ -48,7 +48,7 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
}
// Deserialize the node
deserialized, err := DeserializeNode(serialized, 5)
deserialized, err := DeserializeNode(serialized, 5, common.Hash{})
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
@ -108,7 +108,7 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
}
// Deserialize the node
deserialized, err := DeserializeNode(serialized, 10)
deserialized, err := DeserializeNode(serialized, 10, common.Hash{})
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
@ -149,7 +149,7 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
// TestDeserializeEmptyNode tests deserialization of empty node
func TestDeserializeEmptyNode(t *testing.T) {
// Empty byte slice should deserialize to Empty node
deserialized, err := DeserializeNode([]byte{}, 0)
deserialized, err := DeserializeNode([]byte{}, 0, common.Hash{})
if err != nil {
t.Fatalf("Failed to deserialize empty node: %v", err)
}
@ -165,7 +165,7 @@ func TestDeserializeInvalidType(t *testing.T) {
// Create invalid serialized data with unknown type byte
invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid
_, err := DeserializeNode(invalidData, 0)
_, err := DeserializeNode(invalidData, 0, common.Hash{})
if err == nil {
t.Fatal("Expected error for invalid type byte, got nil")
}
@ -176,7 +176,7 @@ func TestDeserializeInvalidLength(t *testing.T) {
// InternalNode with type byte 1 but wrong length
invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node
_, err := DeserializeNode(invalidData, 0)
_, err := DeserializeNode(invalidData, 0, common.Hash{})
if err == nil {
t.Fatal("Expected error for invalid data length, got nil")
}

View file

@ -32,9 +32,10 @@ func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (Bi
var values [256][]byte
values[key[31]] = value
return &StemNode{
Stem: slices.Clone(key[:31]),
Values: values[:],
depth: depth,
Stem: slices.Clone(key[:31]),
Values: values[:],
depth: depth,
mustRecompute: true,
}, nil
}
@ -53,9 +54,10 @@ func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
return &StemNode{
Stem: slices.Clone(key[:31]),
Values: values,
depth: depth,
Stem: slices.Clone(key[:31]),
Values: values,
depth: depth,
mustRecompute: true,
}, nil
}

View file

@ -64,7 +64,7 @@ func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver No
}
// Step 3: Deserialize the resolved data into a concrete node
node, err := DeserializeNode(data, depth)
node, err := DeserializeNode(data, depth, common.Hash(h))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}

View file

@ -40,6 +40,9 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
type InternalNode struct {
left, right BinaryNode
depth int
mustRecompute bool // true if the hash needs to be recomputed
hash common.Hash // cached hash when mustRecompute == false
}
// GetValuesAtStem retrieves the group of values located at the given stem key.
@ -59,7 +62,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
@ -77,7 +80,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
@ -108,14 +111,20 @@ func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn
// Copy creates a deep copy of the node.
func (bt *InternalNode) Copy() BinaryNode {
return &InternalNode{
left: bt.left.Copy(),
right: bt.right.Copy(),
depth: bt.depth,
left: bt.left.Copy(),
right: bt.right.Copy(),
depth: bt.depth,
mustRecompute: bt.mustRecompute,
hash: bt.hash,
}
}
// Hash returns the hash of the node.
func (bt *InternalNode) Hash() common.Hash {
if !bt.mustRecompute {
return bt.hash
}
h := sha256.New()
if bt.left != nil {
h.Write(bt.left.Hash().Bytes())
@ -127,7 +136,9 @@ func (bt *InternalNode) Hash() common.Hash {
} else {
h.Write(zero[:])
}
return common.BytesToHash(h.Sum(nil))
bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false
return bt.hash
}
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
@ -149,7 +160,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
@ -157,6 +168,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
}
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
return bt, err
}
@ -173,7 +185,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
@ -181,6 +193,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
}
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
return bt, err
}

View file

@ -239,6 +239,7 @@ func TestInternalNodeHash(t *testing.T) {
// Changing a child should change the hash
node.left = HashedNode(common.HexToHash("0x3333"))
node.mustRecompute = true
hash3 := node.Hash()
if hash1 == hash3 {
t.Error("Hash didn't change after modifying left child")
@ -246,9 +247,10 @@ func TestInternalNodeHash(t *testing.T) {
// Test with nil children (should use zero hash)
nodeWithNil := &InternalNode{
depth: 0,
left: nil,
right: HashedNode(common.HexToHash("0x4444")),
depth: 0,
left: nil,
right: HashedNode(common.HexToHash("0x4444")),
mustRecompute: true,
}
hashWithNil := nodeWithNil.Hash()
if hashWithNil == (common.Hash{}) {

View file

@ -123,7 +123,7 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
if err != nil {
panic(err)
}
it.current, err = DeserializeNode(data, len(it.stack)-1)
it.current, err = DeserializeNode(data, len(it.stack)-1, common.Hash(node))
if err != nil {
panic(err)
}

View file

@ -31,6 +31,9 @@ type StemNode struct {
Stem []byte // Stem path to get to StemNodeWidth values
Values [][]byte // All values, indexed by the last byte of the key.
depth int // Depth of the node
mustRecompute bool // true if the hash needs to be recomputed
hash common.Hash // cached hash when mustRecompute == false
}
// Get retrieves the value for the given key.
@ -43,7 +46,7 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
if !bytes.Equal(bt.Stem, key[:StemSize]) {
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
n := &InternalNode{depth: bt.depth}
n := &InternalNode{depth: bt.depth, mustRecompute: true}
bt.depth++
var child, other *BinaryNode
if bitStem == 0 {
@ -68,9 +71,10 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
var values [StemNodeWidth][]byte
values[key[StemSize]] = value
*other = &StemNode{
Stem: slices.Clone(key[:StemSize]),
Values: values[:],
depth: depth + 1,
Stem: slices.Clone(key[:StemSize]),
Values: values[:],
depth: depth + 1,
mustRecompute: true,
}
}
return n, nil
@ -89,9 +93,11 @@ func (bt *StemNode) Copy() BinaryNode {
values[i] = slices.Clone(v)
}
return &StemNode{
Stem: slices.Clone(bt.Stem),
Values: values[:],
depth: bt.depth,
Stem: slices.Clone(bt.Stem),
Values: values[:],
depth: bt.depth,
hash: bt.hash,
mustRecompute: bt.mustRecompute,
}
}
@ -102,6 +108,10 @@ func (bt *StemNode) GetHeight() int {
// Hash returns the hash of the node.
func (bt *StemNode) Hash() common.Hash {
if !bt.mustRecompute {
return bt.hash
}
var data [StemNodeWidth]common.Hash
for i, v := range bt.Values {
if v != nil {
@ -130,7 +140,9 @@ func (bt *StemNode) Hash() common.Hash {
h.Write(bt.Stem)
h.Write([]byte{0})
h.Write(data[0][:])
return common.BytesToHash(h.Sum(nil))
bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false
return bt.hash
}
// CollectNodes collects all child nodes at a given path, and flushes it
@ -154,7 +166,7 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
if !bytes.Equal(bt.Stem, key[:StemSize]) {
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
n := &InternalNode{depth: bt.depth}
n := &InternalNode{depth: bt.depth, mustRecompute: true}
bt.depth++
var child, other *BinaryNode
if bitStem == 0 {
@ -177,9 +189,10 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
*other = Empty{}
} else {
*other = &StemNode{
Stem: slices.Clone(key[:StemSize]),
Values: values,
depth: n.depth + 1,
Stem: slices.Clone(key[:StemSize]),
Values: values,
depth: n.depth + 1,
mustRecompute: true,
}
}
return n, nil
@ -189,6 +202,7 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
for i, v := range values {
if v != nil {
bt.Values[i] = v
bt.mustRecompute = true
}
}
return bt, nil

View file

@ -220,6 +220,7 @@ func TestStemNodeHash(t *testing.T) {
// Changing a value should change the hash
node.Values[1] = common.HexToHash("0x0202").Bytes()
node.mustRecompute = true
hash3 := node.Hash()
if hash1 == hash3 {
t.Error("Hash didn't change after modifying values")

View file

@ -143,7 +143,7 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
if err != nil {
return nil, err
}
node, err := DeserializeNode(blob, 0)
node, err := DeserializeNode(blob, 0, root)
if err != nil {
return nil, err
}

View file

@ -102,7 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) {
if len(blob) == 0 {
return types.EmptyVerkleHash, nil
}
n, err := bintrie.DeserializeNode(blob, 0)
n, err := bintrie.DeserializeNode(blob, 0, common.Hash{})
if err != nil {
return common.Hash{}, err
}