From e0643adc515182e96c39596a98ca38f1dfb19430 Mon Sep 17 00:00:00 2001 From: Guillaume Ballet <3272758+gballet@users.noreply.github.com> Date: Thu, 5 Mar 2026 21:45:28 +0100 Subject: [PATCH] review feedback + fix node hash computation --- trie/bintrie/binary_node.go | 32 ++++++++++++++++++++++---------- trie/bintrie/binary_node_test.go | 10 +++++----- trie/bintrie/hashed_node.go | 2 +- trie/bintrie/internal_node.go | 8 ++++---- trie/bintrie/iterator.go | 2 +- trie/bintrie/stem_node.go | 1 + trie/bintrie/trie.go | 2 +- triedb/pathdb/database.go | 2 +- 8 files changed, 36 insertions(+), 23 deletions(-) diff --git a/trie/bintrie/binary_node.go b/trie/bintrie/binary_node.go index 214eb24dd2..a7392ec958 100644 --- a/trie/bintrie/binary_node.go +++ b/trie/bintrie/binary_node.go @@ -90,8 +90,18 @@ func SerializeNode(node BinaryNode) []byte { var invalidSerializedLength = errors.New("invalid serialized node length") -// DeserializeNode deserializes a binary trie node from a byte slice. -func DeserializeNode(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) { +// DeserializeNode deserializes a binary trie node from a byte slice. The +// hash will be recomputed from the deserialized data. +func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) { + return deserializeNode(serialized, depth, common.Hash{}, true) +} + +// DeserializeNodeWithHash deserializes a binary trie node from a byte slice, using the provided hash. +func DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) { + return deserializeNode(serialized, depth, hn, false) +} + +func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute bool) (BinaryNode, error) { if len(serialized) == 0 { return Empty{}, nil } @@ -102,10 +112,11 @@ func DeserializeNode(serialized []byte, depth int, hn common.Hash) (BinaryNode, return nil, invalidSerializedLength } return &InternalNode{ - depth: depth, - left: HashedNode(common.BytesToHash(serialized[1:33])), - right: HashedNode(common.BytesToHash(serialized[33:65])), - hash: hn, + depth: depth, + left: HashedNode(common.BytesToHash(serialized[1:33])), + right: HashedNode(common.BytesToHash(serialized[33:65])), + hash: hn, + mustRecompute: mustRecompute, }, nil case nodeTypeStem: if len(serialized) < 64 { @@ -125,10 +136,11 @@ func DeserializeNode(serialized []byte, depth int, hn common.Hash) (BinaryNode, } } return &StemNode{ - Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize], - Values: values[:], - depth: depth, - hash: hn, + Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize], + Values: values[:], + depth: depth, + hash: hn, + mustRecompute: mustRecompute, }, nil default: return nil, errors.New("invalid node type") diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go index 4b1c4a33de..242743ba53 100644 --- a/trie/bintrie/binary_node_test.go +++ b/trie/bintrie/binary_node_test.go @@ -48,7 +48,7 @@ func TestSerializeDeserializeInternalNode(t *testing.T) { } // Deserialize the node - deserialized, err := DeserializeNode(serialized, 5, common.Hash{}) + deserialized, err := DeserializeNode(serialized, 5) if err != nil { t.Fatalf("Failed to deserialize node: %v", err) } @@ -108,7 +108,7 @@ func TestSerializeDeserializeStemNode(t *testing.T) { } // Deserialize the node - deserialized, err := DeserializeNode(serialized, 10, common.Hash{}) + deserialized, err := DeserializeNode(serialized, 10) if err != nil { t.Fatalf("Failed to deserialize node: %v", err) } @@ -149,7 +149,7 @@ func TestSerializeDeserializeStemNode(t *testing.T) { // TestDeserializeEmptyNode tests deserialization of empty node func TestDeserializeEmptyNode(t *testing.T) { // Empty byte slice should deserialize to Empty node - deserialized, err := DeserializeNode([]byte{}, 0, common.Hash{}) + deserialized, err := DeserializeNode([]byte{}, 0) if err != nil { t.Fatalf("Failed to deserialize empty node: %v", err) } @@ -165,7 +165,7 @@ func TestDeserializeInvalidType(t *testing.T) { // Create invalid serialized data with unknown type byte invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid - _, err := DeserializeNode(invalidData, 0, common.Hash{}) + _, err := DeserializeNode(invalidData, 0) if err == nil { t.Fatal("Expected error for invalid type byte, got nil") } @@ -176,7 +176,7 @@ func TestDeserializeInvalidLength(t *testing.T) { // InternalNode with type byte 1 but wrong length invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node - _, err := DeserializeNode(invalidData, 0, common.Hash{}) + _, err := DeserializeNode(invalidData, 0) if err == nil { t.Fatal("Expected error for invalid data length, got nil") } diff --git a/trie/bintrie/hashed_node.go b/trie/bintrie/hashed_node.go index 0d16543844..e44c6d1e8a 100644 --- a/trie/bintrie/hashed_node.go +++ b/trie/bintrie/hashed_node.go @@ -64,7 +64,7 @@ func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver No } // Step 3: Deserialize the resolved data into a concrete node - node, err := DeserializeNode(data, depth, common.Hash(h)) + node, err := DeserializeNodeWithHash(data, depth, common.Hash(h)) if err != nil { return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) } diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go index 797e639d81..2d02e240be 100644 --- a/trie/bintrie/internal_node.go +++ b/trie/bintrie/internal_node.go @@ -62,7 +62,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([ if err != nil { return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) } - node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn)) + node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn)) if err != nil { return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) } @@ -80,7 +80,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([ if err != nil { return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) } - node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn)) + node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn)) if err != nil { return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) } @@ -160,7 +160,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve if err != nil { return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) } - node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn)) + node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn)) if err != nil { return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) } @@ -185,7 +185,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve if err != nil { return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) } - node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn)) + node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn)) if err != nil { return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) } diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go index a97f8ddb8c..917f82efc9 100644 --- a/trie/bintrie/iterator.go +++ b/trie/bintrie/iterator.go @@ -123,7 +123,7 @@ func (it *binaryNodeIterator) Next(descend bool) bool { if err != nil { panic(err) } - it.current, err = DeserializeNode(data, len(it.stack)-1, common.Hash(node)) + it.current, err = DeserializeNodeWithHash(data, len(it.stack)-1, common.Hash(node)) if err != nil { panic(err) } diff --git a/trie/bintrie/stem_node.go b/trie/bintrie/stem_node.go index 1dd4cdbedf..f1ae2361ff 100644 --- a/trie/bintrie/stem_node.go +++ b/trie/bintrie/stem_node.go @@ -83,6 +83,7 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int return bt, errors.New("invalid insertion: value length") } bt.Values[key[StemSize]] = value + bt.mustRecompute = true return bt, nil } diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index 1bb2928e96..6c29239a87 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -143,7 +143,7 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err if err != nil { return nil, err } - node, err := DeserializeNode(blob, 0, root) + node, err := DeserializeNodeWithHash(blob, 0, root) if err != nil { return nil, err } diff --git a/triedb/pathdb/database.go b/triedb/pathdb/database.go index a5dd43260c..5255602a4e 100644 --- a/triedb/pathdb/database.go +++ b/triedb/pathdb/database.go @@ -102,7 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) { if len(blob) == 0 { return types.EmptyVerkleHash, nil } - n, err := bintrie.DeserializeNode(blob, 0, common.Hash{}) + n, err := bintrie.DeserializeNode(blob, 0) if err != nil { return common.Hash{}, err }