From fc19b631052947bd8bd5908ed6881cf5a6777331 Mon Sep 17 00:00:00 2001 From: Sahil Sojitra Date: Fri, 22 May 2026 16:15:31 +0530 Subject: [PATCH] core/types, eth: reduce header RLP decode allocations --- core/types/block.go | 204 ++++++++++++++++++++++++++++++ core/types/block_test.go | 88 +++++++++++++ eth/protocols/eth/handler_test.go | 40 ++++++ eth/protocols/eth/handlers.go | 15 ++- rlp/decode.go | 64 ++++++++++ 5 files changed, 410 insertions(+), 1 deletion(-) diff --git a/core/types/block.go b/core/types/block.go index eab458e88a..9a1d79df4f 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -128,6 +128,210 @@ func (h *Header) Hash() common.Hash { return rlpHash(h) } +// DecodeHeader decodes one RLP-encoded header from b into dst. It avoids the +// generic reflection path used by rlp.DecodeBytes and rejects trailing input in +// the same way. +func DecodeHeader(b []byte, dst *Header) error { + stream := rlp.NewBytesStream(b) + defer rlp.PutStream(stream) + if err := dst.DecodeRLP(stream); err != nil { + return err + } + if stream.Remaining() != 0 { + return rlp.ErrMoreThanOneValue + } + return nil +} + +// DecodeRLP decodes a block header. +func (h *Header) DecodeRLP(s *rlp.Stream) error { + if _, err := s.List(); err != nil { + return err + } + if err := s.ReadBytes(h.ParentHash[:]); err != nil { + return fmt.Errorf("read ParentHash: %w", err) + } + if err := s.ReadBytes(h.UncleHash[:]); err != nil { + return fmt.Errorf("read UncleHash: %w", err) + } + if err := s.ReadBytes(h.Coinbase[:]); err != nil { + return fmt.Errorf("read Coinbase: %w", err) + } + if err := s.ReadBytes(h.Root[:]); err != nil { + return fmt.Errorf("read Root: %w", err) + } + if err := s.ReadBytes(h.TxHash[:]); err != nil { + return fmt.Errorf("read TxHash: %w", err) + } + if err := s.ReadBytes(h.ReceiptHash[:]); err != nil { + return fmt.Errorf("read ReceiptHash: %w", err) + } + if err := s.ReadBytes(h.Bloom[:]); err != nil { + return fmt.Errorf("read Bloom: %w", err) + } + if h.Difficulty == nil { + h.Difficulty = new(big.Int) + } + if err := s.ReadBigInt(h.Difficulty); err != nil { + return fmt.Errorf("read Difficulty: %w", err) + } + if h.Number == nil { + h.Number = new(big.Int) + } + if err := s.ReadBigInt(h.Number); err != nil { + return fmt.Errorf("read Number: %w", err) + } + var err error + if h.GasLimit, err = s.Uint64(); err != nil { + return fmt.Errorf("read GasLimit: %w", err) + } + if h.GasUsed, err = s.Uint64(); err != nil { + return fmt.Errorf("read GasUsed: %w", err) + } + if h.Time, err = s.Uint64(); err != nil { + return fmt.Errorf("read Time: %w", err) + } + if h.Extra, err = s.AppendBytes(h.Extra[:0]); err != nil { + return fmt.Errorf("read Extra: %w", err) + } + if err := s.ReadBytes(h.MixDigest[:]); err != nil { + return fmt.Errorf("read MixDigest: %w", err) + } + if err := s.ReadBytes(h.Nonce[:]); err != nil { + return fmt.Errorf("read Nonce: %w", err) + } + + if !s.MoreDataInList() { + h.clearOptionalFieldsFrom(headerBaseFeeField) + return s.ListEnd() + } + if h.BaseFee == nil { + h.BaseFee = new(big.Int) + } + if err := s.ReadBigInt(h.BaseFee); err != nil { + return fmt.Errorf("read BaseFee: %w", err) + } + + if !s.MoreDataInList() { + h.clearOptionalFieldsFrom(headerWithdrawalsHashField) + return s.ListEnd() + } + if h.WithdrawalsHash == nil { + h.WithdrawalsHash = new(common.Hash) + } + if err := s.ReadBytes(h.WithdrawalsHash[:]); err != nil { + return fmt.Errorf("read WithdrawalsHash: %w", err) + } + + if !s.MoreDataInList() { + h.clearOptionalFieldsFrom(headerBlobGasUsedField) + return s.ListEnd() + } + if h.BlobGasUsed == nil { + h.BlobGasUsed = new(uint64) + } + if *h.BlobGasUsed, err = s.Uint64(); err != nil { + return fmt.Errorf("read BlobGasUsed: %w", err) + } + + if !s.MoreDataInList() { + h.clearOptionalFieldsFrom(headerExcessBlobGasField) + return s.ListEnd() + } + if h.ExcessBlobGas == nil { + h.ExcessBlobGas = new(uint64) + } + if *h.ExcessBlobGas, err = s.Uint64(); err != nil { + return fmt.Errorf("read ExcessBlobGas: %w", err) + } + + if !s.MoreDataInList() { + h.clearOptionalFieldsFrom(headerParentBeaconRootField) + return s.ListEnd() + } + if h.ParentBeaconRoot == nil { + h.ParentBeaconRoot = new(common.Hash) + } + if err := s.ReadBytes(h.ParentBeaconRoot[:]); err != nil { + return fmt.Errorf("read ParentBeaconRoot: %w", err) + } + + if !s.MoreDataInList() { + h.clearOptionalFieldsFrom(headerRequestsHashField) + return s.ListEnd() + } + if h.RequestsHash == nil { + h.RequestsHash = new(common.Hash) + } + if err := s.ReadBytes(h.RequestsHash[:]); err != nil { + return fmt.Errorf("read RequestsHash: %w", err) + } + + if !s.MoreDataInList() { + h.clearOptionalFieldsFrom(headerBlockAccessListHashField) + return s.ListEnd() + } + if h.BlockAccessListHash == nil { + h.BlockAccessListHash = new(common.Hash) + } + if err := s.ReadBytes(h.BlockAccessListHash[:]); err != nil { + return fmt.Errorf("read BlockAccessListHash: %w", err) + } + + if !s.MoreDataInList() { + h.clearOptionalFieldsFrom(headerSlotNumberField) + return s.ListEnd() + } + if h.SlotNumber == nil { + h.SlotNumber = new(uint64) + } + if *h.SlotNumber, err = s.Uint64(); err != nil { + return fmt.Errorf("read SlotNumber: %w", err) + } + if err := s.ListEnd(); err != nil { + return fmt.Errorf("close header: %w", err) + } + return nil +} + +const ( + headerBaseFeeField = iota + headerWithdrawalsHashField + headerBlobGasUsedField + headerExcessBlobGasField + headerParentBeaconRootField + headerRequestsHashField + headerBlockAccessListHashField + headerSlotNumberField +) + +func (h *Header) clearOptionalFieldsFrom(field int) { + if field <= headerBaseFeeField { + h.BaseFee = nil + } + if field <= headerWithdrawalsHashField { + h.WithdrawalsHash = nil + } + if field <= headerBlobGasUsedField { + h.BlobGasUsed = nil + } + if field <= headerExcessBlobGasField { + h.ExcessBlobGas = nil + } + if field <= headerParentBeaconRootField { + h.ParentBeaconRoot = nil + } + if field <= headerRequestsHashField { + h.RequestsHash = nil + } + if field <= headerBlockAccessListHashField { + h.BlockAccessListHash = nil + } + if field <= headerSlotNumberField { + h.SlotNumber = nil + } +} + var headerSize = common.StorageSize(reflect.TypeFor[Header]().Size()) // Size returns the approximate memory used by all internal contents. It is used diff --git a/core/types/block_test.go b/core/types/block_test.go index 5fa4756a50..f288acd980 100644 --- a/core/types/block_test.go +++ b/core/types/block_test.go @@ -373,3 +373,91 @@ func TestRlpDecodeParentHash(t *testing.T) { } } } + +func TestDecodeHeaderReuseClearsOptionalFields(t *testing.T) { + full := &Header{ + ParentHash: common.Hash{0x01}, + UncleHash: EmptyUncleHash, + Coinbase: common.Address{0x02}, + Root: common.Hash{0x03}, + TxHash: EmptyTxsHash, + ReceiptHash: EmptyReceiptsHash, + Difficulty: big.NewInt(1), + Number: big.NewInt(2), + GasLimit: 30_000_000, + GasUsed: 21_000, + Time: 123, + Extra: []byte{0x04, 0x05}, + MixDigest: common.Hash{0x06}, + Nonce: BlockNonce{0x07}, + BaseFee: big.NewInt(params.InitialBaseFee), + WithdrawalsHash: &common.Hash{0x08}, + BlobGasUsed: new(uint64), + ExcessBlobGas: new(uint64), + ParentBeaconRoot: &common.Hash{0x09}, + RequestsHash: &common.Hash{0x0a}, + BlockAccessListHash: &common.Hash{0x0b}, + SlotNumber: new(uint64), + } + *full.BlobGasUsed = 1 + *full.ExcessBlobGas = 2 + *full.SlotNumber = 3 + + legacy := *full + legacy.ParentHash = common.Hash{0x11} + legacy.Extra = []byte{0x12} + legacy.BaseFee = nil + legacy.WithdrawalsHash = nil + legacy.BlobGasUsed = nil + legacy.ExcessBlobGas = nil + legacy.ParentBeaconRoot = nil + legacy.RequestsHash = nil + legacy.BlockAccessListHash = nil + legacy.SlotNumber = nil + + fullRLP, err := rlp.EncodeToBytes(full) + if err != nil { + t.Fatal(err) + } + legacyRLP, err := rlp.EncodeToBytes(&legacy) + if err != nil { + t.Fatal(err) + } + var decoded Header + if err := DecodeHeader(fullRLP, &decoded); err != nil { + t.Fatalf("decode full header: %v", err) + } + if decoded.BaseFee == nil || decoded.WithdrawalsHash == nil || decoded.BlobGasUsed == nil || + decoded.ExcessBlobGas == nil || decoded.ParentBeaconRoot == nil || decoded.RequestsHash == nil || + decoded.BlockAccessListHash == nil || decoded.SlotNumber == nil { + t.Fatalf("full header optional fields not decoded") + } + if err := DecodeHeader(legacyRLP, &decoded); err != nil { + t.Fatalf("decode legacy header: %v", err) + } + if decoded.BaseFee != nil || decoded.WithdrawalsHash != nil || decoded.BlobGasUsed != nil || + decoded.ExcessBlobGas != nil || decoded.ParentBeaconRoot != nil || decoded.RequestsHash != nil || + decoded.BlockAccessListHash != nil || decoded.SlotNumber != nil { + t.Fatalf("legacy decode retained optional fields: %#v", decoded) + } + if !bytes.Equal(decoded.Extra, legacy.Extra) { + t.Fatalf("extra mismatch: got %x want %x", decoded.Extra, legacy.Extra) + } +} + +func TestDecodeHeaderRejectsTrailingData(t *testing.T) { + enc, err := rlp.EncodeToBytes(&Header{ + UncleHash: EmptyUncleHash, + TxHash: EmptyTxsHash, + ReceiptHash: EmptyReceiptsHash, + Difficulty: big.NewInt(1), + Number: big.NewInt(1), + }) + if err != nil { + t.Fatal(err) + } + var h Header + if err := DecodeHeader(append(enc, 0x80), &h); err != rlp.ErrMoreThanOneValue { + t.Fatalf("unexpected error: got %v want %v", err, rlp.ErrMoreThanOneValue) + } +} diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go index 3f40fdb3b3..123d0f6113 100644 --- a/eth/protocols/eth/handler_test.go +++ b/eth/protocols/eth/handler_test.go @@ -177,6 +177,46 @@ func (b *testBackend) Handle(*Peer, Packet) error { // Tests that block headers can be retrieved from a remote chain based on user queries. func TestGetBlockHeaders69(t *testing.T) { testGetBlockHeaders(t, ETH69) } +func TestDecodeBlockHeadersKeepsDistinctHeaders(t *testing.T) { + headers := []*types.Header{ + { + UncleHash: types.EmptyUncleHash, + TxHash: types.EmptyTxsHash, + ReceiptHash: types.EmptyReceiptsHash, + Difficulty: big.NewInt(1), + Number: big.NewInt(1), + Extra: []byte{0x01}, + }, + { + UncleHash: types.EmptyUncleHash, + TxHash: types.EmptyTxsHash, + ReceiptHash: types.EmptyReceiptsHash, + Difficulty: big.NewInt(2), + Number: big.NewInt(2), + Extra: []byte{0x02}, + }, + } + list := encodeRL(headers) + decoded, err := decodeBlockHeaders(&list) + if err != nil { + t.Fatal(err) + } + if len(decoded) != len(headers) { + t.Fatalf("decoded header count mismatch: got %d want %d", len(decoded), len(headers)) + } + if decoded[0] == decoded[1] { + t.Fatal("decoded headers alias the same object") + } + for i := range headers { + if decoded[i].Hash() != headers[i].Hash() { + t.Fatalf("header %d hash mismatch: got %x want %x", i, decoded[i].Hash(), headers[i].Hash()) + } + if !bytes.Equal(decoded[i].Extra, headers[i].Extra) { + t.Fatalf("header %d extra mismatch: got %x want %x", i, decoded[i].Extra, headers[i].Extra) + } + } +} + func testGetBlockHeaders(t *testing.T, protocol uint) { t.Parallel() diff --git a/eth/protocols/eth/handlers.go b/eth/protocols/eth/handlers.go index 71942cc9ad..742cdd9ec3 100644 --- a/eth/protocols/eth/handlers.go +++ b/eth/protocols/eth/handlers.go @@ -355,7 +355,7 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error { if err := peer.tracker.Fulfil(tresp); err != nil { return fmt.Errorf("BlockHeaders: %w", err) } - headers, err := res.List.Items() + headers, err := decodeBlockHeaders(&res.List) if err != nil { return fmt.Errorf("BlockHeaders: %w", err) } @@ -374,6 +374,19 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error { }, metadata) } +func decodeBlockHeaders(list *rlp.RawList[*types.Header]) ([]*types.Header, error) { + headers := make([]*types.Header, 0, list.Len()) + it := list.ContentIterator() + for it.Next() { + header := new(types.Header) + if err := types.DecodeHeader(it.Value(), header); err != nil { + return headers, err + } + headers = append(headers, header) + } + return headers, nil +} + func handleBlockBodies(backend Backend, msg Decoder, peer *Peer) error { // A batch of block bodies arrived to one of our previous requests res := new(BlockBodiesPacket) diff --git a/rlp/decode.go b/rlp/decode.go index 19074072fb..93a8250155 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -586,6 +586,9 @@ type ByteReader interface { type Stream struct { r ByteReader + // Inline storage for NewBytesStream so the slice header doesn't escape per call. + sliceRdr sliceReader + remaining uint64 // number of bytes remaining to be read from r size uint64 // size of value ahead kinderr error // error from last readKind @@ -596,6 +599,11 @@ type Stream struct { limited bool // true if input limit is in effect } +// Remaining returns the number of bytes remaining to be read. +func (s *Stream) Remaining() uint64 { + return s.remaining +} + // NewStream creates a new decoding stream reading from r. // // If r implements the ByteReader interface, Stream will @@ -619,6 +627,21 @@ func NewStream(r io.Reader, inputLimit uint64) *Stream { return s } +// NewBytesStream returns a pooled Stream reading from b. The caller must return +// the stream with PutStream when decoding is done. +func NewBytesStream(b []byte) *Stream { + stream := streamPool.Get().(*Stream) + stream.sliceRdr = b + stream.Reset(&stream.sliceRdr, uint64(len(b))) + return stream +} + +// PutStream returns a Stream obtained from NewBytesStream to the pool. +func PutStream(stream *Stream) { + stream.sliceRdr = nil // release caller's backing array + streamPool.Put(stream) +} + // NewListStream creates a new stream that pretends to be positioned // at an encoded list of the given length. func NewListStream(r io.Reader, len uint64) *Stream { @@ -686,6 +709,42 @@ func (s *Stream) ReadBytes(b []byte) error { } } +// AppendBytes decodes the next RLP string and appends its contents to dst, +// returning the extended slice. Pass dst[:0] to reuse an existing buffer. +func (s *Stream) AppendBytes(dst []byte) ([]byte, error) { + kind, size, err := s.Kind() + if err != nil { + return dst, err + } + switch kind { + case Byte: + s.kind = -1 + return append(dst, s.byteval), nil + case String: + cur := len(dst) + if size > uint64(int(^uint(0)>>1)-cur) { + return dst, ErrValueTooLarge + } + need := cur + int(size) + if cap(dst) < need { + grown := make([]byte, need) + copy(grown, dst) + dst = grown + } else { + dst = dst[:need] + } + if err = s.readFull(dst[cur:]); err != nil { + return dst, err + } + if size == 1 && dst[cur] < 128 { + return dst, ErrCanonSize + } + return dst, nil + default: + return dst, ErrExpectedString + } +} + // Raw reads a raw encoded value including RLP type information. func (s *Stream) Raw() ([]byte, error) { kind, size, err := s.Kind() @@ -845,6 +904,11 @@ func (s *Stream) BigInt() (*big.Int, error) { return i, nil } +// ReadBigInt decodes the next value as a big integer into dst. +func (s *Stream) ReadBigInt(dst *big.Int) error { + return s.decodeBigInt(dst) +} + func (s *Stream) decodeBigInt(dst *big.Int) error { var buffer []byte kind, size, err := s.Kind()