From 55dae4fd09babea853c28fbf1896faaf74f28f9e Mon Sep 17 00:00:00 2001 From: apetro2 Date: Wed, 25 Mar 2026 21:24:50 +0800 Subject: [PATCH] core/txpool: guard nil head subscription --- core/txpool/txpool.go | 11 ++- core/txpool/txpool_test.go | 144 ++++++++++++++++++++++++++++++++----- 2 files changed, 134 insertions(+), 21 deletions(-) diff --git a/core/txpool/txpool.go b/core/txpool/txpool.go index 0f66a5abef..2ce5db326f 100644 --- a/core/txpool/txpool.go +++ b/core/txpool/txpool.go @@ -69,6 +69,7 @@ type TxPool struct { stateLock sync.RWMutex // The lock for protecting state instance state *state.StateDB // Current state at the blockchain head headCh chan core.ChainHeadEvent + headSub event.Subscription subs event.SubscriptionScope // Subscription scope to unsubscribe all on shutdown quit chan chan error // Quit channel to tear down the head updater @@ -104,15 +105,16 @@ func New(gasTip uint64, chain BlockChain, subpools []SubPool) (*TxPool, error) { term: make(chan struct{}), sync: make(chan chan error), } - if sub := chain.SubscribeChainHeadEvent(pool.headCh); sub != nil { - pool.subs.Track(sub) - } + pool.headSub = chain.SubscribeChainHeadEvent(pool.headCh) reserver := NewReservationTracker() for i, subpool := range subpools { if err := subpool.Init(gasTip, head, reserver.NewHandle(i)); err != nil { for j := i - 1; j >= 0; j-- { subpools[j].Close() } + if pool.headSub != nil { + pool.headSub.Unsubscribe() + } return nil, err } } @@ -124,6 +126,9 @@ func New(gasTip uint64, chain BlockChain, subpools []SubPool) (*TxPool, error) { func (p *TxPool) Close() error { var errs []error + if p.headSub != nil { + p.headSub.Unsubscribe() + } // Terminate the reset loop and wait for it to finish errc := make(chan error) p.quit <- errc diff --git a/core/txpool/txpool_test.go b/core/txpool/txpool_test.go index f28b6b18ad..f320c894c6 100644 --- a/core/txpool/txpool_test.go +++ b/core/txpool/txpool_test.go @@ -17,6 +17,9 @@ package txpool import ( + "errors" + "math/big" + "sync" "testing" "time" @@ -29,7 +32,6 @@ import ( ) type nilHeadSubChain struct{} -type trackedHeadSubChain struct{ nilHeadSubChain } func (nilHeadSubChain) Config() *params.ChainConfig { return params.TestChainConfig } @@ -43,8 +45,129 @@ func (nilHeadSubChain) StateAt(common.Hash) (*state.StateDB, error) { return state.New(types.EmptyRootHash, state.NewDatabaseForTesting()) } -func (trackedHeadSubChain) SubscribeChainHeadEvent(chan<- core.ChainHeadEvent) event.Subscription { - return event.NewSubscription(func(<-chan struct{}) error { return nil }) +type trackedHeadSubChain struct { + nilHeadSubChain + sub *subscriptionSpy +} + +func (c *trackedHeadSubChain) SubscribeChainHeadEvent(chan<- core.ChainHeadEvent) event.Subscription { + c.sub = newSubscriptionSpy() + return c.sub +} + +type subscriptionSpy struct { + err chan error + mu sync.Mutex + once sync.Once + closed bool +} + +func newSubscriptionSpy() *subscriptionSpy { + return &subscriptionSpy{err: make(chan error)} +} + +func (s *subscriptionSpy) Unsubscribe() { + s.once.Do(func() { + s.mu.Lock() + s.closed = true + s.mu.Unlock() + close(s.err) + }) +} + +func (s *subscriptionSpy) Err() <-chan error { + return s.err +} + +func (s *subscriptionSpy) isClosed() bool { + s.mu.Lock() + defer s.mu.Unlock() + + return s.closed +} + +type failingSubPool struct{} + +func (failingSubPool) Filter(*types.Transaction) bool { return false } + +func (failingSubPool) FilterType(byte) bool { return false } + +func (failingSubPool) Init(uint64, *types.Header, Reserver) error { + return errors.New("boom") +} + +func (failingSubPool) Close() error { return nil } + +func (failingSubPool) Reset(*types.Header, *types.Header) {} + +func (failingSubPool) SetGasTip(*big.Int) {} + +func (failingSubPool) Has(common.Hash) bool { return false } + +func (failingSubPool) Get(common.Hash) *types.Transaction { return nil } + +func (failingSubPool) GetRLP(common.Hash) []byte { return nil } + +func (failingSubPool) GetMetadata(common.Hash) *TxMetadata { return nil } + +func (failingSubPool) ValidateTxBasics(*types.Transaction) error { return nil } + +func (failingSubPool) Add([]*types.Transaction, bool) []error { return nil } + +func (failingSubPool) Pending(PendingFilter) map[common.Address][]*LazyTransaction { return nil } + +func (failingSubPool) SubscribeTransactions(chan<- core.NewTxsEvent, bool) event.Subscription { + return nil +} + +func (failingSubPool) Nonce(common.Address) uint64 { return 0 } + +func (failingSubPool) Stats() (int, int) { return 0, 0 } + +func (failingSubPool) Content() (map[common.Address][]*types.Transaction, map[common.Address][]*types.Transaction) { + return nil, nil +} + +func (failingSubPool) ContentFrom(common.Address) ([]*types.Transaction, []*types.Transaction) { + return nil, nil +} + +func (failingSubPool) Status(common.Hash) TxStatus { return TxStatusUnknown } + +func (failingSubPool) Clear() {} + +func TestTxPoolCloseUnsubscribesHeadSubscription(t *testing.T) { + t.Parallel() + + chain := &trackedHeadSubChain{} + pool, err := New(0, chain, nil) + if err != nil { + t.Fatalf("failed to create txpool: %v", err) + } + if chain.sub == nil { + t.Fatal("expected head subscription") + } + if err := pool.Close(); err != nil { + t.Fatalf("unexpected close error: %v", err) + } + if !chain.sub.isClosed() { + t.Fatal("expected head subscription to be unsubscribed on close") + } +} + +func TestTxPoolNewUnsubscribesHeadSubscriptionOnInitFailure(t *testing.T) { + t.Parallel() + + chain := &trackedHeadSubChain{} + if _, err := New(0, chain, []SubPool{failingSubPool{}}); err == nil { + t.Fatal("expected init failure") + } + if chain.sub == nil { + t.Fatal("expected head subscription") + } + if !chain.sub.isClosed() { + t.Fatal("expected head subscription to be unsubscribed on init failure") + } } func TestTxPoolCloseNilHeadSubscription(t *testing.T) { @@ -67,18 +190,3 @@ func TestTxPoolCloseNilHeadSubscription(t *testing.T) { t.Fatal("timed out waiting for txpool loop termination") } } - -func TestTxPoolNewTracksHeadSubscription(t *testing.T) { - t.Parallel() - - pool, err := New(0, trackedHeadSubChain{}, nil) - if err != nil { - t.Fatalf("failed to create txpool: %v", err) - } - if count := pool.subs.Count(); count != 1 { - t.Fatalf("unexpected subscription count: have %d want %d", count, 1) - } - if err := pool.Close(); err != nil { - t.Fatalf("unexpected close error: %v", err) - } -}