diff --git a/trie/bintrie/binary_node.go b/trie/bintrie/binary_node.go index 690489b2aa..214eb24dd2 100644 --- a/trie/bintrie/binary_node.go +++ b/trie/bintrie/binary_node.go @@ -91,7 +91,7 @@ func SerializeNode(node BinaryNode) []byte { var invalidSerializedLength = errors.New("invalid serialized node length") // DeserializeNode deserializes a binary trie node from a byte slice. -func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) { +func DeserializeNode(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) { if len(serialized) == 0 { return Empty{}, nil } @@ -105,6 +105,7 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) { depth: depth, left: HashedNode(common.BytesToHash(serialized[1:33])), right: HashedNode(common.BytesToHash(serialized[33:65])), + hash: hn, }, nil case nodeTypeStem: if len(serialized) < 64 { @@ -127,6 +128,7 @@ func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) { Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize], Values: values[:], depth: depth, + hash: hn, }, nil default: return nil, errors.New("invalid node type") diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go index 242743ba53..4b1c4a33de 100644 --- a/trie/bintrie/binary_node_test.go +++ b/trie/bintrie/binary_node_test.go @@ -48,7 +48,7 @@ func TestSerializeDeserializeInternalNode(t *testing.T) { } // Deserialize the node - deserialized, err := DeserializeNode(serialized, 5) + deserialized, err := DeserializeNode(serialized, 5, common.Hash{}) if err != nil { t.Fatalf("Failed to deserialize node: %v", err) } @@ -108,7 +108,7 @@ func TestSerializeDeserializeStemNode(t *testing.T) { } // Deserialize the node - deserialized, err := DeserializeNode(serialized, 10) + deserialized, err := DeserializeNode(serialized, 10, common.Hash{}) if err != nil { t.Fatalf("Failed to deserialize node: %v", err) } @@ -149,7 +149,7 @@ func TestSerializeDeserializeStemNode(t *testing.T) { // TestDeserializeEmptyNode tests deserialization of empty node func TestDeserializeEmptyNode(t *testing.T) { // Empty byte slice should deserialize to Empty node - deserialized, err := DeserializeNode([]byte{}, 0) + deserialized, err := DeserializeNode([]byte{}, 0, common.Hash{}) if err != nil { t.Fatalf("Failed to deserialize empty node: %v", err) } @@ -165,7 +165,7 @@ func TestDeserializeInvalidType(t *testing.T) { // Create invalid serialized data with unknown type byte invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid - _, err := DeserializeNode(invalidData, 0) + _, err := DeserializeNode(invalidData, 0, common.Hash{}) if err == nil { t.Fatal("Expected error for invalid type byte, got nil") } @@ -176,7 +176,7 @@ func TestDeserializeInvalidLength(t *testing.T) { // InternalNode with type byte 1 but wrong length invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node - _, err := DeserializeNode(invalidData, 0) + _, err := DeserializeNode(invalidData, 0, common.Hash{}) if err == nil { t.Fatal("Expected error for invalid data length, got nil") } diff --git a/trie/bintrie/empty.go b/trie/bintrie/empty.go index 7cfe373b35..252146a4a7 100644 --- a/trie/bintrie/empty.go +++ b/trie/bintrie/empty.go @@ -32,9 +32,10 @@ func (e Empty) Insert(key []byte, value []byte, _ NodeResolverFn, depth int) (Bi var values [256][]byte values[key[31]] = value return &StemNode{ - Stem: slices.Clone(key[:31]), - Values: values[:], - depth: depth, + Stem: slices.Clone(key[:31]), + Values: values[:], + depth: depth, + mustRecompute: true, }, nil } @@ -53,9 +54,10 @@ func (e Empty) GetValuesAtStem(_ []byte, _ NodeResolverFn) ([][]byte, error) { func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, depth int) (BinaryNode, error) { return &StemNode{ - Stem: slices.Clone(key[:31]), - Values: values, - depth: depth, + Stem: slices.Clone(key[:31]), + Values: values, + depth: depth, + mustRecompute: true, }, nil } diff --git a/trie/bintrie/hashed_node.go b/trie/bintrie/hashed_node.go index e4d8c2e7ac..0d16543844 100644 --- a/trie/bintrie/hashed_node.go +++ b/trie/bintrie/hashed_node.go @@ -64,7 +64,7 @@ func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver No } // Step 3: Deserialize the resolved data into a concrete node - node, err := DeserializeNode(data, depth) + node, err := DeserializeNode(data, depth, common.Hash(h)) if err != nil { return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) } diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go index 0a7bece521..797e639d81 100644 --- a/trie/bintrie/internal_node.go +++ b/trie/bintrie/internal_node.go @@ -40,6 +40,9 @@ func keyToPath(depth int, key []byte) ([]byte, error) { type InternalNode struct { left, right BinaryNode depth int + + mustRecompute bool // true if the hash needs to be recomputed + hash common.Hash // cached hash when mustRecompute == false } // GetValuesAtStem retrieves the group of values located at the given stem key. @@ -59,7 +62,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([ if err != nil { return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) } - node, err := DeserializeNode(data, bt.depth+1) + node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn)) if err != nil { return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) } @@ -77,7 +80,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([ if err != nil { return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) } - node, err := DeserializeNode(data, bt.depth+1) + node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn)) if err != nil { return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) } @@ -108,14 +111,20 @@ func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn // Copy creates a deep copy of the node. func (bt *InternalNode) Copy() BinaryNode { return &InternalNode{ - left: bt.left.Copy(), - right: bt.right.Copy(), - depth: bt.depth, + left: bt.left.Copy(), + right: bt.right.Copy(), + depth: bt.depth, + mustRecompute: bt.mustRecompute, + hash: bt.hash, } } // Hash returns the hash of the node. func (bt *InternalNode) Hash() common.Hash { + if !bt.mustRecompute { + return bt.hash + } + h := sha256.New() if bt.left != nil { h.Write(bt.left.Hash().Bytes()) @@ -127,7 +136,9 @@ func (bt *InternalNode) Hash() common.Hash { } else { h.Write(zero[:]) } - return common.BytesToHash(h.Sum(nil)) + bt.hash = common.BytesToHash(h.Sum(nil)) + bt.mustRecompute = false + return bt.hash } // InsertValuesAtStem inserts a full value group at the given stem in the internal node. @@ -149,7 +160,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve if err != nil { return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) } - node, err := DeserializeNode(data, bt.depth+1) + node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn)) if err != nil { return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) } @@ -157,6 +168,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve } bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1) + bt.mustRecompute = true return bt, err } @@ -173,7 +185,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve if err != nil { return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) } - node, err := DeserializeNode(data, bt.depth+1) + node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn)) if err != nil { return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) } @@ -181,6 +193,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve } bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1) + bt.mustRecompute = true return bt, err } diff --git a/trie/bintrie/internal_node_test.go b/trie/bintrie/internal_node_test.go index 158d8b7147..69097483fd 100644 --- a/trie/bintrie/internal_node_test.go +++ b/trie/bintrie/internal_node_test.go @@ -239,6 +239,7 @@ func TestInternalNodeHash(t *testing.T) { // Changing a child should change the hash node.left = HashedNode(common.HexToHash("0x3333")) + node.mustRecompute = true hash3 := node.Hash() if hash1 == hash3 { t.Error("Hash didn't change after modifying left child") @@ -246,9 +247,10 @@ func TestInternalNodeHash(t *testing.T) { // Test with nil children (should use zero hash) nodeWithNil := &InternalNode{ - depth: 0, - left: nil, - right: HashedNode(common.HexToHash("0x4444")), + depth: 0, + left: nil, + right: HashedNode(common.HexToHash("0x4444")), + mustRecompute: true, } hashWithNil := nodeWithNil.Hash() if hashWithNil == (common.Hash{}) { diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go index 9b863ed1e3..a97f8ddb8c 100644 --- a/trie/bintrie/iterator.go +++ b/trie/bintrie/iterator.go @@ -123,7 +123,7 @@ func (it *binaryNodeIterator) Next(descend bool) bool { if err != nil { panic(err) } - it.current, err = DeserializeNode(data, len(it.stack)-1) + it.current, err = DeserializeNode(data, len(it.stack)-1, common.Hash(node)) if err != nil { panic(err) } diff --git a/trie/bintrie/stem_node.go b/trie/bintrie/stem_node.go index 60856b42ce..1dd4cdbedf 100644 --- a/trie/bintrie/stem_node.go +++ b/trie/bintrie/stem_node.go @@ -31,6 +31,9 @@ type StemNode struct { Stem []byte // Stem path to get to StemNodeWidth values Values [][]byte // All values, indexed by the last byte of the key. depth int // Depth of the node + + mustRecompute bool // true if the hash needs to be recomputed + hash common.Hash // cached hash when mustRecompute == false } // Get retrieves the value for the given key. @@ -43,7 +46,7 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int if !bytes.Equal(bt.Stem, key[:StemSize]) { bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 - n := &InternalNode{depth: bt.depth} + n := &InternalNode{depth: bt.depth, mustRecompute: true} bt.depth++ var child, other *BinaryNode if bitStem == 0 { @@ -68,9 +71,10 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int var values [StemNodeWidth][]byte values[key[StemSize]] = value *other = &StemNode{ - Stem: slices.Clone(key[:StemSize]), - Values: values[:], - depth: depth + 1, + Stem: slices.Clone(key[:StemSize]), + Values: values[:], + depth: depth + 1, + mustRecompute: true, } } return n, nil @@ -89,9 +93,11 @@ func (bt *StemNode) Copy() BinaryNode { values[i] = slices.Clone(v) } return &StemNode{ - Stem: slices.Clone(bt.Stem), - Values: values[:], - depth: bt.depth, + Stem: slices.Clone(bt.Stem), + Values: values[:], + depth: bt.depth, + hash: bt.hash, + mustRecompute: bt.mustRecompute, } } @@ -102,6 +108,10 @@ func (bt *StemNode) GetHeight() int { // Hash returns the hash of the node. func (bt *StemNode) Hash() common.Hash { + if !bt.mustRecompute { + return bt.hash + } + var data [StemNodeWidth]common.Hash for i, v := range bt.Values { if v != nil { @@ -130,7 +140,9 @@ func (bt *StemNode) Hash() common.Hash { h.Write(bt.Stem) h.Write([]byte{0}) h.Write(data[0][:]) - return common.BytesToHash(h.Sum(nil)) + bt.hash = common.BytesToHash(h.Sum(nil)) + bt.mustRecompute = false + return bt.hash } // CollectNodes collects all child nodes at a given path, and flushes it @@ -154,7 +166,7 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv if !bytes.Equal(bt.Stem, key[:StemSize]) { bitStem := bt.Stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 - n := &InternalNode{depth: bt.depth} + n := &InternalNode{depth: bt.depth, mustRecompute: true} bt.depth++ var child, other *BinaryNode if bitStem == 0 { @@ -177,9 +189,10 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv *other = Empty{} } else { *other = &StemNode{ - Stem: slices.Clone(key[:StemSize]), - Values: values, - depth: n.depth + 1, + Stem: slices.Clone(key[:StemSize]), + Values: values, + depth: n.depth + 1, + mustRecompute: true, } } return n, nil @@ -189,6 +202,7 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv for i, v := range values { if v != nil { bt.Values[i] = v + bt.mustRecompute = true } } return bt, nil diff --git a/trie/bintrie/stem_node_test.go b/trie/bintrie/stem_node_test.go index d8d6844427..92c1b49e02 100644 --- a/trie/bintrie/stem_node_test.go +++ b/trie/bintrie/stem_node_test.go @@ -220,6 +220,7 @@ func TestStemNodeHash(t *testing.T) { // Changing a value should change the hash node.Values[1] = common.HexToHash("0x0202").Bytes() + node.mustRecompute = true hash3 := node.Hash() if hash1 == hash3 { t.Error("Hash didn't change after modifying values") diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index a509c471b8..1bb2928e96 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -143,7 +143,7 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err if err != nil { return nil, err } - node, err := DeserializeNode(blob, 0) + node, err := DeserializeNode(blob, 0, root) if err != nil { return nil, err } diff --git a/triedb/pathdb/database.go b/triedb/pathdb/database.go index 5255602a4e..a5dd43260c 100644 --- a/triedb/pathdb/database.go +++ b/triedb/pathdb/database.go @@ -102,7 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) { if len(blob) == 0 { return types.EmptyVerkleHash, nil } - n, err := bintrie.DeserializeNode(blob, 0) + n, err := bintrie.DeserializeNode(blob, 0, common.Hash{}) if err != nil { return common.Hash{}, err }