eth/downloader: fixes data race between synchronize and other methods (#21201)

This commit is contained in:
Daniel Liu 2025-02-24 18:54:12 +08:00
parent 2955a988ca
commit bf011986a5
5 changed files with 42 additions and 32 deletions

View file

@ -95,7 +95,7 @@ var (
)
type Downloader struct {
mode SyncMode // Synchronisation mode defining the strategy used (per sync cycle)
mode uint32 // Synchronisation mode defining the strategy used (per sync cycle)
mux *event.TypeMux // Event multiplexer to announce sync operation events
queue *queue // Scheduler for selecting the hashes to download
@ -205,13 +205,12 @@ type BlockChain interface {
}
// New creates a new downloader to fetch hashes and blocks from remote peers.
func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, chain BlockChain, lightchain LightChain, dropPeer peerDropFn, handleProposedBlock proposeBlockHandlerFn) *Downloader {
func New(stateDb ethdb.Database, mux *event.TypeMux, chain BlockChain, lightchain LightChain, dropPeer peerDropFn, handleProposedBlock proposeBlockHandlerFn) *Downloader {
if lightchain == nil {
lightchain = chain
}
dl := &Downloader{
mode: mode,
stateDB: stateDb,
mux: mux,
queue: newQueue(),
@ -254,13 +253,16 @@ func (d *Downloader) Progress() XDPoSChain.SyncProgress {
defer d.syncStatsLock.RUnlock()
current := uint64(0)
switch d.mode {
case FullSync:
mode := d.getMode()
switch {
case d.blockchain != nil && mode == FullSync:
current = d.blockchain.CurrentBlock().NumberU64()
case FastSync:
case d.blockchain != nil && mode == FastSync:
current = d.blockchain.CurrentFastBlock().NumberU64()
case LightSync:
case d.lightchain != nil:
current = d.lightchain.CurrentHeader().Number.Uint64()
default:
log.Error("Unknown downloader chain/mode combo", "light", d.lightchain != nil, "full", d.blockchain != nil, "mode", mode)
}
return XDPoSChain.SyncProgress{
StartingBlock: d.syncStatsChainOrigin,
@ -403,8 +405,8 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode
defer d.Cancel() // No matter what, we can't leave the cancel channel open
// Set the requested sync mode, unless it's forbidden
d.mode = mode
// Atomically set the requested sync mode
atomic.StoreUint32(&d.mode, uint32(mode))
// Retrieve the origin peer and initiate the downloading process
p := d.peers.Peer(id)
@ -414,6 +416,10 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode
return d.syncWithPeer(p, hash, td)
}
func (d *Downloader) getMode() SyncMode {
return SyncMode(atomic.LoadUint32(&d.mode))
}
// syncWithPeer starts a block synchronization based on the hash chain from the
// specified peer and head hash.
func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.Int) (err error) {
@ -429,8 +435,9 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I
if p.version < 62 {
return errTooOld
}
mode := d.getMode()
log.Debug("Synchronising with the network", "peer", p.id, "eth", p.version, "head", hash, "td", td, "mode", d.mode)
log.Debug("Synchronising with the network", "peer", p.id, "eth", p.version, "head", hash, "td", td, "mode", mode)
defer func(start time.Time) {
log.Debug("Synchronisation terminated", "elapsed", time.Since(start))
}(time.Now())
@ -455,7 +462,7 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I
// Ensure our origin point is below any fast sync pivot point
pivot := uint64(0)
if d.mode == FastSync {
if mode == FastSync {
if height <= uint64(fsMinFullBlocks) {
origin = 0
} else {
@ -466,11 +473,11 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I
}
}
d.committed = 1
if d.mode == FastSync && pivot != 0 {
if mode == FastSync && pivot != 0 {
d.committed = 0
}
// Initiate the sync using a concurrent header and content retrieval algorithm
d.queue.Prepare(origin+1, d.mode)
d.queue.Prepare(origin+1, mode)
if d.syncInitHook != nil {
d.syncInitHook(origin, height)
}
@ -481,9 +488,9 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I
func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync
func() error { return d.processHeaders(origin+1, pivot, td) },
}
if d.mode == FastSync {
if mode == FastSync {
fetchers = append(fetchers, func() error { return d.processFastSyncContent(latest) })
} else if d.mode == FullSync {
} else if mode == FullSync {
fetchers = append(fetchers, func() error { return d.processFullSyncContent(height) })
}
return d.spawnSync(fetchers)
@ -666,7 +673,8 @@ func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header)
localHeight uint64
remoteHeight = remoteHeader.Number.Uint64()
)
switch d.mode {
mode := d.getMode()
switch mode {
case FullSync:
localHeight = d.blockchain.CurrentBlock().NumberU64()
case FastSync:
@ -675,6 +683,7 @@ func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header)
localHeight = d.lightchain.CurrentHeader().Number.Uint64()
}
p.log.Debug("Looking for common ancestor", "local", localHeight, "remote", remoteHeight)
if localHeight >= MaxForkAncestry {
floor = int64(localHeight - MaxForkAncestry)
}
@ -726,7 +735,7 @@ func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header)
n := headers[i].Number.Uint64()
var known bool
switch d.mode {
switch mode {
case FullSync:
known = d.blockchain.HasBlock(h, n)
case FastSync:
@ -799,7 +808,7 @@ func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header)
n := headers[0].Number.Uint64()
var known bool
switch d.mode {
switch mode {
case FullSync:
known = d.blockchain.HasBlock(h, n)
case FastSync:
@ -874,6 +883,7 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, pivot uint64)
// Start pulling the header chain skeleton until all is done
getHeaders(from)
mode := d.getMode()
for {
select {
case <-d.cancelCh:
@ -934,7 +944,7 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, pivot uint64)
if n := len(headers); n > 0 {
// Retrieve the current head we're at
head := uint64(0)
if d.mode == LightSync {
if mode == LightSync {
head = d.lightchain.CurrentHeader().Number.Uint64()
} else {
head = d.blockchain.CurrentFastBlock().NumberU64()
@ -1289,6 +1299,7 @@ func (d *Downloader) fetchParts(deliveryCh chan dataPack, deliver func(dataPack)
func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) error {
// Keep a count of uncertain headers to roll back
rollback := []*types.Header{}
mode := d.getMode()
defer func() {
if len(rollback) > 0 {
// Flatten the headers and roll them back
@ -1297,13 +1308,13 @@ func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) er
hashes[i] = header.Hash()
}
lastHeader, lastFastBlock, lastBlock := d.lightchain.CurrentHeader().Number, common.Big0, common.Big0
if d.mode != LightSync {
if mode != LightSync {
lastFastBlock = d.blockchain.CurrentFastBlock().Number()
lastBlock = d.blockchain.CurrentBlock().Number()
}
d.lightchain.Rollback(hashes)
curFastBlock, curBlock := common.Big0, common.Big0
if d.mode != LightSync {
if mode != LightSync {
curFastBlock = d.blockchain.CurrentFastBlock().Number()
curBlock = d.blockchain.CurrentBlock().Number()
}
@ -1344,7 +1355,7 @@ func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) er
// L: Sync begins, and finds common ancestor at 11
// L: Request new headers up from 11 (R's TD was higher, it must have something)
// R: Nothing to give
if d.mode != LightSync {
if mode != LightSync {
head := d.blockchain.CurrentBlock()
if !gotHeaders && td.Cmp(d.blockchain.GetTd(head.Hash(), head.NumberU64())) > 0 {
return errStallingPeer
@ -1357,7 +1368,7 @@ func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) er
// This check cannot be executed "as is" for full imports, since blocks may still be
// queued for processing when the header download completes. However, as long as the
// peer gave us something useful, we're already happy/progressed (above check).
if d.mode == FastSync || d.mode == LightSync {
if mode == FastSync || mode == LightSync {
head := d.lightchain.CurrentHeader()
if td.Cmp(d.lightchain.GetTd(head.Hash(), head.Number.Uint64())) > 0 {
return errStallingPeer
@ -1382,9 +1393,8 @@ func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) er
limit = len(headers)
}
chunk := headers[:limit]
// In case of header only syncing, validate the chunk immediately
if d.mode == FastSync || d.mode == LightSync {
if mode == FastSync || mode == LightSync {
// Collect the yet unknown headers to mark them as uncertain
unknown := make([]*types.Header, 0, len(headers))
for _, header := range chunk {
@ -1412,7 +1422,7 @@ func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) er
}
}
// Unless we're doing light chains, schedule the headers for associated content retrieval
if d.mode == FullSync || d.mode == FastSync {
if mode == FullSync || mode == FastSync {
// If we've reached the allowed number of pending headers, stall a bit
for d.queue.PendingBlocks() >= maxQueuedHeaders || d.queue.PendingReceipts() >= maxQueuedHeaders {
select {

View file

@ -99,7 +99,7 @@ func newTester() *downloadTester {
tester.stateDb = rawdb.NewMemoryDatabase()
tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00})
tester.downloader = New(FullSync, tester.stateDb, new(event.TypeMux), tester, nil, tester.dropPeer, tester.handleProposedBlock)
tester.downloader = New(tester.stateDb, new(event.TypeMux), tester, nil, tester.dropPeer, tester.handleProposedBlock)
return tester
}
@ -658,7 +658,7 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
blocks += length - common
receipts += length - common
}
if tester.downloader.mode == LightSync {
if tester.downloader.getMode() == LightSync {
blocks, receipts = 1, 1
}
if hs := len(tester.ownHeaders); hs != headers {

View file

@ -18,8 +18,8 @@ package downloader
import "fmt"
// SyncMode represents the synchronisation mode of the downloader.
type SyncMode int
// It is a uint32 as it is used with atomic operations.
type SyncMode uint32
const (
FullSync SyncMode = iota // Synchronise the entire blockchain history from full blocks

View file

@ -209,7 +209,7 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne
}
// Construct the different synchronisation mechanisms
manager.downloader = downloader.New(mode, chaindb, manager.eventMux, blockchain, nil, manager.removePeer, handleProposedBlock)
manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain, nil, manager.removePeer, handleProposedBlock)
validator := func(header *types.Header) error {
return engine.VerifyHeader(blockchain, header, true)

View file

@ -208,7 +208,7 @@ func NewProtocolManager(chainConfig *params.ChainConfig, lightSync bool, protoco
}
if lightSync {
manager.downloader = downloader.New(downloader.LightSync, chainDb, manager.eventMux, nil, blockchain, removePeer, manager.handleProposedBlock)
manager.downloader = downloader.New(chainDb, manager.eventMux, nil, blockchain, removePeer, manager.handleProposedBlock)
manager.peers.notify((*downloaderPeerNotify)(manager))
manager.fetcher = newLightFetcher(manager)
}