diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 721cd7b589..016a2d1af3 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -753,6 +753,41 @@ func (tab *Table) deleteNode(n *enode.Node) { // waitForNodes blocks until the table contains at least n nodes. func (tab *Table) waitForNodes(ctx context.Context, n int) error { + // Wrap ctx so the forwarder goroutine exits when waitForNodes returns, + // regardless of whether the caller's ctx is canceled. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Set up a notification channel that gets unblocked when there was any activity on + // the table. Ultimately this reads from the table's nodeFeed, but can't use the feed + // directly on the same goroutine that takes Table.mutex, it would deadlock. + var notify chan struct{} + var notifyErr error + initsub := func() event.Subscription { + notify = make(chan struct{}, 1) + newnode := make(chan *enode.Node, 1) + sub := tab.nodeFeed.Subscribe(newnode) + go func() { + defer close(notify) + for { + select { + case <-newnode: + select { + case notify <- struct{}{}: + default: + } + case <-ctx.Done(): + notifyErr = ctx.Err() + return + case <-tab.closeReq: + notifyErr = errClosed + return + } + } + }() + return sub + } + getlength := func() (count int) { for _, b := range &tab.buckets { count += len(b.entries) @@ -760,28 +795,24 @@ func (tab *Table) waitForNodes(ctx context.Context, n int) error { return count } - var ch chan *enode.Node for { tab.mutex.Lock() if getlength() >= n { tab.mutex.Unlock() return nil } - if ch == nil { - // Init subscription. - ch = make(chan *enode.Node) - sub := tab.nodeFeed.Subscribe(ch) + if notify == nil { + // Lazily init the subscription. Do this while holding the + // lock so we don't miss any events that change the node count. + sub := initsub() defer sub.Unsubscribe() } tab.mutex.Unlock() - // Wait for a node add event. - select { - case <-ch: - case <-ctx.Done(): - return ctx.Err() - case <-tab.closeReq: - return errClosed + // Wait for table event. + if _, ok := <-notify; !ok { + break } } + return notifyErr } diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index c3b71ea5a6..a16b4d9cab 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -17,6 +17,7 @@ package discover import ( + "context" "crypto/ecdsa" "fmt" "math/rand" @@ -550,6 +551,45 @@ func TestSetFallbackNodes_DNSHostname(t *testing.T) { t.Logf("resolved localhost to %v", resolved.IPAddr()) } +// This test checks that waitForNodes does not block addFoundNode. +// See https://github.com/ethereum/go-ethereum/issues/34881. +func TestTable_waitForNodesLocking(t *testing.T) { + transport := newPingRecorder() + tab, db := newTestTable(transport, Config{}) + defer db.Close() + defer tab.close() + <-tab.initDone + + // waitForNodes will never reach this count, so it stays subscribed + // to nodeFeed and looping for the duration of the test. + waitCtx, cancelWait := context.WithCancel(context.Background()) + defer cancelWait() + waitDone := make(chan struct{}) + go func() { + defer close(waitDone) + tab.waitForNodes(waitCtx, 1<<20) + }() + + // Call addFoundNode in loop to send to the feed. + addDone := make(chan struct{}) + go func() { + defer close(addDone) + for i := range 10000 { + d := 240 + (i % 17) + n := nodeAtDistance(tab.self().ID(), d, intIP(i)) + tab.addFoundNode(n, true) + } + }() + + select { + case <-addDone: + cancelWait() + <-waitDone + case <-time.After(10 * time.Second): + t.Fatal("deadlock detected: add loop did not finish within 10s") + } +} + func newkey() *ecdsa.PrivateKey { key, err := crypto.GenerateKey() if err != nil {