diff --git a/eth/dropper.go b/eth/dropper.go index 75eea891ca..741ca5cbd2 100644 --- a/eth/dropper.go +++ b/eth/dropper.go @@ -236,8 +236,21 @@ func (cm *dropper) protectedPeers(peers []*p2p.Peer) map[*p2p.Peer]bool { dialed = append(dialed, p) } } - // protectPool selects the top-frac peers from pool by score and adds them to result. + result := protectedPeersByPool(inbound, dialed, stats) + if len(result) > 0 { + log.Debug("Protecting high-value peers from drop", "protected", len(result)) + } + return result +} + +// protectedPeersByPool selects the union of top-N peers per protection +// category across the given already-split inbound and dialed pools. +// Factored from protectedPeers so tests can exercise the per-pool +// selection logic without needing to construct direction-flagged +// *p2p.Peer instances (which require unexported p2p types). +func protectedPeersByPool(inbound, dialed []*p2p.Peer, stats map[string]PeerInclusionStats) map[*p2p.Peer]bool { result := make(map[*p2p.Peer]bool) + // protectPool selects the top-frac peers from pool by score and adds them to result. protectPool := func(pool []*p2p.Peer, score func(*p2p.Peer) float64, frac float64) { n := int(float64(len(pool)) * frac) if n == 0 { @@ -260,9 +273,6 @@ func (cm *dropper) protectedPeers(peers []*p2p.Peer) map[*p2p.Peer]bool { protectPool(inbound, score, cat.frac) protectPool(dialed, score, cat.frac) } - if len(result) > 0 { - log.Debug("Protecting high-value peers from drop", "protected", len(result)) - } return result } diff --git a/eth/dropper_test.go b/eth/dropper_test.go index ba1eb66faf..ea764c9ab3 100644 --- a/eth/dropper_test.go +++ b/eth/dropper_test.go @@ -122,3 +122,112 @@ func TestProtectedPeersNilFunc(t *testing.T) { t.Fatalf("expected nil with nil stats func, got %v", protected) } } + +// TestProtectedByPoolPerPoolTopN verifies that the top-N selection runs +// independently in each of the inbound and dialed pools, not globally. +// With 10 peers per pool and inclusionProtectionFrac=0.1, exactly 1 peer +// is protected per pool per category — so 2 total (one per pool), both +// for the Finalized category since we don't set RecentIncluded. +func TestProtectedByPoolPerPoolTopN(t *testing.T) { + inbound := makePeers(10) + dialed := makePeers(10) + // Distinguish dialed peer IDs from inbound so stats maps don't collide. + for i := range dialed { + id := enode.ID{byte(100 + i)} + dialed[i] = p2p.NewPeer(id, fmt.Sprintf("dialed%d", i), nil) + } + // Strictly increasing scores: highest wins in each pool. + stats := make(map[string]PeerInclusionStats) + for i, p := range inbound { + stats[p.ID().String()] = PeerInclusionStats{Finalized: int64(1 + i)} + } + for i, p := range dialed { + stats[p.ID().String()] = PeerInclusionStats{Finalized: int64(1 + i)} + } + + protected := protectedPeersByPool(inbound, dialed, stats) + + // Expect top 1 of inbound (inbound[9]) and top 1 of dialed (dialed[9]). + if len(protected) != 2 { + t.Fatalf("expected 2 protected peers (1 per pool), got %d", len(protected)) + } + if !protected[inbound[9]] { + t.Error("expected top inbound peer to be protected") + } + if !protected[dialed[9]] { + t.Error("expected top dialed peer to be protected") + } +} + +// TestProtectedByPoolCrossCategoryOverlap verifies that the union across +// protection categories is correctly deduplicated: a peer that wins in +// multiple categories appears once, and category winners are all +// protected. Uses a pool large enough that frac*len yields n=2 per +// category, so cross-category overlap is observable. +func TestProtectedByPoolCrossCategoryOverlap(t *testing.T) { + // 20 dialed peers so 0.1 * 20 = 2 protected per category. + dialed := makePeers(20) + // P0: high Finalized only. P1: high RecentIncluded only. P2: high both. + // With n=2 per category: + // Finalized winners: P2 (tie-broken-ok), P0 + // RecentIncluded winners: P2, P1 + // Union: {P0, P1, P2}. + stats := make(map[string]PeerInclusionStats) + stats[dialed[0].ID().String()] = PeerInclusionStats{Finalized: 100, RecentIncluded: 0} + stats[dialed[1].ID().String()] = PeerInclusionStats{Finalized: 0, RecentIncluded: 5.0} + stats[dialed[2].ID().String()] = PeerInclusionStats{Finalized: 200, RecentIncluded: 10.0} + + protected := protectedPeersByPool(nil, dialed, stats) + + if len(protected) != 3 { + t.Fatalf("expected 3 protected peers (union of category winners), got %d", len(protected)) + } + for _, idx := range []int{0, 1, 2} { + if !protected[dialed[idx]] { + t.Errorf("peer %d should be protected", idx) + } + } +} + +// TestProtectedByPoolPerPoolIndependence locks in that selection runs +// per-pool, not globally. Every inbound peer scores higher than every +// dialed peer, so a global top-N would pick only inbound peers. Per-pool +// top-N must still protect the top dialed peers. +func TestProtectedByPoolPerPoolIndependence(t *testing.T) { + // 20 inbound, 20 dialed — frac=0.1 → 2 protected per pool per category. + // Global top-4 of Finalized would be inbound[16..19] — zero dialed. + inbound := makePeers(20) + dialed := make([]*p2p.Peer, 20) + for i := range dialed { + id := enode.ID{byte(100 + i)} + dialed[i] = p2p.NewPeer(id, fmt.Sprintf("dialed%d", i), nil) + } + stats := make(map[string]PeerInclusionStats) + // Every inbound peer outscores every dialed peer. + for i, p := range inbound { + stats[p.ID().String()] = PeerInclusionStats{Finalized: int64(1000 + i)} + } + for i, p := range dialed { + stats[p.ID().String()] = PeerInclusionStats{Finalized: int64(1 + i)} + } + + protected := protectedPeersByPool(inbound, dialed, stats) + + // Per-pool top-2 of Finalized: + // inbound: inbound[18], inbound[19] + // dialed: dialed[18], dialed[19] + // Global top-N would contain zero dialed peers, so asserting the top + // dialed peers are protected enforces per-pool independence. + if !protected[dialed[19]] { + t.Fatal("top dialed peer must be protected regardless of globally-higher inbound peers") + } + if !protected[dialed[18]] { + t.Fatal("second-top dialed peer must be protected regardless of globally-higher inbound peers") + } + if !protected[inbound[19]] || !protected[inbound[18]] { + t.Fatal("top inbound peers must also be protected") + } + if len(protected) != 4 { + t.Fatalf("expected 4 protected peers (top-2 of each pool), got %d", len(protected)) + } +}