From 184bde8ca035ebb0d69168f86d8f5be92271011f Mon Sep 17 00:00:00 2001 From: jonny rhea <5555162+jrhea@users.noreply.github.com> Date: Wed, 22 Apr 2026 14:42:20 -0500 Subject: [PATCH] eth/protocols/snap, eth/downloader: version SyncProgress and use *types.Header for pivot tracking eth/protocols/snap: remove unnecessary Sync() loop, drop errPivotStale and resetDownloadState eth/protocols/snap: move pivot-reorg detection into Sync(), rename checkDeepReorg to isPivotReorged eth/protocols/snap: don't apply BALs to accounts that haven't been downloaded yet. eth/protocols/snap: drop empty accounts and zero-value storage on BAL apply eth/protocols/snap: wipe flat state on sync reset, consolidate reorg detection eth/protocols/snap: persist Complete=true on sync completion to skip redundant resyncs eth/protocols/snap: persist catchUp progress incrementally to enable resume eth/protocols/snap: verify BALs during fetch to route around bad peers, make catchUp cancelable core/rawdb,eth/protocols/snap: add DeleteSnapshotSyncStatus helper --- core/rawdb/accessors_snapshot.go | 7 + eth/downloader/downloader.go | 52 +-- eth/downloader/statesync.go | 24 +- eth/protocols/snap/bal_apply.go | 46 +- eth/protocols/snap/bal_apply_test.go | 166 +++++++ eth/protocols/snap/progress_test.go | 224 +++++---- eth/protocols/snap/sync.go | 486 +++++++++++++------- eth/protocols/snap/sync_test.go | 663 +++++++++++++++++++++++---- 8 files changed, 1207 insertions(+), 461 deletions(-) diff --git a/core/rawdb/accessors_snapshot.go b/core/rawdb/accessors_snapshot.go index 5cea581fcd..8872b00fc2 100644 --- a/core/rawdb/accessors_snapshot.go +++ b/core/rawdb/accessors_snapshot.go @@ -208,3 +208,10 @@ func WriteSnapshotSyncStatus(db ethdb.KeyValueWriter, status []byte) { log.Crit("Failed to store snapshot sync status", "err", err) } } + +// DeleteSnapshotSyncStatus removes the serialized sync status from the database. +func DeleteSnapshotSyncStatus(db ethdb.KeyValueWriter) { + if err := db.Delete(snapshotSyncStatusKey); err != nil { + log.Crit("Failed to remove snapshot sync status", "err", err) + } +} diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index e291042e53..5e13210152 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -867,55 +867,13 @@ func (d *Downloader) importBlockResults(results []*fetchResult) error { return nil } -// checkDeepReorg checks if the old pivot block was reorged by comparing its -// state root against the current canonical chain. Returns true if the -// canonical header at the old pivot's block number has a different state root, -// meaning the syncer's flat state is from the old fork and must be wiped. -// -// Returns false conservatively when canonical data is missing. If the chain -// really did shorten past the old pivot, sync.catchUp's from > to guard will -// catch this. -func checkDeepReorg(db ethdb.Database, oldNumber uint64, oldRoot common.Hash) bool { - // No canonical hash at the old pivot height. This could mean the chain was - // reorged to a shorter fork, or that headers for this height haven't been - // downloaded yet. Can't tell the two apart here, so don't wipe. - oldHash := rawdb.ReadCanonicalHash(db, oldNumber) - if oldHash == (common.Hash{}) { - return false - } - - // Canonical hash exists but the header is missing (pruned or corrupted). - // Nothing to compare against, so don't wipe. - oldHeader := rawdb.ReadHeader(db, oldHash, oldNumber) - if oldHeader == nil { - return false - } - - // Canonical root at this height differs from what we were syncing against — - // the old pivot was reorged out. - return oldHeader.Root != oldRoot -} - -// restartSnapSync cancels the current state sync and starts a new one with the -// given root. Before restarting, it checks for deep reorgs and wipes sync -// progress if the old pivot was reorged. -func (d *Downloader) restartSnapSync(oldSync *stateSync, newRoot common.Hash, newNumber uint64) *stateSync { - if checkDeepReorg(d.stateDB, oldSync.number, oldSync.root) { - log.Warn("Deep reorg detected, restarting snap sync from scratch", - "number", oldSync.number, "oldRoot", oldSync.root) - rawdb.WriteSnapshotSyncStatus(d.stateDB, nil) - } - oldSync.Cancel() - return d.syncState(newRoot, newNumber) -} - // processSnapSyncContent takes fetch results from the queue and writes them to the // database. It also controls the synchronisation of state nodes of the pivot block. func (d *Downloader) processSnapSyncContent() error { // Start syncing state of the reported head block. This should get us most of // the state of the pivot block. d.pivotLock.RLock() - sync := d.syncState(d.pivotHeader.Root, d.pivotHeader.Number.Uint64()) + sync := d.syncState(d.pivotHeader) d.pivotLock.RUnlock() defer func() { @@ -985,8 +943,9 @@ func (d *Downloader) processSnapSyncContent() error { if oldPivot == nil { // no results piling up, we can move the pivot if !d.committed.Load() { // not yet passed the pivot, we can move the pivot - if pivot.Root != sync.root { // pivot position changed, we can move the pivot - sync = d.restartSnapSync(sync, pivot.Root, pivot.Number.Uint64()) + if pivot.Hash() != sync.pivot.Hash() { // pivot position changed, we can move the pivot + sync.Cancel() + sync = d.syncState(pivot) go closeOnErr(sync) } } @@ -1000,7 +959,8 @@ func (d *Downloader) processSnapSyncContent() error { if P != nil { // If new pivot block found, cancel old state retrieval and restart if oldPivot != P { - sync = d.restartSnapSync(sync, P.Header.Root, P.Header.Number.Uint64()) + sync.Cancel() + sync = d.syncState(P.Header) go closeOnErr(sync) oldPivot = P } diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index 873a190af7..7b755ca968 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -19,14 +19,14 @@ package downloader import ( "sync" - "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" ) -// syncState starts downloading state with the given root hash and block number. -func (d *Downloader) syncState(root common.Hash, number uint64) *stateSync { +// syncState starts downloading state with the given pivot header. +func (d *Downloader) syncState(pivot *types.Header) *stateSync { // Create the state sync - s := newStateSync(d, root, number) + s := newStateSync(d, pivot) select { case d.stateSyncStart <- s: // If we tell the statesync to restart with a new root, we also need @@ -58,7 +58,7 @@ func (d *Downloader) stateFetcher() { // runStateSync runs a state synchronisation until it completes or another root // hash is requested to be switched over to. func (d *Downloader) runStateSync(s *stateSync) *stateSync { - log.Trace("State sync starting", "root", s.root) + log.Trace("State sync starting", "pivot", s.pivot.Hash(), "number", s.pivot.Number) go s.run() defer s.Cancel() @@ -75,11 +75,10 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync { } // stateSync schedules requests for downloading a particular state trie defined -// by a given state root. +// by a given pivot header. type stateSync struct { - d *Downloader // Downloader instance to access and manage current peerset - root common.Hash // State root currently being synced - number uint64 // Block number of the pivot + d *Downloader // Downloader instance to access and manage current peerset + pivot *types.Header // Pivot header currently being synced started chan struct{} // Started is signalled once the sync loop starts cancel chan struct{} // Channel to signal a termination request @@ -90,11 +89,10 @@ type stateSync struct { // newStateSync creates a new state trie download scheduler. This method does not // yet start the sync. The user needs to call run to initiate. -func newStateSync(d *Downloader, root common.Hash, number uint64) *stateSync { +func newStateSync(d *Downloader, pivot *types.Header) *stateSync { return &stateSync{ d: d, - root: root, - number: number, + pivot: pivot, cancel: make(chan struct{}), done: make(chan struct{}), started: make(chan struct{}), @@ -106,7 +104,7 @@ func newStateSync(d *Downloader, root common.Hash, number uint64) *stateSync { // finish. func (s *stateSync) run() { close(s.started) - s.err = s.d.SnapSyncer.Sync(s.root, s.number, s.cancel) + s.err = s.d.SnapSyncer.Sync(s.pivot, s.cancel) close(s.done) } diff --git a/eth/protocols/snap/bal_apply.go b/eth/protocols/snap/bal_apply.go index 5bb388f9ba..1399fc58c9 100644 --- a/eth/protocols/snap/bal_apply.go +++ b/eth/protocols/snap/bal_apply.go @@ -41,6 +41,18 @@ func verifyAccessList(b *bal.BlockAccessList, header *types.Header) error { return nil } +// isFetched tell us if accountHash has been downloaded. +func (s *Syncer) isFetched(accountHash common.Hash) bool { + s.lock.RLock() + defer s.lock.RUnlock() + for _, task := range s.tasks { + if bytes.Compare(accountHash[:], task.Last[:]) <= 0 { + return bytes.Compare(accountHash[:], task.Next[:]) < 0 + } + } + return true +} + // applyAccessList applies a single block's access list diffs to the flat state // in the database. For each account, it applies the post-block values (highest // TxIdx entry) for balance, nonce, code, and storage. The storageRoot field is @@ -53,6 +65,11 @@ func (s *Syncer) applyAccessList(b *bal.BlockAccessList) error { addr := common.Address(access.Address) accountHash := crypto.Keccak256Hash(addr[:]) + // Skip accounts whose hash range hasn't been downloaded yet. + if !s.isFetched(accountHash) { + continue + } + // Read the existing account from flat state (may not exist yet) var ( account types.StateAccount @@ -95,22 +112,35 @@ func (s *Syncer) applyAccessList(b *bal.BlockAccessList) error { } } - // Apply storage writes (last entry per slot = post-block state) + // Apply storage writes (last entry per slot = post-block state). for _, slotWrites := range access.StorageWrites { if n := len(slotWrites.Accesses); n > 0 { value := slotWrites.Accesses[n-1].ValueAfter storageHash := crypto.Keccak256Hash(slotWrites.Slot[:]) - rawdb.WriteStorageSnapshot(batch, accountHash, storageHash, value[:]) + if value == (common.Hash{}) { + rawdb.DeleteStorageSnapshot(batch, accountHash, storageHash) + } else { + rawdb.WriteStorageSnapshot(batch, accountHash, storageHash, value[:]) + } } } // Don't create empty accounts in flat state (EIP-161). - // This handles the case where an account is created and - // self-destructed in the same transaction. The BAL will - // include it with a balance change to zero, but the account - // should not exist in state. - if isNew && account.Balance.IsZero() && account.Nonce == 0 && - bytes.Equal(account.CodeHash, types.EmptyCodeHash[:]) { + isEmpty := account.Balance.IsZero() && account.Nonce == 0 && + bytes.Equal(account.CodeHash, types.EmptyCodeHash[:]) + switch { + case isEmpty && isNew: + // This handles the case where an account is created and + // self-destructed in the same transaction. The BAL will + // include it with a balance change to zero, but the account + // should not exist in state. + continue + case isEmpty && !isNew: + // Existing account got fully drained (e.g., pre-funded + // address that gets deployed to with init code that + // self-destructs). Delete the entry so the trie rebuild + // doesn't pick it up as an empty leaf. + rawdb.DeleteAccountSnapshot(batch, accountHash) continue } diff --git a/eth/protocols/snap/bal_apply_test.go b/eth/protocols/snap/bal_apply_test.go index acb6d35a14..c22df170c4 100644 --- a/eth/protocols/snap/bal_apply_test.go +++ b/eth/protocols/snap/bal_apply_test.go @@ -201,6 +201,41 @@ func TestAccessListApplicationMultiTx(t *testing.T) { } } +// TestAccessListApplicationZeroStorage verifies that a BAL slot write with a +// zero post-value deletes the snapshot entry instead of writing 32 zero +// bytes. +func TestAccessListApplicationZeroStorage(t *testing.T) { + t.Parallel() + db := rawdb.NewMemoryDatabase() + syncer := NewSyncer(db, rawdb.HashScheme) + addr := common.HexToAddress("0x06") + accountHash := crypto.Keccak256Hash(addr[:]) + + // Existing account with a non-zero storage slot. + original := types.StateAccount{ + Nonce: 1, + Balance: uint256.NewInt(1), + Root: types.EmptyRootHash, + CodeHash: types.EmptyCodeHash[:], + } + rawdb.WriteAccountSnapshot(db, accountHash, types.SlimAccountRLP(original)) + rawSlot := common.HexToHash("0xaa") + slotHash := crypto.Keccak256Hash(rawSlot[:]) + rawdb.WriteStorageSnapshot(db, accountHash, slotHash, common.HexToHash("0x42").Bytes()) + + // BAL writes the slot to zero (deletion). + cb := bal.NewConstructionBlockAccessList() + cb.StorageWrite(0, addr, rawSlot, common.Hash{}) + b := buildTestBAL(t, &cb) + if err := syncer.applyAccessList(b); err != nil { + t.Fatalf("applyAccessList failed: %v", err) + } + + if val := rawdb.ReadStorageSnapshot(db, accountHash, slotHash); len(val) != 0 { + t.Errorf("zeroed slot should have been deleted, got %x", val) + } +} + // TestAccessListApplicationNewAccount verifies that applyAccessList creates // new accounts that don't exist in the DB yet. func TestAccessListApplicationNewAccount(t *testing.T) { @@ -255,6 +290,100 @@ func TestAccessListApplicationNewAccount(t *testing.T) { } } +// TestAccessListApplicationSkipsUnfetched verifies that applyAccessList does +// not write account entries for addresses whose hash falls in a range that +// hasn't been downloaded yet. +func TestAccessListApplicationSkipsUnfetched(t *testing.T) { + t.Parallel() + db := rawdb.NewMemoryDatabase() + syncer := NewSyncer(db, rawdb.HashScheme) + + // Pick two addresses and order them by hash. + addrA := common.HexToAddress("0x01") + addrB := common.HexToAddress("0x02") + hashA := crypto.Keccak256Hash(addrA[:]) + hashB := crypto.Keccak256Hash(addrB[:]) + fetchedAddr, fetchedHash := addrA, hashA + unfetchedAddr, unfetchedHash := addrB, hashB + if bytes.Compare(hashA[:], hashB[:]) > 0 { + fetchedAddr, fetchedHash = addrB, hashB + unfetchedAddr, unfetchedHash = addrA, hashA + } + + // One remaining task covering [unfetchedHash, MaxHash]: the fetched hash + // is below Next so isFetched returns true; the unfetched hash equals Next + // so isFetched returns false. + syncer.tasks = []*accountTask{{ + Next: unfetchedHash, + Last: common.MaxHash, + SubTasks: make(map[common.Hash][]*storageTask), + stateCompleted: make(map[common.Hash]struct{}), + }} + + cb := bal.NewConstructionBlockAccessList() + cb.BalanceChange(0, fetchedAddr, uint256.NewInt(100)) + cb.BalanceChange(0, unfetchedAddr, uint256.NewInt(200)) + b := buildTestBAL(t, &cb) + + if err := syncer.applyAccessList(b); err != nil { + t.Fatalf("applyAccessList failed: %v", err) + } + + // The fetched account should have been written. + if data := rawdb.ReadAccountSnapshot(db, fetchedHash); len(data) == 0 { + t.Error("expected fetched account to be written") + } + // The unfetched account should not have been touched. + if data := rawdb.ReadAccountSnapshot(db, unfetchedHash); len(data) != 0 { + t.Errorf("unfetched account should not be written, got %x", data) + } +} + +// TestAccessListApplicationSkipsUnfetchedStorage verifies that storage writes +// are also skipped when the parent account's hash range isn't downloaded yet. +func TestAccessListApplicationSkipsUnfetchedStorage(t *testing.T) { + t.Parallel() + db := rawdb.NewMemoryDatabase() + syncer := NewSyncer(db, rawdb.HashScheme) + + addrA := common.HexToAddress("0x01") + addrB := common.HexToAddress("0x02") + hashA := crypto.Keccak256Hash(addrA[:]) + hashB := crypto.Keccak256Hash(addrB[:]) + + unfetchedAddr, unfetchedHash := addrB, hashB + if bytes.Compare(hashA[:], hashB[:]) > 0 { + unfetchedAddr, unfetchedHash = addrA, hashA + } + + syncer.tasks = []*accountTask{{ + Next: unfetchedHash, + Last: common.MaxHash, + SubTasks: make(map[common.Hash][]*storageTask), + stateCompleted: make(map[common.Hash]struct{}), + }} + + // BAL touches an unfetched account with a storage write AND an empty + // balance mutation. Neither should result in any flat-state writes. + rawSlot := common.HexToHash("0xaa") + slotHash := crypto.Keccak256Hash(rawSlot[:]) + cb := bal.NewConstructionBlockAccessList() + cb.BalanceChange(0, unfetchedAddr, uint256.NewInt(0)) // empty mutation + cb.StorageWrite(0, unfetchedAddr, rawSlot, common.HexToHash("0xff")) + b := buildTestBAL(t, &cb) + + if err := syncer.applyAccessList(b); err != nil { + t.Fatalf("applyAccessList failed: %v", err) + } + + if data := rawdb.ReadAccountSnapshot(db, unfetchedHash); len(data) != 0 { + t.Errorf("unfetched account should not be written, got %x", data) + } + if val := rawdb.ReadStorageSnapshot(db, unfetchedHash, slotHash); len(val) != 0 { + t.Errorf("storage for unfetched account should not be written, got %x", val) + } +} + // TestAccessListApplicationSameTxCreateDestroy tests the edge case where an // account is created and self-destructed in the same transaction during the // pivot gap. Per EIP-7928, such accounts appear in the BAL with a balance @@ -297,3 +426,40 @@ func TestAccessListApplicationSameTxCreateDestroy(t *testing.T) { account.Balance, account.Nonce, account.CodeHash, account.Root) } } + +// TestAccessListApplicationDestroyExisting verifies that when a BAL reduces +// an existing flat-state account to nonce=0, balance=0, empty code (the +// pre-funded destruction pattern), applyAccessList deletes the entry rather +// than leaving it zereod. +func TestAccessListApplicationDestroyExisting(t *testing.T) { + t.Parallel() + db := rawdb.NewMemoryDatabase() + syncer := NewSyncer(db, rawdb.HashScheme) + addr := common.HexToAddress("0x05") + accountHash := crypto.Keccak256Hash(addr[:]) + + // Pre-funded account: has balance, no nonce, no code. + original := types.StateAccount{ + Nonce: 0, + Balance: uint256.NewInt(1000), + Root: types.EmptyRootHash, + CodeHash: types.EmptyCodeHash[:], + } + rawdb.WriteAccountSnapshot(db, accountHash, types.SlimAccountRLP(original)) + + // The BAL zeros the balance. Nonce and code were already empty, so + // the account ends up fully empty after applying. + cb := bal.NewConstructionBlockAccessList() + cb.BalanceChange(0, addr, uint256.NewInt(0)) + b := buildTestBAL(t, &cb) + if err := syncer.applyAccessList(b); err != nil { + t.Fatalf("applyAccessList failed: %v", err) + } + + if data := rawdb.ReadAccountSnapshot(db, accountHash); len(data) != 0 { + account, _ := types.FullAccount(data) + t.Errorf("destroyed account should have been deleted from flat state, "+ + "got balance=%v, nonce=%d, codeHash=%x", + account.Balance, account.Nonce, account.CodeHash) + } +} diff --git a/eth/protocols/snap/progress_test.go b/eth/protocols/snap/progress_test.go index 1d9a6b8474..fef90c3dda 100644 --- a/eth/protocols/snap/progress_test.go +++ b/eth/protocols/snap/progress_test.go @@ -18,137 +18,123 @@ package snap import ( "encoding/json" + "math/big" "testing" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" ) -// Legacy sync progress definitions -type legacyStorageTask struct { - Next common.Hash // Next account to sync in this interval - Last common.Hash // Last account to sync in this interval -} +// TestSyncProgressV1Discarded verifies that a persisted blob written in the +// old unversioned format (raw JSON, no version prefix) is detected and +// discarded on load, that the syncer falls through to a fresh start, and +// that any orphan flat-state entries from the prior format are wiped. +func TestSyncProgressV1Discarded(t *testing.T) { + db := rawdb.NewMemoryDatabase() -type legacyAccountTask struct { - Next common.Hash // Next account to sync in this interval - Last common.Hash // Last account to sync in this interval - SubTasks map[common.Hash][]*legacyStorageTask // Storage intervals needing fetching for large contracts -} - -type legacyProgress struct { - Tasks []*legacyAccountTask // The suspended account tasks (contract tasks within) -} - -func compareProgress(a legacyProgress, b SyncProgress) bool { - if len(a.Tasks) != len(b.Tasks) { - return false + // Write a raw JSON blob (no version byte) to simulate progress persisted + // by a prior geth binary (snap/1 format). + legacy := map[string]any{ + "Root": common.HexToHash("0xaaaa"), + "BlockNumber": uint64(42), + "Tasks": []any{}, } - for i := 0; i < len(a.Tasks); i++ { - if a.Tasks[i].Next != b.Tasks[i].Next { - return false - } - if a.Tasks[i].Last != b.Tasks[i].Last { - return false - } - // new fields are not checked here - - if len(a.Tasks[i].SubTasks) != len(b.Tasks[i].SubTasks) { - return false - } - for addrHash, subTasksA := range a.Tasks[i].SubTasks { - subTasksB, ok := b.Tasks[i].SubTasks[addrHash] - if !ok || len(subTasksB) != len(subTasksA) { - return false - } - for j := 0; j < len(subTasksA); j++ { - if subTasksA[j].Next != subTasksB[j].Next { - return false - } - if subTasksA[j].Last != subTasksB[j].Last { - return false - } - } - } - } - return true -} - -func makeLegacyProgress() legacyProgress { - return legacyProgress{ - Tasks: []*legacyAccountTask{ - { - Next: common.Hash{}, - Last: common.Hash{0x77}, - SubTasks: map[common.Hash][]*legacyStorageTask{ - {0x1}: { - { - Next: common.Hash{}, - Last: common.Hash{0xff}, - }, - }, - }, - }, - { - Next: common.Hash{0x88}, - Last: common.Hash{0xff}, - }, - }, - } -} - -func convertLegacy(legacy legacyProgress) SyncProgress { - var progress SyncProgress - for i, task := range legacy.Tasks { - subTasks := make(map[common.Hash][]*storageTask) - for owner, list := range task.SubTasks { - var cpy []*storageTask - for i := 0; i < len(list); i++ { - cpy = append(cpy, &storageTask{ - Next: list[i].Next, - Last: list[i].Last, - }) - } - subTasks[owner] = cpy - } - accountTask := &accountTask{ - Next: task.Next, - Last: task.Last, - SubTasks: subTasks, - } - if i == 0 { - accountTask.StorageCompleted = []common.Hash{{0xaa}, {0xbb}} // fulfill new fields - } - progress.Tasks = append(progress.Tasks, accountTask) - } - return progress -} - -func TestSyncProgressCompatibility(t *testing.T) { - // Decode serialized bytes of legacy progress, backward compatibility - legacy := makeLegacyProgress() blob, err := json.Marshal(legacy) if err != nil { - t.Fatalf("Failed to marshal progress %v", err) - } - var dec SyncProgress - if err := json.Unmarshal(blob, &dec); err != nil { - t.Fatalf("Failed to unmarshal progress %v", err) - } - if !compareProgress(legacy, dec) { - t.Fatal("sync progress is not backward compatible") + t.Fatalf("marshal legacy: %v", err) } + rawdb.WriteSnapshotSyncStatus(db, blob) - // Decode serialized bytes of new format progress - progress := convertLegacy(legacy) - blob, err = json.Marshal(progress) - if err != nil { - t.Fatalf("Failed to marshal progress %v", err) + // Pre-write orphan flat-state entries that should be wiped on fresh start. + orphanAccountHash := common.HexToHash("0xdeadbeef") + rawdb.WriteAccountSnapshot(db, orphanAccountHash, []byte{0xde, 0xad}) + orphanStorageAccount := common.HexToHash("0xfeedface") + orphanStorageSlot := common.HexToHash("0xabcd") + rawdb.WriteStorageSnapshot(db, orphanStorageAccount, orphanStorageSlot, []byte{0xff, 0xff}) + + syncer := NewSyncer(db, rawdb.HashScheme) + syncer.loadSyncStatus() + + if syncer.previousPivot != nil { + t.Fatalf("expected previousPivot nil after discarding old format, got %+v", syncer.previousPivot) } - var legacyDec legacyProgress - if err := json.Unmarshal(blob, &legacyDec); err != nil { - t.Fatalf("Failed to unmarshal progress %v", err) + if len(syncer.tasks) != accountConcurrency { + t.Fatalf("expected fresh task split of %d, got %d", accountConcurrency, len(syncer.tasks)) } - if !compareProgress(legacyDec, progress) { - t.Fatal("sync progress is not forward compatible") + if data := rawdb.ReadAccountSnapshot(db, orphanAccountHash); len(data) != 0 { + t.Errorf("orphan account snapshot should be wiped, got %x", data) + } + if val := rawdb.ReadStorageSnapshot(db, orphanStorageAccount, orphanStorageSlot); len(val) != 0 { + t.Errorf("orphan storage snapshot should be wiped, got %x", val) + } +} + +// TestSyncProgressV2RoundTrip verifies that the persisted blob is framed +// with the expected version byte at offset 0, and that all six status +// counters survive the round-trip. +func TestSyncProgressV2RoundTrip(t *testing.T) { + db := rawdb.NewMemoryDatabase() + + saver := NewSyncer(db, rawdb.HashScheme) + saver.pivot = &types.Header{Number: new(big.Int).SetUint64(123), Difficulty: common.Big0} + saver.accountSynced = 1 + saver.accountBytes = 2 + saver.bytecodeSynced = 3 + saver.bytecodeBytes = 4 + saver.storageSynced = 5 + saver.storageBytes = 6 + saver.saveSyncStatus() + + raw := rawdb.ReadSnapshotSyncStatus(db) + if len(raw) == 0 || raw[0] != syncProgressVersion { + t.Fatalf("expected version byte %d at offset 0, got blob %x", syncProgressVersion, raw) + } + + loader := NewSyncer(db, rawdb.HashScheme) + loader.loadSyncStatus() + for _, c := range []struct { + name string + got uint64 + want uint64 + }{ + {"accountSynced", loader.accountSynced, 1}, + {"accountBytes", uint64(loader.accountBytes), 2}, + {"bytecodeSynced", loader.bytecodeSynced, 3}, + {"bytecodeBytes", uint64(loader.bytecodeBytes), 4}, + {"storageSynced", loader.storageSynced, 5}, + {"storageBytes", uint64(loader.storageBytes), 6}, + } { + if c.got != c.want { + t.Errorf("%s mismatch: got %d, want %d", c.name, c.got, c.want) + } + } +} + +// TestSyncProgressCorruptPayload verifies that a persisted blob with the +// correct version byte but unparseable JSON body is discarded, triggers a +// fresh-start fall-through (not a panic or a stale-state load), and the +// orphan flat state is wiped along with the corrupt status. +func TestSyncProgressCorruptPayload(t *testing.T) { + db := rawdb.NewMemoryDatabase() + + // Version byte followed by garbage that isn't valid JSON. + rawdb.WriteSnapshotSyncStatus(db, []byte{syncProgressVersion, 0x7b, 0x7b, 0x7b}) + + // Pre-write orphan flat-state entries that should be wiped on fresh start. + orphanAccountHash := common.HexToHash("0xdeadbeef") + rawdb.WriteAccountSnapshot(db, orphanAccountHash, []byte{0xde, 0xad}) + + syncer := NewSyncer(db, rawdb.HashScheme) + syncer.loadSyncStatus() + + if syncer.previousPivot != nil { + t.Fatalf("expected previousPivot nil after corrupt payload, got %+v", syncer.previousPivot) + } + if len(syncer.tasks) != accountConcurrency { + t.Fatalf("expected fresh task split of %d, got %d", accountConcurrency, len(syncer.tasks)) + } + if data := rawdb.ReadAccountSnapshot(db, orphanAccountHash); len(data) != 0 { + t.Errorf("orphan account snapshot should be wiped, got %x", data) } } diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go index 3c33709fca..ffd857d105 100644 --- a/eth/protocols/snap/sync.go +++ b/eth/protocols/snap/sync.go @@ -64,7 +64,7 @@ const ( // come close to that, requesting 4x should be a good approximation. maxCodeRequestCount = maxRequestSize / (24 * 1024) * 4 - // maxAccessListRequestCount is the maximum number of block access lists to + // maxAccessListRequestCount is the maximum number of block BALs to // request in a single query. BALs average ~72 KiB compressed (per EIP-7928), // and EIP-8189 recommends a 2 MiB response soft limit, so we target ~28 // blocks per request to avoid server-side truncation. @@ -73,6 +73,11 @@ const ( // to avoid server-side truncation and re-requesting. It is currently based on // the assumption that the gas limit is 60M. maxAccessListRequestCount = 28 + + // syncProgressVersion is the version byte prepended to the JSON-encoded + // SyncProgress when persisted. On load, a mismatching version byte causes + // the persisted progress to be discarded and sync to start fresh. + syncProgressVersion byte = 2 ) var ( @@ -89,6 +94,11 @@ var ( // terminated. var ErrCancelled = errors.New("sync cancelled") +// errAccessListPeersExhausted is returned from fetchAccessLists when every +// connected peer has been marked stateless for BAL requests and there +// are still hashes left to fetch. +var errAccessListPeersExhausted = errors.New("all peers exhausted for BAL requests") + // accountRequest tracks a pending account range request to ensure responses are // to actual requests and to validate any security constraints. // @@ -292,9 +302,9 @@ type storageTask struct { // sync. Opposed to full and fast sync, there is no way to restart a suspended // snap sync without prior knowledge of the suspension point. type SyncProgress struct { - Root common.Hash // State root being synced (for pivot move detection) - BlockNumber uint64 // Block number of the pivot - Tasks []*accountTask // The suspended account tasks (contract tasks within) + Pivot *types.Header // Pivot header being synced (for pivot move and reorg detection) + Tasks []*accountTask // The suspended account tasks (contract tasks within) + Complete bool // True once sync ran to completion for Pivot // Status report during syncing phase AccountSynced uint64 // Number of accounts downloaded @@ -303,7 +313,6 @@ type SyncProgress struct { BytecodeBytes common.StorageSize // Number of bytecode bytes downloaded StorageSynced uint64 // Number of storage slots downloaded StorageBytes common.StorageSize // Number of storage trie bytes persisted to disk - } // SyncPeer abstracts out the methods required for a peer to be synced against @@ -347,12 +356,11 @@ type Syncer struct { db ethdb.Database // Database to store the trie nodes into (and dedup) scheme string // Node scheme used in node database - root common.Hash // Current state trie root being synced - number uint64 // Block number of the current pivot - previousRoot common.Hash // Root from previous sync run (for pivot move detection) - previousNumber uint64 // Block number of the previous pivot - tasks []*accountTask // Current account task set being synced - update chan struct{} // Notification channel for possible sync progression + pivot *types.Header // Current pivot header being synced + previousPivot *types.Header // Pivot from previous sync run (for pivot move detection) + complete bool // Whether the persisted progress was a completed sync + tasks []*accountTask // Current account task set being synced + update chan struct{} // Notification channel for possible sync progression peers map[string]SyncPeer // Currently active peers to download from peerJoin *event.Feed // Event feed to react to peers joining @@ -364,12 +372,12 @@ type Syncer struct { accountIdlers map[string]struct{} // Peers that aren't serving account requests bytecodeIdlers map[string]struct{} // Peers that aren't serving bytecode requests storageIdlers map[string]struct{} // Peers that aren't serving storage requests - accessListIdlers map[string]struct{} // Peers that aren't serving access list requests + accessListIdlers map[string]struct{} // Peers that aren't serving BAL requests accountReqs map[uint64]*accountRequest // Account requests currently running bytecodeReqs map[uint64]*bytecodeRequest // Bytecode requests currently running storageReqs map[uint64]*storageRequest // Storage requests currently running - accessListReqs map[uint64]*accessListRequest // Access list requests currently running + accessListReqs map[uint64]*accessListRequest // BAL requests currently running accountSynced uint64 // Number of accounts downloaded accountBytes common.StorageSize // Number of account trie bytes persisted to disk @@ -384,7 +392,7 @@ type Syncer struct { logTime time.Time // Time instance when status was last reported pend sync.WaitGroup // Tracks network request goroutines for graceful shutdown - lock sync.RWMutex // Protects fields that can change outside of sync (peers, reqs, root) + lock sync.RWMutex // Protects fields that can change outside of sync (peers, reqs, pivot) } // NewSyncer creates a new snapshot syncer to download the Ethereum state over the @@ -469,44 +477,47 @@ func (s *Syncer) Unregister(id string) error { return nil } -// errPivotStale is returned from download when the pivot has become stale -// and the syncer needs to perform access list catch-up before continuing. -var errPivotStale = errors.New("pivot stale") - // Sync starts (or resumes a previous) sync cycle to iterate over a state trie -// with the given root and reconstruct the nodes based on the snapshot leaves. -// The number parameter is the block number of the pivot block. -func (s *Syncer) Sync(root common.Hash, number uint64, cancel chan struct{}) error { +// with the given pivot header and reconstruct the nodes based on the snapshot +// leaves. +func (s *Syncer) Sync(pivot *types.Header, cancel chan struct{}) error { + if pivot == nil { + return errors.New("snap sync: pivot header is nil") + } s.lock.Lock() - s.root = root - s.number = number - s.previousRoot = root // Default: no pivot move. loadSyncStatus may overwrite. - s.previousNumber = number + s.pivot = pivot + s.previousPivot = nil // loadSyncStatus overwrites when resuming from persisted progress s.statelessPeers = make(map[string]struct{}) s.lock.Unlock() if s.startTime.IsZero() { s.startTime = time.Now() } + root := pivot.Root // Retrieve the previous sync status from DB. If there's no persisted // status, sync is either fresh or already complete. s.loadSyncStatus() - var syncComplete bool - defer func() { - if !syncComplete { - for _, task := range s.tasks { - s.forwardAccountTask(task) - } - s.cleanAccountTasks() - s.saveSyncStatus() - } - }() - log.Debug("Starting snapshot sync cycle", "root", root) - defer s.report(true) + // isPivotChanged is true when we have prior progress against a different + // pivot. That means we need to roll forward via catchUp, or wipe and + // restart if the prior pivot was reorged out. + isPivotChanged := s.previousPivot != nil && s.previousPivot.Hash() != s.pivot.Hash() + + // Skip if we've already finished syncing this pivot. + if !isPivotChanged && s.complete { + log.Info("Snap sync already complete for this pivot", "root", root) + return nil + } + + // We're committing to running this sync. Clear the complete flag so a + // mid-run save (on cancel or error) doesn't persist a stale Complete=true + // status from a prior pivot. + s.lock.Lock() + s.complete = false + s.lock.Unlock() - // Whether sync completed or not, disregard any future packets defer func() { + // Whether sync completed or not, disregard any future packets log.Debug("Terminating snapshot sync cycle", "root", root) s.lock.Lock() s.accountReqs = make(map[uint64]*accountRequest) @@ -514,57 +525,68 @@ func (s *Syncer) Sync(root common.Hash, number uint64, cancel chan struct{}) err s.bytecodeReqs = make(map[uint64]*bytecodeRequest) s.accessListReqs = make(map[uint64]*accessListRequest) s.lock.Unlock() + + // Persist final task state. + for _, task := range s.tasks { + s.forwardAccountTask(task) + } + s.cleanAccountTasks() + s.saveSyncStatus() + + // Log final progress. + s.report(true) }() - // Sync loop - log.Info("Starting state download", "root", root) - for { - // Download: fetch all required state data - err := s.downloadState(cancel) - if err == errPivotStale { - // Pivot moved: catch up to new pivot - if err := s.catchUp(cancel); err != nil { - return err - } - s.resetDownloadState(root, number) - log.Info("Resuming state download", "root", root) - continue - } + log.Debug("Starting snapshot sync cycle", "root", root) - // Download error that isn't a stale pivot. This is typically due to - // the downloader cancelling the sync because the pivot moved. This - // error propagates to the downloader which will restart the sync with - // a new root. - if err != nil { + // If we resumed against a different pivot, decide whether the persisted + // progress is still usable. If yes, roll forward via BAL catch-up. If not, + // wipe everything and restart fresh. + if isPivotChanged { + if isPivotReorged(s.db, s.previousPivot, s.pivot) { + log.Warn("Persisted progress unusable, restarting snap sync from scratch", + "number", s.previousPivot.Number, "oldHash", s.previousPivot.Hash()) + s.resetSyncState() + } else if err := s.catchUp(cancel); err != nil { return err } - log.Info("State download complete", "root", root) - - // Trie rebuild: build all tries from flat state and verify root - log.Info("Starting trie rebuild", "root", root) - if err := triedb.GenerateTrie(s.db, s.scheme, root); err != nil { - return err - } - log.Info("Trie rebuild complete", "root", root) - - // Sync complete: clear persisted status so we don't re-run. - // Set syncComplete to prevent the deferred saveSyncStatus from - // overwriting the nil. - syncComplete = true - rawdb.WriteSnapshotSyncStatus(s.db, nil) - return nil } + + // Pin previousPivot to the current pivot before downloadState runs. + // This is what saveSyncStatus persists. If the download is interrupted + // and the next Sync gets a different pivot, this is how isPivotReorged + // recognizes the partial flat state belongs to the old pivot. Without + // it, isPivotReorged sees nil, skips the reorg branch, and downloadState + // would resume from the persisted task markers but mix the old pivot's + // already-downloaded accounts with the new pivot's data. + s.lock.Lock() + s.previousPivot = s.pivot + s.lock.Unlock() + + log.Info("Starting state download", "root", root) + if err := s.downloadState(cancel); err != nil { + return err + } + log.Info("State download complete", "root", root) + + log.Info("Starting trie generation", "root", root) + if err := triedb.GenerateTrie(s.db, s.scheme, root); err != nil { + return err + } + log.Info("Trie generation complete", "root", root) + + // Mark sync complete. The deferred saveSyncStatus persists this with + // Complete=true so a follow-up Sync call for the same pivot can skip + // the work entirely. + s.lock.Lock() + s.complete = true + s.lock.Unlock() + return nil } // download runs the bulk flat-state download. It fetches // account ranges, storage slots, and bytecodes, writing flat state to disk. func (s *Syncer) downloadState(cancel chan struct{}) error { - // If the pivot moved since the last run (downloader cancelled and restarted - // us with a new root), signal catch-up before downloading. - if s.previousRoot != s.root { - return errPivotStale - } - // Subscribe to peer events peerJoin := make(chan string, 16) peerJoinSub := s.peerJoin.Subscribe(peerJoin) @@ -638,95 +660,105 @@ func (s *Syncer) downloadState(cancel chan struct{}) error { } } -// resetDownloadState resets the download state for a new pivot after catch-up. -// It regenerates the task list for accounts not yet downloaded, clears -// in-flight requests, and updates the root. -func (s *Syncer) resetDownloadState(root common.Hash, number uint64) { - s.lock.Lock() - s.root = root - s.number = number - s.previousRoot = root // Prevent downloadState() from returning errPivotStale again - s.previousNumber = number +// isPivotReorged reports whether the previous pivot is no longer usable +// as a starting point for forward catch-up. Either it was reorged out +// of the canonical chain, or the new pivot doesn't advance past it. +func isPivotReorged(db ethdb.Database, prev, curr *types.Header) bool { + // If the new pivot is at or below the old one, there's nothing for + // catchUp to roll forward. + if curr.Number.Cmp(prev.Number) <= 0 { + return true + } + // If there's no canonical hash at the old pivot's height, something + // is wrong. Headers up to the new pivot should already be indexed, + // so a missing entry at an earlier block means the chain state is + // broken. The most common cause is a chain rewind across the + // snap-synced pivot, which resets head to genesis and deletes + // canonical entries above it (see rewindPathHead in core/blockchain.go). + // Bail and let the fresh sync recover. + canonical := rawdb.ReadCanonicalHash(db, prev.Number.Uint64()) + if canonical == (common.Hash{}) { + return true + } - // Clear stateless peers bc they may be able to serve the new pivot - s.statelessPeers = make(map[string]struct{}) - s.lock.Unlock() + // If canonical at the old pivot's height has a different hash, the + // old pivot was reorged out. + return canonical != prev.Hash() } -// catchUp runs the BAL catch-up. When the pivot has moved (previousRoot != -// root), it fetches BALs for the gap blocks, verifies them against -// block headers, and applies the diffs to roll flat state forward. +// catchUp runs the BAL catch-up. When the pivot has moved, it fetches BALs +// for the gap blocks, verifies them against block headers, and applies the +// diffs to roll flat state forward. func (s *Syncer) catchUp(cancel chan struct{}) error { s.lock.RLock() - from := s.previousNumber + 1 - to := s.number + from := s.previousPivot.Number.Uint64() + 1 + to := s.pivot.Number.Uint64() s.lock.RUnlock() + log.Info("Starting BAL catch-up", "from", from, "to", to, "blocks", to-from+1) - // The new pivot must be ahead of the old one. The range is inverted if - // a reorg replaced the block at the pivot height (same number, different - // root) or if the chain shortened past the old pivot. In either case, - // catch-up can't roll forward — wipe progress and restart. This also - // catches reorgs missed by checkDeepReorg, which only runs when the - // downloader actively restarts the syncer, not on resume from persisted - // progress. - if from > to { - log.Warn("Catch-up range inverted, wiping sync progress", "from", from, "to", to) - rawdb.WriteSnapshotSyncStatus(s.db, nil) - return fmt.Errorf("catch-up range inverted (from %d > to %d): pivot reorged", from, to) - } - log.Info("Starting access list catch-up", "from", from, "to", to, "blocks", to-from+1) - - // Collect block hashes for the gap range + // Collect block hashes and headers for the gap range. hashes := make([]common.Hash, 0, to-from+1) + headers := make(map[common.Hash]*types.Header, to-from+1) for num := from; num <= to; num++ { hash := rawdb.ReadCanonicalHash(s.db, num) if hash == (common.Hash{}) { return fmt.Errorf("missing canonical hash for block %d during catch-up", num) } - hashes = append(hashes, hash) - } - - // Fetch BALs from peers - rawBALs, err := s.fetchAccessLists(hashes, cancel) - if err != nil { - return err - } - - // Verify and apply each BAL in block order - for i, raw := range rawBALs { - num := from + uint64(i) - hash := hashes[i] - - // Decode the raw RLP into a BlockAccessList - var bal bal.BlockAccessList - if err := rlp.DecodeBytes(raw, &bal); err != nil { - return fmt.Errorf("failed to decode BAL for block %d: %v", num, err) - } - - // Verify against the block header header := rawdb.ReadHeader(s.db, hash, num) if header == nil { return fmt.Errorf("missing header for block %d (hash %v) during catch-up", num, hash) } - if err := verifyAccessList(&bal, header); err != nil { - return fmt.Errorf("BAL verification failed for block %d: %v", num, err) + hashes = append(hashes, hash) + headers[hash] = header + } + + // Fetch BALs from peers + rawBALs, err := s.fetchAccessLists(hashes, headers, cancel) + if err != nil { + return err + } + + // Apply each BAL in block order. BALs are already verified by fetchAccessLists. + for i, raw := range rawBALs { + select { + case <-cancel: + return ErrCancelled + default: + } + num := from + uint64(i) + hash := hashes[i] + + // Decode the raw RLP into a BAL. + var b bal.BlockAccessList + if err := rlp.DecodeBytes(raw, &b); err != nil { + return fmt.Errorf("failed to decode BAL for block %d: %v", num, err) } - // Apply the state diffs - if err := s.applyAccessList(&bal); err != nil { + // applyAccessList failures are persistent. If a block's apply fails + // here, the next Sync will resume from this block and hit the same + // failure. Auto-recovery isn't implemented yet. + if err := s.applyAccessList(&b); err != nil { return fmt.Errorf("BAL application failed for block %d: %v", num, err) } + + // Persist incremental progress so a crash mid-catchUp can resume + // from the next unapplied block. + s.lock.Lock() + s.previousPivot = headers[hash] + s.lock.Unlock() + s.saveSyncStatus() } - log.Info("Access list catch-up complete", "blocks", len(rawBALs)) + log.Info("BAL catch-up complete", "blocks", len(rawBALs)) return nil } // fetchAccessLists fetches BALs for the given block hashes from // remote peers. It runs its own event loop to assign requests -// to idle peers and process responses asynchronously. Results are returned in -// the same order as the input hashes. -func (s *Syncer) fetchAccessLists(hashes []common.Hash, cancel chan struct{}) ([]rlp.RawValue, error) { - log.Debug("Fetching access lists for catch-up", "blocks", len(hashes)) +// to idle peers and process responses asynchronously. Each BAL is verified +// against its header before being accepted. Results are returned in the +// same order as the input hashes. +func (s *Syncer) fetchAccessLists(hashes []common.Hash, headers map[common.Hash]*types.Header, cancel chan struct{}) ([]rlp.RawValue, error) { + log.Debug("Fetching BALs for catch-up", "blocks", len(hashes)) // Subscribe to peer events peerJoin := make(chan string, 16) @@ -745,11 +777,29 @@ func (s *Syncer) fetchAccessLists(hashes []common.Hash, cancel chan struct{}) ([ var ( accessListReqFails = make(chan *accessListRequest) accessListResps = make(chan *accessListResponse) + lastStallLog = time.Now() ) for len(fetched) < len(hashes) { - // Assign access list retrieval tasks to idle peers + // Assign BAL retrieval tasks to idle peers s.assignAccessListTasks(pending, accessListResps, accessListReqFails, cancel) + // If every peer is now stateless and nothing is in flight, no event + // short of cancel or a new peer joining can move us forward. Surface + // this so the caller can return and let a higher-level retry happen + // against a fresh peer set. + if s.accessListPeersExhausted() { + log.Warn("BAL peers exhausted, stopping catch-up early", + "fetched", len(fetched), "remaining", len(pending)) + return nil, errAccessListPeersExhausted + } + + // Periodic visibility while stalled with peers connected but idle. + if len(pending) > 0 && time.Since(lastStallLog) > 30*time.Second { + lastStallLog = time.Now() + log.Warn("BAL catch-up stalled, awaiting peers", + "fetched", len(fetched), "remaining", len(pending)) + } + // Wait for something to happen select { case <-s.update: @@ -777,7 +827,7 @@ func (s *Syncer) fetchAccessLists(hashes []common.Hash, cancel chan struct{}) ([ pending[h] = struct{}{} } case res := <-accessListResps: - s.processAccessListResponse(res, pending, fetched) + s.processAccessListResponse(res, headers, pending, fetched) } } @@ -789,7 +839,7 @@ func (s *Syncer) fetchAccessLists(hashes []common.Hash, cancel chan struct{}) ([ return results, nil } -// assignAccessListTasks attempts to assign access list fetch requests to idle +// assignAccessListTasks attempts to assign BAL fetch requests to idle // peers for any hashes still in pending. func (s *Syncer) assignAccessListTasks(pending map[common.Hash]struct{}, success chan *accessListResponse, fail chan *accessListRequest, cancel chan struct{}) { s.lock.Lock() @@ -842,7 +892,7 @@ func (s *Syncer) assignAccessListTasks(pending map[common.Hash]struct{}, success stale: make(chan struct{}), } req.timeout = time.AfterFunc(s.rates.TargetTimeout(), func() { - peer.Log().Debug("Access list request timed out", "reqid", reqid) + peer.Log().Debug("BAL request timed out", "reqid", reqid) s.rates.Update(idle, AccessListsMsg, 0, 0) s.scheduleRevertAccessListRequest(req) }) @@ -855,16 +905,22 @@ func (s *Syncer) assignAccessListTasks(pending map[common.Hash]struct{}, success // Attempt to send the remote request and revert if it fails if err := peer.RequestAccessLists(reqid, batch, softResponseLimit); err != nil { - log.Debug("Failed to request access lists", "err", err) + log.Debug("Failed to request BALs", "err", err) s.scheduleRevertAccessListRequest(req) } }() } } -// processAccessListResponse handles a successful access list response by -// matching results to pending hashes and storing them. -func (s *Syncer) processAccessListResponse(res *accessListResponse, pending map[common.Hash]struct{}, fetched map[common.Hash]rlp.RawValue) { +// processAccessListResponse handles a successful BAL response. It +// verifies each non-empty BAL against the corresponding block header and +// stores the verified ones in fetched. +func (s *Syncer) processAccessListResponse(res *accessListResponse, headers map[common.Hash]*types.Header, pending map[common.Hash]struct{}, fetched map[common.Hash]rlp.RawValue) { + type verified struct { + hash common.Hash + raw rlp.RawValue + } + var ok []verified // Each response entry corresponds to the requested hash at the same index. for i, raw := range res.accessLists { h := res.req.hashes[i] @@ -874,25 +930,68 @@ func (s *Syncer) processAccessListResponse(res *accessListResponse, pending map[ pending[h] = struct{}{} continue } - fetched[h] = raw - delete(pending, h) + var b bal.BlockAccessList + if err := rlp.DecodeBytes(raw, &b); err != nil { + log.Warn("Peer sent unparseable BAL, marking stateless", + "peer", res.req.peer, "block", h, "err", err) + s.rejectAccessListResponse(res, pending) + return + } + header, found := headers[h] + if !found { + // Caller must supply a header for every requested hash. + log.Error("Missing header for fetched BAL", "block", h) + s.rejectAccessListResponse(res, pending) + return + } + if err := verifyAccessList(&b, header); err != nil { + log.Warn("Peer sent BAL that failed verification, marking stateless", + "peer", res.req.peer, "block", h, "err", err) + s.rejectAccessListResponse(res, pending) + return + } + ok = append(ok, verified{hash: h, raw: raw}) } // Re-add hashes that were not served back to pending for i := len(res.accessLists); i < len(res.req.hashes); i++ { pending[res.req.hashes[i]] = struct{}{} } + + // Commit the verified entries. + for _, v := range ok { + fetched[v.hash] = v.raw + delete(pending, v.hash) + } +} + +// rejectAccessListResponse marks the responding peer stateless and re-adds +// every hash from the request to pending so the work moves to other peers. +func (s *Syncer) rejectAccessListResponse(res *accessListResponse, pending map[common.Hash]struct{}) { + s.lock.Lock() + s.statelessPeers[res.req.peer] = struct{}{} + s.lock.Unlock() + for _, h := range res.req.hashes { + pending[h] = struct{}{} + } } // loadSyncStatus retrieves a previously aborted sync status from the database, -// or generates a fresh one if none is available. +// or generates a fresh one if none is available. The persisted blob is framed +// as `[version byte | JSON payload]`; a missing or mismatching version byte +// causes the progress to be discarded and sync to start fresh. func (s *Syncer) loadSyncStatus() { var progress SyncProgress - if status := rawdb.ReadSnapshotSyncStatus(s.db); status != nil { - if err := json.Unmarshal(status, &progress); err != nil { + if raw := rawdb.ReadSnapshotSyncStatus(s.db); len(raw) > 0 { + if raw[0] != syncProgressVersion { + log.Info("Discarding old-format sync progress", "version", raw[0], "expected", syncProgressVersion) + } else if err := json.Unmarshal(raw[1:], &progress); err != nil { log.Error("Failed to decode snap sync status", "err", err) } else { + s.lock.Lock() + defer s.lock.Unlock() + for _, task := range progress.Tasks { log.Debug("Scheduled account sync task", "from", task.Next, "last", task.Last) } @@ -905,11 +1004,8 @@ func (s *Syncer) loadSyncStatus() { } task.StorageCompleted = nil } - s.lock.Lock() - defer s.lock.Unlock() - - s.previousRoot = progress.Root - s.previousNumber = progress.BlockNumber + s.previousPivot = progress.Pivot + s.complete = progress.Complete s.accountSynced = progress.AccountSynced s.accountBytes = progress.AccountBytes s.bytecodeSynced = progress.BytecodeSynced @@ -920,9 +1016,27 @@ func (s *Syncer) loadSyncStatus() { } } // Either we've failed to decode the previous state, or there was none. - // Start a fresh sync by chunking up the account range and scheduling - // them for retrieval. + s.resetSyncState() +} + +// resetSyncState wipes all persisted snap-sync data (sync status, account +// and storage snapshots) and re-initializes in-memory state with a fresh +// chunking of the account hash range. +func (s *Syncer) resetSyncState() { + rawdb.DeleteSnapshotSyncStatus(s.db) + if err := s.db.DeleteRange(rawdb.SnapshotAccountPrefix, []byte{rawdb.SnapshotAccountPrefix[0] + 1}); err != nil { + log.Crit("Failed to wipe account snapshot range", "err", err) + } + if err := s.db.DeleteRange(rawdb.SnapshotStoragePrefix, []byte{rawdb.SnapshotStoragePrefix[0] + 1}); err != nil { + log.Crit("Failed to wipe storage snapshot range", "err", err) + } + + s.lock.Lock() + defer s.lock.Unlock() + s.tasks = nil + s.previousPivot = nil + s.complete = false s.accountSynced, s.accountBytes = 0, 0 s.bytecodeSynced, s.bytecodeBytes = 0, 0 s.storageSynced, s.storageBytes = 0, 0 @@ -951,7 +1065,7 @@ func (s *Syncer) loadSyncStatus() { } } -// saveSyncStatus marshals the remaining sync tasks into leveldb. +// saveSyncStatus marshals the remaining sync tasks into db. func (s *Syncer) saveSyncStatus() { // Serialize any partial progress to disk before spinning down for _, task := range s.tasks { @@ -964,11 +1078,11 @@ func (s *Syncer) saveSyncStatus() { log.Debug("Leftover completed storages", "number", len(task.StorageCompleted), "next", task.Next, "last", task.Last) } } - // Store the actual progress markers + // Store the actual progress markers. progress := &SyncProgress{ - Root: s.root, - BlockNumber: s.number, + Pivot: s.previousPivot, Tasks: s.tasks, + Complete: s.complete, AccountSynced: s.accountSynced, AccountBytes: s.accountBytes, BytecodeSynced: s.bytecodeSynced, @@ -976,10 +1090,12 @@ func (s *Syncer) saveSyncStatus() { StorageSynced: s.storageSynced, StorageBytes: s.storageBytes, } - status, err := json.Marshal(progress) + blob, err := json.Marshal(progress) if err != nil { panic(err) // This can only fail during implementation } + // Prepend the version byte so future format changes can be detected on load. + status := append([]byte{syncProgressVersion}, blob...) rawdb.WriteSnapshotSyncStatus(s.db, status) } @@ -1125,7 +1241,7 @@ func (s *Syncer) assignAccountTasks(success chan *accountResponse, fail chan *ac peer.Log().Debug("Failed to request account range", "err", err) s.scheduleRevertAccountRequest(req) } - }(s.root) + }(s.pivot.Root) // Inject the request into the task to block further assignments task.req = req @@ -1354,7 +1470,7 @@ func (s *Syncer) assignStorageTasks(success chan *storageResponse, fail chan *st log.Debug("Failed to request storage", "err", err) s.scheduleRevertStorageRequest(req) } - }(s.root) + }(s.pivot.Root) // Inject the request into the subtask to block further assignments if subtask != nil { @@ -1564,13 +1680,13 @@ func (s *Syncer) scheduleRevertAccessListRequest(req *accessListRequest) { } } -// revertAccessListRequest cleans up an access list request and returns all +// revertAccessListRequest cleans up an BAL request and returns all // failed retrieval tasks to the scheduler for reassignment. func (s *Syncer) revertAccessListRequest(req *accessListRequest) { - log.Debug("Reverting access list request", "peer", req.peer) + log.Debug("Reverting BAL request", "peer", req.peer) select { case <-req.stale: - log.Trace("Access list request already reverted", "peer", req.peer, "reqid", req.id) + log.Trace("BAL request already reverted", "peer", req.peer, "reqid", req.id) return default: } @@ -2024,7 +2140,7 @@ func (s *Syncer) OnAccounts(peer SyncPeer, id uint64, hashes []common.Hash, acco // retrieved was either already pruned remotely, or the peer is not yet // synced to our head. if len(hashes) == 0 && len(accounts) == 0 && len(proof) == 0 { - logger.Debug("Peer rejected account range request", "root", s.root) + logger.Debug("Peer rejected account range request", "root", s.pivot.Root) s.statelessPeers[peer.ID()] = struct{}{} s.lock.Unlock() @@ -2032,7 +2148,7 @@ func (s *Syncer) OnAccounts(peer SyncPeer, id uint64, hashes []common.Hash, acco s.scheduleRevertAccountRequest(req) return nil } - root := s.root + root := s.pivot.Root s.lock.Unlock() // Reconstruct a partial trie from the response and verify it @@ -2326,7 +2442,7 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo return nil } -// OnAccessLists is a callback method to invoke when a batch of access lists +// OnAccessLists is a callback method to invoke when a batch of BALs // are received from a remote peer. func (s *Syncer) OnAccessLists(peer SyncPeer, id uint64, accessLists rlp.RawList[rlp.RawValue]) error { // Convert RawList to slice of raw values @@ -2363,7 +2479,7 @@ func (s *Syncer) OnAccessLists(peer SyncPeer, id uint64, accessLists rlp.RawList req, ok := s.accessListReqs[id] if !ok { // Request stale, perhaps the peer timed out but came through in the end - logger.Warn("Unexpected access list packet") + logger.Warn("Unexpected BAL packet") s.lock.Unlock() return nil } @@ -2380,7 +2496,7 @@ func (s *Syncer) OnAccessLists(peer SyncPeer, id uint64, accessLists rlp.RawList // Response is valid, but check if peer is signalling that it does not have // the requested data. if len(bals) == 0 { - logger.Debug("Peer rejected access list request") + logger.Debug("Peer rejected BAL request") s.statelessPeers[peer.ID()] = struct{}{} s.lock.Unlock() @@ -2479,6 +2595,26 @@ func estimateRemainingSlots(hashes int, last common.Hash) (uint64, error) { return space.Uint64() - uint64(hashes), nil } +// accessListPeersExhausted reports whether forward progress on BAL fetches is +// impossible: at least one peer is connected, every connected peer is marked +// stateless, and no BAL requests are in flight. +func (s *Syncer) accessListPeersExhausted() bool { + s.lock.RLock() + defer s.lock.RUnlock() + if len(s.peers) == 0 { + return false + } + if len(s.accessListReqs) > 0 { + return false + } + for id := range s.peers { + if _, ok := s.statelessPeers[id]; !ok { + return false + } + } + return true +} + // sortIdlePeers builds a list of idle peers sorted by download capacity // (highest first), filtering out stateless peers. Must be called with s.lock held. func (s *Syncer) sortIdlePeers(idlerSet map[string]struct{}, msgCode uint64) *capacitySort { diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index 7ebcd2e3c3..b43b7812a3 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -20,6 +20,7 @@ import ( "bytes" "crypto/rand" "encoding/binary" + "errors" "fmt" "math/big" mrand "math/rand" @@ -178,7 +179,7 @@ func (t *testPeer) ID() string { return t.id } func (t *testPeer) Log() log.Logger { return t.logger } func (t *testPeer) Stats() string { - return fmt.Sprintf(`Account requests: %d Storage requests: %d Bytecode requests: %d`, t.nAccountRequests, t.nStorageRequests, t.nBytecodeRequests) + return fmt.Sprintf(`Account requests: %d Storage requests: %d Bytecode requests: %d`, t.nAccountRequests.Load(), t.nStorageRequests.Load(), t.nBytecodeRequests.Load()) } func (t *testPeer) RequestAccountRange(id uint64, root, origin, limit common.Hash, bytes int) error { @@ -207,7 +208,7 @@ func (t *testPeer) RequestByteCodes(id uint64, hashes []common.Hash, bytes int) } func (t *testPeer) RequestAccessLists(id uint64, hashes []common.Hash, bytes int) error { - t.nAccessListRequests++ + t.nAccessListRequests.Add(1) t.logger.Trace("Fetching set of BALs", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes)) go t.accessListRequestHandler(t, id, hashes, bytes) return nil @@ -592,7 +593,7 @@ func testSyncBloatedProof(t *testing.T, scheme string) { return nil } syncer := setupSyncer(nodeScheme, source) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err == nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err == nil { t.Fatal("No error returned from incomplete/cancelled sync") } } @@ -607,6 +608,33 @@ func setupSyncer(scheme string, peers ...*testPeer) *Syncer { return syncer } +// mkPivot builds a minimal pivot header with the given block number and state +// root, suitable for test calls into Syncer.Sync. +func mkPivot(num uint64, root common.Hash) *types.Header { + return &types.Header{ + Number: new(big.Int).SetUint64(num), + Root: root, + Difficulty: common.Big0, + } +} + +// makeAccessListHeaders builds a header map keyed by block hash where each +// header's BlockAccessListHash matches the BAL it points to. fetchAccessLists +// uses these headers to verify peer responses, so tests need to provide them +// alongside any BALs they expect to be accepted. +func makeAccessListHeaders(bals map[common.Hash]rlp.RawValue) map[common.Hash]*types.Header { + headers := make(map[common.Hash]*types.Header, len(bals)) + for h, raw := range bals { + var b bal.BlockAccessList + if err := rlp.DecodeBytes(raw, &b); err != nil { + continue + } + bh := b.Hash() + headers[h] = &types.Header{BlockAccessListHash: &bh} + } + return headers +} + // TestSync tests a basic sync with one peer func TestSync(t *testing.T) { t.Parallel() @@ -634,7 +662,7 @@ func testSync(t *testing.T, scheme string) { return source } syncer := setupSyncer(nodeScheme, mkSource("source")) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t) @@ -669,7 +697,7 @@ func testSyncTinyTriePanic(t *testing.T, scheme string) { } syncer := setupSyncer(nodeScheme, mkSource("source")) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -704,7 +732,7 @@ func testMultiSync(t *testing.T, scheme string) { } syncer := setupSyncer(nodeScheme, mkSource("sourceA"), mkSource("sourceB")) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -741,7 +769,7 @@ func testSyncWithStorage(t *testing.T, scheme string) { } syncer := setupSyncer(scheme, mkSource("sourceA")) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -791,7 +819,7 @@ func testMultiSyncManyUseless(t *testing.T, scheme string) { mkSource("noStorage", true, false), ) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -846,7 +874,7 @@ func testMultiSyncManyUselessWithLowTimeout(t *testing.T, scheme string) { syncer.rates.OverrideTTLLimit = time.Millisecond done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -899,7 +927,7 @@ func testMultiSyncManyUnresponsive(t *testing.T, scheme string) { syncer.rates.OverrideTTLLimit = time.Millisecond done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -953,7 +981,7 @@ func testSyncBoundaryAccountTrie(t *testing.T, scheme string) { mkSource("peer-b"), ) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -1000,7 +1028,7 @@ func testSyncNoStorageAndOneCappedPeer(t *testing.T, scheme string) { mkSource("capped", true), ) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -1045,7 +1073,7 @@ func testSyncNoStorageAndOneCodeCorruptPeer(t *testing.T, scheme string) { mkSource("corrupt", corruptCodeRequestHandler), ) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -1088,7 +1116,7 @@ func testSyncNoStorageAndOneAccountCorruptPeer(t *testing.T, scheme string) { mkSource("corrupt", corruptAccountRequestHandler), ) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -1134,7 +1162,7 @@ func testSyncNoStorageAndOneCodeCappedPeer(t *testing.T, scheme string) { }), ) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -1186,7 +1214,7 @@ func testSyncBoundaryStorageTrie(t *testing.T, scheme string) { mkSource("peer-b"), ) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -1233,7 +1261,7 @@ func testSyncWithStorageAndOneCappedPeer(t *testing.T, scheme string) { mkSource("slow", true), ) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -1279,7 +1307,7 @@ func testSyncWithStorageAndCorruptPeer(t *testing.T, scheme string) { mkSource("corrupt", corruptStorageRequestHandler), ) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -1322,7 +1350,7 @@ func testSyncWithStorageAndNonProvingPeer(t *testing.T, scheme string) { mkSource("corrupt", noProofStorageRequestHandler), ) done := checkStall(t, term) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } close(done) @@ -1362,7 +1390,7 @@ func testSyncWithStorageMisbehavingProve(t *testing.T, scheme string) { return source } syncer := setupSyncer(nodeScheme, mkSource("sourceA")) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, sourceAccountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } verifyTrie(scheme, syncer.db, sourceAccountTrie.Hash(), t) @@ -1401,7 +1429,7 @@ func testSyncWithUnevenStorage(t *testing.T, scheme string) { return source } syncer := setupSyncer(scheme, mkSource("source")) - if err := syncer.Sync(accountTrie.Hash(), 0, cancel); err != nil { + if err := syncer.Sync(mkPivot(0, accountTrie.Hash()), cancel); err != nil { t.Fatalf("sync failed: %v", err) } verifyTrie(scheme, syncer.db, accountTrie.Hash(), t) @@ -1907,68 +1935,157 @@ func TestSlotEstimation(t *testing.T) { } } -// TestPivotMoveDetection verifies that when the syncer is restarted with a -// different root (simulating the downloader's cancel+restart on pivot move), -// downloadState() returns errPivotStale immediately. -func TestPivotMoveDetection(t *testing.T) { +// TestIsPivotReorged verifies the four conditions isPivotReorged covers: +// reorged out, non-advancing pivot, missing canonical, and the happy path +// where the previous pivot is still canonical and the new pivot advances. +func TestIsPivotReorged(t *testing.T) { t.Parallel() - rootA := common.HexToHash("0xaaaa") - rootB := common.HexToHash("0xbbbb") + // Reorged: canonical hash at prev's height differs from prev. The + // previous pivot was reorged out by an alternate chain at the same + // (or higher) height. + t.Run("Reorged_DifferentHash", func(t *testing.T) { + db := rawdb.NewMemoryDatabase() + prev := mkPivot(100, common.HexToHash("0xaaaa")) + curr := mkPivot(105, common.HexToHash("0xcccc")) + canonical := mkPivot(100, common.HexToHash("0xbbbb")) + rawdb.WriteHeader(db, canonical) + rawdb.WriteCanonicalHash(db, canonical.Hash(), canonical.Number.Uint64()) - db := rawdb.NewMemoryDatabase() - syncer := NewSyncer(db, rawdb.HashScheme) + if !isPivotReorged(db, prev, curr) { + t.Fatal("expected reorg detection when canonical hash differs") + } + }) - // Simulate a previous sync run against rootA with some pending tasks - syncer.root = rootA - syncer.tasks = []*accountTask{ - {Next: common.Hash{}, Last: common.MaxHash, SubTasks: make(map[common.Hash][]*storageTask), stateCompleted: make(map[common.Hash]struct{})}, - } - syncer.saveSyncStatus() + // NonAdvancingPivot: new pivot is at or below the old one. There's + // nothing for catchUp to roll forward, regardless of canonical state. + t.Run("NonAdvancingPivot", func(t *testing.T) { + db := rawdb.NewMemoryDatabase() + prev := mkPivot(100, common.HexToHash("0xaaaa")) + curr := mkPivot(95, common.HexToHash("0xcccc")) + rawdb.WriteHeader(db, prev) + rawdb.WriteCanonicalHash(db, prev.Hash(), prev.Number.Uint64()) - // Simulate downloader restarting us with rootB (as Sync() would do) - syncer.root = rootB - syncer.previousRoot = rootB // Sync() sets this as default - syncer.loadSyncStatus() // Overwrites previousRoot with persisted rootA + if !isPivotReorged(db, prev, curr) { + t.Fatal("expected reorg detection when new pivot is at or below the old one") + } + }) - if syncer.previousRoot != rootA { - t.Fatalf("previousRoot mismatch: got %v, want %v", syncer.previousRoot, rootA) - } - if syncer.root != rootB { - t.Fatalf("root mismatch: got %v, want %v", syncer.root, rootB) - } - // downloadState() should detect the mismatch and return errPivotStale - cancel := make(chan struct{}) - err := syncer.downloadState(cancel) - if err != errPivotStale { - t.Fatalf("expected errPivotStale, got %v", err) - } + // MissingCanonical: canonical hash at prev's height is absent while + // curr advances past it. By the time Sync is called, headers up to + // curr should be indexed, so this implies broken chain state. + t.Run("MissingCanonical", func(t *testing.T) { + db := rawdb.NewMemoryDatabase() + prev := mkPivot(100, common.HexToHash("0xaaaa")) + curr := mkPivot(105, common.HexToHash("0xcccc")) + + if !isPivotReorged(db, prev, curr) { + t.Fatal("expected reorg detection when canonical hash is missing at prev's height") + } + }) + + // NotReorged_SameHash: prev is still canonical and curr advances past + // it. Catch-up is feasible. + t.Run("NotReorged_SameHash", func(t *testing.T) { + db := rawdb.NewMemoryDatabase() + prev := mkPivot(100, common.HexToHash("0xaaaa")) + curr := mkPivot(105, common.HexToHash("0xcccc")) + rawdb.WriteHeader(db, prev) + rawdb.WriteCanonicalHash(db, prev.Hash(), prev.Number.Uint64()) + + if isPivotReorged(db, prev, curr) { + t.Fatal("should not detect reorg when prev is canonical and curr advances") + } + }) } -// TestCatchUpInvertedRange verifies that catchUp returns an error and wipes -// sync progress when the new pivot is at the same (or lower) block number as -// the old pivot.. -func TestCatchUpInvertedRange(t *testing.T) { +// TestSyncDetectsPivotReorged exercises the reorg-handling branch in Sync +// end-to-end. +// +// Setup: persisted progress points at an orphan pivot at block 100; the new +// canonical header at block 100 has a different hash. Sync is then called with +// a new pivot at the same height. +// +// If isPivotReorged works, loadSyncStatus restores previousPivot, the check +// flags it as reorged, resetSyncState clears previousPivot, catchUp is +// skipped, and the fresh download proceeds to completion. +// +// If detection doesn't fire, the pivot-move check would call catchUp with +// from = 101 and to = 100 — the inverted-range guard surfaces that as an +// error, failing the test. So Sync returning nil is the positive signal that +// reorg detection and the reset worked. +func TestSyncDetectsPivotReorged(t *testing.T) { t.Parallel() + + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100, rawdb.HashScheme) + root := sourceAccountTrie.Hash() + db := rawdb.NewMemoryDatabase() - syncer := NewSyncer(db, rawdb.HashScheme) - // Simulate: old pivot at block 100, new pivot at block 100 (same number, - // different root). This happens when a reorg replaces the pivot block. - syncer.previousNumber = 100 - syncer.number = 100 + // Persist progress against an orphan pivot — same height as the new + // canonical pivot we'll sync to, different hash. Populate a partial task + // and non-zero counter so the reset path has something to clean up. + orphanPivot := mkPivot(100, common.HexToHash("0xdead")) + seed := NewSyncer(db, nodeScheme) + // previousPivot reflects where flat state matches and it is what + // saveSyncStatus persists. Set it to simulate a prior sync reaching + // orphanPivot. + seed.previousPivot = orphanPivot + seed.pivot = orphanPivot + seed.accountSynced = 42 + seed.tasks = []*accountTask{{ + Next: common.HexToHash("0x80"), + Last: common.MaxHash, + SubTasks: make(map[common.Hash][]*storageTask), + stateCompleted: make(map[common.Hash]struct{}), + }} + seed.saveSyncStatus() - // Write some sync progress so we can verify it gets wiped - rawdb.WriteSnapshotSyncStatus(db, []byte("some progress")) - cancel := make(chan struct{}) - err := syncer.catchUp(cancel) - if err == nil { - t.Fatal("expected error from catchUp with inverted range") + // Pre-write orphan flat-state entries at hashes the test peer won't + // re-serve. After resetSyncState wipes the snapshot ranges, these + // should be gone. + orphanAccountHash := common.HexToHash("0xdeadbeef") + rawdb.WriteAccountSnapshot(db, orphanAccountHash, []byte{0xde, 0xad}) + orphanStorageAccount := common.HexToHash("0xfeedfacefeedfacefeedfacefeedfacefeedfacefeedfacefeedfacefeedface") + orphanStorageSlot := common.HexToHash("0xabcd") + rawdb.WriteStorageSnapshot(db, orphanStorageAccount, orphanStorageSlot, []byte{0xff, 0xff}) + + // Canonical header at block 100 is newPivot — different hash from the + // orphan pivot, which is what isPivotReorged will detect. + newPivot := mkPivot(100, root) + rawdb.WriteHeader(db, newPivot) + rawdb.WriteCanonicalHash(db, newPivot.Hash(), newPivot.Number.Uint64()) + + var ( + once sync.Once + cancel = make(chan struct{}) + term = func() { once.Do(func() { close(cancel) }) } + ) + syncer := NewSyncer(db, nodeScheme) + src := newTestPeer("source", t, term) + src.accountTrie = sourceAccountTrie.Copy() + src.accountValues = elems + syncer.Register(src) + src.remote = syncer + + if err := syncer.Sync(newPivot, cancel); err != nil { + t.Fatalf("sync failed (reorg detection likely broken): %v", err) } - - // Verify sync progress was wiped - if status := rawdb.ReadSnapshotSyncStatus(db); status != nil { - t.Fatal("sync progress should be wiped after inverted catch-up range") + // After successful completion, status should be marked Complete=true + // against the new (canonical) pivot. + loader := NewSyncer(db, nodeScheme) + loader.loadSyncStatus() + if !loader.complete { + t.Fatal("sync status should be marked Complete=true after successful completion") + } + if loader.previousPivot == nil || loader.previousPivot.Hash() != newPivot.Hash() { + t.Fatalf("expected persisted pivot to match new pivot") + } + if data := rawdb.ReadAccountSnapshot(db, orphanAccountHash); len(data) != 0 { + t.Errorf("orphan account snapshot should be wiped, got %x", data) + } + if val := rawdb.ReadStorageSnapshot(db, orphanStorageAccount, orphanStorageSlot); len(val) != 0 { + t.Errorf("orphan storage snapshot should be wiped, got %x", val) } } @@ -2007,8 +2124,9 @@ func testInterruptedDownloadRecovery(t *testing.T, scheme string) { src1.accountRequestHandler = cancelAfterHandler syncer1.Register(src1) src1.remote = syncer1 - syncer1.root = root - syncer1.previousRoot = root + pivot := mkPivot(0, root) + syncer1.pivot = pivot + syncer1.previousPivot = pivot // Sync sets this before downloadState syncer1.loadSyncStatus() syncer1.downloadState(cancel1) @@ -2044,8 +2162,9 @@ func testInterruptedDownloadRecovery(t *testing.T, scheme string) { src2.accountValues = elems syncer2.Register(src2) src2.remote = syncer2 - syncer2.root = root - syncer2.previousRoot = root + pivot2 := mkPivot(0, root) + syncer2.pivot = pivot2 + syncer2.previousPivot = pivot2 // Sync sets this before downloadState syncer2.loadSyncStatus() if err := syncer2.downloadState(cancel2); err != nil { t.Fatalf("resumed download failed: %v", err) @@ -2059,6 +2178,52 @@ func testInterruptedDownloadRecovery(t *testing.T, scheme string) { } } +// TestSyncPersistsPivotDuringDownload verifies that after a fresh Sync is +// interrupted mid-download, the persisted previousPivot equals the current +// pivot (not nil). Without this, a follow-up Sync at a different pivot +// would not see that the partial flat state belongs to the old pivot, and +// would mix old-pivot accounts with new-pivot data. +func TestSyncPersistsPivotDuringDownload(t *testing.T) { + t.Parallel() + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100, rawdb.HashScheme) + + var ( + once sync.Once + cancel = make(chan struct{}) + term = func() { once.Do(func() { close(cancel) }) } + responses atomic.Int32 + ) + db := rawdb.NewMemoryDatabase() + syncer := NewSyncer(db, nodeScheme) + src := newTestPeer("source", t, term) + src.accountTrie = sourceAccountTrie.Copy() + src.accountValues = elems + src.accountRequestHandler = func(tp *testPeer, id uint64, root common.Hash, origin common.Hash, limit common.Hash, cap int) error { + if responses.Add(1) > 2 { + term() + return nil + } + return defaultAccountRequestHandler(tp, id, root, origin, limit, cap) + } + syncer.Register(src) + src.remote = syncer + + pivot := mkPivot(0, sourceAccountTrie.Hash()) + // Sync should be interrupted by the cancel after a couple of responses. + _ = syncer.Sync(pivot, cancel) + + // Persisted previousPivot must equal the pivot, so a follow-up Sync at a + // different pivot can recognize the partial flat state belongs to this one. + loader := NewSyncer(db, nodeScheme) + loader.loadSyncStatus() + if loader.previousPivot == nil { + t.Fatal("expected persisted previousPivot to be set after interrupted download, got nil") + } + if loader.previousPivot.Hash() != pivot.Hash() { + t.Errorf("persisted previousPivot mismatch: got %v, want %v", loader.previousPivot.Hash(), pivot.Hash()) + } +} + // TestPivotMovement verifies the full pivot move flow: download with rootA, // cancel+restart with rootB, catch-up applies BAL diffs, download resumes // and completes against the new state. @@ -2179,7 +2344,7 @@ func testPivotMovement(t *testing.T, scheme string, pivotMoves int) { } syncer1.Register(src1) src1.remote = syncer1 - syncer1.Sync(rootA, numA, cancel1) + syncer1.Sync(mkPivot(numA, rootA), cancel1) // Subsequent runs: each move triggers catch-up then resumes download for i, move := range moves { @@ -2195,7 +2360,7 @@ func testPivotMovement(t *testing.T, scheme string, pivotMoves int) { src.accessLists = move.bals syncer.Register(src) src.remote = syncer - if err := syncer.Sync(move.root, move.blockNum, cancel); err != nil { + if err := syncer.Sync(mkPivot(move.blockNum, move.root), cancel); err != nil { t.Fatalf("pivot move %d: sync failed: %v", i+1, err) } @@ -2214,16 +2379,151 @@ func testPivotMovement(t *testing.T, scheme string, pivotMoves int) { } } -// TestSyncStatusClearedAfterCompletion verifies that the persisted sync status -// is cleared after a full sync completes (download + trie rebuild), so the -// next Sync() call starts fresh. -func TestSyncStatusClearedAfterCompletion(t *testing.T) { +// TestCatchUpPersistsIncrementally verifies that catchUp updates and persists +// previousPivot after each successfully applied BAL. If a later block in the +// gap fails to apply, the persisted state reflects the last successful block, +// so a follow-up Sync can resume from there rather than reapplying everything. +func TestCatchUpPersistsIncrementally(t *testing.T) { t.Parallel() - testSyncStatusClearedAfterCompletion(t, rawdb.HashScheme) - testSyncStatusClearedAfterCompletion(t, rawdb.PathScheme) + + nodeScheme, sourceAccountTrie, elems, addrs := makeAccountTrieWithAddresses(100, rawdb.HashScheme) + rootA := sourceAccountTrie.Hash() + numA := uint64(100) + + goodAddr := addrs[0] + corruptAddr := addrs[1] + + type balBlock struct { + header *types.Header + bal rlp.RawValue + } + + db := rawdb.NewMemoryDatabase() + emptyHash := common.Hash{} + zero := uint64(0) + + // Write the header and canonical hash for block A so the reorg-detection + // canonical-lookup in Sync passes (otherwise it'd treat A as reorged out + // and reset instead of running catchUp). + pivotAHeader := &types.Header{ + Number: new(big.Int).SetUint64(numA), Root: rootA, Difficulty: common.Big0, + BaseFee: common.Big0, WithdrawalsHash: &emptyHash, + BlobGasUsed: &zero, ExcessBlobGas: &zero, + ParentBeaconRoot: &emptyHash, RequestsHash: &emptyHash, + } + rawdb.WriteHeader(db, pivotAHeader) + rawdb.WriteCanonicalHash(db, pivotAHeader.Hash(), numA) + pivotA := pivotAHeader + + // Build three sequential BAL blocks (A+1, A+2, A+3). The first two touch + // goodAddr, the third touches corruptAddr so that block's apply fails + // once we've corrupted that account's snapshot. + blocks := make([]balBlock, 3) + for i := 0; i < 3; i++ { + blockNum := numA + uint64(i) + 1 + target := goodAddr + if i == 2 { + target = corruptAddr + } + balance := uint256.NewInt(uint64(1000 * (i + 1))) + + cb := bal.NewConstructionBlockAccessList() + cb.BalanceChange(0, target, balance) + var buf bytes.Buffer + if err := cb.EncodeRLP(&buf); err != nil { + t.Fatal(err) + } + var b bal.BlockAccessList + if err := rlp.DecodeBytes(buf.Bytes(), &b); err != nil { + t.Fatal(err) + } + balHash := b.Hash() + header := &types.Header{ + Number: new(big.Int).SetUint64(blockNum), Difficulty: common.Big0, + BaseFee: common.Big0, WithdrawalsHash: &emptyHash, + BlobGasUsed: &zero, ExcessBlobGas: &zero, + ParentBeaconRoot: &emptyHash, RequestsHash: &emptyHash, + BlockAccessListHash: &balHash, + } + rawdb.WriteHeader(db, header) + rawdb.WriteCanonicalHash(db, header.Hash(), blockNum) + blocks[i] = balBlock{header: header, bal: buf.Bytes()} + } + + // First sync: complete sync to A so persisted state has previousPivot=A, + // flat state covers all accounts. + { + var ( + once sync.Once + cancel = make(chan struct{}) + term = func() { once.Do(func() { close(cancel) }) } + ) + syncer := NewSyncer(db, nodeScheme) + src := newTestPeer("seed", t, term) + src.accountTrie = sourceAccountTrie.Copy() + src.accountValues = elems + syncer.Register(src) + src.remote = syncer + if err := syncer.Sync(pivotA, cancel); err != nil { + t.Fatalf("seed sync failed: %v", err) + } + } + + // Corrupt the flat-state snapshot for corruptAddr so applyAccessList will + // fail when block A+3's BAL touches it. types.FullAccount rejects this + // payload as undecodable. + rawdb.WriteAccountSnapshot(db, crypto.Keccak256Hash(corruptAddr[:]), []byte{0xff, 0xff, 0xff, 0xff}) + + // Second sync: target is A+3. catchUp should apply A+1 and A+2 (good + // account), persist after each, then fail on A+3 (corrupt account). + pivotB := blocks[2].header + balsByHash := map[common.Hash]rlp.RawValue{ + blocks[0].header.Hash(): blocks[0].bal, + blocks[1].header.Hash(): blocks[1].bal, + blocks[2].header.Hash(): blocks[2].bal, + } + + var ( + once sync.Once + cancel = make(chan struct{}) + term = func() { once.Do(func() { close(cancel) }) } + ) + syncer := NewSyncer(db, nodeScheme) + src := newTestPeer("catchup", t, term) + src.accountTrie = sourceAccountTrie.Copy() + src.accountValues = elems + src.accessLists = balsByHash + syncer.Register(src) + src.remote = syncer + + if err := syncer.Sync(pivotB, cancel); err == nil { + t.Fatal("expected Sync to fail when applyAccessList hits corrupt flat state") + } + + // Persisted previousPivot should now reflect the last successfully applied + // block (A+2). Without per-iteration saves, it would still be at A. + loader := NewSyncer(db, nodeScheme) + loader.loadSyncStatus() + if loader.previousPivot == nil { + t.Fatal("expected persisted previousPivot to be set after partial catchUp") + } + wantHash := blocks[1].header.Hash() + if loader.previousPivot.Hash() != wantHash { + t.Errorf("persisted previousPivot mismatch after partial catchUp: got %v, want %v (block A+2)", + loader.previousPivot.Hash(), wantHash) + } } -func testSyncStatusClearedAfterCompletion(t *testing.T, scheme string) { +// TestSyncStatusMarkedCompleteAfterCompletion verifies that after a full sync +// completes, the persisted sync status has Complete=true. This lets a +// subsequent Sync call distinguish "already done" from "fresh node" and skip. +func TestSyncStatusMarkedCompleteAfterCompletion(t *testing.T) { + t.Parallel() + testSyncStatusMarkedCompleteAfterCompletion(t, rawdb.HashScheme) + testSyncStatusMarkedCompleteAfterCompletion(t, rawdb.PathScheme) +} + +func testSyncStatusMarkedCompleteAfterCompletion(t *testing.T, scheme string) { var ( once sync.Once cancel = make(chan struct{}) @@ -2238,12 +2538,61 @@ func testSyncStatusClearedAfterCompletion(t *testing.T, scheme string) { return source } syncer := setupSyncer(nodeScheme, mkSource("source")) - if err := syncer.Sync(sourceAccountTrie.Hash(), 0, cancel); err != nil { + pivot := mkPivot(0, sourceAccountTrie.Hash()) + if err := syncer.Sync(pivot, cancel); err != nil { t.Fatalf("sync failed: %v", err) } - // After successful sync, status should be cleared - if status := rawdb.ReadSnapshotSyncStatus(syncer.db); status != nil { - t.Fatal("sync status should be nil after successful completion") + + // After successful sync, persisted status should be present with + // Complete=true and the pivot we synced to. + loader := NewSyncer(syncer.db, nodeScheme) + loader.loadSyncStatus() + if !loader.complete { + t.Fatal("expected persisted status to have Complete=true after successful sync") + } + if loader.previousPivot == nil || loader.previousPivot.Hash() != pivot.Hash() { + t.Fatalf("expected persisted pivot to match synced pivot") + } +} + +// TestSyncSkipsIfAlreadyComplete verifies that a follow-up Sync call for the +// same pivot returns immediately without doing any work, since the persisted +// status indicates the sync is already complete. To prove the skip path actually +// fires, we deliberately wipe the flat state between the two calls. If it skips, +// Sync returns nil without touching flat state. If it doesn't kip, GenerateTrie +// would run against an empty snapshot and fail with a root mismatch. +func TestSyncSkipsIfAlreadyComplete(t *testing.T) { + t.Parallel() + + nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100, rawdb.HashScheme) + pivot := mkPivot(0, sourceAccountTrie.Hash()) + + var ( + once1 sync.Once + cancel1 = make(chan struct{}) + term1 = func() { once1.Do(func() { close(cancel1) }) } + ) + src1 := newTestPeer("source1", t, term1) + src1.accountTrie = sourceAccountTrie.Copy() + src1.accountValues = elems + syncer := setupSyncer(nodeScheme, src1) + if err := syncer.Sync(pivot, cancel1); err != nil { + t.Fatalf("first sync failed: %v", err) + } + + // Wipe the flat state. The persisted status (with Complete=true) stays. + if err := syncer.db.DeleteRange(rawdb.SnapshotAccountPrefix, []byte{rawdb.SnapshotAccountPrefix[0] + 1}); err != nil { + t.Fatalf("failed to wipe account snapshot: %v", err) + } + if err := syncer.db.DeleteRange(rawdb.SnapshotStoragePrefix, []byte{rawdb.SnapshotStoragePrefix[0] + 1}); err != nil { + t.Fatalf("failed to wipe storage snapshot: %v", err) + } + + // Second sync must take the skip path. If it didn't, the empty flat + // state would cause GenerateTrie to fail with a root mismatch. + cancel2 := make(chan struct{}) + if err := syncer.Sync(pivot, cancel2); err != nil { + t.Fatalf("second sync should have skipped, got error: %v", err) } } @@ -2270,8 +2619,9 @@ func TestInterruptedRebuildRecovery(t *testing.T) { src1.accountValues = elems syncer1.Register(src1) src1.remote = syncer1 - syncer1.root = root - syncer1.previousRoot = root + pivot := mkPivot(0, root) + syncer1.pivot = pivot + syncer1.previousPivot = pivot // Sync sets this before downloadState syncer1.loadSyncStatus() if err := syncer1.downloadState(cancel1); err != nil { @@ -2301,12 +2651,14 @@ func TestInterruptedRebuildRecovery(t *testing.T) { syncer2.Register(src2) src2.remote = syncer2 - if err := syncer2.Sync(root, 0, cancel2); err != nil { + if err := syncer2.Sync(mkPivot(0, root), cancel2); err != nil { t.Fatalf("resumed sync failed: %v", err) } - // After rebuild completes, status should be cleared - if status := rawdb.ReadSnapshotSyncStatus(db); status != nil { - t.Fatal("sync status should be nil after rebuild completes") + // After rebuild completes, status should be marked Complete=true. + loader := NewSyncer(db, nodeScheme) + loader.loadSyncStatus() + if !loader.complete { + t.Fatal("sync status should be marked Complete=true after rebuild completes") } } @@ -2340,7 +2692,7 @@ func TestFetchAccessListsMultiplePeers(t *testing.T) { return source } syncer := setupSyncer(rawdb.HashScheme, mkSource("peer-a"), mkSource("peer-b"), mkSource("peer-c")) - results, err := syncer.fetchAccessLists(hashes, cancel) + results, err := syncer.fetchAccessLists(hashes, makeAccessListHeaders(bals), cancel) if err != nil { t.Fatalf("fetchAccessLists failed: %v", err) } @@ -2386,7 +2738,7 @@ func TestFetchAccessListsPeerTimeout(t *testing.T) { good.accessLists = bals syncer := setupSyncer(rawdb.HashScheme, nonResponsive, good) syncer.rates.OverrideTTLLimit = time.Millisecond // Fast timeout - results, err := syncer.fetchAccessLists(hashes, cancel) + results, err := syncer.fetchAccessLists(hashes, makeAccessListHeaders(bals), cancel) if err != nil { t.Fatalf("fetchAccessLists failed: %v", err) } @@ -2422,7 +2774,7 @@ func TestFetchAccessListsPeerRejection(t *testing.T) { good := newTestPeer("good", t, term) good.accessLists = bals syncer := setupSyncer(rawdb.HashScheme, rejector, good) - results, err := syncer.fetchAccessLists(hashes, cancel) + results, err := syncer.fetchAccessLists(hashes, makeAccessListHeaders(bals), cancel) if err != nil { t.Fatalf("fetchAccessLists failed: %v", err) } @@ -2450,7 +2802,7 @@ func TestFetchAccessListsCancel(t *testing.T) { time.Sleep(50 * time.Millisecond) close(cancel) }() - _, err := syncer.fetchAccessLists(hashes, cancel) + _, err := syncer.fetchAccessLists(hashes, nil, cancel) if err != ErrCancelled { t.Fatalf("expected ErrCancelled, got %v", err) } @@ -2487,7 +2839,7 @@ func TestFetchAccessListsPeerDrop(t *testing.T) { good := newTestPeer("good", t, term) good.accessLists = bals syncer := setupSyncer(rawdb.HashScheme, dropped, good) - results, err := syncer.fetchAccessLists(hashes, cancel) + results, err := syncer.fetchAccessLists(hashes, makeAccessListHeaders(bals), cancel) if err != nil { t.Fatalf("fetchAccessLists failed: %v", err) } @@ -2561,7 +2913,7 @@ func TestFetchAccessListsShortResponse(t *testing.T) { fetchErr error ) go func() { - results, fetchErr = syncer.fetchAccessLists(hashes, cancel) + results, fetchErr = syncer.fetchAccessLists(hashes, makeAccessListHeaders(allBALs), cancel) close(done) }() @@ -2647,7 +2999,7 @@ func TestFetchAccessListsEmptyPlaceholder(t *testing.T) { fetchErr error ) go func() { - results, fetchErr = syncer.fetchAccessLists(hashes, cancel) + results, fetchErr = syncer.fetchAccessLists(hashes, makeAccessListHeaders(allBALs), cancel) close(done) }() @@ -2671,6 +3023,117 @@ func TestFetchAccessListsEmptyPlaceholder(t *testing.T) { } } +// TestFetchAccessListsRejectsBadBAL verifies that when a peer delivers a BAL +// whose hash doesn't match the canonical block header, fetchAccessLists marks +// the peer stateless, drops the response, and surfaces the exhaustion error +// once no other peers can serve the work. +func TestFetchAccessListsRejectsBadBAL(t *testing.T) { + t.Parallel() + var ( + once sync.Once + cancel = make(chan struct{}) + term = func() { once.Do(func() { close(cancel) }) } + ) + hash := common.HexToHash("0x01") + hashes := []common.Hash{hash} + + // Build a BAL we'll actually serve. + cb := bal.NewConstructionBlockAccessList() + cb.BalanceChange(0, common.HexToAddress("0xaa"), uint256.NewInt(42)) + var buf bytes.Buffer + if err := cb.EncodeRLP(&buf); err != nil { + t.Fatal(err) + } + served := buf.Bytes() + + // Build a header whose BlockAccessListHash points at something else, so + // the served BAL fails verification. + mismatch := common.HexToHash("0xdeadbeef") + headers := map[common.Hash]*types.Header{ + hash: {BlockAccessListHash: &mismatch}, + } + + peer := newTestPeer("liar", t, term) + peer.accessLists = map[common.Hash]rlp.RawValue{hash: served} + syncer := setupSyncer(rawdb.HashScheme, peer) + + results, err := syncer.fetchAccessLists(hashes, headers, cancel) + if !errors.Is(err, errAccessListPeersExhausted) { + t.Fatalf("expected errAccessListPeersExhausted, got %v", err) + } + if results != nil { + t.Errorf("expected nil results on error, got %v", results) + } + syncer.lock.RLock() + _, stateless := syncer.statelessPeers[peer.id] + syncer.lock.RUnlock() + if !stateless { + t.Error("expected liar peer to be marked stateless after bad BAL") + } +} + +// TestCatchUpRetriesOnBadBAL verifies that when one peer serves a BAL that +// fails verification but another serves a valid one, fetchAccessLists routes +// the work around the bad peer and returns the verified BAL. +func TestCatchUpRetriesOnBadBAL(t *testing.T) { + t.Parallel() + var ( + once sync.Once + cancel = make(chan struct{}) + term = func() { once.Do(func() { close(cancel) }) } + ) + hash := common.HexToHash("0x01") + hashes := []common.Hash{hash} + + cb := bal.NewConstructionBlockAccessList() + cb.BalanceChange(0, common.HexToAddress("0xaa"), uint256.NewInt(42)) + var buf bytes.Buffer + if err := cb.EncodeRLP(&buf); err != nil { + t.Fatal(err) + } + good := buf.Bytes() + + // A second BAL with different content used as the "bad" payload. It + // decodes cleanly but its hash will not match the header. + other := bal.NewConstructionBlockAccessList() + other.BalanceChange(0, common.HexToAddress("0xbb"), uint256.NewInt(99)) + var otherBuf bytes.Buffer + if err := other.EncodeRLP(&otherBuf); err != nil { + t.Fatal(err) + } + bad := otherBuf.Bytes() + + headers := makeAccessListHeaders(map[common.Hash]rlp.RawValue{hash: good}) + + liar := newTestPeer("liar", t, term) + liar.accessLists = map[common.Hash]rlp.RawValue{hash: bad} + honest := newTestPeer("honest", t, term) + honest.accessLists = map[common.Hash]rlp.RawValue{hash: good} + + syncer := setupSyncer(rawdb.HashScheme, liar, honest) + // Bias the capacity sort so the liar is asked first, exercising the + // reject-and-retry path rather than getting lucky on assignment order. + syncer.rates.Update(liar.id, AccessListsMsg, time.Millisecond, 1000) + + results, err := syncer.fetchAccessLists(hashes, headers, cancel) + if err != nil { + t.Fatalf("fetchAccessLists failed: %v", err) + } + if !bytes.Equal(results[0], good) { + t.Errorf("expected the honest BAL, got %x", results[0]) + } + syncer.lock.RLock() + _, liarStateless := syncer.statelessPeers[liar.id] + _, honestStateless := syncer.statelessPeers[honest.id] + syncer.lock.RUnlock() + if !liarStateless { + t.Error("expected liar to be marked stateless") + } + if honestStateless { + t.Error("expected honest peer to remain in good standing") + } +} + func newDbConfig(scheme string) *triedb.Config { if scheme == rawdb.HashScheme { return &triedb.Config{}