diff --git a/cmd/devp2p/internal/ethtest/protocol.go b/cmd/devp2p/internal/ethtest/protocol.go index af76082318..a21d1ca7a1 100644 --- a/cmd/devp2p/internal/ethtest/protocol.go +++ b/cmd/devp2p/internal/ethtest/protocol.go @@ -86,9 +86,3 @@ func protoOffset(proto Proto) uint64 { panic("unhandled protocol") } } - -// msgTypePtr is the constraint for protocol message types. -type msgTypePtr[U any] interface { - *U - Kind() byte -} diff --git a/cmd/devp2p/internal/ethtest/snap.go b/cmd/devp2p/internal/ethtest/snap.go index f4fce0931f..7c1ca70cc0 100644 --- a/cmd/devp2p/internal/ethtest/snap.go +++ b/cmd/devp2p/internal/ethtest/snap.go @@ -30,6 +30,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/protocols/snap" "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie/trienode" ) @@ -937,10 +938,14 @@ func (s *Suite) snapGetTrieNodes(t *utesting.T, tc *trieNodesTest) error { } // write0 request + paths, err := rlp.EncodeToRawList(tc.paths) + if err != nil { + panic(err) + } req := &snap.GetTrieNodesPacket{ ID: uint64(rand.Int63()), Root: tc.root, - Paths: tc.paths, + Paths: paths, Bytes: tc.nBytes, } msg, err := conn.snapRequest(snap.GetTrieNodesMsg, req) diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go index c23360bf82..8bb488e358 100644 --- a/cmd/devp2p/internal/ethtest/suite.go +++ b/cmd/devp2p/internal/ethtest/suite.go @@ -34,6 +34,7 @@ import ( "github.com/ethereum/go-ethereum/internal/utesting" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/rlp" "github.com/holiman/uint256" ) @@ -151,7 +152,11 @@ func (s *Suite) TestGetBlockHeaders(t *utesting.T) { if err != nil { t.Fatalf("failed to get headers for given request: %v", err) } - if !headersMatch(expected, headers.BlockHeadersRequest) { + received, err := headers.List.Items() + if err != nil { + t.Fatalf("invalid headers received: %v", err) + } + if !headersMatch(expected, received) { t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers) } } @@ -237,7 +242,7 @@ concurrently, with different request IDs.`) // Wait for responses. // Note they can arrive in either order. - resp, err := collectResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 { + resp, err := collectHeaderResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 { if msg.RequestId != 111 && msg.RequestId != 222 { t.Fatalf("response with unknown request ID: %v", msg.RequestId) } @@ -248,17 +253,11 @@ concurrently, with different request IDs.`) } // Check if headers match. - resp1 := resp[111] - if expected, err := s.chain.GetHeaders(req1); err != nil { - t.Fatalf("failed to get expected headers for request 1: %v", err) - } else if !headersMatch(expected, resp1.BlockHeadersRequest) { - t.Fatalf("header mismatch for request ID %v: \nexpected %v \ngot %v", 111, expected, resp1) + if err := s.checkHeadersAgainstChain(req1, resp[111]); err != nil { + t.Fatal(err) } - resp2 := resp[222] - if expected, err := s.chain.GetHeaders(req2); err != nil { - t.Fatalf("failed to get expected headers for request 2: %v", err) - } else if !headersMatch(expected, resp2.BlockHeadersRequest) { - t.Fatalf("header mismatch for request ID %v: \nexpected %v \ngot %v", 222, expected, resp2) + if err := s.checkHeadersAgainstChain(req2, resp[222]); err != nil { + t.Fatal(err) } } @@ -303,8 +302,8 @@ same request ID. The node should handle the request by responding to both reques // Wait for the responses. They can arrive in either order, and we can't tell them // apart by their request ID, so use the number of headers instead. - resp, err := collectResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 { - id := uint64(len(msg.BlockHeadersRequest)) + resp, err := collectHeaderResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 { + id := uint64(msg.List.Len()) if id != 2 && id != 3 { t.Fatalf("invalid number of headers in response: %d", id) } @@ -315,26 +314,35 @@ same request ID. The node should handle the request by responding to both reques } // Check if headers match. - resp1 := resp[2] - if expected, err := s.chain.GetHeaders(request1); err != nil { - t.Fatalf("failed to get expected headers for request 1: %v", err) - } else if !headersMatch(expected, resp1.BlockHeadersRequest) { - t.Fatalf("headers mismatch: \nexpected %v \ngot %v", expected, resp1) + if err := s.checkHeadersAgainstChain(request1, resp[2]); err != nil { + t.Fatal(err) } - resp2 := resp[3] - if expected, err := s.chain.GetHeaders(request2); err != nil { - t.Fatalf("failed to get expected headers for request 2: %v", err) - } else if !headersMatch(expected, resp2.BlockHeadersRequest) { - t.Fatalf("headers mismatch: \nexpected %v \ngot %v", expected, resp2) + if err := s.checkHeadersAgainstChain(request2, resp[3]); err != nil { + t.Fatal(err) } } +func (s *Suite) checkHeadersAgainstChain(req *eth.GetBlockHeadersPacket, resp *eth.BlockHeadersPacket) error { + received2, err := resp.List.Items() + if err != nil { + return fmt.Errorf("invalid headers in response with request ID %v (%d items): %v", resp.RequestId, resp.List.Len(), err) + } + if expected, err := s.chain.GetHeaders(req); err != nil { + return fmt.Errorf("test chain failed to get expected headers for request: %v", err) + } else if !headersMatch(expected, received2) { + return fmt.Errorf("header mismatch for request ID %v (%d items): \nexpected %v \ngot %v", resp.RequestId, resp.List.Len(), expected, resp) + } + return nil +} + // collectResponses waits for n messages of type T on the given connection. // The messsages are collected according to the 'identity' function. -func collectResponses[T any, P msgTypePtr[T]](conn *Conn, n int, identity func(P) uint64) (map[uint64]P, error) { - resp := make(map[uint64]P, n) +// +// This function is written in a generic way to handle +func collectHeaderResponses(conn *Conn, n int, identity func(*eth.BlockHeadersPacket) uint64) (map[uint64]*eth.BlockHeadersPacket, error) { + resp := make(map[uint64]*eth.BlockHeadersPacket, n) for range n { - r := new(T) + r := new(eth.BlockHeadersPacket) if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, r); err != nil { return resp, fmt.Errorf("read error: %v", err) } @@ -373,10 +381,8 @@ and expects a response.`) if got, want := headers.RequestId, req.RequestId; got != want { t.Fatalf("unexpected request id") } - if expected, err := s.chain.GetHeaders(req); err != nil { - t.Fatalf("failed to get expected block headers: %v", err) - } else if !headersMatch(expected, headers.BlockHeadersRequest) { - t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers) + if err := s.checkHeadersAgainstChain(req, headers); err != nil { + t.Fatal(err) } } @@ -407,9 +413,8 @@ func (s *Suite) TestGetBlockBodies(t *utesting.T) { if got, want := resp.RequestId, req.RequestId; got != want { t.Fatalf("unexpected request id in respond", got, want) } - bodies := resp.BlockBodiesResponse - if len(bodies) != len(req.GetBlockBodiesRequest) { - t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetBlockBodiesRequest), len(bodies)) + if resp.List.Len() != len(req.GetBlockBodiesRequest) { + t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetBlockBodiesRequest), resp.List.Len()) } } @@ -433,7 +438,7 @@ func (s *Suite) TestGetReceipts(t *utesting.T) { } } - // Create block bodies request. + // Create receipts request. req := ð.GetReceiptsPacket{ RequestId: 66, GetReceiptsRequest: (eth.GetReceiptsRequest)(hashes), @@ -449,8 +454,8 @@ func (s *Suite) TestGetReceipts(t *utesting.T) { if got, want := resp.RequestId, req.RequestId; got != want { t.Fatalf("unexpected request id in respond", got, want) } - if len(resp.List) != len(req.GetReceiptsRequest) { - t.Fatalf("wrong bodies in response: expected %d bodies, got %d", len(req.GetReceiptsRequest), len(resp.List)) + if resp.List.Len() != len(req.GetReceiptsRequest) { + t.Fatalf("wrong receipts in response: expected %d receipts, got %d", len(req.GetReceiptsRequest), resp.List.Len()) } } @@ -804,7 +809,11 @@ on another peer connection using GetPooledTransactions.`) if got, want := msg.RequestId, req.RequestId; got != want { t.Fatalf("unexpected request id in response: got %d, want %d", got, want) } - for _, got := range msg.PooledTransactionsResponse { + responseTxs, err := msg.List.Items() + if err != nil { + t.Fatalf("invalid transactions in response: %v", err) + } + for _, got := range responseTxs { if _, exists := set[got.Hash()]; !exists { t.Fatalf("unexpected tx received: %v", got.Hash()) } @@ -976,7 +985,9 @@ func (s *Suite) TestBlobViolations(t *utesting.T) { if err := conn.ReadMsg(ethProto, eth.GetPooledTransactionsMsg, req); err != nil { t.Fatalf("reading pooled tx request failed: %v", err) } - resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: test.resp} + + encTxs, _ := rlp.EncodeToRawList(test.resp) + resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs} if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil { t.Fatalf("writing pooled tx response failed: %v", err) } @@ -1104,7 +1115,8 @@ func (s *Suite) testBadBlobTx(t *utesting.T, tx *types.Transaction, badTx *types // the good peer is connected, and has announced the tx. // proceed to send the incorrect one from the bad peer. - resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: eth.PooledTransactionsResponse(types.Transactions{badTx})} + encTxs, _ := rlp.EncodeToRawList([]*types.Transaction{badTx}) + resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs} if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil { errc <- fmt.Errorf("writing pooled tx response failed: %v", err) return @@ -1164,7 +1176,8 @@ func (s *Suite) testBadBlobTx(t *utesting.T, tx *types.Transaction, badTx *types return } - resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, PooledTransactionsResponse: eth.PooledTransactionsResponse(types.Transactions{tx})} + encTxs, _ := rlp.EncodeToRawList([]*types.Transaction{tx}) + resp := eth.PooledTransactionsPacket{RequestId: req.RequestId, List: encTxs} if err := conn.Write(ethProto, eth.PooledTransactionsMsg, resp); err != nil { errc <- fmt.Errorf("writing pooled tx response failed: %v", err) return diff --git a/cmd/devp2p/internal/ethtest/transaction.go b/cmd/devp2p/internal/ethtest/transaction.go index cbbbbce8d9..8ce26f3e1a 100644 --- a/cmd/devp2p/internal/ethtest/transaction.go +++ b/cmd/devp2p/internal/ethtest/transaction.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/rlp" ) // sendTxs sends the given transactions to the node and @@ -51,7 +52,8 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error { return fmt.Errorf("peering failed: %v", err) } - if err = sendConn.Write(ethProto, eth.TransactionsMsg, eth.TransactionsPacket(txs)); err != nil { + encTxs, _ := rlp.EncodeToRawList(txs) + if err = sendConn.Write(ethProto, eth.TransactionsMsg, eth.TransactionsPacket{RawList: encTxs}); err != nil { return fmt.Errorf("failed to write message to connection: %v", err) } @@ -68,7 +70,8 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error { } switch msg := msg.(type) { case *eth.TransactionsPacket: - for _, tx := range *msg { + txs, _ := msg.Items() + for _, tx := range txs { got[tx.Hash()] = true } case *eth.NewPooledTransactionHashesPacket: @@ -80,9 +83,10 @@ func (s *Suite) sendTxs(t *utesting.T, txs []*types.Transaction) error { if err != nil { t.Logf("invalid GetBlockHeaders request: %v", err) } + encHeaders, _ := rlp.EncodeToRawList(headers) recvConn.Write(ethProto, eth.BlockHeadersMsg, ð.BlockHeadersPacket{ - RequestId: msg.RequestId, - BlockHeadersRequest: headers, + RequestId: msg.RequestId, + List: encHeaders, }) default: return fmt.Errorf("unexpected eth wire msg: %s", pretty.Sdump(msg)) @@ -167,9 +171,10 @@ func (s *Suite) sendInvalidTxs(t *utesting.T, txs []*types.Transaction) error { if err != nil { t.Logf("invalid GetBlockHeaders request: %v", err) } + encHeaders, _ := rlp.EncodeToRawList(headers) recvConn.Write(ethProto, eth.BlockHeadersMsg, ð.BlockHeadersPacket{ - RequestId: msg.RequestId, - BlockHeadersRequest: headers, + RequestId: msg.RequestId, + List: encHeaders, }) default: return fmt.Errorf("unexpected eth message: %v", pretty.Sdump(msg)) diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 7fa2522a3d..e5d4a7c59b 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -214,10 +214,12 @@ func (dlp *downloadTesterPeer) RequestHeadersByNumber(origin uint64, amount int, func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *eth.Response) (*eth.Request, error) { blobs := eth.ServiceGetBlockBodiesQuery(dlp.chain, hashes) - bodies := make([]*eth.BlockBody, len(blobs)) + bodies := make([]*types.Body, len(blobs)) + ethbodies := make([]eth.BlockBody, len(blobs)) for i, blob := range blobs { - bodies[i] = new(eth.BlockBody) + bodies[i] = new(types.Body) rlp.DecodeBytes(blob, bodies[i]) + rlp.DecodeBytes(blob, ðbodies[i]) } var ( txsHashes = make([]common.Hash, len(bodies)) @@ -239,9 +241,13 @@ func (dlp *downloadTesterPeer) RequestBodies(hashes []common.Hash, sink chan *et Peer: dlp.id, } res := ð.Response{ - Req: req, - Res: (*eth.BlockBodiesResponse)(&bodies), - Meta: [][]common.Hash{txsHashes, uncleHashes, withdrawalHashes}, + Req: req, + Res: (*eth.BlockBodiesResponse)(ðbodies), + Meta: eth.BlockBodyHashes{ + TransactionRoots: txsHashes, + UncleHashes: uncleHashes, + WithdrawalRoots: withdrawalHashes, + }, Time: 1, Done: make(chan error, 1), // Ignore the returned status } @@ -290,14 +296,14 @@ func (dlp *downloadTesterPeer) ID() string { // RequestAccountRange fetches a batch of accounts rooted in a specific account // trie, starting with the origin. -func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error { +func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error { // Create the request and service it req := &snap.GetAccountRangePacket{ ID: id, Root: root, Origin: origin, Limit: limit, - Bytes: bytes, + Bytes: uint64(bytes), } slimaccs, proofs := snap.ServiceGetAccountRangeQuery(dlp.chain, req) @@ -316,7 +322,7 @@ func (dlp *downloadTesterPeer) RequestAccountRange(id uint64, root, origin, limi // RequestStorageRanges fetches a batch of storage slots belonging to one or // more accounts. If slots from only one account is requested, an origin marker // may also be used to retrieve from there. -func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error { +func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error { // Create the request and service it req := &snap.GetStorageRangesPacket{ ID: id, @@ -324,7 +330,7 @@ func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, Root: root, Origin: origin, Limit: limit, - Bytes: bytes, + Bytes: uint64(bytes), } storage, proofs := snap.ServiceGetStorageRangesQuery(dlp.chain, req) @@ -341,25 +347,28 @@ func (dlp *downloadTesterPeer) RequestStorageRanges(id uint64, root common.Hash, } // RequestByteCodes fetches a batch of bytecodes by hash. -func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error { +func (dlp *downloadTesterPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error { req := &snap.GetByteCodesPacket{ ID: id, Hashes: hashes, - Bytes: bytes, + Bytes: uint64(bytes), } codes := snap.ServiceGetByteCodesQuery(dlp.chain, req) go dlp.dl.downloader.SnapSyncer.OnByteCodes(dlp, id, codes) return nil } -// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in -// a specific state trie. -func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, paths []snap.TrieNodePathSet, bytes uint64) error { +// RequestTrieNodes fetches a batch of account or storage trie nodes. +func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []snap.TrieNodePathSet, bytes int) error { + encPaths, err := rlp.EncodeToRawList(paths) + if err != nil { + panic(err) + } req := &snap.GetTrieNodesPacket{ ID: id, Root: root, - Paths: paths, - Bytes: bytes, + Paths: encPaths, + Bytes: uint64(bytes), } nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req, time.Now()) go dlp.dl.downloader.SnapSyncer.OnTrieNodes(dlp, id, nodes) diff --git a/eth/downloader/fetchers_concurrent_bodies.go b/eth/downloader/fetchers_concurrent_bodies.go index 56359b33c9..6a8eb35219 100644 --- a/eth/downloader/fetchers_concurrent_bodies.go +++ b/eth/downloader/fetchers_concurrent_bodies.go @@ -88,15 +88,14 @@ func (q *bodyQueue) request(peer *peerConnection, req *fetchRequest, resCh chan // deliver is responsible for taking a generic response packet from the concurrent // fetcher, unpacking the body data and delivering it to the downloader's queue. func (q *bodyQueue) deliver(peer *peerConnection, packet *eth.Response) (int, error) { - txs, uncles, withdrawals := packet.Res.(*eth.BlockBodiesResponse).Unpack() - hashsets := packet.Meta.([][]common.Hash) // {txs hashes, uncle hashes, withdrawal hashes} - - accepted, err := q.queue.DeliverBodies(peer.id, txs, hashsets[0], uncles, hashsets[1], withdrawals, hashsets[2]) + resp := packet.Res.(*eth.BlockBodiesResponse) + meta := packet.Meta.(eth.BlockBodyHashes) + accepted, err := q.queue.DeliverBodies(peer.id, meta, *resp) switch { - case err == nil && len(txs) == 0: + case err == nil && len(*resp) == 0: peer.log.Trace("Requested bodies delivered") case err == nil: - peer.log.Trace("Delivered new batch of bodies", "count", len(txs), "accepted", accepted) + peer.log.Trace("Delivered new batch of bodies", "count", len(*resp), "accepted", accepted) default: peer.log.Debug("Failed to deliver retrieved bodies", "err", err) } diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index 76a14345e5..c0cb9b174a 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -29,11 +29,10 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/prque" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto/kzg4844" "github.com/ethereum/go-ethereum/eth/ethconfig" + "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" - "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" ) @@ -559,63 +558,54 @@ func (q *queue) expire(peer string, pendPool map[string]*fetchRequest, taskQueue // DeliverBodies injects a block body retrieval response into the results queue. // The method returns the number of blocks bodies accepted from the delivery and // also wakes any threads waiting for data delivery. -func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, txListHashes []common.Hash, - uncleLists [][]*types.Header, uncleListHashes []common.Hash, - withdrawalLists [][]*types.Withdrawal, withdrawalListHashes []common.Hash, -) (int, error) { +func (q *queue) DeliverBodies(id string, hashes eth.BlockBodyHashes, bodies []eth.BlockBody) (int, error) { q.lock.Lock() defer q.lock.Unlock() + var txLists [][]*types.Transaction + var uncleLists [][]*types.Header + var withdrawalLists [][]*types.Withdrawal + validate := func(index int, header *types.Header) error { - if txListHashes[index] != header.TxHash { + if hashes.TransactionRoots[index] != header.TxHash { return errInvalidBody } - if uncleListHashes[index] != header.UncleHash { + if hashes.UncleHashes[index] != header.UncleHash { return errInvalidBody } if header.WithdrawalsHash == nil { // nil hash means that withdrawals should not be present in body - if withdrawalLists[index] != nil { + if bodies[index].Withdrawals != nil { return errInvalidBody } } else { // non-nil hash: body must have withdrawals - if withdrawalLists[index] == nil { + if bodies[index].Withdrawals == nil { return errInvalidBody } - if withdrawalListHashes[index] != *header.WithdrawalsHash { + if hashes.WithdrawalRoots[index] != *header.WithdrawalsHash { return errInvalidBody } } - // Blocks must have a number of blobs corresponding to the header gas usage, - // and zero before the Cancun hardfork. - var blobs int - for _, tx := range txLists[index] { - // Count the number of blobs to validate against the header's blobGasUsed - blobs += len(tx.BlobHashes()) - // Validate the data blobs individually too - if tx.Type() == types.BlobTxType { - if len(tx.BlobHashes()) == 0 { - return errInvalidBody - } - for _, hash := range tx.BlobHashes() { - if !kzg4844.IsValidVersionedHash(hash[:]) { - return errInvalidBody - } - } - if tx.BlobTxSidecar() != nil { - return errInvalidBody - } - } + // decode + txs, err := bodies[index].Transactions.Items() + if err != nil { + return fmt.Errorf("%w: bad transactions: %v", errInvalidBody, err) } - if header.BlobGasUsed != nil { - if want := *header.BlobGasUsed / params.BlobTxBlobGasPerBlob; uint64(blobs) != want { // div because the header is surely good vs the body might be bloated - return errInvalidBody + txLists = append(txLists, txs) + uncles, err := bodies[index].Uncles.Items() + if err != nil { + return fmt.Errorf("%w: bad uncles: %v", errInvalidBody, err) + } + uncleLists = append(uncleLists, uncles) + if bodies[index].Withdrawals != nil { + withdrawals, err := bodies[index].Withdrawals.Items() + if err != nil { + return fmt.Errorf("%w: bad withdrawals: %v", errInvalidBody, err) } + withdrawalLists = append(withdrawalLists, withdrawals) } else { - if blobs != 0 { - return errInvalidBody - } + withdrawalLists = append(withdrawalLists, nil) } return nil } @@ -626,8 +616,9 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, txListH result.Withdrawals = withdrawalLists[index] result.SetBodyDone() } + nresults := len(hashes.TransactionRoots) return q.deliver(id, q.blockTaskPool, q.blockTaskQueue, q.blockPendPool, - bodyReqTimer, bodyInMeter, bodyDropMeter, len(txLists), validate, reconstruct) + bodyReqTimer, bodyInMeter, bodyDropMeter, nresults, validate, reconstruct) } // DeliverReceipts injects a receipt retrieval response into the results queue. diff --git a/eth/downloader/queue_test.go b/eth/downloader/queue_test.go index ca71a769de..c7e8a0d1d6 100644 --- a/eth/downloader/queue_test.go +++ b/eth/downloader/queue_test.go @@ -30,8 +30,10 @@ import ( "github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -323,26 +325,31 @@ func XTestDelivery(t *testing.T) { emptyList []*types.Header txset [][]*types.Transaction uncleset [][]*types.Header + bodies []eth.BlockBody ) numToSkip := rand.Intn(len(f.Headers)) for _, hdr := range f.Headers[0 : len(f.Headers)-numToSkip] { - txset = append(txset, world.getTransactions(hdr.Number.Uint64())) + txs := world.getTransactions(hdr.Number.Uint64()) + txset = append(txset, txs) uncleset = append(uncleset, emptyList) + txsList, _ := rlp.EncodeToRawList(txs) + bodies = append(bodies, eth.BlockBody{Transactions: txsList}) + } + hashes := eth.BlockBodyHashes{ + TransactionRoots: make([]common.Hash, len(txset)), + UncleHashes: make([]common.Hash, len(uncleset)), + WithdrawalRoots: make([]common.Hash, len(txset)), } - var ( - txsHashes = make([]common.Hash, len(txset)) - uncleHashes = make([]common.Hash, len(uncleset)) - ) hasher := trie.NewStackTrie(nil) for i, txs := range txset { - txsHashes[i] = types.DeriveSha(types.Transactions(txs), hasher) + hashes.TransactionRoots[i] = types.DeriveSha(types.Transactions(txs), hasher) } for i, uncles := range uncleset { - uncleHashes[i] = types.CalcUncleHash(uncles) + hashes.UncleHashes[i] = types.CalcUncleHash(uncles) } + time.Sleep(100 * time.Millisecond) - _, err := q.DeliverBodies(peer.id, txset, txsHashes, uncleset, uncleHashes, nil, nil) - if err != nil { + if _, err := q.DeliverBodies(peer.id, hashes, bodies); err != nil { fmt.Printf("delivered %d bodies %v\n", len(txset), err) } } else { diff --git a/eth/handler_eth.go b/eth/handler_eth.go index 11742b14ad..8704a86af4 100644 --- a/eth/handler_eth.go +++ b/eth/handler_eth.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/protocols/eth" @@ -61,19 +62,42 @@ func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error { return h.txFetcher.Notify(peer.ID(), packet.Types, packet.Sizes, packet.Hashes) case *eth.TransactionsPacket: - for _, tx := range *packet { - if tx.Type() == types.BlobTxType { - return errors.New("disallowed broadcast blob transaction") - } + txs, err := packet.Items() + if err != nil { + return fmt.Errorf("Transactions: %v", err) } - return h.txFetcher.Enqueue(peer.ID(), *packet, false) + if err := handleTransactions(peer, txs, true); err != nil { + return fmt.Errorf("Transactions: %v", err) + } + return h.txFetcher.Enqueue(peer.ID(), txs, false) - case *eth.PooledTransactionsResponse: - // If we receive any blob transactions missing sidecars, or with - // sidecars that don't correspond to the versioned hashes reported - // in the header, disconnect from the sending peer. - for _, tx := range *packet { - if tx.Type() == types.BlobTxType { + case *eth.PooledTransactionsPacket: + txs, err := packet.List.Items() + if err != nil { + return fmt.Errorf("PooledTransactions: %v", err) + } + if err := handleTransactions(peer, txs, false); err != nil { + return fmt.Errorf("PooledTransactions: %v", err) + } + return h.txFetcher.Enqueue(peer.ID(), txs, true) + + default: + return fmt.Errorf("unexpected eth packet type: %T", packet) + } +} + +// handleTransactions marks all given transactions as known to the peer +// and performs basic validations. +func handleTransactions(peer *eth.Peer, list []*types.Transaction, directBroadcast bool) error { + seen := make(map[common.Hash]struct{}) + for _, tx := range list { + if tx.Type() == types.BlobTxType { + if directBroadcast { + return errors.New("disallowed broadcast blob transaction") + } else { + // If we receive any blob transactions missing sidecars, or with + // sidecars that don't correspond to the versioned hashes reported + // in the header, disconnect from the sending peer. if tx.BlobTxSidecar() == nil { return errors.New("received sidecar-less blob transaction") } @@ -82,9 +106,16 @@ func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error { } } } - return h.txFetcher.Enqueue(peer.ID(), *packet, true) - default: - return fmt.Errorf("unexpected eth packet type: %T", packet) + // Check for duplicates. + hash := tx.Hash() + if _, exists := seen[hash]; exists { + return fmt.Errorf("multiple copies of the same hash %v", hash) + } + seen[hash] = struct{}{} + + // Mark as known. + peer.MarkTransaction(hash) } + return nil } diff --git a/eth/handler_eth_test.go b/eth/handler_eth_test.go index 1343cae03e..0330713071 100644 --- a/eth/handler_eth_test.go +++ b/eth/handler_eth_test.go @@ -60,11 +60,19 @@ func (h *testEthHandler) Handle(peer *eth.Peer, packet eth.Packet) error { return nil case *eth.TransactionsPacket: - h.txBroadcasts.Send(([]*types.Transaction)(*packet)) + txs, err := packet.Items() + if err != nil { + return err + } + h.txBroadcasts.Send(txs) return nil - case *eth.PooledTransactionsResponse: - h.txBroadcasts.Send(([]*types.Transaction)(*packet)) + case *eth.PooledTransactionsPacket: + txs, err := packet.List.Items() + if err != nil { + return err + } + h.txBroadcasts.Send(txs) return nil default: diff --git a/eth/protocols/eth/dispatcher.go b/eth/protocols/eth/dispatcher.go index cba40596fc..3f78fb4646 100644 --- a/eth/protocols/eth/dispatcher.go +++ b/eth/protocols/eth/dispatcher.go @@ -22,6 +22,7 @@ import ( "time" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/tracker" ) var ( @@ -47,9 +48,10 @@ type Request struct { sink chan *Response // Channel to deliver the response on cancel chan struct{} // Channel to cancel requests ahead of time - code uint64 // Message code of the request packet - want uint64 // Message code of the response packet - data interface{} // Data content of the request packet + code uint64 // Message code of the request packet + want uint64 // Message code of the response packet + numItems int // Number of requested items + data interface{} // Data content of the request packet Peer string // Demultiplexer if cross-peer requests are batched together Sent time.Time // Timestamp when the request was sent @@ -190,19 +192,30 @@ func (p *Peer) dispatchResponse(res *Response, metadata func() interface{}) erro func (p *Peer) dispatcher() { pending := make(map[uint64]*Request) +loop: for { select { case reqOp := <-p.reqDispatch: req := reqOp.req req.Sent = time.Now() - requestTracker.Track(p.id, p.version, req.code, req.want, req.id) - err := p2p.Send(p.rw, req.code, req.data) - reqOp.fail <- err - - if err == nil { - pending[req.id] = req + treq := tracker.Request{ + ID: req.id, + ReqCode: req.code, + RespCode: req.want, + Size: req.numItems, } + if err := p.tracker.Track(treq); err != nil { + reqOp.fail <- err + continue loop + } + if err := p2p.Send(p.rw, req.code, req.data); err != nil { + reqOp.fail <- err + continue loop + } + + pending[req.id] = req + reqOp.fail <- nil case cancelOp := <-p.reqCancel: // Retrieve the pending request to cancel and short circuit if it @@ -220,9 +233,6 @@ func (p *Peer) dispatcher() { res := resOp.res res.Req = pending[res.id] - // Independent if the request exists or not, track this packet - requestTracker.Fulfil(p.id, p.version, res.code, res.id) - switch { case res.Req == nil: // Response arrived with an untracked ID. Since even cancelled @@ -249,6 +259,7 @@ func (p *Peer) dispatcher() { } case <-p.term: + p.tracker.Stop() return } } diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go index 65c491f815..8f7f82c3a1 100644 --- a/eth/protocols/eth/handler_test.go +++ b/eth/protocols/eth/handler_test.go @@ -23,9 +23,11 @@ import ( "math/big" "math/rand" "os" + "reflect" "testing" "time" + "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/beacon" "github.com/ethereum/go-ethereum/consensus/ethash" @@ -42,6 +44,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" "github.com/holiman/uint256" ) @@ -360,8 +363,8 @@ func testGetBlockHeaders(t *testing.T, protocol uint) { GetBlockHeadersRequest: tt.query, }) if err := p2p.ExpectMsg(peer.app, BlockHeadersMsg, &BlockHeadersPacket{ - RequestId: 123, - BlockHeadersRequest: headers, + RequestId: 123, + List: encodeRL(headers), }); err != nil { t.Errorf("test %d: headers mismatch: %v", i, err) } @@ -374,7 +377,7 @@ func testGetBlockHeaders(t *testing.T, protocol uint) { RequestId: 456, GetBlockHeadersRequest: tt.query, }) - expected := &BlockHeadersPacket{RequestId: 456, BlockHeadersRequest: headers} + expected := &BlockHeadersPacket{RequestId: 456, List: encodeRL(headers)} if err := p2p.ExpectMsg(peer.app, BlockHeadersMsg, expected); err != nil { t.Errorf("test %d by hash: headers mismatch: %v", i, err) } @@ -437,7 +440,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) { // Collect the hashes to request, and the response to expect var ( hashes []common.Hash - bodies []*BlockBody + bodies []BlockBody seen = make(map[int64]bool) ) for j := 0; j < tt.random; j++ { @@ -449,7 +452,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) { block := backend.chain.GetBlockByNumber(uint64(num)) hashes = append(hashes, block.Hash()) if len(bodies) < tt.expected { - bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles(), Withdrawals: block.Withdrawals()}) + bodies = append(bodies, encodeBody(block)) } break } @@ -459,7 +462,7 @@ func testGetBlockBodies(t *testing.T, protocol uint) { hashes = append(hashes, hash) if tt.available[j] && len(bodies) < tt.expected { block := backend.chain.GetBlockByHash(hash) - bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles(), Withdrawals: block.Withdrawals()}) + bodies = append(bodies, encodeBody(block)) } } @@ -469,14 +472,69 @@ func testGetBlockBodies(t *testing.T, protocol uint) { GetBlockBodiesRequest: hashes, }) if err := p2p.ExpectMsg(peer.app, BlockBodiesMsg, &BlockBodiesPacket{ - RequestId: 123, - BlockBodiesResponse: bodies, + RequestId: 123, + List: encodeRL(bodies), }); err != nil { t.Fatalf("test %d: bodies mismatch: %v", i, err) } } } +func encodeBody(b *types.Block) BlockBody { + body := BlockBody{ + Transactions: encodeRL([]*types.Transaction(b.Transactions())), + Uncles: encodeRL(b.Uncles()), + } + if b.Withdrawals() != nil { + wd := encodeRL([]*types.Withdrawal(b.Withdrawals())) + body.Withdrawals = &wd + } + return body +} + +func TestHashBody(t *testing.T) { + key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") + signer := types.NewCancunSigner(big.NewInt(1)) + + // create block 1 + header := &types.Header{Number: big.NewInt(11)} + txs := []*types.Transaction{ + types.MustSignNewTx(key, signer, &types.DynamicFeeTx{ + ChainID: big.NewInt(1), + Nonce: 1, + Data: []byte("testing"), + }), + types.MustSignNewTx(key, signer, &types.LegacyTx{ + Nonce: 2, + Data: []byte("testing"), + }), + } + uncles := []*types.Header{{Number: big.NewInt(10)}} + body1 := &types.Body{Transactions: txs, Uncles: uncles} + block1 := types.NewBlock(header, body1, nil, trie.NewStackTrie(nil)) + + // create block 2 (has withdrawals) + header2 := &types.Header{Number: big.NewInt(12)} + body2 := &types.Body{ + Withdrawals: []*types.Withdrawal{{Index: 10}, {Index: 11}}, + } + block2 := types.NewBlock(header2, body2, nil, trie.NewStackTrie(nil)) + + expectedHashes := BlockBodyHashes{ + TransactionRoots: []common.Hash{block1.TxHash(), block2.TxHash()}, + WithdrawalRoots: []common.Hash{common.Hash{}, *block2.Header().WithdrawalsHash}, + UncleHashes: []common.Hash{block1.UncleHash(), block2.UncleHash()}, + } + + // compute hash like protocol handler does + protocolBodies := []BlockBody{encodeBody(block1), encodeBody(block2)} + hashes := hashBodyParts(protocolBodies) + if !reflect.DeepEqual(hashes, expectedHashes) { + t.Errorf("wrong hashes: %s", spew.Sdump(hashes)) + t.Logf("expected: %s", spew.Sdump(expectedHashes)) + } +} + // Tests that the transaction receipts can be retrieved based on hashes. func TestGetBlockReceipts68(t *testing.T) { testGetBlockReceipts(t, ETH68) } @@ -528,13 +586,13 @@ func testGetBlockReceipts(t *testing.T, protocol uint) { // Collect the hashes to request, and the response to expect var ( hashes []common.Hash - receipts []*ReceiptList68 + receipts rlp.RawList[*ReceiptList68] ) for i := uint64(0); i <= backend.chain.CurrentBlock().Number.Uint64(); i++ { block := backend.chain.GetBlockByNumber(i) hashes = append(hashes, block.Hash()) trs := backend.chain.GetReceiptsByHash(block.Hash()) - receipts = append(receipts, NewReceiptList68(trs)) + receipts.Append(NewReceiptList68(trs)) } // Send the hash request and verify the response @@ -688,10 +746,18 @@ func testGetPooledTransaction(t *testing.T, blobTx bool) { RequestId: 123, GetPooledTransactionsRequest: []common.Hash{tx.Hash()}, }) - if err := p2p.ExpectMsg(peer.app, PooledTransactionsMsg, PooledTransactionsPacket{ - RequestId: 123, - PooledTransactionsResponse: []*types.Transaction{tx}, + if err := p2p.ExpectMsg(peer.app, PooledTransactionsMsg, &PooledTransactionsPacket{ + RequestId: 123, + List: encodeRL([]*types.Transaction{tx}), }); err != nil { t.Errorf("pooled transaction mismatch: %v", err) } } + +func encodeRL[T any](slice []T) rlp.RawList[T] { + rl, err := rlp.EncodeToRawList(slice) + if err != nil { + panic(err) + } + return rl +} diff --git a/eth/protocols/eth/handlers.go b/eth/protocols/eth/handlers.go index aad3353d88..7f1ccc360d 100644 --- a/eth/protocols/eth/handlers.go +++ b/eth/protocols/eth/handlers.go @@ -17,23 +17,22 @@ package eth import ( + "bytes" "encoding/json" "errors" "fmt" - "time" + "math" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/tracker" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) -// requestTracker is a singleton tracker for eth/66 and newer request times. -var requestTracker = tracker.New(ProtocolName, 5*time.Minute) - func handleGetBlockHeaders(backend Backend, msg Decoder, peer *Peer) error { // Decode the complex header query var query GetBlockHeadersPacket @@ -356,9 +355,18 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error { if err := msg.Decode(res); err != nil { return err } + tresp := tracker.Response{ID: res.RequestId, MsgCode: BlockHeadersMsg, Size: res.List.Len()} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("BlockHeaders: %w", err) + } + headers, err := res.List.Items() + if err != nil { + return fmt.Errorf("BlockHeaders: %w", err) + } + metadata := func() interface{} { - hashes := make([]common.Hash, len(res.BlockHeadersRequest)) - for i, header := range res.BlockHeadersRequest { + hashes := make([]common.Hash, len(headers)) + for i, header := range headers { hashes[i] = header.Hash() } return hashes @@ -366,7 +374,7 @@ func handleBlockHeaders(backend Backend, msg Decoder, peer *Peer) error { return peer.dispatchResponse(&Response{ id: res.RequestId, code: BlockHeadersMsg, - Res: &res.BlockHeadersRequest, + Res: (*BlockHeadersRequest)(&headers), }, metadata) } @@ -376,53 +384,150 @@ func handleBlockBodies(backend Backend, msg Decoder, peer *Peer) error { if err := msg.Decode(res); err != nil { return err } - metadata := func() interface{} { - var ( - txsHashes = make([]common.Hash, len(res.BlockBodiesResponse)) - uncleHashes = make([]common.Hash, len(res.BlockBodiesResponse)) - withdrawalHashes = make([]common.Hash, len(res.BlockBodiesResponse)) - ) - hasher := trie.NewStackTrie(nil) - for i, body := range res.BlockBodiesResponse { - txsHashes[i] = types.DeriveSha(types.Transactions(body.Transactions), hasher) - uncleHashes[i] = types.CalcUncleHash(body.Uncles) - if body.Withdrawals != nil { - withdrawalHashes[i] = types.DeriveSha(types.Withdrawals(body.Withdrawals), hasher) - } - } - return [][]common.Hash{txsHashes, uncleHashes, withdrawalHashes} + + // Check against the request. + length := res.List.Len() + tresp := tracker.Response{ID: res.RequestId, MsgCode: BlockBodiesMsg, Size: length} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("BlockBodies: %w", err) } + + // Collect items and dispatch. + items, err := res.List.Items() + if err != nil { + return fmt.Errorf("BlockBodies: %w", err) + } + metadata := func() any { return hashBodyParts(items) } return peer.dispatchResponse(&Response{ id: res.RequestId, code: BlockBodiesMsg, - Res: &res.BlockBodiesResponse, + Res: (*BlockBodiesResponse)(&items), }, metadata) } +// BlockBodyHashes contains the lists of block body part roots for a list of block bodies. +type BlockBodyHashes struct { + TransactionRoots []common.Hash + WithdrawalRoots []common.Hash + UncleHashes []common.Hash +} + +func hashBodyParts(items []BlockBody) BlockBodyHashes { + h := BlockBodyHashes{ + TransactionRoots: make([]common.Hash, len(items)), + WithdrawalRoots: make([]common.Hash, len(items)), + UncleHashes: make([]common.Hash, len(items)), + } + hasher := trie.NewStackTrie(nil) + for i, body := range items { + // txs + txsList := newDerivableRawList(&body.Transactions, writeTxForHash) + h.TransactionRoots[i] = types.DeriveSha(txsList, hasher) + // uncles + if body.Uncles.Len() == 0 { + h.UncleHashes[i] = types.EmptyUncleHash + } else { + h.UncleHashes[i] = crypto.Keccak256Hash(body.Uncles.Bytes()) + } + // withdrawals + if body.Withdrawals != nil { + wdlist := newDerivableRawList(body.Withdrawals, nil) + h.WithdrawalRoots[i] = types.DeriveSha(wdlist, hasher) + } + } + return h +} + +// derivableRawList implements types.DerivableList for a serialized RLP list. +type derivableRawList struct { + data []byte + offsets []uint32 + write func([]byte, *bytes.Buffer) +} + +func newDerivableRawList[T any](list *rlp.RawList[T], write func([]byte, *bytes.Buffer)) *derivableRawList { + dl := derivableRawList{data: list.Content(), write: write} + if dl.write == nil { + // default transform is identity + dl.write = func(b []byte, buf *bytes.Buffer) { buf.Write(b) } + } + // Assert to ensure 32-bit offsets are valid. This can never trigger + // unless a block body component or p2p receipt list is larger than 4GB. + if uint(len(dl.data)) > math.MaxUint32 { + panic("list data too big for derivableRawList") + } + it := list.ContentIterator() + dl.offsets = make([]uint32, list.Len()) + for i := 0; it.Next(); i++ { + dl.offsets[i] = uint32(it.Offset()) + } + return &dl +} + +// Len returns the number of items in the list. +func (dl *derivableRawList) Len() int { + return len(dl.offsets) +} + +// EncodeIndex writes the i'th item to the buffer. +func (dl *derivableRawList) EncodeIndex(i int, buf *bytes.Buffer) { + start := dl.offsets[i] + end := uint32(len(dl.data)) + if i != len(dl.offsets)-1 { + end = dl.offsets[i+1] + } + dl.write(dl.data[start:end], buf) +} + +// writeTxForHash changes a transaction in 'network encoding' into the format used for +// the transactions MPT. +func writeTxForHash(tx []byte, buf *bytes.Buffer) { + k, content, _, _ := rlp.Split(tx) + if k == rlp.List { + buf.Write(tx) // legacy tx + } else { + buf.Write(content) // typed tx + } +} + func handleReceipts[L ReceiptsList](backend Backend, msg Decoder, peer *Peer) error { // A batch of receipts arrived to one of our previous requests res := new(ReceiptsPacket[L]) if err := msg.Decode(res); err != nil { return err } + + tresp := tracker.Response{ID: res.RequestId, MsgCode: ReceiptsMsg, Size: res.List.Len()} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("Receipts: %w", err) + } + // Assign temporary hashing buffer to each list item, the same buffer is shared // between all receipt list instances. + receiptLists, err := res.List.Items() + if err != nil { + return fmt.Errorf("Receipts: %w", err) + } buffers := new(receiptListBuffers) - for i := range res.List { - res.List[i].setBuffers(buffers) + for i := range receiptLists { + receiptLists[i].setBuffers(buffers) } metadata := func() interface{} { hasher := trie.NewStackTrie(nil) - hashes := make([]common.Hash, len(res.List)) - for i := range res.List { - hashes[i] = types.DeriveSha(res.List[i], hasher) + hashes := make([]common.Hash, len(receiptLists)) + for i := range receiptLists { + hashes[i] = types.DeriveSha(receiptLists[i].Derivable(), hasher) } return hashes } var enc ReceiptsRLPResponse - for i := range res.List { - enc = append(enc, res.List[i].EncodeForStorage()) + for i := range receiptLists { + encReceipts, err := receiptLists[i].EncodeForStorage() + if err != nil { + return fmt.Errorf("Receipts: invalid list %d: %v", i, err) + } + enc = append(enc, encReceipts) } return peer.dispatchResponse(&Response{ id: res.RequestId, @@ -446,7 +551,7 @@ func handleNewPooledTransactionHashes(backend Backend, msg Decoder, peer *Peer) } // Schedule all the unknown hashes for retrieval for _, hash := range ann.Hashes { - peer.markTransaction(hash) + peer.MarkTransaction(hash) } return backend.Handle(peer, ann) } @@ -494,19 +599,8 @@ func handleTransactions(backend Backend, msg Decoder, peer *Peer) error { if err := msg.Decode(&txs); err != nil { return err } - // Duplicate transactions are not allowed - seen := make(map[common.Hash]struct{}) - for i, tx := range txs { - // Validate and mark the remote transaction - if tx == nil { - return fmt.Errorf("Transactions: transaction %d is nil", i) - } - hash := tx.Hash() - if _, exists := seen[hash]; exists { - return fmt.Errorf("Transactions: multiple copies of the same hash %v", hash) - } - seen[hash] = struct{}{} - peer.markTransaction(hash) + if txs.Len() > maxTransactionAnnouncements { + return fmt.Errorf("too many transactions") } return backend.Handle(peer, &txs) } @@ -516,28 +610,22 @@ func handlePooledTransactions(backend Backend, msg Decoder, peer *Peer) error { if !backend.AcceptTxs() { return nil } - // Transactions can be processed, parse all of them and deliver to the pool - var txs PooledTransactionsPacket - if err := msg.Decode(&txs); err != nil { + + // Check against request and decode. + var resp PooledTransactionsPacket + if err := msg.Decode(&resp); err != nil { return err } - // Duplicate transactions are not allowed - seen := make(map[common.Hash]struct{}) - for i, tx := range txs.PooledTransactionsResponse { - // Validate and mark the remote transaction - if tx == nil { - return fmt.Errorf("PooledTransactions: transaction %d is nil", i) - } - hash := tx.Hash() - if _, exists := seen[hash]; exists { - return fmt.Errorf("PooledTransactions: multiple copies of the same hash %v", hash) - } - seen[hash] = struct{}{} - peer.markTransaction(hash) + tresp := tracker.Response{ + ID: resp.RequestId, + MsgCode: PooledTransactionsMsg, + Size: resp.List.Len(), + } + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("PooledTransactions: %w", err) } - requestTracker.Fulfil(peer.id, peer.version, PooledTransactionsMsg, txs.RequestId) - return backend.Handle(peer, &txs.PooledTransactionsResponse) + return backend.Handle(peer, &resp) } func handleBlockRangeUpdate(backend Backend, msg Decoder, peer *Peer) error { diff --git a/eth/protocols/eth/peer.go b/eth/protocols/eth/peer.go index df20c672c0..4ea2d7158c 100644 --- a/eth/protocols/eth/peer.go +++ b/eth/protocols/eth/peer.go @@ -19,11 +19,13 @@ package eth import ( "math/rand" "sync/atomic" + "time" mapset "github.com/deckarep/golang-set/v2" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/tracker" "github.com/ethereum/go-ethereum/rlp" ) @@ -43,9 +45,10 @@ const ( // Peer is a collection of relevant information we have about a `eth` peer. type Peer struct { + *p2p.Peer // The embedded P2P package peer + id string // Unique ID for the peer, cached - *p2p.Peer // The embedded P2P package peer rw p2p.MsgReadWriter // Input/output streams for snap version uint // Protocol version negotiated lastRange atomic.Pointer[BlockRangeUpdatePacket] @@ -55,6 +58,7 @@ type Peer struct { txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests + tracker *tracker.Tracker reqDispatch chan *request // Dispatch channel to send requests and track then until fulfillment reqCancel chan *cancel // Dispatch channel to cancel pending requests and untrack them resDispatch chan *response // Dispatch channel to fulfil pending requests and untrack them @@ -65,14 +69,17 @@ type Peer struct { // NewPeer creates a wrapper for a network connection and negotiated protocol // version. func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter, txpool TxPool) *Peer { + cap := p2p.Cap{Name: ProtocolName, Version: version} + id := p.ID().String() peer := &Peer{ - id: p.ID().String(), + id: id, Peer: p, rw: rw, version: version, knownTxs: newKnownCache(maxKnownTxs), txBroadcast: make(chan []common.Hash), txAnnounce: make(chan []common.Hash), + tracker: tracker.New(cap, id, 5*time.Minute), reqDispatch: make(chan *request), reqCancel: make(chan *cancel), resDispatch: make(chan *response), @@ -115,9 +122,9 @@ func (p *Peer) KnownTransaction(hash common.Hash) bool { return p.knownTxs.Contains(hash) } -// markTransaction marks a transaction as known for the peer, ensuring that it +// MarkTransaction marks a transaction as known for the peer, ensuring that it // will never be propagated to this particular peer. -func (p *Peer) markTransaction(hash common.Hash) { +func (p *Peer) MarkTransaction(hash common.Hash) { // If we reached the memory allowance, drop a previously known transaction hash p.knownTxs.Add(hash) } @@ -222,10 +229,11 @@ func (p *Peer) RequestOneHeader(hash common.Hash, sink chan *Response) (*Request id := rand.Uint64() req := &Request{ - id: id, - sink: sink, - code: GetBlockHeadersMsg, - want: BlockHeadersMsg, + id: id, + sink: sink, + code: GetBlockHeadersMsg, + want: BlockHeadersMsg, + numItems: 1, data: &GetBlockHeadersPacket{ RequestId: id, GetBlockHeadersRequest: &GetBlockHeadersRequest{ @@ -249,10 +257,11 @@ func (p *Peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, re id := rand.Uint64() req := &Request{ - id: id, - sink: sink, - code: GetBlockHeadersMsg, - want: BlockHeadersMsg, + id: id, + sink: sink, + code: GetBlockHeadersMsg, + want: BlockHeadersMsg, + numItems: amount, data: &GetBlockHeadersPacket{ RequestId: id, GetBlockHeadersRequest: &GetBlockHeadersRequest{ @@ -276,10 +285,11 @@ func (p *Peer) RequestHeadersByNumber(origin uint64, amount int, skip int, rever id := rand.Uint64() req := &Request{ - id: id, - sink: sink, - code: GetBlockHeadersMsg, - want: BlockHeadersMsg, + id: id, + sink: sink, + code: GetBlockHeadersMsg, + want: BlockHeadersMsg, + numItems: amount, data: &GetBlockHeadersPacket{ RequestId: id, GetBlockHeadersRequest: &GetBlockHeadersRequest{ @@ -303,10 +313,11 @@ func (p *Peer) RequestBodies(hashes []common.Hash, sink chan *Response) (*Reques id := rand.Uint64() req := &Request{ - id: id, - sink: sink, - code: GetBlockBodiesMsg, - want: BlockBodiesMsg, + id: id, + sink: sink, + code: GetBlockBodiesMsg, + want: BlockBodiesMsg, + numItems: len(hashes), data: &GetBlockBodiesPacket{ RequestId: id, GetBlockBodiesRequest: hashes, @@ -324,10 +335,11 @@ func (p *Peer) RequestReceipts(hashes []common.Hash, sink chan *Response) (*Requ id := rand.Uint64() req := &Request{ - id: id, - sink: sink, - code: GetReceiptsMsg, - want: ReceiptsMsg, + id: id, + sink: sink, + code: GetReceiptsMsg, + want: ReceiptsMsg, + numItems: len(hashes), data: &GetReceiptsPacket{ RequestId: id, GetReceiptsRequest: hashes, @@ -344,7 +356,15 @@ func (p *Peer) RequestTxs(hashes []common.Hash) error { p.Log().Trace("Fetching batch of transactions", "count", len(hashes)) id := rand.Uint64() - requestTracker.Track(p.id, p.version, GetPooledTransactionsMsg, PooledTransactionsMsg, id) + err := p.tracker.Track(tracker.Request{ + ID: id, + ReqCode: GetPooledTransactionsMsg, + RespCode: PooledTransactionsMsg, + Size: len(hashes), + }) + if err != nil { + return err + } return p2p.Send(p.rw, GetPooledTransactionsMsg, &GetPooledTransactionsPacket{ RequestId: id, GetPooledTransactionsRequest: hashes, diff --git a/eth/protocols/eth/protocol.go b/eth/protocols/eth/protocol.go index 7c41e7a996..6ab800f4f4 100644 --- a/eth/protocols/eth/protocol.go +++ b/eth/protocols/eth/protocol.go @@ -49,6 +49,9 @@ var protocolLengths = map[uint]uint64{ETH68: 17, ETH69: 18} // maxMessageSize is the maximum cap on the size of a protocol message. const maxMessageSize = 10 * 1024 * 1024 +// This is the maximum number of transactions in a Transactions message. +const maxTransactionAnnouncements = 5000 + const ( StatusMsg = 0x00 NewBlockHashesMsg = 0x01 @@ -127,7 +130,9 @@ func (p *NewBlockHashesPacket) Unpack() ([]common.Hash, []uint64) { } // TransactionsPacket is the network packet for broadcasting new transactions. -type TransactionsPacket []*types.Transaction +type TransactionsPacket struct { + rlp.RawList[*types.Transaction] +} // GetBlockHeadersRequest represents a block header query. type GetBlockHeadersRequest struct { @@ -185,7 +190,7 @@ type BlockHeadersRequest []*types.Header // BlockHeadersPacket represents a block header response over with request ID wrapping. type BlockHeadersPacket struct { RequestId uint64 - BlockHeadersRequest + List rlp.RawList[*types.Header] } // BlockHeadersRLPResponse represents a block header response, to use when we already @@ -213,14 +218,11 @@ type GetBlockBodiesPacket struct { GetBlockBodiesRequest } -// BlockBodiesResponse is the network packet for block content distribution. -type BlockBodiesResponse []*BlockBody - // BlockBodiesPacket is the network packet for block content distribution with // request ID wrapping. type BlockBodiesPacket struct { RequestId uint64 - BlockBodiesResponse + List rlp.RawList[BlockBody] } // BlockBodiesRLPResponse is used for replying to block body requests, in cases @@ -234,25 +236,14 @@ type BlockBodiesRLPPacket struct { BlockBodiesRLPResponse } +// BlockBodiesResponse is the network packet for block content distribution. +type BlockBodiesResponse []BlockBody + // BlockBody represents the data content of a single block. type BlockBody struct { - Transactions []*types.Transaction // Transactions contained within a block - Uncles []*types.Header // Uncles contained within a block - Withdrawals []*types.Withdrawal `rlp:"optional"` // Withdrawals contained within a block -} - -// Unpack retrieves the transactions and uncles from the range packet and returns -// them in a split flat format that's more consistent with the internal data structures. -func (p *BlockBodiesResponse) Unpack() ([][]*types.Transaction, [][]*types.Header, [][]*types.Withdrawal) { - var ( - txset = make([][]*types.Transaction, len(*p)) - uncleset = make([][]*types.Header, len(*p)) - withdrawalset = make([][]*types.Withdrawal, len(*p)) - ) - for i, body := range *p { - txset[i], uncleset[i], withdrawalset[i] = body.Transactions, body.Uncles, body.Withdrawals - } - return txset, uncleset, withdrawalset + Transactions rlp.RawList[*types.Transaction] + Uncles rlp.RawList[*types.Header] + Withdrawals *rlp.RawList[*types.Withdrawal] `rlp:"optional"` } // GetReceiptsRequest represents a block receipts query. @@ -271,15 +262,15 @@ type ReceiptsResponse []types.Receipts type ReceiptsList interface { *ReceiptList68 | *ReceiptList69 setBuffers(*receiptListBuffers) - EncodeForStorage() rlp.RawValue - types.DerivableList + EncodeForStorage() (rlp.RawValue, error) + Derivable() types.DerivableList } // ReceiptsPacket is the network packet for block receipts distribution with // request ID wrapping. type ReceiptsPacket[L ReceiptsList] struct { RequestId uint64 - List []L + List rlp.RawList[L] } // ReceiptsRLPResponse is used for receipts, when we already have it encoded @@ -314,7 +305,7 @@ type PooledTransactionsResponse []*types.Transaction // with request ID wrapping. type PooledTransactionsPacket struct { RequestId uint64 - PooledTransactionsResponse + List rlp.RawList[*types.Transaction] } // PooledTransactionsRLPResponse is the network packet for transaction distribution, used @@ -367,8 +358,8 @@ func (*NewPooledTransactionHashesPacket) Kind() byte { return NewPooledTransac func (*GetPooledTransactionsRequest) Name() string { return "GetPooledTransactions" } func (*GetPooledTransactionsRequest) Kind() byte { return GetPooledTransactionsMsg } -func (*PooledTransactionsResponse) Name() string { return "PooledTransactions" } -func (*PooledTransactionsResponse) Kind() byte { return PooledTransactionsMsg } +func (*PooledTransactionsPacket) Name() string { return "PooledTransactions" } +func (*PooledTransactionsPacket) Kind() byte { return PooledTransactionsMsg } func (*GetReceiptsRequest) Name() string { return "GetReceipts" } func (*GetReceiptsRequest) Kind() byte { return GetReceiptsMsg } diff --git a/eth/protocols/eth/protocol_test.go b/eth/protocols/eth/protocol_test.go index 8a2559a6c5..e37d72dcd6 100644 --- a/eth/protocols/eth/protocol_test.go +++ b/eth/protocols/eth/protocol_test.go @@ -78,34 +78,34 @@ func TestEmptyMessages(t *testing.T) { for i, msg := range []any{ // Headers GetBlockHeadersPacket{1111, nil}, - BlockHeadersPacket{1111, nil}, // Bodies GetBlockBodiesPacket{1111, nil}, - BlockBodiesPacket{1111, nil}, BlockBodiesRLPPacket{1111, nil}, // Receipts GetReceiptsPacket{1111, nil}, // Transactions GetPooledTransactionsPacket{1111, nil}, - PooledTransactionsPacket{1111, nil}, PooledTransactionsRLPPacket{1111, nil}, // Headers - BlockHeadersPacket{1111, BlockHeadersRequest([]*types.Header{})}, + BlockHeadersPacket{1111, encodeRL([]*types.Header{})}, // Bodies GetBlockBodiesPacket{1111, GetBlockBodiesRequest([]common.Hash{})}, - BlockBodiesPacket{1111, BlockBodiesResponse([]*BlockBody{})}, + BlockBodiesPacket{1111, encodeRL([]BlockBody{})}, BlockBodiesRLPPacket{1111, BlockBodiesRLPResponse([]rlp.RawValue{})}, // Receipts GetReceiptsPacket{1111, GetReceiptsRequest([]common.Hash{})}, - ReceiptsPacket[*ReceiptList68]{1111, []*ReceiptList68{}}, - ReceiptsPacket[*ReceiptList69]{1111, []*ReceiptList69{}}, + ReceiptsPacket[*ReceiptList68]{1111, encodeRL([]*ReceiptList68{})}, + ReceiptsPacket[*ReceiptList69]{1111, encodeRL([]*ReceiptList69{})}, // Transactions GetPooledTransactionsPacket{1111, GetPooledTransactionsRequest([]common.Hash{})}, - PooledTransactionsPacket{1111, PooledTransactionsResponse([]*types.Transaction{})}, + PooledTransactionsPacket{1111, encodeRL([]*types.Transaction{})}, PooledTransactionsRLPPacket{1111, PooledTransactionsRLPResponse([]rlp.RawValue{})}, } { - if have, _ := rlp.EncodeToBytes(msg); !bytes.Equal(have, want) { + have, err := rlp.EncodeToBytes(msg) + if err != nil { + t.Errorf("test %d, type %T, error: %v", i, msg, err) + } else if !bytes.Equal(have, want) { t.Errorf("test %d, type %T, have\n\t%x\nwant\n\t%x", i, msg, have, want) } } @@ -116,7 +116,7 @@ func TestMessages(t *testing.T) { // Some basic structs used during testing var ( header *types.Header - blockBody *BlockBody + blockBody BlockBody blockBodyRlp rlp.RawValue txs []*types.Transaction txRlps []rlp.RawValue @@ -150,9 +150,9 @@ func TestMessages(t *testing.T) { } } // init the block body data, both object and rlp form - blockBody = &BlockBody{ - Transactions: txs, - Uncles: []*types.Header{header}, + blockBody = BlockBody{ + Transactions: encodeRL(txs), + Uncles: encodeRL([]*types.Header{header}), } blockBodyRlp, err = rlp.EncodeToBytes(blockBody) if err != nil { @@ -211,7 +211,7 @@ func TestMessages(t *testing.T) { common.FromHex("ca820457c682270f050580"), }, { - BlockHeadersPacket{1111, BlockHeadersRequest{header}}, + BlockHeadersPacket{1111, encodeRL([]*types.Header{header})}, common.FromHex("f90202820457f901fcf901f9a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008208ae820d0582115c8215b3821a0a827788a00000000000000000000000000000000000000000000000000000000000000000880000000000000000"), }, { @@ -219,7 +219,7 @@ func TestMessages(t *testing.T) { common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"), }, { - BlockBodiesPacket{1111, BlockBodiesResponse([]*BlockBody{blockBody})}, + BlockBodiesPacket{1111, encodeRL([]BlockBody{blockBody})}, common.FromHex("f902dc820457f902d6f902d3f8d2f867088504a817c8088302e2489435353535353535353535353535353535353535358202008025a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c12a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c10f867098504a817c809830334509435353535353535353535353535353535353535358202d98025a052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afba052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afbf901fcf901f9a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b90100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008208ae820d0582115c8215b3821a0a827788a00000000000000000000000000000000000000000000000000000000000000000880000000000000000"), }, { // Identical to non-rlp-shortcut version @@ -231,7 +231,7 @@ func TestMessages(t *testing.T) { common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"), }, { - ReceiptsPacket[*ReceiptList68]{1111, []*ReceiptList68{NewReceiptList68(receipts)}}, + ReceiptsPacket[*ReceiptList68]{1111, encodeRL([]*ReceiptList68{NewReceiptList68(receipts)})}, common.FromHex("f902e6820457f902e0f902ddf901688082014db9010000000000000010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000014000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000004000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000f85ff85d940000000000000000000000000000000000000011f842a0000000000000000000000000000000000000000000000000000000000000deada0000000000000000000000000000000000000000000000000000000000000beef830100ffb9016f01f9016b018201bcb9010000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000001000000000000000000000000000000000000000000000000040000000000000000000000000004000000000000000000000000000000000000000000000000000000008000400000000000000000000000000000000000000000000000000000000000000000000000000000040f862f860940000000000000000000000000000000000000022f842a00000000000000000000000000000000000000000000000000000000000005668a0000000000000000000000000000000000000000000000000000000000000977386020f0f0f0608"), }, { @@ -240,7 +240,7 @@ func TestMessages(t *testing.T) { common.FromHex("f902e6820457f902e0f902ddf901688082014db9010000000000000010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000014000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000004000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000f85ff85d940000000000000000000000000000000000000011f842a0000000000000000000000000000000000000000000000000000000000000deada0000000000000000000000000000000000000000000000000000000000000beef830100ffb9016f01f9016b018201bcb9010000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040000000001000000000000000000000000000000000000000000000000040000000000000000000000000004000000000000000000000000000000000000000000000000000000008000400000000000000000000000000000000000000000000000000000000000000000000000000000040f862f860940000000000000000000000000000000000000022f842a00000000000000000000000000000000000000000000000000000000000005668a0000000000000000000000000000000000000000000000000000000000000977386020f0f0f0608"), }, { - ReceiptsPacket[*ReceiptList69]{1111, []*ReceiptList69{NewReceiptList69(receipts)}}, + ReceiptsPacket[*ReceiptList69]{1111, encodeRL([]*ReceiptList69{NewReceiptList69(receipts)})}, common.FromHex("f8da820457f8d5f8d3f866808082014df85ff85d940000000000000000000000000000000000000011f842a0000000000000000000000000000000000000000000000000000000000000deada0000000000000000000000000000000000000000000000000000000000000beef830100fff86901018201bcf862f860940000000000000000000000000000000000000022f842a00000000000000000000000000000000000000000000000000000000000005668a0000000000000000000000000000000000000000000000000000000000000977386020f0f0f0608"), }, { @@ -248,7 +248,7 @@ func TestMessages(t *testing.T) { common.FromHex("f847820457f842a000000000000000000000000000000000000000000000000000000000deadc0dea000000000000000000000000000000000000000000000000000000000feedbeef"), }, { - PooledTransactionsPacket{1111, PooledTransactionsResponse(txs)}, + PooledTransactionsPacket{1111, encodeRL(txs)}, common.FromHex("f8d7820457f8d2f867088504a817c8088302e2489435353535353535353535353535353535353535358202008025a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c12a064b1702d9298fee62dfeccc57d322a463ad55ca201256d01f62b45b2e1c21c10f867098504a817c809830334509435353535353535353535353535353535353535358202d98025a052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afba052f8f61201b2b11a78d6e866abc9c3db2ae8631fa656bfe5cb53668255367afb"), }, { diff --git a/eth/protocols/eth/receipt.go b/eth/protocols/eth/receipt.go index 45c4766b17..06956df9f2 100644 --- a/eth/protocols/eth/receipt.go +++ b/eth/protocols/eth/receipt.go @@ -50,8 +50,8 @@ func newReceipt(tr *types.Receipt) Receipt { } // decode68 parses a receipt in the eth/68 network encoding. -func (r *Receipt) decode68(buf *receiptListBuffers, s *rlp.Stream) error { - k, size, err := s.Kind() +func (r *Receipt) decode68(b []byte) error { + k, content, _, err := rlp.Split(b) if err != nil { return err } @@ -59,65 +59,84 @@ func (r *Receipt) decode68(buf *receiptListBuffers, s *rlp.Stream) error { *r = Receipt{} if k == rlp.List { // Legacy receipt. - return r.decodeInnerList(s, false, true) + return r.decodeInnerList(b, false, true) } // Typed receipt. - if size < 2 || size > maxReceiptSize { - return fmt.Errorf("invalid receipt size %d", size) + if len(content) < 2 || len(content) > maxReceiptSize { + return fmt.Errorf("invalid receipt size %d", len(content)) } - buf.tmp.Reset() - buf.tmp.Grow(int(size)) - payload := buf.tmp.Bytes()[:int(size)] - if err := s.ReadBytes(payload); err != nil { - return err - } - r.TxType = payload[0] - s2 := rlp.NewStream(bytes.NewReader(payload[1:]), 0) - return r.decodeInnerList(s2, false, true) + r.TxType = content[0] + return r.decodeInnerList(content[1:], false, true) } // decode69 parses a receipt in the eth/69 network encoding. -func (r *Receipt) decode69(s *rlp.Stream) error { +func (r *Receipt) decode69(b []byte) error { *r = Receipt{} - return r.decodeInnerList(s, true, false) + return r.decodeInnerList(b, true, false) } // decodeDatabase parses a receipt in the basic database encoding. -func (r *Receipt) decodeDatabase(txType byte, s *rlp.Stream) error { +func (r *Receipt) decodeDatabase(txType byte, b []byte) error { *r = Receipt{TxType: txType} - return r.decodeInnerList(s, false, false) + return r.decodeInnerList(b, false, false) } -func (r *Receipt) decodeInnerList(s *rlp.Stream, readTxType, readBloom bool) error { - _, err := s.List() +func (r *Receipt) decodeInnerList(input []byte, readTxType, readBloom bool) error { + input, _, err := rlp.SplitList(input) if err != nil { - return err + return fmt.Errorf("inner list: %v", err) } + + // txType if readTxType { - r.TxType, err = s.Uint8() + var txType uint64 + txType, input, err = rlp.SplitUint64(input) if err != nil { return fmt.Errorf("invalid txType: %w", err) } + if txType > 0x7f { + return fmt.Errorf("invalid txType: too large") + } + r.TxType = byte(txType) } - r.PostStateOrStatus, err = s.Bytes() + + // status + r.PostStateOrStatus, input, err = rlp.SplitString(input) if err != nil { return fmt.Errorf("invalid postStateOrStatus: %w", err) } - r.GasUsed, err = s.Uint64() + if len(r.PostStateOrStatus) > 1 && len(r.PostStateOrStatus) != 32 { + return fmt.Errorf("invalid postStateOrStatus length %d", len(r.PostStateOrStatus)) + } + + // gas + r.GasUsed, input, err = rlp.SplitUint64(input) if err != nil { return fmt.Errorf("invalid gasUsed: %w", err) } + + // bloom if readBloom { - var b types.Bloom - if err := s.ReadBytes(b[:]); err != nil { + var bloomBytes []byte + bloomBytes, input, err = rlp.SplitString(input) + if err != nil { return fmt.Errorf("invalid bloom: %v", err) } + if len(bloomBytes) != types.BloomByteLength { + return fmt.Errorf("invalid bloom length %d", len(bloomBytes)) + } } - r.Logs, err = s.Raw() + + // logs + _, rest, err := rlp.SplitList(input) if err != nil { return fmt.Errorf("invalid logs: %w", err) } - return s.ListEnd() + if len(rest) != 0 { + return fmt.Errorf("junk at end of receipt") + } + r.Logs = input + return nil } // encodeForStorage produces the the storage encoding, i.e. the result matches @@ -223,32 +242,45 @@ func initBuffers(buf **receiptListBuffers) { } // encodeForStorage encodes a list of receipts for the database. -func (buf *receiptListBuffers) encodeForStorage(rs []Receipt) rlp.RawValue { +func (buf *receiptListBuffers) encodeForStorage(rs rlp.RawList[Receipt], decode func([]byte, *Receipt) error) (rlp.RawValue, error) { var out bytes.Buffer w := &buf.enc w.Reset(&out) outer := w.List() - for _, receipts := range rs { - receipts.encodeForStorage(w) + it := rs.ContentIterator() + for it.Next() { + var receipt Receipt + if err := decode(it.Value(), &receipt); err != nil { + return nil, err + } + receipt.encodeForStorage(w) + } + if it.Err() != nil { + return nil, fmt.Errorf("bad list: %v", it.Err()) } w.ListEnd(outer) w.Flush() - return out.Bytes() + return out.Bytes(), nil } // ReceiptList68 is a block receipt list as downloaded by eth/68. // This also implements types.DerivableList for validation purposes. type ReceiptList68 struct { buf *receiptListBuffers - items []Receipt + items rlp.RawList[Receipt] } // NewReceiptList68 creates a receipt list. // This is slow, and exists for testing purposes. func NewReceiptList68(trs []*types.Receipt) *ReceiptList68 { - rl := &ReceiptList68{items: make([]Receipt, len(trs))} - for i, tr := range trs { - rl.items[i] = newReceipt(tr) + rl := new(ReceiptList68) + initBuffers(&rl.buf) + enc := rlp.NewEncoderBuffer(nil) + for _, tr := range trs { + r := newReceipt(tr) + r.encodeForNetwork68(rl.buf, &enc) + rl.items.AppendRaw(enc.ToBytes()) + enc.Reset(nil) } return rl } @@ -266,17 +298,12 @@ func blockReceiptsToNetwork68(blockReceipts, blockBody rlp.RawValue) ([]byte, er buf receiptListBuffers ) blockReceiptIter, _ := rlp.NewListIterator(blockReceipts) - innerReader := bytes.NewReader(nil) - innerStream := rlp.NewStream(innerReader, 0) w := rlp.NewEncoderBuffer(&out) outer := w.List() for i := 0; blockReceiptIter.Next(); i++ { - content := blockReceiptIter.Value() - innerReader.Reset(content) - innerStream.Reset(innerReader, uint64(len(content))) - var r Receipt txType, _ := nextTxType() - if err := r.decodeDatabase(txType, innerStream); err != nil { + var r Receipt + if err := r.decodeDatabase(txType, blockReceiptIter.Value()); err != nil { return nil, fmt.Errorf("invalid database receipt %d: %v", i, err) } r.encodeForNetwork68(&buf, &w) @@ -292,64 +319,51 @@ func (rl *ReceiptList68) setBuffers(buf *receiptListBuffers) { } // EncodeForStorage encodes the receipts for storage into the database. -func (rl *ReceiptList68) EncodeForStorage() rlp.RawValue { +func (rl *ReceiptList68) EncodeForStorage() (rlp.RawValue, error) { initBuffers(&rl.buf) - return rl.buf.encodeForStorage(rl.items) + return rl.buf.encodeForStorage(rl.items, func(data []byte, r *Receipt) error { + return r.decode68(data) + }) } -// Len implements types.DerivableList. -func (rl *ReceiptList68) Len() int { - return len(rl.items) -} - -// EncodeIndex implements types.DerivableList. -func (rl *ReceiptList68) EncodeIndex(i int, out *bytes.Buffer) { +// Derivable turns the receipts into a list that can derive the root hash. +func (rl *ReceiptList68) Derivable() types.DerivableList { initBuffers(&rl.buf) - rl.items[i].encodeForHash(rl.buf, out) + return newDerivableRawList(&rl.items, func(data []byte, outbuf *bytes.Buffer) { + var r Receipt + if r.decode68(data) == nil { + r.encodeForHash(rl.buf, outbuf) + } + }) } // DecodeRLP decodes a list of receipts from the network format. func (rl *ReceiptList68) DecodeRLP(s *rlp.Stream) error { - initBuffers(&rl.buf) - if _, err := s.List(); err != nil { - return err - } - for i := 0; s.MoreDataInList(); i++ { - var item Receipt - err := item.decode68(rl.buf, s) - if err != nil { - return fmt.Errorf("receipt %d: %v", i, err) - } - rl.items = append(rl.items, item) - } - return s.ListEnd() + return rl.items.DecodeRLP(s) } // EncodeRLP encodes the list into the network format of eth/68. -func (rl *ReceiptList68) EncodeRLP(_w io.Writer) error { - initBuffers(&rl.buf) - w := rlp.NewEncoderBuffer(_w) - outer := w.List() - for i := range rl.items { - rl.items[i].encodeForNetwork68(rl.buf, &w) - } - w.ListEnd(outer) - return w.Flush() +func (rl *ReceiptList68) EncodeRLP(w io.Writer) error { + return rl.items.EncodeRLP(w) } // ReceiptList69 is the block receipt list as downloaded by eth/69. // This implements types.DerivableList for validation purposes. type ReceiptList69 struct { buf *receiptListBuffers - items []Receipt + items rlp.RawList[Receipt] } // NewReceiptList69 creates a receipt list. // This is slow, and exists for testing purposes. func NewReceiptList69(trs []*types.Receipt) *ReceiptList69 { - rl := &ReceiptList69{items: make([]Receipt, len(trs))} - for i, tr := range trs { - rl.items[i] = newReceipt(tr) + rl := new(ReceiptList69) + enc := rlp.NewEncoderBuffer(nil) + for _, tr := range trs { + r := newReceipt(tr) + r.encodeForNetwork69(&enc) + rl.items.AppendRaw(enc.ToBytes()) + enc.Reset(nil) } return rl } @@ -360,47 +374,32 @@ func (rl *ReceiptList69) setBuffers(buf *receiptListBuffers) { } // EncodeForStorage encodes the receipts for storage into the database. -func (rl *ReceiptList69) EncodeForStorage() rlp.RawValue { +func (rl *ReceiptList69) EncodeForStorage() (rlp.RawValue, error) { initBuffers(&rl.buf) - return rl.buf.encodeForStorage(rl.items) + return rl.buf.encodeForStorage(rl.items, func(data []byte, r *Receipt) error { + return r.decode69(data) + }) } -// Len implements types.DerivableList. -func (rl *ReceiptList69) Len() int { - return len(rl.items) -} - -// EncodeIndex implements types.DerivableList. -func (rl *ReceiptList69) EncodeIndex(i int, out *bytes.Buffer) { +// Derivable turns the receipts into a list that can derive the root hash. +func (rl *ReceiptList69) Derivable() types.DerivableList { initBuffers(&rl.buf) - rl.items[i].encodeForHash(rl.buf, out) + return newDerivableRawList(&rl.items, func(data []byte, outbuf *bytes.Buffer) { + var r Receipt + if r.decode69(data) == nil { + r.encodeForHash(rl.buf, outbuf) + } + }) } // DecodeRLP decodes a list receipts from the network format. func (rl *ReceiptList69) DecodeRLP(s *rlp.Stream) error { - if _, err := s.List(); err != nil { - return err - } - for i := 0; s.MoreDataInList(); i++ { - var item Receipt - err := item.decode69(s) - if err != nil { - return fmt.Errorf("receipt %d: %v", i, err) - } - rl.items = append(rl.items, item) - } - return s.ListEnd() + return rl.items.DecodeRLP(s) } // EncodeRLP encodes the list into the network format of eth/69. -func (rl *ReceiptList69) EncodeRLP(_w io.Writer) error { - w := rlp.NewEncoderBuffer(_w) - outer := w.List() - for i := range rl.items { - rl.items[i].encodeForNetwork69(&w) - } - w.ListEnd(outer) - return w.Flush() +func (rl *ReceiptList69) EncodeRLP(w io.Writer) error { + return rl.items.EncodeRLP(w) } // blockReceiptsToNetwork69 takes a slice of rlp-encoded receipts, and transactions, diff --git a/eth/protocols/eth/receipt_test.go b/eth/protocols/eth/receipt_test.go index 3c73c07396..39a2728f7f 100644 --- a/eth/protocols/eth/receipt_test.go +++ b/eth/protocols/eth/receipt_test.go @@ -63,6 +63,18 @@ var receiptsTests = []struct { input: []types.ReceiptForStorage{{CumulativeGasUsed: 555, Status: 1, Logs: receiptsTestLogs2}}, txs: []*types.Transaction{types.NewTx(&types.AccessListTx{})}, }, + { + input: []types.ReceiptForStorage{ + {CumulativeGasUsed: 111, PostState: common.HexToHash("0x1111").Bytes(), Logs: receiptsTestLogs1}, + {CumulativeGasUsed: 222, Status: 0, Logs: receiptsTestLogs2}, + {CumulativeGasUsed: 333, Status: 1, Logs: nil}, + }, + txs: []*types.Transaction{ + types.NewTx(&types.LegacyTx{}), + types.NewTx(&types.AccessListTx{}), + types.NewTx(&types.DynamicFeeTx{}), + }, + }, } func init() { @@ -103,7 +115,10 @@ func TestReceiptList69(t *testing.T) { if err := rlp.DecodeBytes(network, &rl); err != nil { t.Fatalf("test[%d]: can't decode network receipts: %v", i, err) } - rlStorageEnc := rl.EncodeForStorage() + rlStorageEnc, err := rl.EncodeForStorage() + if err != nil { + t.Fatalf("test[%d]: error from EncodeForStorage: %v", i, err) + } if !bytes.Equal(rlStorageEnc, canonDB) { t.Fatalf("test[%d]: re-encoded receipts not equal\nhave: %x\nwant: %x", i, rlStorageEnc, canonDB) } @@ -113,7 +128,7 @@ func TestReceiptList69(t *testing.T) { } // compute root hash from ReceiptList69 and compare. - responseHash := types.DeriveSha(&rl, trie.NewStackTrie(nil)) + responseHash := types.DeriveSha(rl.Derivable(), trie.NewStackTrie(nil)) if responseHash != test.root { t.Fatalf("test[%d]: wrong root hash from ReceiptList69\nhave: %v\nwant: %v", i, responseHash, test.root) } @@ -140,7 +155,10 @@ func TestReceiptList68(t *testing.T) { if err := rlp.DecodeBytes(network, &rl); err != nil { t.Fatalf("test[%d]: can't decode network receipts: %v", i, err) } - rlStorageEnc := rl.EncodeForStorage() + rlStorageEnc, err := rl.EncodeForStorage() + if err != nil { + t.Fatalf("test[%d]: error from EncodeForStorage: %v", i, err) + } if !bytes.Equal(rlStorageEnc, canonDB) { t.Fatalf("test[%d]: re-encoded receipts not equal\nhave: %x\nwant: %x", i, rlStorageEnc, canonDB) } @@ -150,7 +168,7 @@ func TestReceiptList68(t *testing.T) { } // compute root hash from ReceiptList68 and compare. - responseHash := types.DeriveSha(&rl, trie.NewStackTrie(nil)) + responseHash := types.DeriveSha(rl.Derivable(), trie.NewStackTrie(nil)) if responseHash != test.root { t.Fatalf("test[%d]: wrong root hash from ReceiptList68\nhave: %v\nwant: %v", i, responseHash, test.root) } diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index 3249720f90..071a0419fb 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -31,6 +31,8 @@ import ( "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/tracker" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie/trienode" "github.com/ethereum/go-ethereum/triedb/database" @@ -94,6 +96,7 @@ func MakeProtocols(backend Backend) []p2p.Protocol { Length: protocolLengths[version], Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { return backend.RunPeer(NewPeer(version, p, rw), func(peer *Peer) error { + defer peer.Close() return Handle(backend, peer) }) }, @@ -149,7 +152,6 @@ func HandleMessage(backend Backend, peer *Peer) error { // Handle the message depending on its contents switch { case msg.Code == GetAccountRangeMsg: - // Decode the account retrieval request var req GetAccountRangePacket if err := msg.Decode(&req); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) @@ -165,23 +167,40 @@ func HandleMessage(backend Backend, peer *Peer) error { }) case msg.Code == AccountRangeMsg: - // A range of accounts arrived to one of our previous requests - res := new(AccountRangePacket) + res := new(accountRangeInput) if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } + + // Check response validity. + if len := res.Proof.Len(); len > 128 { + return fmt.Errorf("AccountRange: invalid proof (length %d)", len) + } + tresp := tracker.Response{ID: res.ID, MsgCode: AccountRangeMsg, Size: len(res.Accounts.Content())} + if err := peer.tracker.Fulfil(tresp); err != nil { + return err + } + + // Decode. + accounts, err := res.Accounts.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid accounts list: %v", err) + } + proof, err := res.Proof.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid proof: %v", err) + } + // Ensure the range is monotonically increasing - for i := 1; i < len(res.Accounts); i++ { - if bytes.Compare(res.Accounts[i-1].Hash[:], res.Accounts[i].Hash[:]) >= 0 { - return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, res.Accounts[i-1].Hash[:], i, res.Accounts[i].Hash[:]) + for i := 1; i < len(accounts); i++ { + if bytes.Compare(accounts[i-1].Hash[:], accounts[i].Hash[:]) >= 0 { + return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, accounts[i-1].Hash[:], i, accounts[i].Hash[:]) } } - requestTracker.Fulfil(peer.id, peer.version, AccountRangeMsg, res.ID) - return backend.Handle(peer, res) + return backend.Handle(peer, &AccountRangePacket{res.ID, accounts, proof}) case msg.Code == GetStorageRangesMsg: - // Decode the storage retrieval request var req GetStorageRangesPacket if err := msg.Decode(&req); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) @@ -197,25 +216,42 @@ func HandleMessage(backend Backend, peer *Peer) error { }) case msg.Code == StorageRangesMsg: - // A range of storage slots arrived to one of our previous requests - res := new(StorageRangesPacket) + res := new(storageRangesInput) if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } + + // Check response validity. + if len := res.Proof.Len(); len > 128 { + return fmt.Errorf("StorageRangesMsg: invalid proof (length %d)", len) + } + tresp := tracker.Response{ID: res.ID, MsgCode: StorageRangesMsg, Size: len(res.Slots.Content())} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("StorageRangesMsg: %w", err) + } + + // Decode. + slotLists, err := res.Slots.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid accounts list: %v", err) + } + proof, err := res.Proof.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid proof: %v", err) + } + // Ensure the ranges are monotonically increasing - for i, slots := range res.Slots { + for i, slots := range slotLists { for j := 1; j < len(slots); j++ { if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 { return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:]) } } } - requestTracker.Fulfil(peer.id, peer.version, StorageRangesMsg, res.ID) - return backend.Handle(peer, res) + return backend.Handle(peer, &StorageRangesPacket{res.ID, slotLists, proof}) case msg.Code == GetByteCodesMsg: - // Decode bytecode retrieval request var req GetByteCodesPacket if err := msg.Decode(&req); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) @@ -230,17 +266,25 @@ func HandleMessage(backend Backend, peer *Peer) error { }) case msg.Code == ByteCodesMsg: - // A batch of byte codes arrived to one of our previous requests - res := new(ByteCodesPacket) + res := new(byteCodesInput) if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } - requestTracker.Fulfil(peer.id, peer.version, ByteCodesMsg, res.ID) - return backend.Handle(peer, res) + length := res.Codes.Len() + tresp := tracker.Response{ID: res.ID, MsgCode: ByteCodesMsg, Size: length} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("ByteCodes: %w", err) + } + + codes, err := res.Codes.Items() + if err != nil { + return fmt.Errorf("ByteCodes: %w", err) + } + + return backend.Handle(peer, &ByteCodesPacket{res.ID, codes}) case msg.Code == GetTrieNodesMsg: - // Decode trie node retrieval request var req GetTrieNodesPacket if err := msg.Decode(&req); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) @@ -257,14 +301,21 @@ func HandleMessage(backend Backend, peer *Peer) error { }) case msg.Code == TrieNodesMsg: - // A batch of trie nodes arrived to one of our previous requests - res := new(TrieNodesPacket) + res := new(trieNodesInput) if err := msg.Decode(res); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } - requestTracker.Fulfil(peer.id, peer.version, TrieNodesMsg, res.ID) - return backend.Handle(peer, res) + tresp := tracker.Response{ID: res.ID, MsgCode: TrieNodesMsg, Size: res.Nodes.Len()} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("TrieNodes: %w", err) + } + nodes, err := res.Nodes.Items() + if err != nil { + return fmt.Errorf("TrieNodes: %w", err) + } + + return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes}) default: return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code) @@ -512,21 +563,32 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s if reader == nil { reader, _ = triedb.StateReader(req.Root) } + // Retrieve trie nodes until the packet size limit is reached var ( - nodes [][]byte - bytes uint64 - loads int // Trie hash expansions to count database reads + outerIt = req.Paths.ContentIterator() + nodes [][]byte + bytes uint64 + loads int // Trie hash expansions to count database reads ) - for _, pathset := range req.Paths { - switch len(pathset) { + for outerIt.Next() { + innerIt, err := rlp.NewListIterator(outerIt.Value()) + if err != nil { + return nodes, err + } + + switch innerIt.Count() { case 0: // Ensure we penalize invalid requests return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest) case 1: // If we're only retrieving an account trie node, fetch it directly - blob, resolved, err := accTrie.GetNode(pathset[0]) + accKey := nextBytes(&innerIt) + if accKey == nil { + return nodes, fmt.Errorf("%w: invalid account node request", errBadRequest) + } + blob, resolved, err := accTrie.GetNode(accKey) loads += resolved // always account database reads, even for failures if err != nil { break @@ -535,33 +597,41 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s bytes += uint64(len(blob)) default: - var stRoot common.Hash - // Storage slots requested, open the storage trie and retrieve from there + accKey := nextBytes(&innerIt) + if accKey == nil { + return nodes, fmt.Errorf("%w: invalid account storage request", errBadRequest) + } + var stRoot common.Hash if reader == nil { // We don't have the requested state snapshotted yet (or it is stale), // but can look up the account via the trie instead. - account, err := accTrie.GetAccountByHash(common.BytesToHash(pathset[0])) + account, err := accTrie.GetAccountByHash(common.BytesToHash(accKey)) loads += 8 // We don't know the exact cost of lookup, this is an estimate if err != nil || account == nil { break } stRoot = account.Root } else { - account, err := reader.Account(common.BytesToHash(pathset[0])) + account, err := reader.Account(common.BytesToHash(accKey)) loads++ // always account database reads, even for failures if err != nil || account == nil { break } stRoot = common.BytesToHash(account.Root) } - id := trie.StorageTrieID(req.Root, common.BytesToHash(pathset[0]), stRoot) + + id := trie.StorageTrieID(req.Root, common.BytesToHash(accKey), stRoot) stTrie, err := trie.NewStateTrie(id, triedb) loads++ // always account database reads, even for failures if err != nil { break } - for _, path := range pathset[1:] { + for innerIt.Next() { + path, _, err := rlp.SplitString(innerIt.Value()) + if err != nil { + return nil, fmt.Errorf("%w: invalid storage key: %v", errBadRequest, err) + } blob, resolved, err := stTrie.GetNode(path) loads += resolved // always account database reads, even for failures if err != nil { @@ -584,6 +654,17 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s return nodes, nil } +func nextBytes(it *rlp.Iterator) []byte { + if !it.Next() { + return nil + } + content, _, err := rlp.SplitString(it.Value()) + if err != nil { + return nil + } + return content +} + // NodeInfo represents a short summary of the `snap` sub-protocol metadata // known about the host peer. type NodeInfo struct{} diff --git a/eth/protocols/snap/peer.go b/eth/protocols/snap/peer.go index c57931678c..0b96de4158 100644 --- a/eth/protocols/snap/peer.go +++ b/eth/protocols/snap/peer.go @@ -17,9 +17,13 @@ package snap import ( + "time" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/tracker" + "github.com/ethereum/go-ethereum/rlp" ) // Peer is a collection of relevant information we have about a `snap` peer. @@ -29,6 +33,7 @@ type Peer struct { *p2p.Peer // The embedded P2P package peer rw p2p.MsgReadWriter // Input/output streams for snap version uint // Protocol version negotiated + tracker *tracker.Tracker logger log.Logger // Contextual logger with the peer id injected } @@ -36,22 +41,26 @@ type Peer struct { // NewPeer creates a wrapper for a network connection and negotiated protocol // version. func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) *Peer { + cap := p2p.Cap{Name: ProtocolName, Version: version} id := p.ID().String() return &Peer{ id: id, Peer: p, rw: rw, version: version, + tracker: tracker.New(cap, id, 1*time.Minute), logger: log.New("peer", id[:8]), } } // NewFakePeer creates a fake snap peer without a backing p2p peer, for testing purposes. func NewFakePeer(version uint, id string, rw p2p.MsgReadWriter) *Peer { + cap := p2p.Cap{Name: ProtocolName, Version: version} return &Peer{ id: id, rw: rw, version: version, + tracker: tracker.New(cap, id, 1*time.Minute), logger: log.New("peer", id[:8]), } } @@ -71,63 +80,99 @@ func (p *Peer) Log() log.Logger { return p.logger } +// Close releases resources associated with the peer. +func (p *Peer) Close() { + p.tracker.Stop() +} + // RequestAccountRange fetches a batch of accounts rooted in a specific account // trie, starting with the origin. -func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes uint64) error { +func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes int) error { p.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes)) - requestTracker.Track(p.id, p.version, GetAccountRangeMsg, AccountRangeMsg, id) + err := p.tracker.Track(tracker.Request{ + ReqCode: GetAccountRangeMsg, + RespCode: AccountRangeMsg, + ID: id, + Size: 2 * bytes, + }) + if err != nil { + return err + } return p2p.Send(p.rw, GetAccountRangeMsg, &GetAccountRangePacket{ ID: id, Root: root, Origin: origin, Limit: limit, - Bytes: bytes, + Bytes: uint64(bytes), }) } // RequestStorageRanges fetches a batch of storage slots belonging to one or more // accounts. If slots from only one account is requested, an origin marker may also // be used to retrieve from there. -func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error { +func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error { if len(accounts) == 1 && origin != nil { p.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes)) } else { p.logger.Trace("Fetching ranges of small storage slots", "reqid", id, "root", root, "accounts", len(accounts), "first", accounts[0], "bytes", common.StorageSize(bytes)) } - requestTracker.Track(p.id, p.version, GetStorageRangesMsg, StorageRangesMsg, id) + + p.tracker.Track(tracker.Request{ + ReqCode: GetStorageRangesMsg, + RespCode: StorageRangesMsg, + ID: id, + Size: 2 * bytes, + }) return p2p.Send(p.rw, GetStorageRangesMsg, &GetStorageRangesPacket{ ID: id, Root: root, Accounts: accounts, Origin: origin, Limit: limit, - Bytes: bytes, + Bytes: uint64(bytes), }) } // RequestByteCodes fetches a batch of bytecodes by hash. -func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error { +func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error { p.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes)) - requestTracker.Track(p.id, p.version, GetByteCodesMsg, ByteCodesMsg, id) + err := p.tracker.Track(tracker.Request{ + ReqCode: GetByteCodesMsg, + RespCode: ByteCodesMsg, + ID: id, + Size: len(hashes), // ByteCodes is limited by the length of the hash list. + }) + if err != nil { + return err + } return p2p.Send(p.rw, GetByteCodesMsg, &GetByteCodesPacket{ ID: id, Hashes: hashes, - Bytes: bytes, + Bytes: uint64(bytes), }) } // RequestTrieNodes fetches a batch of account or storage trie nodes rooted in -// a specific state trie. -func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error { +// a specific state trie. The `count` is the total count of paths being requested. +func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error { p.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes)) - requestTracker.Track(p.id, p.version, GetTrieNodesMsg, TrieNodesMsg, id) + err := p.tracker.Track(tracker.Request{ + ReqCode: GetTrieNodesMsg, + RespCode: TrieNodesMsg, + ID: id, + Size: count, // TrieNodes is limited by number of items. + }) + if err != nil { + return err + } + encPaths, _ := rlp.EncodeToRawList(paths) return p2p.Send(p.rw, GetTrieNodesMsg, &GetTrieNodesPacket{ ID: id, Root: root, - Paths: paths, - Bytes: bytes, + Paths: encPaths, + Bytes: uint64(bytes), }) } diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go index 0db206b081..25fe25822b 100644 --- a/eth/protocols/snap/protocol.go +++ b/eth/protocols/snap/protocol.go @@ -78,6 +78,12 @@ type GetAccountRangePacket struct { Bytes uint64 // Soft limit at which to stop returning data } +type accountRangeInput struct { + ID uint64 // ID of the request this is a response for + Accounts rlp.RawList[*AccountData] // List of consecutive accounts from the trie + Proof rlp.RawList[[]byte] // List of trie nodes proving the account range +} + // AccountRangePacket represents an account query response. type AccountRangePacket struct { ID uint64 // ID of the request this is a response for @@ -123,6 +129,12 @@ type GetStorageRangesPacket struct { Bytes uint64 // Soft limit at which to stop returning data } +type storageRangesInput struct { + ID uint64 // ID of the request this is a response for + Slots rlp.RawList[[]*StorageData] // Lists of consecutive storage slots for the requested accounts + Proof rlp.RawList[[]byte] // Merkle proofs for the *last* slot range, if it's incomplete +} + // StorageRangesPacket represents a storage slot query response. type StorageRangesPacket struct { ID uint64 // ID of the request this is a response for @@ -161,6 +173,11 @@ type GetByteCodesPacket struct { Bytes uint64 // Soft limit at which to stop returning data } +type byteCodesInput struct { + ID uint64 // ID of the request this is a response for + Codes rlp.RawList[[]byte] // Requested contract bytecodes +} + // ByteCodesPacket represents a contract bytecode query response. type ByteCodesPacket struct { ID uint64 // ID of the request this is a response for @@ -169,10 +186,10 @@ type ByteCodesPacket struct { // GetTrieNodesPacket represents a state trie node query. type GetTrieNodesPacket struct { - ID uint64 // Request ID to match up responses with - Root common.Hash // Root hash of the account trie to serve - Paths []TrieNodePathSet // Trie node hashes to retrieve the nodes for - Bytes uint64 // Soft limit at which to stop returning data + ID uint64 // Request ID to match up responses with + Root common.Hash // Root hash of the account trie to serve + Paths rlp.RawList[TrieNodePathSet] // Trie node hashes to retrieve the nodes for + Bytes uint64 // Soft limit at which to stop returning data } // TrieNodePathSet is a list of trie node paths to retrieve. A naive way to @@ -187,6 +204,11 @@ type GetTrieNodesPacket struct { // that a slot is accessed before the account path is fully expanded. type TrieNodePathSet [][]byte +type trieNodesInput struct { + ID uint64 // ID of the request this is a response for + Nodes rlp.RawList[[]byte] // Requested state trie nodes +} + // TrieNodesPacket represents a state trie node query response. type TrieNodesPacket struct { ID uint64 // ID of the request this is a response for diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go index cf4e494645..08e85c896a 100644 --- a/eth/protocols/snap/sync.go +++ b/eth/protocols/snap/sync.go @@ -412,19 +412,19 @@ type SyncPeer interface { // RequestAccountRange fetches a batch of accounts rooted in a specific account // trie, starting with the origin. - RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error + RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error // RequestStorageRanges fetches a batch of storage slots belonging to one or // more accounts. If slots from only one account is requested, an origin marker // may also be used to retrieve from there. - RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error + RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error // RequestByteCodes fetches a batch of bytecodes by hash. - RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error + RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error // RequestTrieNodes fetches a batch of account or storage trie nodes rooted in // a specific state trie. - RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error + RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error // Log retrieves the peer's own contextual logger. Log() log.Logger @@ -1102,7 +1102,7 @@ func (s *Syncer) assignAccountTasks(success chan *accountResponse, fail chan *ac if cap < minRequestSize { // Don't bother with peers below a bare minimum performance cap = minRequestSize } - if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, uint64(cap)); err != nil { + if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, cap); err != nil { peer.Log().Debug("Failed to request account range", "err", err) s.scheduleRevertAccountRequest(req) } @@ -1359,7 +1359,7 @@ func (s *Syncer) assignStorageTasks(success chan *storageResponse, fail chan *st if subtask != nil { origin, limit = req.origin[:], req.limit[:] } - if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, uint64(cap)); err != nil { + if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, cap); err != nil { log.Debug("Failed to request storage", "err", err) s.scheduleRevertStorageRequest(req) } @@ -1492,7 +1492,7 @@ func (s *Syncer) assignTrienodeHealTasks(success chan *trienodeHealResponse, fai defer s.pend.Done() // Attempt to send the remote request and revert if it fails - if err := peer.RequestTrieNodes(reqid, root, pathsets, maxRequestSize); err != nil { + if err := peer.RequestTrieNodes(reqid, root, len(paths), pathsets, maxRequestSize); err != nil { log.Debug("Failed to request trienode healers", "err", err) s.scheduleRevertTrienodeHealRequest(req) } diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index ed83af860a..b11ad4e78a 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -119,10 +119,10 @@ func BenchmarkHashing(b *testing.B) { } type ( - accountHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error - storageHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error - trieHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error - codeHandlerFunc func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error + accountHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error + storageHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error + trieHandlerFunc func(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error + codeHandlerFunc func(t *testPeer, id uint64, hashes []common.Hash, max int) error ) type testPeer struct { @@ -182,21 +182,21 @@ Trienode requests: %d `, t.nAccountRequests, t.nStorageRequests, t.nBytecodeRequests, t.nTrienodeRequests) } -func (t *testPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes uint64) error { +func (t *testPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error { t.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes)) t.nAccountRequests++ go t.accountRequestHandler(t, id, root, origin, limit, bytes) return nil } -func (t *testPeer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error { +func (t *testPeer) RequestTrieNodes(id uint64, root common.Hash, count int, paths []TrieNodePathSet, bytes int) error { t.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes)) t.nTrienodeRequests++ go t.trieRequestHandler(t, id, root, paths, bytes) return nil } -func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error { +func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes int) error { t.nStorageRequests++ if len(accounts) == 1 && origin != nil { t.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes)) @@ -207,7 +207,7 @@ func (t *testPeer) RequestStorageRanges(id uint64, root common.Hash, accounts [] return nil } -func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error { +func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) error { t.nBytecodeRequests++ t.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes)) go t.codeRequestHandler(t, id, hashes, bytes) @@ -215,7 +215,7 @@ func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint6 } // defaultTrieRequestHandler is a well-behaving handler for trie healing requests -func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error { +func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error { // Pass the response var nodes [][]byte for _, pathset := range paths { @@ -244,7 +244,7 @@ func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, } // defaultAccountRequestHandler is a well-behaving handler for AccountRangeRequests -func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { +func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { keys, vals, proofs := createAccountRequestResponse(t, root, origin, limit, cap) if err := t.remote.OnAccounts(t, id, keys, vals, proofs); err != nil { t.test.Errorf("Remote side rejected our delivery: %v", err) @@ -254,8 +254,8 @@ func defaultAccountRequestHandler(t *testPeer, id uint64, root common.Hash, orig return nil } -func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) (keys []common.Hash, vals [][]byte, proofs [][]byte) { - var size uint64 +func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.Hash, limit common.Hash, cap int) (keys []common.Hash, vals [][]byte, proofs [][]byte) { + var size int if limit == (common.Hash{}) { limit = common.MaxHash } @@ -266,7 +266,7 @@ func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.H if bytes.Compare(origin[:], entry.k) <= 0 { keys = append(keys, common.BytesToHash(entry.k)) vals = append(vals, entry.v) - size += uint64(32 + len(entry.v)) + size += 32 + len(entry.v) } // If we've exceeded the request threshold, abort if bytes.Compare(entry.k, limit[:]) >= 0 { @@ -290,7 +290,7 @@ func createAccountRequestResponse(t *testPeer, root common.Hash, origin common.H } // defaultStorageRequestHandler is a well-behaving storage request handler -func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max uint64) error { +func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max int) error { hashes, slots, proofs := createStorageRequestResponse(t, root, accounts, bOrigin, bLimit, max) if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil { t.test.Errorf("Remote side rejected our delivery: %v", err) @@ -299,7 +299,7 @@ func defaultStorageRequestHandler(t *testPeer, requestId uint64, root common.Has return nil } -func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error { +func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error { var bytecodes [][]byte for _, h := range hashes { bytecodes = append(bytecodes, getCodeByHash(h)) @@ -311,8 +311,8 @@ func defaultCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max return nil } -func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) { - var size uint64 +func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) { + var size int for _, account := range accounts { // The first account might start from a different origin and end sooner var originHash common.Hash @@ -338,7 +338,7 @@ func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []comm } keys = append(keys, common.BytesToHash(entry.k)) vals = append(vals, entry.v) - size += uint64(32 + len(entry.v)) + size += 32 + len(entry.v) if bytes.Compare(entry.k, limitHash[:]) >= 0 { break } @@ -377,8 +377,8 @@ func createStorageRequestResponse(t *testPeer, root common.Hash, accounts []comm // createStorageRequestResponseAlwaysProve tests a cornercase, where the peer always // supplies the proof for the last account, even if it is 'complete'. -func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max uint64) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) { - var size uint64 +func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, accounts []common.Hash, bOrigin, bLimit []byte, max int) (hashes [][]common.Hash, slots [][][]byte, proofs [][]byte) { + var size int max = max * 3 / 4 var origin common.Hash @@ -395,7 +395,7 @@ func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, acco } keys = append(keys, common.BytesToHash(entry.k)) vals = append(vals, entry.v) - size += uint64(32 + len(entry.v)) + size += 32 + len(entry.v) if size > max { exit = true } @@ -433,34 +433,34 @@ func createStorageRequestResponseAlwaysProve(t *testPeer, root common.Hash, acco } // emptyRequestAccountRangeFn is a rejects AccountRangeRequests -func emptyRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { +func emptyRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { t.remote.OnAccounts(t, requestId, nil, nil, nil) return nil } -func nonResponsiveRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { +func nonResponsiveRequestAccountRangeFn(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { return nil } -func emptyTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error { +func emptyTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error { t.remote.OnTrieNodes(t, requestId, nil) return nil } -func nonResponsiveTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap uint64) error { +func nonResponsiveTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash, paths []TrieNodePathSet, cap int) error { return nil } -func emptyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func emptyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { t.remote.OnStorage(t, requestId, nil, nil, nil) return nil } -func nonResponsiveStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func nonResponsiveStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { return nil } -func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { hashes, slots, proofs := createStorageRequestResponseAlwaysProve(t, root, accounts, origin, limit, max) if err := t.remote.OnStorage(t, requestId, hashes, slots, proofs); err != nil { t.test.Errorf("Remote side rejected our delivery: %v", err) @@ -475,7 +475,7 @@ func proofHappyStorageRequestHandler(t *testPeer, requestId uint64, root common. // return nil //} -func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error { +func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error { var bytecodes [][]byte for _, h := range hashes { // Send back the hashes @@ -489,7 +489,7 @@ func corruptCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max return nil } -func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max uint64) error { +func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max int) error { var bytecodes [][]byte for _, h := range hashes[:1] { bytecodes = append(bytecodes, getCodeByHash(h)) @@ -503,11 +503,11 @@ func cappedCodeRequestHandler(t *testPeer, id uint64, hashes []common.Hash, max } // starvingStorageRequestHandler is somewhat well-behaving storage handler, but it caps the returned results to be very small -func starvingStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func starvingStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { return defaultStorageRequestHandler(t, requestId, root, accounts, origin, limit, 500) } -func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { +func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { return defaultAccountRequestHandler(t, requestId, root, origin, limit, 500) } @@ -515,7 +515,7 @@ func starvingAccountRequestHandler(t *testPeer, requestId uint64, root common.Ha // return defaultAccountRequestHandler(t, requestId-1, root, origin, 500) //} -func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { +func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { hashes, accounts, proofs := createAccountRequestResponse(t, root, origin, limit, cap) if len(proofs) > 0 { proofs = proofs[1:] @@ -529,7 +529,7 @@ func corruptAccountRequestHandler(t *testPeer, requestId uint64, root common.Has } // corruptStorageRequestHandler doesn't provide good proofs -func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { hashes, slots, proofs := createStorageRequestResponse(t, root, accounts, origin, limit, max) if len(proofs) > 0 { proofs = proofs[1:] @@ -542,7 +542,7 @@ func corruptStorageRequestHandler(t *testPeer, requestId uint64, root common.Has return nil } -func noProofStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { +func noProofStorageRequestHandler(t *testPeer, requestId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { hashes, slots, _ := createStorageRequestResponse(t, root, accounts, origin, limit, max) if err := t.remote.OnStorage(t, requestId, hashes, slots, nil); err != nil { t.logger.Info("remote error on delivery (as expected)", "error", err) @@ -577,7 +577,7 @@ func testSyncBloatedProof(t *testing.T, scheme string) { source.accountTrie = sourceAccountTrie.Copy() source.accountValues = elems - source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap uint64) error { + source.accountRequestHandler = func(t *testPeer, requestId uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { var ( keys []common.Hash vals [][]byte @@ -1165,7 +1165,7 @@ func testSyncNoStorageAndOneCodeCappedPeer(t *testing.T, scheme string) { var counter int syncer := setupSyncer( nodeScheme, - mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max uint64) error { + mkSource("capped", func(t *testPeer, id uint64, hashes []common.Hash, max int) error { counter++ return cappedCodeRequestHandler(t, id, hashes, max) }), @@ -1432,7 +1432,7 @@ func testSyncWithUnevenStorage(t *testing.T, scheme string) { source.accountValues = accounts source.setStorageTries(storageTries) source.storageValues = storageElems - source.storageRequestHandler = func(t *testPeer, reqId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max uint64) error { + source.storageRequestHandler = func(t *testPeer, reqId uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, max int) error { return defaultStorageRequestHandler(t, reqId, root, accounts, origin, limit, 128) // retrieve storage in large mode } return source diff --git a/eth/protocols/snap/tracker.go b/eth/protocols/snap/tracker.go deleted file mode 100644 index 2cf59cc23a..0000000000 --- a/eth/protocols/snap/tracker.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2021 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package snap - -import ( - "time" - - "github.com/ethereum/go-ethereum/p2p/tracker" -) - -// requestTracker is a singleton tracker for request times. -var requestTracker = tracker.New(ProtocolName, time.Minute) diff --git a/p2p/tracker/tracker.go b/p2p/tracker/tracker.go index a1cf6f1119..3a9135aa3b 100644 --- a/p2p/tracker/tracker.go +++ b/p2p/tracker/tracker.go @@ -18,52 +18,63 @@ package tracker import ( "container/list" + "errors" "fmt" "sync" "time" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p" ) const ( - // trackedGaugeName is the prefix of the per-packet request tracking. trackedGaugeName = "p2p/tracked" - - // lostMeterName is the prefix of the per-packet request expirations. - lostMeterName = "p2p/lost" - - // staleMeterName is the prefix of the per-packet stale responses. - staleMeterName = "p2p/stale" - - // waitHistName is the prefix of the per-packet (req only) waiting time histograms. - waitHistName = "p2p/wait" + lostMeterName = "p2p/lost" + staleMeterName = "p2p/stale" + waitHistName = "p2p/wait" // maxTrackedPackets is a huge number to act as a failsafe on the number of // pending requests the node will track. It should never be hit unless an // attacker figures out a way to spin requests. - maxTrackedPackets = 100000 + maxTrackedPackets = 10000 ) -// request tracks sent network requests which have not yet received a response. -type request struct { - peer string - version uint // Protocol version +var ( + ErrNoMatchingRequest = errors.New("no matching request") + ErrTooManyItems = errors.New("response is larger than request allows") + ErrCollision = errors.New("request ID collision") + ErrCodeMismatch = errors.New("wrong response code for request") + ErrLimitReached = errors.New("request limit reached") + ErrStopped = errors.New("tracker stopped") +) - reqCode uint64 // Protocol message code of the request - resCode uint64 // Protocol message code of the expected response +// Request tracks sent network requests which have not yet received a response. +type Request struct { + ID uint64 // Request ID + Size int // Number/size of requested items + ReqCode uint64 // Protocol message code of the request + RespCode uint64 // Protocol message code of the expected response time time.Time // Timestamp when the request was made expire *list.Element // Expiration marker to untrack it } +type Response struct { + ID uint64 // Request ID of the response + MsgCode uint64 // Protocol message code + Size int // number/size of items in response +} + // Tracker is a pending network request tracker to measure how much time it takes // a remote peer to respond. type Tracker struct { - protocol string // Protocol capability identifier for the metrics - timeout time.Duration // Global timeout after which to drop a tracked packet + cap p2p.Cap // Protocol capability identifier for the metrics - pending map[uint64]*request // Currently pending requests + peer string // Peer ID + timeout time.Duration // Global timeout after which to drop a tracked packet + + pending map[uint64]*Request // Currently pending requests expire *list.List // Linked list tracking the expiration order wake *time.Timer // Timer tracking the expiration of the next item @@ -72,52 +83,53 @@ type Tracker struct { // New creates a new network request tracker to monitor how much time it takes to // fill certain requests and how individual peers perform. -func New(protocol string, timeout time.Duration) *Tracker { +func New(cap p2p.Cap, peerID string, timeout time.Duration) *Tracker { return &Tracker{ - protocol: protocol, - timeout: timeout, - pending: make(map[uint64]*request), - expire: list.New(), + cap: cap, + peer: peerID, + timeout: timeout, + pending: make(map[uint64]*Request), + expire: list.New(), } } // Track adds a network request to the tracker to wait for a response to arrive // or until the request it cancelled or times out. -func (t *Tracker) Track(peer string, version uint, reqCode uint64, resCode uint64, id uint64) { - if !metrics.Enabled() { - return - } +func (t *Tracker) Track(req Request) error { t.lock.Lock() defer t.lock.Unlock() + if t.expire == nil { + return ErrStopped + } + // If there's a duplicate request, we've just random-collided (or more probably, // we have a bug), report it. We could also add a metric, but we're not really // expecting ourselves to be buggy, so a noisy warning should be enough. - if _, ok := t.pending[id]; ok { - log.Error("Network request id collision", "protocol", t.protocol, "version", version, "code", reqCode, "id", id) - return + if _, ok := t.pending[req.ID]; ok { + log.Error("Network request id collision", "cap", t.cap, "code", req.ReqCode, "id", req.ID) + return ErrCollision } // If we have too many pending requests, bail out instead of leaking memory if pending := len(t.pending); pending >= maxTrackedPackets { - log.Error("Request tracker exceeded allowance", "pending", pending, "peer", peer, "protocol", t.protocol, "version", version, "code", reqCode) - return + log.Error("Request tracker exceeded allowance", "pending", pending, "peer", t.peer, "cap", t.cap, "code", req.ReqCode) + return ErrLimitReached } + // Id doesn't exist yet, start tracking it - t.pending[id] = &request{ - peer: peer, - version: version, - reqCode: reqCode, - resCode: resCode, - time: time.Now(), - expire: t.expire.PushBack(id), + req.time = time.Now() + req.expire = t.expire.PushBack(req.ID) + t.pending[req.ID] = &req + + if metrics.Enabled() { + t.trackedGauge(req.ReqCode).Inc(1) } - g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, version, reqCode) - metrics.GetOrRegisterGauge(g, nil).Inc(1) // If we've just inserted the first item, start the expiration timer if t.wake == nil { t.wake = time.AfterFunc(t.timeout, t.clean) } + return nil } // clean is called automatically when a preset time passes without a response @@ -142,11 +154,10 @@ func (t *Tracker) clean() { t.expire.Remove(head) delete(t.pending, id) - g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, req.version, req.reqCode) - metrics.GetOrRegisterGauge(g, nil).Dec(1) - - m := fmt.Sprintf("%s/%s/%d/%#02x", lostMeterName, t.protocol, req.version, req.reqCode) - metrics.GetOrRegisterMeter(m, nil).Mark(1) + if metrics.Enabled() { + t.trackedGauge(req.ReqCode).Dec(1) + t.lostMeter(req.ReqCode).Mark(1) + } } t.schedule() } @@ -161,46 +172,92 @@ func (t *Tracker) schedule() { t.wake = time.AfterFunc(time.Until(t.pending[t.expire.Front().Value.(uint64)].time.Add(t.timeout)), t.clean) } -// Fulfil fills a pending request, if any is available, reporting on various metrics. -func (t *Tracker) Fulfil(peer string, version uint, code uint64, id uint64) { - if !metrics.Enabled() { - return +// Stop reclaims resources of the tracker. +func (t *Tracker) Stop() { + t.lock.Lock() + defer t.lock.Unlock() + + if t.wake != nil { + t.wake.Stop() + t.wake = nil } + if metrics.Enabled() { + // Ensure metrics are decremented for pending requests. + counts := make(map[uint64]int64) + for _, req := range t.pending { + counts[req.ReqCode]++ + } + for code, count := range counts { + t.trackedGauge(code).Dec(count) + } + } + clear(t.pending) + t.expire = nil +} + +// Fulfil fills a pending request, if any is available, reporting on various metrics. +func (t *Tracker) Fulfil(resp Response) error { t.lock.Lock() defer t.lock.Unlock() // If it's a non existing request, track as stale response - req, ok := t.pending[id] + req, ok := t.pending[resp.ID] if !ok { - m := fmt.Sprintf("%s/%s/%d/%#02x", staleMeterName, t.protocol, version, code) - metrics.GetOrRegisterMeter(m, nil).Mark(1) - return + if metrics.Enabled() { + t.staleMeter(resp.MsgCode).Mark(1) + } + return ErrNoMatchingRequest } + // If the response is funky, it might be some active attack - if req.peer != peer || req.version != version || req.resCode != code { - log.Warn("Network response id collision", - "have", fmt.Sprintf("%s:%s/%d:%d", peer, t.protocol, version, code), - "want", fmt.Sprintf("%s:%s/%d:%d", peer, t.protocol, req.version, req.resCode), + if req.RespCode != resp.MsgCode { + log.Warn("Network response code collision", + "have", fmt.Sprintf("%s:%s/%d:%d", t.peer, t.cap.Name, t.cap.Version, resp.MsgCode), + "want", fmt.Sprintf("%s:%s/%d:%d", t.peer, t.cap.Name, t.cap.Version, req.RespCode), ) - return + return ErrCodeMismatch } + if resp.Size > req.Size { + return ErrTooManyItems + } + // Everything matches, mark the request serviced and meter it wasHead := req.expire.Prev() == nil t.expire.Remove(req.expire) - delete(t.pending, id) + delete(t.pending, req.ID) if wasHead { if t.wake.Stop() { t.schedule() } } - g := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.protocol, req.version, req.reqCode) - metrics.GetOrRegisterGauge(g, nil).Dec(1) - h := fmt.Sprintf("%s/%s/%d/%#02x", waitHistName, t.protocol, req.version, req.reqCode) - sampler := func() metrics.Sample { - return metrics.ResettingSample( - metrics.NewExpDecaySample(1028, 0.015), - ) + // Update request metrics. + if metrics.Enabled() { + t.trackedGauge(req.ReqCode).Dec(1) + t.waitHistogram(req.ReqCode).Update(time.Since(req.time).Microseconds()) } - metrics.GetOrRegisterHistogramLazy(h, nil, sampler).Update(time.Since(req.time).Microseconds()) + return nil +} + +func (t *Tracker) trackedGauge(code uint64) *metrics.Gauge { + name := fmt.Sprintf("%s/%s/%d/%#02x", trackedGaugeName, t.cap.Name, t.cap.Version, code) + return metrics.GetOrRegisterGauge(name, nil) +} + +func (t *Tracker) lostMeter(code uint64) *metrics.Meter { + name := fmt.Sprintf("%s/%s/%d/%#02x", lostMeterName, t.cap.Name, t.cap.Version, code) + return metrics.GetOrRegisterMeter(name, nil) +} + +func (t *Tracker) staleMeter(code uint64) *metrics.Meter { + name := fmt.Sprintf("%s/%s/%d/%#02x", staleMeterName, t.cap.Name, t.cap.Version, code) + return metrics.GetOrRegisterMeter(name, nil) +} + +func (t *Tracker) waitHistogram(code uint64) metrics.Histogram { + name := fmt.Sprintf("%s/%s/%d/%#02x", waitHistName, t.cap.Name, t.cap.Version, code) + sampler := func() metrics.Sample { + return metrics.ResettingSample(metrics.NewExpDecaySample(1028, 0.015)) + } + return metrics.GetOrRegisterHistogramLazy(name, nil, sampler) } diff --git a/p2p/tracker/tracker_test.go b/p2p/tracker/tracker_test.go new file mode 100644 index 0000000000..a37a59f70f --- /dev/null +++ b/p2p/tracker/tracker_test.go @@ -0,0 +1,64 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package tracker + +import ( + "testing" + "time" + + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p" +) + +// This checks that metrics gauges for pending requests are be decremented when a +// Tracker is stopped. +func TestMetricsOnStop(t *testing.T) { + metrics.Enable() + + cap := p2p.Cap{Name: "test", Version: 1} + tr := New(cap, "peer1", time.Minute) + + // Track some requests with different ReqCodes. + var id uint64 + for i := 0; i < 3; i++ { + tr.Track(Request{ID: id, ReqCode: 0x01, RespCode: 0x02, Size: 1}) + id++ + } + for i := 0; i < 5; i++ { + tr.Track(Request{ID: id, ReqCode: 0x03, RespCode: 0x04, Size: 1}) + id++ + } + + gauge1 := tr.trackedGauge(0x01) + gauge2 := tr.trackedGauge(0x03) + + if gauge1.Snapshot().Value() != 3 { + t.Fatalf("gauge1 value mismatch: got %d, want 3", gauge1.Snapshot().Value()) + } + if gauge2.Snapshot().Value() != 5 { + t.Fatalf("gauge2 value mismatch: got %d, want 5", gauge2.Snapshot().Value()) + } + + tr.Stop() + + if gauge1.Snapshot().Value() != 0 { + t.Fatalf("gauge1 value after stop: got %d, want 0", gauge1.Snapshot().Value()) + } + if gauge2.Snapshot().Value() != 0 { + t.Fatalf("gauge2 value after stop: got %d, want 0", gauge2.Snapshot().Value()) + } +}