This commit is contained in:
Sahil Sojitra 2026-05-22 11:42:03 +00:00 committed by GitHub
commit a08b152dea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 435 additions and 1 deletions

View file

@ -128,6 +128,210 @@ func (h *Header) Hash() common.Hash {
return rlpHash(h)
}
// DecodeHeader decodes one RLP-encoded header from b into dst. It avoids the
// generic reflection path used by rlp.DecodeBytes and rejects trailing input in
// the same way.
func DecodeHeader(b []byte, dst *Header) error {
stream := rlp.NewBytesStream(b)
defer rlp.PutStream(stream)
if err := dst.DecodeRLP(stream); err != nil {
return err
}
if stream.Remaining() != 0 {
return rlp.ErrMoreThanOneValue
}
return nil
}
// DecodeRLP decodes a block header.
func (h *Header) DecodeRLP(s *rlp.Stream) error {
if _, err := s.List(); err != nil {
return err
}
if err := s.ReadBytes(h.ParentHash[:]); err != nil {
return fmt.Errorf("read ParentHash: %w", err)
}
if err := s.ReadBytes(h.UncleHash[:]); err != nil {
return fmt.Errorf("read UncleHash: %w", err)
}
if err := s.ReadBytes(h.Coinbase[:]); err != nil {
return fmt.Errorf("read Coinbase: %w", err)
}
if err := s.ReadBytes(h.Root[:]); err != nil {
return fmt.Errorf("read Root: %w", err)
}
if err := s.ReadBytes(h.TxHash[:]); err != nil {
return fmt.Errorf("read TxHash: %w", err)
}
if err := s.ReadBytes(h.ReceiptHash[:]); err != nil {
return fmt.Errorf("read ReceiptHash: %w", err)
}
if err := s.ReadBytes(h.Bloom[:]); err != nil {
return fmt.Errorf("read Bloom: %w", err)
}
if h.Difficulty == nil {
h.Difficulty = new(big.Int)
}
if err := s.ReadBigInt(h.Difficulty); err != nil {
return fmt.Errorf("read Difficulty: %w", err)
}
if h.Number == nil {
h.Number = new(big.Int)
}
if err := s.ReadBigInt(h.Number); err != nil {
return fmt.Errorf("read Number: %w", err)
}
var err error
if h.GasLimit, err = s.Uint64(); err != nil {
return fmt.Errorf("read GasLimit: %w", err)
}
if h.GasUsed, err = s.Uint64(); err != nil {
return fmt.Errorf("read GasUsed: %w", err)
}
if h.Time, err = s.Uint64(); err != nil {
return fmt.Errorf("read Time: %w", err)
}
if h.Extra, err = s.AppendBytes(h.Extra[:0]); err != nil {
return fmt.Errorf("read Extra: %w", err)
}
if err := s.ReadBytes(h.MixDigest[:]); err != nil {
return fmt.Errorf("read MixDigest: %w", err)
}
if err := s.ReadBytes(h.Nonce[:]); err != nil {
return fmt.Errorf("read Nonce: %w", err)
}
if !s.MoreDataInList() {
h.clearOptionalFieldsFrom(headerBaseFeeField)
return s.ListEnd()
}
if h.BaseFee == nil {
h.BaseFee = new(big.Int)
}
if err := s.ReadBigInt(h.BaseFee); err != nil {
return fmt.Errorf("read BaseFee: %w", err)
}
if !s.MoreDataInList() {
h.clearOptionalFieldsFrom(headerWithdrawalsHashField)
return s.ListEnd()
}
if h.WithdrawalsHash == nil {
h.WithdrawalsHash = new(common.Hash)
}
if err := s.ReadBytes(h.WithdrawalsHash[:]); err != nil {
return fmt.Errorf("read WithdrawalsHash: %w", err)
}
if !s.MoreDataInList() {
h.clearOptionalFieldsFrom(headerBlobGasUsedField)
return s.ListEnd()
}
if h.BlobGasUsed == nil {
h.BlobGasUsed = new(uint64)
}
if *h.BlobGasUsed, err = s.Uint64(); err != nil {
return fmt.Errorf("read BlobGasUsed: %w", err)
}
if !s.MoreDataInList() {
h.clearOptionalFieldsFrom(headerExcessBlobGasField)
return s.ListEnd()
}
if h.ExcessBlobGas == nil {
h.ExcessBlobGas = new(uint64)
}
if *h.ExcessBlobGas, err = s.Uint64(); err != nil {
return fmt.Errorf("read ExcessBlobGas: %w", err)
}
if !s.MoreDataInList() {
h.clearOptionalFieldsFrom(headerParentBeaconRootField)
return s.ListEnd()
}
if h.ParentBeaconRoot == nil {
h.ParentBeaconRoot = new(common.Hash)
}
if err := s.ReadBytes(h.ParentBeaconRoot[:]); err != nil {
return fmt.Errorf("read ParentBeaconRoot: %w", err)
}
if !s.MoreDataInList() {
h.clearOptionalFieldsFrom(headerRequestsHashField)
return s.ListEnd()
}
if h.RequestsHash == nil {
h.RequestsHash = new(common.Hash)
}
if err := s.ReadBytes(h.RequestsHash[:]); err != nil {
return fmt.Errorf("read RequestsHash: %w", err)
}
if !s.MoreDataInList() {
h.clearOptionalFieldsFrom(headerBlockAccessListHashField)
return s.ListEnd()
}
if h.BlockAccessListHash == nil {
h.BlockAccessListHash = new(common.Hash)
}
if err := s.ReadBytes(h.BlockAccessListHash[:]); err != nil {
return fmt.Errorf("read BlockAccessListHash: %w", err)
}
if !s.MoreDataInList() {
h.clearOptionalFieldsFrom(headerSlotNumberField)
return s.ListEnd()
}
if h.SlotNumber == nil {
h.SlotNumber = new(uint64)
}
if *h.SlotNumber, err = s.Uint64(); err != nil {
return fmt.Errorf("read SlotNumber: %w", err)
}
if err := s.ListEnd(); err != nil {
return fmt.Errorf("close header: %w", err)
}
return nil
}
const (
headerBaseFeeField = iota
headerWithdrawalsHashField
headerBlobGasUsedField
headerExcessBlobGasField
headerParentBeaconRootField
headerRequestsHashField
headerBlockAccessListHashField
headerSlotNumberField
)
func (h *Header) clearOptionalFieldsFrom(field int) {
if field <= headerBaseFeeField {
h.BaseFee = nil
}
if field <= headerWithdrawalsHashField {
h.WithdrawalsHash = nil
}
if field <= headerBlobGasUsedField {
h.BlobGasUsed = nil
}
if field <= headerExcessBlobGasField {
h.ExcessBlobGas = nil
}
if field <= headerParentBeaconRootField {
h.ParentBeaconRoot = nil
}
if field <= headerRequestsHashField {
h.RequestsHash = nil
}
if field <= headerBlockAccessListHashField {
h.BlockAccessListHash = nil
}
if field <= headerSlotNumberField {
h.SlotNumber = nil
}
}
var headerSize = common.StorageSize(reflect.TypeFor[Header]().Size())
// Size returns the approximate memory used by all internal contents. It is used

View file

@ -373,3 +373,91 @@ func TestRlpDecodeParentHash(t *testing.T) {
}
}
}
func TestDecodeHeaderReuseClearsOptionalFields(t *testing.T) {
full := &Header{
ParentHash: common.Hash{0x01},
UncleHash: EmptyUncleHash,
Coinbase: common.Address{0x02},
Root: common.Hash{0x03},
TxHash: EmptyTxsHash,
ReceiptHash: EmptyReceiptsHash,
Difficulty: big.NewInt(1),
Number: big.NewInt(2),
GasLimit: 30_000_000,
GasUsed: 21_000,
Time: 123,
Extra: []byte{0x04, 0x05},
MixDigest: common.Hash{0x06},
Nonce: BlockNonce{0x07},
BaseFee: big.NewInt(params.InitialBaseFee),
WithdrawalsHash: &common.Hash{0x08},
BlobGasUsed: new(uint64),
ExcessBlobGas: new(uint64),
ParentBeaconRoot: &common.Hash{0x09},
RequestsHash: &common.Hash{0x0a},
BlockAccessListHash: &common.Hash{0x0b},
SlotNumber: new(uint64),
}
*full.BlobGasUsed = 1
*full.ExcessBlobGas = 2
*full.SlotNumber = 3
legacy := *full
legacy.ParentHash = common.Hash{0x11}
legacy.Extra = []byte{0x12}
legacy.BaseFee = nil
legacy.WithdrawalsHash = nil
legacy.BlobGasUsed = nil
legacy.ExcessBlobGas = nil
legacy.ParentBeaconRoot = nil
legacy.RequestsHash = nil
legacy.BlockAccessListHash = nil
legacy.SlotNumber = nil
fullRLP, err := rlp.EncodeToBytes(full)
if err != nil {
t.Fatal(err)
}
legacyRLP, err := rlp.EncodeToBytes(&legacy)
if err != nil {
t.Fatal(err)
}
var decoded Header
if err := DecodeHeader(fullRLP, &decoded); err != nil {
t.Fatalf("decode full header: %v", err)
}
if decoded.BaseFee == nil || decoded.WithdrawalsHash == nil || decoded.BlobGasUsed == nil ||
decoded.ExcessBlobGas == nil || decoded.ParentBeaconRoot == nil || decoded.RequestsHash == nil ||
decoded.BlockAccessListHash == nil || decoded.SlotNumber == nil {
t.Fatalf("full header optional fields not decoded")
}
if err := DecodeHeader(legacyRLP, &decoded); err != nil {
t.Fatalf("decode legacy header: %v", err)
}
if decoded.BaseFee != nil || decoded.WithdrawalsHash != nil || decoded.BlobGasUsed != nil ||
decoded.ExcessBlobGas != nil || decoded.ParentBeaconRoot != nil || decoded.RequestsHash != nil ||
decoded.BlockAccessListHash != nil || decoded.SlotNumber != nil {
t.Fatalf("legacy decode retained optional fields: %#v", decoded)
}
if !bytes.Equal(decoded.Extra, legacy.Extra) {
t.Fatalf("extra mismatch: got %x want %x", decoded.Extra, legacy.Extra)
}
}
func TestDecodeHeaderRejectsTrailingData(t *testing.T) {
enc, err := rlp.EncodeToBytes(&Header{
UncleHash: EmptyUncleHash,
TxHash: EmptyTxsHash,
ReceiptHash: EmptyReceiptsHash,
Difficulty: big.NewInt(1),
Number: big.NewInt(1),
})
if err != nil {
t.Fatal(err)
}
var h Header
if err := DecodeHeader(append(enc, 0x80), &h); err != rlp.ErrMoreThanOneValue {
t.Fatalf("unexpected error: got %v want %v", err, rlp.ErrMoreThanOneValue)
}
}

View file

@ -177,6 +177,46 @@ func (b *testBackend) Handle(*Peer, Packet) error {
// Tests that block headers can be retrieved from a remote chain based on user queries.
func TestGetBlockHeaders69(t *testing.T) { testGetBlockHeaders(t, ETH69) }
func TestDecodeBlockHeadersKeepsDistinctHeaders(t *testing.T) {
headers := []*types.Header{
{
UncleHash: types.EmptyUncleHash,
TxHash: types.EmptyTxsHash,
ReceiptHash: types.EmptyReceiptsHash,
Difficulty: big.NewInt(1),
Number: big.NewInt(1),
Extra: []byte{0x01},
},
{
UncleHash: types.EmptyUncleHash,
TxHash: types.EmptyTxsHash,
ReceiptHash: types.EmptyReceiptsHash,
Difficulty: big.NewInt(2),
Number: big.NewInt(2),
Extra: []byte{0x02},
},
}
list := encodeRL(headers)
decoded, err := decodeBlockHeaders(&list)
if err != nil {
t.Fatal(err)
}
if len(decoded) != len(headers) {
t.Fatalf("decoded header count mismatch: got %d want %d", len(decoded), len(headers))
}
if decoded[0] == decoded[1] {
t.Fatal("decoded headers alias the same object")
}
for i := range headers {
if decoded[i].Hash() != headers[i].Hash() {
t.Fatalf("header %d hash mismatch: got %x want %x", i, decoded[i].Hash(), headers[i].Hash())
}
if !bytes.Equal(decoded[i].Extra, headers[i].Extra) {
t.Fatalf("header %d extra mismatch: got %x want %x", i, decoded[i].Extra, headers[i].Extra)
}
}
}
func testGetBlockHeaders(t *testing.T, protocol uint) {
t.Parallel()

View file

@ -21,6 +21,7 @@ import (
"encoding/json"
"fmt"
"math"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
@ -355,7 +356,7 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error {
if err := peer.tracker.Fulfil(tresp); err != nil {
return fmt.Errorf("BlockHeaders: %w", err)
}
headers, err := res.List.Items()
headers, err := decodeBlockHeaders(&res.List)
if err != nil {
return fmt.Errorf("BlockHeaders: %w", err)
}
@ -374,6 +375,35 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error {
}, metadata)
}
type headerAlloc struct {
header types.Header
difficulty big.Int
number big.Int
}
func decodeBlockHeaders(list *rlp.RawList[*types.Header]) ([]*types.Header, error) {
headers := make([]*types.Header, 0, list.Len())
stream := rlp.NewBytesStream(nil)
defer rlp.PutStream(stream)
it := list.ContentIterator()
for it.Next() {
a := new(headerAlloc)
a.header.Difficulty = &a.difficulty
a.header.Number = &a.number
stream.ResetBytes(it.Value())
if err := a.header.DecodeRLP(stream); err != nil {
return headers, err
}
if stream.Remaining() != 0 {
return headers, rlp.ErrMoreThanOneValue
}
headers = append(headers, &a.header)
}
return headers, nil
}
func handleBlockBodies(backend Backend, msg Decoder, peer *Peer) error {
// A batch of block bodies arrived to one of our previous requests
res := new(BlockBodiesPacket)

View file

@ -586,6 +586,9 @@ type ByteReader interface {
type Stream struct {
r ByteReader
// Inline storage for NewBytesStream so the slice header doesn't escape per call.
sliceRdr sliceReader
remaining uint64 // number of bytes remaining to be read from r
size uint64 // size of value ahead
kinderr error // error from last readKind
@ -596,6 +599,11 @@ type Stream struct {
limited bool // true if input limit is in effect
}
// Remaining returns the number of bytes remaining to be read.
func (s *Stream) Remaining() uint64 {
return s.remaining
}
// NewStream creates a new decoding stream reading from r.
//
// If r implements the ByteReader interface, Stream will
@ -619,6 +627,29 @@ func NewStream(r io.Reader, inputLimit uint64) *Stream {
return s
}
// NewBytesStream returns a pooled Stream reading from b. The caller must return
// the stream with PutStream when decoding is done.
func NewBytesStream(b []byte) *Stream {
stream := streamPool.Get().(*Stream)
stream.sliceRdr = b
stream.Reset(&stream.sliceRdr, uint64(len(b)))
return stream
}
// PutStream returns a Stream obtained from NewBytesStream to the pool.
func PutStream(stream *Stream) {
stream.sliceRdr = nil // release caller's backing array
streamPool.Put(stream)
}
// ResetBytes resets the stream to decode from b, reusing the inline slice
// reader. This allows decoding a batch of values with one pooled stream
// instead of doing a pool round-trip per value.
func (s *Stream) ResetBytes(b []byte) {
s.sliceRdr = b
s.Reset(&s.sliceRdr, uint64(len(b)))
}
// NewListStream creates a new stream that pretends to be positioned
// at an encoded list of the given length.
func NewListStream(r io.Reader, len uint64) *Stream {
@ -686,6 +717,42 @@ func (s *Stream) ReadBytes(b []byte) error {
}
}
// AppendBytes decodes the next RLP string and appends its contents to dst,
// returning the extended slice. Pass dst[:0] to reuse an existing buffer.
func (s *Stream) AppendBytes(dst []byte) ([]byte, error) {
kind, size, err := s.Kind()
if err != nil {
return dst, err
}
switch kind {
case Byte:
s.kind = -1
return append(dst, s.byteval), nil
case String:
cur := len(dst)
if size > uint64(int(^uint(0)>>1)-cur) {
return dst, ErrValueTooLarge
}
need := cur + int(size)
if cap(dst) < need {
grown := make([]byte, need)
copy(grown, dst)
dst = grown
} else {
dst = dst[:need]
}
if err = s.readFull(dst[cur:]); err != nil {
return dst, err
}
if size == 1 && dst[cur] < 128 {
return dst, ErrCanonSize
}
return dst, nil
default:
return dst, ErrExpectedString
}
}
// Raw reads a raw encoded value including RLP type information.
func (s *Stream) Raw() ([]byte, error) {
kind, size, err := s.Kind()
@ -845,6 +912,11 @@ func (s *Stream) BigInt() (*big.Int, error) {
return i, nil
}
// ReadBigInt decodes the next value as a big integer into dst.
func (s *Stream) ReadBigInt(dst *big.Int) error {
return s.decodeBigInt(dst)
}
func (s *Stream) decodeBigInt(dst *big.Int) error {
var buffer []byte
kind, size, err := s.Kind()