diff --git a/cmd/devp2p/internal/ethtest/protocol.go b/cmd/devp2p/internal/ethtest/protocol.go
index af76082318..a21d1ca7a1 100644
--- a/cmd/devp2p/internal/ethtest/protocol.go
+++ b/cmd/devp2p/internal/ethtest/protocol.go
@@ -86,9 +86,3 @@ func protoOffset(proto Proto) uint64 {
panic("unhandled protocol")
}
}
-
-// msgTypePtr is the constraint for protocol message types.
-type msgTypePtr[U any] interface {
- *U
- Kind() byte
-}
diff --git a/cmd/devp2p/internal/ethtest/snap.go b/cmd/devp2p/internal/ethtest/snap.go
index f4fce0931f..7c1ca70cc0 100644
--- a/cmd/devp2p/internal/ethtest/snap.go
+++ b/cmd/devp2p/internal/ethtest/snap.go
@@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/eth/protocols/snap"
"github.com/ethereum/go-ethereum/internal/utesting"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/trie/trienode"
)
@@ -937,10 +938,14 @@ func (s *Suite) snapGetTrieNodes(t *utesting.T, tc *trieNodesTest) error {
}
// write0 request
+ paths, err := rlp.EncodeToRawList(tc.paths)
+ if err != nil {
+ panic(err)
+ }
req := &snap.GetTrieNodesPacket{
ID: uint64(rand.Int63()),
Root: tc.root,
- Paths: tc.paths,
+ Paths: paths,
Bytes: tc.nBytes,
}
msg, err := conn.snapRequest(snap.GetTrieNodesMsg, req)
diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go
index c23360bf82..8bb488e358 100644
--- a/cmd/devp2p/internal/ethtest/suite.go
+++ b/cmd/devp2p/internal/ethtest/suite.go
@@ -34,6 +34,7 @@ import (
"github.com/ethereum/go-ethereum/internal/utesting"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/holiman/uint256"
)
@@ -151,7 +152,11 @@ func (s *Suite) TestGetBlockHeaders(t *utesting.T) {
if err != nil {
t.Fatalf("failed to get headers for given request: %v", err)
}
- if !headersMatch(expected, headers.BlockHeadersRequest) {
+ received, err := headers.List.Items()
+ if err != nil {
+ t.Fatalf("invalid headers received: %v", err)
+ }
+ if !headersMatch(expected, received) {
t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers)
}
}
@@ -237,7 +242,7 @@ concurrently, with different request IDs.`)
// Wait for responses.
// Note they can arrive in either order.
- resp, err := collectResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 {
+ resp, err := collectHeaderResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 {
if msg.RequestId != 111 && msg.RequestId != 222 {
t.Fatalf("response with unknown request ID: %v", msg.RequestId)
}
@@ -248,17 +253,11 @@ concurrently, with different request IDs.`)
}
// Check if headers match.
- resp1 := resp[111]
- if expected, err := s.chain.GetHeaders(req1); err != nil {
- t.Fatalf("failed to get expected headers for request 1: %v", err)
- } else if !headersMatch(expected, resp1.BlockHeadersRequest) {
- t.Fatalf("header mismatch for request ID %v: \nexpected %v \ngot %v", 111, expected, resp1)
+ if err := s.checkHeadersAgainstChain(req1, resp[111]); err != nil {
+ t.Fatal(err)
}
- resp2 := resp[222]
- if expected, err := s.chain.GetHeaders(req2); err != nil {
- t.Fatalf("failed to get expected headers for request 2: %v", err)
- } else if !headersMatch(expected, resp2.BlockHeadersRequest) {
- t.Fatalf("header mismatch for request ID %v: \nexpected %v \ngot %v", 222, expected, resp2)
+ if err := s.checkHeadersAgainstChain(req2, resp[222]); err != nil {
+ t.Fatal(err)
}
}
@@ -303,8 +302,8 @@ same request ID. The node should handle the request by responding to both reques
// Wait for the responses. They can arrive in either order, and we can't tell them
// apart by their request ID, so use the number of headers instead.
- resp, err := collectResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 {
- id := uint64(len(msg.BlockHeadersRequest))
+ resp, err := collectHeaderResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 {
+ id := uint64(msg.List.Len())
if id != 2 && id != 3 {
t.Fatalf("invalid number of headers in response: %d", id)
}
@@ -315,26 +314,35 @@ same request ID. The node should handle the request by responding to both reques
}
// Check if headers match.
- resp1 := resp[2]
- if expected, err := s.chain.GetHeaders(request1); err != nil {
- t.Fatalf("failed to get expected headers for request 1: %v", err)
- } else if !headersMatch(expected, resp1.BlockHeadersRequest) {
- t.Fatalf("headers mismatch: \nexpected %v \ngot %v", expected, resp1)
+ if err := s.checkHeadersAgainstChain(request1, resp[2]); err != nil {
+ t.Fatal(err)
}
- resp2 := resp[3]
- if expected, err := s.chain.GetHeaders(request2); err != nil {
- t.Fatalf("failed to get expected headers for request 2: %v", err)
- } else if !headersMatch(expected, resp2.BlockHeadersRequest) {
- t.Fatalf("headers mismatch: \nexpected %v \ngot %v", expected, resp2)
+ if err := s.checkHeadersAgainstChain(request2, resp[3]); err != nil {
+ t.Fatal(err)
}
}
+func (s *Suite) checkHeadersAgainstChain(req *eth.GetBlockHeadersPacket, resp *eth.BlockHeadersPacket) error {
+ received2, err := resp.List.Items()
+ if err != nil {
+ return fmt.Errorf("invalid headers in response with request ID %v (%d items): %v", resp.RequestId, resp.List.Len(), err)
+ }
+ if expected, err := s.chain.GetHeaders(req); err != nil {
+ return fmt.Errorf("test chain failed to get expected headers for request: %v", err)
+ } else if !headersMatch(expected, received2) {
+ return fmt.Errorf("header mismatch for request ID %v (%d items): \nexpected %v \ngot %v", resp.RequestId, resp.List.Len(), expected, resp)
+ }
+ return nil
+}
+
// collectResponses waits for n messages of type T on the given connection.
// The messsages are collected according to the 'identity' function.
-func collectResponses[T any, P msgTypePtr[T]](conn *Conn, n int, identity func(P) uint64) (map[uint64]P, error) {
- resp := make(map[uint64]P, n)
+//
+// This function is written in a generic way to handle
+func collectHeaderResponses(conn *Conn, n int, identity func(*eth.BlockHeadersPacket) uint64) (map[uint64]*eth.BlockHeadersPacket, error) {
+ resp := make(map[uint64]*eth.BlockHeadersPacket, n)
for range n {
- r := new(T)
+ r := new(eth.BlockHeadersPacket)
if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, r); err != nil {
return resp, fmt.Errorf("read error: %v", err)
}
@@ -373,10 +381,8 @@ and expects a response.`)
if got, want := headers.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id")
}
- if expected, err := s.chain.GetHeaders(req); err != nil {
- t.Fatalf("failed to get expected block headers: %v", err)
- } else if !headersMatch(expected, headers.BlockHeadersRequest) {
- t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers)
+ if err := s.checkHeadersAgainstChain(req, headers); err != nil {
+ t.Fatal(err)
}
}
@@ -407,9 +413,8 @@ func (s *Suite) TestGetBlockBodies(t *utesting.T) {
if got, want := resp.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id in respond", got, want)
}
- bodies := resp.BlockBodiesResponse
- if len(bodies) != len(req.GetBlockBodiesRequest) {
- t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetBlockBodiesRequest), len(bodies))
+ if resp.List.Len() != len(req.GetBlockBodiesRequest) {
+ t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetBlockBodiesRequest), resp.List.Len())
}
}
@@ -433,7 +438,7 @@ func (s *Suite) TestGetReceipts(t *utesting.T) {
}
}
- // Create block bodies request.
+ // Create receipts request.
req := ð.GetReceiptsPacket{
RequestId: 66,
GetReceiptsRequest: (eth.GetReceiptsRequest)(hashes),
@@ -449,8 +454,8 @@ func (s *Suite) TestGetReceipts(t *utesting.T) {
if got, want := resp.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id in respond", got, want)
}
- if len(resp.List) != len(req.GetReceiptsRequest) {
- t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetReceiptsRequest), len(resp.List))
+ if resp.List.Len() != len(req.GetReceiptsRequest) {
+ t.Fatalf("wrong receipts in response: expected %d receipts, got %d", len(req.GetReceiptsRequest), resp.List.Len())
}
}
@@ -804,7 +809,11 @@ on another peer connection using GetPooledTransactions.`)
if got, want := msg.RequestId, req.RequestId; got != want {
t.Fatalf("unexpected request id in response: got %d, want %d", got, want)
}
- for _, got := range msg.PooledTransactionsResponse {
+ responseTxs, err := msg.List.Items()
+ if err != nil {
+ t.Fatalf("invalid transactions in response: %v", err)
+ }
+ for _, got := range responseTxs {
if _, exists := set[got.Hash()]; !exists {
t.Fatalf("unexpected tx received: %v", got.Hash())
}
@@ -976,7 +985,9 @@ func (s *Suite) TestBlobViolations(t *utesting.T) {
if err := conn.ReadMsg(ethProto, eth.GetPooledTransactionsMsg, req); err != nil {
t.Fatalf("reading pooled tx request failed: %v", err)
}
- resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: test.resp}
+
+ encTxs, _ := rlp.EncodeToRawList(test.resp)
+ resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs}
if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil {
t.Fatalf("writing pooled tx response failed: %v", err)
}
@@ -1104,7 +1115,8 @@ func (s *Suite) testBadBlobTx(t *utesting.T, tx *types.Transaction, badTx *types
// the good peer is connected, and has announced the tx.
// proceed to send the incorrect one from the bad peer.
- resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: eth.PooledTransactionsResponse(types.Transactions{badTx})}
+ encTxs, _ := rlp.EncodeToRawList([]*types.Transaction{badTx})
+ resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs}
if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil {
errc <- fmt.Errorf("writing pooled tx response failed: %v", err)
return
@@ -1164,7 +1176,8 @@ func (s *Suite) testBadBlobTx(t *utesting.T, tx *types.Transaction, badTx *types
return
}
- resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: eth.PooledTransactionsResponse(types.Transactions{tx})}
+ encTxs, _ := rlp.EncodeToRawList([]*types.Transaction{tx})
+ resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs}
if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil {
errc <- fmt.Errorf("writing pooled tx response failed: %v", err)
return
diff --git a/cmd/devp2p/internal/ethtest/transaction.go b/cmd/devp2p/internal/ethtest/transaction.go
index cbbbbce8d9..8ce26f3e1a 100644
--- a/cmd/devp2p/internal/ethtest/transaction.go
+++ b/cmd/devp2p/internal/ethtest/transaction.go
@@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/internal/utesting"
+ "github.com/ethereum/go-ethereum/rlp"
)
// sendTxs sends the given transactions to the node and
@@ -51,7 +52,8 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error {
return fmt.Errorf("peering failed: %v", err)
}
- if err = sendConn.Write(ethProto, eth.TransactionsMsg, eth.TransactionsPacket(txs)); err != nil {
+ encTxs, _ := rlp.EncodeToRawList(txs)
+ if err = sendConn.Write(ethProto, eth.TransactionsMsg, eth.TransactionsPacket{RawList: encTxs}); err != nil {
return fmt.Errorf("failed to write message to connection: %v", err)
}
@@ -68,7 +70,8 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error {
}
switch msg := msg.(type) {
case *eth.TransactionsPacket:
- for _, tx := range *msg {
+ txs, _ := msg.Items()
+ for _, tx := range txs {
got[tx.Hash()] = true
}
case *eth.NewPooledTransactionHashesPacket:
@@ -80,9 +83,10 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error {
if err != nil {
t.Logf("invalid GetBlockHeaders request: %v", err)
}
+ encHeaders, _ := rlp.EncodeToRawList(headers)
recvConn.Write(ethProto, eth.BlockHeadersMsg, ð.BlockHeadersPacket{
- RequestId: msg.RequestId,
- BlockHeadersRequest: headers,
+ RequestId: msg.RequestId,
+ List: encHeaders,
})
default:
return fmt.Errorf("unexpected eth wire msg: %s", pretty.Sdump(msg))
@@ -167,9 +171,10 @@ func (s *Suite) sendInvalidTxs(t *utesting.T, txs []*types.Transaction) error {
if err != nil {
t.Logf("invalid GetBlockHeaders request: %v", err)
}
+ encHeaders, _ := rlp.EncodeToRawList(headers)
recvConn.Write(ethProto, eth.BlockHeadersMsg, ð.BlockHeadersPacket{
- RequestId: msg.RequestId,
- BlockHeadersRequest: headers,
+ RequestId: msg.RequestId,
+ List: encHeaders,
})
default:
return fmt.Errorf("unexpected eth message: %v", pretty.Sdump(msg))
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index 7fa2522a3d..e5d4a7c59b 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -214,10 +214,12 @@ func (dlp *downloadTesterPeer) RequestHeadersByNumber(origin uint64, amount int,
func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *eth.Response) (*eth.Request, error) {
blobs := eth.ServiceGetBlockBodiesQuery(dlp.chain, hashes)
- bodies := make([]*eth.BlockBody, len(blobs))
+ bodies := make([]*types.Body, len(blobs))
+ ethbodies := make([]eth.BlockBody, len(blobs))
for i, blob := range blobs {
- bodies[i] = new(eth.BlockBody)
+ bodies[i] = new(types.Body)
rlp.DecodeBytes(blob, bodies[i])
+ rlp.DecodeBytes(blob, ðbodies[i])
}
var (
txsHashes = make([]common.Hash, len(bodies))
@@ -239,9 +241,13 @@ func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *et
Peer: dlp.id,
}
res := ð.Response{
- Req: req,
- Res: (*eth.BlockBodiesResponse)(&bodies),
- Meta: [][]common.Hash{txsHashes, uncleHashes, withdrawalHashes},
+ Req: req,
+ Res: (*eth.BlockBodiesResponse)(ðbodies),
+ Meta: eth.BlockBodyHashes{
+ TransactionRoots: txsHashes,
+ UncleHashes: uncleHashes,
+ WithdrawalRoots: withdrawalHashes,
+ },
Time: 1,
Done: make(chan error, 1), // Ignore the returned status
}
@@ -290,14 +296,14 @@ func (dlp *downloadTesterPeer) ID() string {
// RequestAccountRange fetches a batch of accounts rooted in a specific account
// trie, starting with the origin.
-func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error {
+func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error {
// Create the request and service it
req := &snap.GetAccountRangePacket{
ID: id,
Root: root,
Origin: origin,
Limit: limit,
- Bytes: bytes,
+ Bytes: uint64(bytes),
}
slimaccs, proofs := snap.ServiceGetAccountRangeQuery(dlp.chain, req)
@@ -316,7 +322,7 @@ func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limi
// RequestStorageRanges fetches a batch of storage slots belonging to one or
// more accounts. If slots from only one account is requested, an origin marker
// may also be used to retrieve from there.
-func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
+func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error {
// Create the request and service it
req := &snap.GetStorageRangesPacket{
ID: id,
@@ -324,7 +330,7 @@ func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash,
Root: root,
Origin: origin,
Limit: limit,
- Bytes: bytes,
+ Bytes: uint64(bytes),
}
storage, proofs := snap.ServiceGetStorageRangesQuery(dlp.chain, req)
@@ -341,25 +347,28 @@ func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash,
}
// RequestByteCodes fetches a batch of bytecodes by hash.
-func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
+func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error {
req := &snap.GetByteCodesPacket{
ID: id,
Hashes: hashes,
- Bytes: bytes,
+ Bytes: uint64(bytes),
}
codes := snap.ServiceGetByteCodesQuery(dlp.chain, req)
go dlp.dl.downloader.SnapSyncer.OnByteCodes(dlp, id, codes)
return nil
}
-// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
-// a specific state trie.
-func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, paths []snap.TrieNodePathSet, bytes uint64) error {
+// RequestTrieNodes fetches a batch of account or storage trie nodes.
+func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []snap.TrieNodePathSet, bytes int) error {
+ encPaths, err := rlp.EncodeToRawList(paths)
+ if err != nil {
+ panic(err)
+ }
req := &snap.GetTrieNodesPacket{
ID: id,
Root: root,
- Paths: paths,
- Bytes: bytes,
+ Paths: encPaths,
+ Bytes: uint64(bytes),
}
nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req, time.Now())
go dlp.dl.downloader.SnapSyncer.OnTrieNodes(dlp, id, nodes)
diff --git a/eth/downloader/fetchers_concurrent_bodies.go b/eth/downloader/fetchers_concurrent_bodies.go
index 56359b33c9..6a8eb35219 100644
--- a/eth/downloader/fetchers_concurrent_bodies.go
+++ b/eth/downloader/fetchers_concurrent_bodies.go
@@ -88,15 +88,14 @@ func (q *bodyQueue) request(peer *peerConnection, req *fetchRequest, resCh chan
// deliver is responsible for taking a generic response packet from the concurrent
// fetcher, unpacking the body data and delivering it to the downloader's queue.
func (q *bodyQueue) deliver(peer *peerConnection, packet *eth.Response) (int, error) {
- txs, uncles, withdrawals := packet.Res.(*eth.BlockBodiesResponse).Unpack()
- hashsets := packet.Meta.([][]common.Hash) // {txs hashes, uncle hashes, withdrawal hashes}
-
- accepted, err := q.queue.DeliverBodies(peer.id, txs, hashsets[0], uncles, hashsets[1], withdrawals, hashsets[2])
+ resp := packet.Res.(*eth.BlockBodiesResponse)
+ meta := packet.Meta.(eth.BlockBodyHashes)
+ accepted, err := q.queue.DeliverBodies(peer.id, meta, *resp)
switch {
- case err == nil && len(txs) == 0:
+ case err == nil && len(*resp) == 0:
peer.log.Trace("Requested bodies delivered")
case err == nil:
- peer.log.Trace("Delivered new batch of bodies", "count", len(txs), "accepted", accepted)
+ peer.log.Trace("Delivered new batch of bodies", "count", len(*resp), "accepted", accepted)
default:
peer.log.Debug("Failed to deliver retrieved bodies", "err", err)
}
diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go
index 76a14345e5..c0cb9b174a 100644
--- a/eth/downloader/queue.go
+++ b/eth/downloader/queue.go
@@ -29,11 +29,10 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/prque"
"github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/crypto/kzg4844"
"github.com/ethereum/go-ethereum/eth/ethconfig"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
- "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -559,63 +558,54 @@ func (q *queue) expire(peer string, pendPool map[string]*fetchRequest, taskQueue
// DeliverBodies injects a block body retrieval response into the results queue.
// The method returns the number of blocks bodies accepted from the delivery and
// also wakes any threads waiting for data delivery.
-func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, txListHashes []common.Hash,
- uncleLists [][]*types.Header, uncleListHashes []common.Hash,
- withdrawalLists [][]*types.Withdrawal, withdrawalListHashes []common.Hash,
-) (int, error) {
+func (q *queue) DeliverBodies(id string, hashes eth.BlockBodyHashes, bodies []eth.BlockBody) (int, error) {
q.lock.Lock()
defer q.lock.Unlock()
+ var txLists [][]*types.Transaction
+ var uncleLists [][]*types.Header
+ var withdrawalLists [][]*types.Withdrawal
+
validate := func(index int, header *types.Header) error {
- if txListHashes[index] != header.TxHash {
+ if hashes.TransactionRoots[index] != header.TxHash {
return errInvalidBody
}
- if uncleListHashes[index] != header.UncleHash {
+ if hashes.UncleHashes[index] != header.UncleHash {
return errInvalidBody
}
if header.WithdrawalsHash == nil {
// nil hash means that withdrawals should not be present in body
- if withdrawalLists[index] != nil {
+ if bodies[index].Withdrawals != nil {
return errInvalidBody
}
} else { // non-nil hash: body must have withdrawals
- if withdrawalLists[index] == nil {
+ if bodies[index].Withdrawals == nil {
return errInvalidBody
}
- if withdrawalListHashes[index] != *header.WithdrawalsHash {
+ if hashes.WithdrawalRoots[index] != *header.WithdrawalsHash {
return errInvalidBody
}
}
- // Blocks must have a number of blobs corresponding to the header gas usage,
- // and zero before the Cancun hardfork.
- var blobs int
- for _, tx := range txLists[index] {
- // Count the number of blobs to validate against the header's blobGasUsed
- blobs += len(tx.BlobHashes())
- // Validate the data blobs individually too
- if tx.Type() == types.BlobTxType {
- if len(tx.BlobHashes()) == 0 {
- return errInvalidBody
- }
- for _, hash := range tx.BlobHashes() {
- if !kzg4844.IsValidVersionedHash(hash[:]) {
- return errInvalidBody
- }
- }
- if tx.BlobTxSidecar() != nil {
- return errInvalidBody
- }
- }
+ // decode
+ txs, err := bodies[index].Transactions.Items()
+ if err != nil {
+ return fmt.Errorf("%w: bad transactions: %v", errInvalidBody, err)
}
- if header.BlobGasUsed != nil {
- if want := *header.BlobGasUsed / params.BlobTxBlobGasPerBlob; uint64(blobs) != want { // div because the header is surely good vs the body might be bloated
- return errInvalidBody
+ txLists = append(txLists, txs)
+ uncles, err := bodies[index].Uncles.Items()
+ if err != nil {
+ return fmt.Errorf("%w: bad uncles: %v", errInvalidBody, err)
+ }
+ uncleLists = append(uncleLists, uncles)
+ if bodies[index].Withdrawals != nil {
+ withdrawals, err := bodies[index].Withdrawals.Items()
+ if err != nil {
+ return fmt.Errorf("%w: bad withdrawals: %v", errInvalidBody, err)
}
+ withdrawalLists = append(withdrawalLists, withdrawals)
} else {
- if blobs != 0 {
- return errInvalidBody
- }
+ withdrawalLists = append(withdrawalLists, nil)
}
return nil
}
@@ -626,8 +616,9 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, txListH
result.Withdrawals = withdrawalLists[index]
result.SetBodyDone()
}
+ nresults := len(hashes.TransactionRoots)
return q.deliver(id, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool,
- bodyReqTimer, bodyInMeter, bodyDropMeter, len(txLists), validate, reconstruct)
+ bodyReqTimer, bodyInMeter, bodyDropMeter, nresults, validate, reconstruct)
}
// DeliverReceipts injects a receipt retrieval response into the results queue.
diff --git a/eth/downloader/queue_test.go b/eth/downloader/queue_test.go
index ca71a769de..c7e8a0d1d6 100644
--- a/eth/downloader/queue_test.go
+++ b/eth/downloader/queue_test.go
@@ -30,8 +30,10 @@ import (
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/eth/protocols/eth"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
@@ -323,26 +325,31 @@ func XTestDelivery(t *testing.T) {
emptyList []*types.Header
txset [][]*types.Transaction
uncleset [][]*types.Header
+ bodies []eth.BlockBody
)
numToSkip := rand.Intn(len(f.Headers))
for _, hdr := range f.Headers[0 : len(f.Headers)-numToSkip] {
- txset = append(txset, world.getTransactions(hdr.Number.Uint64()))
+ txs := world.getTransactions(hdr.Number.Uint64())
+ txset = append(txset, txs)
uncleset = append(uncleset, emptyList)
+ txsList, _ := rlp.EncodeToRawList(txs)
+ bodies = append(bodies, eth.BlockBody{Transactions: txsList})
+ }
+ hashes := eth.BlockBodyHashes{
+ TransactionRoots: make([]common.Hash, len(txset)),
+ UncleHashes: make([]common.Hash, len(uncleset)),
+ WithdrawalRoots: make([]common.Hash, len(txset)),
}
- var (
- txsHashes = make([]common.Hash, len(txset))
- uncleHashes = make([]common.Hash, len(uncleset))
- )
hasher := trie.NewStackTrie(nil)
for i, txs := range txset {
- txsHashes[i] = types.DeriveSha(types.Transactions(txs), hasher)
+ hashes.TransactionRoots[i] = types.DeriveSha(types.Transactions(txs), hasher)
}
for i, uncles := range uncleset {
- uncleHashes[i] = types.CalcUncleHash(uncles)
+ hashes.UncleHashes[i] = types.CalcUncleHash(uncles)
}
+
time.Sleep(100 * time.Millisecond)
- _, err := q.DeliverBodies(peer.id, txset, txsHashes, uncleset, uncleHashes, nil, nil)
- if err != nil {
+ if _, err := q.DeliverBodies(peer.id, hashes, bodies); err != nil {
fmt.Printf("delivered %d bodies %v\n", len(txset), err)
}
} else {
diff --git a/eth/handler_eth.go b/eth/handler_eth.go
index 11742b14ad..8704a86af4 100644
--- a/eth/handler_eth.go
+++ b/eth/handler_eth.go
@@ -20,6 +20,7 @@ import (
"errors"
"fmt"
+ "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/eth/protocols/eth"
@@ -61,19 +62,42 @@ func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error {
return h.txFetcher.Notify(peer.ID(), packet.Types, packet.Sizes, packet.Hashes)
case *eth.TransactionsPacket:
- for _, tx := range *packet {
- if tx.Type() == types.BlobTxType {
- return errors.New("disallowed broadcast blob transaction")
- }
+ txs, err := packet.Items()
+ if err != nil {
+ return fmt.Errorf("Transactions: %v", err)
}
- return h.txFetcher.Enqueue(peer.ID(), *packet, false)
+ if err := handleTransactions(peer, txs, true); err != nil {
+ return fmt.Errorf("Transactions: %v", err)
+ }
+ return h.txFetcher.Enqueue(peer.ID(), txs, false)
- case *eth.PooledTransactionsResponse:
- // If we receive any blob transactions missing sidecars, or with
- // sidecars that don't correspond to the versioned hashes reported
- // in the header, disconnect from the sending peer.
- for _, tx := range *packet {
- if tx.Type() == types.BlobTxType {
+ case *eth.PooledTransactionsPacket:
+ txs, err := packet.List.Items()
+ if err != nil {
+ return fmt.Errorf("PooledTransactions: %v", err)
+ }
+ if err := handleTransactions(peer, txs, false); err != nil {
+ return fmt.Errorf("PooledTransactions: %v", err)
+ }
+ return h.txFetcher.Enqueue(peer.ID(), txs, true)
+
+ default:
+ return fmt.Errorf("unexpected eth packet type: %T", packet)
+ }
+}
+
+// handleTransactions marks all given transactions as known to the peer
+// and performs basic validations.
+func handleTransactions(peer *eth.Peer, list []*types.Transaction, directBroadcast bool) error {
+ seen := make(map[common.Hash]struct{})
+ for _, tx := range list {
+ if tx.Type() == types.BlobTxType {
+ if directBroadcast {
+ return errors.New("disallowed broadcast blob transaction")
+ } else {
+ // If we receive any blob transactions missing sidecars, or with
+ // sidecars that don't correspond to the versioned hashes reported
+ // in the header, disconnect from the sending peer.
if tx.BlobTxSidecar() == nil {
return errors.New("received sidecar-less blob transaction")
}
@@ -82,9 +106,16 @@ func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error {
}
}
}
- return h.txFetcher.Enqueue(peer.ID(), *packet, true)
- default:
- return fmt.Errorf("unexpected eth packet type: %T", packet)
+ // Check for duplicates.
+ hash := tx.Hash()
+ if _, exists := seen[hash]; exists {
+ return fmt.Errorf("multiple copies of the same hash %v", hash)
+ }
+ seen[hash] = struct{}{}
+
+ // Mark as known.
+ peer.MarkTransaction(hash)
}
+ return nil
}
diff --git a/eth/handler_eth_test.go b/eth/handler_eth_test.go
index 1343cae03e..0330713071 100644
--- a/eth/handler_eth_test.go
+++ b/eth/handler_eth_test.go
@@ -60,11 +60,19 @@ func (h *testEthHandler) Handle(peer *eth.Peer, packet eth.Packet) error {
return nil
case *eth.TransactionsPacket:
- h.txBroadcasts.Send(([]*types.Transaction)(*packet))
+ txs, err := packet.Items()
+ if err != nil {
+ return err
+ }
+ h.txBroadcasts.Send(txs)
return nil
- case *eth.PooledTransactionsResponse:
- h.txBroadcasts.Send(([]*types.Transaction)(*packet))
+ case *eth.PooledTransactionsPacket:
+ txs, err := packet.List.Items()
+ if err != nil {
+ return err
+ }
+ h.txBroadcasts.Send(txs)
return nil
default:
diff --git a/eth/protocols/eth/dispatcher.go b/eth/protocols/eth/dispatcher.go
index cba40596fc..3f78fb4646 100644
--- a/eth/protocols/eth/dispatcher.go
+++ b/eth/protocols/eth/dispatcher.go
@@ -22,6 +22,7 @@ import (
"time"
"github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/tracker"
)
var (
@@ -47,9 +48,10 @@ type Request struct {
sink chan *Response // Channel to deliver the response on
cancel chan struct{} // Channel to cancel requests ahead of time
- code uint64 // Message code of the request packet
- want uint64 // Message code of the response packet
- data interface{} // Data content of the request packet
+ code uint64 // Message code of the request packet
+ want uint64 // Message code of the response packet
+ numItems int // Number of requested items
+ data interface{} // Data content of the request packet
Peer string // Demultiplexer if cross-peer requests are batched together
Sent time.Time // Timestamp when the request was sent
@@ -190,19 +192,30 @@ func (p *Peer) dispatchResponse(res *Response, metadata func() interface{}) erro
func (p *Peer) dispatcher() {
pending := make(map[uint64]*Request)
+loop:
for {
select {
case reqOp := <-p.reqDispatch:
req := reqOp.req
req.Sent = time.Now()
- requestTracker.Track(p.id, p.version, req.code, req.want, req.id)
- err := p2p.Send(p.rw, req.code, req.data)
- reqOp.fail <- err
-
- if err == nil {
- pending[req.id] = req
+ treq := tracker.Request{
+ ID: req.id,
+ ReqCode: req.code,
+ RespCode: req.want,
+ Size: req.numItems,
}
+ if err := p.tracker.Track(treq); err != nil {
+ reqOp.fail <- err
+ continue loop
+ }
+ if err := p2p.Send(p.rw, req.code, req.data); err != nil {
+ reqOp.fail <- err
+ continue loop
+ }
+
+ pending[req.id] = req
+ reqOp.fail <- nil
case cancelOp := <-p.reqCancel:
// Retrieve the pending request to cancel and short circuit if it
@@ -220,9 +233,6 @@ func (p *Peer) dispatcher() {
res := resOp.res
res.Req = pending[res.id]
- // Independent if the request exists or not, track this packet
- requestTracker.Fulfil(p.id, p.version, res.code, res.id)
-
switch {
case res.Req == nil:
// Response arrived with an untracked ID. Since even cancelled
@@ -249,6 +259,7 @@ func (p *Peer) dispatcher() {
}
case <-p.term:
+ p.tracker.Stop()
return
}
}
diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go
index 65c491f815..8f7f82c3a1 100644
--- a/eth/protocols/eth/handler_test.go
+++ b/eth/protocols/eth/handler_test.go
@@ -23,9 +23,11 @@ import (
"math/big"
"math/rand"
"os"
+ "reflect"
"testing"
"time"
+ "github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/beacon"
"github.com/ethereum/go-ethereum/consensus/ethash"
@@ -42,6 +44,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
"github.com/holiman/uint256"
)
@@ -360,8 +363,8 @@ func testGetBlockHeaders(t *testing.T, protocol uint) {
GetBlockHeadersRequest: tt.query,
})
if err := p2p.ExpectMsg(peer.app, BlockHeadersMsg, &BlockHeadersPacket{
- RequestId: 123,
- BlockHeadersRequest: headers,
+ RequestId: 123,
+ List: encodeRL(headers),
}); err != nil {
t.Errorf("test %d: headers mismatch: %v", i, err)
}
@@ -374,7 +377,7 @@ func testGetBlockHeaders(t *testing.T, protocol uint) {
RequestId: 456,
GetBlockHeadersRequest: tt.query,
})
- expected := &BlockHeadersPacket{RequestId: 456, BlockHeadersRequest: headers}
+ expected := &BlockHeadersPacket{RequestId: 456, List: encodeRL(headers)}
if err := p2p.ExpectMsg(peer.app, BlockHeadersMsg, expected); err != nil {
t.Errorf("test %d by hash: headers mismatch: %v", i, err)
}
@@ -437,7 +440,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) {
// Collect the hashes to request, and the response to expect
var (
hashes []common.Hash
- bodies []*BlockBody
+ bodies []BlockBody
seen = make(map[int64]bool)
)
for j := 0; j < tt.random; j++ {
@@ -449,7 +452,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) {
block := backend.chain.GetBlockByNumber(uint64(num))
hashes = append(hashes, block.Hash())
if len(bodies) < tt.expected {
- bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles(), Withdrawals: block.Withdrawals()})
+ bodies = append(bodies, encodeBody(block))
}
break
}
@@ -459,7 +462,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) {
hashes = append(hashes, hash)
if tt.available[j] && len(bodies) < tt.expected {
block := backend.chain.GetBlockByHash(hash)
- bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles(), Withdrawals: block.Withdrawals()})
+ bodies = append(bodies, encodeBody(block))
}
}
@@ -469,14 +472,69 @@ func testGetBlockBodies(t *testing.T, protocol uint) {
GetBlockBodiesRequest: hashes,
})
if err := p2p.ExpectMsg(peer.app, BlockBodiesMsg, &BlockBodiesPacket{
- RequestId: 123,
- BlockBodiesResponse: bodies,
+ RequestId: 123,
+ List: encodeRL(bodies),
}); err != nil {
t.Fatalf("test %d: bodies mismatch: %v", i, err)
}
}
}
+func encodeBody(b *types.Block) BlockBody {
+ body := BlockBody{
+ Transactions: encodeRL([]*types.Transaction(b.Transactions())),
+ Uncles: encodeRL(b.Uncles()),
+ }
+ if b.Withdrawals() != nil {
+ wd := encodeRL([]*types.Withdrawal(b.Withdrawals()))
+ body.Withdrawals = &wd
+ }
+ return body
+}
+
+func TestHashBody(t *testing.T) {
+ key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a")
+ signer := types.NewCancunSigner(big.NewInt(1))
+
+ // create block 1
+ header := &types.Header{Number: big.NewInt(11)}
+ txs := []*types.Transaction{
+ types.MustSignNewTx(key, signer, &types.DynamicFeeTx{
+ ChainID: big.NewInt(1),
+ Nonce: 1,
+ Data: []byte("testing"),
+ }),
+ types.MustSignNewTx(key, signer, &types.LegacyTx{
+ Nonce: 2,
+ Data: []byte("testing"),
+ }),
+ }
+ uncles := []*types.Header{{Number: big.NewInt(10)}}
+ body1 := &types.Body{Transactions: txs, Uncles: uncles}
+ block1 := types.NewBlock(header, body1, nil, trie.NewStackTrie(nil))
+
+ // create block 2 (has withdrawals)
+ header2 := &types.Header{Number: big.NewInt(12)}
+ body2 := &types.Body{
+ Withdrawals: []*types.Withdrawal{{Index: 10}, {Index: 11}},
+ }
+ block2 := types.NewBlock(header2, body2, nil, trie.NewStackTrie(nil))
+
+ expectedHashes := BlockBodyHashes{
+ TransactionRoots: []common.Hash{block1.TxHash(), block2.TxHash()},
+ WithdrawalRoots: []common.Hash{common.Hash{}, *block2.Header().WithdrawalsHash},
+ UncleHashes: []common.Hash{block1.UncleHash(), block2.UncleHash()},
+ }
+
+ // compute hash like protocol handler does
+ protocolBodies := []BlockBody{encodeBody(block1), encodeBody(block2)}
+ hashes := hashBodyParts(protocolBodies)
+ if !reflect.DeepEqual(hashes, expectedHashes) {
+ t.Errorf("wrong hashes: %s", spew.Sdump(hashes))
+ t.Logf("expected: %s", spew.Sdump(expectedHashes))
+ }
+}
+
// Tests that the transaction receipts can be retrieved based on hashes.
func TestGetBlockReceipts68(t *testing.T) { testGetBlockReceipts(t, ETH68) }
@@ -528,13 +586,13 @@ func testGetBlockReceipts(t *testing.T, protocol uint) {
// Collect the hashes to request, and the response to expect
var (
hashes []common.Hash
- receipts []*ReceiptList68
+ receipts rlp.RawList[*ReceiptList68]
)
for i := uint64(0); i <= backend.chain.CurrentBlock().Number.Uint64(); i++ {
block := backend.chain.GetBlockByNumber(i)
hashes = append(hashes, block.Hash())
trs := backend.chain.GetReceiptsByHash(block.Hash())
- receipts = append(receipts, NewReceiptList68(trs))
+ receipts.Append(NewReceiptList68(trs))
}
// Send the hash request and verify the response
@@ -688,10 +746,18 @@ func testGetPooledTransaction(t *testing.T, blobTx bool) {
RequestId: 123,
GetPooledTransactionsRequest: []common.Hash{tx.Hash()},
})
- if err := p2p.ExpectMsg(peer.app, PooledTransactionsMsg, PooledTransactionsPacket{
- RequestId: 123,
- PooledTransactionsResponse: []*types.Transaction{tx},
+ if err := p2p.ExpectMsg(peer.app, PooledTransactionsMsg, &PooledTransactionsPacket{
+ RequestId: 123,
+ List: encodeRL([]*types.Transaction{tx}),
}); err != nil {
t.Errorf("pooled transaction mismatch: %v", err)
}
}
+
+func encodeRL[T any](slice []T) rlp.RawList[T] {
+ rl, err := rlp.EncodeToRawList(slice)
+ if err != nil {
+ panic(err)
+ }
+ return rl
+}
diff --git a/eth/protocols/eth/handlers.go b/eth/protocols/eth/handlers.go
index aad3353d88..7f1ccc360d 100644
--- a/eth/protocols/eth/handlers.go
+++ b/eth/protocols/eth/handlers.go
@@ -17,23 +17,22 @@
package eth
import (
+ "bytes"
"encoding/json"
"errors"
"fmt"
- "time"
+ "math"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/tracker"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)
-// requestTracker is a singleton tracker for eth/66 and newer request times.
-var requestTracker = tracker.New(ProtocolName, 5*time.Minute)
-
func handleGetBlockHeaders(backend Backend, msg Decoder, peer *Peer) error {
// Decode the complex header query
var query GetBlockHeadersPacket
@@ -356,9 +355,18 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error {
if err := msg.Decode(res); err != nil {
return err
}
+ tresp := tracker.Response{ID: res.RequestId, MsgCode: BlockHeadersMsg, Size: res.List.Len()}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("BlockHeaders: %w", err)
+ }
+ headers, err := res.List.Items()
+ if err != nil {
+ return fmt.Errorf("BlockHeaders: %w", err)
+ }
+
metadata := func() interface{} {
- hashes := make([]common.Hash, len(res.BlockHeadersRequest))
- for i, header := range res.BlockHeadersRequest {
+ hashes := make([]common.Hash, len(headers))
+ for i, header := range headers {
hashes[i] = header.Hash()
}
return hashes
@@ -366,7 +374,7 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error {
return peer.dispatchResponse(&Response{
id: res.RequestId,
code: BlockHeadersMsg,
- Res: &res.BlockHeadersRequest,
+ Res: (*BlockHeadersRequest)(&headers),
}, metadata)
}
@@ -376,53 +384,150 @@ func handleBlockBodies(backend Backend, msg Decoder, peer *Peer) error {
if err := msg.Decode(res); err != nil {
return err
}
- metadata := func() interface{} {
- var (
- txsHashes = make([]common.Hash, len(res.BlockBodiesResponse))
- uncleHashes = make([]common.Hash, len(res.BlockBodiesResponse))
- withdrawalHashes = make([]common.Hash, len(res.BlockBodiesResponse))
- )
- hasher := trie.NewStackTrie(nil)
- for i, body := range res.BlockBodiesResponse {
- txsHashes[i] = types.DeriveSha(types.Transactions(body.Transactions), hasher)
- uncleHashes[i] = types.CalcUncleHash(body.Uncles)
- if body.Withdrawals != nil {
- withdrawalHashes[i] = types.DeriveSha(types.Withdrawals(body.Withdrawals), hasher)
- }
- }
- return [][]common.Hash{txsHashes, uncleHashes, withdrawalHashes}
+
+ // Check against the request.
+ length := res.List.Len()
+ tresp := tracker.Response{ID: res.RequestId, MsgCode: BlockBodiesMsg, Size: length}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("BlockBodies: %w", err)
}
+
+ // Collect items and dispatch.
+ items, err := res.List.Items()
+ if err != nil {
+ return fmt.Errorf("BlockBodies: %w", err)
+ }
+ metadata := func() any { return hashBodyParts(items) }
return peer.dispatchResponse(&Response{
id: res.RequestId,
code: BlockBodiesMsg,
- Res: &res.BlockBodiesResponse,
+ Res: (*BlockBodiesResponse)(&items),
}, metadata)
}
+// BlockBodyHashes contains the lists of block body part roots for a list of block bodies.
+type BlockBodyHashes struct {
+ TransactionRoots []common.Hash
+ WithdrawalRoots []common.Hash
+ UncleHashes []common.Hash
+}
+
+func hashBodyParts(items []BlockBody) BlockBodyHashes {
+ h := BlockBodyHashes{
+ TransactionRoots: make([]common.Hash, len(items)),
+ WithdrawalRoots: make([]common.Hash, len(items)),
+ UncleHashes: make([]common.Hash, len(items)),
+ }
+ hasher := trie.NewStackTrie(nil)
+ for i, body := range items {
+ // txs
+ txsList := newDerivableRawList(&body.Transactions, writeTxForHash)
+ h.TransactionRoots[i] = types.DeriveSha(txsList, hasher)
+ // uncles
+ if body.Uncles.Len() == 0 {
+ h.UncleHashes[i] = types.EmptyUncleHash
+ } else {
+ h.UncleHashes[i] = crypto.Keccak256Hash(body.Uncles.Bytes())
+ }
+ // withdrawals
+ if body.Withdrawals != nil {
+ wdlist := newDerivableRawList(body.Withdrawals, nil)
+ h.WithdrawalRoots[i] = types.DeriveSha(wdlist, hasher)
+ }
+ }
+ return h
+}
+
+// derivableRawList implements types.DerivableList for a serialized RLP list.
+type derivableRawList struct {
+ data []byte
+ offsets []uint32
+ write func([]byte, *bytes.Buffer)
+}
+
+func newDerivableRawList[T any](list *rlp.RawList[T], write func([]byte, *bytes.Buffer)) *derivableRawList {
+ dl := derivableRawList{data: list.Content(), write: write}
+ if dl.write == nil {
+ // default transform is identity
+ dl.write = func(b []byte, buf *bytes.Buffer) { buf.Write(b) }
+ }
+ // Assert to ensure 32-bit offsets are valid. This can never trigger
+ // unless a block body component or p2p receipt list is larger than 4GB.
+ if uint(len(dl.data)) > math.MaxUint32 {
+ panic("list data too big for derivableRawList")
+ }
+ it := list.ContentIterator()
+ dl.offsets = make([]uint32, list.Len())
+ for i := 0; it.Next(); i++ {
+ dl.offsets[i] = uint32(it.Offset())
+ }
+ return &dl
+}
+
+// Len returns the number of items in the list.
+func (dl *derivableRawList) Len() int {
+ return len(dl.offsets)
+}
+
+// EncodeIndex writes the i'th item to the buffer.
+func (dl *derivableRawList) EncodeIndex(i int, buf *bytes.Buffer) {
+ start := dl.offsets[i]
+ end := uint32(len(dl.data))
+ if i != len(dl.offsets)-1 {
+ end = dl.offsets[i+1]
+ }
+ dl.write(dl.data[start:end], buf)
+}
+
+// writeTxForHash changes a transaction in 'network encoding' into the format used for
+// the transactions MPT.
+func writeTxForHash(tx []byte, buf *bytes.Buffer) {
+ k, content, _, _ := rlp.Split(tx)
+ if k == rlp.List {
+ buf.Write(tx) // legacy tx
+ } else {
+ buf.Write(content) // typed tx
+ }
+}
+
func handleReceipts[L ReceiptsList](backend Backend, msg Decoder, peer *Peer) error {
// A batch of receipts arrived to one of our previous requests
res := new(ReceiptsPacket[L])
if err := msg.Decode(res); err != nil {
return err
}
+
+ tresp := tracker.Response{ID: res.RequestId, MsgCode: ReceiptsMsg, Size: res.List.Len()}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("Receipts: %w", err)
+ }
+
// Assign temporary hashing buffer to each list item, the same buffer is shared
// between all receipt list instances.
+ receiptLists, err := res.List.Items()
+ if err != nil {
+ return fmt.Errorf("Receipts: %w", err)
+ }
buffers := new(receiptListBuffers)
- for i := range res.List {
- res.List[i].setBuffers(buffers)
+ for i := range receiptLists {
+ receiptLists[i].setBuffers(buffers)
}
metadata := func() interface{} {
hasher := trie.NewStackTrie(nil)
- hashes := make([]common.Hash, len(res.List))
- for i := range res.List {
- hashes[i] = types.DeriveSha(res.List[i], hasher)
+ hashes := make([]common.Hash, len(receiptLists))
+ for i := range receiptLists {
+ hashes[i] = types.DeriveSha(receiptLists[i].Derivable(), hasher)
}
return hashes
}
var enc ReceiptsRLPResponse
- for i := range res.List {
- enc = append(enc, res.List[i].EncodeForStorage())
+ for i := range receiptLists {
+ encReceipts, err := receiptLists[i].EncodeForStorage()
+ if err != nil {
+ return fmt.Errorf("Receipts: invalid list %d: %v", i, err)
+ }
+ enc = append(enc, encReceipts)
}
return peer.dispatchResponse(&Response{
id: res.RequestId,
@@ -446,7 +551,7 @@ func handleNewPooledTransactionHashes(backend Backend, msg Decoder, peer *Peer)
}
// Schedule all the unknown hashes for retrieval
for _, hash := range ann.Hashes {
- peer.markTransaction(hash)
+ peer.MarkTransaction(hash)
}
return backend.Handle(peer, ann)
}
@@ -494,19 +599,8 @@ func handleTransactions(backend Backend, msg Decoder, peer *Peer) error {
if err := msg.Decode(&txs); err != nil {
return err
}
- // Duplicate transactions are not allowed
- seen := make(map[common.Hash]struct{})
- for i, tx := range txs {
- // Validate and mark the remote transaction
- if tx == nil {
- return fmt.Errorf("Transactions: transaction %d is nil", i)
- }
- hash := tx.Hash()
- if _, exists := seen[hash]; exists {
- return fmt.Errorf("Transactions: multiple copies of the same hash %v", hash)
- }
- seen[hash] = struct{}{}
- peer.markTransaction(hash)
+ if txs.Len() > maxTransactionAnnouncements {
+ return fmt.Errorf("too many transactions")
}
return backend.Handle(peer, &txs)
}
@@ -516,28 +610,22 @@ func handlePooledTransactions(backend Backend, msg Decoder, peer *Peer) error {
if !backend.AcceptTxs() {
return nil
}
- // Transactions can be processed, parse all of them and deliver to the pool
- var txs PooledTransactionsPacket
- if err := msg.Decode(&txs); err != nil {
+
+ // Check against request and decode.
+ var resp PooledTransactionsPacket
+ if err := msg.Decode(&resp); err != nil {
return err
}
- // Duplicate transactions are not allowed
- seen := make(map[common.Hash]struct{})
- for i, tx := range txs.PooledTransactionsResponse {
- // Validate and mark the remote transaction
- if tx == nil {
- return fmt.Errorf("PooledTransactions: transaction %d is nil", i)
- }
- hash := tx.Hash()
- if _, exists := seen[hash]; exists {
- return fmt.Errorf("PooledTransactions: multiple copies of the same hash %v", hash)
- }
- seen[hash] = struct{}{}
- peer.markTransaction(hash)
+ tresp := tracker.Response{
+ ID: resp.RequestId,
+ MsgCode: PooledTransactionsMsg,
+ Size: resp.List.Len(),
+ }
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("PooledTransactions: %w", err)
}
- requestTracker.Fulfil(peer.id, peer.version, PooledTransactionsMsg, txs.RequestId)
- return backend.Handle(peer, &txs.PooledTransactionsResponse)
+ return backend.Handle(peer, &resp)
}
func handleBlockRangeUpdate(backend Backend, msg Decoder, peer *Peer) error {
diff --git a/eth/protocols/eth/peer.go b/eth/protocols/eth/peer.go
index df20c672c0..4ea2d7158c 100644
--- a/eth/protocols/eth/peer.go
+++ b/eth/protocols/eth/peer.go
@@ -19,11 +19,13 @@ package eth
import (
"math/rand"
"sync/atomic"
+ "time"
mapset "github.com/deckarep/golang-set/v2"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/tracker"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -43,9 +45,10 @@ const (
// Peer is a collection of relevant information we have about a `eth` peer.
type Peer struct {
+ *p2p.Peer // The embedded P2P package peer
+
id string // Unique ID for the peer, cached
- *p2p.Peer // The embedded P2P package peer
rw p2p.MsgReadWriter // Input/output streams for snap
version uint // Protocol version negotiated
lastRange atomic.Pointer[BlockRangeUpdatePacket]
@@ -55,6 +58,7 @@ type Peer struct {
txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests
txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests
+ tracker *tracker.Tracker
reqDispatch chan *request // Dispatch channel to send requests and track then until fulfillment
reqCancel chan *cancel // Dispatch channel to cancel pending requests and untrack them
resDispatch chan *response // Dispatch channel to fulfil pending requests and untrack them
@@ -65,14 +69,17 @@ type Peer struct {
// NewPeer creates a wrapper for a network connection and negotiated protocol
// version.
func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter, txpool TxPool) *Peer {
+ cap := p2p.Cap{Name: ProtocolName, Version: version}
+ id := p.ID().String()
peer := &Peer{
- id: p.ID().String(),
+ id: id,
Peer: p,
rw: rw,
version: version,
knownTxs: newKnownCache(maxKnownTxs),
txBroadcast: make(chan []common.Hash),
txAnnounce: make(chan []common.Hash),
+ tracker: tracker.New(cap, id, 5*time.Minute),
reqDispatch: make(chan *request),
reqCancel: make(chan *cancel),
resDispatch: make(chan *response),
@@ -115,9 +122,9 @@ func (p *Peer) KnownTransaction(hash common.Hash) bool {
return p.knownTxs.Contains(hash)
}
-// markTransaction marks a transaction as known for the peer, ensuring that it
+// MarkTransaction marks a transaction as known for the peer, ensuring that it
// will never be propagated to this particular peer.
-func (p *Peer) markTransaction(hash common.Hash) {
+func (p *Peer) MarkTransaction(hash common.Hash) {
// If we reached the memory allowance, drop a previously known transaction hash
p.knownTxs.Add(hash)
}
@@ -222,10 +229,11 @@ func (p *Peer) RequestOneHeader(hash common.Hash, sink chan *Response) (*Request
id := rand.Uint64()
req := &Request{
- id: id,
- sink: sink,
- code: GetBlockHeadersMsg,
- want: BlockHeadersMsg,
+ id: id,
+ sink: sink,
+ code: GetBlockHeadersMsg,
+ want: BlockHeadersMsg,
+ numItems: 1,
data: &GetBlockHeadersPacket{
RequestId: id,
GetBlockHeadersRequest: &GetBlockHeadersRequest{
@@ -249,10 +257,11 @@ func (p *Peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, re
id := rand.Uint64()
req := &Request{
- id: id,
- sink: sink,
- code: GetBlockHeadersMsg,
- want: BlockHeadersMsg,
+ id: id,
+ sink: sink,
+ code: GetBlockHeadersMsg,
+ want: BlockHeadersMsg,
+ numItems: amount,
data: &GetBlockHeadersPacket{
RequestId: id,
GetBlockHeadersRequest: &GetBlockHeadersRequest{
@@ -276,10 +285,11 @@ func (p *Peer) RequestHeadersByNumber(origin uint64, amount int, skip int, rever
id := rand.Uint64()
req := &Request{
- id: id,
- sink: sink,
- code: GetBlockHeadersMsg,
- want: BlockHeadersMsg,
+ id: id,
+ sink: sink,
+ code: GetBlockHeadersMsg,
+ want: BlockHeadersMsg,
+ numItems: amount,
data: &GetBlockHeadersPacket{
RequestId: id,
GetBlockHeadersRequest: &GetBlockHeadersRequest{
@@ -303,10 +313,11 @@ func (p *Peer) RequestBodies(hashes []common.Hash, sink chan *Response) (*Reques
id := rand.Uint64()
req := &Request{
- id: id,
- sink: sink,
- code: GetBlockBodiesMsg,
- want: BlockBodiesMsg,
+ id: id,
+ sink: sink,
+ code: GetBlockBodiesMsg,
+ want: BlockBodiesMsg,
+ numItems: len(hashes),
data: &GetBlockBodiesPacket{
RequestId: id,
GetBlockBodiesRequest: hashes,
@@ -324,10 +335,11 @@ func (p *Peer) RequestReceipts(hashes []common.Hash, sink chan *Response) (*Requ
id := rand.Uint64()
req := &Request{
- id: id,
- sink: sink,
- code: GetReceiptsMsg,
- want: ReceiptsMsg,
+ id: id,
+ sink: sink,
+ code: GetReceiptsMsg,
+ want: ReceiptsMsg,
+ numItems: len(hashes),
data: &GetReceiptsPacket{
RequestId: id,
GetReceiptsRequest: hashes,
@@ -344,7 +356,15 @@ func (p *Peer) RequestTxs(hashes []common.Hash) error {
p.Log().Trace("Fetching batch of transactions", "count", len(hashes))
id := rand.Uint64()
- requestTracker.Track(p.id, p.version, GetPooledTransactionsMsg, PooledTransactionsMsg, id)
+ err := p.tracker.Track(tracker.Request{
+ ID: id,
+ ReqCode: GetPooledTransactionsMsg,
+ RespCode: PooledTransactionsMsg,
+ Size: len(hashes),
+ })
+ if err != nil {
+ return err
+ }
return p2p.Send(p.rw, GetPooledTransactionsMsg, &GetPooledTransactionsPacket{
RequestId: id,
GetPooledTransactionsRequest: hashes,
diff --git a/eth/protocols/eth/protocol.go b/eth/protocols/eth/protocol.go
index 7c41e7a996..6ab800f4f4 100644
--- a/eth/protocols/eth/protocol.go
+++ b/eth/protocols/eth/protocol.go
@@ -49,6 +49,9 @@ var protocolLengths = map[uint]uint64{ETH68: 17, ETH69: 18}
// maxMessageSize is the maximum cap on the size of a protocol message.
const maxMessageSize = 10 * 1024 * 1024
+// This is the maximum number of transactions in a Transactions message.
+const maxTransactionAnnouncements = 5000
+
const (
StatusMsg = 0x00
NewBlockHashesMsg = 0x01
@@ -127,7 +130,9 @@ func (p *NewBlockHashesPacket) Unpack() ([]common.Hash, []uint64) {
}
// TransactionsPacket is the network packet for broadcasting new transactions.
-type TransactionsPacket []*types.Transaction
+type TransactionsPacket struct {
+ rlp.RawList[*types.Transaction]
+}
// GetBlockHeadersRequest represents a block header query.
type GetBlockHeadersRequest struct {
@@ -185,7 +190,7 @@ type BlockHeadersRequest []*types.Header
// BlockHeadersPacket represents a block header response over with request ID wrapping.
type BlockHeadersPacket struct {
RequestId uint64
- BlockHeadersRequest
+ List rlp.RawList[*types.Header]
}
// BlockHeadersRLPResponse represents a block header response, to use when we already
@@ -213,14 +218,11 @@ type GetBlockBodiesPacket struct {
GetBlockBodiesRequest
}
-// BlockBodiesResponse is the network packet for block content distribution.
-type BlockBodiesResponse []*BlockBody
-
// BlockBodiesPacket is the network packet for block content distribution with
// request ID wrapping.
type BlockBodiesPacket struct {
RequestId uint64
- BlockBodiesResponse
+ List rlp.RawList[BlockBody]
}
// BlockBodiesRLPResponse is used for replying to block body requests, in cases
@@ -234,25 +236,14 @@ type BlockBodiesRLPPacket struct {
BlockBodiesRLPResponse
}
+// BlockBodiesResponse is the network packet for block content distribution.
+type BlockBodiesResponse []BlockBody
+
// BlockBody represents the data content of a single block.
type BlockBody struct {
- Transactions []*types.Transaction // Transactions contained within a block
- Uncles []*types.Header // Uncles contained within a block
- Withdrawals []*types.Withdrawal `rlp:"optional"` // Withdrawals contained within a block
-}
-
-// Unpack retrieves the transactions and uncles from the range packet and returns
-// them in a split flat format that's more consistent with the internal data structures.
-func (p *BlockBodiesResponse) Unpack() ([][]*types.Transaction, [][]*types.Header, [][]*types.Withdrawal) {
- var (
- txset = make([][]*types.Transaction, len(*p))
- uncleset = make([][]*types.Header, len(*p))
- withdrawalset = make([][]*types.Withdrawal, len(*p))
- )
- for i, body := range *p {
- txset[i], uncleset[i], withdrawalset[i] = body.Transactions, body.Uncles, body.Withdrawals
- }
- return txset, uncleset, withdrawalset
+ Transactions rlp.RawList[*types.Transaction]
+ Uncles rlp.RawList[*types.Header]
+ Withdrawals *rlp.RawList[*types.Withdrawal] `rlp:"optional"`
}
// GetReceiptsRequest represents a block receipts query.
@@ -271,15 +262,15 @@ type ReceiptsResponse []types.Receipts
type ReceiptsList interface {
*ReceiptList68 | *ReceiptList69
setBuffers(*receiptListBuffers)
- EncodeForStorage() rlp.RawValue
- types.DerivableList
+ EncodeForStorage() (rlp.RawValue, error)
+ Derivable() types.DerivableList
}
// ReceiptsPacket is the network packet for block receipts distribution with
// request ID wrapping.
type ReceiptsPacket[L ReceiptsList] struct {
RequestId uint64
- List []L
+ List rlp.RawList[L]
}
// ReceiptsRLPResponse is used for receipts, when we already have it encoded
@@ -314,7 +305,7 @@ type PooledTransactionsResponse []*types.Transaction
// with request ID wrapping.
type PooledTransactionsPacket struct {
RequestId uint64
- PooledTransactionsResponse
+ List rlp.RawList[*types.Transaction]
}
// PooledTransactionsRLPResponse is the network packet for transaction distribution, used
@@ -367,8 +358,8 @@ func (*NewPooledTransactionHashesPacket) Kind() byte { return NewPooledTransac
func (*GetPooledTransactionsRequest) Name() string { return "GetPooledTransactions" }
func (*GetPooledTransactionsRequest) Kind() byte { return GetPooledTransactionsMsg }
-func (*PooledTransactionsResponse) Name() string { return "PooledTransactions" }
-func (*PooledTransactionsResponse) Kind() byte { return PooledTransactionsMsg }
+func (*PooledTransactionsPacket) Name() string { return "PooledTransactions" }
+func (*PooledTransactionsPacket) Kind() byte { return PooledTransactionsMsg }
func (*GetReceiptsRequest) Name() string { return "GetReceipts" }
func (*GetReceiptsRequest) Kind() byte { return GetReceiptsMsg }
diff --git a/eth/protocols/eth/protocol_test.go b/eth/protocols/eth/protocol_test.go
index 8a2559a6c5..e37d72dcd6 100644
--- a/eth/protocols/eth/protocol_test.go
+++ b/eth/protocols/eth/protocol_test.go
@@ -78,34 +78,34 @@ func TestEmptyMessages(t *testing.T) {
for i, msg := range []any{
// Headers
GetBlockHeadersPacket{1111, nil},
- BlockHeadersPacket{1111, nil},
// Bodies
GetBlockBodiesPacket{1111, nil},
- BlockBodiesPacket{1111, nil},
BlockBodiesRLPPacket{1111, nil},
// Receipts
GetReceiptsPacket{1111, nil},
// Transactions
GetPooledTransactionsPacket{1111, nil},
- PooledTransactionsPacket{1111, nil},
PooledTransactionsRLPPacket{1111, nil},
// Headers
- BlockHeadersPacket{1111, BlockHeadersRequest([]*types.Header{})},
+ BlockHeadersPacket{1111, encodeRL([]*types.Header{})},
// Bodies
GetBlockBodiesPacket{1111, GetBlockBodiesRequest([]common.Hash{})},
- BlockBodiesPacket{1111, BlockBodiesResponse([]*BlockBody{})},
+ BlockBodiesPacket{1111, encodeRL([]BlockBody{})},
BlockBodiesRLPPacket{1111, BlockBodiesRLPResponse([]rlp.RawValue{})},
// Receipts
GetReceiptsPacket{1111, GetReceiptsRequest([]common.Hash{})},
- ReceiptsPacket[*ReceiptList68]{1111, []*ReceiptList68{}},
- ReceiptsPacket[*ReceiptList69]{1111, []*ReceiptList69{}},
+ ReceiptsPacket[*ReceiptList68]{1111, encodeRL([]*ReceiptList68{})},
+ ReceiptsPacket[*ReceiptList69]{1111, encodeRL([]*ReceiptList69{})},
// Transactions
GetPooledTransactionsPacket{1111, GetPooledTransactionsRequest([]common.Hash{})},
- PooledTransactionsPacket{1111, PooledTransactionsResponse([]*types.Transaction{})},
+ PooledTransactionsPacket{1111, encodeRL([]*types.Transaction{})},
PooledTransactionsRLPPacket{1111, PooledTransactionsRLPResponse([]rlp.RawValue{})},
} {
- if have, _ := rlp.EncodeToBytes(msg); !bytes.Equal(have, want) {
+ have, err := rlp.EncodeToBytes(msg)
+ if err != nil {
+ t.Errorf("test %d, type %T, error: %v", i, msg, err)
+ } else if !bytes.Equal(have, want) {
t.Errorf("test %d, type %T, have\n\t%x\nwant\n\t%x", i, msg, have, want)
}
}
@@ -116,7 +116,7 @@ func TestMessages(t *testing.T) {
// Some basic structs used during testing
var (
header *types.Header
- blockBody *BlockBody
+ blockBody BlockBody
blockBodyRlp rlp.RawValue
txs []*types.Transaction
txRlps []rlp.RawValue
@@ -150,9 +150,9 @@ func TestMessages(t *testing.T) {
}
}
// init the block body data, both object and rlp form
- blockBody = &BlockBody{
- Transactions: txs,
- Uncles: []*types.Header{header},
+ blockBody = BlockBody{
+ Transactions: encodeRL(txs),
+ Uncles: encodeRL([]*types.Header{header}),
}
blockBodyRlp, err = rlp.EncodeToBytes(blockBody)
if err != nil {
@@ -211,7 +211,7 @@ func TestMessages(t *testing.T) {
common.FromHex("ca820457c682270f050580"),
},
{
- BlockHeadersPacket{1111, BlockHeadersRequest{header}},
+ BlockHeadersPacket{1111, encodeRL([]*types.Header{header})},
common.FromHex("f90202820457f901fcf901f9a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008208ae820d0582115c8215b3821a0a827788a00000000000000000000000000000000000000000000000000000000000000000880000000000000000"),
},
{
@@ -219,7 +219,7 @@ func TestMessages(t *testing.T) {
common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"),
},
{
- BlockBodiesPacket{1111, BlockBodiesResponse([]*BlockBody{blockBody})},
+ BlockBodiesPacket{1111, encodeRL([]BlockBody{blockBody})},
common.FromHex("f902dc820457f902d6f902d3f8d2f867088504a817c8088302e2489435353535353535353535353535353535353535358202008025a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c12a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c10f867098504a817c809830334509435353535353535353535353535353535353535358202d98025a052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afba052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afbf901fcf901f9a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008208ae820d0582115c8215b3821a0a827788a00000000000000000000000000000000000000000000000000000000000000000880000000000000000"),
},
{ // Identical to non-rlp-shortcut version
@@ -231,7 +231,7 @@ func TestMessages(t *testing.T) {
common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"),
},
{
- ReceiptsPacket[*ReceiptList68]{1111, []*ReceiptList68{NewReceiptList68(receipts)}},
+ ReceiptsPacket[*ReceiptList68]{1111, encodeRL([]*ReceiptList68{NewReceiptList68(receipts)})},
common.FromHex("f902e6820457f902e0f902ddf901688082014db9010000000000000010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000014000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000004000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000f85ff85d940000000000000000000000000000000000000011f842a0000000000000000000000000000000000000000000000000000000000000deada0000000000000000000000000000000000000000000000000000000000000beef830100ffb9016f01f9016b018201bcb9010000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000001000000000000000000000000000000000000000000000000040000000000000000000000000004000000000000000000000000000000000000000000000000000000008000400000000000000000000000000000000000000000000000000000000000000000000000000000040f862f860940000000000000000000000000000000000000022f842a00000000000000000000000000000000000000000000000000000000000005668a0000000000000000000000000000000000000000000000000000000000000977386020f0f0f0608"),
},
{
@@ -240,7 +240,7 @@ func TestMessages(t *testing.T) {
common.FromHex("f902e6820457f902e0f902ddf901688082014db9010000000000000010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000014000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000004000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000f85ff85d940000000000000000000000000000000000000011f842a0000000000000000000000000000000000000000000000000000000000000deada0000000000000000000000000000000000000000000000000000000000000beef830100ffb9016f01f9016b018201bcb9010000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000001000000000000000000000000000000000000000000000000040000000000000000000000000004000000000000000000000000000000000000000000000000000000008000400000000000000000000000000000000000000000000000000000000000000000000000000000040f862f860940000000000000000000000000000000000000022f842a00000000000000000000000000000000000000000000000000000000000005668a0000000000000000000000000000000000000000000000000000000000000977386020f0f0f0608"),
},
{
- ReceiptsPacket[*ReceiptList69]{1111, []*ReceiptList69{NewReceiptList69(receipts)}},
+ ReceiptsPacket[*ReceiptList69]{1111, encodeRL([]*ReceiptList69{NewReceiptList69(receipts)})},
common.FromHex("f8da820457f8d5f8d3f866808082014df85ff85d940000000000000000000000000000000000000011f842a0000000000000000000000000000000000000000000000000000000000000deada0000000000000000000000000000000000000000000000000000000000000beef830100fff86901018201bcf862f860940000000000000000000000000000000000000022f842a00000000000000000000000000000000000000000000000000000000000005668a0000000000000000000000000000000000000000000000000000000000000977386020f0f0f0608"),
},
{
@@ -248,7 +248,7 @@ func TestMessages(t *testing.T) {
common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"),
},
{
- PooledTransactionsPacket{1111, PooledTransactionsResponse(txs)},
+ PooledTransactionsPacket{1111, encodeRL(txs)},
common.FromHex("f8d7820457f8d2f867088504a817c8088302e2489435353535353535353535353535353535353535358202008025a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c12a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c10f867098504a817c809830334509435353535353535353535353535353535353535358202d98025a052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afba052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afb"),
},
{
diff --git a/eth/protocols/eth/receipt.go b/eth/protocols/eth/receipt.go
index 45c4766b17..06956df9f2 100644
--- a/eth/protocols/eth/receipt.go
+++ b/eth/protocols/eth/receipt.go
@@ -50,8 +50,8 @@ func newReceipt(tr *types.Receipt) Receipt {
}
// decode68 parses a receipt in the eth/68 network encoding.
-func (r *Receipt) decode68(buf *receiptListBuffers, s *rlp.Stream) error {
- k, size, err := s.Kind()
+func (r *Receipt) decode68(b []byte) error {
+ k, content, _, err := rlp.Split(b)
if err != nil {
return err
}
@@ -59,65 +59,84 @@ func (r *Receipt) decode68(buf *receiptListBuffers, s *rlp.Stream) error {
*r = Receipt{}
if k == rlp.List {
// Legacy receipt.
- return r.decodeInnerList(s, false, true)
+ return r.decodeInnerList(b, false, true)
}
// Typed receipt.
- if size < 2 || size > maxReceiptSize {
- return fmt.Errorf("invalid receipt size %d", size)
+ if len(content) < 2 || len(content) > maxReceiptSize {
+ return fmt.Errorf("invalid receipt size %d", len(content))
}
- buf.tmp.Reset()
- buf.tmp.Grow(int(size))
- payload := buf.tmp.Bytes()[:int(size)]
- if err := s.ReadBytes(payload); err != nil {
- return err
- }
- r.TxType = payload[0]
- s2 := rlp.NewStream(bytes.NewReader(payload[1:]), 0)
- return r.decodeInnerList(s2, false, true)
+ r.TxType = content[0]
+ return r.decodeInnerList(content[1:], false, true)
}
// decode69 parses a receipt in the eth/69 network encoding.
-func (r *Receipt) decode69(s *rlp.Stream) error {
+func (r *Receipt) decode69(b []byte) error {
*r = Receipt{}
- return r.decodeInnerList(s, true, false)
+ return r.decodeInnerList(b, true, false)
}
// decodeDatabase parses a receipt in the basic database encoding.
-func (r *Receipt) decodeDatabase(txType byte, s *rlp.Stream) error {
+func (r *Receipt) decodeDatabase(txType byte, b []byte) error {
*r = Receipt{TxType: txType}
- return r.decodeInnerList(s, false, false)
+ return r.decodeInnerList(b, false, false)
}
-func (r *Receipt) decodeInnerList(s *rlp.Stream, readTxType, readBloom bool) error {
- _, err := s.List()
+func (r *Receipt) decodeInnerList(input []byte, readTxType, readBloom bool) error {
+ input, _, err := rlp.SplitList(input)
if err != nil {
- return err
+ return fmt.Errorf("inner list: %v", err)
}
+
+ // txType
if readTxType {
- r.TxType, err = s.Uint8()
+ var txType uint64
+ txType, input, err = rlp.SplitUint64(input)
if err != nil {
return fmt.Errorf("invalid txType: %w", err)
}
+ if txType > 0x7f {
+ return fmt.Errorf("invalid txType: too large")
+ }
+ r.TxType = byte(txType)
}
- r.PostStateOrStatus, err = s.Bytes()
+
+ // status
+ r.PostStateOrStatus, input, err = rlp.SplitString(input)
if err != nil {
return fmt.Errorf("invalid postStateOrStatus: %w", err)
}
- r.GasUsed, err = s.Uint64()
+ if len(r.PostStateOrStatus) > 1 && len(r.PostStateOrStatus) != 32 {
+ return fmt.Errorf("invalid postStateOrStatus length %d", len(r.PostStateOrStatus))
+ }
+
+ // gas
+ r.GasUsed, input, err = rlp.SplitUint64(input)
if err != nil {
return fmt.Errorf("invalid gasUsed: %w", err)
}
+
+ // bloom
if readBloom {
- var b types.Bloom
- if err := s.ReadBytes(b[:]); err != nil {
+ var bloomBytes []byte
+ bloomBytes, input, err = rlp.SplitString(input)
+ if err != nil {
return fmt.Errorf("invalid bloom: %v", err)
}
+ if len(bloomBytes) != types.BloomByteLength {
+ return fmt.Errorf("invalid bloom length %d", len(bloomBytes))
+ }
}
- r.Logs, err = s.Raw()
+
+ // logs
+ _, rest, err := rlp.SplitList(input)
if err != nil {
return fmt.Errorf("invalid logs: %w", err)
}
- return s.ListEnd()
+ if len(rest) != 0 {
+ return fmt.Errorf("junk at end of receipt")
+ }
+ r.Logs = input
+ return nil
}
// encodeForStorage produces the the storage encoding, i.e. the result matches
@@ -223,32 +242,45 @@ func initBuffers(buf **receiptListBuffers) {
}
// encodeForStorage encodes a list of receipts for the database.
-func (buf *receiptListBuffers) encodeForStorage(rs []Receipt) rlp.RawValue {
+func (buf *receiptListBuffers) encodeForStorage(rs rlp.RawList[Receipt], decode func([]byte, *Receipt) error) (rlp.RawValue, error) {
var out bytes.Buffer
w := &buf.enc
w.Reset(&out)
outer := w.List()
- for _, receipts := range rs {
- receipts.encodeForStorage(w)
+ it := rs.ContentIterator()
+ for it.Next() {
+ var receipt Receipt
+ if err := decode(it.Value(), &receipt); err != nil {
+ return nil, err
+ }
+ receipt.encodeForStorage(w)
+ }
+ if it.Err() != nil {
+ return nil, fmt.Errorf("bad list: %v", it.Err())
}
w.ListEnd(outer)
w.Flush()
- return out.Bytes()
+ return out.Bytes(), nil
}
// ReceiptList68 is a block receipt list as downloaded by eth/68.
// This also implements types.DerivableList for validation purposes.
type ReceiptList68 struct {
buf *receiptListBuffers
- items []Receipt
+ items rlp.RawList[Receipt]
}
// NewReceiptList68 creates a receipt list.
// This is slow, and exists for testing purposes.
func NewReceiptList68(trs []*types.Receipt) *ReceiptList68 {
- rl := &ReceiptList68{items: make([]Receipt, len(trs))}
- for i, tr := range trs {
- rl.items[i] = newReceipt(tr)
+ rl := new(ReceiptList68)
+ initBuffers(&rl.buf)
+ enc := rlp.NewEncoderBuffer(nil)
+ for _, tr := range trs {
+ r := newReceipt(tr)
+ r.encodeForNetwork68(rl.buf, &enc)
+ rl.items.AppendRaw(enc.ToBytes())
+ enc.Reset(nil)
}
return rl
}
@@ -266,17 +298,12 @@ func blockReceiptsToNetwork68(blockReceipts, blockBody rlp.RawValue) ([]byte, er
buf receiptListBuffers
)
blockReceiptIter, _ := rlp.NewListIterator(blockReceipts)
- innerReader := bytes.NewReader(nil)
- innerStream := rlp.NewStream(innerReader, 0)
w := rlp.NewEncoderBuffer(&out)
outer := w.List()
for i := 0; blockReceiptIter.Next(); i++ {
- content := blockReceiptIter.Value()
- innerReader.Reset(content)
- innerStream.Reset(innerReader, uint64(len(content)))
- var r Receipt
txType, _ := nextTxType()
- if err := r.decodeDatabase(txType, innerStream); err != nil {
+ var r Receipt
+ if err := r.decodeDatabase(txType, blockReceiptIter.Value()); err != nil {
return nil, fmt.Errorf("invalid database receipt %d: %v", i, err)
}
r.encodeForNetwork68(&buf, &w)
@@ -292,64 +319,51 @@ func (rl *ReceiptList68) setBuffers(buf *receiptListBuffers) {
}
// EncodeForStorage encodes the receipts for storage into the database.
-func (rl *ReceiptList68) EncodeForStorage() rlp.RawValue {
+func (rl *ReceiptList68) EncodeForStorage() (rlp.RawValue, error) {
initBuffers(&rl.buf)
- return rl.buf.encodeForStorage(rl.items)
+ return rl.buf.encodeForStorage(rl.items, func(data []byte, r *Receipt) error {
+ return r.decode68(data)
+ })
}
-// Len implements types.DerivableList.
-func (rl *ReceiptList68) Len() int {
- return len(rl.items)
-}
-
-// EncodeIndex implements types.DerivableList.
-func (rl *ReceiptList68) EncodeIndex(i int, out *bytes.Buffer) {
+// Derivable turns the receipts into a list that can derive the root hash.
+func (rl *ReceiptList68) Derivable() types.DerivableList {
initBuffers(&rl.buf)
- rl.items[i].encodeForHash(rl.buf, out)
+ return newDerivableRawList(&rl.items, func(data []byte, outbuf *bytes.Buffer) {
+ var r Receipt
+ if r.decode68(data) == nil {
+ r.encodeForHash(rl.buf, outbuf)
+ }
+ })
}
// DecodeRLP decodes a list of receipts from the network format.
func (rl *ReceiptList68) DecodeRLP(s *rlp.Stream) error {
- initBuffers(&rl.buf)
- if _, err := s.List(); err != nil {
- return err
- }
- for i := 0; s.MoreDataInList(); i++ {
- var item Receipt
- err := item.decode68(rl.buf, s)
- if err != nil {
- return fmt.Errorf("receipt %d: %v", i, err)
- }
- rl.items = append(rl.items, item)
- }
- return s.ListEnd()
+ return rl.items.DecodeRLP(s)
}
// EncodeRLP encodes the list into the network format of eth/68.
-func (rl *ReceiptList68) EncodeRLP(_w io.Writer) error {
- initBuffers(&rl.buf)
- w := rlp.NewEncoderBuffer(_w)
- outer := w.List()
- for i := range rl.items {
- rl.items[i].encodeForNetwork68(rl.buf, &w)
- }
- w.ListEnd(outer)
- return w.Flush()
+func (rl *ReceiptList68) EncodeRLP(w io.Writer) error {
+ return rl.items.EncodeRLP(w)
}
// ReceiptList69 is the block receipt list as downloaded by eth/69.
// This implements types.DerivableList for validation purposes.
type ReceiptList69 struct {
buf *receiptListBuffers
- items []Receipt
+ items rlp.RawList[Receipt]
}
// NewReceiptList69 creates a receipt list.
// This is slow, and exists for testing purposes.
func NewReceiptList69(trs []*types.Receipt) *ReceiptList69 {
- rl := &ReceiptList69{items: make([]Receipt, len(trs))}
- for i, tr := range trs {
- rl.items[i] = newReceipt(tr)
+ rl := new(ReceiptList69)
+ enc := rlp.NewEncoderBuffer(nil)
+ for _, tr := range trs {
+ r := newReceipt(tr)
+ r.encodeForNetwork69(&enc)
+ rl.items.AppendRaw(enc.ToBytes())
+ enc.Reset(nil)
}
return rl
}
@@ -360,47 +374,32 @@ func (rl *ReceiptList69) setBuffers(buf *receiptListBuffers) {
}
// EncodeForStorage encodes the receipts for storage into the database.
-func (rl *ReceiptList69) EncodeForStorage() rlp.RawValue {
+func (rl *ReceiptList69) EncodeForStorage() (rlp.RawValue, error) {
initBuffers(&rl.buf)
- return rl.buf.encodeForStorage(rl.items)
+ return rl.buf.encodeForStorage(rl.items, func(data []byte, r *Receipt) error {
+ return r.decode69(data)
+ })
}
-// Len implements types.DerivableList.
-func (rl *ReceiptList69) Len() int {
- return len(rl.items)
-}
-
-// EncodeIndex implements types.DerivableList.
-func (rl *ReceiptList69) EncodeIndex(i int, out *bytes.Buffer) {
+// Derivable turns the receipts into a list that can derive the root hash.
+func (rl *ReceiptList69) Derivable() types.DerivableList {
initBuffers(&rl.buf)
- rl.items[i].encodeForHash(rl.buf, out)
+ return newDerivableRawList(&rl.items, func(data []byte, outbuf *bytes.Buffer) {
+ var r Receipt
+ if r.decode69(data) == nil {
+ r.encodeForHash(rl.buf, outbuf)
+ }
+ })
}
// DecodeRLP decodes a list receipts from the network format.
func (rl *ReceiptList69) DecodeRLP(s *rlp.Stream) error {
- if _, err := s.List(); err != nil {
- return err
- }
- for i := 0; s.MoreDataInList(); i++ {
- var item Receipt
- err := item.decode69(s)
- if err != nil {
- return fmt.Errorf("receipt %d: %v", i, err)
- }
- rl.items = append(rl.items, item)
- }
- return s.ListEnd()
+ return rl.items.DecodeRLP(s)
}
// EncodeRLP encodes the list into the network format of eth/69.
-func (rl *ReceiptList69) EncodeRLP(_w io.Writer) error {
- w := rlp.NewEncoderBuffer(_w)
- outer := w.List()
- for i := range rl.items {
- rl.items[i].encodeForNetwork69(&w)
- }
- w.ListEnd(outer)
- return w.Flush()
+func (rl *ReceiptList69) EncodeRLP(w io.Writer) error {
+ return rl.items.EncodeRLP(w)
}
// blockReceiptsToNetwork69 takes a slice of rlp-encoded receipts, and transactions,
diff --git a/eth/protocols/eth/receipt_test.go b/eth/protocols/eth/receipt_test.go
index 3c73c07396..39a2728f7f 100644
--- a/eth/protocols/eth/receipt_test.go
+++ b/eth/protocols/eth/receipt_test.go
@@ -63,6 +63,18 @@ var receiptsTests = []struct {
input: []types.ReceiptForStorage{{CumulativeGasUsed: 555, Status: 1, Logs: receiptsTestLogs2}},
txs: []*types.Transaction{types.NewTx(&types.AccessListTx{})},
},
+ {
+ input: []types.ReceiptForStorage{
+ {CumulativeGasUsed: 111, PostState: common.HexToHash("0x1111").Bytes(), Logs: receiptsTestLogs1},
+ {CumulativeGasUsed: 222, Status: 0, Logs: receiptsTestLogs2},
+ {CumulativeGasUsed: 333, Status: 1, Logs: nil},
+ },
+ txs: []*types.Transaction{
+ types.NewTx(&types.LegacyTx{}),
+ types.NewTx(&types.AccessListTx{}),
+ types.NewTx(&types.DynamicFeeTx{}),
+ },
+ },
}
func init() {
@@ -103,7 +115,10 @@ func TestReceiptList69(t *testing.T) {
if err := rlp.DecodeBytes(network, &rl); err != nil {
t.Fatalf("test[%d]: can't decode network receipts: %v", i, err)
}
- rlStorageEnc := rl.EncodeForStorage()
+ rlStorageEnc, err := rl.EncodeForStorage()
+ if err != nil {
+ t.Fatalf("test[%d]: error from EncodeForStorage: %v", i, err)
+ }
if !bytes.Equal(rlStorageEnc, canonDB) {
t.Fatalf("test[%d]: re-encoded receipts not equal\nhave: %x\nwant: %x", i, rlStorageEnc, canonDB)
}
@@ -113,7 +128,7 @@ func TestReceiptList69(t *testing.T) {
}
// compute root hash from ReceiptList69 and compare.
- responseHash := types.DeriveSha(&rl, trie.NewStackTrie(nil))
+ responseHash := types.DeriveSha(rl.Derivable(), trie.NewStackTrie(nil))
if responseHash != test.root {
t.Fatalf("test[%d]: wrong root hash from ReceiptList69\nhave: %v\nwant: %v", i, responseHash, test.root)
}
@@ -140,7 +155,10 @@ func TestReceiptList68(t *testing.T) {
if err := rlp.DecodeBytes(network, &rl); err != nil {
t.Fatalf("test[%d]: can't decode network receipts: %v", i, err)
}
- rlStorageEnc := rl.EncodeForStorage()
+ rlStorageEnc, err := rl.EncodeForStorage()
+ if err != nil {
+ t.Fatalf("test[%d]: error from EncodeForStorage: %v", i, err)
+ }
if !bytes.Equal(rlStorageEnc, canonDB) {
t.Fatalf("test[%d]: re-encoded receipts not equal\nhave: %x\nwant: %x", i, rlStorageEnc, canonDB)
}
@@ -150,7 +168,7 @@ func TestReceiptList68(t *testing.T) {
}
// compute root hash from ReceiptList68 and compare.
- responseHash := types.DeriveSha(&rl, trie.NewStackTrie(nil))
+ responseHash := types.DeriveSha(rl.Derivable(), trie.NewStackTrie(nil))
if responseHash != test.root {
t.Fatalf("test[%d]: wrong root hash from ReceiptList68\nhave: %v\nwant: %v", i, responseHash, test.root)
}
diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go
index 3249720f90..071a0419fb 100644
--- a/eth/protocols/snap/handler.go
+++ b/eth/protocols/snap/handler.go
@@ -31,6 +31,8 @@ import (
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
+ "github.com/ethereum/go-ethereum/p2p/tracker"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"github.com/ethereum/go-ethereum/trie/trienode"
"github.com/ethereum/go-ethereum/triedb/database"
@@ -94,6 +96,7 @@ func MakeProtocols(backend Backend) []p2p.Protocol {
Length: protocolLengths[version],
Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
return backend.RunPeer(NewPeer(version, p, rw), func(peer *Peer) error {
+ defer peer.Close()
return Handle(backend, peer)
})
},
@@ -149,7 +152,6 @@ func HandleMessage(backend Backend, peer *Peer) error {
// Handle the message depending on its contents
switch {
case msg.Code == GetAccountRangeMsg:
- // Decode the account retrieval request
var req GetAccountRangePacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
@@ -165,23 +167,40 @@ func HandleMessage(backend Backend, peer *Peer) error {
})
case msg.Code == AccountRangeMsg:
- // A range of accounts arrived to one of our previous requests
- res := new(AccountRangePacket)
+ res := new(accountRangeInput)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
+
+ // Check response validity.
+ if len := res.Proof.Len(); len > 128 {
+ return fmt.Errorf("AccountRange: invalid proof (length %d)", len)
+ }
+ tresp := tracker.Response{ID: res.ID, MsgCode: AccountRangeMsg, Size: len(res.Accounts.Content())}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return err
+ }
+
+ // Decode.
+ accounts, err := res.Accounts.Items()
+ if err != nil {
+ return fmt.Errorf("AccountRange: invalid accounts list: %v", err)
+ }
+ proof, err := res.Proof.Items()
+ if err != nil {
+ return fmt.Errorf("AccountRange: invalid proof: %v", err)
+ }
+
// Ensure the range is monotonically increasing
- for i := 1; i < len(res.Accounts); i++ {
- if bytes.Compare(res.Accounts[i-1].Hash[:], res.Accounts[i].Hash[:]) >= 0 {
- return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, res.Accounts[i-1].Hash[:], i, res.Accounts[i].Hash[:])
+ for i := 1; i < len(accounts); i++ {
+ if bytes.Compare(accounts[i-1].Hash[:], accounts[i].Hash[:]) >= 0 {
+ return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, accounts[i-1].Hash[:], i, accounts[i].Hash[:])
}
}
- requestTracker.Fulfil(peer.id, peer.version, AccountRangeMsg, res.ID)
- return backend.Handle(peer, res)
+ return backend.Handle(peer, &AccountRangePacket{res.ID, accounts, proof})
case msg.Code == GetStorageRangesMsg:
- // Decode the storage retrieval request
var req GetStorageRangesPacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
@@ -197,25 +216,42 @@ func HandleMessage(backend Backend, peer *Peer) error {
})
case msg.Code == StorageRangesMsg:
- // A range of storage slots arrived to one of our previous requests
- res := new(StorageRangesPacket)
+ res := new(storageRangesInput)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
+
+ // Check response validity.
+ if len := res.Proof.Len(); len > 128 {
+ return fmt.Errorf("StorageRangesMsg: invalid proof (length %d)", len)
+ }
+ tresp := tracker.Response{ID: res.ID, MsgCode: StorageRangesMsg, Size: len(res.Slots.Content())}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("StorageRangesMsg: %w", err)
+ }
+
+ // Decode.
+ slotLists, err := res.Slots.Items()
+ if err != nil {
+ return fmt.Errorf("AccountRange: invalid accounts list: %v", err)
+ }
+ proof, err := res.Proof.Items()
+ if err != nil {
+ return fmt.Errorf("AccountRange: invalid proof: %v", err)
+ }
+
// Ensure the ranges are monotonically increasing
- for i, slots := range res.Slots {
+ for i, slots := range slotLists {
for j := 1; j < len(slots); j++ {
if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 {
return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:])
}
}
}
- requestTracker.Fulfil(peer.id, peer.version, StorageRangesMsg, res.ID)
- return backend.Handle(peer, res)
+ return backend.Handle(peer, &StorageRangesPacket{res.ID, slotLists, proof})
case msg.Code == GetByteCodesMsg:
- // Decode bytecode retrieval request
var req GetByteCodesPacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
@@ -230,17 +266,25 @@ func HandleMessage(backend Backend, peer *Peer) error {
})
case msg.Code == ByteCodesMsg:
- // A batch of byte codes arrived to one of our previous requests
- res := new(ByteCodesPacket)
+ res := new(byteCodesInput)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
- requestTracker.Fulfil(peer.id, peer.version, ByteCodesMsg, res.ID)
- return backend.Handle(peer, res)
+ length := res.Codes.Len()
+ tresp := tracker.Response{ID: res.ID, MsgCode: ByteCodesMsg, Size: length}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("ByteCodes: %w", err)
+ }
+
+ codes, err := res.Codes.Items()
+ if err != nil {
+ return fmt.Errorf("ByteCodes: %w", err)
+ }
+
+ return backend.Handle(peer, &ByteCodesPacket{res.ID, codes})
case msg.Code == GetTrieNodesMsg:
- // Decode trie node retrieval request
var req GetTrieNodesPacket
if err := msg.Decode(&req); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
@@ -257,14 +301,21 @@ func HandleMessage(backend Backend, peer *Peer) error {
})
case msg.Code == TrieNodesMsg:
- // A batch of trie nodes arrived to one of our previous requests
- res := new(TrieNodesPacket)
+ res := new(trieNodesInput)
if err := msg.Decode(res); err != nil {
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
}
- requestTracker.Fulfil(peer.id, peer.version, TrieNodesMsg, res.ID)
- return backend.Handle(peer, res)
+ tresp := tracker.Response{ID: res.ID, MsgCode: TrieNodesMsg, Size: res.Nodes.Len()}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("TrieNodes: %w", err)
+ }
+ nodes, err := res.Nodes.Items()
+ if err != nil {
+ return fmt.Errorf("TrieNodes: %w", err)
+ }
+
+ return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes})
default:
return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
@@ -512,21 +563,32 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s
if reader == nil {
reader, _ = triedb.StateReader(req.Root)
}
+
// Retrieve trie nodes until the packet size limit is reached
var (
- nodes [][]byte
- bytes uint64
- loads int // Trie hash expansions to count database reads
+ outerIt = req.Paths.ContentIterator()
+ nodes [][]byte
+ bytes uint64
+ loads int // Trie hash expansions to count database reads
)
- for _, pathset := range req.Paths {
- switch len(pathset) {
+ for outerIt.Next() {
+ innerIt, err := rlp.NewListIterator(outerIt.Value())
+ if err != nil {
+ return nodes, err
+ }
+
+ switch innerIt.Count() {
case 0:
// Ensure we penalize invalid requests
return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest)
case 1:
// If we're only retrieving an account trie node, fetch it directly
- blob, resolved, err := accTrie.GetNode(pathset[0])
+ accKey := nextBytes(&innerIt)
+ if accKey == nil {
+ return nodes, fmt.Errorf("%w: invalid account node request", errBadRequest)
+ }
+ blob, resolved, err := accTrie.GetNode(accKey)
loads += resolved // always account database reads, even for failures
if err != nil {
break
@@ -535,33 +597,41 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s
bytes += uint64(len(blob))
default:
- var stRoot common.Hash
-
// Storage slots requested, open the storage trie and retrieve from there
+ accKey := nextBytes(&innerIt)
+ if accKey == nil {
+ return nodes, fmt.Errorf("%w: invalid account storage request", errBadRequest)
+ }
+ var stRoot common.Hash
if reader == nil {
// We don't have the requested state snapshotted yet (or it is stale),
// but can look up the account via the trie instead.
- account, err := accTrie.GetAccountByHash(common.BytesToHash(pathset[0]))
+ account, err := accTrie.GetAccountByHash(common.BytesToHash(accKey))
loads += 8 // We don't know the exact cost of lookup, this is an estimate
if err != nil || account == nil {
break
}
stRoot = account.Root
} else {
- account, err := reader.Account(common.BytesToHash(pathset[0]))
+ account, err := reader.Account(common.BytesToHash(accKey))
loads++ // always account database reads, even for failures
if err != nil || account == nil {
break
}
stRoot = common.BytesToHash(account.Root)
}
- id := trie.StorageTrieID(req.Root, common.BytesToHash(pathset[0]), stRoot)
+
+ id := trie.StorageTrieID(req.Root, common.BytesToHash(accKey), stRoot)
stTrie, err := trie.NewStateTrie(id, triedb)
loads++ // always account database reads, even for failures
if err != nil {
break
}
- for _, path := range pathset[1:] {
+ for innerIt.Next() {
+ path, _, err := rlp.SplitString(innerIt.Value())
+ if err != nil {
+ return nil, fmt.Errorf("%w: invalid storage key: %v", errBadRequest, err)
+ }
blob, resolved, err := stTrie.GetNode(path)
loads += resolved // always account database reads, even for failures
if err != nil {
@@ -584,6 +654,17 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s
return nodes, nil
}
+func nextBytes(it *rlp.Iterator) []byte {
+ if !it.Next() {
+ return nil
+ }
+ content, _, err := rlp.SplitString(it.Value())
+ if err != nil {
+ return nil
+ }
+ return content
+}
+
// NodeInfo represents a short summary of the `snap` sub-protocol metadata
// known about the host peer.
type NodeInfo struct{}
diff --git a/eth/protocols/snap/peer.go b/eth/protocols/snap/peer.go
index c57931678c..0b96de4158 100644
--- a/eth/protocols/snap/peer.go
+++ b/eth/protocols/snap/peer.go
@@ -17,9 +17,13 @@
package snap
import (
+ "time"
+
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/tracker"
+ "github.com/ethereum/go-ethereum/rlp"
)
// Peer is a collection of relevant information we have about a `snap` peer.
@@ -29,6 +33,7 @@ type Peer struct {
*p2p.Peer // The embedded P2P package peer
rw p2p.MsgReadWriter // Input/output streams for snap
version uint // Protocol version negotiated
+ tracker *tracker.Tracker
logger log.Logger // Contextual logger with the peer id injected
}
@@ -36,22 +41,26 @@ type Peer struct {
// NewPeer creates a wrapper for a network connection and negotiated protocol
// version.
func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) *Peer {
+ cap := p2p.Cap{Name: ProtocolName, Version: version}
id := p.ID().String()
return &Peer{
id: id,
Peer: p,
rw: rw,
version: version,
+ tracker: tracker.New(cap, id, 1*time.Minute),
logger: log.New("peer", id[:8]),
}
}
// NewFakePeer creates a fake snap peer without a backing p2p peer, for testing purposes.
func NewFakePeer(version uint, id string, rw p2p.MsgReadWriter) *Peer {
+ cap := p2p.Cap{Name: ProtocolName, Version: version}
return &Peer{
id: id,
rw: rw,
version: version,
+ tracker: tracker.New(cap, id, 1*time.Minute),
logger: log.New("peer", id[:8]),
}
}
@@ -71,63 +80,99 @@ func (p *Peer) Log() log.Logger {
return p.logger
}
+// Close releases resources associated with the peer.
+func (p *Peer) Close() {
+ p.tracker.Stop()
+}
+
// RequestAccountRange fetches a batch of accounts rooted in a specific account
// trie, starting with the origin.
-func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes uint64) error {
+func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes int) error {
p.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes))
- requestTracker.Track(p.id, p.version, GetAccountRangeMsg, AccountRangeMsg, id)
+ err := p.tracker.Track(tracker.Request{
+ ReqCode: GetAccountRangeMsg,
+ RespCode: AccountRangeMsg,
+ ID: id,
+ Size: 2 * bytes,
+ })
+ if err != nil {
+ return err
+ }
return p2p.Send(p.rw, GetAccountRangeMsg, &GetAccountRangePacket{
ID: id,
Root: root,
Origin: origin,
Limit: limit,
- Bytes: bytes,
+ Bytes: uint64(bytes),
})
}
// RequestStorageRanges fetches a batch of storage slots belonging to one or more
// accounts. If slots from only one account is requested, an origin marker may also
// be used to retrieve from there.
-func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
+func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error {
if len(accounts) == 1 && origin != nil {
p.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes))
} else {
p.logger.Trace("Fetching ranges of small storage slots", "reqid", id, "root", root, "accounts", len(accounts), "first", accounts[0], "bytes", common.StorageSize(bytes))
}
- requestTracker.Track(p.id, p.version, GetStorageRangesMsg, StorageRangesMsg, id)
+
+ p.tracker.Track(tracker.Request{
+ ReqCode: GetStorageRangesMsg,
+ RespCode: StorageRangesMsg,
+ ID: id,
+ Size: 2 * bytes,
+ })
return p2p.Send(p.rw, GetStorageRangesMsg, &GetStorageRangesPacket{
ID: id,
Root: root,
Accounts: accounts,
Origin: origin,
Limit: limit,
- Bytes: bytes,
+ Bytes: uint64(bytes),
})
}
// RequestByteCodes fetches a batch of bytecodes by hash.
-func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
+func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error {
p.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes))
- requestTracker.Track(p.id, p.version, GetByteCodesMsg, ByteCodesMsg, id)
+ err := p.tracker.Track(tracker.Request{
+ ReqCode: GetByteCodesMsg,
+ RespCode: ByteCodesMsg,
+ ID: id,
+ Size: len(hashes), // ByteCodes is limited by the length of the hash list.
+ })
+ if err != nil {
+ return err
+ }
return p2p.Send(p.rw, GetByteCodesMsg, &GetByteCodesPacket{
ID: id,
Hashes: hashes,
- Bytes: bytes,
+ Bytes: uint64(bytes),
})
}
// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
-// a specific state trie.
-func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error {
+// a specific state trie. The `count` is the total count of paths being requested.
+func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error {
p.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes))
- requestTracker.Track(p.id, p.version, GetTrieNodesMsg, TrieNodesMsg, id)
+ err := p.tracker.Track(tracker.Request{
+ ReqCode: GetTrieNodesMsg,
+ RespCode: TrieNodesMsg,
+ ID: id,
+ Size: count, // TrieNodes is limited by number of items.
+ })
+ if err != nil {
+ return err
+ }
+ encPaths, _ := rlp.EncodeToRawList(paths)
return p2p.Send(p.rw, GetTrieNodesMsg, &GetTrieNodesPacket{
ID: id,
Root: root,
- Paths: paths,
- Bytes: bytes,
+ Paths: encPaths,
+ Bytes: uint64(bytes),
})
}
diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go
index 0db206b081..25fe25822b 100644
--- a/eth/protocols/snap/protocol.go
+++ b/eth/protocols/snap/protocol.go
@@ -78,6 +78,12 @@ type GetAccountRangePacket struct {
Bytes uint64 // Soft limit at which to stop returning data
}
+type accountRangeInput struct {
+ ID uint64 // ID of the request this is a response for
+ Accounts rlp.RawList[*AccountData] // List of consecutive accounts from the trie
+ Proof rlp.RawList[[]byte] // List of trie nodes proving the account range
+}
+
// AccountRangePacket represents an account query response.
type AccountRangePacket struct {
ID uint64 // ID of the request this is a response for
@@ -123,6 +129,12 @@ type GetStorageRangesPacket struct {
Bytes uint64 // Soft limit at which to stop returning data
}
+type storageRangesInput struct {
+ ID uint64 // ID of the request this is a response for
+ Slots rlp.RawList[[]*StorageData] // Lists of consecutive storage slots for the requested accounts
+ Proof rlp.RawList[[]byte] // Merkle proofs for the *last* slot range, if it's incomplete
+}
+
// StorageRangesPacket represents a storage slot query response.
type StorageRangesPacket struct {
ID uint64 // ID of the request this is a response for
@@ -161,6 +173,11 @@ type GetByteCodesPacket struct {
Bytes uint64 // Soft limit at which to stop returning data
}
+type byteCodesInput struct {
+ ID uint64 // ID of the request this is a response for
+ Codes rlp.RawList[[]byte] // Requested contract bytecodes
+}
+
// ByteCodesPacket represents a contract bytecode query response.
type ByteCodesPacket struct {
ID uint64 // ID of the request this is a response for
@@ -169,10 +186,10 @@ type ByteCodesPacket struct {
// GetTrieNodesPacket represents a state trie node query.
type GetTrieNodesPacket struct {
- ID uint64 // Request ID to match up responses with
- Root common.Hash // Root hash of the account trie to serve
- Paths []TrieNodePathSet // Trie node hashes to retrieve the nodes for
- Bytes uint64 // Soft limit at which to stop returning data
+ ID uint64 // Request ID to match up responses with
+ Root common.Hash // Root hash of the account trie to serve
+ Paths rlp.RawList[TrieNodePathSet] // Trie node hashes to retrieve the nodes for
+ Bytes uint64 // Soft limit at which to stop returning data
}
// TrieNodePathSet is a list of trie node paths to retrieve. A naive way to
@@ -187,6 +204,11 @@ type GetTrieNodesPacket struct {
// that a slot is accessed before the account path is fully expanded.
type TrieNodePathSet [][]byte
+type trieNodesInput struct {
+ ID uint64 // ID of the request this is a response for
+ Nodes rlp.RawList[[]byte] // Requested state trie nodes
+}
+
// TrieNodesPacket represents a state trie node query response.
type TrieNodesPacket struct {
ID uint64 // ID of the request this is a response for
diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go
index cf4e494645..08e85c896a 100644
--- a/eth/protocols/snap/sync.go
+++ b/eth/protocols/snap/sync.go
@@ -412,19 +412,19 @@ type SyncPeer interface {
// RequestAccountRange fetches a batch of accounts rooted in a specific account
// trie, starting with the origin.
- RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error
+ RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error
// RequestStorageRanges fetches a batch of storage slots belonging to one or
// more accounts. If slots from only one account is requested, an origin marker
// may also be used to retrieve from there.
- RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error
+ RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error
// RequestByteCodes fetches a batch of bytecodes by hash.
- RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error
+ RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error
// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in
// a specific state trie.
- RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error
+ RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error
// Log retrieves the peer's own contextual logger.
Log() log.Logger
@@ -1102,7 +1102,7 @@ func (s *Syncer) assignAccountTasks(success chan *accountResponse, fail chan *ac
if cap < minRequestSize { // Don't bother with peers below a bare minimum performance
cap = minRequestSize
}
- if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, uint64(cap)); err != nil {
+ if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, cap); err != nil {
peer.Log().Debug("Failed to request account range", "err", err)
s.scheduleRevertAccountRequest(req)
}
@@ -1359,7 +1359,7 @@ func (s *Syncer) assignStorageTasks(success chan *storageResponse, fail chan *st
if subtask != nil {
origin, limit = req.origin[:], req.limit[:]
}
- if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, uint64(cap)); err != nil {
+ if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, cap); err != nil {
log.Debug("Failed to request storage", "err", err)
s.scheduleRevertStorageRequest(req)
}
@@ -1492,7 +1492,7 @@ func (s *Syncer) assignTrienodeHealTasks(success chan *trienodeHealResponse, fai
defer s.pend.Done()
// Attempt to send the remote request and revert if it fails
- if err := peer.RequestTrieNodes(reqid, root, pathsets, maxRequestSize); err != nil {
+ if err := peer.RequestTrieNodes(reqid, root, len(paths), pathsets, maxRequestSize); err != nil {
log.Debug("Failed to request trienode healers", "err", err)
s.scheduleRevertTrienodeHealRequest(req)
}
diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go
index ed83af860a..b11ad4e78a 100644
--- a/eth/protocols/snap/sync_test.go
+++ b/eth/protocols/snap/sync_test.go
@@ -119,10 +119,10 @@ func BenchmarkHashing(b *testing.B) {
}
type (
- accountHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error
- storageHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error
- trieHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error
- codeHandlerFunc func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error
+ accountHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error
+ storageHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error
+ trieHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error
+ codeHandlerFunc func(t *testPeer, id uint64, hashes []common.Hash, max int) error
)
type testPeer struct {
@@ -182,21 +182,21 @@ Trienode requests: %d
`, t.nAccountRequests, t.nStorageRequests, t.nBytecodeRequests, t.nTrienodeRequests)
}
-func (t *testPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error {
+func (t *testPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error {
t.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes))
t.nAccountRequests++
go t.accountRequestHandler(t, id, root, origin, limit, bytes)
return nil
}
-func (t *testPeer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error {
+func (t *testPeer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error {
t.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes))
t.nTrienodeRequests++
go t.trieRequestHandler(t, id, root, paths, bytes)
return nil
}
-func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error {
+func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error {
t.nStorageRequests++
if len(accounts) == 1 && origin != nil {
t.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes))
@@ -207,7 +207,7 @@ func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []
return nil
}
-func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error {
+func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error {
t.nBytecodeRequests++
t.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes))
go t.codeRequestHandler(t, id, hashes, bytes)
@@ -215,7 +215,7 @@ func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint6
}
// defaultTrieRequestHandler is a well-behaving handler for trie healing requests
-func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error {
+func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error {
// Pass the response
var nodes [][]byte
for _, pathset := range paths {
@@ -244,7 +244,7 @@ func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash,
}
// defaultAccountRequestHandler is a well-behaving handler for AccountRangeRequests
-func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
keys, vals, proofs := createAccountRequestResponse(t, root, origin, limit, cap)
if err := t.remote.OnAccounts(t, id, keys, vals, proofs); err != nil {
t.test.Errorf("Remote side rejected our delivery: %v", err)
@@ -254,8 +254,8 @@ func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, orig
return nil
}
-func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) (keys []common.Hash, vals [][]byte, proofs [][]byte) {
- var size uint64
+func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap int) (keys []common.Hash, vals [][]byte, proofs [][]byte) {
+ var size int
if limit == (common.Hash{}) {
limit = common.MaxHash
}
@@ -266,7 +266,7 @@ func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.H
if bytes.Compare(origin[:], entry.k) <= 0 {
keys = append(keys, common.BytesToHash(entry.k))
vals = append(vals, entry.v)
- size += uint64(32 + len(entry.v))
+ size += 32 + len(entry.v)
}
// If we've exceeded the request threshold, abort
if bytes.Compare(entry.k, limit[:]) >= 0 {
@@ -290,7 +290,7 @@ func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.H
}
// defaultStorageRequestHandler is a well-behaving storage request handler
-func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max uint64) error {
+func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max int) error {
hashes, slots, proofs := createStorageRequestResponse(t, root, accounts, bOrigin, bLimit, max)
if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil {
t.test.Errorf("Remote side rejected our delivery: %v", err)
@@ -299,7 +299,7 @@ func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Has
return nil
}
-func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error {
var bytecodes [][]byte
for _, h := range hashes {
bytecodes = append(bytecodes, getCodeByHash(h))
@@ -311,8 +311,8 @@ func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max
return nil
}
-func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) {
- var size uint64
+func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) {
+ var size int
for _, account := range accounts {
// The first account might start from a different origin and end sooner
var originHash common.Hash
@@ -338,7 +338,7 @@ func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []comm
}
keys = append(keys, common.BytesToHash(entry.k))
vals = append(vals, entry.v)
- size += uint64(32 + len(entry.v))
+ size += 32 + len(entry.v)
if bytes.Compare(entry.k, limitHash[:]) >= 0 {
break
}
@@ -377,8 +377,8 @@ func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []comm
// createStorageRequestResponseAlwaysProve tests a cornercase, where the peer always
// supplies the proof for the last account, even if it is 'complete'.
-func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max uint64) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) {
- var size uint64
+func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max int) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) {
+ var size int
max = max * 3 / 4
var origin common.Hash
@@ -395,7 +395,7 @@ func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, acco
}
keys = append(keys, common.BytesToHash(entry.k))
vals = append(vals, entry.v)
- size += uint64(32 + len(entry.v))
+ size += 32 + len(entry.v)
if size > max {
exit = true
}
@@ -433,34 +433,34 @@ func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, acco
}
// emptyRequestAccountRangeFn is a rejects AccountRangeRequests
-func emptyRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+func emptyRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
t.remote.OnAccounts(t, requestId, nil, nil, nil)
return nil
}
-func nonResponsiveRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+func nonResponsiveRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
return nil
}
-func emptyTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error {
+func emptyTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error {
t.remote.OnTrieNodes(t, requestId, nil)
return nil
}
-func nonResponsiveTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error {
+func nonResponsiveTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error {
return nil
}
-func emptyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func emptyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
t.remote.OnStorage(t, requestId, nil, nil, nil)
return nil
}
-func nonResponsiveStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func nonResponsiveStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
return nil
}
-func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
hashes, slots, proofs := createStorageRequestResponseAlwaysProve(t, root, accounts, origin, limit, max)
if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil {
t.test.Errorf("Remote side rejected our delivery: %v", err)
@@ -475,7 +475,7 @@ func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.
// return nil
//}
-func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error {
var bytecodes [][]byte
for _, h := range hashes {
// Send back the hashes
@@ -489,7 +489,7 @@ func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max
return nil
}
-func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error {
var bytecodes [][]byte
for _, h := range hashes[:1] {
bytecodes = append(bytecodes, getCodeByHash(h))
@@ -503,11 +503,11 @@ func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max
}
// starvingStorageRequestHandler is somewhat well-behaving storage handler, but it caps the returned results to be very small
-func starvingStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func starvingStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
return defaultStorageRequestHandler(t, requestId, root, accounts, origin, limit, 500)
}
-func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
return defaultAccountRequestHandler(t, requestId, root, origin, limit, 500)
}
@@ -515,7 +515,7 @@ func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Ha
// return defaultAccountRequestHandler(t, requestId-1, root, origin, 500)
//}
-func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
hashes, accounts, proofs := createAccountRequestResponse(t, root, origin, limit, cap)
if len(proofs) > 0 {
proofs = proofs[1:]
@@ -529,7 +529,7 @@ func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Has
}
// corruptStorageRequestHandler doesn't provide good proofs
-func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
hashes, slots, proofs := createStorageRequestResponse(t, root, accounts, origin, limit, max)
if len(proofs) > 0 {
proofs = proofs[1:]
@@ -542,7 +542,7 @@ func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Has
return nil
}
-func noProofStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+func noProofStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
hashes, slots, _ := createStorageRequestResponse(t, root, accounts, origin, limit, max)
if err := t.remote.OnStorage(t, requestId, hashes, slots, nil); err != nil {
t.logger.Info("remote error on delivery (as expected)", "error", err)
@@ -577,7 +577,7 @@ func testSyncBloatedProof(t *testing.T, scheme string) {
source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems
- source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error {
+ source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error {
var (
keys []common.Hash
vals [][]byte
@@ -1165,7 +1165,7 @@ func testSyncNoStorageAndOneCodeCappedPeer(t *testing.T, scheme string) {
var counter int
syncer := setupSyncer(
nodeScheme,
- mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error {
+ mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max int) error {
counter++
return cappedCodeRequestHandler(t, id, hashes, max)
}),
@@ -1432,7 +1432,7 @@ func testSyncWithUnevenStorage(t *testing.T, scheme string) {
source.accountValues = accounts
source.setStorageTries(storageTries)
source.storageValues = storageElems
- source.storageRequestHandler = func(t *testPeer, reqId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error {
+ source.storageRequestHandler = func(t *testPeer, reqId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error {
return defaultStorageRequestHandler(t, reqId, root, accounts, origin, limit, 128) // retrieve storage in large mode
}
return source
diff --git a/eth/protocols/snap/tracker.go b/eth/protocols/snap/tracker.go
deleted file mode 100644
index 2cf59cc23a..0000000000
--- a/eth/protocols/snap/tracker.go
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2021 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package snap
-
-import (
- "time"
-
- "github.com/ethereum/go-ethereum/p2p/tracker"
-)
-
-// requestTracker is a singleton tracker for request times.
-var requestTracker = tracker.New(ProtocolName, time.Minute)
diff --git a/p2p/tracker/tracker.go b/p2p/tracker/tracker.go
index a1cf6f1119..3a9135aa3b 100644
--- a/p2p/tracker/tracker.go
+++ b/p2p/tracker/tracker.go
@@ -18,52 +18,63 @@ package tracker
import (
"container/list"
+ "errors"
"fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/p2p"
)
const (
- // trackedGaugeName is the prefix of the per-packet request tracking.
trackedGaugeName = "p2p/tracked"
-
- // lostMeterName is the prefix of the per-packet request expirations.
- lostMeterName = "p2p/lost"
-
- // staleMeterName is the prefix of the per-packet stale responses.
- staleMeterName = "p2p/stale"
-
- // waitHistName is the prefix of the per-packet (req only) waiting time histograms.
- waitHistName = "p2p/wait"
+ lostMeterName = "p2p/lost"
+ staleMeterName = "p2p/stale"
+ waitHistName = "p2p/wait"
// maxTrackedPackets is a huge number to act as a failsafe on the number of
// pending requests the node will track. It should never be hit unless an
// attacker figures out a way to spin requests.
- maxTrackedPackets = 100000
+ maxTrackedPackets = 10000
)
-// request tracks sent network requests which have not yet received a response.
-type request struct {
- peer string
- version uint // Protocol version
+var (
+ ErrNoMatchingRequest = errors.New("no matching request")
+ ErrTooManyItems = errors.New("response is larger than request allows")
+ ErrCollision = errors.New("request ID collision")
+ ErrCodeMismatch = errors.New("wrong response code for request")
+ ErrLimitReached = errors.New("request limit reached")
+ ErrStopped = errors.New("tracker stopped")
+)
- reqCode uint64 // Protocol message code of the request
- resCode uint64 // Protocol message code of the expected response
+// Request tracks sent network requests which have not yet received a response.
+type Request struct {
+ ID uint64 // Request ID
+ Size int // Number/size of requested items
+ ReqCode uint64 // Protocol message code of the request
+ RespCode uint64 // Protocol message code of the expected response
time time.Time // Timestamp when the request was made
expire *list.Element // Expiration marker to untrack it
}
+type Response struct {
+ ID uint64 // Request ID of the response
+ MsgCode uint64 // Protocol message code
+ Size int // number/size of items in response
+}
+
// Tracker is a pending network request tracker to measure how much time it takes
// a remote peer to respond.
type Tracker struct {
- protocol string // Protocol capability identifier for the metrics
- timeout time.Duration // Global timeout after which to drop a tracked packet
+ cap p2p.Cap // Protocol capability identifier for the metrics
- pending map[uint64]*request // Currently pending requests
+ peer string // Peer ID
+ timeout time.Duration // Global timeout after which to drop a tracked packet
+
+ pending map[uint64]*Request // Currently pending requests
expire *list.List // Linked list tracking the expiration order
wake *time.Timer // Timer tracking the expiration of the next item
@@ -72,52 +83,53 @@ type Tracker struct {
// New creates a new network request tracker to monitor how much time it takes to
// fill certain requests and how individual peers perform.
-func New(protocol string, timeout time.Duration) *Tracker {
+func New(cap p2p.Cap, peerID string, timeout time.Duration) *Tracker {
return &Tracker{
- protocol: protocol,
- timeout: timeout,
- pending: make(map[uint64]*request),
- expire: list.New(),
+ cap: cap,
+ peer: peerID,
+ timeout: timeout,
+ pending: make(map[uint64]*Request),
+ expire: list.New(),
}
}
// Track adds a network request to the tracker to wait for a response to arrive
// or until the request it cancelled or times out.
-func (t *Tracker) Track(peer string, version uint, reqCode uint64, resCode uint64, id uint64) {
- if !metrics.Enabled() {
- return
- }
+func (t *Tracker) Track(req Request) error {
t.lock.Lock()
defer t.lock.Unlock()
+ if t.expire == nil {
+ return ErrStopped
+ }
+
// If there's a duplicate request, we've just random-collided (or more probably,
// we have a bug), report it. We could also add a metric, but we're not really
// expecting ourselves to be buggy, so a noisy warning should be enough.
- if _, ok := t.pending[id]; ok {
- log.Error("Network request id collision", "protocol", t.protocol, "version", version, "code", reqCode, "id", id)
- return
+ if _, ok := t.pending[req.ID]; ok {
+ log.Error("Network request id collision", "cap", t.cap, "code", req.ReqCode, "id", req.ID)
+ return ErrCollision
}
// If we have too many pending requests, bail out instead of leaking memory
if pending := len(t.pending); pending >= maxTrackedPackets {
- log.Error("Request tracker exceeded allowance", "pending", pending, "peer", peer, "protocol", t.protocol, "version", version, "code", reqCode)
- return
+ log.Error("Request tracker exceeded allowance", "pending", pending, "peer", t.peer, "cap", t.cap, "code", req.ReqCode)
+ return ErrLimitReached
}
+
// Id doesn't exist yet, start tracking it
- t.pending[id] = &request{
- peer: peer,
- version: version,
- reqCode: reqCode,
- resCode: resCode,
- time: time.Now(),
- expire: t.expire.PushBack(id),
+ req.time = time.Now()
+ req.expire = t.expire.PushBack(req.ID)
+ t.pending[req.ID] = &req
+
+ if metrics.Enabled() {
+ t.trackedGauge(req.ReqCode).Inc(1)
}
- g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, version, reqCode)
- metrics.GetOrRegisterGauge(g, nil).Inc(1)
// If we've just inserted the first item, start the expiration timer
if t.wake == nil {
t.wake = time.AfterFunc(t.timeout, t.clean)
}
+ return nil
}
// clean is called automatically when a preset time passes without a response
@@ -142,11 +154,10 @@ func (t *Tracker) clean() {
t.expire.Remove(head)
delete(t.pending, id)
- g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, req.version, req.reqCode)
- metrics.GetOrRegisterGauge(g, nil).Dec(1)
-
- m := fmt.Sprintf("%s/%s/%d/%#02x", lostMeterName, t.protocol, req.version, req.reqCode)
- metrics.GetOrRegisterMeter(m, nil).Mark(1)
+ if metrics.Enabled() {
+ t.trackedGauge(req.ReqCode).Dec(1)
+ t.lostMeter(req.ReqCode).Mark(1)
+ }
}
t.schedule()
}
@@ -161,46 +172,92 @@ func (t *Tracker) schedule() {
t.wake = time.AfterFunc(time.Until(t.pending[t.expire.Front().Value.(uint64)].time.Add(t.timeout)), t.clean)
}
-// Fulfil fills a pending request, if any is available, reporting on various metrics.
-func (t *Tracker) Fulfil(peer string, version uint, code uint64, id uint64) {
- if !metrics.Enabled() {
- return
+// Stop reclaims resources of the tracker.
+func (t *Tracker) Stop() {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ if t.wake != nil {
+ t.wake.Stop()
+ t.wake = nil
}
+ if metrics.Enabled() {
+ // Ensure metrics are decremented for pending requests.
+ counts := make(map[uint64]int64)
+ for _, req := range t.pending {
+ counts[req.ReqCode]++
+ }
+ for code, count := range counts {
+ t.trackedGauge(code).Dec(count)
+ }
+ }
+ clear(t.pending)
+ t.expire = nil
+}
+
+// Fulfil fills a pending request, if any is available, reporting on various metrics.
+func (t *Tracker) Fulfil(resp Response) error {
t.lock.Lock()
defer t.lock.Unlock()
// If it's a non existing request, track as stale response
- req, ok := t.pending[id]
+ req, ok := t.pending[resp.ID]
if !ok {
- m := fmt.Sprintf("%s/%s/%d/%#02x", staleMeterName, t.protocol, version, code)
- metrics.GetOrRegisterMeter(m, nil).Mark(1)
- return
+ if metrics.Enabled() {
+ t.staleMeter(resp.MsgCode).Mark(1)
+ }
+ return ErrNoMatchingRequest
}
+
// If the response is funky, it might be some active attack
- if req.peer != peer || req.version != version || req.resCode != code {
- log.Warn("Network response id collision",
- "have", fmt.Sprintf("%s:%s/%d:%d", peer, t.protocol, version, code),
- "want", fmt.Sprintf("%s:%s/%d:%d", peer, t.protocol, req.version, req.resCode),
+ if req.RespCode != resp.MsgCode {
+ log.Warn("Network response code collision",
+ "have", fmt.Sprintf("%s:%s/%d:%d", t.peer, t.cap.Name, t.cap.Version, resp.MsgCode),
+ "want", fmt.Sprintf("%s:%s/%d:%d", t.peer, t.cap.Name, t.cap.Version, req.RespCode),
)
- return
+ return ErrCodeMismatch
}
+ if resp.Size > req.Size {
+ return ErrTooManyItems
+ }
+
// Everything matches, mark the request serviced and meter it
wasHead := req.expire.Prev() == nil
t.expire.Remove(req.expire)
- delete(t.pending, id)
+ delete(t.pending, req.ID)
if wasHead {
if t.wake.Stop() {
t.schedule()
}
}
- g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, req.version, req.reqCode)
- metrics.GetOrRegisterGauge(g, nil).Dec(1)
- h := fmt.Sprintf("%s/%s/%d/%#02x", waitHistName, t.protocol, req.version, req.reqCode)
- sampler := func() metrics.Sample {
- return metrics.ResettingSample(
- metrics.NewExpDecaySample(1028, 0.015),
- )
+ // Update request metrics.
+ if metrics.Enabled() {
+ t.trackedGauge(req.ReqCode).Dec(1)
+ t.waitHistogram(req.ReqCode).Update(time.Since(req.time).Microseconds())
}
- metrics.GetOrRegisterHistogramLazy(h, nil, sampler).Update(time.Since(req.time).Microseconds())
+ return nil
+}
+
+func (t *Tracker) trackedGauge(code uint64) *metrics.Gauge {
+ name := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.cap.Name, t.cap.Version, code)
+ return metrics.GetOrRegisterGauge(name, nil)
+}
+
+func (t *Tracker) lostMeter(code uint64) *metrics.Meter {
+ name := fmt.Sprintf("%s/%s/%d/%#02x", lostMeterName, t.cap.Name, t.cap.Version, code)
+ return metrics.GetOrRegisterMeter(name, nil)
+}
+
+func (t *Tracker) staleMeter(code uint64) *metrics.Meter {
+ name := fmt.Sprintf("%s/%s/%d/%#02x", staleMeterName, t.cap.Name, t.cap.Version, code)
+ return metrics.GetOrRegisterMeter(name, nil)
+}
+
+func (t *Tracker) waitHistogram(code uint64) metrics.Histogram {
+ name := fmt.Sprintf("%s/%s/%d/%#02x", waitHistName, t.cap.Name, t.cap.Version, code)
+ sampler := func() metrics.Sample {
+ return metrics.ResettingSample(metrics.NewExpDecaySample(1028, 0.015))
+ }
+ return metrics.GetOrRegisterHistogramLazy(name, nil, sampler)
}
diff --git a/p2p/tracker/tracker_test.go b/p2p/tracker/tracker_test.go
new file mode 100644
index 0000000000..a37a59f70f
--- /dev/null
+++ b/p2p/tracker/tracker_test.go
@@ -0,0 +1,64 @@
+// Copyright 2026 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package tracker
+
+import (
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/p2p"
+)
+
+// This checks that metrics gauges for pending requests are be decremented when a
+// Tracker is stopped.
+func TestMetricsOnStop(t *testing.T) {
+ metrics.Enable()
+
+ cap := p2p.Cap{Name: "test", Version: 1}
+ tr := New(cap, "peer1", time.Minute)
+
+ // Track some requests with different ReqCodes.
+ var id uint64
+ for i := 0; i < 3; i++ {
+ tr.Track(Request{ID: id, ReqCode: 0x01, RespCode: 0x02, Size: 1})
+ id++
+ }
+ for i := 0; i < 5; i++ {
+ tr.Track(Request{ID: id, ReqCode: 0x03, RespCode: 0x04, Size: 1})
+ id++
+ }
+
+ gauge1 := tr.trackedGauge(0x01)
+ gauge2 := tr.trackedGauge(0x03)
+
+ if gauge1.Snapshot().Value() != 3 {
+ t.Fatalf("gauge1 value mismatch: got %d, want 3", gauge1.Snapshot().Value())
+ }
+ if gauge2.Snapshot().Value() != 5 {
+ t.Fatalf("gauge2 value mismatch: got %d, want 5", gauge2.Snapshot().Value())
+ }
+
+ tr.Stop()
+
+ if gauge1.Snapshot().Value() != 0 {
+ t.Fatalf("gauge1 value after stop: got %d, want 0", gauge1.Snapshot().Value())
+ }
+ if gauge2.Snapshot().Value() != 0 {
+ t.Fatalf("gauge2 value after stop: got %d, want 0", gauge2.Snapshot().Value())
+ }
+}