diff --git a/cmd/devp2p/internal/ethtest/protocol.go b/cmd/devp2p/internal/ethtest/protocol.go index a21d1ca7a1..af76082318 100644 --- a/cmd/devp2p/internal/ethtest/protocol.go +++ b/cmd/devp2p/internal/ethtest/protocol.go @@ -86,3 +86,9 @@ func protoOffset(proto Proto) uint64 { panic("unhandled protocol") } } + +// msgTypePtr is the constraint for protocol message types. +type msgTypePtr[U any] interface { + *U + Kind() byte +} diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go index 47327b6844..c23360bf82 100644 --- a/cmd/devp2p/internal/ethtest/suite.go +++ b/cmd/devp2p/internal/ethtest/suite.go @@ -196,6 +196,7 @@ to check if the node disconnects after receiving multiple invalid requests.`) func (s *Suite) TestSimultaneousRequests(t *utesting.T) { t.Log(`This test requests blocks headers from the node, performing two requests concurrently, with different request IDs.`) + conn, err := s.dialAndPeer(nil) if err != nil { t.Fatalf("peering failed: %v", err) @@ -235,37 +236,36 @@ concurrently, with different request IDs.`) } // Wait for responses. - headers1 := new(eth.BlockHeadersPacket) - if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers1); err != nil { - t.Fatalf("error reading block headers msg: %v", err) - } - if got, want := headers1.RequestId, req1.RequestId; got != want { - t.Fatalf("unexpected request id in response: got %d, want %d", got, want) - } - headers2 := new(eth.BlockHeadersPacket) - if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers2); err != nil { - t.Fatalf("error reading block headers msg: %v", err) - } - if got, want := headers2.RequestId, req2.RequestId; got != want { - t.Fatalf("unexpected request id in response: got %d, want %d", got, want) + // Note they can arrive in either order. + resp, err := collectResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 { + if msg.RequestId != 111 && msg.RequestId != 222 { + t.Fatalf("response with unknown request ID: %v", msg.RequestId) + } + return msg.RequestId + }) + if err != nil { + t.Fatal(err) } - // Check received headers for accuracy. + // Check if headers match. + resp1 := resp[111] if expected, err := s.chain.GetHeaders(req1); err != nil { t.Fatalf("failed to get expected headers for request 1: %v", err) - } else if !headersMatch(expected, headers1.BlockHeadersRequest) { - t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers1) + } else if !headersMatch(expected, resp1.BlockHeadersRequest) { + t.Fatalf("header mismatch for request ID %v: \nexpected %v \ngot %v", 111, expected, resp1) } + resp2 := resp[222] if expected, err := s.chain.GetHeaders(req2); err != nil { t.Fatalf("failed to get expected headers for request 2: %v", err) - } else if !headersMatch(expected, headers2.BlockHeadersRequest) { - t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers2) + } else if !headersMatch(expected, resp2.BlockHeadersRequest) { + t.Fatalf("header mismatch for request ID %v: \nexpected %v \ngot %v", 222, expected, resp2) } } func (s *Suite) TestSameRequestID(t *utesting.T) { t.Log(`This test requests block headers, performing two concurrent requests with the same request ID. The node should handle the request by responding to both requests.`) + conn, err := s.dialAndPeer(nil) if err != nil { t.Fatalf("peering failed: %v", err) @@ -289,7 +289,7 @@ same request ID. The node should handle the request by responding to both reques Origin: eth.HashOrNumber{ Number: 33, }, - Amount: 2, + Amount: 3, }, } @@ -301,35 +301,52 @@ same request ID. The node should handle the request by responding to both reques t.Fatalf("failed to write to connection: %v", err) } - // Wait for the responses. - headers1 := new(eth.BlockHeadersPacket) - if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers1); err != nil { - t.Fatalf("error reading from connection: %v", err) - } - if got, want := headers1.RequestId, request1.RequestId; got != want { - t.Fatalf("unexpected request id: got %d, want %d", got, want) - } - headers2 := new(eth.BlockHeadersPacket) - if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, &headers2); err != nil { - t.Fatalf("error reading from connection: %v", err) - } - if got, want := headers2.RequestId, request2.RequestId; got != want { - t.Fatalf("unexpected request id: got %d, want %d", got, want) + // Wait for the responses. They can arrive in either order, and we can't tell them + // apart by their request ID, so use the number of headers instead. + resp, err := collectResponses(conn, 2, func(msg *eth.BlockHeadersPacket) uint64 { + id := uint64(len(msg.BlockHeadersRequest)) + if id != 2 && id != 3 { + t.Fatalf("invalid number of headers in response: %d", id) + } + return id + }) + if err != nil { + t.Fatal(err) } // Check if headers match. + resp1 := resp[2] if expected, err := s.chain.GetHeaders(request1); err != nil { - t.Fatalf("failed to get expected block headers: %v", err) - } else if !headersMatch(expected, headers1.BlockHeadersRequest) { - t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers1) + t.Fatalf("failed to get expected headers for request 1: %v", err) + } else if !headersMatch(expected, resp1.BlockHeadersRequest) { + t.Fatalf("headers mismatch: \nexpected %v \ngot %v", expected, resp1) } + resp2 := resp[3] if expected, err := s.chain.GetHeaders(request2); err != nil { - t.Fatalf("failed to get expected block headers: %v", err) - } else if !headersMatch(expected, headers2.BlockHeadersRequest) { - t.Fatalf("header mismatch: \nexpected %v \ngot %v", expected, headers2) + t.Fatalf("failed to get expected headers for request 2: %v", err) + } else if !headersMatch(expected, resp2.BlockHeadersRequest) { + t.Fatalf("headers mismatch: \nexpected %v \ngot %v", expected, resp2) } } +// collectResponses waits for n messages of type T on the given connection. +// The messsages are collected according to the 'identity' function. +func collectResponses[T any, P msgTypePtr[T]](conn *Conn, n int, identity func(P) uint64) (map[uint64]P, error) { + resp := make(map[uint64]P, n) + for range n { + r := new(T) + if err := conn.ReadMsg(ethProto, eth.BlockHeadersMsg, r); err != nil { + return resp, fmt.Errorf("read error: %v", err) + } + id := identity(r) + if resp[id] != nil { + return resp, fmt.Errorf("duplicate response %v", r) + } + resp[id] = r + } + return resp, nil +} + func (s *Suite) TestZeroRequestID(t *utesting.T) { t.Log(`This test sends a GetBlockHeaders message with a request-id of zero, and expects a response.`) @@ -887,7 +904,7 @@ func (s *Suite) makeBlobTxs(count, blobs int, discriminator byte) (txs types.Tra from, nonce := s.chain.GetSender(5) for i := 0; i < count; i++ { // Make blob data, max of 2 blobs per tx. - blobdata := make([]byte, blobs%3) + blobdata := make([]byte, min(blobs, 2)) for i := range blobdata { blobdata[i] = discriminator blobs -= 1