diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go index 3026dff538..606b35c4f2 100644 --- a/p2p/discover/v5_udp_test.go +++ b/p2p/discover/v5_udp_test.go @@ -181,29 +181,35 @@ func TestUDPv5_handshakeRepeatChallenge(t *testing.T) { nonce1 := v5wire.Nonce{1} nonce2 := v5wire.Nonce{2} nonce3 := v5wire.Nonce{3} - check := func(p *v5wire.Whoareyou, wantNonce v5wire.Nonce) { + var firstAuthTag *v5wire.Nonce + check := func(p *v5wire.Whoareyou, authTag, wantNonce v5wire.Nonce) { t.Helper() if p.Nonce != wantNonce { - t.Error("wrong nonce in WHOAREYOU:", p.Nonce, wantNonce) + t.Error("wrong nonce in WHOAREYOU:", p.Nonce, "want:", wantNonce) + } + if firstAuthTag == nil { + firstAuthTag = &authTag + } else if authTag != *firstAuthTag { + t.Error("wrong auth tag in WHOAREYOU header:", authTag, "want:", *firstAuthTag) } } // Unknown packet from unknown node. test.packetIn(&v5wire.Unknown{Nonce: nonce1}) - test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { - check(p, nonce1) + test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, authTag v5wire.Nonce) { + check(p, authTag, nonce1) }) // Second unknown packet. Here we expect the response to reference the // first unknown packet. test.packetIn(&v5wire.Unknown{Nonce: nonce2}) - test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { - check(p, nonce1) + test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, authTag v5wire.Nonce) { + check(p, authTag, nonce1) }) // Third unknown packet. This should still return the first nonce. test.packetIn(&v5wire.Unknown{Nonce: nonce3}) - test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { - check(p, nonce1) + test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, authTag v5wire.Nonce) { + check(p, authTag, nonce1) }) } @@ -766,20 +772,30 @@ type testCodecFrame struct { } func (c *testCodec) Encode(toID enode.ID, addr string, p v5wire.Packet, _ *v5wire.Whoareyou) ([]byte, v5wire.Nonce, error) { + // To match the behavior of v5wire.Codec, we return the cached encoding of + // WHOAREYOU challenges. + if wp, ok := p.(*v5wire.Whoareyou); ok && len(wp.Encoded) > 0 { + return wp.Encoded, wp.Nonce, nil + } + c.ctr++ var authTag v5wire.Nonce binary.BigEndian.PutUint64(authTag[:], c.ctr) + penc, _ := rlp.EncodeToBytes(p) + frame, err := rlp.EncodeToBytes(testCodecFrame{c.id, authTag, p.Kind(), penc}) + if err != nil { + return frame, authTag, err + } + // Store recently sent challenges. if w, ok := p.(*v5wire.Whoareyou); ok { - // Store recently sent Whoareyou challenges. + w.Nonce = authTag + w.Encoded = frame if c.sentChallenges == nil { c.sentChallenges = make(map[enode.ID]*v5wire.Whoareyou) } c.sentChallenges[toID] = w } - - penc, _ := rlp.EncodeToBytes(p) - frame, err := rlp.EncodeToBytes(testCodecFrame{c.id, authTag, p.Kind(), penc}) return frame, authTag, err } diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index e50b7cd16d..b16d14eda5 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -189,6 +189,11 @@ func (c *Codec) Encode(id enode.ID, addr string, packet Packet, challenge *Whoar ) switch { case packet.Kind() == WhoareyouPacket: + // just send the WHOAREYOU packet raw again, rather than the re-encoded challenge data + w := packet.(*Whoareyou) + if len(w.Encoded) > 0 { + return w.Encoded, w.Nonce, nil + } head, err = c.encodeWhoareyou(id, packet.(*Whoareyou)) case challenge != nil: // We have an unanswered challenge, send handshake. @@ -218,15 +223,22 @@ func (c *Codec) Encode(id enode.ID, addr string, packet Packet, challenge *Whoar // Store sent WHOAREYOU challenges. if challenge, ok := packet.(*Whoareyou); ok { challenge.ChallengeData = bytesCopy(&c.buf) + enc, err := c.EncodeRaw(id, head, msgData) + if err != nil { + return nil, Nonce{}, err + } + challenge.Encoded = bytes.Clone(enc) c.sc.storeSentHandshake(id, addr, challenge) - } else if msgData == nil { + return enc, head.Nonce, err + } + + if msgData == nil { headerData := c.buf.Bytes() msgData, err = c.encryptMessage(session, packet, &head, headerData) if err != nil { return nil, Nonce{}, err } } - enc, err := c.EncodeRaw(id, head, msgData) return enc, head.Nonce, err } diff --git a/p2p/discover/v5wire/msg.go b/p2p/discover/v5wire/msg.go index 401db2f6c5..089fd4ebdc 100644 --- a/p2p/discover/v5wire/msg.go +++ b/p2p/discover/v5wire/msg.go @@ -73,6 +73,9 @@ type ( Node *enode.Node sent mclock.AbsTime // for handshake GC. + + // Encoded is packet raw data for sending out, but should not be include in the RLP encoding. + Encoded []byte `rlp:"-"` } // PING is sent during liveness checks.