diff --git a/cmd/devp2p/internal/v5test/discv5tests.go b/cmd/devp2p/internal/v5test/discv5tests.go index efe9144069..4dc2507693 100644 --- a/cmd/devp2p/internal/v5test/discv5tests.go +++ b/cmd/devp2p/internal/v5test/discv5tests.go @@ -257,34 +257,50 @@ that they are returned by FINDNODE.`) // Create bystanders. nodes := make([]*bystander, 5) - added := make(chan enode.ID, len(nodes)) + liveCh := make(chan enode.ID, len(nodes)) for i := range nodes { - nodes[i] = newBystander(t, s, added) + nodes[i] = newBystander(t, s, liveCh) defer nodes[i].close() } - // Get them added to the remote table. + // Prefill each bystander with the full bystander set so background FINDNODE + // lookups see useful routing data instead of empty responses. + known := make([]*enode.Node, 0, len(nodes)) + for _, bn := range nodes { + known = append(known, bn.conn.localNode.Node()) + } + for _, bn := range nodes { + bn.known = append([]*enode.Node(nil), known...) + } + + // Wait until enough bystanders have actually become live, i.e. the remote node + // has revalidated them by sending PING and receiving our PONG. + requiredLiveNodes := len(nodes) timeout := 60 * time.Second timeoutCh := time.After(timeout) - for count := 0; count < len(nodes); { + liveSet := make(map[enode.ID]*enode.Node) + for len(liveSet) < requiredLiveNodes { select { - case id := <-added: - t.Logf("bystander node %v added to remote table", id) - count++ + case id := <-liveCh: + for _, bn := range nodes { + if bn.id() == id { + liveSet[id] = bn.conn.localNode.Node() + break + } + } + t.Logf("bystander node %v became live", id) case <-timeoutCh: - t.Errorf("remote added %d bystander nodes in %v, need %d to continue", count, timeout, len(nodes)) - t.Logf("this can happen if the node has a non-empty table from previous runs") + t.Errorf("remote revalidated %d bystander nodes in %v, need %d to continue", len(liveSet), timeout, requiredLiveNodes) return } } - t.Logf("all %d bystander nodes were added", len(nodes)) + t.Logf("continuing after all %d bystander nodes became live", len(liveSet)) - // Collect our nodes by distance. + // Collect live nodes by distance. var dists []uint expect := make(map[enode.ID]*enode.Node) - for _, bn := range nodes { - n := bn.conn.localNode.Node() - expect[n.ID()] = n + for id, n := range liveSet { + expect[id] = n d := uint(enode.LogDist(n.ID(), s.Dest.ID())) if !slices.Contains(dists, d) { dists = append(dists, d) @@ -295,42 +311,63 @@ that they are returned by FINDNODE.`) t.Log("requesting nodes") conn, l1 := s.listen1(t) defer conn.close() - foundNodes, err := conn.findnode(l1, dists) - if err != nil { - t.Fatal(err) - } - t.Logf("remote returned %d nodes for distance list %v", len(foundNodes), dists) - for _, n := range foundNodes { - delete(expect, n.ID()) - } - if len(expect) > 0 { - t.Errorf("missing %d nodes in FINDNODE result", len(expect)) - t.Logf("this can happen if the test is run multiple times in quick succession") - t.Logf("and the remote node hasn't removed dead nodes from previous runs yet") - } else { - t.Logf("all %d expected nodes were returned", len(nodes)) + + const maxAttempts = 5 + const retryInterval = 2 * time.Second + + for attempt := 1; attempt <= maxAttempts; attempt++ { + foundNodes, err := conn.findnode(l1, dists) + if err != nil { + t.Fatal(err) + } + missing := make(map[enode.ID]struct{}) + for id := range expect { + missing[id] = struct{}{} + } + for _, n := range foundNodes { + delete(missing, n.ID()) + } + t.Logf("attempt %d: remote returned %d nodes for distance list %v, missing %d", attempt, len(foundNodes), dists, len(missing)) + if len(missing) == 0 { + t.Logf("all %d expected live nodes were returned", len(expect)) + return + } + if attempt < maxAttempts { + time.Sleep(retryInterval) + } } + t.Errorf("missing nodes in FINDNODE result after %d attempts", maxAttempts) + t.Logf("this can happen if the node has a non-empty table from previous runs") } // A bystander is a node whose only purpose is filling a spot in the remote table. type bystander struct { - dest *enode.Node - conn *conn - l net.PacketConn + dest *enode.Node + conn *conn + l net.PacketConn + known []*enode.Node - addedCh chan enode.ID - done sync.WaitGroup + liveCh chan enode.ID + sent map[v5wire.Nonce]v5wire.Packet + done sync.WaitGroup } -func newBystander(t *utesting.T, s *Suite, added chan enode.ID) *bystander { +func newBystander(t *utesting.T, s *Suite, live chan enode.ID) *bystander { conn, l := s.listen1(t) conn.setEndpoint(l) // bystander nodes need IP/port to get pinged bn := &bystander{ - conn: conn, - l: l, - dest: s.Dest, - addedCh: added, + conn: conn, + l: l, + dest: s.Dest, + liveCh: live, + sent: make(map[v5wire.Nonce]v5wire.Packet), } + // Establish an initial session and let the remote learn this node before + // switching to the passive responder loop below. + conn.reqresp(l, &v5wire.Ping{ + ReqID: conn.nextReqID(), + ENRSeq: conn.localNode.Seq(), + }) bn.done.Add(1) go bn.loop() return bn @@ -351,48 +388,57 @@ func (bn *bystander) close() { func (bn *bystander) loop() { defer bn.done.Done() - var ( - lastPing time.Time - wasAdded bool - ) for { - // Ping the remote node. - if !wasAdded && time.Since(lastPing) > 10*time.Second { - bn.conn.reqresp(bn.l, &v5wire.Ping{ - ReqID: bn.conn.nextReqID(), - ENRSeq: bn.dest.Seq(), - }) - lastPing = time.Now() - } - // Answer packets. - switch p := bn.conn.read(bn.l).(type) { - case *v5wire.Ping: - bn.conn.write(bn.l, &v5wire.Pong{ - ReqID: p.ReqID, - ENRSeq: bn.conn.localNode.Seq(), - ToIP: bn.dest.IP(), - ToPort: uint16(bn.dest.UDP()), - }, nil) - wasAdded = true - bn.notifyAdded() - case *v5wire.Findnode: - bn.conn.write(bn.l, &v5wire.Nodes{ReqID: p.ReqID, RespCount: 1}, nil) - wasAdded = true - bn.notifyAdded() - case *v5wire.TalkRequest: - bn.conn.write(bn.l, &v5wire.TalkResponse{ReqID: p.ReqID}, nil) - case *readError: - if !netutil.IsTemporaryError(p.err) { - bn.conn.logf("shutting down: %v", p.err) - return + p, from := bn.conn.readFrom(bn.l) + switch p := p.(type) { + case *v5wire.Whoareyou: + p.Node = bn.dest + if resp, ok := bn.sent[p.Nonce]; ok { + nonce := bn.conn.writeTo(bn.l, resp, p, from) + delete(bn.sent, p.Nonce) + bn.sent[nonce] = resp + } else { + bn.conn.writeTo(bn.l, &v5wire.Ping{ + ReqID: bn.conn.nextReqID(), + ENRSeq: bn.conn.localNode.Seq(), + }, p, from) } + case *v5wire.Ping: + resp := &v5wire.Pong{ + ReqID: append([]byte(nil), p.ReqID...), + ENRSeq: bn.conn.localNode.Seq(), + ToIP: from.IP, + ToPort: uint16(from.Port), + } + nonce := bn.conn.writeTo(bn.l, resp, nil, from) + bn.sent[nonce] = resp + bn.notifyLive() + case *v5wire.Findnode: + resp := &v5wire.Nodes{ReqID: append([]byte(nil), p.ReqID...), RespCount: 1} + for _, n := range bn.known { + if slices.Contains(p.Distances, uint(enode.LogDist(n.ID(), bn.id()))) { + resp.Nodes = append(resp.Nodes, n.Record()) + } + } + nonce := bn.conn.writeTo(bn.l, resp, nil, from) + bn.sent[nonce] = resp + case *v5wire.TalkRequest: + resp := &v5wire.TalkResponse{ReqID: append([]byte(nil), p.ReqID...)} + nonce := bn.conn.writeTo(bn.l, resp, nil, from) + bn.sent[nonce] = resp + case *readError: + if netutil.IsTemporaryError(p.err) || v5wire.IsInvalidHeader(p.err) { + continue + } + bn.conn.logf("shutting down: %v", p.err) + return } } } -func (bn *bystander) notifyAdded() { - if bn.addedCh != nil { - bn.addedCh <- bn.id() - bn.addedCh = nil +func (bn *bystander) notifyLive() { + if bn.liveCh != nil { + bn.liveCh <- bn.id() + bn.liveCh = nil } } diff --git a/cmd/devp2p/internal/v5test/framework.go b/cmd/devp2p/internal/v5test/framework.go index acab1eef79..0532dafb2c 100644 --- a/cmd/devp2p/internal/v5test/framework.go +++ b/cmd/devp2p/internal/v5test/framework.go @@ -127,14 +127,16 @@ func (tc *conn) nextReqID() []byte { // The request is retried if a handshake is requested. func (tc *conn) reqresp(c net.PacketConn, req v5wire.Packet) v5wire.Packet { reqnonce := tc.write(c, req, nil) - switch resp := tc.read(c).(type) { + resp, from := tc.readFrom(c) + switch resp := resp.(type) { case *v5wire.Whoareyou: if resp.Nonce != reqnonce { return readErrorf("wrong nonce %x in WHOAREYOU (want %x)", resp.Nonce[:], reqnonce[:]) } resp.Node = tc.remote - tc.write(c, req, resp) - return tc.read(c) + tc.writeTo(c, req, resp, from) + resp2, _ := tc.readFrom(c) + return resp2 default: return resp } @@ -150,21 +152,24 @@ func (tc *conn) findnode(c net.PacketConn, dists []uint) ([]*enode.Node, error) results []*enode.Node ) for n := 1; n > 0; { - switch resp := tc.read(c).(type) { + resp, from := tc.readFrom(c) + switch resp := resp.(type) { case *v5wire.Whoareyou: // Handle handshake. if resp.Nonce == reqnonce { resp.Node = tc.remote - tc.write(c, findnode, resp) + tc.writeTo(c, findnode, resp, from) } else { return nil, fmt.Errorf("unexpected WHOAREYOU (nonce %x), waiting for NODES", resp.Nonce[:]) } case *v5wire.Ping: // Handle ping from remote. - tc.write(c, &v5wire.Pong{ + tc.writeTo(c, &v5wire.Pong{ ReqID: resp.ReqID, ENRSeq: tc.localNode.Seq(), - }, nil) + ToIP: from.IP, + ToPort: uint16(from.Port), + }, nil, from) case *v5wire.Nodes: // Got NODES! Check request ID. if !bytes.Equal(resp.ReqID, findnode.ReqID) { @@ -200,11 +205,16 @@ func (tc *conn) findnode(c net.PacketConn, dists []uint) ([]*enode.Node, error) // write sends a packet on the given connection. func (tc *conn) write(c net.PacketConn, p v5wire.Packet, challenge *v5wire.Whoareyou) v5wire.Nonce { + return tc.writeTo(c, p, challenge, tc.remoteAddr) +} + +// writeTo sends a packet on the given connection to the given UDP address. +func (tc *conn) writeTo(c net.PacketConn, p v5wire.Packet, challenge *v5wire.Whoareyou, to *net.UDPAddr) v5wire.Nonce { packet, nonce, err := tc.codec.Encode(tc.remote.ID(), tc.remoteAddr.String(), p, challenge) if err != nil { panic(fmt.Errorf("can't encode %v packet: %v", p.Name(), err)) } - if _, err := c.WriteTo(packet, tc.remoteAddr); err != nil { + if _, err := c.WriteTo(packet, to); err != nil { tc.logf("Can't send %s: %v", p.Name(), err) } else { tc.logf(">> %s", p.Name()) @@ -214,24 +224,30 @@ func (tc *conn) write(c net.PacketConn, p v5wire.Packet, challenge *v5wire.Whoar // read waits for an incoming packet on the given connection. func (tc *conn) read(c net.PacketConn) v5wire.Packet { + p, _ := tc.readFrom(c) + return p +} + +// readFrom waits for an incoming packet and returns its source address. +func (tc *conn) readFrom(c net.PacketConn) (v5wire.Packet, *net.UDPAddr) { buf := make([]byte, 1280) if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil { - return &readError{err} + return &readError{err}, nil } - n, _, err := c.ReadFrom(buf) + n, from, err := c.ReadFrom(buf) if err != nil { - return &readError{err} + return &readError{err}, nil } - // Always use tc.remoteAddr for session lookup. The actual source address of - // the packet may differ from tc.remoteAddr when the remote node is reachable - // via multiple networks (e.g. Docker bridge vs. overlay), but the codec's - // session cache is keyed by the address used during Encode. + udpFrom, _ := from.(*net.UDPAddr) + // Use tc.remoteAddr for codec/session lookup because the fixture keys sessions + // by the advertised endpoint, but return the actual UDP source so responses can + // comply with the spec and go back to the request envelope address. _, _, p, err := tc.codec.Decode(buf[:n], tc.remoteAddr.String()) if err != nil { - return &readError{err} + return &readError{err}, udpFrom } tc.logf("<< %s", p.Name()) - return p + return p, udpFrom } // logf prints to the test log.