p2p/discover: pass node instead of node ID to TALKREQ handler (#31075)

This is for the implementation of Portal Network in the Shisui client.
Their handler needs access to the node object in order to send further
calls to the requesting node. This is a breaking API change but it
should be fine, since there are basically no known users of TALKREQ
outside of Portal network.

---------

Signed-off-by: thinkAfCod <q315xia@163.com>
Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
thinkAfCod 2025-04-02 20:56:21 +08:00 committed by GitHub
parent 3e4fbce034
commit d2176f463b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 45 additions and 16 deletions

View file

@ -39,7 +39,7 @@ const talkHandlerLaunchTimeout = 400 * time.Millisecond
// Note that talk handlers are expected to come up with a response very quickly, within at
// most 200ms or so. If the handler takes longer than that, the remote end may time out
// and wont receive the response.
type TalkRequestHandler func(enode.ID, *net.UDPAddr, []byte) []byte
type TalkRequestHandler func(*enode.Node, *net.UDPAddr, []byte) []byte
type talkSystem struct {
transport *UDPv5
@ -72,13 +72,19 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) {
// handleRequest handles a talk request.
func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) {
n := t.transport.codec.SessionNode(id, addr.String())
if n == nil {
// The node must be contained in the session here, since we wouldn't have
// received the request otherwise.
panic("missing node in session")
}
t.mutex.Lock()
handler, ok := t.handlers[req.Protocol]
t.mutex.Unlock()
if !ok {
resp := &v5wire.TalkResponse{ReqID: req.ReqID}
t.transport.sendResponse(id, addr, resp)
t.transport.sendResponse(n.ID(), addr, resp)
return
}
@ -90,9 +96,9 @@ func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire
go func() {
defer func() { t.slots <- struct{}{} }()
udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())}
respMessage := handler(id, udpAddr, req.Message)
respMessage := handler(n, udpAddr, req.Message)
resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage}
t.transport.sendFromAnotherThread(id, addr, resp)
t.transport.sendFromAnotherThread(n.ID(), addr, resp)
}()
case <-timeout.C:
// Couldn't get it in time, drop the request.

View file

@ -64,6 +64,9 @@ type codecV5 interface {
// CurrentChallenge returns the most recent WHOAREYOU challenge that was encoded to given node.
// This will return a non-nil value if there is an active handshake attempt with the node, and nil otherwise.
CurrentChallenge(id enode.ID, addr string) *v5wire.Whoareyou
// SessionNode returns a node that has completed the handshake.
SessionNode(id enode.ID, addr string) *enode.Node
}
// UDPv5 is the implementation of protocol version 5.

View file

@ -492,7 +492,7 @@ func TestUDPv5_talkHandling(t *testing.T) {
defer test.close()
var recvMessage []byte
test.udp.RegisterTalkHandler("test", func(id enode.ID, addr *net.UDPAddr, message []byte) []byte {
test.udp.RegisterTalkHandler("test", func(n *enode.Node, addr *net.UDPAddr, message []byte) []byte {
recvMessage = message
return []byte("test response")
})
@ -811,6 +811,10 @@ func (c *testCodec) Decode(input []byte, addr string) (enode.ID, *enode.Node, v5
return frame.NodeID, nil, p, nil
}
func (c *testCodec) SessionNode(id enode.ID, addr string) *enode.Node {
return c.test.nodesByID[id].Node()
}
func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p v5wire.Packet, err error) {
if err = rlp.DecodeBytes(input, &frame); err != nil {
return frame, nil, fmt.Errorf("invalid frame: %v", err)

View file

@ -359,7 +359,7 @@ func (c *Codec) encodeHandshakeHeader(toID enode.ID, addr string, challenge *Who
}
// TODO: this should happen when the first authenticated message is received
c.sc.storeNewSession(toID, addr, session)
c.sc.storeNewSession(toID, addr, session, challenge.Node)
// Encode the auth header.
var (
@ -534,7 +534,7 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData
}
// Handshake OK, drop the challenge and store the new session keys.
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session)
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session, node)
c.sc.deleteHandshake(auth.h.SrcID, fromAddr)
return node, msg, nil
}
@ -656,6 +656,10 @@ func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet
return DecodeMessage(msgdata[0], msgdata[1:])
}
func (c *Codec) SessionNode(id enode.ID, addr string) *enode.Node {
return c.sc.readNode(id, addr)
}
// checkValid performs some basic validity checks on the header.
// The packetLen here is the length remaining after the static header.
func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error {

View file

@ -166,7 +166,7 @@ func TestHandshake_rekey(t *testing.T) {
readKey: []byte("BBBBBBBBBBBBBBBB"),
writeKey: []byte("AAAAAAAAAAAAAAAA"),
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
// A -> B FINDNODE (encrypted with zero keys)
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{})
@ -209,8 +209,8 @@ func TestHandshake_rekey2(t *testing.T) {
readKey: []byte("CCCCCCCCCCCCCCCC"),
writeKey: []byte("DDDDDDDDDDDDDDDD"),
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA)
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB)
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA, net.nodeB.n())
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB, net.nodeA.n())
// A -> B FINDNODE encrypted with initKeysA
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{Distances: []uint{3}})
@ -362,8 +362,8 @@ func TestTestVectorsV5(t *testing.T) {
ENRSeq: 2,
},
prep: func(net *handshakeTest) {
net.nodeA.c.sc.storeNewSession(idB, addr, session)
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped())
net.nodeA.c.sc.storeNewSession(idB, addr, session, net.nodeB.n())
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped(), net.nodeA.n())
},
},
{
@ -499,8 +499,8 @@ func BenchmarkV5_DecodePing(b *testing.B) {
readKey: []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17},
writeKey: []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134},
}
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped())
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped(), net.nodeA.n())
addrB := net.nodeA.addr()
ping := &Ping{ReqID: []byte("reqid"), ENRSeq: 5}
enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), addrB, ping, nil)

View file

@ -54,11 +54,12 @@ type session struct {
writeKey []byte
readKey []byte
nonceCounter uint32
node *enode.Node
}
// keysFlipped returns a copy of s with the read and write keys flipped.
func (s *session) keysFlipped() *session {
return &session{s.readKey, s.writeKey, s.nonceCounter}
return &session{s.readKey, s.writeKey, s.nonceCounter, s.node}
}
func NewSessionCache(maxItems int, clock mclock.Clock) *SessionCache {
@ -103,8 +104,19 @@ func (sc *SessionCache) readKey(id enode.ID, addr string) []byte {
return nil
}
func (sc *SessionCache) readNode(id enode.ID, addr string) *enode.Node {
if s := sc.session(id, addr); s != nil {
return s.node
}
return nil
}
// storeNewSession stores new encryption keys in the cache.
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session) {
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session, n *enode.Node) {
if n == nil {
panic("nil node in storeNewSession")
}
s.node = n
sc.sessions.Add(sessionID{id, addr}, s)
}