trie/bintrie: cache hashes of clean nodes so as not to rehash the whole tree (#33961)

This is an optimization that existed for verkle and the MPT, but that
got dropped during the rebase.

Mark the nodes that were modified as needing recomputation, and skip the
hash computation if this is not needed. Otherwise, the whole tree is
hashed, which kills performance.
This commit is contained in:
Guillaume Ballet 2026-03-06 18:06:24 +01:00 committed by GitHub
parent a0fb8102fe
commit 3f1871524f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 86 additions and 39 deletions

View file

@ -90,8 +90,18 @@ func SerializeNode(node BinaryNode) []byte {
var invalidSerializedLength = errors.New("invalid serialized node length")
// DeserializeNode deserializes a binary trie node from a byte slice.
// DeserializeNode deserializes a binary trie node from a byte slice. The
// hash will be recomputed from the deserialized data.
func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
return deserializeNode(serialized, depth, common.Hash{}, true)
}
// DeserializeNodeWithHash deserializes a binary trie node from a byte slice, using the provided hash.
func DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
return deserializeNode(serialized, depth, hn, false)
}
func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute bool) (BinaryNode, error) {
if len(serialized) == 0 {
return Empty{}, nil
}
@ -102,9 +112,11 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
return nil, invalidSerializedLength
}
return &InternalNode{
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
hash: hn,
mustRecompute: mustRecompute,
}, nil
case nodeTypeStem:
if len(serialized) < 64 {
@ -124,9 +136,11 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
}
}
return &StemNode{
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
Values: values[:],
depth: depth,
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
Values: values[:],
depth: depth,
hash: hn,
mustRecompute: mustRecompute,
}, nil
default:
return nil, errors.New("invalid node type")

View file

@ -32,9 +32,10 @@ func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (Bi
var values [256][]byte
values[key[31]] = value
return &StemNode{
Stem: slices.Clone(key[:31]),
Values: values[:],
depth: depth,
Stem: slices.Clone(key[:31]),
Values: values[:],
depth: depth,
mustRecompute: true,
}, nil
}
@ -53,9 +54,10 @@ func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
return &StemNode{
Stem: slices.Clone(key[:31]),
Values: values,
depth: depth,
Stem: slices.Clone(key[:31]),
Values: values,
depth: depth,
mustRecompute: true,
}, nil
}

View file

@ -64,7 +64,7 @@ func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver No
}
// Step 3: Deserialize the resolved data into a concrete node
node, err := DeserializeNode(data, depth)
node, err := DeserializeNodeWithHash(data, depth, common.Hash(h))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}

View file

@ -40,6 +40,9 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
type InternalNode struct {
left, right BinaryNode
depth int
mustRecompute bool // true if the hash needs to be recomputed
hash common.Hash // cached hash when mustRecompute == false
}
// GetValuesAtStem retrieves the group of values located at the given stem key.
@ -59,7 +62,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
@ -77,7 +80,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
@ -108,14 +111,20 @@ func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn
// Copy creates a deep copy of the node.
func (bt *InternalNode) Copy() BinaryNode {
return &InternalNode{
left: bt.left.Copy(),
right: bt.right.Copy(),
depth: bt.depth,
left: bt.left.Copy(),
right: bt.right.Copy(),
depth: bt.depth,
mustRecompute: bt.mustRecompute,
hash: bt.hash,
}
}
// Hash returns the hash of the node.
func (bt *InternalNode) Hash() common.Hash {
if !bt.mustRecompute {
return bt.hash
}
h := sha256.New()
if bt.left != nil {
h.Write(bt.left.Hash().Bytes())
@ -127,7 +136,9 @@ func (bt *InternalNode) Hash() common.Hash {
} else {
h.Write(zero[:])
}
return common.BytesToHash(h.Sum(nil))
bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false
return bt.hash
}
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
@ -149,7 +160,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
@ -157,6 +168,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
}
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
return bt, err
}
@ -173,7 +185,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
@ -181,6 +193,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
}
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
return bt, err
}

View file

@ -239,6 +239,7 @@ func TestInternalNodeHash(t *testing.T) {
// Changing a child should change the hash
node.left = HashedNode(common.HexToHash("0x3333"))
node.mustRecompute = true
hash3 := node.Hash()
if hash1 == hash3 {
t.Error("Hash didn't change after modifying left child")
@ -246,9 +247,10 @@ func TestInternalNodeHash(t *testing.T) {
// Test with nil children (should use zero hash)
nodeWithNil := &InternalNode{
depth: 0,
left: nil,
right: HashedNode(common.HexToHash("0x4444")),
depth: 0,
left: nil,
right: HashedNode(common.HexToHash("0x4444")),
mustRecompute: true,
}
hashWithNil := nodeWithNil.Hash()
if hashWithNil == (common.Hash{}) {

View file

@ -123,7 +123,7 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
if err != nil {
panic(err)
}
it.current, err = DeserializeNode(data, len(it.stack)-1)
it.current, err = DeserializeNodeWithHash(data, len(it.stack)-1, common.Hash(node))
if err != nil {
panic(err)
}

View file

@ -31,6 +31,9 @@ type StemNode struct {
Stem []byte // Stem path to get to StemNodeWidth values
Values [][]byte // All values, indexed by the last byte of the key.
depth int // Depth of the node
mustRecompute bool // true if the hash needs to be recomputed
hash common.Hash // cached hash when mustRecompute == false
}
// Get retrieves the value for the given key.
@ -43,7 +46,7 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
if !bytes.Equal(bt.Stem, key[:StemSize]) {
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
n := &InternalNode{depth: bt.depth}
n := &InternalNode{depth: bt.depth, mustRecompute: true}
bt.depth++
var child, other *BinaryNode
if bitStem == 0 {
@ -68,9 +71,10 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
var values [StemNodeWidth][]byte
values[key[StemSize]] = value
*other = &StemNode{
Stem: slices.Clone(key[:StemSize]),
Values: values[:],
depth: depth + 1,
Stem: slices.Clone(key[:StemSize]),
Values: values[:],
depth: depth + 1,
mustRecompute: true,
}
}
return n, nil
@ -79,6 +83,7 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
return bt, errors.New("invalid insertion: value length")
}
bt.Values[key[StemSize]] = value
bt.mustRecompute = true
return bt, nil
}
@ -89,9 +94,11 @@ func (bt *StemNode) Copy() BinaryNode {
values[i] = slices.Clone(v)
}
return &StemNode{
Stem: slices.Clone(bt.Stem),
Values: values[:],
depth: bt.depth,
Stem: slices.Clone(bt.Stem),
Values: values[:],
depth: bt.depth,
hash: bt.hash,
mustRecompute: bt.mustRecompute,
}
}
@ -102,6 +109,10 @@ func (bt *StemNode) GetHeight() int {
// Hash returns the hash of the node.
func (bt *StemNode) Hash() common.Hash {
if !bt.mustRecompute {
return bt.hash
}
var data [StemNodeWidth]common.Hash
for i, v := range bt.Values {
if v != nil {
@ -130,7 +141,9 @@ func (bt *StemNode) Hash() common.Hash {
h.Write(bt.Stem)
h.Write([]byte{0})
h.Write(data[0][:])
return common.BytesToHash(h.Sum(nil))
bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false
return bt.hash
}
// CollectNodes collects all child nodes at a given path, and flushes it
@ -154,7 +167,7 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
if !bytes.Equal(bt.Stem, key[:StemSize]) {
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
n := &InternalNode{depth: bt.depth}
n := &InternalNode{depth: bt.depth, mustRecompute: true}
bt.depth++
var child, other *BinaryNode
if bitStem == 0 {
@ -177,9 +190,10 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
*other = Empty{}
} else {
*other = &StemNode{
Stem: slices.Clone(key[:StemSize]),
Values: values,
depth: n.depth + 1,
Stem: slices.Clone(key[:StemSize]),
Values: values,
depth: n.depth + 1,
mustRecompute: true,
}
}
return n, nil
@ -189,6 +203,7 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
for i, v := range values {
if v != nil {
bt.Values[i] = v
bt.mustRecompute = true
}
}
return bt, nil

View file

@ -220,6 +220,7 @@ func TestStemNodeHash(t *testing.T) {
// Changing a value should change the hash
node.Values[1] = common.HexToHash("0x0202").Bytes()
node.mustRecompute = true
hash3 := node.Hash()
if hash1 == hash3 {
t.Error("Hash didn't change after modifying values")

View file

@ -143,7 +143,7 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
if err != nil {
return nil, err
}
node, err := DeserializeNode(blob, 0)
node, err := DeserializeNodeWithHash(blob, 0, root)
if err != nil {
return nil, err
}