mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-07-01 02:37:37 +00:00
At tree depths below log2(NumCPU) (capped at 8), hash the left subtree in a goroutine while hashing the right subtree inline when both children need rehashing. This exploits available CPU cores for the top levels of the tree where subtree hashing is most expensive. When only one child is dirty, the goroutine is skipped to avoid overhead. Deeper nodes use sequential hashing with the existing sync.Pool hasher. The parallel path uses sha256.Sum256 with a stack-allocated buffer to avoid pool contention across goroutines.
298 lines
8.3 KiB
Go
298 lines
8.3 KiB
Go
// Copyright 2025 go-ethereum Authors
|
|
// This file is part of the go-ethereum library.
|
|
//
|
|
// The go-ethereum library is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Lesser General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// The go-ethereum library is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Lesser General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Lesser General Public License
|
|
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
package bintrie
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"errors"
|
|
"fmt"
|
|
"math/bits"
|
|
"runtime"
|
|
"sync"
|
|
|
|
"github.com/ethereum/go-ethereum/common"
|
|
)
|
|
|
|
// parallelDepth returns the tree depth below which Hash() spawns goroutines.
|
|
func parallelDepth() int {
|
|
return min(bits.Len(uint(runtime.NumCPU())), 8)
|
|
}
|
|
|
|
// isDirty reports whether a BinaryNode child needs rehashing.
|
|
func isDirty(n BinaryNode) bool {
|
|
switch v := n.(type) {
|
|
case *InternalNode:
|
|
return v.mustRecompute
|
|
case *StemNode:
|
|
return v.mustRecompute
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func keyToPath(depth int, key []byte) ([]byte, error) {
|
|
if depth > 31*8 {
|
|
return nil, errors.New("node too deep")
|
|
}
|
|
path := make([]byte, 0, depth+1)
|
|
for i := range depth + 1 {
|
|
bit := key[i/8] >> (7 - (i % 8)) & 1
|
|
path = append(path, bit)
|
|
}
|
|
return path, nil
|
|
}
|
|
|
|
// InternalNode is a binary trie internal node.
|
|
type InternalNode struct {
|
|
left, right BinaryNode
|
|
depth int
|
|
|
|
mustRecompute bool // true if the hash needs to be recomputed
|
|
hash common.Hash // cached hash when mustRecompute == false
|
|
}
|
|
|
|
// GetValuesAtStem retrieves the group of values located at the given stem key.
|
|
func (bt *InternalNode) GetValuesAtStem(stem []byte, resolver NodeResolverFn) ([][]byte, error) {
|
|
if bt.depth > 31*8 {
|
|
return nil, errors.New("node too deep")
|
|
}
|
|
|
|
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
|
|
if bit == 0 {
|
|
if hn, ok := bt.left.(HashedNode); ok {
|
|
path, err := keyToPath(bt.depth, stem)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
|
|
}
|
|
data, err := resolver(path, common.Hash(hn))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
|
|
}
|
|
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
|
|
}
|
|
bt.left = node
|
|
}
|
|
return bt.left.GetValuesAtStem(stem, resolver)
|
|
}
|
|
|
|
if hn, ok := bt.right.(HashedNode); ok {
|
|
path, err := keyToPath(bt.depth, stem)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
|
|
}
|
|
data, err := resolver(path, common.Hash(hn))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetValuesAtStem resolve error: %w", err)
|
|
}
|
|
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("GetValuesAtStem node deserialization error: %w", err)
|
|
}
|
|
bt.right = node
|
|
}
|
|
return bt.right.GetValuesAtStem(stem, resolver)
|
|
}
|
|
|
|
// Get retrieves the value for the given key.
|
|
func (bt *InternalNode) Get(key []byte, resolver NodeResolverFn) ([]byte, error) {
|
|
values, err := bt.GetValuesAtStem(key[:31], resolver)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get error: %w", err)
|
|
}
|
|
if values == nil {
|
|
return nil, nil
|
|
}
|
|
return values[key[31]], nil
|
|
}
|
|
|
|
// Insert inserts a new key-value pair into the trie.
|
|
func (bt *InternalNode) Insert(key []byte, value []byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
|
|
var values [256][]byte
|
|
values[key[31]] = value
|
|
return bt.InsertValuesAtStem(key[:31], values[:], resolver, depth)
|
|
}
|
|
|
|
// Copy creates a deep copy of the node.
|
|
func (bt *InternalNode) Copy() BinaryNode {
|
|
return &InternalNode{
|
|
left: bt.left.Copy(),
|
|
right: bt.right.Copy(),
|
|
depth: bt.depth,
|
|
mustRecompute: bt.mustRecompute,
|
|
hash: bt.hash,
|
|
}
|
|
}
|
|
|
|
// Hash returns the hash of the node.
|
|
func (bt *InternalNode) Hash() common.Hash {
|
|
if !bt.mustRecompute {
|
|
return bt.hash
|
|
}
|
|
|
|
// At shallow depths, parallelize when both children need rehashing:
|
|
// hash left subtree in a goroutine, right subtree inline, then combine.
|
|
// Skip goroutine overhead when only one child is dirty (common case
|
|
// for narrow state updates that touch a single path through the trie).
|
|
if bt.depth < parallelDepth() && isDirty(bt.left) && isDirty(bt.right) {
|
|
var input [64]byte
|
|
var lh common.Hash
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
lh = bt.left.Hash()
|
|
}()
|
|
rh := bt.right.Hash()
|
|
copy(input[32:], rh[:])
|
|
wg.Wait()
|
|
copy(input[:32], lh[:])
|
|
bt.hash = sha256.Sum256(input[:])
|
|
bt.mustRecompute = false
|
|
return bt.hash
|
|
}
|
|
|
|
// Deeper nodes: sequential using pooled hasher (goroutine overhead > hash cost)
|
|
h := newSha256()
|
|
defer returnSha256(h)
|
|
if bt.left != nil {
|
|
h.Write(bt.left.Hash().Bytes())
|
|
} else {
|
|
h.Write(zero[:])
|
|
}
|
|
if bt.right != nil {
|
|
h.Write(bt.right.Hash().Bytes())
|
|
} else {
|
|
h.Write(zero[:])
|
|
}
|
|
bt.hash = common.BytesToHash(h.Sum(nil))
|
|
bt.mustRecompute = false
|
|
return bt.hash
|
|
}
|
|
|
|
// InsertValuesAtStem inserts a full value group at the given stem in the internal node.
|
|
// Already-existing values will be overwritten.
|
|
func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver NodeResolverFn, depth int) (BinaryNode, error) {
|
|
var err error
|
|
bit := stem[bt.depth/8] >> (7 - (bt.depth % 8)) & 1
|
|
if bit == 0 {
|
|
if bt.left == nil {
|
|
bt.left = Empty{}
|
|
}
|
|
|
|
if hn, ok := bt.left.(HashedNode); ok {
|
|
path, err := keyToPath(bt.depth, stem)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
|
}
|
|
data, err := resolver(path, common.Hash(hn))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
|
}
|
|
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
|
|
}
|
|
bt.left = node
|
|
}
|
|
|
|
bt.left, err = bt.left.InsertValuesAtStem(stem, values, resolver, depth+1)
|
|
bt.mustRecompute = true
|
|
return bt, err
|
|
}
|
|
|
|
if bt.right == nil {
|
|
bt.right = Empty{}
|
|
}
|
|
|
|
if hn, ok := bt.right.(HashedNode); ok {
|
|
path, err := keyToPath(bt.depth, stem)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
|
}
|
|
data, err := resolver(path, common.Hash(hn))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("InsertValuesAtStem resolve error: %w", err)
|
|
}
|
|
node, err := DeserializeNodeWithHash(data, bt.depth+1, common.Hash(hn))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("InsertValuesAtStem node deserialization error: %w", err)
|
|
}
|
|
bt.right = node
|
|
}
|
|
|
|
bt.right, err = bt.right.InsertValuesAtStem(stem, values, resolver, depth+1)
|
|
bt.mustRecompute = true
|
|
return bt, err
|
|
}
|
|
|
|
// CollectNodes collects all child nodes at a given path, and flushes it
|
|
// into the provided node collector.
|
|
func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error {
|
|
if bt.left != nil {
|
|
var p [256]byte
|
|
copy(p[:], path)
|
|
childpath := p[:len(path)]
|
|
childpath = append(childpath, 0)
|
|
if err := bt.left.CollectNodes(childpath, flushfn); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if bt.right != nil {
|
|
var p [256]byte
|
|
copy(p[:], path)
|
|
childpath := p[:len(path)]
|
|
childpath = append(childpath, 1)
|
|
if err := bt.right.CollectNodes(childpath, flushfn); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
flushfn(path, bt)
|
|
return nil
|
|
}
|
|
|
|
// GetHeight returns the height of the node.
|
|
func (bt *InternalNode) GetHeight() int {
|
|
var (
|
|
leftHeight int
|
|
rightHeight int
|
|
)
|
|
if bt.left != nil {
|
|
leftHeight = bt.left.GetHeight()
|
|
}
|
|
if bt.right != nil {
|
|
rightHeight = bt.right.GetHeight()
|
|
}
|
|
return 1 + max(leftHeight, rightHeight)
|
|
}
|
|
|
|
func (bt *InternalNode) toDot(parent, path string) string {
|
|
me := fmt.Sprintf("internal%s", path)
|
|
ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, bt.Hash())
|
|
if len(parent) > 0 {
|
|
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
|
|
}
|
|
|
|
if bt.left != nil {
|
|
ret = fmt.Sprintf("%s%s", ret, bt.left.toDot(me, fmt.Sprintf("%s%02x", path, 0)))
|
|
}
|
|
if bt.right != nil {
|
|
ret = fmt.Sprintf("%s%s", ret, bt.right.toDot(me, fmt.Sprintf("%s%02x", path, 1)))
|
|
}
|
|
return ret
|
|
}
|