diff --git a/p2p/dial.go b/p2p/dial.go index 225709427c..f9463d6d89 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -76,6 +76,7 @@ var ( errSelf = errors.New("is self") errAlreadyDialing = errors.New("already dialing") errAlreadyConnected = errors.New("already connected") + errPendingInbound = errors.New("peer has pending inbound connection") errRecentlyDialed = errors.New("recently dialed") errNetRestrict = errors.New("not contained in netrestrict list") errNoPort = errors.New("node does not provide TCP port") @@ -104,12 +105,15 @@ type dialScheduler struct { remStaticCh chan *enode.Node addPeerCh chan *conn remPeerCh chan *conn + addPendingCh chan enode.ID + remPendingCh chan enode.ID // Everything below here belongs to loop and // should only be accessed by code on the loop goroutine. - dialing map[enode.ID]*dialTask // active tasks - peers map[enode.ID]struct{} // all connected peers - dialPeers int // current number of dialed peers + dialing map[enode.ID]*dialTask // active tasks + peers map[enode.ID]struct{} // all connected peers + pendingInbound map[enode.ID]struct{} // in-progress inbound connections + dialPeers int // current number of dialed peers // The static map tracks all static dial tasks. The subset of usable static dial tasks // (i.e. those passing checkDial) is kept in staticPool. The scheduler prefers @@ -163,19 +167,22 @@ func (cfg dialConfig) withDefaults() dialConfig { func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler { cfg := config.withDefaults() d := &dialScheduler{ - dialConfig: cfg, - historyTimer: mclock.NewAlarm(cfg.clock), - setupFunc: setupFunc, - dnsLookupFunc: net.DefaultResolver.LookupNetIP, - dialing: make(map[enode.ID]*dialTask), - static: make(map[enode.ID]*dialTask), - peers: make(map[enode.ID]struct{}), - doneCh: make(chan *dialTask), - nodesIn: make(chan *enode.Node), - addStaticCh: make(chan *enode.Node), - remStaticCh: make(chan *enode.Node), - addPeerCh: make(chan *conn), - remPeerCh: make(chan *conn), + dialConfig: cfg, + historyTimer: mclock.NewAlarm(cfg.clock), + setupFunc: setupFunc, + dnsLookupFunc: net.DefaultResolver.LookupNetIP, + dialing: make(map[enode.ID]*dialTask), + static: make(map[enode.ID]*dialTask), + peers: make(map[enode.ID]struct{}), + pendingInbound: make(map[enode.ID]struct{}), + doneCh: make(chan *dialTask), + nodesIn: make(chan *enode.Node), + addStaticCh: make(chan *enode.Node), + remStaticCh: make(chan *enode.Node), + addPeerCh: make(chan *conn), + remPeerCh: make(chan *conn), + addPendingCh: make(chan enode.ID), + remPendingCh: make(chan enode.ID), } d.lastStatsLog = d.clock.Now() d.ctx, d.cancel = context.WithCancel(context.Background()) @@ -223,6 +230,22 @@ func (d *dialScheduler) peerRemoved(c *conn) { } } +// inboundPending notifies the scheduler about a pending inbound connection. +func (d *dialScheduler) inboundPending(id enode.ID) { + select { + case d.addPendingCh <- id: + case <-d.ctx.Done(): + } +} + +// inboundCompleted notifies the scheduler that an inbound connection completed or failed. +func (d *dialScheduler) inboundCompleted(id enode.ID) { + select { + case d.remPendingCh <- id: + case <-d.ctx.Done(): + } +} + // loop is the main loop of the dialer. func (d *dialScheduler) loop(it enode.Iterator) { var ( @@ -276,6 +299,15 @@ loop: delete(d.peers, c.node.ID()) d.updateStaticPool(c.node.ID()) + case id := <-d.addPendingCh: + d.pendingInbound[id] = struct{}{} + d.log.Trace("Marked node as pending inbound", "id", id) + + case id := <-d.remPendingCh: + delete(d.pendingInbound, id) + d.updateStaticPool(id) + d.log.Trace("Unmarked node as pending inbound", "id", id) + case node := <-d.addStaticCh: id := node.ID() _, exists := d.static[id] @@ -390,6 +422,9 @@ func (d *dialScheduler) checkDial(n *enode.Node) error { if _, ok := d.peers[n.ID()]; ok { return errAlreadyConnected } + if _, ok := d.pendingInbound[n.ID()]; ok { + return errPendingInbound + } if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) { return errNetRestrict } diff --git a/p2p/dial_test.go b/p2p/dial_test.go index f18dacce2a..9684aa6e91 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -423,6 +423,82 @@ func TestDialSchedDNSHostname(t *testing.T) { }) } +// This test checks that nodes with pending inbound connections are not dialed. +func TestDialSchedPendingInbound(t *testing.T) { + t.Parallel() + + config := dialConfig{ + maxActiveDials: 5, + maxDialPeers: 4, + } + runDialTest(t, config, []dialTestRound{ + // 2 peers are connected, leaving 2 dial slots. + // Node 0x03 has a pending inbound connection. + // Discovered nodes 0x03, 0x04, 0x05 but only 0x04 and 0x05 should be dialed. + { + peersAdded: []*conn{ + {flags: dynDialedConn, node: newNode(uintID(0x01), "127.0.0.1:30303")}, + {flags: dynDialedConn, node: newNode(uintID(0x02), "127.0.0.2:30303")}, + }, + update: func(d *dialScheduler) { + d.inboundPending(uintID(0x03)) + }, + discovered: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), // not dialed because pending inbound + newNode(uintID(0x04), "127.0.0.4:30303"), + newNode(uintID(0x05), "127.0.0.5:30303"), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x04), "127.0.0.4:30303"), + newNode(uintID(0x05), "127.0.0.5:30303"), + }, + }, + // Pending inbound connection for 0x03 completes successfully. + // Node 0x03 becomes a connected peer. + // One dial slot remains, node 0x06 is dialed. + { + update: func(d *dialScheduler) { + // Pending inbound completes + d.inboundCompleted(uintID(0x03)) + }, + peersAdded: []*conn{ + {flags: inboundConn, node: newNode(uintID(0x03), "127.0.0.3:30303")}, + }, + succeeded: []enode.ID{ + uintID(0x04), + }, + failed: []enode.ID{ + uintID(0x05), + }, + discovered: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), // not dialed, now connected + newNode(uintID(0x06), "127.0.0.6:30303"), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x06), "127.0.0.6:30303"), + }, + }, + // Inbound peer 0x03 disconnects. + // Another pending inbound starts for 0x07. + // Only 0x03 should be dialed, not 0x07. + { + peersRemoved: []enode.ID{ + uintID(0x03), + }, + update: func(d *dialScheduler) { + d.inboundPending(uintID(0x07)) + }, + discovered: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), + newNode(uintID(0x07), "127.0.0.7:30303"), // not dialed because pending inbound + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), + }, + }, + }) +} + // ------- // Code below here is the framework for the tests above. diff --git a/p2p/server.go b/p2p/server.go index 10c855f1c4..6d2323f9ce 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -686,8 +686,11 @@ running: // Ensure that the trusted flag is set before checking against MaxPeers. c.flags |= trustedConn } - // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. - c.cont <- srv.postHandshakeChecks(peers, inboundCount, c) + err := srv.postHandshakeChecks(peers, inboundCount, c) + if err == nil && c.flags&inboundConn != 0 { + srv.dialsched.inboundPending(c.node.ID()) + } + c.cont <- err case c := <-srv.checkpointAddPeer: // At this point the connection is past the protocol handshake. @@ -870,6 +873,11 @@ func (srv *Server) checkInboundConn(remoteIP netip.Addr) error { // or the handshakes have failed. func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error { c := &conn{fd: fd, flags: flags, cont: make(chan error)} + defer func() { + if c.is(inboundConn) && c.node != nil { + srv.dialsched.inboundCompleted(c.node.ID()) + } + }() if dialDest == nil { c.transport = srv.newTransport(fd, nil) } else {