mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-06-19 13:21:37 +00:00
trie/bintrie: cache hashes of clean nodes so as not to rehash the whole tree
This commit is contained in:
parent
a0fb8102fe
commit
ad2125ecb2
11 changed files with 73 additions and 39 deletions
|
|
@ -91,7 +91,7 @@ func SerializeNode(node BinaryNode) []byte {
|
|||
var invalidSerializedLength = errors.New("invalid serialized node length")
|
||||
|
||||
// DeserializeNode deserializes a binary trie node from a byte slice.
|
||||
func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
|
||||
func DeserializeNode(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
|
||||
if len(serialized) == 0 {
|
||||
return Empty{}, nil
|
||||
}
|
||||
|
|
@ -105,6 +105,7 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
|
|||
depth: depth,
|
||||
left: HashedNode(common.BytesToHash(serialized[1:33])),
|
||||
right: HashedNode(common.BytesToHash(serialized[33:65])),
|
||||
hash: hn,
|
||||
}, nil
|
||||
case nodeTypeStem:
|
||||
if len(serialized) < 64 {
|
||||
|
|
@ -127,6 +128,7 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
|
|||
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
|
||||
Values: values[:],
|
||||
depth: depth,
|
||||
hash: hn,
|
||||
}, nil
|
||||
default:
|
||||
return nil, errors.New("invalid node type")
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
|
|||
}
|
||||
|
||||
// Deserialize the node
|
||||
deserialized, err := DeserializeNode(serialized, 5)
|
||||
deserialized, err := DeserializeNode(serialized, 5, common.Hash{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to deserialize node: %v", err)
|
||||
}
|
||||
|
|
@ -108,7 +108,7 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
|
|||
}
|
||||
|
||||
// Deserialize the node
|
||||
deserialized, err := DeserializeNode(serialized, 10)
|
||||
deserialized, err := DeserializeNode(serialized, 10, common.Hash{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to deserialize node: %v", err)
|
||||
}
|
||||
|
|
@ -149,7 +149,7 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
|
|||
// TestDeserializeEmptyNode tests deserialization of empty node
|
||||
func TestDeserializeEmptyNode(t *testing.T) {
|
||||
// Empty byte slice should deserialize to Empty node
|
||||
deserialized, err := DeserializeNode([]byte{}, 0)
|
||||
deserialized, err := DeserializeNode([]byte{}, 0, common.Hash{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to deserialize empty node: %v", err)
|
||||
}
|
||||
|
|
@ -165,7 +165,7 @@ func TestDeserializeInvalidType(t *testing.T) {
|
|||
// Create invalid serialized data with unknown type byte
|
||||
invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid
|
||||
|
||||
_, err := DeserializeNode(invalidData, 0)
|
||||
_, err := DeserializeNode(invalidData, 0, common.Hash{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid type byte, got nil")
|
||||
}
|
||||
|
|
@ -176,7 +176,7 @@ func TestDeserializeInvalidLength(t *testing.T) {
|
|||
// InternalNode with type byte 1 but wrong length
|
||||
invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node
|
||||
|
||||
_, err := DeserializeNode(invalidData, 0)
|
||||
_, err := DeserializeNode(invalidData, 0, common.Hash{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid data length, got nil")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,9 +32,10 @@ func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (Bi
|
|||
var values [256][]byte
|
||||
values[key[31]] = value
|
||||
return &StemNode{
|
||||
Stem: slices.Clone(key[:31]),
|
||||
Values: values[:],
|
||||
depth: depth,
|
||||
Stem: slices.Clone(key[:31]),
|
||||
Values: values[:],
|
||||
depth: depth,
|
||||
mustRecompute: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -53,9 +54,10 @@ func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) {
|
|||
|
||||
func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) {
|
||||
return &StemNode{
|
||||
Stem: slices.Clone(key[:31]),
|
||||
Values: values,
|
||||
depth: depth,
|
||||
Stem: slices.Clone(key[:31]),
|
||||
Values: values,
|
||||
depth: depth,
|
||||
mustRecompute: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver No
|
|||
}
|
||||
|
||||
// Step 3: Deserialize the resolved data into a concrete node
|
||||
node, err := DeserializeNode(data, depth)
|
||||
node, err := DeserializeNode(data, depth, common.Hash(h))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
|
|||
type InternalNode struct {
|
||||
left, right BinaryNode
|
||||
depth int
|
||||
|
||||
mustRecompute bool // true if the hash needs to be recomputed
|
||||
hash common.Hash // cached hash when mustRecompute == false
|
||||
}
|
||||
|
||||
// GetValuesAtStem retrieves the group of values located at the given stem key.
|
||||
|
|
@ -59,7 +62,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
node, err := DeserializeNode(data, bt.depth+1)
|
||||
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
|
||||
}
|
||||
|
|
@ -77,7 +80,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
node, err := DeserializeNode(data, bt.depth+1)
|
||||
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
|
||||
}
|
||||
|
|
@ -108,14 +111,20 @@ func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn
|
|||
// Copy creates a deep copy of the node.
|
||||
func (bt *InternalNode) Copy() BinaryNode {
|
||||
return &InternalNode{
|
||||
left: bt.left.Copy(),
|
||||
right: bt.right.Copy(),
|
||||
depth: bt.depth,
|
||||
left: bt.left.Copy(),
|
||||
right: bt.right.Copy(),
|
||||
depth: bt.depth,
|
||||
mustRecompute: bt.mustRecompute,
|
||||
hash: bt.hash,
|
||||
}
|
||||
}
|
||||
|
||||
// Hash returns the hash of the node.
|
||||
func (bt *InternalNode) Hash() common.Hash {
|
||||
if !bt.mustRecompute {
|
||||
return bt.hash
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
if bt.left != nil {
|
||||
h.Write(bt.left.Hash().Bytes())
|
||||
|
|
@ -127,7 +136,9 @@ func (bt *InternalNode) Hash() common.Hash {
|
|||
} else {
|
||||
h.Write(zero[:])
|
||||
}
|
||||
return common.BytesToHash(h.Sum(nil))
|
||||
bt.hash = common.BytesToHash(h.Sum(nil))
|
||||
bt.mustRecompute = false
|
||||
return bt.hash
|
||||
}
|
||||
|
||||
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
|
||||
|
|
@ -149,7 +160,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
node, err := DeserializeNode(data, bt.depth+1)
|
||||
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
|
||||
}
|
||||
|
|
@ -157,6 +168,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
|
|||
}
|
||||
|
||||
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
|
||||
bt.mustRecompute = true
|
||||
return bt, err
|
||||
}
|
||||
|
||||
|
|
@ -173,7 +185,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
||||
}
|
||||
node, err := DeserializeNode(data, bt.depth+1)
|
||||
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
|
||||
}
|
||||
|
|
@ -181,6 +193,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
|
|||
}
|
||||
|
||||
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
|
||||
bt.mustRecompute = true
|
||||
return bt, err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -239,6 +239,7 @@ func TestInternalNodeHash(t *testing.T) {
|
|||
|
||||
// Changing a child should change the hash
|
||||
node.left = HashedNode(common.HexToHash("0x3333"))
|
||||
node.mustRecompute = true
|
||||
hash3 := node.Hash()
|
||||
if hash1 == hash3 {
|
||||
t.Error("Hash didn't change after modifying left child")
|
||||
|
|
@ -246,9 +247,10 @@ func TestInternalNodeHash(t *testing.T) {
|
|||
|
||||
// Test with nil children (should use zero hash)
|
||||
nodeWithNil := &InternalNode{
|
||||
depth: 0,
|
||||
left: nil,
|
||||
right: HashedNode(common.HexToHash("0x4444")),
|
||||
depth: 0,
|
||||
left: nil,
|
||||
right: HashedNode(common.HexToHash("0x4444")),
|
||||
mustRecompute: true,
|
||||
}
|
||||
hashWithNil := nodeWithNil.Hash()
|
||||
if hashWithNil == (common.Hash{}) {
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
it.current, err = DeserializeNode(data, len(it.stack)-1)
|
||||
it.current, err = DeserializeNode(data, len(it.stack)-1, common.Hash(node))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,9 @@ type StemNode struct {
|
|||
Stem []byte // Stem path to get to StemNodeWidth values
|
||||
Values [][]byte // All values, indexed by the last byte of the key.
|
||||
depth int // Depth of the node
|
||||
|
||||
mustRecompute bool // true if the hash needs to be recomputed
|
||||
hash common.Hash // cached hash when mustRecompute == false
|
||||
}
|
||||
|
||||
// Get retrieves the value for the given key.
|
||||
|
|
@ -43,7 +46,7 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
|
|||
if !bytes.Equal(bt.Stem, key[:StemSize]) {
|
||||
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
|
||||
|
||||
n := &InternalNode{depth: bt.depth}
|
||||
n := &InternalNode{depth: bt.depth, mustRecompute: true}
|
||||
bt.depth++
|
||||
var child, other *BinaryNode
|
||||
if bitStem == 0 {
|
||||
|
|
@ -68,9 +71,10 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
|
|||
var values [StemNodeWidth][]byte
|
||||
values[key[StemSize]] = value
|
||||
*other = &StemNode{
|
||||
Stem: slices.Clone(key[:StemSize]),
|
||||
Values: values[:],
|
||||
depth: depth + 1,
|
||||
Stem: slices.Clone(key[:StemSize]),
|
||||
Values: values[:],
|
||||
depth: depth + 1,
|
||||
mustRecompute: true,
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
|
|
@ -89,9 +93,11 @@ func (bt *StemNode) Copy() BinaryNode {
|
|||
values[i] = slices.Clone(v)
|
||||
}
|
||||
return &StemNode{
|
||||
Stem: slices.Clone(bt.Stem),
|
||||
Values: values[:],
|
||||
depth: bt.depth,
|
||||
Stem: slices.Clone(bt.Stem),
|
||||
Values: values[:],
|
||||
depth: bt.depth,
|
||||
hash: bt.hash,
|
||||
mustRecompute: bt.mustRecompute,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -102,6 +108,10 @@ func (bt *StemNode) GetHeight() int {
|
|||
|
||||
// Hash returns the hash of the node.
|
||||
func (bt *StemNode) Hash() common.Hash {
|
||||
if !bt.mustRecompute {
|
||||
return bt.hash
|
||||
}
|
||||
|
||||
var data [StemNodeWidth]common.Hash
|
||||
for i, v := range bt.Values {
|
||||
if v != nil {
|
||||
|
|
@ -130,7 +140,9 @@ func (bt *StemNode) Hash() common.Hash {
|
|||
h.Write(bt.Stem)
|
||||
h.Write([]byte{0})
|
||||
h.Write(data[0][:])
|
||||
return common.BytesToHash(h.Sum(nil))
|
||||
bt.hash = common.BytesToHash(h.Sum(nil))
|
||||
bt.mustRecompute = false
|
||||
return bt.hash
|
||||
}
|
||||
|
||||
// CollectNodes collects all child nodes at a given path, and flushes it
|
||||
|
|
@ -154,7 +166,7 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
|
|||
if !bytes.Equal(bt.Stem, key[:StemSize]) {
|
||||
bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
|
||||
|
||||
n := &InternalNode{depth: bt.depth}
|
||||
n := &InternalNode{depth: bt.depth, mustRecompute: true}
|
||||
bt.depth++
|
||||
var child, other *BinaryNode
|
||||
if bitStem == 0 {
|
||||
|
|
@ -177,9 +189,10 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
|
|||
*other = Empty{}
|
||||
} else {
|
||||
*other = &StemNode{
|
||||
Stem: slices.Clone(key[:StemSize]),
|
||||
Values: values,
|
||||
depth: n.depth + 1,
|
||||
Stem: slices.Clone(key[:StemSize]),
|
||||
Values: values,
|
||||
depth: n.depth + 1,
|
||||
mustRecompute: true,
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
|
|
@ -189,6 +202,7 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
|
|||
for i, v := range values {
|
||||
if v != nil {
|
||||
bt.Values[i] = v
|
||||
bt.mustRecompute = true
|
||||
}
|
||||
}
|
||||
return bt, nil
|
||||
|
|
|
|||
|
|
@ -220,6 +220,7 @@ func TestStemNodeHash(t *testing.T) {
|
|||
|
||||
// Changing a value should change the hash
|
||||
node.Values[1] = common.HexToHash("0x0202").Bytes()
|
||||
node.mustRecompute = true
|
||||
hash3 := node.Hash()
|
||||
if hash1 == hash3 {
|
||||
t.Error("Hash didn't change after modifying values")
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
node, err := DeserializeNode(blob, 0)
|
||||
node, err := DeserializeNode(blob, 0, root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) {
|
|||
if len(blob) == 0 {
|
||||
return types.EmptyVerkleHash, nil
|
||||
}
|
||||
n, err := bintrie.DeserializeNode(blob, 0)
|
||||
n, err := bintrie.DeserializeNode(blob, 0, common.Hash{})
|
||||
if err != nil {
|
||||
return common.Hash{}, err
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue