feat: tracked pending inbound in dial.go

This commit is contained in:
jeevan-sid 2025-12-09 22:54:13 +05:30
parent e58c785424
commit 7d4c805681
2 changed files with 71 additions and 23 deletions

View file

@ -76,6 +76,7 @@ var (
errSelf = errors.New("is self") errSelf = errors.New("is self")
errAlreadyDialing = errors.New("already dialing") errAlreadyDialing = errors.New("already dialing")
errAlreadyConnected = errors.New("already connected") errAlreadyConnected = errors.New("already connected")
errPendingInbound = errors.New("peer has pending inbound connection")
errRecentlyDialed = errors.New("recently dialed") errRecentlyDialed = errors.New("recently dialed")
errNetRestrict = errors.New("not contained in netrestrict list") errNetRestrict = errors.New("not contained in netrestrict list")
errNoPort = errors.New("node does not provide TCP port") errNoPort = errors.New("node does not provide TCP port")
@ -104,12 +105,15 @@ type dialScheduler struct {
remStaticCh chan *enode.Node remStaticCh chan *enode.Node
addPeerCh chan *conn addPeerCh chan *conn
remPeerCh chan *conn remPeerCh chan *conn
addPendingCh chan enode.ID
remPendingCh chan enode.ID
// Everything below here belongs to loop and // Everything below here belongs to loop and
// should only be accessed by code on the loop goroutine. // should only be accessed by code on the loop goroutine.
dialing map[enode.ID]*dialTask // active tasks dialing map[enode.ID]*dialTask // active tasks
peers map[enode.ID]struct{} // all connected peers peers map[enode.ID]struct{} // all connected peers
dialPeers int // current number of dialed peers pendingInbound map[enode.ID]struct{} // in-progress inbound connections
dialPeers int // current number of dialed peers
// The static map tracks all static dial tasks. The subset of usable static dial tasks // The static map tracks all static dial tasks. The subset of usable static dial tasks
// (i.e. those passing checkDial) is kept in staticPool. The scheduler prefers // (i.e. those passing checkDial) is kept in staticPool. The scheduler prefers
@ -163,19 +167,22 @@ func (cfg dialConfig) withDefaults() dialConfig {
func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler { func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler {
cfg := config.withDefaults() cfg := config.withDefaults()
d := &dialScheduler{ d := &dialScheduler{
dialConfig: cfg, dialConfig: cfg,
historyTimer: mclock.NewAlarm(cfg.clock), historyTimer: mclock.NewAlarm(cfg.clock),
setupFunc: setupFunc, setupFunc: setupFunc,
dnsLookupFunc: net.DefaultResolver.LookupNetIP, dnsLookupFunc: net.DefaultResolver.LookupNetIP,
dialing: make(map[enode.ID]*dialTask), dialing: make(map[enode.ID]*dialTask),
static: make(map[enode.ID]*dialTask), static: make(map[enode.ID]*dialTask),
peers: make(map[enode.ID]struct{}), peers: make(map[enode.ID]struct{}),
doneCh: make(chan *dialTask), pendingInbound: make(map[enode.ID]struct{}),
nodesIn: make(chan *enode.Node), doneCh: make(chan *dialTask),
addStaticCh: make(chan *enode.Node), nodesIn: make(chan *enode.Node),
remStaticCh: make(chan *enode.Node), addStaticCh: make(chan *enode.Node),
addPeerCh: make(chan *conn), remStaticCh: make(chan *enode.Node),
remPeerCh: make(chan *conn), addPeerCh: make(chan *conn),
remPeerCh: make(chan *conn),
addPendingCh: make(chan enode.ID),
remPendingCh: make(chan enode.ID),
} }
d.lastStatsLog = d.clock.Now() d.lastStatsLog = d.clock.Now()
d.ctx, d.cancel = context.WithCancel(context.Background()) d.ctx, d.cancel = context.WithCancel(context.Background())
@ -223,6 +230,22 @@ func (d *dialScheduler) peerRemoved(c *conn) {
} }
} }
// inboundPending notifies the scheduler about a pending inbound connection.
func (d *dialScheduler) inboundPending(id enode.ID) {
select {
case d.addPendingCh <- id:
case <-d.ctx.Done():
}
}
// inboundCompleted notifies the scheduler that an inbound connection completed or failed.
func (d *dialScheduler) inboundCompleted(id enode.ID) {
select {
case d.remPendingCh <- id:
case <-d.ctx.Done():
}
}
// loop is the main loop of the dialer. // loop is the main loop of the dialer.
func (d *dialScheduler) loop(it enode.Iterator) { func (d *dialScheduler) loop(it enode.Iterator) {
var ( var (
@ -276,6 +299,15 @@ loop:
delete(d.peers, c.node.ID()) delete(d.peers, c.node.ID())
d.updateStaticPool(c.node.ID()) d.updateStaticPool(c.node.ID())
case id := <-d.addPendingCh:
d.pendingInbound[id] = struct{}{}
d.log.Trace("Marked node as pending inbound", "id", id)
case id := <-d.remPendingCh:
delete(d.pendingInbound, id)
d.updateStaticPool(id)
d.log.Trace("Unmarked node as pending inbound", "id", id)
case node := <-d.addStaticCh: case node := <-d.addStaticCh:
id := node.ID() id := node.ID()
_, exists := d.static[id] _, exists := d.static[id]
@ -390,6 +422,9 @@ func (d *dialScheduler) checkDial(n *enode.Node) error {
if _, ok := d.peers[n.ID()]; ok { if _, ok := d.peers[n.ID()]; ok {
return errAlreadyConnected return errAlreadyConnected
} }
if _, ok := d.pendingInbound[n.ID()]; ok {
return errPendingInbound
}
if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) { if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) {
return errNetRestrict return errNetRestrict
} }

View file

@ -682,22 +682,28 @@ running:
case c := <-srv.checkpointPostHandshake: case c := <-srv.checkpointPostHandshake:
// A connection has passed the encryption handshake so // A connection has passed the encryption handshake so
// the remote identity is known (but hasn't been verified yet). // the remote identity is known (but hasn't been verified yet).
if trusted[c.node.ID()] { nodeID := c.node.ID()
if trusted[nodeID] {
// Ensure that the trusted flag is set before checking against MaxPeers. // Ensure that the trusted flag is set before checking against MaxPeers.
c.flags |= trustedConn c.flags |= trustedConn
} }
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. err := srv.postHandshakeChecks(peers, inboundCount, c)
c.cont <- srv.postHandshakeChecks(peers, inboundCount, c) if err == nil && c.flags&inboundConn != 0 {
srv.dialsched.inboundPending(c.node.ID())
}
c.cont <- err
case c := <-srv.checkpointAddPeer: case c := <-srv.checkpointAddPeer:
// At this point the connection is past the protocol handshake. // At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified. // Its capabilities are known and the remote identity is verified.
nodeID := c.node.ID()
err := srv.addPeerChecks(peers, inboundCount, c) err := srv.addPeerChecks(peers, inboundCount, c)
if err == nil { if err == nil {
// The handshakes are done and it passed all checks. // The handshakes are done and it passed all checks.
p := srv.launchPeer(c) p := srv.launchPeer(c)
peers[c.node.ID()] = p peers[nodeID] = p
srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(), "conn", c.flags, "addr", p.RemoteAddr(), "name", p.Name()) srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(),
"conn", c.flags, "addr", p.RemoteAddr(), "name", p.Name())
srv.dialsched.peerAdded(c) srv.dialsched.peerAdded(c)
if p.Inbound() { if p.Inbound() {
inboundCount++ inboundCount++
@ -714,8 +720,10 @@ running:
case pd := <-srv.delpeer: case pd := <-srv.delpeer:
// A peer disconnected. // A peer disconnected.
d := common.PrettyDuration(mclock.Now() - pd.created) d := common.PrettyDuration(mclock.Now() - pd.created)
delete(peers, pd.ID()) nodeID := pd.ID()
srv.log.Debug("Removing p2p peer", "peercount", len(peers), "id", pd.ID(), "duration", d, "req", pd.requested, "err", pd.err) delete(peers, nodeID)
srv.log.Debug("Removing p2p peer", "peercount", len(peers), "id", nodeID,
"duration", d, "req", pd.requested, "err", pd.err)
srv.dialsched.peerRemoved(pd.rw) srv.dialsched.peerRemoved(pd.rw)
if pd.Inbound() { if pd.Inbound() {
inboundCount-- inboundCount--
@ -870,6 +878,11 @@ func (srv *Server) checkInboundConn(remoteIP netip.Addr) error {
// or the handshakes have failed. // or the handshakes have failed.
func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error { func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error {
c := &conn{fd: fd, flags: flags, cont: make(chan error)} c := &conn{fd: fd, flags: flags, cont: make(chan error)}
defer func() {
if c.is(inboundConn) && c.node != nil {
srv.dialsched.inboundCompleted(c.node.ID())
}
}()
if dialDest == nil { if dialDest == nil {
c.transport = srv.newTransport(fd, nil) c.transport = srv.newTransport(fd, nil)
} else { } else {