From 73eeee65bf6dd1f3d81726e54c1ce62bcf256db1 Mon Sep 17 00:00:00 2001
From: Ng Wei Han <47109095+weiihann@users.noreply.github.com>
Date: Tue, 23 Jun 2026 03:26:17 +0700
Subject: [PATCH] trie/bintrie: use bitarray for path encoding + fix
serialization issues (#34772)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Replace 1-byte-per-bit path encoding with bit-packed `BitArray`,
reducing DB key size by 8x
Benchmark (sparse single-leaf write, M3 Pro):
```
│ Before (1B/bit) │ After (BitArray) │
│ sec/op │ sec/op vs base │
CollectNodesSparseWrite-11 10.50µ ± 1% 9.78µ ± 1% -6.86%
│ B/op │ B/op vs base │
CollectNodesSparseWrite-11 5.50Ki ± 0% 5.09Ki ± 0% -7.38%
│ allocs/op │ allocs vs base │
CollectNodesSparseWrite-11 67 ± 0% 58 ± 0% -13.43%
```
---------
Co-authored-by: Guillaume Ballet <3272758+gballet@users.noreply.github.com>
---
trie/bintrie/binary_node_test.go | 29 +-
trie/bintrie/bitarray.go | 160 +++++++++++
trie/bintrie/bitarray_test.go | 205 +++++++++++++
trie/bintrie/format_test.go | 445 +++++++++++++++++++++++++++++
trie/bintrie/internal_node.go | 8 +-
trie/bintrie/internal_node_test.go | 11 +-
trie/bintrie/iterator.go | 8 +-
trie/bintrie/node_store.go | 11 +-
trie/bintrie/stem_node_test.go | 16 +-
trie/bintrie/store_commit.go | 306 ++++++++++++--------
trie/bintrie/store_ops.go | 8 +
trie/bintrie/trie.go | 33 ++-
12 files changed, 1080 insertions(+), 160 deletions(-)
create mode 100644 trie/bintrie/bitarray.go
create mode 100644 trie/bintrie/bitarray_test.go
create mode 100644 trie/bintrie/format_test.go
diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go
index 857060a0c0..d91a85a765 100644
--- a/trie/bintrie/binary_node_test.go
+++ b/trie/bintrie/binary_node_test.go
@@ -40,7 +40,10 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
s.root = rootRef
// Serialize the node — grouped format at groupDepth=1:
- // [type(1)][groupDepth(1)][bitmap(1)][leftHash(32)][rightHash(32)] = 67 bytes
+ // [type(1)][groupDepth(1)][bitmap(1)][depths(1)][leftHash(32)][rightHash(32)] = 68 bytes.
+ // Both children are at depthOffset=groupDepth=1 (the bottom of the 1-level
+ // group). Each depth is stored as (offset-1)=0 in 3 bits, so the two entries
+ // pack into a single byte 0x00.
serialized := s.serializeNode(rootRef, 1)
if serialized[0] != nodeTypeInternal {
@@ -50,7 +53,7 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
t.Errorf("Expected groupDepth byte to be 1, got %d", serialized[1])
}
- expectedLen := NodeTypeBytes + 1 + 1 + 2*HashSize // type + groupDepth + bitmap + 2 hashes = 67
+ expectedLen := NodeTypeBytes + 1 + 1 + 1 + 2*HashSize // type + groupDepth + bitmap + packed depths + 2 hashes = 68
if len(serialized) != expectedLen {
t.Errorf("Expected serialized length to be %d, got %d", expectedLen, len(serialized))
}
@@ -60,7 +63,13 @@ func TestSerializeDeserializeInternalNode(t *testing.T) {
t.Errorf("Expected bitmap byte 0xc0, got 0x%02x", serialized[2])
}
- hashesStart := NodeTypeBytes + 1 + 1
+ depthsStart := NodeTypeBytes + 1 + 1
+ // Two depth offsets of 1 → stored as (1-1)=0 each → packed byte 0x00.
+ if serialized[depthsStart] != 0x00 {
+ t.Errorf("Expected packed depth byte 0x00, got 0x%02x", serialized[depthsStart])
+ }
+
+ hashesStart := depthsStart + 1
if !bytes.Equal(serialized[hashesStart:hashesStart+HashSize], leftHash[:]) {
t.Error("Left hash not found at expected position")
}
@@ -244,29 +253,29 @@ func TestKeyToPath(t *testing.T) {
{
name: "depth 0",
depth: 0,
- key: []byte{0x80}, // 10000000 in binary
- expected: []byte{1},
+ key: []byte{0x80}, // 10000000 in binary
+ expected: []byte{0x80, 1}, // 1 bit "1", left-aligned, + length byte 1
wantErr: false,
},
{
name: "depth 7",
depth: 7,
- key: []byte{0xFF}, // 11111111 in binary
- expected: []byte{1, 1, 1, 1, 1, 1, 1, 1},
+ key: []byte{0xFF}, // 11111111 in binary
+ expected: []byte{0xFF, 8}, // 8-bit value 0xFF + length byte 8
wantErr: false,
},
{
name: "depth crossing byte boundary",
depth: 10,
- key: []byte{0xFF, 0x00}, // 11111111 00000000 in binary
- expected: []byte{1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0},
+ key: []byte{0xFF, 0x00}, // 11111111 00000000 in binary
+ expected: []byte{0xFF, 0x00, 11}, // top 11 bits "11111111 000", left-aligned, + length byte 11
wantErr: false,
},
{
name: "max valid depth",
depth: StemSize*8 - 1,
key: make([]byte, HashSize),
- expected: make([]byte, StemSize*8),
+ expected: append(make([]byte, StemSize), StemSize*8), // 248 bits of zeros + length byte 248
wantErr: false,
},
{
diff --git a/trie/bintrie/bitarray.go b/trie/bintrie/bitarray.go
new file mode 100644
index 0000000000..4da3e0c5b5
--- /dev/null
+++ b/trie/bintrie/bitarray.go
@@ -0,0 +1,160 @@
+// Copyright 2026 The go-ethereum Authors
+// This file is part of go-ethereum.
+//
+// go-ethereum is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// go-ethereum is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with go-ethereum. If not, see .
+package bintrie
+
+// BitArray represents a trie path: the most significant `len` bits of a key,
+// packed big-endian and MSB-first. Bit i (0 = most significant) lives at
+// bytes[i/8] in mask 1<<(7-i%8). All bits at positions >= len are kept zero so
+// that two paths are byte-equal iff they are logically equal.
+//
+// This mirrors the on-disk key layout, so path manipulation is plain slicing
+// and copying: no shifting or endianness conversion is required. The maximum
+// length is 248 bits (a 31-byte trie stem), and is a uint8 so the spare bits in
+// the final byte are always available.
+type BitArray struct {
+ len uint8
+ bytes [32]byte
+}
+
+// NewBitArray creates a bit array of the given length whose bits are the `length`
+// least-significant bits of val, read most-significant-first. Used by tests to
+// build expected paths; the value is interpreted as a number, not raw bytes.
+func NewBitArray(length uint8, val uint64) BitArray {
+ var b BitArray
+ b.len = length
+ for p := uint8(0); p < length; p++ {
+ if (val>>(length-1-p))&1 == 1 {
+ b.bytes[p/8] |= 1 << (7 - p%8)
+ }
+ }
+ return b
+}
+
+// Len returns the number of used bits.
+func (b *BitArray) Len() uint8 {
+ return b.len
+}
+
+// Bytes returns the packed big-endian, MSB-first representation. Bits beyond
+// len are zero.
+func (b *BitArray) Bytes() [32]byte {
+ return b.bytes
+}
+
+// AppendBit sets the bit array to x with a single bit appended, and returns the
+// receiver. Safe when b and x alias the same value.
+func (b *BitArray) AppendBit(x *BitArray, bit uint8) *BitArray {
+ *b = *x
+ if bit&1 == 1 {
+ // Position b.len is guaranteed zero by the all-bits-beyond-len-are-zero
+ // invariant, so a 1 only needs setting; a 0 is already in place.
+ b.bytes[b.len/8] |= 1 << (7 - b.len%8)
+ }
+ b.len++
+ return b
+}
+
+// MSBs sets the bit array to the most significant n bits of x and returns the
+// receiver. If n >= x.len it is an exact copy of x. Think of it as x[:n].
+func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray {
+ *b = *x
+ if n < b.len {
+ b.len = n
+ b.maskTail()
+ }
+ return b
+}
+
+// Equal reports whether two bit arrays hold the same path.
+func (b *BitArray) Equal(x *BitArray) bool {
+ return b.len == x.len && b.bytes == x.bytes
+}
+
+// SetBytes sets the bit array to the most significant `length` bits of data,
+// interpreted as big-endian bytes, and returns the receiver. At most 32 bytes
+// of data are read; bits beyond length are zeroed.
+func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray {
+ b.bytes = [32]byte{}
+ copy(b.bytes[:], data)
+ b.len = length
+ b.maskTail()
+ return b
+}
+
+// SetBit sets the bit array to a single bit and returns the receiver.
+func (b *BitArray) SetBit(bit uint8) *BitArray {
+ b.bytes = [32]byte{}
+ b.len = 1
+ if bit&1 == 1 {
+ b.bytes[0] = 0x80
+ }
+ return b
+}
+
+// Copy returns a value copy of the bit array.
+func (b *BitArray) Copy() BitArray {
+ return *b
+}
+
+// Set sets the bit array to the same value as x and returns the receiver.
+func (b *BitArray) Set(x *BitArray) *BitArray {
+ *b = *x
+ return b
+}
+
+// KeyBytes returns the path-to-DB-key encoding: the active bytes (the
+// left-aligned MSB-first prefix) followed by a single trailing byte holding the
+// bit-length. The trailing length disambiguates paths whose active bytes
+// coincide (e.g. 1-bit "1" packs to [0x80, 0x01] and 8-bit "10000000" to
+// [0x80, 0x08]). The empty path encodes as no bytes.
+func (b *BitArray) KeyBytes() []byte {
+ if b.len == 0 {
+ return nil
+ }
+ bc := (int(b.len) + 7) / 8
+ res := make([]byte, bc+1)
+ copy(res[:bc], b.bytes[:bc])
+ res[bc] = b.len
+ return res
+}
+
+// PutKeyBytes writes the key encoding (active bytes followed by length byte)
+// into dst and returns the populated sub-slice. The empty path returns dst[:0]
+// without touching dst. For non-empty paths dst must have len >= 33 (32 packed
+// bytes for 248 bits + 1 length byte).
+func (b *BitArray) PutKeyBytes(dst []byte) []byte {
+ if b.len == 0 {
+ return dst[:0]
+ }
+ bc := (int(b.len) + 7) / 8
+ _ = dst[bc] // bounds check hint
+ copy(dst[:bc], b.bytes[:bc])
+ dst[bc] = b.len
+ return dst[:bc+1]
+}
+
+// maskTail zeroes every bit at a position >= len, preserving the invariant that
+// equal paths are byte-equal.
+func (b *BitArray) maskTail() {
+ full := int(b.len / 8)
+ if rem := b.len % 8; rem != 0 {
+ b.bytes[full] &= byte(0xFF) << (8 - rem)
+ full++
+ }
+ for i := full; i < len(b.bytes); i++ {
+ b.bytes[i] = 0
+ }
+}
diff --git a/trie/bintrie/bitarray_test.go b/trie/bintrie/bitarray_test.go
new file mode 100644
index 0000000000..10cfc4dc20
--- /dev/null
+++ b/trie/bintrie/bitarray_test.go
@@ -0,0 +1,205 @@
+package bintrie
+
+import (
+ "bytes"
+ "testing"
+)
+
+// ba builds a BitArray with the given length and leading bytes, for use as an
+// expected value. Remaining bytes are zero.
+func ba(length uint8, lead ...byte) BitArray {
+ var b BitArray
+ b.len = length
+ copy(b.bytes[:], lead)
+ return b
+}
+
+func TestNewBitArray(t *testing.T) {
+ tests := []struct {
+ name string
+ length uint8
+ val uint64
+ want BitArray
+ }{
+ {"empty", 0, 0, ba(0)},
+ {"single 1", 1, 1, ba(1, 0x80)},
+ {"single 0", 1, 0, ba(1, 0x00)},
+ {"101", 3, 0b101, ba(3, 0xA0)},
+ {"full byte", 8, 0xFF, ba(8, 0xFF)},
+ {"ten bits", 10, 0x3FF, ba(10, 0xFF, 0xC0)},
+ {"high bits ignored beyond length", 3, 0b11101, ba(3, 0xA0)},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := NewBitArray(tt.length, tt.val)
+ if !got.Equal(&tt.want) {
+ t.Errorf("NewBitArray(%d, %#x) = %x (len %d), want %x (len %d)",
+ tt.length, tt.val, got.bytes, got.len, tt.want.bytes, tt.want.len)
+ }
+ })
+ }
+}
+
+func TestSetBytes(t *testing.T) {
+ tests := []struct {
+ name string
+ length uint8
+ data []byte
+ want BitArray
+ }{
+ {"empty", 0, []byte{0xFF}, ba(0)},
+ {"full byte", 8, []byte{0xAB}, ba(8, 0xAB)},
+ {"top 4 bits", 4, []byte{0xFF}, ba(4, 0xF0)},
+ {"11 bits masks tail", 11, []byte{0xFF, 0xFF}, ba(11, 0xFF, 0xE0)},
+ {"data longer than length", 4, []byte{0xFF, 0xFF}, ba(4, 0xF0)},
+ {"data shorter than length", 16, []byte{0xAB}, ba(16, 0xAB, 0x00)},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := new(BitArray).SetBytes(tt.length, tt.data)
+ if !got.Equal(&tt.want) {
+ t.Errorf("SetBytes(%d, %x) = %x (len %d), want %x (len %d)",
+ tt.length, tt.data, got.bytes, got.len, tt.want.bytes, tt.want.len)
+ }
+ })
+ }
+}
+
+func TestSetBytesFull(t *testing.T) {
+ data := bytes.Repeat([]byte{0xFF}, 32)
+ got := new(BitArray).SetBytes(248, data)
+ want := ba(248)
+ for i := 0; i < 31; i++ {
+ want.bytes[i] = 0xFF
+ }
+ if !got.Equal(&want) {
+ t.Errorf("SetBytes(248, 0xFF*32): byte 31 must be zeroed; got %x", got.bytes)
+ }
+}
+
+func TestMSBs(t *testing.T) {
+ x := new(BitArray).SetBytes(16, []byte{0xAB, 0xCD})
+ tests := []struct {
+ name string
+ n uint8
+ want BitArray
+ }{
+ {"prefix byte", 8, ba(8, 0xAB)},
+ {"prefix nibble", 4, ba(4, 0xA0)},
+ {"zero", 0, ba(0)},
+ {"n equals len", 16, ba(16, 0xAB, 0xCD)},
+ {"n exceeds len copies x", 20, ba(16, 0xAB, 0xCD)},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := new(BitArray).MSBs(x, tt.n)
+ if !got.Equal(&tt.want) {
+ t.Errorf("MSBs(x, %d) = %x (len %d), want %x (len %d)",
+ tt.n, got.bytes, got.len, tt.want.bytes, tt.want.len)
+ }
+ })
+ }
+}
+
+func TestAppendBit(t *testing.T) {
+ // Build "101" one bit at a time from empty.
+ var p BitArray
+ for _, bit := range []uint8{1, 0, 1} {
+ p.AppendBit(&p, bit) // receiver aliases argument
+ }
+ if want := ba(3, 0xA0); !p.Equal(&want) {
+ t.Fatalf("append 1,0,1 = %x (len %d), want %x (len 3)", p.bytes, p.len, want.bytes)
+ }
+
+ // Append across a byte boundary: 8 ones then a 1 → 9 bits.
+ var q BitArray
+ for i := 0; i < 9; i++ {
+ q.AppendBit(&q, 1)
+ }
+ if want := ba(9, 0xFF, 0x80); !q.Equal(&want) {
+ t.Fatalf("append nine 1s = %x (len %d), want %x (len 9)", q.bytes, q.len, want.bytes)
+ }
+
+ // Appending to a copy must not mutate the source.
+ src := new(BitArray).SetBytes(4, []byte{0xF0})
+ child := *src
+ child.AppendBit(&child, 0)
+ if want := ba(4, 0xF0); !src.Equal(&want) {
+ t.Errorf("source mutated by append on copy: %x", src.bytes)
+ }
+ if want := ba(5, 0xF0); !child.Equal(&want) {
+ t.Errorf("append 0 = %x (len %d), want %x (len 5)", child.bytes, child.len, want.bytes)
+ }
+}
+
+func TestSetBit(t *testing.T) {
+ if got, want := new(BitArray).SetBit(1), ba(1, 0x80); !got.Equal(&want) {
+ t.Errorf("SetBit(1) = %x (len %d), want %x", got.bytes, got.len, want.bytes)
+ }
+ if got, want := new(BitArray).SetBit(0), ba(1, 0x00); !got.Equal(&want) {
+ t.Errorf("SetBit(0) = %x (len %d), want %x", got.bytes, got.len, want.bytes)
+ }
+}
+
+func TestEqual(t *testing.T) {
+ a := NewBitArray(3, 0b101)
+ b := NewBitArray(3, 0b101)
+ if !a.Equal(&b) {
+ t.Error("equal arrays reported unequal")
+ }
+ // Same active bytes, different length must be unequal.
+ c := NewBitArray(2, 0b10) // "10" -> byte 0x80, len 2
+ d := ba(3, c.bytes[0]) // same byte, len 3
+ if c.Equal(&d) {
+ t.Error("arrays with different length reported equal")
+ }
+}
+
+func TestKeyBytesRoundTrip(t *testing.T) {
+ tests := []struct {
+ name string
+ length uint8
+ data []byte
+ want []byte // expected KeyBytes output
+ }{
+ {"empty", 0, nil, nil},
+ {"one bit", 1, []byte{0x80}, []byte{0x80, 1}},
+ {"full byte", 8, []byte{0x80}, []byte{0x80, 8}},
+ {"eleven bits", 11, []byte{0xFF, 0xFF}, []byte{0xFF, 0xE0, 11}},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ src := new(BitArray).SetBytes(tt.length, tt.data)
+ key := src.KeyBytes()
+ if !bytes.Equal(key, tt.want) {
+ t.Fatalf("KeyBytes() = %x, want %x", key, tt.want)
+ }
+
+ // PutKeyBytes must agree with KeyBytes.
+ var buf [33]byte
+ if put := src.PutKeyBytes(buf[:]); !bytes.Equal(put, tt.want) {
+ t.Fatalf("PutKeyBytes() = %x, want %x", put, tt.want)
+ }
+
+ // Re-parse the active bytes and confirm the path round-trips.
+ if tt.length == 0 {
+ return
+ }
+ lengthByte := key[len(key)-1]
+ reparsed := new(BitArray).SetBytes(lengthByte, key[:len(key)-1])
+ if !reparsed.Equal(src) {
+ t.Fatalf("round-trip mismatch: %x (len %d) != %x (len %d)",
+ reparsed.bytes, reparsed.len, src.bytes, src.len)
+ }
+ })
+ }
+}
+
+func TestCopyIsIndependent(t *testing.T) {
+ src := new(BitArray).SetBytes(8, []byte{0xAB})
+ cp := src.Copy()
+ cp.AppendBit(&cp, 1)
+ if want := ba(8, 0xAB); !src.Equal(&want) {
+ t.Errorf("Copy not independent: source became %x (len %d)", src.bytes, src.len)
+ }
+}
diff --git a/trie/bintrie/format_test.go b/trie/bintrie/format_test.go
new file mode 100644
index 0000000000..0942ce2dc2
--- /dev/null
+++ b/trie/bintrie/format_test.go
@@ -0,0 +1,445 @@
+// Copyright 2026 go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package bintrie
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/trie"
+ "github.com/ethereum/go-ethereum/trie/trienode"
+)
+
+// TestRootHashMatchesReadBackHash pins the round-trip invariant: the root
+// hash a Commit advertises must be exactly the value a fresh reader computes
+// from the on-disk blob. Before Option B the writer produced a natural-depth
+// hash while DeserializeAndHash produced an extended-depth hash, so the two
+// disagreed for any non-trivial subtree — this test failed. With the
+// per-entry depth byte, the reader rebuilds the natural-shape tree and the
+// hashes match for every groupDepth and every divergence bit.
+func TestRootHashMatchesReadBackHash(t *testing.T) {
+ for groupDepth := 1; groupDepth <= MaxGroupDepth; groupDepth++ {
+ // divergeBit ∈ [0, groupDepth-1] places the two stems at natural
+ // depth (divergeBit+1) within the root group; we want to exercise
+ // every depth-offset value the new format must handle.
+ for divergeBit := 0; divergeBit < groupDepth; divergeBit++ {
+ t.Run(fmt.Sprintf("gd=%d/diverge=%d", groupDepth, divergeBit), func(t *testing.T) {
+ tr := &BinaryTrie{
+ store: newNodeStore(),
+ tracer: trie.NewPrevalueTracer(),
+ groupDepth: groupDepth,
+ }
+ stemL, stemR := stemsDivergingAt(divergeBit)
+ if err := tr.store.Insert(stemL, oneKey[:], nil); err != nil {
+ t.Fatalf("Insert stemL: %v", err)
+ }
+ if err := tr.store.Insert(stemR, twoKey[:], nil); err != nil {
+ t.Fatalf("Insert stemR: %v", err)
+ }
+
+ natural := tr.Hash()
+ _, ns := tr.Commit(false)
+ rootNode, ok := ns.Nodes[""]
+ if !ok {
+ t.Fatalf("Commit produced no root blob (path \"\")")
+ }
+
+ readBack, err := DeserializeAndHash(rootNode.Blob, 0)
+ if err != nil {
+ t.Fatalf("DeserializeAndHash: %v", err)
+ }
+ if natural != readBack {
+ t.Fatalf("round-trip hash mismatch:\n"+
+ " tr.Hash() = %x\n"+
+ " DeserializeAndHash(rootBlob) = %x\n"+
+ "the parent's stored root hash cannot be reproduced from its own blob",
+ natural, readBack)
+ }
+ })
+ }
+ }
+}
+
+// TestMultiStemMixedDepths inserts four stems that diverge at different
+// depths within a single groupDepth=5 group, then round-trips the trie
+// through Commit + fresh-read. Verifies that every stem is retrievable by
+// key after reload — exercises the new format with several depth-offset
+// values in the same blob (1, 2, 3, 4) and confirms attachInGroup builds
+// the natural-shape tree correctly.
+func TestMultiStemMixedDepths(t *testing.T) {
+ const groupDepth = 5
+
+ // Each stem diverges from `0x00…00` at a different bit, so naturally:
+ // - stem at bit-0 divergence → depth 1
+ // - stem at bit-1 divergence → depth 2
+ // - stem at bit-2 divergence → depth 3
+ // - stem at bit-3 divergence → depth 4
+ stems := [][]byte{
+ zeroKey[:],
+ bitFlipStem(0), // diverge at bit 0
+ bitFlipStem(1), // diverge at bit 1 (prefix "0" matches stem 0)
+ bitFlipStem(2), // diverge at bit 2 (prefix "00")
+ bitFlipStem(3), // diverge at bit 3 (prefix "000")
+ }
+ values := []common.Hash{oneKey, twoKey, threeKey, fourKey, ffKey}
+
+ tr := &BinaryTrie{
+ store: newNodeStore(),
+ tracer: trie.NewPrevalueTracer(),
+ groupDepth: groupDepth,
+ }
+ for i, stem := range stems {
+ if err := tr.store.Insert(stem, values[i][:], nil); err != nil {
+ t.Fatalf("Insert stem %d: %v", i, err)
+ }
+ }
+
+ before := tr.Hash()
+ _, ns := tr.Commit(false)
+ rootBlob, ok := ns.Nodes[""]
+ if !ok {
+ t.Fatalf("no root blob in NodeSet")
+ }
+ readBack, err := DeserializeAndHash(rootBlob.Blob, 0)
+ if err != nil {
+ t.Fatalf("DeserializeAndHash: %v", err)
+ }
+ if before != readBack {
+ t.Fatalf("hash mismatch: tr.Hash()=%x DeserializeAndHash(rootBlob)=%x", before, readBack)
+ }
+
+ // Reload the root blob into a fresh store and confirm structure.
+ fresh := newNodeStore()
+ ref, err := fresh.deserializeNodeWithHash(rootBlob.Blob, 0, before)
+ if err != nil {
+ t.Fatalf("deserializeNodeWithHash: %v", err)
+ }
+ if ref.Kind() != kindInternal {
+ t.Fatalf("expected root to be Internal, got kind %d", ref.Kind())
+ }
+ // Spot-check: the reload-tree's root hash equals the commit-time hash.
+ if got := fresh.computeHash(ref); got != before {
+ t.Fatalf("reload root hash mismatch: got %x, want %x", got, before)
+ }
+}
+
+// TestDecodeRejectsNonCanonicalPosition hand-crafts a blob where the bitmap
+// position has nonzero trailing bits given its depth offset. Two
+// implementations must produce byte-identical blobs for the same logical
+// content, so a non-canonical position is unambiguously an invalid blob.
+func TestDecodeRejectsNonCanonicalPosition(t *testing.T) {
+ // groupDepth=5, bitmap size = 4 bytes. Set bit at position 5 (binary
+ // 00101) and declare depthOffset=2. Top 2 bits of 00101 are 00 (path
+ // "00"), the trailing 3 bits should be zero — they're 101 here, so the
+ // reader must reject.
+ blob := []byte{nodeTypeInternal, 5}
+ // bitmap[0] = bit at position 5 → 1 << (7-5) = 0x04
+ blob = append(blob, 0x04, 0x00, 0x00, 0x00)
+ // depths[0] = 2, packed as (2-1)=1 in 3 bits MSB-first → 0b0010_0000 = 0x20
+ blob = append(blob, 0x20)
+ // hashes[0] = 32 zero bytes
+ blob = append(blob, make([]byte, HashSize)...)
+
+ s := newNodeStore()
+ _, err := s.deserializeNode(blob, 0)
+ if err == nil {
+ t.Fatal("expected non-canonical position error, got nil")
+ }
+ if err.Error() != "non-canonical bitmap position" {
+ t.Errorf("expected 'non-canonical bitmap position', got %q", err.Error())
+ }
+}
+
+// TestDecodeRejectsInvalidDepthOffset covers depthOffset>groupDepth (the entry
+// would live below the group's bottom layer, impossible by construction). The
+// old depthOffset=0 and depthOffset>MaxGroupDepth cases are gone: the 3-bit
+// field stores (offset-1) ∈ [0,7], so offset 0 and offset 9 are unrepresentable
+// and can no longer be hand-crafted into a blob. Only offset>groupDepth with
+// groupDepth= 31*8 {
return nil, errors.New("node too deep")
}
- path := make([]byte, 0, depth+1)
- for i := range depth + 1 {
- bit := key[i/8] >> (7 - (i % 8)) & 1
- path = append(path, bit)
- }
- return path, nil
+ path := new(BitArray).SetBytes(uint8(depth+1), key)
+ return path.KeyBytes(), nil
}
// Invariant: dirty=false implies mustRecompute=false. Every mutation that
diff --git a/trie/bintrie/internal_node_test.go b/trie/bintrie/internal_node_test.go
index 4d8da8af37..2f7ade14b0 100644
--- a/trie/bintrie/internal_node_test.go
+++ b/trie/bintrie/internal_node_test.go
@@ -283,14 +283,13 @@ func TestInternalNodeCollectNodes(t *testing.T) {
t.Fatal(err)
}
- var collectedPaths [][]byte
- flushFn := func(path []byte, hash common.Hash, serialized []byte) {
- pathCopy := make([]byte, len(path))
- copy(pathCopy, path)
- collectedPaths = append(collectedPaths, pathCopy)
+ var collectedPaths []BitArray
+ flushFn := func(path BitArray, hash common.Hash, serialized []byte) {
+ collectedPaths = append(collectedPaths, path)
}
- s.collectNodes(s.root, []byte{1}, flushFn, 8)
+ initialPath := NewBitArray(1, 1)
+ s.collectNodes(s.root, initialPath, flushFn, 8)
// Should have collected 3 nodes: left stem, right stem, and the internal node itself
if len(collectedPaths) != 3 {
diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go
index a920f91378..e678ee310d 100644
--- a/trie/bintrie/iterator.go
+++ b/trie/bintrie/iterator.go
@@ -188,20 +188,20 @@ func (it *binaryNodeIterator) Parent() common.Hash {
return it.store.computeHash(it.stack[len(it.stack)-2].Node)
}
-// Path returns the bit-path to the current node.
+// Path returns the bit-packed path to the current node.
// Callers must not retain references to the returned slice after calling Next.
func (it *binaryNodeIterator) Path() []byte {
if it.Leaf() {
return it.LeafKey()
}
- var path []byte
+ var path BitArray
for i, state := range it.stack {
if i >= len(it.stack)-1 {
break
}
- path = append(path, byte(state.Index))
+ path.AppendBit(&path, uint8(state.Index))
}
- return path
+ return path.KeyBytes()
}
func (it *binaryNodeIterator) NodeBlob() []byte {
diff --git a/trie/bintrie/node_store.go b/trie/bintrie/node_store.go
index 8a35f06ee1..06e49d1d72 100644
--- a/trie/bintrie/node_store.go
+++ b/trie/bintrie/node_store.go
@@ -41,10 +41,15 @@ type nodeStore struct {
// stem-split keeps the old stem at a deeper position), so they don't
// have free lists.
freeHashed []uint32
+
+ // orphans holds on-disk paths whose committed blob has been abandoned by a
+ // stem depth-promotion since the last commit. Commit emits a deletion for
+ // each (unless a freshly flushed node reoccupies the path), then clears it.
+ orphans map[string]struct{}
}
func newNodeStore() *nodeStore {
- return &nodeStore{root: emptyRef}
+ return &nodeStore{root: emptyRef, orphans: make(map[string]struct{})}
}
func (s *nodeStore) allocInternal() uint32 {
@@ -179,6 +184,10 @@ func (s *nodeStore) Copy() *nodeStore {
ns.freeHashed = make([]uint32, len(s.freeHashed))
copy(ns.freeHashed, s.freeHashed)
}
+ ns.orphans = make(map[string]struct{}, len(s.orphans))
+ for path := range s.orphans {
+ ns.orphans[path] = struct{}{}
+ }
return ns
}
diff --git a/trie/bintrie/stem_node_test.go b/trie/bintrie/stem_node_test.go
index ae6b57ab34..8c563fce9d 100644
--- a/trie/bintrie/stem_node_test.go
+++ b/trie/bintrie/stem_node_test.go
@@ -313,14 +313,13 @@ func TestStemNodeCollectNodes(t *testing.T) {
t.Fatal(err)
}
- var collectedPaths [][]byte
- flushFn := func(path []byte, hash common.Hash, serialized []byte) {
- pathCopy := make([]byte, len(path))
- copy(pathCopy, path)
- collectedPaths = append(collectedPaths, pathCopy)
+ var collectedPaths []BitArray
+ flushFn := func(path BitArray, hash common.Hash, serialized []byte) {
+ collectedPaths = append(collectedPaths, path)
}
- s.collectNodes(s.root, []byte{0, 1, 0}, flushFn, 8)
+ initialPath := NewBitArray(3, 0b010)
+ s.collectNodes(s.root, initialPath, flushFn, 8)
// Should have collected one node (itself)
if len(collectedPaths) != 1 {
@@ -328,7 +327,8 @@ func TestStemNodeCollectNodes(t *testing.T) {
}
// Check the path
- if !bytes.Equal(collectedPaths[0], []byte{0, 1, 0}) {
- t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0])
+ expectedPath := NewBitArray(3, 0b010)
+ if !collectedPaths[0].Equal(&expectedPath) {
+ t.Errorf("Path mismatch: expected %v, got %v", expectedPath, collectedPaths[0])
}
}
diff --git a/trie/bintrie/store_commit.go b/trie/bintrie/store_commit.go
index b14bffbc6c..03710326b7 100644
--- a/trie/bintrie/store_commit.go
+++ b/trie/bintrie/store_commit.go
@@ -27,7 +27,7 @@ import (
"github.com/ethereum/go-ethereum/common"
)
-type nodeFlushFn func(path []byte, hash common.Hash, serialized []byte)
+type nodeFlushFn func(path BitArray, hash common.Hash, serialized []byte)
func (s *nodeStore) Hash() common.Hash {
return s.computeHash(s.root)
@@ -111,7 +111,7 @@ func (s *nodeStore) hashInternal(idx uint32) common.Hash {
// It traverses up to `remainingDepth` levels, storing hashes of bottom-layer children.
// position tracks the current index (0 to 2^groupDepth - 1) for bitmap placement.
// hashes collects the hashes of present children, bitmap tracks which positions are present.
-func (s *nodeStore) serializeSubtree(ref nodeRef, remainingDepth int, position int, absoluteDepth int, bitmap []byte, hashes *[]common.Hash) {
+func (s *nodeStore) serializeSubtree(ref nodeRef, remainingDepth int, position int, groupDepth int, bitmap []byte, hashes *[]common.Hash, depths *[]uint8) {
if remainingDepth == 0 {
// Bottom layer: store hash if not empty
switch ref.Kind() {
@@ -122,6 +122,7 @@ func (s *nodeStore) serializeSubtree(ref nodeRef, remainingDepth int, position i
// StemNode, HashedNode, or InternalNode at boundary: store hash
bitmap[position/8] |= 1 << (7 - (position % 8))
*hashes = append(*hashes, s.computeHash(ref))
+ *depths = append(*depths, uint8(groupDepth))
}
return
}
@@ -130,57 +131,81 @@ func (s *nodeStore) serializeSubtree(ref nodeRef, remainingDepth int, position i
case kindInternal:
leftPos := position * 2
rightPos := position*2 + 1
- s.serializeSubtree(s.getInternal(ref.Index()).left, remainingDepth-1, leftPos, absoluteDepth+1, bitmap, hashes)
- s.serializeSubtree(s.getInternal(ref.Index()).right, remainingDepth-1, rightPos, absoluteDepth+1, bitmap, hashes)
+ s.serializeSubtree(s.getInternal(ref.Index()).left, remainingDepth-1, leftPos, groupDepth, bitmap, hashes, depths)
+ s.serializeSubtree(s.getInternal(ref.Index()).right, remainingDepth-1, rightPos, groupDepth, bitmap, hashes, depths)
case kindEmpty:
return
default:
// StemNode or HashedNode encountered before reaching the group's bottom
// layer. Compute the leaf bitmap position where this node's hash will
// be stored.
- leafPos := position
- switch ref.Kind() {
- case kindStem:
- sn := s.getStem(ref.Index())
- // Extend position using the stem's key bits so that
- // GetValuesAtStem traversal (which follows key bits) finds the hash.
- for d := 0; d < remainingDepth; d++ {
- bit := sn.Stem[(absoluteDepth+d)/8] >> (7 - ((absoluteDepth + d) % 8)) & 1
- leafPos = leafPos*2 + int(bit)
- }
- default:
- // HashedNode or unknown: extend all-left (no key bits available).
- // This matches the all-zero path that resolveNode would follow.
- leafPos = position << remainingDepth
- }
- bitmap[leafPos/8] |= 1 << (7 - (leafPos % 8))
+ bitmapPos := position << remainingDepth
+ bitmap[bitmapPos/8] |= 1 << (7 - (bitmapPos % 8))
*hashes = append(*hashes, s.computeHash(ref))
+ *depths = append(*depths, uint8(groupDepth-remainingDepth))
}
}
+// depthBits is the number of bits used to encode one depth offset.
+const depthBits = 3
+
+// packedDepthsLen returns the byte length of k packed depth entries
+func packedDepthsLen(k int) int {
+ return (k*depthBits + 7) / 8
+}
+
+// writeDepth writes a depth entry at idx into the buf, MSB-first.
+func writeDepth(buf []byte, idx int, v uint8) {
+ pos := idx * depthBits
+ for i := range depthBits {
+ bit := (v >> (depthBits - 1 - i)) & 1
+ p := pos + i
+ buf[p>>3] |= bit << (7 - (p & 7))
+ }
+}
+
+// readDepth reads a depth for entry idx from buf.
+func readDepth(buf []byte, idx int) uint8 {
+ pos := idx * depthBits
+ var v uint8
+ for i := range depthBits {
+ p := pos + i
+ bit := (buf[p>>3] >> (7 - (p & 7))) & 1
+ v = v<<1 | bit
+ }
+ return v
+}
+
// SerializeNode serializes a node into the flat on-disk format.
func (s *nodeStore) serializeNode(ref nodeRef, groupDepth int) []byte {
switch ref.Kind() {
case kindInternal:
- // InternalNode group: 1 byte type + 1 byte group depth + variable bitmap + N×32 byte hashes
+ // InternalNode group format:
+ // [type(1)] [groupDepth(1)] [bitmap (2^groupDepth bits)] [depths(3 bits × K, padded)] [hashes(32B × K)]
bitmapSize := bitmapSizeForDepth(groupDepth)
bitmap := make([]byte, bitmapSize)
var hashes []common.Hash
+ var depths []uint8
- node := s.getInternal(ref.Index())
- s.serializeSubtree(ref, groupDepth, 0, int(node.depth), bitmap, &hashes)
+ s.serializeSubtree(ref, groupDepth, 0, groupDepth, bitmap, &hashes, &depths)
// Build serialized output
- serializedLen := NodeTypeBytes + 1 + bitmapSize + len(hashes)*HashSize
+ k := len(hashes)
+ depthsLen := packedDepthsLen(k)
+ serializedLen := NodeTypeBytes + 1 + bitmapSize + depthsLen + k*HashSize
serialized := make([]byte, serializedLen)
serialized[0] = nodeTypeInternal
- serialized[1] = byte(groupDepth) // group depth => bitmap size for a sparse group
+ serialized[1] = byte(groupDepth)
copy(serialized[2:2+bitmapSize], bitmap)
- offset := NodeTypeBytes + 1 + bitmapSize
- for _, h := range hashes {
- copy(serialized[offset:offset+HashSize], h.Bytes())
- offset += HashSize
+ depthsOff := NodeTypeBytes + 1 + bitmapSize
+ for i, d := range depths {
+ writeDepth(serialized[depthsOff:depthsOff+depthsLen], i, d-1)
+ }
+
+ hashesOff := depthsOff + depthsLen
+ for i, h := range hashes {
+ copy(serialized[hashesOff+i*HashSize:hashesOff+(i+1)*HashSize], h.Bytes())
}
return serialized
@@ -229,56 +254,90 @@ func (s *nodeStore) deserializeNodeWithHash(serialized []byte, depth int, hn com
}
// deserializeSubtree reconstructs an InternalNode subtree from grouped serialization.
-// remainingDepth is how many more levels to build, position is current index in the bitmap,
-// nodeDepth is the actual trie depth for the node being created.
-// hashIdx tracks the current position in the hash data (incremented as hashes are consumed).
-func (s *nodeStore) deserializeSubtree(hn common.Hash, remainingDepth int, position int, nodeDepth int, bitmap []byte, hashData []byte, hashIdx *int, mustRecompute bool, dirty bool) (nodeRef, error) {
- if remainingDepth == 0 {
- // Bottom layer: check bitmap and return HashedNode or Empty
- if bitmap[position/8]>>(7-(position%8))&1 == 1 {
- if len(hashData) < (*hashIdx+1)*HashSize {
- return emptyRef, errInvalidSerializedLength
- }
- hash := common.BytesToHash(hashData[*hashIdx*HashSize : (*hashIdx+1)*HashSize])
- *hashIdx++
- return s.newHashedRef(hash), nil
- }
+func (s *nodeStore) deserializeSubtree(hn common.Hash, groupDepth int, nodeDepth int, bitmap []byte, depths []byte, hashData []byte, mustRecompute bool, dirty bool) (nodeRef, error) {
+ if len(hashData)%HashSize != 0 {
+ return emptyRef, errInvalidSerializedLength
+ }
+ k := len(hashData) / HashSize
+ if len(depths) != packedDepthsLen(k) {
+ return emptyRef, errInvalidSerializedLength
+ }
+ if k == 0 {
return emptyRef, nil
}
- // Check if this entire subtree is empty by examining all relevant bitmap bits
- leftPos := position * 2
- rightPos := position*2 + 1
-
- // note that the parent might not need root computations, but the children
- // do, because their hash isn't saved. Hence `mustRecompute` is set to `true`.
- left, err := s.deserializeSubtree(common.Hash{}, remainingDepth-1, leftPos, nodeDepth+1, bitmap, hashData, hashIdx, true, dirty)
- if err != nil {
- return emptyRef, err
- }
- right, err := s.deserializeSubtree(common.Hash{}, remainingDepth-1, rightPos, nodeDepth+1, bitmap, hashData, hashIdx, true, dirty)
- if err != nil {
- return emptyRef, err
- }
-
- // If both children are empty, return Empty
- if left.IsEmpty() && right.IsEmpty() {
- return emptyRef, nil
- }
-
- ref := s.newInternalRef(nodeDepth)
- node := s.getInternal(ref.Index())
- node.left = left
- node.right = right
- node.mustRecompute = mustRecompute
+ rootRef := s.newInternalRef(nodeDepth)
+ rootNode := s.getInternal(rootRef.Index())
+ rootNode.mustRecompute = mustRecompute
if !mustRecompute {
- // mustRecompute will only be false for the root of the subtree,
- // for which we already know the hash.
- node.hash = hn
- node.mustRecompute = false
+ rootNode.hash = hn
}
- node.dirty = dirty
- return ref, nil
+ rootNode.dirty = dirty
+
+ bitmapBits := 1 << groupDepth
+ entryIdx := 0
+ for bit := 0; bit < bitmapBits; bit++ {
+ if bitmap[bit/8]>>(7-(bit%8))&1 == 0 {
+ continue
+ }
+ depthOffset := int(readDepth(depths, entryIdx)) + 1
+ if depthOffset > groupDepth {
+ return emptyRef, errors.New("invalid depth offset")
+ }
+ // Canonical-encoding check: trailing position bits must be zero.
+ mask := (1 << (groupDepth - depthOffset)) - 1
+ if bit&mask != 0 {
+ return emptyRef, errors.New("non-canonical bitmap position")
+ }
+ var hash common.Hash
+ copy(hash[:], hashData[entryIdx*HashSize:(entryIdx+1)*HashSize])
+ if err := s.attachInGroup(rootRef, nodeDepth, groupDepth, depthOffset, bit, hash, dirty); err != nil {
+ return emptyRef, err
+ }
+ entryIdx++
+ }
+ return rootRef, nil
+}
+
+func (s *nodeStore) attachInGroup(rootRef nodeRef, rootDepth, groupDepth, depthOffset, bitmapPos int, hash common.Hash, dirty bool) error {
+ cur := rootRef
+ for level := 0; level < depthOffset-1; level++ {
+ bit := (bitmapPos >> (groupDepth - 1 - level)) & 1
+ node := s.getInternal(cur.Index())
+ childRef := node.left
+ if bit == 1 {
+ childRef = node.right
+ }
+ if childRef.IsEmpty() {
+ newRef := s.newInternalRef(rootDepth + level + 1)
+ s.getInternal(newRef.Index()).dirty = dirty
+ if bit == 0 {
+ node.left = newRef
+ } else {
+ node.right = newRef
+ }
+ cur = newRef
+ continue
+ }
+ if childRef.Kind() != kindInternal {
+ return errors.New("overlapping entries in group blob")
+ }
+ cur = childRef
+ }
+ leafBit := (bitmapPos >> (groupDepth - depthOffset)) & 1
+ node := s.getInternal(cur.Index())
+ if leafBit == 0 {
+ if !node.left.IsEmpty() {
+ return errors.New("overlapping entries in group blob")
+ }
+ node.left = s.newHashedRef(hash)
+ } else {
+ if !node.right.IsEmpty() {
+ return errors.New("overlapping entries in group blob")
+ }
+ node.right = s.newHashedRef(hash)
+ }
+ return nil
}
func (s *nodeStore) decodeNode(serialized []byte, depth int, hn common.Hash, mustRecompute, dirty bool) (nodeRef, error) {
@@ -288,7 +347,9 @@ func (s *nodeStore) decodeNode(serialized []byte, depth int, hn common.Hash, mus
switch serialized[0] {
case nodeTypeInternal:
- // Grouped format: 1 byte type + 1 byte group depth + variable bitmap + N×32 byte hashes
+ // Grouped format:
+ // [type(1)] [groupDepth(1)] [bitmap (2^groupDepth bits, padded to bitmapSize bytes)]
+ // [depthOffsets (3 bits × K, padded to bytes)] [hashes (32B × K)]
if len(serialized) < NodeTypeBytes+1 {
return emptyRef, errInvalidSerializedLength
}
@@ -301,10 +362,38 @@ func (s *nodeStore) decodeNode(serialized []byte, depth int, hn common.Hash, mus
return 0, errInvalidSerializedLength
}
bitmap := serialized[2 : 2+bitmapSize]
- hashData := serialized[2+bitmapSize:]
- hashIdx := 0
- return s.deserializeSubtree(hn, groupDepth, 0, depth, bitmap, hashData, &hashIdx, mustRecompute, dirty)
+ bitmapBits := 1 << groupDepth
+ if bitmapBits < 8 {
+ padMask := byte(0xFF) >> bitmapBits
+ if bitmap[0]&padMask != 0 {
+ return emptyRef, errors.New("non-canonical bitmap padding")
+ }
+ }
+
+ k := 0
+ for _, b := range bitmap {
+ k += bits.OnesCount8(b)
+ }
+ depthsLen := packedDepthsLen(k)
+ expectedLen := NodeTypeBytes + 1 + bitmapSize + depthsLen + k*HashSize
+ if len(serialized) != expectedLen {
+ return emptyRef, errInvalidSerializedLength
+ }
+ depthsOff := NodeTypeBytes + 1 + bitmapSize
+ depths := serialized[depthsOff : depthsOff+depthsLen]
+ hashData := serialized[depthsOff+depthsLen : depthsOff+depthsLen+k*HashSize]
+
+ // Canonical-encoding check: the unused low bits of the last packed
+ // depth byte must be zero.
+ if usedBits := k * depthBits; usedBits%8 != 0 {
+ padMask := byte(0xFF) >> (usedBits % 8)
+ if depths[depthsLen-1]&padMask != 0 {
+ return emptyRef, errors.New("non-canonical depth padding")
+ }
+ }
+
+ return s.deserializeSubtree(hn, groupDepth, depth, bitmap, depths, hashData, mustRecompute, dirty)
case nodeTypeStem:
if len(serialized) < NodeTypeBytes+StemSize+StemBitmapSize {
@@ -340,7 +429,10 @@ func (s *nodeStore) decodeNode(serialized []byte, depth int, hn common.Hash, mus
// CollectNodes flushes every node that needs flushing via flushfn in post-order.
// Invariant: any ancestor of a node that needs flushing is itself marked, so a
// clean root means the whole subtree is clean.
-func (s *nodeStore) collectNodes(ref nodeRef, path []byte, flushfn nodeFlushFn, groupDepth int) {
+//
+// BitArray is passed by value (33 bytes) to keep child paths on the stack.
+// Passing by pointer causes escape to heap per recursive call.
+func (s *nodeStore) collectNodes(ref nodeRef, path BitArray, flushfn nodeFlushFn, groupDepth int) {
switch ref.Kind() {
case kindInternal:
node := s.getInternal(ref.Index())
@@ -375,77 +467,51 @@ func (s *nodeStore) collectNodes(ref nodeRef, path []byte, flushfn nodeFlushFn,
// collectChildGroups traverses within a group to find and collect nodes in the next group.
// remainingLevels is how many more levels below the current node until we reach the group boundary.
// When remainingLevels=0, the current node's children are at the next group boundary.
-func (s *nodeStore) collectChildGroups(node *InternalNode, path []byte, flushfn nodeFlushFn, groupDepth int, remainingLevels int) error {
+func (s *nodeStore) collectChildGroups(node *InternalNode, path BitArray, flushfn nodeFlushFn, groupDepth int, remainingLevels int) error {
if remainingLevels == 0 {
// Current node is at depth (groupBoundary - 1), its children are at the next group boundary
if !node.left.IsEmpty() {
- s.collectNodes(node.left, appendBit(path, 0), flushfn, groupDepth)
+ leftPath := path
+ leftPath.AppendBit(&leftPath, 0)
+ s.collectNodes(node.left, leftPath, flushfn, groupDepth)
}
if !node.right.IsEmpty() {
- s.collectNodes(node.right, appendBit(path, 1), flushfn, groupDepth)
+ rightPath := path
+ rightPath.AppendBit(&rightPath, 1)
+ s.collectNodes(node.right, rightPath, flushfn, groupDepth)
}
return nil
}
if !node.left.IsEmpty() {
+ leftPath := path
+ leftPath.AppendBit(&leftPath, 0)
switch node.left.Kind() {
case kindInternal:
n := s.getInternal(node.left.Index())
- if err := s.collectChildGroups(n, appendBit(path, 0), flushfn, groupDepth, remainingLevels-1); err != nil {
+ if err := s.collectChildGroups(n, leftPath, flushfn, groupDepth, remainingLevels-1); err != nil {
return err
}
default:
- extPath := s.extendPathToGroupLeaf(appendBit(path, 0), node.left, remainingLevels)
- s.collectNodes(node.left, extPath, flushfn, groupDepth)
+ s.collectNodes(node.left, leftPath, flushfn, groupDepth)
}
}
if !node.right.IsEmpty() {
+ rightPath := path
+ rightPath.AppendBit(&rightPath, 1)
switch node.right.Kind() {
case kindInternal:
n := s.getInternal(node.right.Index())
- if err := s.collectChildGroups(n, appendBit(path, 1), flushfn, groupDepth, remainingLevels-1); err != nil {
+ if err := s.collectChildGroups(n, rightPath, flushfn, groupDepth, remainingLevels-1); err != nil {
return err
}
default:
- extPath := s.extendPathToGroupLeaf(appendBit(path, 1), node.right, remainingLevels)
- s.collectNodes(node.right, extPath, flushfn, groupDepth)
+ s.collectNodes(node.right, rightPath, flushfn, groupDepth)
}
}
return nil
}
-// extendPathToGroupLeaf extends a storage path to the group's leaf boundary,
-// matching the projection done by serializeSubtree. For StemNodes, the path
-// is extended using the stem's key bits (same as serializeSubtree). For other
-// node types, the path is extended with all-zero (left) bits.
-func (s *nodeStore) extendPathToGroupLeaf(path []byte, node nodeRef, remainingLevels int) []byte {
- if remainingLevels <= 0 {
- return path
- }
- if node.Kind() == kindStem {
- sn := s.getStem(node.Index())
- for _ = range remainingLevels {
- bit := sn.Stem[len(path)/8] >> (7 - (len(path) % 8)) & 1
- path = appendBit(path, bit)
- }
- } else {
- // HashedNode or other: all-left extension (matches serializeSubtree's
- // position << remainingDepth behavior).
- for _ = range remainingLevels {
- path = appendBit(path, 0)
- }
- }
- return path
-}
-
-// appendBit appends a bit to a path, returning a new slice
-func appendBit(path []byte, bit byte) []byte {
- var p [256]byte
- copy(p[:], path)
- result := p[:len(path)]
- return append(result, bit)
-}
-
func (s *nodeStore) toDot(ref nodeRef, parent, path string) string {
switch ref.Kind() {
case kindInternal:
diff --git a/trie/bintrie/store_ops.go b/trie/bintrie/store_ops.go
index 9a73c8bd64..d3765bd651 100644
--- a/trie/bintrie/store_ops.go
+++ b/trie/bintrie/store_ops.go
@@ -262,7 +262,15 @@ func (s *nodeStore) splitStemValuesInsert(existingRef nodeRef, newStem []byte, v
bitStem := existing.Stem[existing.depth/8] >> (7 - (existing.depth % 8)) & 1
nRef := s.newInternalRef(int(existing.depth))
nNode := s.getInternal(nRef.Index())
+ if !existing.dirty {
+ var buf [33]byte
+ oldPath := new(BitArray).SetBytes(existing.depth, existing.Stem[:]).PutKeyBytes(buf[:])
+ s.orphans[string(oldPath)] = struct{}{}
+ }
existing.depth++
+ // The existing stem's on-disk path lengthens by one bit, which means
+ // the stem must be re-flushed at the longer new path.
+ existing.dirty = true
bitKey := newStem[nNode.depth/8] >> (7 - (nNode.depth % 8)) & 1
if bitKey == bitStem {
diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go
index 0d0c0e0e70..7b05e68df9 100644
--- a/trie/bintrie/trie.go
+++ b/trie/bintrie/trie.go
@@ -347,12 +347,35 @@ func (t *BinaryTrie) Hash() common.Hash {
func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) {
nodeset := trienode.NewNodeSet(common.Hash{})
- // Pre-size the path buffer: collectNodes reuses it in-place via
- // append/truncate; 32 covers typical binary-trie depth without regrowth.
- pathBuf := make([]byte, 0, 32)
- t.store.collectNodes(t.store.root, pathBuf, func(path []byte, hash common.Hash, serialized []byte) {
- nodeset.AddNode(path, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(path)))
+ // Stem depth-promotion abandons a committed blob and is the only source of
+ // orphans. When none are pending (the common case) we skip tracking flushed
+ // paths entirely, keeping Commit allocation-free beyond the node set.
+ var added map[string]struct{}
+ if len(t.store.orphans) > 0 {
+ added = make(map[string]struct{})
+ }
+ var rootPath BitArray
+ t.store.collectNodes(t.store.root, rootPath, func(path BitArray, hash common.Hash, serialized []byte) {
+ var buf [33]byte
+ pathBytes := path.PutKeyBytes(buf[:])
+ if added != nil {
+ added[string(pathBytes)] = struct{}{}
+ }
+ nodeset.AddNode(pathBytes, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(pathBytes)))
}, t.groupDepth)
+
+ // Delete blobs abandoned by stem depth-promotion, unless a freshly flushed
+ // node already reoccupies the path (the group-boundary case).
+ if len(t.store.orphans) > 0 {
+ for path := range t.store.orphans {
+ if _, ok := added[path]; ok {
+ continue
+ }
+ nodeset.AddNode([]byte(path), trienode.NewDeletedWithPrev(t.tracer.Get([]byte(path))))
+ }
+ t.store.orphans = make(map[string]struct{})
+ }
+
return t.Hash(), nodeset
}