This commit is contained in:
Ng Wei Han 2026-04-16 22:36:52 -07:00 committed by GitHub
commit c4c379e7cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 130 additions and 233 deletions

View file

@ -64,8 +64,8 @@ func SerializeNode(node BinaryNode) []byte {
// InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash // InternalNode: 1 byte type + 32 bytes left hash + 32 bytes right hash
var serialized [NodeTypeBytes + HashSize + HashSize]byte var serialized [NodeTypeBytes + HashSize + HashSize]byte
serialized[0] = nodeTypeInternal serialized[0] = nodeTypeInternal
copy(serialized[1:33], n.left.Hash().Bytes()) copy(serialized[1:33], n.children[0].Hash().Bytes())
copy(serialized[33:65], n.right.Hash().Bytes()) copy(serialized[33:65], n.children[1].Hash().Bytes())
return serialized[:] return serialized[:]
case *StemNode: case *StemNode:
// StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values // StemNode: 1 byte type + 31 bytes stem + 32 bytes bitmap + 256*32 bytes values
@ -112,9 +112,11 @@ func deserializeNode(serialized []byte, depth int, hn common.Hash, mustRecompute
return nil, invalidSerializedLength return nil, invalidSerializedLength
} }
return &InternalNode{ return &InternalNode{
depth: depth, depth: depth,
left: HashedNode(common.BytesToHash(serialized[1:33])), children: [2]BinaryNode{
right: HashedNode(common.BytesToHash(serialized[33:65])), HashedNode(common.BytesToHash(serialized[1:33])),
HashedNode(common.BytesToHash(serialized[33:65])),
},
hash: hn, hash: hn,
mustRecompute: mustRecompute, mustRecompute: mustRecompute,
}, nil }, nil

View file

@ -30,9 +30,8 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321") rightHash := common.HexToHash("0xfedcba0987654321fedcba0987654321fedcba0987654321fedcba0987654321")
node := &InternalNode{ node := &InternalNode{
depth: 5, depth: 5,
left: HashedNode(leftHash), children: [2]BinaryNode{HashedNode(leftHash), HashedNode(rightHash)},
right: HashedNode(rightHash),
} }
// Serialize the node // Serialize the node
@ -64,13 +63,13 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
t.Errorf("Expected depth 5, got %d", internalNode.depth) t.Errorf("Expected depth 5, got %d", internalNode.depth)
} }
// Check the left and right hashes // Check the children hashes
if internalNode.left.Hash() != leftHash { if internalNode.children[0].Hash() != leftHash {
t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.left.Hash()) t.Errorf("Left hash mismatch: expected %x, got %x", leftHash, internalNode.children[0].Hash())
} }
if internalNode.right.Hash() != rightHash { if internalNode.children[1].Hash() != rightHash {
t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.right.Hash()) t.Errorf("Right hash mismatch: expected %x, got %x", rightHash, internalNode.children[1].Hash())
} }
} }

View file

@ -58,8 +58,8 @@ func keyToPath(depth int, key []byte) ([]byte, error) {
// InternalNode is a binary trie internal node. // InternalNode is a binary trie internal node.
type InternalNode struct { type InternalNode struct {
left, right BinaryNode children [2]BinaryNode // 0: left, 1: right
depth int depth int
mustRecompute bool // true if the hash needs to be recomputed mustRecompute bool // true if the hash needs to be recomputed
hash common.Hash // cached hash when mustRecompute == false hash common.Hash // cached hash when mustRecompute == false
@ -70,28 +70,8 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if bt.depth > 31*8 { if bt.depth > 31*8 {
return nil, errors.New("node too deep") return nil, errors.New("node too deep")
} }
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 { if hn, ok := bt.children[bit].(HashedNode); ok {
if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
return bt.left.GetValuesAtStem(stem, resolver)
}
if hn, ok := bt.right.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem) path, err := keyToPath(bt.depth, stem)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err) return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
@ -104,9 +84,9 @@ func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([
if err != nil { if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err) return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
} }
bt.right = node bt.children[bit] = node
} }
return bt.right.GetValuesAtStem(stem, resolver) return bt.children[bit].GetValuesAtStem(stem, resolver)
} }
// Get retrieves the value for the given key. // Get retrieves the value for the given key.
@ -131,8 +111,7 @@ func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn
// Copy creates a deep copy of the node. // Copy creates a deep copy of the node.
func (bt *InternalNode) Copy() BinaryNode { func (bt *InternalNode) Copy() BinaryNode {
return &InternalNode{ return &InternalNode{
left: bt.left.Copy(), children: [2]BinaryNode{bt.children[0].Copy(), bt.children[1].Copy()},
right: bt.right.Copy(),
depth: bt.depth, depth: bt.depth,
mustRecompute: bt.mustRecompute, mustRecompute: bt.mustRecompute,
hash: bt.hash, hash: bt.hash,
@ -149,16 +128,16 @@ func (bt *InternalNode) Hash() common.Hash {
// hash left subtree in a goroutine, right subtree inline, then combine. // hash left subtree in a goroutine, right subtree inline, then combine.
// Skip goroutine overhead when only one child is dirty (common case // Skip goroutine overhead when only one child is dirty (common case
// for narrow state updates that touch a single path through the trie). // for narrow state updates that touch a single path through the trie).
if bt.depth < parallelDepth() && isDirty(bt.left) && isDirty(bt.right) { if bt.depth < parallelDepth() && isDirty(bt.children[0]) && isDirty(bt.children[1]) {
var input [64]byte var input [64]byte
var lh common.Hash var lh common.Hash
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
lh = bt.left.Hash() lh = bt.children[0].Hash()
}() }()
rh := bt.right.Hash() rh := bt.children[1].Hash()
copy(input[32:], rh[:]) copy(input[32:], rh[:])
wg.Wait() wg.Wait()
copy(input[:32], lh[:]) copy(input[:32], lh[:])
@ -170,15 +149,12 @@ func (bt *InternalNode) Hash() common.Hash {
// Deeper nodes: sequential using pooled hasher (goroutine overhead > hash cost) // Deeper nodes: sequential using pooled hasher (goroutine overhead > hash cost)
h := newSha256() h := newSha256()
defer returnSha256(h) defer returnSha256(h)
if bt.left != nil { for _, child := range bt.children {
h.Write(bt.left.Hash().Bytes()) if child != nil {
} else { h.Write(child.Hash().Bytes())
h.Write(zero[:]) } else {
} h.Write(zero[:])
if bt.right != nil { }
h.Write(bt.right.Hash().Bytes())
} else {
h.Write(zero[:])
} }
bt.hash = common.BytesToHash(h.Sum(nil)) bt.hash = common.BytesToHash(h.Sum(nil))
bt.mustRecompute = false bt.mustRecompute = false
@ -188,39 +164,11 @@ func (bt *InternalNode) Hash() common.Hash {
// InsertValuesAtStem inserts a full value group at the given stem in the internal node. // InsertValuesAtStem inserts a full value group at the given stem in the internal node.
// Already-existing values will be overwritten. // Already-existing values will be overwritten.
func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) { func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
var err error
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1 bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 { if bt.children[bit] == nil {
if bt.left == nil { bt.children[bit] = Empty{}
bt.left = Empty{}
}
if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true
return bt, err
} }
if hn, ok := bt.children[bit].(HashedNode); ok {
if bt.right == nil {
bt.right = Empty{}
}
if hn, ok := bt.right.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem) path, err := keyToPath(bt.depth, stem)
if err != nil { if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err) return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
@ -233,10 +181,10 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
if err != nil { if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err) return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
} }
bt.right = node bt.children[bit] = node
} }
var err error
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1) bt.children[bit], err = bt.children[bit].InsertValuesAtStem(stem, values, resolver, depth+1)
bt.mustRecompute = true bt.mustRecompute = true
return bt, err return bt, err
} }
@ -244,22 +192,15 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve
// CollectNodes collects all child nodes at a given path, and flushes it // CollectNodes collects all child nodes at a given path, and flushes it
// into the provided node collector. // into the provided node collector.
func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error { func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
if bt.left != nil { for i, child := range bt.children {
var p [256]byte if child != nil {
copy(p[:], path) var p [256]byte
childpath := p[:len(path)] copy(p[:], path)
childpath = append(childpath, 0) childpath := p[:len(path)]
if err := bt.left.CollectNodes(childpath, flushfn); err != nil { childpath = append(childpath, byte(i))
return err if err := child.CollectNodes(childpath, flushfn); err != nil {
} return err
} }
if bt.right != nil {
var p [256]byte
copy(p[:], path)
childpath := p[:len(path)]
childpath = append(childpath, 1)
if err := bt.right.CollectNodes(childpath, flushfn); err != nil {
return err
} }
} }
flushfn(path, bt) flushfn(path, bt)
@ -268,17 +209,13 @@ func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
// GetHeight returns the height of the node. // GetHeight returns the height of the node.
func (bt *InternalNode) GetHeight() int { func (bt *InternalNode) GetHeight() int {
var ( var maxHeight int
leftHeight int for _, child := range bt.children {
rightHeight int if child != nil {
) maxHeight = max(maxHeight, child.GetHeight())
if bt.left != nil { }
leftHeight = bt.left.GetHeight()
} }
if bt.right != nil { return 1 + maxHeight
rightHeight = bt.right.GetHeight()
}
return 1 + max(leftHeight, rightHeight)
} }
func (bt *InternalNode) toDot(parent, path string) string { func (bt *InternalNode) toDot(parent, path string) string {
@ -287,12 +224,10 @@ func (bt *InternalNode) toDot(parent, path string) string {
if len(parent) > 0 { if len(parent) > 0 {
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me) ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
} }
for i, child := range bt.children {
if bt.left != nil { if child != nil {
ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0))) ret = fmt.Sprintf("%s%s", ret, child.toDot(me, fmt.Sprintf("%s%02x", path, i)))
} }
if bt.right != nil {
ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1)))
} }
return ret return ret
} }

View file

@ -37,15 +37,17 @@ func TestInternalNodeGet(t *testing.T) {
node := &InternalNode{ node := &InternalNode{
depth: 0, depth: 0,
left: &StemNode{ children: [2]BinaryNode{
Stem: leftStem, &StemNode{
Values: leftValues[:], Stem: leftStem,
depth: 1, Values: leftValues[:],
}, depth: 1,
right: &StemNode{ },
Stem: rightStem, &StemNode{
Values: rightValues[:], Stem: rightStem,
depth: 1, Values: rightValues[:],
depth: 1,
},
}, },
} }
@ -79,9 +81,8 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
hashedChild := HashedNode(common.HexToHash("0x1234")) hashedChild := HashedNode(common.HexToHash("0x1234"))
node := &InternalNode{ node := &InternalNode{
depth: 0, depth: 0,
left: hashedChild, children: [2]BinaryNode{hashedChild, Empty{}},
right: Empty{},
} }
// Mock resolver that returns a stem node // Mock resolver that returns a stem node
@ -118,9 +119,8 @@ func TestInternalNodeGetWithResolver(t *testing.T) {
func TestInternalNodeInsert(t *testing.T) { func TestInternalNodeInsert(t *testing.T) {
// Start with an internal node with empty children // Start with an internal node with empty children
node := &InternalNode{ node := &InternalNode{
depth: 0, depth: 0,
left: Empty{}, children: [2]BinaryNode{Empty{}, Empty{}},
right: Empty{},
} }
// Insert a value into the left subtree // Insert a value into the left subtree
@ -139,9 +139,9 @@ func TestInternalNodeInsert(t *testing.T) {
} }
// Check that left child is now a StemNode // Check that left child is now a StemNode
leftStem, ok := internalNode.left.(*StemNode) leftStem, ok := internalNode.children[0].(*StemNode)
if !ok { if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) t.Fatalf("Expected left child to be StemNode, got %T", internalNode.children[0])
} }
// Check the inserted value // Check the inserted value
@ -150,9 +150,9 @@ func TestInternalNodeInsert(t *testing.T) {
} }
// Right child should still be Empty // Right child should still be Empty
_, ok = internalNode.right.(Empty) _, ok = internalNode.children[1].(Empty)
if !ok { if !ok {
t.Errorf("Expected right child to remain Empty, got %T", internalNode.right) t.Errorf("Expected right child to remain Empty, got %T", internalNode.children[1])
} }
} }
@ -175,9 +175,8 @@ func TestInternalNodeCopy(t *testing.T) {
rightStem.Values[0] = common.HexToHash("0x0202").Bytes() rightStem.Values[0] = common.HexToHash("0x0202").Bytes()
node := &InternalNode{ node := &InternalNode{
depth: 0, depth: 0,
left: leftStem, children: [2]BinaryNode{leftStem, rightStem},
right: rightStem,
} }
// Create a copy // Create a copy
@ -193,14 +192,14 @@ func TestInternalNodeCopy(t *testing.T) {
} }
// Check that children are copied // Check that children are copied
copiedLeft, ok := copiedInternal.left.(*StemNode) copiedLeft, ok := copiedInternal.children[0].(*StemNode)
if !ok { if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.left) t.Fatalf("Expected left child to be StemNode, got %T", copiedInternal.children[0])
} }
copiedRight, ok := copiedInternal.right.(*StemNode) copiedRight, ok := copiedInternal.children[1].(*StemNode)
if !ok { if !ok {
t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.right) t.Fatalf("Expected right child to be StemNode, got %T", copiedInternal.children[1])
} }
// Verify deep copy (children should be different objects) // Verify deep copy (children should be different objects)
@ -224,9 +223,8 @@ func TestInternalNodeCopy(t *testing.T) {
func TestInternalNodeHash(t *testing.T) { func TestInternalNodeHash(t *testing.T) {
// Create an internal node // Create an internal node
node := &InternalNode{ node := &InternalNode{
depth: 0, depth: 0,
left: HashedNode(common.HexToHash("0x1111")), children: [2]BinaryNode{HashedNode(common.HexToHash("0x1111")), HashedNode(common.HexToHash("0x2222"))},
right: HashedNode(common.HexToHash("0x2222")),
} }
hash1 := node.Hash() hash1 := node.Hash()
@ -238,7 +236,7 @@ func TestInternalNodeHash(t *testing.T) {
} }
// Changing a child should change the hash // Changing a child should change the hash
node.left = HashedNode(common.HexToHash("0x3333")) node.children[0] = HashedNode(common.HexToHash("0x3333"))
node.mustRecompute = true node.mustRecompute = true
hash3 := node.Hash() hash3 := node.Hash()
if hash1 == hash3 { if hash1 == hash3 {
@ -248,8 +246,7 @@ func TestInternalNodeHash(t *testing.T) {
// Test with nil children (should use zero hash) // Test with nil children (should use zero hash)
nodeWithNil := &InternalNode{ nodeWithNil := &InternalNode{
depth: 0, depth: 0,
left: nil, children: [2]BinaryNode{nil, HashedNode(common.HexToHash("0x4444"))},
right: HashedNode(common.HexToHash("0x4444")),
mustRecompute: true, mustRecompute: true,
} }
hashWithNil := nodeWithNil.Hash() hashWithNil := nodeWithNil.Hash()
@ -273,15 +270,17 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
node := &InternalNode{ node := &InternalNode{
depth: 0, depth: 0,
left: &StemNode{ children: [2]BinaryNode{
Stem: leftStem, &StemNode{
Values: leftValues[:], Stem: leftStem,
depth: 1, Values: leftValues[:],
}, depth: 1,
right: &StemNode{ },
Stem: rightStem, &StemNode{
Values: rightValues[:], Stem: rightStem,
depth: 1, Values: rightValues[:],
depth: 1,
},
}, },
} }
@ -314,9 +313,8 @@ func TestInternalNodeGetValuesAtStem(t *testing.T) {
func TestInternalNodeInsertValuesAtStem(t *testing.T) { func TestInternalNodeInsertValuesAtStem(t *testing.T) {
// Start with an internal node with empty children // Start with an internal node with empty children
node := &InternalNode{ node := &InternalNode{
depth: 0, depth: 0,
left: Empty{}, children: [2]BinaryNode{Empty{}, Empty{}},
right: Empty{},
} }
// Insert values at a stem in the left subtree // Insert values at a stem in the left subtree
@ -336,9 +334,9 @@ func TestInternalNodeInsertValuesAtStem(t *testing.T) {
} }
// Check that left child is now a StemNode with the values // Check that left child is now a StemNode with the values
leftStem, ok := internalNode.left.(*StemNode) leftStem, ok := internalNode.children[0].(*StemNode)
if !ok { if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) t.Fatalf("Expected left child to be StemNode, got %T", internalNode.children[0])
} }
if !bytes.Equal(leftStem.Values[5], values[5]) { if !bytes.Equal(leftStem.Values[5], values[5]) {
@ -366,9 +364,8 @@ func TestInternalNodeCollectNodes(t *testing.T) {
rightStem.Stem[0] = 0x80 rightStem.Stem[0] = 0x80
node := &InternalNode{ node := &InternalNode{
depth: 0, depth: 0,
left: leftStem, children: [2]BinaryNode{leftStem, rightStem},
right: rightStem,
} }
var collectedPaths [][]byte var collectedPaths [][]byte
@ -412,12 +409,14 @@ func TestInternalNodeGetHeight(t *testing.T) {
// Right subtree: depth 1 (stem) // Right subtree: depth 1 (stem)
leftInternal := &InternalNode{ leftInternal := &InternalNode{
depth: 1, depth: 1,
left: &StemNode{ children: [2]BinaryNode{
Stem: make([]byte, 31), &StemNode{
Values: make([][]byte, 256), Stem: make([]byte, 31),
depth: 2, Values: make([][]byte, 256),
depth: 2,
},
Empty{},
}, },
right: Empty{},
} }
rightStem := &StemNode{ rightStem := &StemNode{
@ -427,9 +426,8 @@ func TestInternalNodeGetHeight(t *testing.T) {
} }
node := &InternalNode{ node := &InternalNode{
depth: 0, depth: 0,
left: leftInternal, children: [2]BinaryNode{leftInternal, rightStem},
right: rightStem,
} }
height := node.GetHeight() height := node.GetHeight()
@ -444,9 +442,8 @@ func TestInternalNodeGetHeight(t *testing.T) {
func TestInternalNodeDepthTooLarge(t *testing.T) { func TestInternalNodeDepthTooLarge(t *testing.T) {
// Create an internal node at max depth // Create an internal node at max depth
node := &InternalNode{ node := &InternalNode{
depth: 31*8 + 1, depth: 31*8 + 1,
left: Empty{}, children: [2]BinaryNode{Empty{}, Empty{}},
right: Empty{},
} }
stem := make([]byte, 31) stem := make([]byte, 31)

View file

@ -67,24 +67,13 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
// index: 0 = nothing visited, 1=left visited, 2=right visited // index: 0 = nothing visited, 1=left visited, 2=right visited
context := &it.stack[len(it.stack)-1] context := &it.stack[len(it.stack)-1]
// recurse into both children for context.Index < 2 {
if context.Index == 0 { child := node.children[context.Index]
if _, isempty := node.left.(Empty); node.left != nil && !isempty { if _, isempty := child.(Empty); child != nil && !isempty {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.left}) it.stack = append(it.stack, binaryNodeIteratorState{Node: child})
it.current = node.left it.current = child
return it.Next(descend) return it.Next(descend)
} }
context.Index++
}
if context.Index == 1 {
if _, isempty := node.right.(Empty); node.right != nil && !isempty {
it.stack = append(it.stack, binaryNodeIteratorState{Node: node.right})
it.current = node.right
return it.Next(descend)
}
context.Index++ context.Index++
} }
@ -139,11 +128,7 @@ func (it *binaryNodeIterator) Next(descend bool) bool {
it.stack[len(it.stack)-1].Node = it.current it.stack[len(it.stack)-1].Node = it.current
if len(it.stack) >= 2 { if len(it.stack) >= 2 {
parent := &it.stack[len(it.stack)-2] parent := &it.stack[len(it.stack)-2]
if parent.Index == 0 { parent.Node.(*InternalNode).children[parent.Index] = it.current
parent.Node.(*InternalNode).left = it.current
} else {
parent.Node.(*InternalNode).right = it.current
}
} }
return it.Next(descend) return it.Next(descend)
case Empty: case Empty:
@ -268,15 +253,8 @@ func (it *binaryNodeIterator) LeafProof() [][]byte {
for i := range it.stack[:len(it.stack)-2] { for i := range it.stack[:len(it.stack)-2] {
state := it.stack[i] state := it.stack[i]
internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode internalNode := state.Node.(*InternalNode) // should panic if the node isn't an InternalNode
sibling := internalNode.children[1-state.Index]
// Add the sibling hash to the proof proof = append(proof, sibling.Hash().Bytes())
if state.Index == 0 {
// We came from left, so include right sibling
proof = append(proof, internalNode.right.Hash().Bytes())
} else {
// We came from right, so include left sibling
proof = append(proof, internalNode.left.Hash().Bytes())
}
} }
// Add the stem and siblings // Add the stem and siblings

View file

@ -170,7 +170,7 @@ func TestIteratorHashedNodeNilData(t *testing.T) {
// Replace right child with a zero-hash HashedNode. nodeResolver // Replace right child with a zero-hash HashedNode. nodeResolver
// short-circuits on common.Hash{} and returns (nil, nil), which // short-circuits on common.Hash{} and returns (nil, nil), which
// triggers the nil-data guard in the iterator. // triggers the nil-data guard in the iterator.
root.right = HashedNode(common.Hash{}) root.children[1] = HashedNode(common.Hash{})
// Should not panic; the zero-hash right child should be treated as Empty. // Should not panic; the zero-hash right child should be treated as Empty.
if leaves := countLeaves(t, tr); leaves != 1 { if leaves := countLeaves(t, tr); leaves != 1 {

View file

@ -50,16 +50,9 @@ func (bt *StemNode) Insert(key []byte, value []byte, _ NodeResolverFn, depth int
n := &InternalNode{depth: bt.depth, mustRecompute: true} n := &InternalNode{depth: bt.depth, mustRecompute: true}
bt.depth++ bt.depth++
var child, other *BinaryNode n.children[bitStem] = bt
if bitStem == 0 { child := &n.children[bitStem]
n.left = bt other := &n.children[1-bitStem]
child = &n.left
other = &n.right
} else {
n.right = bt
child = &n.right
other = &n.left
}
bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1 bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
if bitKey == bitStem { if bitKey == bitStem {
@ -174,16 +167,9 @@ func (bt *StemNode) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolv
n := &InternalNode{depth: bt.depth, mustRecompute: true} n := &InternalNode{depth: bt.depth, mustRecompute: true}
bt.depth++ bt.depth++
var child, other *BinaryNode n.children[bitStem] = bt
if bitStem == 0 { child := &n.children[bitStem]
n.left = bt other := &n.children[1-bitStem]
child = &n.left
other = &n.right
} else {
n.right = bt
child = &n.right
other = &n.left
}
bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1 bitKey := key[n.depth/8] >> (7 - (n.depth % 8)) & 1
if bitKey == bitStem { if bitKey == bitStem {

View file

@ -147,18 +147,18 @@ func TestStemNodeInsertDifferentStem(t *testing.T) {
} }
// Original stem should be on the left (bit 0) // Original stem should be on the left (bit 0)
leftStem, ok := internalNode.left.(*StemNode) leftStem, ok := internalNode.children[0].(*StemNode)
if !ok { if !ok {
t.Fatalf("Expected left child to be StemNode, got %T", internalNode.left) t.Fatalf("Expected left child to be StemNode, got %T", internalNode.children[0])
} }
if !bytes.Equal(leftStem.Stem, stem1) { if !bytes.Equal(leftStem.Stem, stem1) {
t.Errorf("Left stem mismatch") t.Errorf("Left stem mismatch")
} }
// New stem should be on the right (bit 1) // New stem should be on the right (bit 1)
rightStem, ok := internalNode.right.(*StemNode) rightStem, ok := internalNode.children[1].(*StemNode)
if !ok { if !ok {
t.Fatalf("Expected right child to be StemNode, got %T", internalNode.right) t.Fatalf("Expected right child to be StemNode, got %T", internalNode.children[1])
} }
if !bytes.Equal(rightStem.Stem, key[:31]) { if !bytes.Equal(rightStem.Stem, key[:31]) {
t.Errorf("Right stem mismatch") t.Errorf("Right stem mismatch")