1
0
Fork 0
forked from forks/go-ethereum

p2p/discover: pass node instead of node ID to TALKREQ handler (#31075)

This is for the implementation of Portal Network in the Shisui client.
Their handler needs access to the node object in order to send further
calls to the requesting node. This is a breaking API change but it
should be fine, since there are basically no known users of TALKREQ
outside of Portal network.

---------

Signed-off-by: thinkAfCod <q315xia@163.com>
Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
thinkAfCod 2025-04-02 20:56:21 +08:00 committed by GitHub
parent 3e4fbce034
commit d2176f463b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 45 additions and 16 deletions

View file

@ -39,7 +39,7 @@ const talkHandlerLaunchTimeout = 400 * time.Millisecond
// Note that talk handlers are expected to come up with a response very quickly, within at // Note that talk handlers are expected to come up with a response very quickly, within at
// most 200ms or so. If the handler takes longer than that, the remote end may time out // most 200ms or so. If the handler takes longer than that, the remote end may time out
// and wont receive the response. // and wont receive the response.
type TalkRequestHandler func(enode.ID, *net.UDPAddr, []byte) []byte type TalkRequestHandler func(*enode.Node, *net.UDPAddr, []byte) []byte
type talkSystem struct { type talkSystem struct {
transport *UDPv5 transport *UDPv5
@ -72,13 +72,19 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) {
// handleRequest handles a talk request. // handleRequest handles a talk request.
func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) { func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) {
n := t.transport.codec.SessionNode(id, addr.String())
if n == nil {
// The node must be contained in the session here, since we wouldn't have
// received the request otherwise.
panic("missing node in session")
}
t.mutex.Lock() t.mutex.Lock()
handler, ok := t.handlers[req.Protocol] handler, ok := t.handlers[req.Protocol]
t.mutex.Unlock() t.mutex.Unlock()
if !ok { if !ok {
resp := &v5wire.TalkResponse{ReqID: req.ReqID} resp := &v5wire.TalkResponse{ReqID: req.ReqID}
t.transport.sendResponse(id, addr, resp) t.transport.sendResponse(n.ID(), addr, resp)
return return
} }
@ -90,9 +96,9 @@ func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire
go func() { go func() {
defer func() { t.slots <- struct{}{} }() defer func() { t.slots <- struct{}{} }()
udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())} udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())}
respMessage := handler(id, udpAddr, req.Message) respMessage := handler(n, udpAddr, req.Message)
resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage} resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage}
t.transport.sendFromAnotherThread(id, addr, resp) t.transport.sendFromAnotherThread(n.ID(), addr, resp)
}() }()
case <-timeout.C: case <-timeout.C:
// Couldn't get it in time, drop the request. // Couldn't get it in time, drop the request.

View file

@ -64,6 +64,9 @@ type codecV5 interface {
// CurrentChallenge returns the most recent WHOAREYOU challenge that was encoded to given node. // CurrentChallenge returns the most recent WHOAREYOU challenge that was encoded to given node.
// This will return a non-nil value if there is an active handshake attempt with the node, and nil otherwise. // This will return a non-nil value if there is an active handshake attempt with the node, and nil otherwise.
CurrentChallenge(id enode.ID, addr string) *v5wire.Whoareyou CurrentChallenge(id enode.ID, addr string) *v5wire.Whoareyou
// SessionNode returns a node that has completed the handshake.
SessionNode(id enode.ID, addr string) *enode.Node
} }
// UDPv5 is the implementation of protocol version 5. // UDPv5 is the implementation of protocol version 5.

View file

@ -492,7 +492,7 @@ func TestUDPv5_talkHandling(t *testing.T) {
defer test.close() defer test.close()
var recvMessage []byte var recvMessage []byte
test.udp.RegisterTalkHandler("test", func(id enode.ID, addr *net.UDPAddr, message []byte) []byte { test.udp.RegisterTalkHandler("test", func(n *enode.Node, addr *net.UDPAddr, message []byte) []byte {
recvMessage = message recvMessage = message
return []byte("test response") return []byte("test response")
}) })
@ -811,6 +811,10 @@ func (c *testCodec) Decode(input []byte, addr string) (enode.ID, *enode.Node, v5
return frame.NodeID, nil, p, nil return frame.NodeID, nil, p, nil
} }
func (c *testCodec) SessionNode(id enode.ID, addr string) *enode.Node {
return c.test.nodesByID[id].Node()
}
func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p v5wire.Packet, err error) { func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p v5wire.Packet, err error) {
if err = rlp.DecodeBytes(input, &frame); err != nil { if err = rlp.DecodeBytes(input, &frame); err != nil {
return frame, nil, fmt.Errorf("invalid frame: %v", err) return frame, nil, fmt.Errorf("invalid frame: %v", err)

View file

@ -359,7 +359,7 @@ func (c *Codec) encodeHandshakeHeader(toID enode.ID, addr string, challenge *Who
} }
// TODO: this should happen when the first authenticated message is received // TODO: this should happen when the first authenticated message is received
c.sc.storeNewSession(toID, addr, session) c.sc.storeNewSession(toID, addr, session, challenge.Node)
// Encode the auth header. // Encode the auth header.
var ( var (
@ -534,7 +534,7 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData
} }
// Handshake OK, drop the challenge and store the new session keys. // Handshake OK, drop the challenge and store the new session keys.
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session) c.sc.storeNewSession(auth.h.SrcID, fromAddr, session, node)
c.sc.deleteHandshake(auth.h.SrcID, fromAddr) c.sc.deleteHandshake(auth.h.SrcID, fromAddr)
return node, msg, nil return node, msg, nil
} }
@ -656,6 +656,10 @@ func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet
return DecodeMessage(msgdata[0], msgdata[1:]) return DecodeMessage(msgdata[0], msgdata[1:])
} }
func (c *Codec) SessionNode(id enode.ID, addr string) *enode.Node {
return c.sc.readNode(id, addr)
}
// checkValid performs some basic validity checks on the header. // checkValid performs some basic validity checks on the header.
// The packetLen here is the length remaining after the static header. // The packetLen here is the length remaining after the static header.
func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error { func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error {

View file

@ -166,7 +166,7 @@ func TestHandshake_rekey(t *testing.T) {
readKey: []byte("BBBBBBBBBBBBBBBB"), readKey: []byte("BBBBBBBBBBBBBBBB"),
writeKey: []byte("AAAAAAAAAAAAAAAA"), writeKey: []byte("AAAAAAAAAAAAAAAA"),
} }
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session) net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
// A -> B FINDNODE (encrypted with zero keys) // A -> B FINDNODE (encrypted with zero keys)
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{}) findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{})
@ -209,8 +209,8 @@ func TestHandshake_rekey2(t *testing.T) {
readKey: []byte("CCCCCCCCCCCCCCCC"), readKey: []byte("CCCCCCCCCCCCCCCC"),
writeKey: []byte("DDDDDDDDDDDDDDDD"), writeKey: []byte("DDDDDDDDDDDDDDDD"),
} }
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA) net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA, net.nodeB.n())
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB) net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB, net.nodeA.n())
// A -> B FINDNODE encrypted with initKeysA // A -> B FINDNODE encrypted with initKeysA
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{Distances: []uint{3}}) findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{Distances: []uint{3}})
@ -362,8 +362,8 @@ func TestTestVectorsV5(t *testing.T) {
ENRSeq: 2, ENRSeq: 2,
}, },
prep: func(net *handshakeTest) { prep: func(net *handshakeTest) {
net.nodeA.c.sc.storeNewSession(idB, addr, session) net.nodeA.c.sc.storeNewSession(idB, addr, session, net.nodeB.n())
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped()) net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped(), net.nodeA.n())
}, },
}, },
{ {
@ -499,8 +499,8 @@ func BenchmarkV5_DecodePing(b *testing.B) {
readKey: []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17}, readKey: []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17},
writeKey: []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134}, writeKey: []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134},
} }
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session) net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped()) net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped(), net.nodeA.n())
addrB := net.nodeA.addr() addrB := net.nodeA.addr()
ping := &Ping{ReqID: []byte("reqid"), ENRSeq: 5} ping := &Ping{ReqID: []byte("reqid"), ENRSeq: 5}
enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), addrB, ping, nil) enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), addrB, ping, nil)

View file

@ -54,11 +54,12 @@ type session struct {
writeKey []byte writeKey []byte
readKey []byte readKey []byte
nonceCounter uint32 nonceCounter uint32
node *enode.Node
} }
// keysFlipped returns a copy of s with the read and write keys flipped. // keysFlipped returns a copy of s with the read and write keys flipped.
func (s *session) keysFlipped() *session { func (s *session) keysFlipped() *session {
return &session{s.readKey, s.writeKey, s.nonceCounter} return &session{s.readKey, s.writeKey, s.nonceCounter, s.node}
} }
func NewSessionCache(maxItems int, clock mclock.Clock) *SessionCache { func NewSessionCache(maxItems int, clock mclock.Clock) *SessionCache {
@ -103,8 +104,19 @@ func (sc *SessionCache) readKey(id enode.ID, addr string) []byte {
return nil return nil
} }
func (sc *SessionCache) readNode(id enode.ID, addr string) *enode.Node {
if s := sc.session(id, addr); s != nil {
return s.node
}
return nil
}
// storeNewSession stores new encryption keys in the cache. // storeNewSession stores new encryption keys in the cache.
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session) { func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session, n *enode.Node) {
if n == nil {
panic("nil node in storeNewSession")
}
s.node = n
sc.sessions.Add(sessionID{id, addr}, s) sc.sessions.Add(sessionID{id, addr}, s)
} }