diff --git a/core/blockchain.go b/core/blockchain.go index 5164f9a1e5..5494b0345d 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -1182,7 +1182,7 @@ func (bc *BlockChain) SnapSyncComplete(hash common.Hash) error { // Set up the snapshot tree from the synced flat state. Snap/2 downloads // flat state directly as the snapshot. if bc.snaps != nil { - bc.snaps.RebuildFromSyncedState(root) + bc.snaps.InitFromSyncedState(root) } // If all checks out, manually set the head block. diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index 4a69ec3f49..91695c7eb3 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -727,17 +727,25 @@ func (t *Tree) Rebuild(root common.Hash) { } } -// RebuildFromSyncedState sets up the snapshot tree to use flat state that was +// InitFromSyncedState sets up the snapshot tree to use flat state that was // already downloaded by snap sync. Unlike Rebuild, it does NOT regenerate the // snapshot from the trie. -func (t *Tree) RebuildFromSyncedState(root common.Hash) { +func (t *Tree) InitFromSyncedState(root common.Hash) { t.lock.Lock() defer t.lock.Unlock() + + // Delete any recovery flag in the database. rawdb.DeleteSnapshotRecoveryNumber(t.diskdb) rawdb.DeleteSnapshotDisabled(t.diskdb) + + // Write the new root. rawdb.WriteSnapshotRoot(t.diskdb, root) + + // Clear the journal. journalProgress(t.diskdb, nil, nil) log.Info("Setting up snapshot from synced state", "root", root) + + // Replace t.layers with a single diskLayer pointing at the root. t.layers = map[common.Hash]snapshot{ root: &diskLayer{ diskdb: t.diskdb, diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 11da4d0bb6..e291042e53 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -868,24 +868,31 @@ func (d *Downloader) importBlockResults(results []*fetchResult) error { } // checkDeepReorg checks if the old pivot block was reorged by comparing its -// state root against the current canonical chain. If the canonical header at -// the old pivot's block number has a different state root, the syncer's flat -// state is from the old fork and must be wiped. Returns true if a deep reorg -// was detected. +// state root against the current canonical chain. Returns true if the +// canonical header at the old pivot's block number has a different state root, +// meaning the syncer's flat state is from the old fork and must be wiped. // -// Returns false (no reorg) when the canonical hash or header is missing. This -// avoids false positives from pruned or not-yet-downloaded data. If the chain +// Returns false conservatively when canonical data is missing. If the chain // really did shorten past the old pivot, sync.catchUp's from > to guard will // catch this. func checkDeepReorg(db ethdb.Database, oldNumber uint64, oldRoot common.Hash) bool { + // No canonical hash at the old pivot height. This could mean the chain was + // reorged to a shorter fork, or that headers for this height haven't been + // downloaded yet. Can't tell the two apart here, so don't wipe. oldHash := rawdb.ReadCanonicalHash(db, oldNumber) if oldHash == (common.Hash{}) { return false } + + // Canonical hash exists but the header is missing (pruned or corrupted). + // Nothing to compare against, so don't wipe. oldHeader := rawdb.ReadHeader(db, oldHash, oldNumber) if oldHeader == nil { return false } + + // Canonical root at this height differs from what we were syncing against — + // the old pivot was reorged out. return oldHeader.Root != oldRoot } diff --git a/eth/protocols/snap/bal_apply.go b/eth/protocols/snap/bal_apply.go index 5ec9d21420..5bb388f9ba 100644 --- a/eth/protocols/snap/bal_apply.go +++ b/eth/protocols/snap/bal_apply.go @@ -48,6 +48,7 @@ func verifyAccessList(b *bal.BlockAccessList, header *types.Header) error { func (s *Syncer) applyAccessList(b *bal.BlockAccessList) error { batch := s.db.NewBatch() + // Iterate over all accounts in the access list for _, access := range b.Accesses { addr := common.Address(access.Address) accountHash := crypto.Keccak256Hash(addr[:]) diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index 26545f2960..14463757bc 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -141,7 +141,7 @@ var snap2 = map[uint64]msgHandler{ GetByteCodesMsg: handleGetByteCodes, ByteCodesMsg: handleByteCodes, GetAccessListsMsg: handleGetAccessLists, - // AccessListsMsg: TODO + AccessListsMsg: handleAccessLists, } // HandleMessage is invoked whenever an inbound message is received from a diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go index 5537c441e8..3c33709fca 100644 --- a/eth/protocols/snap/sync.go +++ b/eth/protocols/snap/sync.go @@ -520,13 +520,13 @@ func (s *Syncer) Sync(root common.Hash, number uint64, cancel chan struct{}) err log.Info("Starting state download", "root", root) for { // Download: fetch all required state data - err := s.download(cancel) + err := s.downloadState(cancel) if err == errPivotStale { // Pivot moved: catch up to new pivot if err := s.catchUp(cancel); err != nil { return err } - s.resetDownload(root, number) + s.resetDownloadState(root, number) log.Info("Resuming state download", "root", root) continue } @@ -558,7 +558,7 @@ func (s *Syncer) Sync(root common.Hash, number uint64, cancel chan struct{}) err // download runs the bulk flat-state download. It fetches // account ranges, storage slots, and bytecodes, writing flat state to disk. -func (s *Syncer) download(cancel chan struct{}) error { +func (s *Syncer) downloadState(cancel chan struct{}) error { // If the pivot moved since the last run (downloader cancelled and restarted // us with a new root), signal catch-up before downloading. if s.previousRoot != s.root { @@ -638,14 +638,14 @@ func (s *Syncer) download(cancel chan struct{}) error { } } -// resetDownload resets the download state for a new pivot after catch-up. +// resetDownloadState resets the download state for a new pivot after catch-up. // It regenerates the task list for accounts not yet downloaded, clears // in-flight requests, and updates the root. -func (s *Syncer) resetDownload(root common.Hash, number uint64) { +func (s *Syncer) resetDownloadState(root common.Hash, number uint64) { s.lock.Lock() s.root = root s.number = number - s.previousRoot = root // Prevent download() from returning errPivotStale again + s.previousRoot = root // Prevent downloadState() from returning errPivotStale again s.previousNumber = number // Clear stateless peers bc they may be able to serve the new pivot @@ -662,16 +662,13 @@ func (s *Syncer) catchUp(cancel chan struct{}) error { to := s.number s.lock.RUnlock() - // The new pivot must be ahead of the old one. This can fail if a reorg - // replaced the block at the pivot height (same number, different root) - // or if a deep reorg shortened the chain past the old pivot. In either - // case, catch-up can't roll forward, so wipe progress and return an - // error so the caller restarts with a fresh sync. - // - // Note: this check lives here rather than in checkDeepReorg because - // catchUp is reached both when the downloader actively moves the pivot - // (via restartSnapSync) and when the syncer resumes from persisted - // progress after a restart. checkDeepReorg only covers the former. + // The new pivot must be ahead of the old one. The range is inverted if + // a reorg replaced the block at the pivot height (same number, different + // root) or if the chain shortened past the old pivot. In either case, + // catch-up can't roll forward — wipe progress and restart. This also + // catches reorgs missed by checkDeepReorg, which only runs when the + // downloader actively restarts the syncer, not on resume from persisted + // progress. if from > to { log.Warn("Catch-up range inverted, wiping sync progress", "from", from, "to", to) rawdb.WriteSnapshotSyncStatus(s.db, nil) @@ -745,8 +742,6 @@ func (s *Syncer) fetchAccessLists(hashes []common.Hash, cancel chan struct{}) ([ pending[h] = struct{}{} } fetched := make(map[common.Hash]rlp.RawValue, len(hashes)) - - // Create ephemeral channels for this fetch cycle var ( accessListReqFails = make(chan *accessListRequest) accessListResps = make(chan *accessListResponse) @@ -785,6 +780,7 @@ func (s *Syncer) fetchAccessLists(hashes []common.Hash, cancel chan struct{}) ([ s.processAccessListResponse(res, pending, fetched) } } + // Assemble results in input order results := make([]rlp.RawValue, len(hashes)) for i, h := range hashes { @@ -869,15 +865,19 @@ func (s *Syncer) assignAccessListTasks(pending map[common.Hash]struct{}, success // processAccessListResponse handles a successful access list response by // matching results to pending hashes and storing them. func (s *Syncer) processAccessListResponse(res *accessListResponse, pending map[common.Hash]struct{}, fetched map[common.Hash]rlp.RawValue) { - // Each response entry corresponds to the requested hash at the same index + // Each response entry corresponds to the requested hash at the same index. for i, raw := range res.accessLists { - if i >= len(res.req.hashes) { - break - } h := res.req.hashes[i] + + // Peer doesn't have this BAL. Add it back to pending for retry. + if bytes.Equal(raw, rlp.EmptyString) { + pending[h] = struct{}{} + continue + } fetched[h] = raw delete(pending, h) } + // Re-add hashes that were not served back to pending for i := len(res.accessLists); i < len(res.req.hashes); i++ { pending[res.req.hashes[i]] = struct{}{} @@ -2388,6 +2388,12 @@ func (s *Syncer) OnAccessLists(peer SyncPeer, id uint64, accessLists rlp.RawList s.scheduleRevertAccessListRequest(req) return nil } + if len(bals) > len(req.hashes) { + s.lock.Unlock() + s.scheduleRevertAccessListRequest(req) + logger.Warn("Peer sent more BALs than requested", "count", len(bals), "requested", len(req.hashes)) + return errors.New("more BALs than requested") + } s.lock.Unlock() // Response validated, send it to the scheduler for filling. diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index 813360fc61..7ebcd2e3c3 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -1909,7 +1909,7 @@ func TestSlotEstimation(t *testing.T) { // TestPivotMoveDetection verifies that when the syncer is restarted with a // different root (simulating the downloader's cancel+restart on pivot move), -// download() returns errPivotStale immediately. +// downloadState() returns errPivotStale immediately. func TestPivotMoveDetection(t *testing.T) { t.Parallel() @@ -1937,45 +1937,14 @@ func TestPivotMoveDetection(t *testing.T) { if syncer.root != rootB { t.Fatalf("root mismatch: got %v, want %v", syncer.root, rootB) } - // download() should detect the mismatch and return errPivotStale + // downloadState() should detect the mismatch and return errPivotStale cancel := make(chan struct{}) - err := syncer.download(cancel) + err := syncer.downloadState(cancel) if err != errPivotStale { t.Fatalf("expected errPivotStale, got %v", err) } } -// TestNoPivotMoveOnSameRoot verifies that when the syncer is restarted with -// the same root, download() does not return errPivotStale. -func TestNoPivotMoveOnSameRoot(t *testing.T) { - t.Parallel() - - rootA := common.HexToHash("0xaaaa") - - db := rawdb.NewMemoryDatabase() - syncer := NewSyncer(db, rawdb.HashScheme) - - // Simulate a previous sync run against rootA - syncer.root = rootA - syncer.tasks = []*accountTask{ - {Next: common.Hash{}, Last: common.MaxHash, SubTasks: make(map[common.Hash][]*storageTask), stateCompleted: make(map[common.Hash]struct{})}, - } - syncer.saveSyncStatus() - - // Simulate restart with the same root - syncer.root = rootA - syncer.previousRoot = rootA - syncer.loadSyncStatus() - - if syncer.previousRoot != rootA { - t.Fatalf("previousRoot mismatch: got %v, want %v", syncer.previousRoot, rootA) - } - // previousRoot == root, so no pivot move detected - if syncer.previousRoot != syncer.root { - t.Fatalf("expected previousRoot == root, got %v != %v", syncer.previousRoot, syncer.root) - } -} - // TestCatchUpInvertedRange verifies that catchUp returns an error and wipes // sync progress when the new pivot is at the same (or lower) block number as // the old pivot.. @@ -2003,51 +1972,6 @@ func TestCatchUpInvertedRange(t *testing.T) { } } -// TestFlatStateDownload verifies that download() writes flat state to disk -// and makes no trie node requests. -func TestFlatStateDownload(t *testing.T) { - t.Parallel() - testFlatStateDownload(t, rawdb.HashScheme) - testFlatStateDownload(t, rawdb.PathScheme) -} - -func testFlatStateDownload(t *testing.T, scheme string) { - var ( - once sync.Once - cancel = make(chan struct{}) - term = func() { - once.Do(func() { - close(cancel) - }) - } - ) - nodeScheme, sourceAccountTrie, elems := makeAccountTrieNoStorage(100, scheme) - mkSource := func(name string) *testPeer { - source := newTestPeer(name, t, term) - source.accountTrie = sourceAccountTrie.Copy() - source.accountValues = elems - return source - } - syncer := setupSyncer(nodeScheme, mkSource("source")) - - // Call download() directly to avoid rebuildTrie - syncer.root = sourceAccountTrie.Hash() - syncer.previousRoot = syncer.root // No pivot move - syncer.loadSyncStatus() - if err := syncer.download(cancel); err != nil { - t.Fatalf("download failed: %v", err) - } - - // Verify flat state was written - for _, entry := range elems { - hash := common.BytesToHash(entry.k) - data := rawdb.ReadAccountSnapshot(syncer.db, hash) - if len(data) == 0 { - t.Errorf("missing account snapshot for %x", hash) - } - } -} - // TestInterruptedDownloadRecovery verifies that partially completed download // state is persisted and resumed on restart. func TestInterruptedDownloadRecovery(t *testing.T) { @@ -2086,7 +2010,7 @@ func testInterruptedDownloadRecovery(t *testing.T, scheme string) { syncer1.root = root syncer1.previousRoot = root syncer1.loadSyncStatus() - syncer1.download(cancel1) + syncer1.downloadState(cancel1) // Save progress for _, task := range syncer1.tasks { @@ -2123,7 +2047,7 @@ func testInterruptedDownloadRecovery(t *testing.T, scheme string) { syncer2.root = root syncer2.previousRoot = root syncer2.loadSyncStatus() - if err := syncer2.download(cancel2); err != nil { + if err := syncer2.downloadState(cancel2); err != nil { t.Fatalf("resumed download failed: %v", err) } @@ -2333,7 +2257,7 @@ func TestInterruptedRebuildRecovery(t *testing.T) { root := sourceAccountTrie.Hash() // First run: complete download, save status, simulate interruption - // before rebuild by calling download() directly + // before rebuild by calling downloadState() directly var ( once1 sync.Once cancel1 = make(chan struct{}) @@ -2350,7 +2274,7 @@ func TestInterruptedRebuildRecovery(t *testing.T) { syncer1.previousRoot = root syncer1.loadSyncStatus() - if err := syncer1.download(cancel1); err != nil { + if err := syncer1.downloadState(cancel1); err != nil { t.Fatalf("download failed: %v", err) } // Save status (simulating what Sync's defer does) @@ -2386,50 +2310,6 @@ func TestInterruptedRebuildRecovery(t *testing.T) { } } -// TestFetchAccessListsSinglePeer verifies fetching BALs from a single peer. -func TestFetchAccessListsSinglePeer(t *testing.T) { - t.Parallel() - var ( - once sync.Once - cancel = make(chan struct{}) - term = func() { once.Do(func() { close(cancel) }) } - ) - - // Create test BALs - hashes := []common.Hash{ - common.HexToHash("0x01"), - common.HexToHash("0x02"), - common.HexToHash("0x03"), - } - bals := make(map[common.Hash]rlp.RawValue) - for _, h := range hashes { - cb := bal.NewConstructionBlockAccessList() - cb.BalanceChange(0, common.HexToAddress("0xaa"), uint256.NewInt(uint64(h[31]))) - var buf bytes.Buffer - if err := cb.EncodeRLP(&buf); err != nil { - t.Fatal(err) - } - bals[h] = buf.Bytes() - } - source := newTestPeer("source", t, term) - source.accessLists = bals - syncer := setupSyncer(rawdb.HashScheme, source) - results, err := syncer.fetchAccessLists(hashes, cancel) - if err != nil { - t.Fatalf("fetchAccessLists failed: %v", err) - } - if len(results) != len(hashes) { - t.Fatalf("result count mismatch: got %d, want %d", len(results), len(hashes)) - } - - // Verify results match input order - for i, h := range hashes { - if !bytes.Equal(results[i], bals[h]) { - t.Errorf("result %d mismatch", i) - } - } -} - // TestFetchAccessListsMultiplePeers verifies that fetch distributes work // across multiple idle peers. func TestFetchAccessListsMultiplePeers(t *testing.T) { @@ -2467,6 +2347,12 @@ func TestFetchAccessListsMultiplePeers(t *testing.T) { if len(results) != len(hashes) { t.Fatalf("result count mismatch: got %d, want %d", len(results), len(hashes)) } + // Verify results match expected content in request order + for i, h := range hashes { + if !bytes.Equal(results[i], bals[h]) { + t.Errorf("result %d content mismatch for hash %v", i, h) + } + } } // TestFetchAccessListsPeerTimeout verifies that timed-out requests are retried @@ -2700,6 +2586,91 @@ func TestFetchAccessListsShortResponse(t *testing.T) { } } +// TestFetchAccessListsEmptyPlaceholder verifies that when a peer returns +// rlp.EmptyString placeholders for BALs it doesn't have, those placeholders +// are not silently accepted as valid results. +func TestFetchAccessListsEmptyPlaceholder(t *testing.T) { + t.Parallel() + var ( + once sync.Once + cancel = make(chan struct{}) + term = func() { once.Do(func() { close(cancel) }) } + ) + hashes := []common.Hash{ + common.HexToHash("0x01"), + common.HexToHash("0x02"), + common.HexToHash("0x03"), + } + + // Build BALs for all 3 hashes + allBALs := make(map[common.Hash]rlp.RawValue) + for _, h := range hashes { + cb := bal.NewConstructionBlockAccessList() + cb.BalanceChange(0, common.HexToAddress("0xaa"), uint256.NewInt(uint64(h[31]))) + var buf bytes.Buffer + if err := cb.EncodeRLP(&buf); err != nil { + t.Fatal(err) + } + allBALs[h] = buf.Bytes() + } + + // partialPeer has BALs for hashes 0 and 2. The server + // handler returns rlp.EmptyString for the missing BAL. + partialPeer := newTestPeer("partial", t, term) + partialPeer.accessListRequestHandler = func(tp *testPeer, id uint64, reqHashes []common.Hash, max int) error { + var results []rlp.RawValue + for _, h := range reqHashes { + if raw, ok := allBALs[h]; ok && h != hashes[1] { + results = append(results, raw) + } else { + results = append(results, rlp.EmptyString) + } + } + rawList, _ := rlp.EncodeToRawList(results) + if err := tp.remote.OnAccessLists(tp, id, rawList); err != nil { + tp.test.Errorf("delivery rejected: %v", err) + tp.term() + } + return nil + } + + // fullPeer has all BALs + fullPeer := newTestPeer("full", t, term) + fullPeer.accessLists = allBALs + syncer := setupSyncer(rawdb.HashScheme, partialPeer, fullPeer) + + // Pre-seed capacity so partialPeer gets all 3 hashes + syncer.rates.Update(partialPeer.id, AccessListsMsg, time.Millisecond, 100) + done := make(chan struct{}) + var ( + results []rlp.RawValue + fetchErr error + ) + go func() { + results, fetchErr = syncer.fetchAccessLists(hashes, cancel) + close(done) + }() + + // Wait for fetch to complete + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("fetchAccessLists hung") + } + if fetchErr != nil { + t.Fatalf("fetchAccessLists failed: %v", fetchErr) + } + + // Verify the results are valid. + for i, raw := range results { + var accessList bal.BlockAccessList + if err := rlp.DecodeBytes(raw, &accessList); err != nil { + t.Errorf("result %d (hash %v) is not a valid BAL: %v (got raw bytes %x)", + i, hashes[i], err, raw) + } + } +} + func newDbConfig(scheme string) *triedb.Config { if scheme == rawdb.HashScheme { return &triedb.Config{}