From d464b9e485ab11aef9c8f7593e379a758d74b894 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 20 Apr 2026 21:09:21 +0800 Subject: [PATCH 1/4] trie/bintrie: use bitarray for path encoding --- trie/bintrie/binary_node_test.go | 8 +- trie/bintrie/bitarray.go | 584 +++++++++++++++ trie/bintrie/bitarray_test.go | 1078 ++++++++++++++++++++++++++++ trie/bintrie/internal_node.go | 10 +- trie/bintrie/internal_node_test.go | 11 +- trie/bintrie/iterator.go | 8 +- trie/bintrie/stem_node_test.go | 16 +- trie/bintrie/store_commit.go | 29 +- trie/bintrie/trie.go | 10 +- 9 files changed, 1708 insertions(+), 46 deletions(-) create mode 100644 trie/bintrie/bitarray.go create mode 100644 trie/bintrie/bitarray_test.go diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go index 857060a0c0..ff3b8225bb 100644 --- a/trie/bintrie/binary_node_test.go +++ b/trie/bintrie/binary_node_test.go @@ -245,28 +245,28 @@ func TestKeyToPath(t *testing.T) { name: "depth 0", depth: 0, key: []byte{0x80}, // 10000000 in binary - expected: []byte{1}, + expected: []byte{1}, // 1 bit packed: MSB=1 → 0x01 wantErr: false, }, { name: "depth 7", depth: 7, key: []byte{0xFF}, // 11111111 in binary - expected: []byte{1, 1, 1, 1, 1, 1, 1, 1}, + expected: []byte{0xFF}, // 8 bits packed into 1 byte wantErr: false, }, { name: "depth crossing byte boundary", depth: 10, key: []byte{0xFF, 0x00}, // 11111111 00000000 in binary - expected: []byte{1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}, + expected: []byte{0x07, 0xF8}, // 11 bits = 11111111000 → 0x07F8 wantErr: false, }, { name: "max valid depth", depth: StemSize*8 - 1, key: make([]byte, HashSize), - expected: make([]byte, StemSize*8), + expected: make([]byte, StemSize), // 248 bits of zeros → 31 packed bytes wantErr: false, }, { diff --git a/trie/bintrie/bitarray.go b/trie/bintrie/bitarray.go new file mode 100644 index 0000000000..5b39a629d3 --- /dev/null +++ b/trie/bintrie/bitarray.go @@ -0,0 +1,584 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . +package bintrie + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "math" +) + +const ( + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + maxUint8 = uint8(math.MaxUint8) +) + +var emptyBitArray = new(BitArray) + +// BitArray represents a bit array with length representing the number of used bits. +// It uses a little endian representation to do bitwise operations of the words efficiently. +// For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. +// The max length is 255 bits (uint8), because our use case only need up to 248 bits for a given trie key. +// Although words can be used to represent 256 bits, we don't want to add an additional byte for the length. +type BitArray struct { + len uint8 // number of used bits + words [4]uint64 // little endian (i.e. words[0] is the least significant) +} + +// NewBitArray creates a new bit array with the given length and value. +func NewBitArray(length uint8, val uint64) BitArray { + var b BitArray + b.SetUint64(length, val) + return b +} + +func (b *BitArray) Len() uint8 { + return b.len +} + +// Bytes returns the bytes representation of the bit array in big endian format +func (b *BitArray) Bytes() [32]byte { + var res [32]byte + + binary.BigEndian.PutUint64(res[0:8], b.words[3]) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + + return res +} + +// Append sets the bit array to the concatenation of x and y and returns the bit array. +// For example: +// +// x = 000 (len=3) +// y = 111 (len=3) +// Append(x,y) = 000111 (len=6) +func (b *BitArray) Append(x, y *BitArray) *BitArray { + if x.len == 0 { + return b.Set(y) + } + if y.len == 0 { + return b.Set(x) + } + if x.len > maxUint8-y.len { + panic("error on bitarray append: result would exceed maximum length of 255 bits") + } + + // Shift left by y's length and OR with y + return b.lsh(x, y.len).or(b, y) +} + +// AppendBit sets the bit array to the concatenation of x and a single bit. +// Equivalent to Append(x, {bit}) but avoids allocating a temporary BitArray. +func (b *BitArray) AppendBit(x *BitArray, bit uint8) *BitArray { + if x.len == 0 { + return b.SetBit(bit) + } + b.lsh(x, 1) + b.words[0] |= uint64(bit & 1) + return b +} + +// MSBs sets the bit array to the most significant 'n' bits of x, that is position 0 to n (exclusive). +// If n >= x.len, the bit array is an exact copy of x. +// Think of this method as array[0:n] +// For example: +// +// x = 11001011 (len=8) +// MSBs(x, 4) = 1100 (len=4) +// MSBs(x, 10) = 11001011 (len=8, original x) +// MSBs(x, 0) = 0 (len=0) +func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + return b.rsh(x, x.len-n) +} + +// Equal checks if two bit arrays are equal +func (b *BitArray) Equal(x *BitArray) bool { + if b == nil || x == nil { + panic("bit array is nil") + } + + return b.len == x.len && b.words == x.words +} + +// SetBytes interprets the data as the big-endian bytes, sets the bit array to that value and returns it. +// If the data is larger than 32 bytes, only the first 32 bytes are used. +func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { + switch l := len(data); l { + case 0: + b.clear() + case 1: + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(data[0]) + case 2: + _ = data[1] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint16(data[0:2])) + case 3: + _ = data[2] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16 + case 4: + _ = data[3] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint32(data[0:4])) + case 5: + _ = data[4] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint40(data[0:5]) + case 6: + _ = data[5] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint48(data[0:6]) + case 7: + _ = data[6] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint56(data[0:7]) + case 8: + _ = data[7] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, binary.BigEndian.Uint64(data[0:8]) + case 9: + _ = data[8] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(data[0]), binary.BigEndian.Uint64(data[1:9]) + case 10: + _ = data[9] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(binary.BigEndian.Uint16(data[0:2])), binary.BigEndian.Uint64(data[2:10]) + case 11: + _ = data[10] + b.words[3], b.words[2] = 0, 0 + b.words[1], b.words[0] = uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16, binary.BigEndian.Uint64(data[3:11]) + case 12: + _ = data[11] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(binary.BigEndian.Uint32(data[0:4])), binary.BigEndian.Uint64(data[4:12]) + case 13: + _ = data[12] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint40(data[0:5]), binary.BigEndian.Uint64(data[5:13]) + case 14: + _ = data[13] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint48(data[0:6]), binary.BigEndian.Uint64(data[6:14]) + case 15: + _ = data[14] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint56(data[0:7]), binary.BigEndian.Uint64(data[7:15]) + case 16: + _ = data[15] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, binary.BigEndian.Uint64(data[0:8]), binary.BigEndian.Uint64(data[8:16]) + case 17: + _ = data[16] + b.words[3], b.words[2] = 0, uint64(data[0]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[1:9]), binary.BigEndian.Uint64(data[9:17]) + case 18: + _ = data[17] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint16(data[0:2])) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[2:10]), binary.BigEndian.Uint64(data[10:18]) + case 19: + _ = data[18] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16 + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[3:11]), binary.BigEndian.Uint64(data[11:19]) + case 20: + _ = data[19] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint32(data[0:4])) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[4:12]), binary.BigEndian.Uint64(data[12:20]) + case 21: + _ = data[20] + b.words[3], b.words[2] = 0, bigEndianUint40(data[0:5]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[5:13]), binary.BigEndian.Uint64(data[13:21]) + case 22: + _ = data[21] + b.words[3], b.words[2] = 0, bigEndianUint48(data[0:6]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[6:14]), binary.BigEndian.Uint64(data[14:22]) + case 23: + _ = data[22] + b.words[3], b.words[2] = 0, bigEndianUint56(data[0:7]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[7:15]), binary.BigEndian.Uint64(data[15:23]) + case 24: + _ = data[23] + b.words[3], b.words[2] = 0, binary.BigEndian.Uint64(data[0:8]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[8:16]), binary.BigEndian.Uint64(data[16:24]) + case 25: + _ = data[24] + b.words[3], b.words[2] = uint64(data[0]), binary.BigEndian.Uint64(data[1:9]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[9:17]), binary.BigEndian.Uint64(data[17:25]) + case 26: + _ = data[25] + b.words[3], b.words[2] = uint64(binary.BigEndian.Uint16(data[0:2])), binary.BigEndian.Uint64(data[2:10]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[10:18]), binary.BigEndian.Uint64(data[18:26]) + case 27: + _ = data[26] + b.words[3] = uint64(binary.BigEndian.Uint16(data[1:3])) | uint64(data[0])<<16 + b.words[2] = binary.BigEndian.Uint64(data[3:11]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[11:19]), binary.BigEndian.Uint64(data[19:27]) + case 28: + _ = data[27] + b.words[3], b.words[2] = uint64(binary.BigEndian.Uint32(data[0:4])), binary.BigEndian.Uint64(data[4:12]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[12:20]), binary.BigEndian.Uint64(data[20:28]) + case 29: + _ = data[28] + b.words[3], b.words[2] = bigEndianUint40(data[0:5]), binary.BigEndian.Uint64(data[5:13]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[13:21]), binary.BigEndian.Uint64(data[21:29]) + case 30: + _ = data[29] + b.words[3], b.words[2] = bigEndianUint48(data[0:6]), binary.BigEndian.Uint64(data[6:14]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[14:22]), binary.BigEndian.Uint64(data[22:30]) + case 31: + _ = data[30] + b.words[3], b.words[2] = bigEndianUint56(data[0:7]), binary.BigEndian.Uint64(data[7:15]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[15:23]), binary.BigEndian.Uint64(data[23:31]) + default: + b.setBytes32(data) + } + b.len = length + b.truncateToLength() + return b +} + +// SetUint64 sets the bit array to the uint64 representation of a bit array. +func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { + b.words[0] = data + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + b.len = length + b.truncateToLength() + return b +} + +// SetBit sets the bit array to a single bit. +func (b *BitArray) SetBit(bit uint8) *BitArray { + b.len = 1 + b.words[0] = uint64(bit & 1) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + return b +} + +// Copy returns a deep copy of the bit array. +func (b *BitArray) Copy() BitArray { + var res BitArray + res.Set(b) + return res +} + +// String returns a string representation of the bit array. +// This is typically used for logging or debugging. +func (b *BitArray) String() string { + bt := b.Bytes() + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(bt[:])) +} + +// Bit returns the bit value at position n, where n = 0 is MSB. +// If n is out of bounds, returns 0. +func (b *BitArray) Bit(n uint8) uint8 { + if n >= b.Len() { + return 0 + } + + return b.bitFromLSB(b.Len() - n - 1) +} + +// Set sets the bit array to the same value as x. +func (b *BitArray) Set(x *BitArray) *BitArray { + b.len = x.len + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] + return b +} + +// ActiveBytes returns a slice containing only the bytes that are actually used by the bit array, +// as specified by the length. The returned slice is in big-endian order. +// +// Example: +// +// len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] +func (b *BitArray) ActiveBytes() []byte { + wordsBytes := b.Bytes() + return wordsBytes[32-b.byteCount():] +} + +// PutActiveBytes writes the active bytes into dst (which must be at least 32 bytes) +// and returns the populated sub-slice. No heap allocation occurs because the +// backing array is owned by the caller. +func (b *BitArray) PutActiveBytes(dst *[32]byte) []byte { + binary.BigEndian.PutUint64(dst[0:8], b.words[3]) + binary.BigEndian.PutUint64(dst[8:16], b.words[2]) + binary.BigEndian.PutUint64(dst[16:24], b.words[1]) + binary.BigEndian.PutUint64(dst[24:32], b.words[0]) + return dst[32-b.byteCount():] +} + +// bitFromLSB returns the bit value at position n, where n = 0 is LSB. +// If n is out of bounds, returns 0. +func (b *BitArray) bitFromLSB(n uint8) uint8 { + if n >= b.len { + return 0 + } + + if (b.words[n/64] & (1 << (n % 64))) != 0 { + return 1 + } + + return 0 +} + +// copyLsb sets the bit array to the least significant 'n' bits of x. +// n is counted from the least significant bit, starting at 0. +// If length >= x.len, the bit array is an exact copy of x. +// For example: +// +// x = 11001011 (len=8) +// copyLsb(x, 4) = 1011 (len=4) +// copyLsb(x, 10) = 11001011 (len=8, original x) +// copyLsb(x, 0) = 0 (len=0) +func (b *BitArray) copyLsb(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + b.len = n + + switch { + case n == 0: + b.words = [4]uint64{0, 0, 0, 0} + case n <= 64: + b.words[0] = x.words[0] & (maxUint64 >> (64 - n)) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case n <= 128: + b.words[0] = x.words[0] + b.words[1] = x.words[1] & (maxUint64 >> (128 - n)) + b.words[2], b.words[3] = 0, 0 + case n <= 192: + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] & (maxUint64 >> (192 - n)) + b.words[3] = 0 + default: + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] & (maxUint64 >> (256 - uint16(n))) + } + + return b +} + +// lsb returns the least significant bits of `x` with `n` counted from the most significant bit, starting at 0. +// Think of this method as array[n:] +// For example: +// +// x = 11001011 (len=8) +// lsb(x, 1) = 1001011 (len=7) +// lsb(x, 10) = 0 (len=0) +// lsb(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) lsb(x *BitArray, n uint8) *BitArray { + if n == 0 { + return b.Set(x) + } + + if n > x.Len() { + return b.clear() + } + + return b.copyLsb(x, x.Len()-n) +} + +// or sets the bit array to x | y and returns the bit array. +func (b *BitArray) or(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] | y.words[0] + b.words[1] = x.words[1] | y.words[1] + b.words[2] = x.words[2] | y.words[2] + b.words[3] = x.words[3] | y.words[3] + b.len = x.len + return b +} + +// rsh sets the bit array to x >> n and returns the bit array. +func (b *BitArray) rsh(x *BitArray, n uint8) *BitArray { + if x.len == 0 { + return b.Set(x) + } + + if n >= x.len { + return b.clear() + } + + switch { + case n == 0: + return b.Set(x) + case n >= 192: + b.rsh192(x) + b.len = x.len - n + n -= 192 + b.words[0] >>= n + case n >= 128: + b.rsh128(x) + b.len = x.len - n + n -= 128 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] >>= n + case n >= 64: + b.rsh64(x) + b.len = x.len - n + n -= 64 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] >>= n + default: + b.Set(x) + b.len -= n + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) + b.words[3] >>= n + } + + b.truncateToLength() + return b +} + +// lsh sets the bit array to x << n and returns the bit array. +func (b *BitArray) lsh(x *BitArray, n uint8) *BitArray { + if x.len == 0 || n == 0 { + return b.Set(x) + } + + // If the result will overflow, we set the length to the max length + // but we still shift `n` bits + if n > maxUint8-x.len { + b.len = maxUint8 + } else { + b.len = x.len + n + } + + switch { + case n >= 192: + b.lsh192(x) + n -= 192 + b.words[3] <<= n + case n >= 128: + b.lsh128(x) + n -= 128 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] <<= n + case n >= 64: + b.lsh64(x) + n -= 64 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] <<= n + default: + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[3], x.words[2], x.words[1], x.words[0] + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] = (b.words[1] << n) | (b.words[0] >> (64 - n)) + b.words[0] <<= n + } + + b.truncateToLength() + return b +} + +func (b *BitArray) setBytes32(data []byte) { + _ = data[31] // bound check hint, see https://golang.org/issue/14808 + b.words[3] = binary.BigEndian.Uint64(data[0:8]) + b.words[2] = binary.BigEndian.Uint64(data[8:16]) + b.words[1] = binary.BigEndian.Uint64(data[16:24]) + b.words[0] = binary.BigEndian.Uint64(data[24:32]) +} + +// byteCount returns the minimum number of bytes needed to represent the bit array. +// It rounds up to the nearest byte. +func (b *BitArray) byteCount() uint { + const bits8 = 8 + return (uint(b.len) + (bits8 - 1)) / uint(bits8) +} + +func (b *BitArray) rsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] +} + +func (b *BitArray) rsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, x.words[3], x.words[2] +} + +func (b *BitArray) rsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] +} + +func (b *BitArray) lsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[2], x.words[1], x.words[0], 0 +} + +func (b *BitArray) lsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[1], x.words[0], 0, 0 +} + +func (b *BitArray) lsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[0], 0, 0, 0 +} + +func (b *BitArray) clear() *BitArray { + b.len = 0 + b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 + return b +} + +// truncateToLength truncates the bit array to the specified length, ensuring that any unused bits are all zeros. +// +// Example: +// +// b := &BitArray{ +// len: 5, +// words: [4]uint64{ +// 0xFFFFFFFFFFFFFFFF, // Before: all bits are 1 +// 0x0, 0x0, 0x0, +// }, +// } +// b.truncateToLength() +// // After: only first 5 bits remain +// // words[0] = 0x000000000000001F +// // words[1..3] = 0x0 +func (b *BitArray) truncateToLength() { + switch { + case b.len == 0: + b.words = [4]uint64{0, 0, 0, 0} + case b.len <= 64: + b.words[0] &= maxUint64 >> (64 - b.len) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case b.len <= 128: + b.words[1] &= maxUint64 >> (128 - b.len) + b.words[2], b.words[3] = 0, 0 + case b.len <= 192: + b.words[2] &= maxUint64 >> (192 - b.len) + b.words[3] = 0 + default: + b.words[3] &= maxUint64 >> (256 - uint16(b.len)) + } +} + +func bigEndianUint40(b []byte) uint64 { + _ = b[4] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[4]) | uint64(b[3])<<8 | uint64(b[2])<<16 | uint64(b[1])<<24 | + uint64(b[0])<<32 +} + +func bigEndianUint48(b []byte) uint64 { + _ = b[5] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[5]) | uint64(b[4])<<8 | uint64(b[3])<<16 | uint64(b[2])<<24 | + uint64(b[1])<<32 | uint64(b[0])<<40 +} + +func bigEndianUint56(b []byte) uint64 { + _ = b[6] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[6]) | uint64(b[5])<<8 | uint64(b[4])<<16 | uint64(b[3])<<24 | + uint64(b[2])<<32 | uint64(b[1])<<40 | uint64(b[0])<<48 +} diff --git a/trie/bintrie/bitarray_test.go b/trie/bintrie/bitarray_test.go new file mode 100644 index 0000000000..3ea01f14e5 --- /dev/null +++ b/trie/bintrie/bitarray_test.go @@ -0,0 +1,1078 @@ +package bintrie + +import ( + "bytes" + "encoding/binary" + "math/bits" + "testing" +) + +const ( + ones63 = 0x7FFFFFFFFFFFFFFF // 63 bits of 1 +) + +func TestBytes(t *testing.T) { + tests := []struct { + name string + ba BitArray + want [32]byte + }{ + { + name: "length == 0", + ba: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + want: [32]byte{}, + }, + { + name: "length < 64", + ba: BitArray{len: 38, words: [4]uint64{0x3FFFFFFFFF, 0, 0, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) + return b + }(), + }, + { + name: "64 <= length < 128", + ba: BitArray{len: 100, words: [4]uint64{maxUint64, 0xFFFFFFFFF, 0, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "128 <= length < 192", + ba: BitArray{len: 130, words: [4]uint64{maxUint64, maxUint64, 0x3, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[8:16], 0x3) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "192 <= length < 255", + ba: BitArray{len: 201, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x1FF}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x1FF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "length == 254", + ba: BitArray{len: 254, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x3FFFFFFFFFFFFFFF}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "length == 255", + ba: BitArray{len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], ones63) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.Bytes() + if !bytes.Equal(got[:], tt.want[:]) { + t.Errorf("BitArray.Bytes() = %v, want %v", got, tt.want) + } + + // check if the received bytes has the same bit count as the BitArray.len + count := 0 + for _, b := range got { + count += bits.OnesCount8(b) + } + if count != int(tt.ba.len) { + t.Errorf("BitArray.Bytes() bit count = %v, want %v", count, tt.ba.len) + } + }) + } +} + +func TestRsh(t *testing.T) { + tests := []struct { + name string + initial *BitArray + shiftBy uint8 + expected *BitArray + }{ + { + name: "zero length array", + initial: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + shiftBy: 5, + expected: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by 0", + initial: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + shiftBy: 0, + expected: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift by more than length", + initial: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 65, + expected: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by less than 64", + initial: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 32, + expected: &BitArray{ + len: 96, + words: [4]uint64{maxUint64, 0x00000000FFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by exactly 64", + initial: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 64, + expected: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift by 127", + initial: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + shiftBy: 127, + expected: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + }, + { + name: "shift by 128", + initial: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + shiftBy: 128, + expected: &BitArray{ + len: 123, + words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by 192", + initial: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + shiftBy: 192, + expected: &BitArray{ + len: 59, + words: [4]uint64{0x7FFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(BitArray).rsh(tt.initial, tt.shiftBy) + if !result.Equal(tt.expected) { + t.Errorf("rsh() got = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestLsh(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 5, + want: emptyBitArray, + }, + { + name: "shift by 0", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 4, + want: &BitArray{ + len: 8, + words: [4]uint64{0xF0, 0, 0, 0}, // 11110000 + }, + }, + { + name: "shift across word boundary", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 62, + want: &BitArray{ + len: 66, + words: [4]uint64{0xC000000000000000, 0x3, 0, 0}, + }, + }, + { + name: "shift by 64 (full word)", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 64, + want: &BitArray{ + len: 72, + words: [4]uint64{0, 0xFF, 0, 0}, + }, + }, + { + name: "shift by 128", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 128, + want: &BitArray{ + len: 136, + words: [4]uint64{0, 0, 0xFF, 0}, + }, + }, + { + name: "shift by 192", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 192, + want: &BitArray{ + len: 200, + words: [4]uint64{0, 0, 0, 0xFF}, + }, + }, + { + name: "shift causing length overflow", + x: &BitArray{ + len: 200, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{ + 0xF000000000000000, + 0xF, + 0, + 0, + }, + }, + }, + { + name: "shift sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + n: 4, + want: &BitArray{ + len: 12, + words: [4]uint64{0xAA0, 0, 0, 0}, // 101010100000 + }, + }, + { + name: "shift partial word across boundary", + x: &BitArray{ + len: 100, + words: [4]uint64{0xFF, 0xFF, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 160, + words: [4]uint64{ + 0xF000000000000000, + 0xF00000000000000F, + 0xF, + 0, + }, + }, + }, + { + name: "near maximum length shift", + x: &BitArray{ + len: 251, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 4, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{0xFF0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).lsh(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("Lsh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAppend(t *testing.T) { + tests := []struct { + name string + x *BitArray + y *BitArray + want *BitArray + }{ + { + name: "both empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "first array empty", + x: emptyBitArray, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "second array empty", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: emptyBitArray, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + }, + { + name: "different lengths within word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 2, + words: [4]uint64{0x3, 0, 0, 0}, // 11 + }, + want: &BitArray{ + len: 6, + words: [4]uint64{0x3F, 0, 0, 0}, // 111111 + }, + }, + { + name: "across word boundary", + x: &BitArray{ + len: 62, + words: [4]uint64{0x3FFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 66, + words: [4]uint64{maxUint64, 0x3, 0, 0}, + }, + }, + { + name: "across multiple words", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + y: &BitArray{ + len: 8, + words: [4]uint64{0x55, 0, 0, 0}, // 01010101 + }, + want: &BitArray{ + len: 16, + words: [4]uint64{0xAA55, 0, 0, 0}, // 1010101001010101 + }, + }, + { + name: "result exactly at length limit", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, + }, + want: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Append(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("Append() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + pos uint8 + want *BitArray + }{ + { + name: "zero position", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "position beyond length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 65, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get last 4 bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + pos: 4, + want: &BitArray{ + len: 4, + words: [4]uint64{0x0F, 0, 0, 0}, // 1111 + }, + }, + { + name: "get bits across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get bits from max length array", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + pos: 200, + want: &BitArray{ + len: 51, + words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "empty array", + x: emptyBitArray, + pos: 1, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 + }, + pos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + }, + { + name: "position equals length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).lsb(tt.x, tt.pos) + if !got.Equal(tt.want) { + t.Errorf("LSBs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBsFromLSB(t *testing.T) { + tests := []struct { + name string + initial BitArray + length uint8 + expected BitArray + }{ + { + name: "zero", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 0, + expected: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get 32 LSBs", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 32, + expected: BitArray{ + len: 32, + words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get 1 LSB", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 1, + expected: BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + }, + { + name: "get 100 LSBs across words", + initial: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 100, + expected: BitArray{ + len: 100, + words: [4]uint64{maxUint64, 0x0000000FFFFFFFFF, 0, 0}, + }, + }, + { + name: "get 64 LSBs at word boundary", + initial: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 64, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get 128 LSBs at word boundary", + initial: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 128, + expected: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + }, + { + name: "get 150 LSBs in third word", + initial: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 150, + expected: BitArray{ + len: 150, + words: [4]uint64{maxUint64, maxUint64, 0x3FFFFF, 0}, + }, + }, + { + name: "get 220 LSBs in fourth word", + initial: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + length: 220, + expected: BitArray{ + len: 220, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0xFFFFFFF}, + }, + }, + { + name: "get 251 LSBs", + initial: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + length: 251, + expected: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + }, + { + name: "get 100 LSBs from sparse bits", + initial: BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 100, + expected: BitArray{ + len: 100, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, + }, + }, + { + name: "no change when new length equals current length", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 64, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "no change when new length greater than current length", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 128, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(BitArray).copyLsb(&tt.initial, tt.length) + if !result.Equal(&tt.expected) { + t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestMSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 0, + want: emptyBitArray, + }, + { + name: "get all bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get more bits than available", + x: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get half of available bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 32, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF00000000 >> 32, 0, 0, 0}, + }, + }, + { + name: "get MSBs across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + n: 100, + want: &BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64 >> 28, 0, 0}, + }, + }, + { + name: "get MSBs from max length array", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get zero bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{0x5555555555555555, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).MSBs(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + + if got.len != tt.want.len { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSetBit(t *testing.T) { + tests := []struct { + name string + bit uint8 + want BitArray + }{ + { + name: "set bit 0", + bit: 0, + want: BitArray{ + len: 1, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "set bit 1", + bit: 1, + want: BitArray{ + len: 1, + words: [4]uint64{1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).SetBit(tt.bit) + if !got.Equal(&tt.want) { + t.Errorf("SetBit(%v) = %v, want %v", tt.bit, got, tt.want) + } + }) + } +} + +func TestSetBytes(t *testing.T) { + tests := []struct { + name string + length uint8 + data []byte + want BitArray + }{ + { + name: "empty data", + length: 0, + data: []byte{}, + want: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "single byte", + length: 8, + data: []byte{0xFF}, + want: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + }, + { + name: "two bytes", + length: 16, + data: []byte{0xAA, 0xFF}, + want: BitArray{ + len: 16, + words: [4]uint64{0xAAFF, 0, 0, 0}, + }, + }, + { + name: "three bytes", + length: 24, + data: []byte{0xAA, 0xBB, 0xCC}, + want: BitArray{ + len: 24, + words: [4]uint64{0xAABBCC, 0, 0, 0}, + }, + }, + { + name: "four bytes", + length: 32, + data: []byte{0xAA, 0xBB, 0xCC, 0xDD}, + want: BitArray{ + len: 32, + words: [4]uint64{0xAABBCCDD, 0, 0, 0}, + }, + }, + { + name: "eight bytes (full word)", + length: 64, + data: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + want: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "sixteen bytes (two words)", + length: 128, + data: []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, + }, + want: BitArray{ + len: 128, + words: [4]uint64{ + 0xAAAAAAAAAAAAAAAA, + 0xFFFFFFFFFFFFFFFF, + 0, 0, + }, + }, + }, + { + name: "thirty-two bytes (full array)", + length: 251, + data: []byte{ + 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + }, + want: BitArray{ + len: 251, + words: [4]uint64{ + maxUint64, + maxUint64, + maxUint64, + 0x7FFFFFFFFFFFFFF, + }, + }, + }, + { + name: "truncate to length", + length: 4, + data: []byte{0xFF}, + want: BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, + }, + }, + { + name: "data larger than 32 bytes", + length: 251, + data: []byte{ + 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, // extra bytes should be ignored + }, + want: BitArray{ + len: 251, + words: [4]uint64{ + maxUint64, + maxUint64, + maxUint64, + 0x7FFFFFFFFFFFFFF, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).SetBytes(tt.length, tt.data) + if !got.Equal(&tt.want) { + t.Errorf("SetBytes(%d, %v) = %v, want %v", tt.length, tt.data, got, tt.want) + } + }) + } +} diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go index b83cb92d87..3d7010a194 100644 --- a/trie/bintrie/internal_node.go +++ b/trie/bintrie/internal_node.go @@ -26,12 +26,10 @@ func keyToPath(depth int, key []byte) ([]byte, error) { if depth >= 31*8 { return nil, errors.New("node too deep") } - path := make([]byte, 0, depth+1) - for i := range depth + 1 { - bit := key[i/8] >> (7 - (i % 8)) & 1 - path = append(path, bit) - } - return path, nil + keyLen := min(len(key), 31) + ba := new(BitArray).SetBytes(uint8(keyLen*8), key[:keyLen]) + path := new(BitArray).MSBs(ba, uint8(depth+1)) + return path.ActiveBytes(), nil } // Invariant: dirty=false implies mustRecompute=false. Every mutation that diff --git a/trie/bintrie/internal_node_test.go b/trie/bintrie/internal_node_test.go index 4d8da8af37..2f7ade14b0 100644 --- a/trie/bintrie/internal_node_test.go +++ b/trie/bintrie/internal_node_test.go @@ -283,14 +283,13 @@ func TestInternalNodeCollectNodes(t *testing.T) { t.Fatal(err) } - var collectedPaths [][]byte - flushFn := func(path []byte, hash common.Hash, serialized []byte) { - pathCopy := make([]byte, len(path)) - copy(pathCopy, path) - collectedPaths = append(collectedPaths, pathCopy) + var collectedPaths []BitArray + flushFn := func(path BitArray, hash common.Hash, serialized []byte) { + collectedPaths = append(collectedPaths, path) } - s.collectNodes(s.root, []byte{1}, flushFn, 8) + initialPath := NewBitArray(1, 1) + s.collectNodes(s.root, initialPath, flushFn, 8) // Should have collected 3 nodes: left stem, right stem, and the internal node itself if len(collectedPaths) != 3 { diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go index a920f91378..88622ddaae 100644 --- a/trie/bintrie/iterator.go +++ b/trie/bintrie/iterator.go @@ -188,20 +188,20 @@ func (it *binaryNodeIterator) Parent() common.Hash { return it.store.computeHash(it.stack[len(it.stack)-2].Node) } -// Path returns the bit-path to the current node. +// Path returns the bit-packed path to the current node. // Callers must not retain references to the returned slice after calling Next. func (it *binaryNodeIterator) Path() []byte { if it.Leaf() { return it.LeafKey() } - var path []byte + var path BitArray for i, state := range it.stack { if i >= len(it.stack)-1 { break } - path = append(path, byte(state.Index)) + path.AppendBit(&path, uint8(state.Index)) } - return path + return path.ActiveBytes() } func (it *binaryNodeIterator) NodeBlob() []byte { diff --git a/trie/bintrie/stem_node_test.go b/trie/bintrie/stem_node_test.go index ae6b57ab34..8c563fce9d 100644 --- a/trie/bintrie/stem_node_test.go +++ b/trie/bintrie/stem_node_test.go @@ -313,14 +313,13 @@ func TestStemNodeCollectNodes(t *testing.T) { t.Fatal(err) } - var collectedPaths [][]byte - flushFn := func(path []byte, hash common.Hash, serialized []byte) { - pathCopy := make([]byte, len(path)) - copy(pathCopy, path) - collectedPaths = append(collectedPaths, pathCopy) + var collectedPaths []BitArray + flushFn := func(path BitArray, hash common.Hash, serialized []byte) { + collectedPaths = append(collectedPaths, path) } - s.collectNodes(s.root, []byte{0, 1, 0}, flushFn, 8) + initialPath := NewBitArray(3, 0b010) + s.collectNodes(s.root, initialPath, flushFn, 8) // Should have collected one node (itself) if len(collectedPaths) != 1 { @@ -328,7 +327,8 @@ func TestStemNodeCollectNodes(t *testing.T) { } // Check the path - if !bytes.Equal(collectedPaths[0], []byte{0, 1, 0}) { - t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0]) + expectedPath := NewBitArray(3, 0b010) + if !collectedPaths[0].Equal(&expectedPath) { + t.Errorf("Path mismatch: expected %v, got %v", expectedPath, collectedPaths[0]) } } diff --git a/trie/bintrie/store_commit.go b/trie/bintrie/store_commit.go index b14bffbc6c..82220387eb 100644 --- a/trie/bintrie/store_commit.go +++ b/trie/bintrie/store_commit.go @@ -27,7 +27,7 @@ import ( "github.com/ethereum/go-ethereum/common" ) -type nodeFlushFn func(path []byte, hash common.Hash, serialized []byte) +type nodeFlushFn func(path BitArray, hash common.Hash, serialized []byte) func (s *nodeStore) Hash() common.Hash { return s.computeHash(s.root) @@ -340,7 +340,10 @@ func (s *nodeStore) decodeNode(serialized []byte, depth int, hn common.Hash, mus // CollectNodes flushes every node that needs flushing via flushfn in post-order. // Invariant: any ancestor of a node that needs flushing is itself marked, so a // clean root means the whole subtree is clean. -func (s *nodeStore) collectNodes(ref nodeRef, path []byte, flushfn nodeFlushFn, groupDepth int) { +// +// BitArray is passed by value (33 bytes) to keep child paths on the stack. +// Passing by pointer causes escape to heap per recursive call. +func (s *nodeStore) collectNodes(ref nodeRef, path BitArray, flushfn nodeFlushFn, groupDepth int) { switch ref.Kind() { case kindInternal: node := s.getInternal(ref.Index()) @@ -375,7 +378,7 @@ func (s *nodeStore) collectNodes(ref nodeRef, path []byte, flushfn nodeFlushFn, // collectChildGroups traverses within a group to find and collect nodes in the next group. // remainingLevels is how many more levels below the current node until we reach the group boundary. // When remainingLevels=0, the current node's children are at the next group boundary. -func (s *nodeStore) collectChildGroups(node *InternalNode, path []byte, flushfn nodeFlushFn, groupDepth int, remainingLevels int) error { +func (s *nodeStore) collectChildGroups(node *InternalNode, path BitArray, flushfn nodeFlushFn, groupDepth int, remainingLevels int) error { if remainingLevels == 0 { // Current node is at depth (groupBoundary - 1), its children are at the next group boundary if !node.left.IsEmpty() { @@ -418,32 +421,32 @@ func (s *nodeStore) collectChildGroups(node *InternalNode, path []byte, flushfn // matching the projection done by serializeSubtree. For StemNodes, the path // is extended using the stem's key bits (same as serializeSubtree). For other // node types, the path is extended with all-zero (left) bits. -func (s *nodeStore) extendPathToGroupLeaf(path []byte, node nodeRef, remainingLevels int) []byte { +func (s *nodeStore) extendPathToGroupLeaf(path BitArray, node nodeRef, remainingLevels int) BitArray { if remainingLevels <= 0 { return path } if node.Kind() == kindStem { sn := s.getStem(node.Index()) - for _ = range remainingLevels { - bit := sn.Stem[len(path)/8] >> (7 - (len(path) % 8)) & 1 + for range remainingLevels { + n := path.Len() + bit := sn.Stem[n/8] >> (7 - (n % 8)) & 1 path = appendBit(path, bit) } } else { // HashedNode or other: all-left extension (matches serializeSubtree's // position << remainingDepth behavior). - for _ = range remainingLevels { + for range remainingLevels { path = appendBit(path, 0) } } return path } -// appendBit appends a bit to a path, returning a new slice -func appendBit(path []byte, bit byte) []byte { - var p [256]byte - copy(p[:], path) - result := p[:len(path)] - return append(result, bit) +// appendBit returns a new BitArray with bit appended to path. +func appendBit(path BitArray, bit uint8) BitArray { + var p BitArray + p.AppendBit(&path, bit) + return p } func (s *nodeStore) toDot(ref nodeRef, parent, path string) string { diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index e3436e3df1..eee84ad92a 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -319,11 +319,11 @@ func (t *BinaryTrie) Hash() common.Hash { func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { nodeset := trienode.NewNodeSet(common.Hash{}) - // Pre-size the path buffer: collectNodes reuses it in-place via - // append/truncate; 32 covers typical binary-trie depth without regrowth. - pathBuf := make([]byte, 0, 32) - t.store.collectNodes(t.store.root, pathBuf, func(path []byte, hash common.Hash, serialized []byte) { - nodeset.AddNode(path, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(path))) + var rootPath BitArray + t.store.collectNodes(t.store.root, rootPath, func(path BitArray, hash common.Hash, serialized []byte) { + var buf [32]byte + pathBytes := path.PutActiveBytes(&buf) + nodeset.AddNode(pathBytes, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(pathBytes))) }, t.groupDepth) return t.Hash(), nodeset } From 012bec0eb1fdf398a6fd7ab49be09ab935f7d9ce Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 8 May 2026 19:36:16 +0800 Subject: [PATCH 2/4] trie/bintrie: postpend bit-length to disambiguate path encoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The compact LSB-aligned encoding via ActiveBytes packed paths into ceil(len/8) bytes without recording the bit-length. Two distinct paths whose bit-lengths fell in the same byte-bucket and whose integer values matched produced identical bytes — e.g. the 1-bit path "1" and the 8-bit path "00000001" both encoded to [0x01], so two stems sitting at depths 1 and 8 on different branches could clobber each other in nodeset.AddNode. Replace ActiveBytes/PutActiveBytes with KeyBytes/PutKeyBytes, which append a uint8 bit-length byte after the active bytes. Postpend (rather than prepend) so nodes along a single root-to-leaf descent share leading path bytes, improving LSM block locality during traversal. The empty path is encoded as no bytes (not [0x00]): byteCount=0 is unique to len=0 so no disambiguation byte is needed. This keeps the root's DB key empty, matching the resolver's existing nil-path convention. --- trie/bintrie/binary_node_test.go | 14 +++++------ trie/bintrie/bitarray.go | 41 ++++++++++++++++++++++++-------- trie/bintrie/internal_node.go | 2 +- trie/bintrie/iterator.go | 2 +- trie/bintrie/trie.go | 4 ++-- 5 files changed, 42 insertions(+), 21 deletions(-) diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go index ff3b8225bb..e30c875e9d 100644 --- a/trie/bintrie/binary_node_test.go +++ b/trie/bintrie/binary_node_test.go @@ -244,29 +244,29 @@ func TestKeyToPath(t *testing.T) { { name: "depth 0", depth: 0, - key: []byte{0x80}, // 10000000 in binary - expected: []byte{1}, // 1 bit packed: MSB=1 → 0x01 + key: []byte{0x80}, // 10000000 in binary + expected: []byte{0x01, 1}, // 1-bit value 0x01 + length byte 1 wantErr: false, }, { name: "depth 7", depth: 7, - key: []byte{0xFF}, // 11111111 in binary - expected: []byte{0xFF}, // 8 bits packed into 1 byte + key: []byte{0xFF}, // 11111111 in binary + expected: []byte{0xFF, 8}, // 8-bit value 0xFF + length byte 8 wantErr: false, }, { name: "depth crossing byte boundary", depth: 10, - key: []byte{0xFF, 0x00}, // 11111111 00000000 in binary - expected: []byte{0x07, 0xF8}, // 11 bits = 11111111000 → 0x07F8 + key: []byte{0xFF, 0x00}, // 11111111 00000000 in binary + expected: []byte{0x07, 0xF8, 11}, // 11-bit value 0x07F8 + length byte 11 wantErr: false, }, { name: "max valid depth", depth: StemSize*8 - 1, key: make([]byte, HashSize), - expected: make([]byte, StemSize), // 248 bits of zeros → 31 packed bytes + expected: append(make([]byte, StemSize), StemSize*8), // 248 bits of zeros + length byte 248 wantErr: false, }, { diff --git a/trie/bintrie/bitarray.go b/trie/bintrie/bitarray.go index 5b39a629d3..b202be94e4 100644 --- a/trie/bintrie/bitarray.go +++ b/trie/bintrie/bitarray.go @@ -294,26 +294,47 @@ func (b *BitArray) Set(x *BitArray) *BitArray { return b } -// ActiveBytes returns a slice containing only the bytes that are actually used by the bit array, -// as specified by the length. The returned slice is in big-endian order. +// KeyBytes returns the path-to-DB-key encoding: the active bytes in big-endian +// order followed by a single trailing byte holding the bit-length. The trailing +// length disambiguates paths whose active bytes coincide (e.g. 1-bit "1" and +// 8-bit "00000001" both pack to integer value 1, but their key encodings are +// [0x01, 0x01] and [0x01, 0x08] respectively). +// +// The empty path is encoded as no bytes: byteCount=0 is unique to len=0, so +// no disambiguation byte is needed. // // Example: // -// len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] -func (b *BitArray) ActiveBytes() []byte { +// len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF, 0x0A] +func (b *BitArray) KeyBytes() []byte { + if b.len == 0 { + return nil + } + bc := b.byteCount() + res := make([]byte, bc+1) wordsBytes := b.Bytes() - return wordsBytes[32-b.byteCount():] + copy(res[:bc], wordsBytes[32-bc:]) + res[bc] = b.len + return res } -// PutActiveBytes writes the active bytes into dst (which must be at least 32 bytes) -// and returns the populated sub-slice. No heap allocation occurs because the -// backing array is owned by the caller. -func (b *BitArray) PutActiveBytes(dst *[32]byte) []byte { +// PutKeyBytes writes the key encoding (active bytes followed by length byte) +// into dst and returns the populated sub-slice. The empty path returns dst[:0] +// without touching dst. For non-empty paths dst must have len >= 33 (32 packed +// bytes for 248 bits + 1 length byte). +func (b *BitArray) PutKeyBytes(dst []byte) []byte { + if b.len == 0 { + return dst[:0] + } + _ = dst[32] // bounds check hint binary.BigEndian.PutUint64(dst[0:8], b.words[3]) binary.BigEndian.PutUint64(dst[8:16], b.words[2]) binary.BigEndian.PutUint64(dst[16:24], b.words[1]) binary.BigEndian.PutUint64(dst[24:32], b.words[0]) - return dst[32-b.byteCount():] + bc := b.byteCount() + copy(dst, dst[32-bc:32]) + dst[bc] = b.len + return dst[:bc+1] } // bitFromLSB returns the bit value at position n, where n = 0 is LSB. diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go index 3d7010a194..992384da3e 100644 --- a/trie/bintrie/internal_node.go +++ b/trie/bintrie/internal_node.go @@ -29,7 +29,7 @@ func keyToPath(depth int, key []byte) ([]byte, error) { keyLen := min(len(key), 31) ba := new(BitArray).SetBytes(uint8(keyLen*8), key[:keyLen]) path := new(BitArray).MSBs(ba, uint8(depth+1)) - return path.ActiveBytes(), nil + return path.KeyBytes(), nil } // Invariant: dirty=false implies mustRecompute=false. Every mutation that diff --git a/trie/bintrie/iterator.go b/trie/bintrie/iterator.go index 88622ddaae..e678ee310d 100644 --- a/trie/bintrie/iterator.go +++ b/trie/bintrie/iterator.go @@ -201,7 +201,7 @@ func (it *binaryNodeIterator) Path() []byte { } path.AppendBit(&path, uint8(state.Index)) } - return path.ActiveBytes() + return path.KeyBytes() } func (it *binaryNodeIterator) NodeBlob() []byte { diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index eee84ad92a..653d419f44 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -321,8 +321,8 @@ func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { var rootPath BitArray t.store.collectNodes(t.store.root, rootPath, func(path BitArray, hash common.Hash, serialized []byte) { - var buf [32]byte - pathBytes := path.PutActiveBytes(&buf) + var buf [33]byte + pathBytes := path.PutKeyBytes(buf[:]) nodeset.AddNode(pathBytes, trienode.NewNodeWithPrev(hash, serialized, t.tracer.Get(pathBytes))) }, t.groupDepth) return t.Hash(), nodeset From a1eaa21f24b0b03fbaf65a12a82256a93054ea08 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 13 May 2026 09:39:19 +0800 Subject: [PATCH 3/4] trie/bintrie: fix hashInternal at group boundaries to match read-back hash For an InternalNode at a group-boundary depth, hashInternal previously computed pure SHA256(left, right) recursion over the natural-depth in-memory tree built by UpdateStem. But serializeSubtree extends stems to the group's bottom layer via key-bit extension, so the on-disk blob encodes an extended-depth structure. When a fresh reader deserializes that blob, hashInternal walks the extended-depth in-memory tree and produces a different value. The result was that for any subtree with multiple stems sharing a prefix shorter than groupDepth, the parent's stored child-hash (computed from the natural-depth in-memory tree at commit time) did not equal the child blob's read-back hash. Geth's own write-read cycle was internally inconsistent: state-actor's groundtruth test, which feeds the same stems through state-actor's streaming builder and geth's UpdateStem + Commit and diffs the resulting on-disk node sets, fails at n=4 with a mismatched slot hash in the root group blob. At a group boundary, recompute the hash via serializeSubtree + groupedRecursiveHash so that the parent stores the same value the reader will compute when it deserializes the child blob. The fix is gated on groupDepth > 0, so nodeStore tests that construct the store directly without going through NewBinaryTrie retain the existing pure-SHA256 recursion semantics. Verification: - All existing trie/bintrie tests pass unchanged. - state-actor/generator's TestStreamingMatchesGethCommit (which compares state-actor's streaming builder output to geth's Commit output byte-for-byte at n=2,4,8,32,128) now passes. --- trie/bintrie/node_store.go | 10 +++++++ trie/bintrie/store_commit.go | 55 ++++++++++++++++++++++++++++++++++++ trie/bintrie/trie.go | 4 ++- 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/trie/bintrie/node_store.go b/trie/bintrie/node_store.go index 8a35f06ee1..8527f6a1d8 100644 --- a/trie/bintrie/node_store.go +++ b/trie/bintrie/node_store.go @@ -41,6 +41,16 @@ type nodeStore struct { // stem-split keeps the old stem at a deeper position), so they don't // have free lists. freeHashed []uint32 + + // groupDepth, when > 0, makes hashInternal compute the same hash that + // would be produced by serializing the node to a group blob and + // recursively hashing the blob's bottom-layer leaves. This matches the + // hash a fresh reader would compute via deserializeSubtree, keeping the + // parent-stored child hash byte-equal to the child's read-back hash. + // When 0, hashInternal falls back to the natural-depth SHA256 recursion + // used by tests that construct nodeStore directly without going through + // NewBinaryTrie. + groupDepth int } func newNodeStore() *nodeStore { diff --git a/trie/bintrie/store_commit.go b/trie/bintrie/store_commit.go index 82220387eb..800ac0546a 100644 --- a/trie/bintrie/store_commit.go +++ b/trie/bintrie/store_commit.go @@ -59,12 +59,30 @@ var parallelHashDepth = min(bits.Len(uint(runtime.NumCPU())), 8) // goroutine while the right subtree is hashed inline, then the two digests // are combined. Below that threshold the goroutine spawn cost outweighs the // hashing work, so deeper nodes hash both children sequentially. +// +// At a group boundary (depth % groupDepth == 0, with groupDepth > 0) the +// hash is computed from the group's bottom-layer slot hashes via the same +// serialize-then-recursive-hash that a fresh reader applies after reading +// the node's blob from disk. This guarantees the parent's stored child +// hash equals the child's read-back hash byte-for-byte, regardless of +// whether the in-memory subtree placed its stems at natural depth (via +// UpdateStem split) or extended depth (via deserializeSubtree). func (s *nodeStore) hashInternal(idx uint32) common.Hash { node := s.getInternal(idx) if !node.mustRecompute { return node.hash } + if s.groupDepth > 0 && int(node.depth)%s.groupDepth == 0 { + bitmapSize := bitmapSizeForDepth(s.groupDepth) + bitmap := make([]byte, bitmapSize) + var hashes []common.Hash + s.serializeSubtree(makeRef(kindInternal, idx), s.groupDepth, 0, int(node.depth), bitmap, &hashes) + node.hash = groupedRecursiveHash(s.groupDepth, bitmap, hashes) + node.mustRecompute = false + return node.hash + } + if int(node.depth) < parallelHashDepth { var input [64]byte var lh common.Hash @@ -107,6 +125,43 @@ func (s *nodeStore) hashInternal(idx uint32) common.Hash { return node.hash } +// groupedRecursiveHash computes the recursive SHA256 hash of a group-blob +// subtree, given the bitmap and present-hash list produced by serializeSubtree. +// +// The output is byte-equal to what hashInternal would compute on a tree +// produced by deserializeSubtree reading the same (bitmap, hashes) — i.e., +// it's the hash the fresh-reader path produces. Use this from hashInternal +// at group-boundary depths so the parent's stored child hash matches the +// child's read-back hash regardless of in-memory stem placement. +func groupedRecursiveHash(groupDepth int, bitmap []byte, hashes []common.Hash) common.Hash { + nSlots := 1 << groupDepth + leaves := make([]common.Hash, nSlots) + hashIdx := 0 + for i := 0; i < nSlots; i++ { + if bitmap[i/8]>>(7-(i%8))&1 == 1 { + leaves[i] = hashes[hashIdx] + hashIdx++ + } + } + level := leaves + var zero common.Hash + for len(level) > 1 { + next := make([]common.Hash, len(level)/2) + for i := 0; i < len(next); i++ { + l, r := level[2*i], level[2*i+1] + if l == zero && r == zero { + continue + } + var buf [64]byte + copy(buf[:32], l[:]) + copy(buf[32:], r[:]) + next[i] = sha256.Sum256(buf[:]) + } + level = next + } + return level[0] +} + // serializeSubtree recursively collects child hashes from a subtree of InternalNodes. // It traverses up to `remainingDepth` levels, storing hashes of bottom-layer children. // position tracks the current index (0 to 2^groupDepth - 1) for bitmap placement. diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index 653d419f44..9507dd91b6 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -133,8 +133,10 @@ func NewBinaryTrie(root common.Hash, db database.NodeDatabase, groupDepth int) ( if err != nil { return nil, err } + store := newNodeStore() + store.groupDepth = groupDepth t := &BinaryTrie{ - store: newNodeStore(), + store: store, reader: reader, tracer: trie.NewPrevalueTracer(), groupDepth: groupDepth, From bdb7b64173fff535b02945c5c0a706ae55751cdd Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 13 May 2026 09:39:31 +0800 Subject: [PATCH 4/4] trie/bintrie: mark promoted stem dirty during splitStemValuesInsert MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When splitStemValuesInsert inserts a new stem that shares a prefix with an existing stem, it increments the existing stem's depth and inserts a new internal node above it. The existing stem's on-disk path is derived from its depth via collectChildGroups + extendPathToGroupLeaf, so promoting its depth means it should be flushed at a new path. Previously, only the new stem (created in the divergence branch) was marked dirty. The promoted existing stem retained whatever dirty value it had — false if it was just deserialized from disk via a HashedNode resolve. collectNodes would then skip flushing the existing stem at its new path, while the new ancestor internal blob (also dirty) overwrites the existing stem's old blob at the prior path. The stem's data is left with no on-disk home, breaking subsequent reads with "missing trie node". The bug surfaces in the integration-test harness (state-actor builds a DB with single-stem-per-slot at depth 8, geth then mutates by adding a new stem that shares ≥8 prefix bits with the existing stem). After mutation, geth's `getValuesAtStem` resolves a HashedNode whose blob should be at the extended-depth path but isn't on disk. Mark `existing.dirty = true` when promoting the depth so collectNodes re-flushes the stem at its new path. Verification: the 100MB integration-test harness (which previously failed at block 9-10 with "missing trie node bdaf89... (path c96010)") now runs cleanly through 200+ blocks of ERC20 deploys and bloat transactions without any missing-trie-node errors. --- trie/bintrie/store_ops.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/trie/bintrie/store_ops.go b/trie/bintrie/store_ops.go index 9a73c8bd64..1924978ad2 100644 --- a/trie/bintrie/store_ops.go +++ b/trie/bintrie/store_ops.go @@ -263,6 +263,12 @@ func (s *nodeStore) splitStemValuesInsert(existingRef nodeRef, newStem []byte, v nRef := s.newInternalRef(int(existing.depth)) nNode := s.getInternal(nRef.Index()) existing.depth++ + // The existing stem's on-disk path is derived from its depth via + // extendPathToGroupLeaf. Promoting its depth changes that path, so the + // stem must be re-flushed at the new path; otherwise the old blob (at + // the prior path) gets overwritten by the new ancestor internal blob + // and the stem's data has no on-disk home. + existing.dirty = true bitKey := newStem[nNode.depth/8] >> (7 - (nNode.depth % 8)) & 1 if bitKey == bitStem {