p2p/discover: add waitForNodes

This improves the latency of lookups in small networks and test setups. When the local node table runs empty, the lookupIterator will trigger refresh to try and fill the table again.

The behaviour of lookup in case of an empty table is changed:
- Previously, lookup waited fixed 1 second before trying to continue the lookup
- Now, lookup on an empty table returns immediately, and a better wait implementation is part of the LookupIterator. It reinitialises the table, and continues the interator as soon as a node becomes available.
This commit is contained in:
Csaba Kiraly 2025-09-17 10:27:35 +02:00 committed by GitHub
commit a4c9b34730
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 130 additions and 49 deletions

View file

@ -27,6 +27,7 @@ import (
// lookup performs a network search for nodes close to the given target. It approaches the
// target by querying nodes that are closer to it on each iteration. The given target does
// not need to be an actual node identifier.
// lookup on an empty table will return immediately with no nodes.
type lookup struct {
tab *Table
queryfunc queryFunc
@ -49,11 +50,15 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l
result: nodesByDistance{target: target},
replyCh: make(chan []*enode.Node, alpha),
cancelCh: ctx.Done(),
queries: -1,
}
// Don't query further if we hit ourself.
// Unlikely to happen often in practice.
it.asked[tab.self().ID()] = true
it.seen[tab.self().ID()] = true
// Initialize the lookup with nodes from table.
closest := it.tab.findnodeByID(it.result.target, bucketSize, false)
it.addNodes(closest.entries)
return it
}
@ -64,22 +69,19 @@ func (it *lookup) run() []*enode.Node {
return it.result.entries
}
func (it *lookup) empty() bool {
return len(it.replyBuffer) == 0
}
// advance advances the lookup until any new nodes have been found.
// It returns false when the lookup has ended.
func (it *lookup) advance() bool {
for it.startQueries() {
select {
case nodes := <-it.replyCh:
it.replyBuffer = it.replyBuffer[:0]
for _, n := range nodes {
if n != nil && !it.seen[n.ID()] {
it.seen[n.ID()] = true
it.result.push(n, bucketSize)
it.replyBuffer = append(it.replyBuffer, n)
}
}
it.queries--
if len(it.replyBuffer) > 0 {
it.addNodes(nodes)
if !it.empty() {
return true
}
case <-it.cancelCh:
@ -89,6 +91,17 @@ func (it *lookup) advance() bool {
return false
}
func (it *lookup) addNodes(nodes []*enode.Node) {
it.replyBuffer = it.replyBuffer[:0]
for _, n := range nodes {
if n != nil && !it.seen[n.ID()] {
it.seen[n.ID()] = true
it.result.push(n, bucketSize)
it.replyBuffer = append(it.replyBuffer, n)
}
}
}
func (it *lookup) shutdown() {
for it.queries > 0 {
<-it.replyCh
@ -103,20 +116,6 @@ func (it *lookup) startQueries() bool {
return false
}
// The first query returns nodes from the local table.
if it.queries == -1 {
closest := it.tab.findnodeByID(it.result.target, bucketSize, false)
// Avoid finishing the lookup too quickly if table is empty. It'd be better to wait
// for the table to fill in this case, but there is no good mechanism for that
// yet.
if len(closest.entries) == 0 {
it.slowdown()
}
it.queries = 1
it.replyCh <- closest.entries
return true
}
// Ask the closest nodes that we haven't asked yet.
for i := 0; i < len(it.result.entries) && it.queries < alpha; i++ {
n := it.result.entries[i]
@ -130,15 +129,6 @@ func (it *lookup) startQueries() bool {
return it.queries > 0
}
func (it *lookup) slowdown() {
sleep := time.NewTimer(1 * time.Second)
defer sleep.Stop()
select {
case <-sleep.C:
case <-it.tab.closeReq:
}
}
func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) {
r, err := it.queryfunc(n)
if !errors.Is(err, errClosed) { // avoid recording failures on shutdown.
@ -153,12 +143,16 @@ func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) {
// lookupIterator performs lookup operations and iterates over all seen nodes.
// When a lookup finishes, a new one is created through nextLookup.
// LookupIterator waits for table initialization and triggers a table refresh
// when necessary.
type lookupIterator struct {
buffer []*enode.Node
nextLookup lookupFunc
ctx context.Context
cancel func()
lookup *lookup
buffer []*enode.Node
nextLookup lookupFunc
ctx context.Context
cancel func()
lookup *lookup
tabRefreshing <-chan struct{}
}
type lookupFunc func(ctx context.Context) *lookup
@ -182,6 +176,7 @@ func (it *lookupIterator) Next() bool {
if len(it.buffer) > 0 {
it.buffer = it.buffer[1:]
}
// Advance the lookup to refill the buffer.
for len(it.buffer) == 0 {
if it.ctx.Err() != nil {
@ -191,17 +186,55 @@ func (it *lookupIterator) Next() bool {
}
if it.lookup == nil {
it.lookup = it.nextLookup(it.ctx)
if it.lookup.empty() {
// If the lookup is empty right after creation, it means the local table
// is in a degraded state, and we need to wait for it to fill again.
it.lookupFailed(it.lookup.tab, 1*time.Minute)
it.lookup = nil
continue
}
// Yield the initial nodes from the iterator before advancing the lookup.
it.buffer = it.lookup.replyBuffer
continue
}
if !it.lookup.advance() {
it.lookup = nil
continue
}
newNodes := it.lookup.advance()
it.buffer = it.lookup.replyBuffer
if !newNodes {
it.lookup = nil
}
}
return true
}
// lookupFailed handles failed lookup attempts. This can be called when the table has
// exited, or when it runs out of nodes.
func (it *lookupIterator) lookupFailed(tab *Table, timeout time.Duration) {
tout, cancel := context.WithTimeout(it.ctx, timeout)
defer cancel()
// Wait for Table initialization to complete, in case it is still in progress.
select {
case <-tab.initDone:
case <-tout.Done():
return
}
// Wait for ongoing refresh operation, or trigger one.
if it.tabRefreshing == nil {
it.tabRefreshing = tab.refresh()
}
select {
case <-it.tabRefreshing:
it.tabRefreshing = nil
case <-tout.Done():
return
}
// Wait for the table to fill.
tab.waitForNodes(tout, 1)
}
// Close ends the iterator.
func (it *lookupIterator) Close() {
it.cancel()

View file

@ -32,6 +32,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p/enode"
@ -84,6 +85,7 @@ type Table struct {
closeReq chan struct{}
closed chan struct{}
nodeFeed event.FeedOf[*enode.Node]
nodeAddedHook func(*bucket, *tableNode)
nodeRemovedHook func(*bucket, *tableNode)
}
@ -567,6 +569,8 @@ func (tab *Table) nodeAdded(b *bucket, n *tableNode) {
}
n.addedToBucket = time.Now()
tab.revalidation.nodeAdded(tab, n)
tab.nodeFeed.Send(n.Node)
if tab.nodeAddedHook != nil {
tab.nodeAddedHook(b, n)
}
@ -702,3 +706,38 @@ func (tab *Table) deleteNode(n *enode.Node) {
b := tab.bucket(n.ID())
tab.deleteInBucket(b, n.ID())
}
// waitForNodes blocks until the table contains at least n nodes.
func (tab *Table) waitForNodes(ctx context.Context, n int) error {
getlength := func() (count int) {
for _, b := range &tab.buckets {
count += len(b.entries)
}
return count
}
var ch chan *enode.Node
for {
tab.mutex.Lock()
if getlength() >= n {
tab.mutex.Unlock()
return nil
}
if ch == nil {
// Init subscription.
ch = make(chan *enode.Node)
sub := tab.nodeFeed.Subscribe(ch)
defer sub.Unsubscribe()
}
tab.mutex.Unlock()
// Wait for a node add event.
select {
case <-ch:
case <-ctx.Done():
return ctx.Err()
case <-tab.closeReq:
return errClosed
}
}
}

View file

@ -24,10 +24,12 @@ import (
"errors"
"fmt"
"io"
"maps"
"math/rand"
"net"
"net/netip"
"reflect"
"slices"
"sync"
"testing"
"time"
@ -509,18 +511,26 @@ func TestUDPv4_smallNetConvergence(t *testing.T) {
// they have all found each other.
status := make(chan error, len(nodes))
for i := range nodes {
node := nodes[i]
self := nodes[i]
go func() {
found := make(map[enode.ID]bool, len(nodes))
it := node.RandomNodes()
missing := make(map[enode.ID]bool, len(nodes))
for _, n := range nodes {
if n.Self().ID() == self.Self().ID() {
continue // skip self
}
missing[n.Self().ID()] = true
}
it := self.RandomNodes()
for it.Next() {
found[it.Node().ID()] = true
if len(found) == len(nodes) {
delete(missing, it.Node().ID())
if len(missing) == 0 {
status <- nil
return
}
}
status <- fmt.Errorf("node %s didn't find all nodes", node.Self().ID().TerminalString())
missingIDs := slices.Collect(maps.Keys(missing))
status <- fmt.Errorf("node %s didn't find all nodes, missing %v", self.Self().ID().TerminalString(), missingIDs)
}()
}
@ -537,7 +547,6 @@ func TestUDPv4_smallNetConvergence(t *testing.T) {
received++
if err != nil {
t.Error("ERROR:", err)
return
}
}
}