forked from forks/go-ethereum
p2p/discover: pass node instead of node ID to TALKREQ handler (#31075)
This is for the implementation of Portal Network in the Shisui client. Their handler needs access to the node object in order to send further calls to the requesting node. This is a breaking API change but it should be fine, since there are basically no known users of TALKREQ outside of Portal network. --------- Signed-off-by: thinkAfCod <q315xia@163.com> Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
parent
3e4fbce034
commit
d2176f463b
6 changed files with 45 additions and 16 deletions
|
|
@ -39,7 +39,7 @@ const talkHandlerLaunchTimeout = 400 * time.Millisecond
|
|||
// Note that talk handlers are expected to come up with a response very quickly, within at
|
||||
// most 200ms or so. If the handler takes longer than that, the remote end may time out
|
||||
// and wont receive the response.
|
||||
type TalkRequestHandler func(enode.ID, *net.UDPAddr, []byte) []byte
|
||||
type TalkRequestHandler func(*enode.Node, *net.UDPAddr, []byte) []byte
|
||||
|
||||
type talkSystem struct {
|
||||
transport *UDPv5
|
||||
|
|
@ -72,13 +72,19 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) {
|
|||
|
||||
// handleRequest handles a talk request.
|
||||
func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) {
|
||||
n := t.transport.codec.SessionNode(id, addr.String())
|
||||
if n == nil {
|
||||
// The node must be contained in the session here, since we wouldn't have
|
||||
// received the request otherwise.
|
||||
panic("missing node in session")
|
||||
}
|
||||
t.mutex.Lock()
|
||||
handler, ok := t.handlers[req.Protocol]
|
||||
t.mutex.Unlock()
|
||||
|
||||
if !ok {
|
||||
resp := &v5wire.TalkResponse{ReqID: req.ReqID}
|
||||
t.transport.sendResponse(id, addr, resp)
|
||||
t.transport.sendResponse(n.ID(), addr, resp)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -90,9 +96,9 @@ func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire
|
|||
go func() {
|
||||
defer func() { t.slots <- struct{}{} }()
|
||||
udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())}
|
||||
respMessage := handler(id, udpAddr, req.Message)
|
||||
respMessage := handler(n, udpAddr, req.Message)
|
||||
resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage}
|
||||
t.transport.sendFromAnotherThread(id, addr, resp)
|
||||
t.transport.sendFromAnotherThread(n.ID(), addr, resp)
|
||||
}()
|
||||
case <-timeout.C:
|
||||
// Couldn't get it in time, drop the request.
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ type codecV5 interface {
|
|||
// CurrentChallenge returns the most recent WHOAREYOU challenge that was encoded to given node.
|
||||
// This will return a non-nil value if there is an active handshake attempt with the node, and nil otherwise.
|
||||
CurrentChallenge(id enode.ID, addr string) *v5wire.Whoareyou
|
||||
|
||||
// SessionNode returns a node that has completed the handshake.
|
||||
SessionNode(id enode.ID, addr string) *enode.Node
|
||||
}
|
||||
|
||||
// UDPv5 is the implementation of protocol version 5.
|
||||
|
|
|
|||
|
|
@ -492,7 +492,7 @@ func TestUDPv5_talkHandling(t *testing.T) {
|
|||
defer test.close()
|
||||
|
||||
var recvMessage []byte
|
||||
test.udp.RegisterTalkHandler("test", func(id enode.ID, addr *net.UDPAddr, message []byte) []byte {
|
||||
test.udp.RegisterTalkHandler("test", func(n *enode.Node, addr *net.UDPAddr, message []byte) []byte {
|
||||
recvMessage = message
|
||||
return []byte("test response")
|
||||
})
|
||||
|
|
@ -811,6 +811,10 @@ func (c *testCodec) Decode(input []byte, addr string) (enode.ID, *enode.Node, v5
|
|||
return frame.NodeID, nil, p, nil
|
||||
}
|
||||
|
||||
func (c *testCodec) SessionNode(id enode.ID, addr string) *enode.Node {
|
||||
return c.test.nodesByID[id].Node()
|
||||
}
|
||||
|
||||
func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p v5wire.Packet, err error) {
|
||||
if err = rlp.DecodeBytes(input, &frame); err != nil {
|
||||
return frame, nil, fmt.Errorf("invalid frame: %v", err)
|
||||
|
|
|
|||
|
|
@ -359,7 +359,7 @@ func (c *Codec) encodeHandshakeHeader(toID enode.ID, addr string, challenge *Who
|
|||
}
|
||||
|
||||
// TODO: this should happen when the first authenticated message is received
|
||||
c.sc.storeNewSession(toID, addr, session)
|
||||
c.sc.storeNewSession(toID, addr, session, challenge.Node)
|
||||
|
||||
// Encode the auth header.
|
||||
var (
|
||||
|
|
@ -534,7 +534,7 @@ func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData
|
|||
}
|
||||
|
||||
// Handshake OK, drop the challenge and store the new session keys.
|
||||
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session)
|
||||
c.sc.storeNewSession(auth.h.SrcID, fromAddr, session, node)
|
||||
c.sc.deleteHandshake(auth.h.SrcID, fromAddr)
|
||||
return node, msg, nil
|
||||
}
|
||||
|
|
@ -656,6 +656,10 @@ func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet
|
|||
return DecodeMessage(msgdata[0], msgdata[1:])
|
||||
}
|
||||
|
||||
func (c *Codec) SessionNode(id enode.ID, addr string) *enode.Node {
|
||||
return c.sc.readNode(id, addr)
|
||||
}
|
||||
|
||||
// checkValid performs some basic validity checks on the header.
|
||||
// The packetLen here is the length remaining after the static header.
|
||||
func (h *StaticHeader) checkValid(packetLen int, protocolID [6]byte) error {
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ func TestHandshake_rekey(t *testing.T) {
|
|||
readKey: []byte("BBBBBBBBBBBBBBBB"),
|
||||
writeKey: []byte("AAAAAAAAAAAAAAAA"),
|
||||
}
|
||||
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
|
||||
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
|
||||
|
||||
// A -> B FINDNODE (encrypted with zero keys)
|
||||
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{})
|
||||
|
|
@ -209,8 +209,8 @@ func TestHandshake_rekey2(t *testing.T) {
|
|||
readKey: []byte("CCCCCCCCCCCCCCCC"),
|
||||
writeKey: []byte("DDDDDDDDDDDDDDDD"),
|
||||
}
|
||||
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA)
|
||||
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB)
|
||||
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA, net.nodeB.n())
|
||||
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB, net.nodeA.n())
|
||||
|
||||
// A -> B FINDNODE encrypted with initKeysA
|
||||
findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{Distances: []uint{3}})
|
||||
|
|
@ -362,8 +362,8 @@ func TestTestVectorsV5(t *testing.T) {
|
|||
ENRSeq: 2,
|
||||
},
|
||||
prep: func(net *handshakeTest) {
|
||||
net.nodeA.c.sc.storeNewSession(idB, addr, session)
|
||||
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped())
|
||||
net.nodeA.c.sc.storeNewSession(idB, addr, session, net.nodeB.n())
|
||||
net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped(), net.nodeA.n())
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -499,8 +499,8 @@ func BenchmarkV5_DecodePing(b *testing.B) {
|
|||
readKey: []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17},
|
||||
writeKey: []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134},
|
||||
}
|
||||
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session)
|
||||
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped())
|
||||
net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session, net.nodeB.n())
|
||||
net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped(), net.nodeA.n())
|
||||
addrB := net.nodeA.addr()
|
||||
ping := &Ping{ReqID: []byte("reqid"), ENRSeq: 5}
|
||||
enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), addrB, ping, nil)
|
||||
|
|
|
|||
|
|
@ -54,11 +54,12 @@ type session struct {
|
|||
writeKey []byte
|
||||
readKey []byte
|
||||
nonceCounter uint32
|
||||
node *enode.Node
|
||||
}
|
||||
|
||||
// keysFlipped returns a copy of s with the read and write keys flipped.
|
||||
func (s *session) keysFlipped() *session {
|
||||
return &session{s.readKey, s.writeKey, s.nonceCounter}
|
||||
return &session{s.readKey, s.writeKey, s.nonceCounter, s.node}
|
||||
}
|
||||
|
||||
func NewSessionCache(maxItems int, clock mclock.Clock) *SessionCache {
|
||||
|
|
@ -103,8 +104,19 @@ func (sc *SessionCache) readKey(id enode.ID, addr string) []byte {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (sc *SessionCache) readNode(id enode.ID, addr string) *enode.Node {
|
||||
if s := sc.session(id, addr); s != nil {
|
||||
return s.node
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// storeNewSession stores new encryption keys in the cache.
|
||||
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session) {
|
||||
func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session, n *enode.Node) {
|
||||
if n == nil {
|
||||
panic("nil node in storeNewSession")
|
||||
}
|
||||
s.node = n
|
||||
sc.sessions.Add(sessionID{id, addr}, s)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue