go-ethereum/trie/bintrie/internal_node.go
2026-01-23 11:29:29 +01:00

295 lines
8.9 KiB
Go

// Copyright 2025 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"crypto/sha256"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common"
)
func keyToPath(depth int, key []byte) ([]byte, error) {
if depth > 31*8 {
return nil, errors.New("node too deep")
}
path := make([]byte, 0, depth+1)
for i := range depth + 1 {
bit := key[i/8] >> (7 - (i % 8)) & 1
path = append(path, bit)
}
return path, nil
}
// InternalNode is a binary trie internal node.
type InternalNode struct {
left, right BinaryNode
depth int
}
// GetValuesAtStem retrieves the group of values located at the given stem key.
func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) {
if bt.depth > 31*8 {
return nil, errors.New("node too deep")
}
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 {
if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
return bt.left.GetValuesAtStem(stem, resolver)
}
if hn, ok := bt.right.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
if err != nil {
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
}
bt.right = node
}
return bt.right.GetValuesAtStem(stem, resolver)
}
// Get retrieves the value for the given key.
func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error) {
values, err := bt.GetValuesAtStem(key[:31], resolver)
if err != nil {
return nil, fmt.Errorf("get error: %w", err)
}
if values == nil {
return nil, nil
}
return values[key[31]], nil
}
// Insert inserts a new key-value pair into the trie.
func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
var values [256][]byte
values[key[31]] = value
return bt.InsertValuesAtStem(key[:31], values[:], resolver, depth)
}
// Copy creates a deep copy of the node.
func (bt *InternalNode) Copy() BinaryNode {
return &InternalNode{
left: bt.left.Copy(),
right: bt.right.Copy(),
depth: bt.depth,
}
}
// Hash returns the hash of the node.
func (bt *InternalNode) Hash() common.Hash {
h := sha256.New()
if bt.left != nil {
h.Write(bt.left.Hash().Bytes())
} else {
h.Write(zero[:])
}
if bt.right != nil {
h.Write(bt.right.Hash().Bytes())
} else {
h.Write(zero[:])
}
return common.BytesToHash(h.Sum(nil))
}
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
// Already-existing values will be overwritten.
func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
var err error
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
if bit == 0 {
if bt.left == nil {
bt.left = Empty{}
}
if hn, ok := bt.left.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.left = node
}
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
return bt, err
}
if bt.right == nil {
bt.right = Empty{}
}
if hn, ok := bt.right.(HashedNode); ok {
path, err := keyToPath(bt.depth, stem)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
data, err := resolver(path, common.Hash(hn))
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
}
node, err := DeserializeNode(data, bt.depth+1)
if err != nil {
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
}
bt.right = node
}
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
return bt, err
}
// CollectNodes collects all child nodes at group boundaries (every groupDepth levels),
// and flushes them into the provided node collector. Each flush serializes a groupDepth-level
// subtree group. Nodes within a group are not flushed individually.
func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn, groupDepth int) error {
if groupDepth < 1 || groupDepth > MaxGroupDepth {
return errors.New("groupDepth must be between 1 and 8")
}
// Only flush at group boundaries (depth % groupDepth == 0)
if bt.depth%groupDepth == 0 {
// We're at a group boundary - first collect any nodes in deeper groups,
// then flush this group
if err := bt.collectChildGroups(path, flushfn, groupDepth, groupDepth-1); err != nil {
return err
}
flushfn(path, bt)
return nil
}
// Not at a group boundary - this shouldn't happen if we're called correctly from root
// but handle it by continuing to traverse
return bt.collectChildGroups(path, flushfn, groupDepth, groupDepth-(bt.depth%groupDepth)-1)
}
// collectChildGroups traverses within a group to find and collect nodes in the next group.
// remainingLevels is how many more levels below the current node until we reach the group boundary.
// When remainingLevels=0, the current node's children are at the next group boundary.
func (bt *InternalNode) collectChildGroups(path []byte, flushfn NodeFlushFn, groupDepth int, remainingLevels int) error {
if remainingLevels == 0 {
// Current node is at depth (groupBoundary - 1), its children are at the next group boundary
if bt.left != nil {
if err := bt.left.CollectNodes(appendBit(path, 0), flushfn, groupDepth); err != nil {
return err
}
}
if bt.right != nil {
if err := bt.right.CollectNodes(appendBit(path, 1), flushfn, groupDepth); err != nil {
return err
}
}
return nil
}
// Continue traversing within the group
if bt.left != nil {
switch n := bt.left.(type) {
case *InternalNode:
if err := n.collectChildGroups(appendBit(path, 0), flushfn, groupDepth, remainingLevels-1); err != nil {
return err
}
default:
// StemNode, HashedNode, or Empty - they handle their own collection
if err := bt.left.CollectNodes(appendBit(path, 0), flushfn, groupDepth); err != nil {
return err
}
}
}
if bt.right != nil {
switch n := bt.right.(type) {
case *InternalNode:
if err := n.collectChildGroups(appendBit(path, 1), flushfn, groupDepth, remainingLevels-1); err != nil {
return err
}
default:
// StemNode, HashedNode, or Empty - they handle their own collection
if err := bt.right.CollectNodes(appendBit(path, 1), flushfn, groupDepth); err != nil {
return err
}
}
}
return nil
}
// appendBit appends a bit to a path, returning a new slice
func appendBit(path []byte, bit byte) []byte {
var p [256]byte
copy(p[:], path)
result := p[:len(path)]
return append(result, bit)
}
// GetHeight returns the height of the node.
func (bt *InternalNode) GetHeight() int {
var (
leftHeight int
rightHeight int
)
if bt.left != nil {
leftHeight = bt.left.GetHeight()
}
if bt.right != nil {
rightHeight = bt.right.GetHeight()
}
return 1 + max(leftHeight, rightHeight)
}
func (bt *InternalNode) toDot(parent, path string) string {
me := fmt.Sprintf("internal%s", path)
ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, bt.Hash())
if len(parent) > 0 {
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
}
if bt.left != nil {
ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0)))
}
if bt.right != nil {
ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1)))
}
return ret
}