review feedback + fix node hash computation

This commit is contained in:
Guillaume Ballet 2026-03-05 21:45:28 +01:00
parent ad2125ecb2
commit e0643adc51
8 changed files with 36 additions and 23 deletions

View file

@ -90,8 +90,18 @@ func SerializeNode(node BinaryNode) []byte {
var invalidSerializedLength = errors.New("invalid serialized node length")
// DeserializeNode deserializes a binary trie node from a byte slice.
func DeserializeNode(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
// DeserializeNode deserializes a binary trie node from a byte slice. The
// hash will be recomputed from the deserialized data.
func DeserializeNode(serialized []byte, depth int) (BinaryNode, error) {
return deserializeNode(serialized, depth, common.Hash{}, true)
}
// DeserializeNodeWithHash deserializes a binary trie node from a byte slice, using the provided hash.
func DeserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (BinaryNode, error) {
return deserializeNode(serialized, depth, hn, false)
}
func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute bool) (BinaryNode, error) {
if len(serialized) == 0 {
return Empty{}, nil
}
@ -102,10 +112,11 @@ func DeserializeNode(serialized []byte, depth int, hn common.Hash) (BinaryNode,
return nil, invalidSerializedLength
}
return &InternalNode{
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
hash: hn,
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
hash: hn,
mustRecompute: mustRecompute,
}, nil
case nodeTypeStem:
if len(serialized) < 64 {
@ -125,10 +136,11 @@ func DeserializeNode(serialized []byte, depth int, hn common.Hash) (BinaryNode,
}
}
return &StemNode{
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
Values: values[:],
depth: depth,
hash: hn,
Stem: serialized[NodeTypeBytes : NodeTypeBytes+StemSize],
Values: values[:],
depth: depth,
hash: hn,
mustRecompute: mustRecompute,
}, nil
default:
return nil, errors.New("invalid node type")

View file

@ -48,7 +48,7 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
}
// Deserialize the node
deserialized, err := DeserializeNode(serialized, 5, common.Hash{})
deserialized, err := DeserializeNode(serialized, 5)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
@ -108,7 +108,7 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
}
// Deserialize the node
deserialized, err := DeserializeNode(serialized, 10, common.Hash{})
deserialized, err := DeserializeNode(serialized, 10)
if err != nil {
t.Fatalf("Failed to deserialize node: %v", err)
}
@ -149,7 +149,7 @@ func TestSerializeDeserializeStemNode(t *testing.T) {
// TestDeserializeEmptyNode tests deserialization of empty node
func TestDeserializeEmptyNode(t *testing.T) {
// Empty byte slice should deserialize to Empty node
deserialized, err := DeserializeNode([]byte{}, 0, common.Hash{})
deserialized, err := DeserializeNode([]byte{}, 0)
if err != nil {
t.Fatalf("Failed to deserialize empty node: %v", err)
}
@ -165,7 +165,7 @@ func TestDeserializeInvalidType(t *testing.T) {
// Create invalid serialized data with unknown type byte
invalidData := []byte{99, 0, 0, 0} // Type byte 99 is invalid
_, err := DeserializeNode(invalidData, 0, common.Hash{})
_, err := DeserializeNode(invalidData, 0)
if err == nil {
t.Fatal("Expected error for invalid type byte, got nil")
}
@ -176,7 +176,7 @@ func TestDeserializeInvalidLength(t *testing.T) {
// InternalNode with type byte 1 but wrong length
invalidData := []byte{nodeTypeInternal, 0, 0} // Too short for internal node
_, err := DeserializeNode(invalidData, 0, common.Hash{})
_, err := DeserializeNode(invalidData, 0)
if err == nil {
t.Fatal("Expected error for invalid data length, got nil")
}

View file

@ -64,7 +64,7 @@ func (h HashedNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver No
}
// Step 3: Deserialize the resolved data into a concrete node
node, err := DeserializeNode(data, depth, common.Hash(h))
node, err := DeserializeNodeWithHash(data, depth, common.Hash(h))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}

View file

@ -62,7 +62,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
@ -80,7 +80,7 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
@ -160,7 +160,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
@ -185,7 +185,7 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1, common.Hash(hn))
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}

View file

@ -123,7 +123,7 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
if err != nil {
panic(err)
}
it.current, err = DeserializeNode(data, len(it.stack)-1, common.Hash(node))
it.current, err = DeserializeNodeWithHash(data, len(it.stack)-1, common.Hash(node))
if err != nil {
panic(err)
}

View file

@ -83,6 +83,7 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
return bt, errors.New("invalid insertion: value length")
}
bt.Values[key[StemSize]] = value
bt.mustRecompute = true
return bt, nil
}

View file

@ -143,7 +143,7 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase) (*BinaryTrie, err
if err != nil {
return nil, err
}
node, err := DeserializeNode(blob, 0, root)
node, err := DeserializeNodeWithHash(blob, 0, root)
if err != nil {
return nil, err
}

View file

@ -102,7 +102,7 @@ func binaryNodeHasher(blob []byte) (common.Hash, error) {
if len(blob) == 0 {
return types.EmptyVerkleHash, nil
}
n, err := bintrie.DeserializeNode(blob, 0, common.Hash{})
n, err := bintrie.DeserializeNode(blob, 0)
if err != nil {
return common.Hash{}, err
}