This commit is contained in:
Ng Wei Han 2026-04-16 22:36:52 -07:00 committed by GitHub
commit c4c379e7cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 130 additions and 233 deletions

View file

@ -64,8 +64,8 @@ func SerializeNode(node BinaryNode) []byte {
// InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash
var serialized [NodeTypeBytes + HashSize + HashSize]byte
serialized[0] = nodeTypeInternal
copy(serialized[1:33], n.left.Hash().Bytes())
copy(serialized[33:65], n.right.Hash().Bytes())
copy(serialized[1:33], n.children[0].Hash().Bytes())
copy(serialized[33:65], n.children[1].Hash().Bytes())
return serialized[:]
case *StemNode:
// StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values
@ -112,9 +112,11 @@ func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute
return nil, invalidSerializedLength
}
return &InternalNode{
depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])),
right: HashedNode(common.BytesToHash(serialized[33:65])),
depth: depth,
children: [2]BinaryNode{
HashedNode(common.BytesToHash(serialized[1:33])),
HashedNode(common.BytesToHash(serialized[33:65])),
},
hash: hn,
mustRecompute: mustRecompute,
}, nil

View file

@ -30,9 +30,8 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321")
node := &InternalNode{
depth: 5,
left: HashedNode(leftHash),
right: HashedNode(rightHash),
depth: 5,
children: [2]BinaryNode{HashedNode(leftHash), HashedNode(rightHash)},
}
// Serialize the node
@ -64,13 +63,13 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
t.Errorf("Expected depth 5, got %d", internalNode.depth)
}
// Check the left and right hashes
if internalNode.left.Hash() != leftHash {
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash())
// Check the children hashes
if internalNode.children[0].Hash() != leftHash {
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.children[0].Hash())
}
if internalNode.right.Hash() != rightHash {
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash())
if internalNode.children[1].Hash() != rightHash {
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.children[1].Hash())
}
}

View file

@ -58,8 +58,8 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
// InternalNode is a binary trie internal node.
type InternalNode struct {
left, right BinaryNode
depth int
children [2]BinaryNode // 0: left, 1: right
depth int
mustRecompute bool // true if the hash needs to be recomputed
hash common.Hash // cached hash when mustRecompute == false
@ -70,28 +70,8 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if bt.depth > 31*8 {
return nil, errors.New("node too deep")
}
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 {
if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
return bt.left.GetValuesAtStem(stem, resolver)
}
if hn, ok := bt.right.(HashedNode); ok {
if hn, ok := bt.children[bit].(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
@ -104,9 +84,9 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.right = node
bt.children[bit] = node
}
return bt.right.GetValuesAtStem(stem, resolver)
return bt.children[bit].GetValuesAtStem(stem, resolver)
}
// Get retrieves the value for the given key.
@ -131,8 +111,7 @@ func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn
// Copy creates a deep copy of the node.
func (bt *InternalNode) Copy() BinaryNode {
return &InternalNode{
left: bt.left.Copy(),
right: bt.right.Copy(),
children: [2]BinaryNode{bt.children[0].Copy(), bt.children[1].Copy()},
depth: bt.depth,
mustRecompute: bt.mustRecompute,
hash: bt.hash,
@ -149,16 +128,16 @@ func (bt *InternalNode) Hash() common.Hash {
// hash left subtree in a goroutine, right subtree inline, then combine.
// Skip goroutine overhead when only one child is dirty (common case
// for narrow state updates that touch a single path through the trie).
if bt.depth < parallelDepth() && isDirty(bt.left) && isDirty(bt.right) {
if bt.depth < parallelDepth() && isDirty(bt.children[0]) && isDirty(bt.children[1]) {
var input [64]byte
var lh common.Hash
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
lh = bt.left.Hash()
lh = bt.children[0].Hash()
}()
rh := bt.right.Hash()
rh := bt.children[1].Hash()
copy(input[32:], rh[:])
wg.Wait()
copy(input[:32], lh[:])
@ -170,15 +149,12 @@ func (bt *InternalNode) Hash() common.Hash {
// Deeper nodes: sequential using pooled hasher (goroutine overhead > hash cost)
h := newSha256()
defer returnSha256(h)
if bt.left != nil {
h.Write(bt.left.Hash().Bytes())
} else {
h.Write(zero[:])
}
if bt.right != nil {
h.Write(bt.right.Hash().Bytes())
} else {
h.Write(zero[:])
for _, child := range bt.children {
if child != nil {
h.Write(child.Hash().Bytes())
} else {
h.Write(zero[:])
}
}
bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false
@ -188,39 +164,11 @@ func (bt *InternalNode) Hash() common.Hash {
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
// Already-existing values will be overwritten.
func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
var err error
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 {
if bt.left == nil {
bt.left = Empty{}
}
if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
return bt, err
if bt.children[bit] == nil {
bt.children[bit] = Empty{}
}
if bt.right == nil {
bt.right = Empty{}
}
if hn, ok := bt.right.(HashedNode); ok {
if hn, ok := bt.children[bit].(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
@ -233,10 +181,10 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.right = node
bt.children[bit] = node
}
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
var err error
bt.children[bit], err = bt.children[bit].InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
return bt, err
}
@ -244,22 +192,15 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
// CollectNodes collects all child nodes at a given path, and flushes it
// into the provided node collector.
func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
if bt.left != nil {
var p [256]byte
copy(p[:], path)
childpath := p[:len(path)]
childpath = append(childpath, 0)
if err := bt.left.CollectNodes(childpath, flushfn); err != nil {
return err
}
}
if bt.right != nil {
var p [256]byte
copy(p[:], path)
childpath := p[:len(path)]
childpath = append(childpath, 1)
if err := bt.right.CollectNodes(childpath, flushfn); err != nil {
return err
for i, child := range bt.children {
if child != nil {
var p [256]byte
copy(p[:], path)
childpath := p[:len(path)]
childpath = append(childpath, byte(i))
if err := child.CollectNodes(childpath, flushfn); err != nil {
return err
}
}
}
flushfn(path, bt)
@ -268,17 +209,13 @@ func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
// GetHeight returns the height of the node.
func (bt *InternalNode) GetHeight() int {
var (
leftHeight int
rightHeight int
)
if bt.left != nil {
leftHeight = bt.left.GetHeight()
var maxHeight int
for _, child := range bt.children {
if child != nil {
maxHeight = max(maxHeight, child.GetHeight())
}
}
if bt.right != nil {
rightHeight = bt.right.GetHeight()
}
return 1 + max(leftHeight, rightHeight)
return 1 + maxHeight
}
func (bt *InternalNode) toDot(parent, path string) string {
@ -287,12 +224,10 @@ func (bt *InternalNode) toDot(parent, path string) string {
if len(parent) > 0 {
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
}
if bt.left != nil {
ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0)))
}
if bt.right != nil {
ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1)))
for i, child := range bt.children {
if child != nil {
ret = fmt.Sprintf("%s%s", ret, child.toDot(me, fmt.Sprintf("%s%02x", path, i)))
}
}
return ret
}

View file

@ -37,15 +37,17 @@ func TestInternalNodeGet(t *testing.T) {
node := &InternalNode{
depth: 0,
left: &StemNode{
Stem: leftStem,
Values: leftValues[:],
depth: 1,
},
right: &StemNode{
Stem: rightStem,
Values: rightValues[:],
depth: 1,
children: [2]BinaryNode{
&StemNode{
Stem: leftStem,
Values: leftValues[:],
depth: 1,
},
&StemNode{
Stem: rightStem,
Values: rightValues[:],
depth: 1,
},
},
}
@ -79,9 +81,8 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
hashedChild := HashedNode(common.HexToHash("0x1234"))
node := &InternalNode{
depth: 0,
left: hashedChild,
right: Empty{},
depth: 0,
children: [2]BinaryNode{hashedChild, Empty{}},
}
// Mock resolver that returns a stem node
@ -118,9 +119,8 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
func TestInternalNodeInsert(t *testing.T) {
// Start with an internal node with empty children
node := &InternalNode{
depth: 0,
left: Empty{},
right: Empty{},
depth: 0,
children: [2]BinaryNode{Empty{}, Empty{}},
}
// Insert a value into the left subtree
@ -139,9 +139,9 @@ func TestInternalNodeInsert(t *testing.T) {
}
// Check that left child is now a StemNode
leftStem, ok := internalNode.left.(*StemNode)
leftStem, ok := internalNode.children[0].(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.children[0])
}
// Check the inserted value
@ -150,9 +150,9 @@ func TestInternalNodeInsert(t *testing.T) {
}
// Right child should still be Empty
_, ok = internalNode.right.(Empty)
_, ok = internalNode.children[1].(Empty)
if !ok {
t.Errorf("Expected right child to remain Empty, got %T", internalNode.right)
t.Errorf("Expected right child to remain Empty, got %T", internalNode.children[1])
}
}
@ -175,9 +175,8 @@ func TestInternalNodeCopy(t *testing.T) {
rightStem.Values[0] = common.HexToHash("0x0202").Bytes()
node := &InternalNode{
depth: 0,
left: leftStem,
right: rightStem,
depth: 0,
children: [2]BinaryNode{leftStem, rightStem},
}
// Create a copy
@ -193,14 +192,14 @@ func TestInternalNodeCopy(t *testing.T) {
}
// Check that children are copied
copiedLeft, ok := copiedInternal.left.(*StemNode)
copiedLeft, ok := copiedInternal.children[0].(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.left)
t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.children[0])
}
copiedRight, ok := copiedInternal.right.(*StemNode)
copiedRight, ok := copiedInternal.children[1].(*StemNode)
if !ok {
t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.right)
t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.children[1])
}
// Verify deep copy (children should be different objects)
@ -224,9 +223,8 @@ func TestInternalNodeCopy(t *testing.T) {
func TestInternalNodeHash(t *testing.T) {
// Create an internal node
node := &InternalNode{
depth: 0,
left: HashedNode(common.HexToHash("0x1111")),
right: HashedNode(common.HexToHash("0x2222")),
depth: 0,
children: [2]BinaryNode{HashedNode(common.HexToHash("0x1111")), HashedNode(common.HexToHash("0x2222"))},
}
hash1 := node.Hash()
@ -238,7 +236,7 @@ func TestInternalNodeHash(t *testing.T) {
}
// Changing a child should change the hash
node.left = HashedNode(common.HexToHash("0x3333"))
node.children[0] = HashedNode(common.HexToHash("0x3333"))
node.mustRecompute = true
hash3 := node.Hash()
if hash1 == hash3 {
@ -248,8 +246,7 @@ func TestInternalNodeHash(t *testing.T) {
// Test with nil children (should use zero hash)
nodeWithNil := &InternalNode{
depth: 0,
left: nil,
right: HashedNode(common.HexToHash("0x4444")),
children: [2]BinaryNode{nil, HashedNode(common.HexToHash("0x4444"))},
mustRecompute: true,
}
hashWithNil := nodeWithNil.Hash()
@ -273,15 +270,17 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
node := &InternalNode{
depth: 0,
left: &StemNode{
Stem: leftStem,
Values: leftValues[:],
depth: 1,
},
right: &StemNode{
Stem: rightStem,
Values: rightValues[:],
depth: 1,
children: [2]BinaryNode{
&StemNode{
Stem: leftStem,
Values: leftValues[:],
depth: 1,
},
&StemNode{
Stem: rightStem,
Values: rightValues[:],
depth: 1,
},
},
}
@ -314,9 +313,8 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
func TestInternalNodeInsertValuesAtStem(t *testing.T) {
// Start with an internal node with empty children
node := &InternalNode{
depth: 0,
left: Empty{},
right: Empty{},
depth: 0,
children: [2]BinaryNode{Empty{}, Empty{}},
}
// Insert values at a stem in the left subtree
@ -336,9 +334,9 @@ func TestInternalNodeInsertValuesAtStem(t *testing.T) {
}
// Check that left child is now a StemNode with the values
leftStem, ok := internalNode.left.(*StemNode)
leftStem, ok := internalNode.children[0].(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.children[0])
}
if !bytes.Equal(leftStem.Values[5], values[5]) {
@ -366,9 +364,8 @@ func TestInternalNodeCollectNodes(t *testing.T) {
rightStem.Stem[0] = 0x80
node := &InternalNode{
depth: 0,
left: leftStem,
right: rightStem,
depth: 0,
children: [2]BinaryNode{leftStem, rightStem},
}
var collectedPaths [][]byte
@ -412,12 +409,14 @@ func TestInternalNodeGetHeight(t *testing.T) {
// Right subtree: depth 1 (stem)
leftInternal := &InternalNode{
depth: 1,
left: &StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 2,
children: [2]BinaryNode{
&StemNode{
Stem: make([]byte, 31),
Values: make([][]byte, 256),
depth: 2,
},
Empty{},
},
right: Empty{},
}
rightStem := &StemNode{
@ -427,9 +426,8 @@ func TestInternalNodeGetHeight(t *testing.T) {
}
node := &InternalNode{
depth: 0,
left: leftInternal,
right: rightStem,
depth: 0,
children: [2]BinaryNode{leftInternal, rightStem},
}
height := node.GetHeight()
@ -444,9 +442,8 @@ func TestInternalNodeGetHeight(t *testing.T) {
func TestInternalNodeDepthTooLarge(t *testing.T) {
// Create an internal node at max depth
node := &InternalNode{
depth: 31*8 + 1,
left: Empty{},
right: Empty{},
depth: 31*8 + 1,
children: [2]BinaryNode{Empty{}, Empty{}},
}
stem := make([]byte, 31)

View file

@ -67,24 +67,13 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
// index: 0 = nothing visited, 1=left visited, 2=right visited
context := &it.stack[len(it.stack)-1]
// recurse into both children
if context.Index == 0 {
if _, isempty := node.left.(Empty); node.left != nil && !isempty {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left})
it.current = node.left
for context.Index < 2 {
child := node.children[context.Index]
if _, isempty := child.(Empty); child != nil && !isempty {
it.stack = append(it.stack, binaryNodeIteratorState{Node: child})
it.current = child
return it.Next(descend)
}
context.Index++
}
if context.Index == 1 {
if _, isempty := node.right.(Empty); node.right != nil && !isempty {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right})
it.current = node.right
return it.Next(descend)
}
context.Index++
}
@ -139,11 +128,7 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
it.stack[len(it.stack)-1].Node = it.current
if len(it.stack) >= 2 {
parent := &it.stack[len(it.stack)-2]
if parent.Index == 0 {
parent.Node.(*InternalNode).left = it.current
} else {
parent.Node.(*InternalNode).right = it.current
}
parent.Node.(*InternalNode).children[parent.Index] = it.current
}
return it.Next(descend)
case Empty:
@ -268,15 +253,8 @@ func (it *binaryNodeIterator) LeafProof() [][]byte {
for i := range it.stack[:len(it.stack)-2] {
state := it.stack[i]
internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode
// Add the sibling hash to the proof
if state.Index == 0 {
// We came from left, so include right sibling
proof = append(proof, internalNode.right.Hash().Bytes())
} else {
// We came from right, so include left sibling
proof = append(proof, internalNode.left.Hash().Bytes())
}
sibling := internalNode.children[1-state.Index]
proof = append(proof, sibling.Hash().Bytes())
}
// Add the stem and siblings

View file

@ -170,7 +170,7 @@ func TestIteratorHashedNodeNilData(t *testing.T) {
// Replace right child with a zero-hash HashedNode. nodeResolver
// short-circuits on common.Hash{} and returns (nil, nil), which
// triggers the nil-data guard in the iterator.
root.right = HashedNode(common.Hash{})
root.children[1] = HashedNode(common.Hash{})
// Should not panic; the zero-hash right child should be treated as Empty.
if leaves := countLeaves(t, tr); leaves != 1 {

View file

@ -50,16 +50,9 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
n := &InternalNode{depth: bt.depth, mustRecompute: true}
bt.depth++
var child, other *BinaryNode
if bitStem == 0 {
n.left = bt
child = &n.left
other = &n.right
} else {
n.right = bt
child = &n.right
other = &n.left
}
n.children[bitStem] = bt
child := &n.children[bitStem]
other := &n.children[1-bitStem]
bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
if bitKey == bitStem {
@ -174,16 +167,9 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
n := &InternalNode{depth: bt.depth, mustRecompute: true}
bt.depth++
var child, other *BinaryNode
if bitStem == 0 {
n.left = bt
child = &n.left
other = &n.right
} else {
n.right = bt
child = &n.right
other = &n.left
}
n.children[bitStem] = bt
child := &n.children[bitStem]
other := &n.children[1-bitStem]
bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
if bitKey == bitStem {

View file

@ -147,18 +147,18 @@ func TestStemNodeInsertDifferentStem(t *testing.T) {
}
// Original stem should be on the left (bit 0)
leftStem, ok := internalNode.left.(*StemNode)
leftStem, ok := internalNode.children[0].(*StemNode)
if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left)
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.children[0])
}
if !bytes.Equal(leftStem.Stem, stem1) {
t.Errorf("Left stem mismatch")
}
// New stem should be on the right (bit 1)
rightStem, ok := internalNode.right.(*StemNode)
rightStem, ok := internalNode.children[1].(*StemNode)
if !ok {
t.Fatalf("Expected right child to be StemNode, got %T", internalNode.right)
t.Fatalf("Expected right child to be StemNode, got %T", internalNode.children[1])
}
if !bytes.Equal(rightStem.Stem, key[:31]) {
t.Errorf("Right stem mismatch")