eth/protocols/snap: fix issue where emptylist results for BALs are not retried. also, fix weird case where peers return more data than requested.

This commit is contained in:
jonny rhea 2026-04-09 16:26:36 -05:00
parent 6b830ce8fb
commit f11de2b261
7 changed files with 152 additions and 159 deletions

View file

@ -1182,7 +1182,7 @@ func (bc *BlockChain) SnapSyncComplete(hash common.Hash) error {
// Set up the snapshot tree from the synced flat state. Snap/2 downloads
// flat state directly as the snapshot.
if bc.snaps != nil {
bc.snaps.RebuildFromSyncedState(root)
bc.snaps.InitFromSyncedState(root)
}
// If all checks out, manually set the head block.

View file

@ -727,17 +727,25 @@ func (t *Tree) Rebuild(root common.Hash) {
}
}
// RebuildFromSyncedState sets up the snapshot tree to use flat state that was
// InitFromSyncedState sets up the snapshot tree to use flat state that was
// already downloaded by snap sync. Unlike Rebuild, it does NOT regenerate the
// snapshot from the trie.
func (t *Tree) RebuildFromSyncedState(root common.Hash) {
func (t *Tree) InitFromSyncedState(root common.Hash) {
t.lock.Lock()
defer t.lock.Unlock()
// Delete any recovery flag in the database.
rawdb.DeleteSnapshotRecoveryNumber(t.diskdb)
rawdb.DeleteSnapshotDisabled(t.diskdb)
// Write the new root.
rawdb.WriteSnapshotRoot(t.diskdb, root)
// Clear the journal.
journalProgress(t.diskdb, nil, nil)
log.Info("Setting up snapshot from synced state", "root", root)
// Replace t.layers with a single diskLayer pointing at the root.
t.layers = map[common.Hash]snapshot{
root: &diskLayer{
diskdb: t.diskdb,

View file

@ -868,24 +868,31 @@ func (d *Downloader) importBlockResults(results []*fetchResult) error {
}
// checkDeepReorg checks if the old pivot block was reorged by comparing its
// state root against the current canonical chain. If the canonical header at
// the old pivot's block number has a different state root, the syncer's flat
// state is from the old fork and must be wiped. Returns true if a deep reorg
// was detected.
// state root against the current canonical chain. Returns true if the
// canonical header at the old pivot's block number has a different state root,
// meaning the syncer's flat state is from the old fork and must be wiped.
//
// Returns false (no reorg) when the canonical hash or header is missing. This
// avoids false positives from pruned or not-yet-downloaded data. If the chain
// Returns false conservatively when canonical data is missing. If the chain
// really did shorten past the old pivot, sync.catchUp's from > to guard will
// catch this.
func checkDeepReorg(db ethdb.Database, oldNumber uint64, oldRoot common.Hash) bool {
// No canonical hash at the old pivot height. This could mean the chain was
// reorged to a shorter fork, or that headers for this height haven't been
// downloaded yet. Can't tell the two apart here, so don't wipe.
oldHash := rawdb.ReadCanonicalHash(db, oldNumber)
if oldHash == (common.Hash{}) {
return false
}
// Canonical hash exists but the header is missing (pruned or corrupted).
// Nothing to compare against, so don't wipe.
oldHeader := rawdb.ReadHeader(db, oldHash, oldNumber)
if oldHeader == nil {
return false
}
// Canonical root at this height differs from what we were syncing against —
// the old pivot was reorged out.
return oldHeader.Root != oldRoot
}

View file

@ -48,6 +48,7 @@ func verifyAccessList(b *bal.BlockAccessList, header *types.Header) error {
func (s *Syncer) applyAccessList(b *bal.BlockAccessList) error {
batch := s.db.NewBatch()
// Iterate over all accounts in the access list
for _, access := range b.Accesses {
addr := common.Address(access.Address)
accountHash := crypto.Keccak256Hash(addr[:])

View file

@ -141,7 +141,7 @@ var snap2 = map[uint64]msgHandler{
GetByteCodesMsg: handleGetByteCodes,
ByteCodesMsg: handleByteCodes,
GetAccessListsMsg: handleGetAccessLists,
// AccessListsMsg: TODO
AccessListsMsg: handleAccessLists,
}
// HandleMessage is invoked whenever an inbound message is received from a

View file

@ -520,13 +520,13 @@ func (s *Syncer) Sync(root common.Hash, number uint64, cancel chan struct{}) err
log.Info("Starting state download", "root", root)
for {
// Download: fetch all required state data
err := s.download(cancel)
err := s.downloadState(cancel)
if err == errPivotStale {
// Pivot moved: catch up to new pivot
if err := s.catchUp(cancel); err != nil {
return err
}
s.resetDownload(root, number)
s.resetDownloadState(root, number)
log.Info("Resuming state download", "root", root)
continue
}
@ -558,7 +558,7 @@ func (s *Syncer) Sync(root common.Hash, number uint64, cancel chan struct{}) err
// download runs the bulk flat-state download. It fetches
// account ranges, storage slots, and bytecodes, writing flat state to disk.
func (s *Syncer) download(cancel chan struct{}) error {
func (s *Syncer) downloadState(cancel chan struct{}) error {
// If the pivot moved since the last run (downloader cancelled and restarted
// us with a new root), signal catch-up before downloading.
if s.previousRoot != s.root {
@ -638,14 +638,14 @@ func (s *Syncer) download(cancel chan struct{}) error {
}
}
// resetDownload resets the download state for a new pivot after catch-up.
// resetDownloadState resets the download state for a new pivot after catch-up.
// It regenerates the task list for accounts not yet downloaded, clears
// in-flight requests, and updates the root.
func (s *Syncer) resetDownload(root common.Hash, number uint64) {
func (s *Syncer) resetDownloadState(root common.Hash, number uint64) {
s.lock.Lock()
s.root = root
s.number = number
s.previousRoot = root // Prevent download() from returning errPivotStale again
s.previousRoot = root // Prevent downloadState() from returning errPivotStale again
s.previousNumber = number
// Clear stateless peers bc they may be able to serve the new pivot
@ -662,16 +662,13 @@ func (s *Syncer) catchUp(cancel chan struct{}) error {
to := s.number
s.lock.RUnlock()
// The new pivot must be ahead of the old one. This can fail if a reorg
// replaced the block at the pivot height (same number, different root)
// or if a deep reorg shortened the chain past the old pivot. In either
// case, catch-up can't roll forward, so wipe progress and return an
// error so the caller restarts with a fresh sync.
//
// Note: this check lives here rather than in checkDeepReorg because
// catchUp is reached both when the downloader actively moves the pivot
// (via restartSnapSync) and when the syncer resumes from persisted
// progress after a restart. checkDeepReorg only covers the former.
// The new pivot must be ahead of the old one. The range is inverted if
// a reorg replaced the block at the pivot height (same number, different
// root) or if the chain shortened past the old pivot. In either case,
// catch-up can't roll forward — wipe progress and restart. This also
// catches reorgs missed by checkDeepReorg, which only runs when the
// downloader actively restarts the syncer, not on resume from persisted
// progress.
if from > to {
log.Warn("Catch-up range inverted, wiping sync progress", "from", from, "to", to)
rawdb.WriteSnapshotSyncStatus(s.db, nil)
@ -745,8 +742,6 @@ func (s *Syncer) fetchAccessLists(hashes []common.Hash, cancel chan struct{}) ([
pending[h] = struct{}{}
}
fetched := make(map[common.Hash]rlp.RawValue, len(hashes))
// Create ephemeral channels for this fetch cycle
var (
accessListReqFails = make(chan *accessListRequest)
accessListResps = make(chan *accessListResponse)
@ -785,6 +780,7 @@ func (s *Syncer) fetchAccessLists(hashes []common.Hash, cancel chan struct{}) ([
s.processAccessListResponse(res, pending, fetched)
}
}
// Assemble results in input order
results := make([]rlp.RawValue, len(hashes))
for i, h := range hashes {
@ -869,15 +865,19 @@ func (s *Syncer) assignAccessListTasks(pending map[common.Hash]struct{}, success
// processAccessListResponse handles a successful access list response by
// matching results to pending hashes and storing them.
func (s *Syncer) processAccessListResponse(res *accessListResponse, pending map[common.Hash]struct{}, fetched map[common.Hash]rlp.RawValue) {
// Each response entry corresponds to the requested hash at the same index
// Each response entry corresponds to the requested hash at the same index.
for i, raw := range res.accessLists {
if i >= len(res.req.hashes) {
break
}
h := res.req.hashes[i]
// Peer doesn't have this BAL. Add it back to pending for retry.
if bytes.Equal(raw, rlp.EmptyString) {
pending[h] = struct{}{}
continue
}
fetched[h] = raw
delete(pending, h)
}
// Re-add hashes that were not served back to pending
for i := len(res.accessLists); i < len(res.req.hashes); i++ {
pending[res.req.hashes[i]] = struct{}{}
@ -2388,6 +2388,12 @@ func (s *Syncer) OnAccessLists(peer SyncPeer, id uint64, accessLists rlp.RawList
s.scheduleRevertAccessListRequest(req)
return nil
}
if len(bals) > len(req.hashes) {
s.lock.Unlock()
s.scheduleRevertAccessListRequest(req)
logger.Warn("Peer sent more BALs than requested", "count", len(bals), "requested", len(req.hashes))
return errors.New("more BALs than requested")
}
s.lock.Unlock()
// Response validated, send it to the scheduler for filling.

View file

@ -1909,7 +1909,7 @@ func TestSlotEstimation(t *testing.T) {
// TestPivotMoveDetection verifies that when the syncer is restarted with a
// different root (simulating the downloader's cancel+restart on pivot move),
// download() returns errPivotStale immediately.
// downloadState() returns errPivotStale immediately.
func TestPivotMoveDetection(t *testing.T) {
t.Parallel()
@ -1937,45 +1937,14 @@ func TestPivotMoveDetection(t *testing.T) {
if syncer.root != rootB {
t.Fatalf("root mismatch: got %v, want %v", syncer.root, rootB)
}
// download() should detect the mismatch and return errPivotStale
// downloadState() should detect the mismatch and return errPivotStale
cancel := make(chan struct{})
err := syncer.download(cancel)
err := syncer.downloadState(cancel)
if err != errPivotStale {
t.Fatalf("expected errPivotStale, got %v", err)
}
}
// TestNoPivotMoveOnSameRoot verifies that when the syncer is restarted with
// the same root, download() does not return errPivotStale.
func TestNoPivotMoveOnSameRoot(t *testing.T) {
t.Parallel()
rootA := common.HexToHash("0xaaaa")
db := rawdb.NewMemoryDatabase()
syncer := NewSyncer(db, rawdb.HashScheme)
// Simulate a previous sync run against rootA
syncer.root = rootA
syncer.tasks = []*accountTask{
{Next: common.Hash{}, Last: common.MaxHash, SubTasks: make(map[common.Hash][]*storageTask), stateCompleted: make(map[common.Hash]struct{})},
}
syncer.saveSyncStatus()
// Simulate restart with the same root
syncer.root = rootA
syncer.previousRoot = rootA
syncer.loadSyncStatus()
if syncer.previousRoot != rootA {
t.Fatalf("previousRoot mismatch: got %v, want %v", syncer.previousRoot, rootA)
}
// previousRoot == root, so no pivot move detected
if syncer.previousRoot != syncer.root {
t.Fatalf("expected previousRoot == root, got %v != %v", syncer.previousRoot, syncer.root)
}
}
// TestCatchUpInvertedRange verifies that catchUp returns an error and wipes
// sync progress when the new pivot is at the same (or lower) block number as
// the old pivot..
@ -2003,51 +1972,6 @@ func TestCatchUpInvertedRange(t *testing.T) {
}
}
// TestFlatStateDownload verifies that download() writes flat state to disk
// and makes no trie node requests.
func TestFlatStateDownload(t *testing.T) {
t.Parallel()
testFlatStateDownload(t, rawdb.HashScheme)
testFlatStateDownload(t, rawdb.PathScheme)
}
func testFlatStateDownload(t *testing.T, scheme string) {
var (
once sync.Once
cancel = make(chan struct{})
term = func() {
once.Do(func() {
close(cancel)
})
}
)
nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100, scheme)
mkSource := func(name string) *testPeer {
source := newTestPeer(name, t, term)
source.accountTrie = sourceAccountTrie.Copy()
source.accountValues = elems
return source
}
syncer := setupSyncer(nodeScheme, mkSource("source"))
// Call download() directly to avoid rebuildTrie
syncer.root = sourceAccountTrie.Hash()
syncer.previousRoot = syncer.root // No pivot move
syncer.loadSyncStatus()
if err := syncer.download(cancel); err != nil {
t.Fatalf("download failed: %v", err)
}
// Verify flat state was written
for _, entry := range elems {
hash := common.BytesToHash(entry.k)
data := rawdb.ReadAccountSnapshot(syncer.db, hash)
if len(data) == 0 {
t.Errorf("missing account snapshot for %x", hash)
}
}
}
// TestInterruptedDownloadRecovery verifies that partially completed download
// state is persisted and resumed on restart.
func TestInterruptedDownloadRecovery(t *testing.T) {
@ -2086,7 +2010,7 @@ func testInterruptedDownloadRecovery(t *testing.T, scheme string) {
syncer1.root = root
syncer1.previousRoot = root
syncer1.loadSyncStatus()
syncer1.download(cancel1)
syncer1.downloadState(cancel1)
// Save progress
for _, task := range syncer1.tasks {
@ -2123,7 +2047,7 @@ func testInterruptedDownloadRecovery(t *testing.T, scheme string) {
syncer2.root = root
syncer2.previousRoot = root
syncer2.loadSyncStatus()
if err := syncer2.download(cancel2); err != nil {
if err := syncer2.downloadState(cancel2); err != nil {
t.Fatalf("resumed download failed: %v", err)
}
@ -2333,7 +2257,7 @@ func TestInterruptedRebuildRecovery(t *testing.T) {
root := sourceAccountTrie.Hash()
// First run: complete download, save status, simulate interruption
// before rebuild by calling download() directly
// before rebuild by calling downloadState() directly
var (
once1 sync.Once
cancel1 = make(chan struct{})
@ -2350,7 +2274,7 @@ func TestInterruptedRebuildRecovery(t *testing.T) {
syncer1.previousRoot = root
syncer1.loadSyncStatus()
if err := syncer1.download(cancel1); err != nil {
if err := syncer1.downloadState(cancel1); err != nil {
t.Fatalf("download failed: %v", err)
}
// Save status (simulating what Sync's defer does)
@ -2386,50 +2310,6 @@ func TestInterruptedRebuildRecovery(t *testing.T) {
}
}
// TestFetchAccessListsSinglePeer verifies fetching BALs from a single peer.
func TestFetchAccessListsSinglePeer(t *testing.T) {
t.Parallel()
var (
once sync.Once
cancel = make(chan struct{})
term = func() { once.Do(func() { close(cancel) }) }
)
// Create test BALs
hashes := []common.Hash{
common.HexToHash("0x01"),
common.HexToHash("0x02"),
common.HexToHash("0x03"),
}
bals := make(map[common.Hash]rlp.RawValue)
for _, h := range hashes {
cb := bal.NewConstructionBlockAccessList()
cb.BalanceChange(0, common.HexToAddress("0xaa"), uint256.NewInt(uint64(h[31])))
var buf bytes.Buffer
if err := cb.EncodeRLP(&buf); err != nil {
t.Fatal(err)
}
bals[h] = buf.Bytes()
}
source := newTestPeer("source", t, term)
source.accessLists = bals
syncer := setupSyncer(rawdb.HashScheme, source)
results, err := syncer.fetchAccessLists(hashes, cancel)
if err != nil {
t.Fatalf("fetchAccessLists failed: %v", err)
}
if len(results) != len(hashes) {
t.Fatalf("result count mismatch: got %d, want %d", len(results), len(hashes))
}
// Verify results match input order
for i, h := range hashes {
if !bytes.Equal(results[i], bals[h]) {
t.Errorf("result %d mismatch", i)
}
}
}
// TestFetchAccessListsMultiplePeers verifies that fetch distributes work
// across multiple idle peers.
func TestFetchAccessListsMultiplePeers(t *testing.T) {
@ -2467,6 +2347,12 @@ func TestFetchAccessListsMultiplePeers(t *testing.T) {
if len(results) != len(hashes) {
t.Fatalf("result count mismatch: got %d, want %d", len(results), len(hashes))
}
// Verify results match expected content in request order
for i, h := range hashes {
if !bytes.Equal(results[i], bals[h]) {
t.Errorf("result %d content mismatch for hash %v", i, h)
}
}
}
// TestFetchAccessListsPeerTimeout verifies that timed-out requests are retried
@ -2700,6 +2586,91 @@ func TestFetchAccessListsShortResponse(t *testing.T) {
}
}
// TestFetchAccessListsEmptyPlaceholder verifies that when a peer returns
// rlp.EmptyString placeholders for BALs it doesn't have, those placeholders
// are not silently accepted as valid results.
func TestFetchAccessListsEmptyPlaceholder(t *testing.T) {
t.Parallel()
var (
once sync.Once
cancel = make(chan struct{})
term = func() { once.Do(func() { close(cancel) }) }
)
hashes := []common.Hash{
common.HexToHash("0x01"),
common.HexToHash("0x02"),
common.HexToHash("0x03"),
}
// Build BALs for all 3 hashes
allBALs := make(map[common.Hash]rlp.RawValue)
for _, h := range hashes {
cb := bal.NewConstructionBlockAccessList()
cb.BalanceChange(0, common.HexToAddress("0xaa"), uint256.NewInt(uint64(h[31])))
var buf bytes.Buffer
if err := cb.EncodeRLP(&buf); err != nil {
t.Fatal(err)
}
allBALs[h] = buf.Bytes()
}
// partialPeer has BALs for hashes 0 and 2. The server
// handler returns rlp.EmptyString for the missing BAL.
partialPeer := newTestPeer("partial", t, term)
partialPeer.accessListRequestHandler = func(tp *testPeer, id uint64, reqHashes []common.Hash, max int) error {
var results []rlp.RawValue
for _, h := range reqHashes {
if raw, ok := allBALs[h]; ok && h != hashes[1] {
results = append(results, raw)
} else {
results = append(results, rlp.EmptyString)
}
}
rawList, _ := rlp.EncodeToRawList(results)
if err := tp.remote.OnAccessLists(tp, id, rawList); err != nil {
tp.test.Errorf("delivery rejected: %v", err)
tp.term()
}
return nil
}
// fullPeer has all BALs
fullPeer := newTestPeer("full", t, term)
fullPeer.accessLists = allBALs
syncer := setupSyncer(rawdb.HashScheme, partialPeer, fullPeer)
// Pre-seed capacity so partialPeer gets all 3 hashes
syncer.rates.Update(partialPeer.id, AccessListsMsg, time.Millisecond, 100)
done := make(chan struct{})
var (
results []rlp.RawValue
fetchErr error
)
go func() {
results, fetchErr = syncer.fetchAccessLists(hashes, cancel)
close(done)
}()
// Wait for fetch to complete
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("fetchAccessLists hung")
}
if fetchErr != nil {
t.Fatalf("fetchAccessLists failed: %v", fetchErr)
}
// Verify the results are valid.
for i, raw := range results {
var accessList bal.BlockAccessList
if err := rlp.DecodeBytes(raw, &accessList); err != nil {
t.Errorf("result %d (hash %v) is not a valid BAL: %v (got raw bytes %x)",
i, hashes[i], err, raw)
}
}
}
func newDbConfig(scheme string) *triedb.Config {
if scheme == rawdb.HashScheme {
return &triedb.Config{}