go-ethereum/trie/bintrie/store_commit.go
Ng Wei Han 73eeee65bf
Some checks are pending
/ Linux Build (push) Waiting to run
/ Linux Build (arm) (push) Waiting to run
/ Keeper Build (push) Waiting to run
/ Windows Build (push) Waiting to run
/ Docker Image (push) Waiting to run
trie/bintrie: use bitarray for path encoding + fix serialization issues (#34772)
Replace 1-byte-per-bit path encoding with bit-packed `BitArray`,
reducing DB key size by 8x

Benchmark (sparse single-leaf write, M3 Pro):
```
                           │  Before (1B/bit)  │  After (BitArray)  │
                           │      sec/op       │  sec/op   vs base  │
CollectNodesSparseWrite-11        10.50µ ± 1%   9.78µ ± 1%  -6.86%
                           │       B/op        │   B/op    vs base  │
CollectNodesSparseWrite-11       5.50Ki ± 0%   5.09Ki ± 0%  -7.38%
                           │    allocs/op      │ allocs   vs base   │
CollectNodesSparseWrite-11          67 ± 0%       58 ± 0%  -13.43%
```

---------

Co-authored-by: Guillaume Ballet <3272758+gballet@users.noreply.github.com>
2026-06-22 22:26:17 +02:00

553 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2026 go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bintrie
import (
"crypto/sha256"
"errors"
"fmt"
"math/bits"
"runtime"
"sync"
"github.com/ethereum/go-ethereum/common"
)
type nodeFlushFn func(path BitArray, hash common.Hash, serialized []byte)
func (s *nodeStore) Hash() common.Hash {
return s.computeHash(s.root)
}
func (s *nodeStore) computeHash(ref nodeRef) common.Hash {
switch ref.Kind() {
case kindInternal:
return s.hashInternal(ref.Index())
case kindStem:
return s.getStem(ref.Index()).Hash()
case kindHashed:
return s.getHashed(ref.Index()).Hash()
case kindEmpty:
return common.Hash{}
default:
return common.Hash{}
}
}
// parallelHashDepth is the tree depth below which hashInternal spawns
// goroutines for shallow-depth parallelism. Computed once at init because
// NumCPU() never changes after startup.
var parallelHashDepth = min(bits.Len(uint(runtime.NumCPU())), 8)
// hashInternal hashes an InternalNode and caches the result.
//
// At shallow depths (< parallelHashDepth) the left subtree is hashed in a
// goroutine while the right subtree is hashed inline, then the two digests
// are combined. Below that threshold the goroutine spawn cost outweighs the
// hashing work, so deeper nodes hash both children sequentially.
func (s *nodeStore) hashInternal(idx uint32) common.Hash {
node := s.getInternal(idx)
if !node.mustRecompute {
return node.hash
}
if int(node.depth) < parallelHashDepth {
var input [64]byte
var lh common.Hash
var wg sync.WaitGroup
if !node.left.IsEmpty() {
wg.Add(1)
go func() {
// defer wg.Done() so a panic in computeHash still releases
// the waiter; without this, a recover() higher in the call
// stack would leave the parent stuck in wg.Wait forever.
defer wg.Done()
lh = s.computeHash(node.left)
}()
}
if !node.right.IsEmpty() {
rh := s.computeHash(node.right)
copy(input[32:], rh[:])
}
wg.Wait()
copy(input[:32], lh[:])
node.hash = sha256.Sum256(input[:])
node.mustRecompute = false
return node.hash
}
// Deep sequential branch — mirrors the shallow branch's shape to keep
// input on the stack. Writing lh/rh through hash.Hash (interface)
// forces escape; copy into a local [64]byte and hash it in one shot.
var input [64]byte
if !node.left.IsEmpty() {
lh := s.computeHash(node.left)
copy(input[:HashSize], lh[:])
}
if !node.right.IsEmpty() {
rh := s.computeHash(node.right)
copy(input[HashSize:], rh[:])
}
node.hash = sha256.Sum256(input[:])
node.mustRecompute = false
return node.hash
}
// serializeSubtree recursively collects child hashes from a subtree of InternalNodes.
// It traverses up to `remainingDepth` levels, storing hashes of bottom-layer children.
// position tracks the current index (0 to 2^groupDepth - 1) for bitmap placement.
// hashes collects the hashes of present children, bitmap tracks which positions are present.
func (s *nodeStore) serializeSubtree(ref nodeRef, remainingDepth int, position int, groupDepth int, bitmap []byte, hashes *[]common.Hash, depths *[]uint8) {
if remainingDepth == 0 {
// Bottom layer: store hash if not empty
switch ref.Kind() {
case kindEmpty:
// Leave bitmap bit unset, don't add hash
return
default:
// StemNode, HashedNode, or InternalNode at boundary: store hash
bitmap[position/8] |= 1 << (7 - (position % 8))
*hashes = append(*hashes, s.computeHash(ref))
*depths = append(*depths, uint8(groupDepth))
}
return
}
switch ref.Kind() {
case kindInternal:
leftPos := position * 2
rightPos := position*2 + 1
s.serializeSubtree(s.getInternal(ref.Index()).left, remainingDepth-1, leftPos, groupDepth, bitmap, hashes, depths)
s.serializeSubtree(s.getInternal(ref.Index()).right, remainingDepth-1, rightPos, groupDepth, bitmap, hashes, depths)
case kindEmpty:
return
default:
// StemNode or HashedNode encountered before reaching the group's bottom
// layer. Compute the leaf bitmap position where this node's hash will
// be stored.
bitmapPos := position << remainingDepth
bitmap[bitmapPos/8] |= 1 << (7 - (bitmapPos % 8))
*hashes = append(*hashes, s.computeHash(ref))
*depths = append(*depths, uint8(groupDepth-remainingDepth))
}
}
// depthBits is the number of bits used to encode one depth offset.
const depthBits = 3
// packedDepthsLen returns the byte length of k packed depth entries
func packedDepthsLen(k int) int {
return (k*depthBits + 7) / 8
}
// writeDepth writes a depth entry at idx into the buf, MSB-first.
func writeDepth(buf []byte, idx int, v uint8) {
pos := idx * depthBits
for i := range depthBits {
bit := (v >> (depthBits - 1 - i)) & 1
p := pos + i
buf[p>>3] |= bit << (7 - (p & 7))
}
}
// readDepth reads a depth for entry idx from buf.
func readDepth(buf []byte, idx int) uint8 {
pos := idx * depthBits
var v uint8
for i := range depthBits {
p := pos + i
bit := (buf[p>>3] >> (7 - (p & 7))) & 1
v = v<<1 | bit
}
return v
}
// SerializeNode serializes a node into the flat on-disk format.
func (s *nodeStore) serializeNode(ref nodeRef, groupDepth int) []byte {
switch ref.Kind() {
case kindInternal:
// InternalNode group format:
// [type(1)] [groupDepth(1)] [bitmap (2^groupDepth bits)] [depths(3 bits × K, padded)] [hashes(32B × K)]
bitmapSize := bitmapSizeForDepth(groupDepth)
bitmap := make([]byte, bitmapSize)
var hashes []common.Hash
var depths []uint8
s.serializeSubtree(ref, groupDepth, 0, groupDepth, bitmap, &hashes, &depths)
// Build serialized output
k := len(hashes)
depthsLen := packedDepthsLen(k)
serializedLen := NodeTypeBytes + 1 + bitmapSize + depthsLen + k*HashSize
serialized := make([]byte, serializedLen)
serialized[0] = nodeTypeInternal
serialized[1] = byte(groupDepth)
copy(serialized[2:2+bitmapSize], bitmap)
depthsOff := NodeTypeBytes + 1 + bitmapSize
for i, d := range depths {
writeDepth(serialized[depthsOff:depthsOff+depthsLen], i, d-1)
}
hashesOff := depthsOff + depthsLen
for i, h := range hashes {
copy(serialized[hashesOff+i*HashSize:hashesOff+(i+1)*HashSize], h.Bytes())
}
return serialized
case kindStem:
sn := s.getStem(ref.Index())
// Count present slots to size the blob.
var count int
for _, v := range sn.values {
if v != nil {
count++
}
}
serializedLen := NodeTypeBytes + StemSize + StemBitmapSize + count*HashSize
serialized := make([]byte, serializedLen)
serialized[0] = nodeTypeStem
copy(serialized[NodeTypeBytes:NodeTypeBytes+StemSize], sn.Stem[:])
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+StemBitmapSize]
offset := NodeTypeBytes + StemSize + StemBitmapSize
for i, v := range sn.values {
if v != nil {
bitmap[i/8] |= 1 << (7 - (i % 8))
copy(serialized[offset:offset+HashSize], v)
offset += HashSize
}
}
return serialized
default:
panic(fmt.Sprintf("SerializeNode: unexpected node kind %d", ref.Kind()))
}
}
var errInvalidSerializedLength = errors.New("invalid serialized node length")
// DeserializeNode deserializes a node from bytes, recomputing its hash. The
// returned node is marked dirty (provenance unknown, safe re-flush default).
func (s *nodeStore) deserializeNode(serialized []byte, depth int) (nodeRef, error) {
return s.decodeNode(serialized, depth, common.Hash{}, true, true)
}
// DeserializeNodeWithHash deserializes a node whose hash is already known and
// whose blob is already on disk (mustRecompute=false, dirty=false).
func (s *nodeStore) deserializeNodeWithHash(serialized []byte, depth int, hn common.Hash) (nodeRef, error) {
return s.decodeNode(serialized, depth, hn, false, false)
}
// deserializeSubtree reconstructs an InternalNode subtree from grouped serialization.
func (s *nodeStore) deserializeSubtree(hn common.Hash, groupDepth int, nodeDepth int, bitmap []byte, depths []byte, hashData []byte, mustRecompute bool, dirty bool) (nodeRef, error) {
if len(hashData)%HashSize != 0 {
return emptyRef, errInvalidSerializedLength
}
k := len(hashData) / HashSize
if len(depths) != packedDepthsLen(k) {
return emptyRef, errInvalidSerializedLength
}
if k == 0 {
return emptyRef, nil
}
rootRef := s.newInternalRef(nodeDepth)
rootNode := s.getInternal(rootRef.Index())
rootNode.mustRecompute = mustRecompute
if !mustRecompute {
rootNode.hash = hn
}
rootNode.dirty = dirty
bitmapBits := 1 << groupDepth
entryIdx := 0
for bit := 0; bit < bitmapBits; bit++ {
if bitmap[bit/8]>>(7-(bit%8))&1 == 0 {
continue
}
depthOffset := int(readDepth(depths, entryIdx)) + 1
if depthOffset > groupDepth {
return emptyRef, errors.New("invalid depth offset")
}
// Canonical-encoding check: trailing position bits must be zero.
mask := (1 << (groupDepth - depthOffset)) - 1
if bit&mask != 0 {
return emptyRef, errors.New("non-canonical bitmap position")
}
var hash common.Hash
copy(hash[:], hashData[entryIdx*HashSize:(entryIdx+1)*HashSize])
if err := s.attachInGroup(rootRef, nodeDepth, groupDepth, depthOffset, bit, hash, dirty); err != nil {
return emptyRef, err
}
entryIdx++
}
return rootRef, nil
}
func (s *nodeStore) attachInGroup(rootRef nodeRef, rootDepth, groupDepth, depthOffset, bitmapPos int, hash common.Hash, dirty bool) error {
cur := rootRef
for level := 0; level < depthOffset-1; level++ {
bit := (bitmapPos >> (groupDepth - 1 - level)) & 1
node := s.getInternal(cur.Index())
childRef := node.left
if bit == 1 {
childRef = node.right
}
if childRef.IsEmpty() {
newRef := s.newInternalRef(rootDepth + level + 1)
s.getInternal(newRef.Index()).dirty = dirty
if bit == 0 {
node.left = newRef
} else {
node.right = newRef
}
cur = newRef
continue
}
if childRef.Kind() != kindInternal {
return errors.New("overlapping entries in group blob")
}
cur = childRef
}
leafBit := (bitmapPos >> (groupDepth - depthOffset)) & 1
node := s.getInternal(cur.Index())
if leafBit == 0 {
if !node.left.IsEmpty() {
return errors.New("overlapping entries in group blob")
}
node.left = s.newHashedRef(hash)
} else {
if !node.right.IsEmpty() {
return errors.New("overlapping entries in group blob")
}
node.right = s.newHashedRef(hash)
}
return nil
}
func (s *nodeStore) decodeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (nodeRef, error) {
if len(serialized) == 0 {
return emptyRef, nil
}
switch serialized[0] {
case nodeTypeInternal:
// Grouped format:
// [type(1)] [groupDepth(1)] [bitmap (2^groupDepth bits, padded to bitmapSize bytes)]
// [depthOffsets (3 bits × K, padded to bytes)] [hashes (32B × K)]
if len(serialized) < NodeTypeBytes+1 {
return emptyRef, errInvalidSerializedLength
}
groupDepth := int(serialized[1])
if groupDepth < 1 || groupDepth > MaxGroupDepth {
return 0, errors.New("invalid group depth")
}
bitmapSize := bitmapSizeForDepth(groupDepth)
if len(serialized) < NodeTypeBytes+1+bitmapSize {
return 0, errInvalidSerializedLength
}
bitmap := serialized[2 : 2+bitmapSize]
bitmapBits := 1 << groupDepth
if bitmapBits < 8 {
padMask := byte(0xFF) >> bitmapBits
if bitmap[0]&padMask != 0 {
return emptyRef, errors.New("non-canonical bitmap padding")
}
}
k := 0
for _, b := range bitmap {
k += bits.OnesCount8(b)
}
depthsLen := packedDepthsLen(k)
expectedLen := NodeTypeBytes + 1 + bitmapSize + depthsLen + k*HashSize
if len(serialized) != expectedLen {
return emptyRef, errInvalidSerializedLength
}
depthsOff := NodeTypeBytes + 1 + bitmapSize
depths := serialized[depthsOff : depthsOff+depthsLen]
hashData := serialized[depthsOff+depthsLen : depthsOff+depthsLen+k*HashSize]
// Canonical-encoding check: the unused low bits of the last packed
// depth byte must be zero.
if usedBits := k * depthBits; usedBits%8 != 0 {
padMask := byte(0xFF) >> (usedBits % 8)
if depths[depthsLen-1]&padMask != 0 {
return emptyRef, errors.New("non-canonical depth padding")
}
}
return s.deserializeSubtree(hn, groupDepth, depth, bitmap, depths, hashData, mustRecompute, dirty)
case nodeTypeStem:
if len(serialized) < NodeTypeBytes+StemSize+StemBitmapSize {
return emptyRef, errInvalidSerializedLength
}
stemIdx := s.allocStem()
sn := s.getStem(stemIdx)
copy(sn.Stem[:], serialized[NodeTypeBytes:NodeTypeBytes+StemSize])
bitmap := serialized[NodeTypeBytes+StemSize : NodeTypeBytes+StemSize+StemBitmapSize]
offset := NodeTypeBytes + StemSize + StemBitmapSize
for i := range StemNodeWidth {
if bitmap[i/8]>>(7-(i%8))&1 != 1 {
continue
}
if len(serialized) < offset+HashSize {
return emptyRef, errInvalidSerializedLength
}
// Zero-copy: each slot aliases the serialized input buffer.
sn.values[i] = serialized[offset : offset+HashSize]
offset += HashSize
}
sn.depth = uint8(depth)
sn.hash = hn
sn.mustRecompute = mustRecompute
sn.dirty = dirty
return makeRef(kindStem, stemIdx), nil
default:
return emptyRef, errors.New("invalid node type")
}
}
// CollectNodes flushes every node that needs flushing via flushfn in post-order.
// Invariant: any ancestor of a node that needs flushing is itself marked, so a
// clean root means the whole subtree is clean.
//
// BitArray is passed by value (33 bytes) to keep child paths on the stack.
// Passing by pointer causes escape to heap per recursive call.
func (s *nodeStore) collectNodes(ref nodeRef, path BitArray, flushfn nodeFlushFn, groupDepth int) {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
if !node.dirty {
return
}
// Only flush at group boundaries (depth % groupDepth == 0)
if int(node.depth)%groupDepth == 0 {
// We're at a group boundary - first collect any nodes in deeper groups,
// then flush this group
s.collectChildGroups(node, path, flushfn, groupDepth, groupDepth-1)
flushfn(path, s.computeHash(ref), s.serializeNode(ref, groupDepth))
node.dirty = false
return
}
// Not at a group boundary - this shouldn't happen if we're called correctly from root
// but handle it by continuing to traverse
s.collectChildGroups(node, path, flushfn, groupDepth, groupDepth-(int(node.depth)%groupDepth)-1)
case kindStem:
sn := s.getStem(ref.Index())
if !sn.dirty {
return
}
flushfn(path, s.computeHash(ref), s.serializeNode(ref, groupDepth))
sn.dirty = false
case kindHashed, kindEmpty:
default:
panic(fmt.Sprintf("CollectNodes: unexpected kind %d", ref.Kind()))
}
}
// collectChildGroups traverses within a group to find and collect nodes in the next group.
// remainingLevels is how many more levels below the current node until we reach the group boundary.
// When remainingLevels=0, the current node's children are at the next group boundary.
func (s *nodeStore) collectChildGroups(node *InternalNode, path BitArray, flushfn nodeFlushFn, groupDepth int, remainingLevels int) error {
if remainingLevels == 0 {
// Current node is at depth (groupBoundary - 1), its children are at the next group boundary
if !node.left.IsEmpty() {
leftPath := path
leftPath.AppendBit(&leftPath, 0)
s.collectNodes(node.left, leftPath, flushfn, groupDepth)
}
if !node.right.IsEmpty() {
rightPath := path
rightPath.AppendBit(&rightPath, 1)
s.collectNodes(node.right, rightPath, flushfn, groupDepth)
}
return nil
}
if !node.left.IsEmpty() {
leftPath := path
leftPath.AppendBit(&leftPath, 0)
switch node.left.Kind() {
case kindInternal:
n := s.getInternal(node.left.Index())
if err := s.collectChildGroups(n, leftPath, flushfn, groupDepth, remainingLevels-1); err != nil {
return err
}
default:
s.collectNodes(node.left, leftPath, flushfn, groupDepth)
}
}
if !node.right.IsEmpty() {
rightPath := path
rightPath.AppendBit(&rightPath, 1)
switch node.right.Kind() {
case kindInternal:
n := s.getInternal(node.right.Index())
if err := s.collectChildGroups(n, rightPath, flushfn, groupDepth, remainingLevels-1); err != nil {
return err
}
default:
s.collectNodes(node.right, rightPath, flushfn, groupDepth)
}
}
return nil
}
func (s *nodeStore) toDot(ref nodeRef, parent, path string) string {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
me := fmt.Sprintf("internal%s", path)
ret := fmt.Sprintf("%s [label=\"I: %x\"]\n", me, s.computeHash(ref))
if len(parent) > 0 {
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
}
if !node.left.IsEmpty() {
ret += s.toDot(node.left, me, fmt.Sprintf("%s%b", path, 0))
}
if !node.right.IsEmpty() {
ret += s.toDot(node.right, me, fmt.Sprintf("%s%b", path, 1))
}
return ret
case kindStem:
sn := s.getStem(ref.Index())
me := fmt.Sprintf("stem%s", path)
ret := fmt.Sprintf("%s [label=\"stem=%x c=%x\"]\n", me, sn.Stem, sn.Hash())
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
for i, v := range sn.values {
if v == nil {
continue
}
ret += fmt.Sprintf("%s%x [label=\"%x\"]\n", me, i, v)
ret += fmt.Sprintf("%s -> %s%x\n", me, me, i)
}
return ret
case kindHashed:
hn := s.getHashed(ref.Index())
me := fmt.Sprintf("hash%s", path)
ret := fmt.Sprintf("%s [label=\"%x\"]\n", me, hn.Hash())
ret = fmt.Sprintf("%s %s -> %s\n", ret, parent, me)
return ret
default:
return ""
}
}