diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 9e849751c1..6f7c797152 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -50,11 +50,20 @@ const ( // encoding/decoding and with the handshake; the UDPv5 object handles higher-level concerns. type codecV5 interface { // Encode encodes a packet. - Encode(enode.ID, string, v5wire.Packet, *v5wire.Whoareyou) ([]byte, v5wire.Nonce, error) + // + // If the underlying type of 'p' is *v5wire.Whoareyou, a Whoareyou challenge packet is + // encoded. If the 'challenge' parameter is non-nil, the packet is encoded as a + // handshake message packet. Otherwise, the packet will be encoded as an ordinary + // message packet. + Encode(id enode.ID, addr string, p v5wire.Packet, challenge *v5wire.Whoareyou) ([]byte, v5wire.Nonce, error) // Decode decodes a packet. It returns a *v5wire.Unknown packet if decryption fails. // The *enode.Node return value is non-nil when the input contains a handshake response. - Decode([]byte, string) (enode.ID, *enode.Node, v5wire.Packet, error) + Decode(b []byte, addr string) (enode.ID, *enode.Node, v5wire.Packet, error) + + // CurrentChallenge returns the most recent WHOAREYOU challenge that was encoded to given node. + // This will return a non-nil value if there is an active handshake attempt with the node, and nil otherwise. + CurrentChallenge(id enode.ID, addr string) *v5wire.Whoareyou } // UDPv5 is the implementation of protocol version 5. @@ -824,6 +833,19 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr netip.AddrPort // handleUnknown initiates a handshake by responding with WHOAREYOU. func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr netip.AddrPort) { + currentChallenge := t.codec.CurrentChallenge(fromID, fromAddr.String()) + if currentChallenge != nil { + // This case happens when the sender issues multiple concurrent requests. + // Since we only support one in-progress handshake at a time, we need to tell + // them which handshake attempt they need to complete. We tell them to use the + // existing handshake attempt since the response to that one might still be in + // transit. + t.log.Debug("Repeating discv5 handshake challenge", "id", fromID, "addr", fromAddr) + t.sendResponse(fromID, fromAddr, currentChallenge) + return + } + + // Send a fresh challenge. challenge := &v5wire.Whoareyou{Nonce: p.Nonce} crand.Read(challenge.IDNonce[:]) if n := t.GetNode(fromID); n != nil { diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go index 371f414760..3026dff538 100644 --- a/p2p/discover/v5_udp_test.go +++ b/p2p/discover/v5_udp_test.go @@ -140,6 +140,26 @@ func TestUDPv5_unknownPacket(t *testing.T) { test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { check(p, 0) }) +} + +func TestUDPv5_unknownPacketKnownNode(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + nonce := v5wire.Nonce{1, 2, 3} + check := func(p *v5wire.Whoareyou, wantSeq uint64) { + t.Helper() + if p.Nonce != nonce { + t.Error("wrong nonce in WHOAREYOU:", p.Nonce, nonce) + } + if p.IDNonce == ([16]byte{}) { + t.Error("all zero ID nonce") + } + if p.RecordSeq != wantSeq { + t.Errorf("wrong record seq %d in WHOAREYOU, want %d", p.RecordSeq, wantSeq) + } + } // Make node known. n := test.getNode(test.remotekey, test.remoteaddr).Node() @@ -151,6 +171,42 @@ func TestUDPv5_unknownPacket(t *testing.T) { }) } +// This test checks that, when multiple 'unknown' packets are received during a handshake, +// the node sticks to the first handshake attempt. +func TestUDPv5_handshakeRepeatChallenge(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + nonce1 := v5wire.Nonce{1} + nonce2 := v5wire.Nonce{2} + nonce3 := v5wire.Nonce{3} + check := func(p *v5wire.Whoareyou, wantNonce v5wire.Nonce) { + t.Helper() + if p.Nonce != wantNonce { + t.Error("wrong nonce in WHOAREYOU:", p.Nonce, wantNonce) + } + } + + // Unknown packet from unknown node. + test.packetIn(&v5wire.Unknown{Nonce: nonce1}) + test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { + check(p, nonce1) + }) + + // Second unknown packet. Here we expect the response to reference the + // first unknown packet. + test.packetIn(&v5wire.Unknown{Nonce: nonce2}) + test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { + check(p, nonce1) + }) + // Third unknown packet. This should still return the first nonce. + test.packetIn(&v5wire.Unknown{Nonce: nonce3}) + test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { + check(p, nonce1) + }) +} + // This test checks that incoming FINDNODE calls are handled correctly. func TestUDPv5_findnodeHandling(t *testing.T) { t.Parallel() @@ -698,6 +754,8 @@ type testCodec struct { test *udpV5Test id enode.ID ctr uint64 + + sentChallenges map[enode.ID]*v5wire.Whoareyou } type testCodecFrame struct { @@ -712,11 +770,23 @@ func (c *testCodec) Encode(toID enode.ID, addr string, p v5wire.Packet, _ *v5wir var authTag v5wire.Nonce binary.BigEndian.PutUint64(authTag[:], c.ctr) + if w, ok := p.(*v5wire.Whoareyou); ok { + // Store recently sent Whoareyou challenges. + if c.sentChallenges == nil { + c.sentChallenges = make(map[enode.ID]*v5wire.Whoareyou) + } + c.sentChallenges[toID] = w + } + penc, _ := rlp.EncodeToBytes(p) frame, err := rlp.EncodeToBytes(testCodecFrame{c.id, authTag, p.Kind(), penc}) return frame, authTag, err } +func (c *testCodec) CurrentChallenge(id enode.ID, addr string) *v5wire.Whoareyou { + return c.sentChallenges[id] +} + func (c *testCodec) Decode(input []byte, addr string) (enode.ID, *enode.Node, v5wire.Packet, error) { frame, p, err := c.decodeFrame(input) if err != nil { diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index 904a3ddec6..e50b7cd16d 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -245,6 +245,12 @@ func (c *Codec) EncodeRaw(id enode.ID, head Header, msgdata []byte) ([]byte, err return c.buf.Bytes(), nil } +// CurrentChallenge returns the latest challenge sent to the given node. +// This will return non-nil while a handshake is in progress. +func (c *Codec) CurrentChallenge(id enode.ID, addr string) *Whoareyou { + return c.sc.getHandshake(id, addr) +} + func (c *Codec) writeHeaders(head *Header) { c.buf.Reset() c.buf.Write(head.IV[:])