diff --git a/eth/dropper.go b/eth/dropper.go index 5b49a98185..6a8554eb83 100644 --- a/eth/dropper.go +++ b/eth/dropper.go @@ -189,6 +189,25 @@ func (cm *dropper) dropRandomPeer() bool { return true } +// topN returns the top n elements from items by score (descending). +// Only elements with score > 0 are included. The input slice is not modified. +func topN[T any](items []T, n int, score func(T) float64) []T { + if n == 0 || len(items) == 0 { + return nil + } + cp := make([]T, len(items)) + copy(cp, items) + sort.Slice(cp, func(i, j int) bool { return score(cp[i]) > score(cp[j]) }) + + var result []T + for i := 0; i < n && i < len(cp); i++ { + if score(cp[i]) > 0 { + result = append(result, cp[i]) + } + } + return result +} + // protectedPeers computes the set of peers that should not be dropped based // on inclusion stats. Each protection category independently selects its // top-N peers per inbound/dialed pool; the union is returned. @@ -200,42 +219,27 @@ func (cm *dropper) protectedPeers(peers []*p2p.Peer) map[*p2p.Peer]bool { if len(stats) == 0 { return nil } - type peerWithStats struct { - peer *p2p.Peer - s PeerInclusionStats - } - var inbound, dialed []peerWithStats + // Split peers by direction. + var inbound, dialed []*p2p.Peer for _, p := range peers { - entry := peerWithStats{p, stats[p.ID().String()]} if p.Inbound() { - inbound = append(inbound, entry) + inbound = append(inbound, p) } else { - dialed = append(dialed, entry) + dialed = append(dialed, p) } } result := make(map[*p2p.Peer]bool) - - protectTopN := func(entries []peerWithStats, cat protectionCategory) { - n := int(float64(len(entries)) * cat.frac) - if n == 0 { - return - } - sort.Slice(entries, func(i, j int) bool { - return cat.score(entries[i].s) > cat.score(entries[j].s) - }) - for i := 0; i < n && i < len(entries); i++ { - if cat.score(entries[i].s) > 0 { - result[entries[i].peer] = true - } - } - } for _, cat := range protectionCategories { - inCopy := make([]peerWithStats, len(inbound)) - copy(inCopy, inbound) - dialCopy := make([]peerWithStats, len(dialed)) - copy(dialCopy, dialed) - protectTopN(inCopy, cat) - protectTopN(dialCopy, cat) + // Build a score function that looks up the peer's stats. + score := func(p *p2p.Peer) float64 { + return cat.score(stats[p.ID().String()]) + } + for _, p := range topN(inbound, int(float64(len(inbound))*cat.frac), score) { + result[p] = true + } + for _, p := range topN(dialed, int(float64(len(dialed))*cat.frac), score) { + result[p] = true + } } if len(result) > 0 { log.Debug("Protecting high-value peers from drop", "protected", len(result))