From 00da4f51fffbbfaf2a9b962ba2324629d5f0c45f Mon Sep 17 00:00:00 2001 From: Jonny Rhea <5555162+jrhea@users.noreply.github.com> Date: Fri, 3 Apr 2026 01:10:32 -0500 Subject: [PATCH] core, eth/protocols/snap: Snap/2 Protocol + BAL Serving (#34083) Implement the snap/2 wire protocol with BAL serving --------- Co-authored-by: Gary Rong --- core/blockchain_reader.go | 8 + eth/downloader/downloader_test.go | 2 +- eth/protocols/eth/handler.go | 1 - eth/protocols/snap/handler.go | 571 ++------------------ eth/protocols/snap/handler_fuzzing_test.go | 6 + eth/protocols/snap/handler_test.go | 195 +++++++ eth/protocols/snap/handlers.go | 593 +++++++++++++++++++++ eth/protocols/snap/protocol.go | 25 +- 8 files changed, 874 insertions(+), 527 deletions(-) create mode 100644 eth/protocols/snap/handler_test.go create mode 100644 eth/protocols/snap/handlers.go diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go index f1b40d0d0c..8b026680d2 100644 --- a/core/blockchain_reader.go +++ b/core/blockchain_reader.go @@ -296,6 +296,14 @@ func (bc *BlockChain) GetReceiptsRLP(hash common.Hash) rlp.RawValue { return rawdb.ReadReceiptsRLP(bc.db, hash, number) } +func (bc *BlockChain) GetAccessListRLP(hash common.Hash) rlp.RawValue { + number, ok := rawdb.ReadHeaderNumber(bc.db, hash) + if !ok { + return nil + } + return rawdb.ReadAccessListRLP(bc.db, hash, number) +} + // GetUnclesInChain retrieves all the uncles from a given block backwards until // a specific distance is reached. func (bc *BlockChain) GetUnclesInChain(block *types.Block, length int) []*types.Header { diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 01a994dbfd..9280d455fb 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -370,7 +370,7 @@ func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, cou Paths: encPaths, Bytes: uint64(bytes), } - nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req, time.Now()) + nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req) go dlp.dl.downloader.SnapSyncer.OnTrieNodes(dlp, id, nodes) return nil } diff --git a/eth/protocols/eth/handler.go b/eth/protocols/eth/handler.go index 59512f5be7..f7d25bd8ca 100644 --- a/eth/protocols/eth/handler.go +++ b/eth/protocols/eth/handler.go @@ -167,7 +167,6 @@ func Handle(backend Backend, peer *Peer) error { type msgHandler func(backend Backend, msg Decoder, peer *Peer) error type Decoder interface { Decode(val interface{}) error - Time() time.Time } var eth69 = map[uint64]msgHandler{ diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index 071a0419fb..26545f2960 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -17,25 +17,14 @@ package snap import ( - "bytes" "fmt" "time" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/state/snapshot" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" - "github.com/ethereum/go-ethereum/p2p/tracker" - "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" - "github.com/ethereum/go-ethereum/trie/trienode" - "github.com/ethereum/go-ethereum/triedb/database" ) const ( @@ -55,6 +44,10 @@ const ( // number is there to limit the number of disk lookups. maxTrieNodeLookups = 1024 + // maxAccessListLookups is the maximum number of BALs to server. This number + // is there to limit the number of disk lookups. + maxAccessListLookups = 1024 + // maxTrieNodeTimeSpent is the maximum time we should spend on looking up trie nodes. // If we spend too much time, then it's a fairly high chance of timing out // at the remote side, which means all the work is in vain. @@ -123,6 +116,34 @@ func Handle(backend Backend, peer *Peer) error { } } +type msgHandler func(backend Backend, msg Decoder, peer *Peer) error +type Decoder interface { + Decode(val interface{}) error +} + +var snap1 = map[uint64]msgHandler{ + GetAccountRangeMsg: handleGetAccountRange, + AccountRangeMsg: handleAccountRange, + GetStorageRangesMsg: handleGetStorageRanges, + StorageRangesMsg: handleStorageRanges, + GetByteCodesMsg: handleGetByteCodes, + ByteCodesMsg: handleByteCodes, + GetTrieNodesMsg: handleGetTrienodes, + TrieNodesMsg: handleTrieNodes, +} + +// nolint:unused +var snap2 = map[uint64]msgHandler{ + GetAccountRangeMsg: handleGetAccountRange, + AccountRangeMsg: handleAccountRange, + GetStorageRangesMsg: handleGetStorageRanges, + StorageRangesMsg: handleStorageRanges, + GetByteCodesMsg: handleGetByteCodes, + ByteCodesMsg: handleByteCodes, + GetAccessListsMsg: handleGetAccessLists, + // AccessListsMsg: TODO +} + // HandleMessage is invoked whenever an inbound message is received from a // remote peer on the `snap` protocol. The remote connection is torn down upon // returning any error. @@ -136,8 +157,19 @@ func HandleMessage(backend Backend, peer *Peer) error { return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize) } defer msg.Discard() - start := time.Now() + + var handlers map[uint64]msgHandler + switch peer.version { + case SNAP1: + handlers = snap1 + //case SNAP2: + // handlers = snap2 + default: + return fmt.Errorf("unknown eth protocol version: %v", peer.version) + } + // Track the amount of time it takes to serve the request and run the handler + start := time.Now() if metrics.Enabled() { h := fmt.Sprintf("%s/%s/%d/%#02x", p2p.HandleHistName, ProtocolName, peer.Version(), msg.Code) defer func(start time.Time) { @@ -149,520 +181,11 @@ func HandleMessage(backend Backend, peer *Peer) error { metrics.GetOrRegisterHistogramLazy(h, nil, sampler).Update(time.Since(start).Microseconds()) }(start) } - // Handle the message depending on its contents - switch { - case msg.Code == GetAccountRangeMsg: - var req GetAccountRangePacket - if err := msg.Decode(&req); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - // Service the request, potentially returning nothing in case of errors - accounts, proofs := ServiceGetAccountRangeQuery(backend.Chain(), &req) - // Send back anything accumulated (or empty in case of errors) - return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ - ID: req.ID, - Accounts: accounts, - Proof: proofs, - }) - - case msg.Code == AccountRangeMsg: - res := new(accountRangeInput) - if err := msg.Decode(res); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - - // Check response validity. - if len := res.Proof.Len(); len > 128 { - return fmt.Errorf("AccountRange: invalid proof (length %d)", len) - } - tresp := tracker.Response{ID: res.ID, MsgCode: AccountRangeMsg, Size: len(res.Accounts.Content())} - if err := peer.tracker.Fulfil(tresp); err != nil { - return err - } - - // Decode. - accounts, err := res.Accounts.Items() - if err != nil { - return fmt.Errorf("AccountRange: invalid accounts list: %v", err) - } - proof, err := res.Proof.Items() - if err != nil { - return fmt.Errorf("AccountRange: invalid proof: %v", err) - } - - // Ensure the range is monotonically increasing - for i := 1; i < len(accounts); i++ { - if bytes.Compare(accounts[i-1].Hash[:], accounts[i].Hash[:]) >= 0 { - return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, accounts[i-1].Hash[:], i, accounts[i].Hash[:]) - } - } - - return backend.Handle(peer, &AccountRangePacket{res.ID, accounts, proof}) - - case msg.Code == GetStorageRangesMsg: - var req GetStorageRangesPacket - if err := msg.Decode(&req); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - // Service the request, potentially returning nothing in case of errors - slots, proofs := ServiceGetStorageRangesQuery(backend.Chain(), &req) - - // Send back anything accumulated (or empty in case of errors) - return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ - ID: req.ID, - Slots: slots, - Proof: proofs, - }) - - case msg.Code == StorageRangesMsg: - res := new(storageRangesInput) - if err := msg.Decode(res); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - - // Check response validity. - if len := res.Proof.Len(); len > 128 { - return fmt.Errorf("StorageRangesMsg: invalid proof (length %d)", len) - } - tresp := tracker.Response{ID: res.ID, MsgCode: StorageRangesMsg, Size: len(res.Slots.Content())} - if err := peer.tracker.Fulfil(tresp); err != nil { - return fmt.Errorf("StorageRangesMsg: %w", err) - } - - // Decode. - slotLists, err := res.Slots.Items() - if err != nil { - return fmt.Errorf("AccountRange: invalid accounts list: %v", err) - } - proof, err := res.Proof.Items() - if err != nil { - return fmt.Errorf("AccountRange: invalid proof: %v", err) - } - - // Ensure the ranges are monotonically increasing - for i, slots := range slotLists { - for j := 1; j < len(slots); j++ { - if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 { - return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:]) - } - } - } - - return backend.Handle(peer, &StorageRangesPacket{res.ID, slotLists, proof}) - - case msg.Code == GetByteCodesMsg: - var req GetByteCodesPacket - if err := msg.Decode(&req); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - // Service the request, potentially returning nothing in case of errors - codes := ServiceGetByteCodesQuery(backend.Chain(), &req) - - // Send back anything accumulated (or empty in case of errors) - return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{ - ID: req.ID, - Codes: codes, - }) - - case msg.Code == ByteCodesMsg: - res := new(byteCodesInput) - if err := msg.Decode(res); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - - length := res.Codes.Len() - tresp := tracker.Response{ID: res.ID, MsgCode: ByteCodesMsg, Size: length} - if err := peer.tracker.Fulfil(tresp); err != nil { - return fmt.Errorf("ByteCodes: %w", err) - } - - codes, err := res.Codes.Items() - if err != nil { - return fmt.Errorf("ByteCodes: %w", err) - } - - return backend.Handle(peer, &ByteCodesPacket{res.ID, codes}) - - case msg.Code == GetTrieNodesMsg: - var req GetTrieNodesPacket - if err := msg.Decode(&req); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - // Service the request, potentially returning nothing in case of errors - nodes, err := ServiceGetTrieNodesQuery(backend.Chain(), &req, start) - if err != nil { - return err - } - // Send back anything accumulated (or empty in case of errors) - return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ - ID: req.ID, - Nodes: nodes, - }) - - case msg.Code == TrieNodesMsg: - res := new(trieNodesInput) - if err := msg.Decode(res); err != nil { - return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) - } - - tresp := tracker.Response{ID: res.ID, MsgCode: TrieNodesMsg, Size: res.Nodes.Len()} - if err := peer.tracker.Fulfil(tresp); err != nil { - return fmt.Errorf("TrieNodes: %w", err) - } - nodes, err := res.Nodes.Items() - if err != nil { - return fmt.Errorf("TrieNodes: %w", err) - } - - return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes}) - - default: - return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code) + if handler := handlers[msg.Code]; handler != nil { + return handler(backend, msg, peer) } -} - -// ServiceGetAccountRangeQuery assembles the response to an account range query. -// It is exposed to allow external packages to test protocol behavior. -func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePacket) ([]*AccountData, [][]byte) { - if req.Bytes > softResponseLimit { - req.Bytes = softResponseLimit - } - // Retrieve the requested state and bail out if non existent - tr, err := trie.New(trie.StateTrieID(req.Root), chain.TrieDB()) - if err != nil { - return nil, nil - } - // Temporary solution: using the snapshot interface for both cases. - // This can be removed once the hash scheme is deprecated. - var it snapshot.AccountIterator - if chain.TrieDB().Scheme() == rawdb.HashScheme { - // The snapshot is assumed to be available in hash mode if - // the SNAP protocol is enabled. - it, err = chain.Snapshots().AccountIterator(req.Root, req.Origin) - } else { - it, err = chain.TrieDB().AccountIterator(req.Root, req.Origin) - } - if err != nil { - return nil, nil - } - // Iterate over the requested range and pile accounts up - var ( - accounts []*AccountData - size uint64 - last common.Hash - ) - for it.Next() { - hash, account := it.Hash(), common.CopyBytes(it.Account()) - - // Track the returned interval for the Merkle proofs - last = hash - - // Assemble the reply item - size += uint64(common.HashLength + len(account)) - accounts = append(accounts, &AccountData{ - Hash: hash, - Body: account, - }) - // If we've exceeded the request threshold, abort - if bytes.Compare(hash[:], req.Limit[:]) >= 0 { - break - } - if size > req.Bytes { - break - } - } - it.Release() - - // Generate the Merkle proofs for the first and last account - proof := trienode.NewProofSet() - if err := tr.Prove(req.Origin[:], proof); err != nil { - log.Warn("Failed to prove account range", "origin", req.Origin, "err", err) - return nil, nil - } - if last != (common.Hash{}) { - if err := tr.Prove(last[:], proof); err != nil { - log.Warn("Failed to prove account range", "last", last, "err", err) - return nil, nil - } - } - return accounts, proof.List() -} - -func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesPacket) ([][]*StorageData, [][]byte) { - if req.Bytes > softResponseLimit { - req.Bytes = softResponseLimit - } - // TODO(karalabe): Do we want to enforce > 0 accounts and 1 account if origin is set? - // TODO(karalabe): - Logging locally is not ideal as remote faults annoy the local user - // TODO(karalabe): - Dropping the remote peer is less flexible wrt client bugs (slow is better than non-functional) - - // Calculate the hard limit at which to abort, even if mid storage trie - hardLimit := uint64(float64(req.Bytes) * (1 + stateLookupSlack)) - - // Retrieve storage ranges until the packet limit is reached - var ( - slots [][]*StorageData - proofs [][]byte - size uint64 - ) - for _, account := range req.Accounts { - // If we've exceeded the requested data limit, abort without opening - // a new storage range (that we'd need to prove due to exceeded size) - if size >= req.Bytes { - break - } - // The first account might start from a different origin and end sooner - var origin common.Hash - if len(req.Origin) > 0 { - origin, req.Origin = common.BytesToHash(req.Origin), nil - } - var limit = common.MaxHash - if len(req.Limit) > 0 { - limit, req.Limit = common.BytesToHash(req.Limit), nil - } - // Retrieve the requested state and bail out if non existent - var ( - err error - it snapshot.StorageIterator - ) - // Temporary solution: using the snapshot interface for both cases. - // This can be removed once the hash scheme is deprecated. - if chain.TrieDB().Scheme() == rawdb.HashScheme { - // The snapshot is assumed to be available in hash mode if - // the SNAP protocol is enabled. - it, err = chain.Snapshots().StorageIterator(req.Root, account, origin) - } else { - it, err = chain.TrieDB().StorageIterator(req.Root, account, origin) - } - if err != nil { - return nil, nil - } - // Iterate over the requested range and pile slots up - var ( - storage []*StorageData - last common.Hash - abort bool - ) - for it.Next() { - if size >= hardLimit { - abort = true - break - } - hash, slot := it.Hash(), common.CopyBytes(it.Slot()) - - // Track the returned interval for the Merkle proofs - last = hash - - // Assemble the reply item - size += uint64(common.HashLength + len(slot)) - storage = append(storage, &StorageData{ - Hash: hash, - Body: slot, - }) - // If we've exceeded the request threshold, abort - if bytes.Compare(hash[:], limit[:]) >= 0 { - break - } - } - if len(storage) > 0 { - slots = append(slots, storage) - } - it.Release() - - // Generate the Merkle proofs for the first and last storage slot, but - // only if the response was capped. If the entire storage trie included - // in the response, no need for any proofs. - if origin != (common.Hash{}) || (abort && len(storage) > 0) { - // Request started at a non-zero hash or was capped prematurely, add - // the endpoint Merkle proofs - accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), chain.TrieDB()) - if err != nil { - return nil, nil - } - acc, err := accTrie.GetAccountByHash(account) - if err != nil || acc == nil { - return nil, nil - } - id := trie.StorageTrieID(req.Root, account, acc.Root) - stTrie, err := trie.NewStateTrie(id, chain.TrieDB()) - if err != nil { - return nil, nil - } - proof := trienode.NewProofSet() - if err := stTrie.Prove(origin[:], proof); err != nil { - log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err) - return nil, nil - } - if last != (common.Hash{}) { - if err := stTrie.Prove(last[:], proof); err != nil { - log.Warn("Failed to prove storage range", "last", last, "err", err) - return nil, nil - } - } - proofs = append(proofs, proof.List()...) - // Proof terminates the reply as proofs are only added if a node - // refuses to serve more data (exception when a contract fetch is - // finishing, but that's that). - break - } - } - return slots, proofs -} - -// ServiceGetByteCodesQuery assembles the response to a byte codes query. -// It is exposed to allow external packages to test protocol behavior. -func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [][]byte { - if req.Bytes > softResponseLimit { - req.Bytes = softResponseLimit - } - if len(req.Hashes) > maxCodeLookups { - req.Hashes = req.Hashes[:maxCodeLookups] - } - // Retrieve bytecodes until the packet size limit is reached - var ( - codes [][]byte - bytes uint64 - ) - for _, hash := range req.Hashes { - if hash == types.EmptyCodeHash { - // Peers should not request the empty code, but if they do, at - // least sent them back a correct response without db lookups - codes = append(codes, []byte{}) - } else if blob := chain.ContractCodeWithPrefix(hash); len(blob) > 0 { - codes = append(codes, blob) - bytes += uint64(len(blob)) - } - if bytes > req.Bytes { - break - } - } - return codes -} - -// ServiceGetTrieNodesQuery assembles the response to a trie nodes query. -// It is exposed to allow external packages to test protocol behavior. -func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, start time.Time) ([][]byte, error) { - if req.Bytes > softResponseLimit { - req.Bytes = softResponseLimit - } - // Make sure we have the state associated with the request - triedb := chain.TrieDB() - - accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), triedb) - if err != nil { - // We don't have the requested state available, bail out - return nil, nil - } - // The 'reader' might be nil, in which case we cannot serve storage slots - // via snapshot. - var reader database.StateReader - if chain.Snapshots() != nil { - reader = chain.Snapshots().Snapshot(req.Root) - } - if reader == nil { - reader, _ = triedb.StateReader(req.Root) - } - - // Retrieve trie nodes until the packet size limit is reached - var ( - outerIt = req.Paths.ContentIterator() - nodes [][]byte - bytes uint64 - loads int // Trie hash expansions to count database reads - ) - for outerIt.Next() { - innerIt, err := rlp.NewListIterator(outerIt.Value()) - if err != nil { - return nodes, err - } - - switch innerIt.Count() { - case 0: - // Ensure we penalize invalid requests - return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest) - - case 1: - // If we're only retrieving an account trie node, fetch it directly - accKey := nextBytes(&innerIt) - if accKey == nil { - return nodes, fmt.Errorf("%w: invalid account node request", errBadRequest) - } - blob, resolved, err := accTrie.GetNode(accKey) - loads += resolved // always account database reads, even for failures - if err != nil { - break - } - nodes = append(nodes, blob) - bytes += uint64(len(blob)) - - default: - // Storage slots requested, open the storage trie and retrieve from there - accKey := nextBytes(&innerIt) - if accKey == nil { - return nodes, fmt.Errorf("%w: invalid account storage request", errBadRequest) - } - var stRoot common.Hash - if reader == nil { - // We don't have the requested state snapshotted yet (or it is stale), - // but can look up the account via the trie instead. - account, err := accTrie.GetAccountByHash(common.BytesToHash(accKey)) - loads += 8 // We don't know the exact cost of lookup, this is an estimate - if err != nil || account == nil { - break - } - stRoot = account.Root - } else { - account, err := reader.Account(common.BytesToHash(accKey)) - loads++ // always account database reads, even for failures - if err != nil || account == nil { - break - } - stRoot = common.BytesToHash(account.Root) - } - - id := trie.StorageTrieID(req.Root, common.BytesToHash(accKey), stRoot) - stTrie, err := trie.NewStateTrie(id, triedb) - loads++ // always account database reads, even for failures - if err != nil { - break - } - for innerIt.Next() { - path, _, err := rlp.SplitString(innerIt.Value()) - if err != nil { - return nil, fmt.Errorf("%w: invalid storage key: %v", errBadRequest, err) - } - blob, resolved, err := stTrie.GetNode(path) - loads += resolved // always account database reads, even for failures - if err != nil { - break - } - nodes = append(nodes, blob) - bytes += uint64(len(blob)) - - // Sanity check limits to avoid DoS on the store trie loads - if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent { - break - } - } - } - // Abort request processing if we've exceeded our limits - if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent { - break - } - } - return nodes, nil -} - -func nextBytes(it *rlp.Iterator) []byte { - if !it.Next() { - return nil - } - content, _, err := rlp.SplitString(it.Value()) - if err != nil { - return nil - } - return content + return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code) } // NodeInfo represents a short summary of the `snap` sub-protocol metadata diff --git a/eth/protocols/snap/handler_fuzzing_test.go b/eth/protocols/snap/handler_fuzzing_test.go index 4930ae9ae6..a52da0aac5 100644 --- a/eth/protocols/snap/handler_fuzzing_test.go +++ b/eth/protocols/snap/handler_fuzzing_test.go @@ -60,6 +60,12 @@ func FuzzTrieNodes(f *testing.F) { }) } +func FuzzAccessLists(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + doFuzz(data, &GetAccessListsPacket{}, GetAccessListsMsg) + }) +} + func doFuzz(input []byte, obj interface{}, code int) { bc := getChain() defer bc.Stop() diff --git a/eth/protocols/snap/handler_test.go b/eth/protocols/snap/handler_test.go new file mode 100644 index 0000000000..cb4b378a8d --- /dev/null +++ b/eth/protocols/snap/handler_test.go @@ -0,0 +1,195 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snap + +import ( + "bytes" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" +) + +// getChainWithBALs creates a minimal test chain with BALs stored for each block. +// It returns the chain, block hashes, and the stored BAL data. +func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash, []rlp.RawValue) { + gspec := &core.Genesis{ + Config: params.TestChainConfig, + } + db := rawdb.NewMemoryDatabase() + _, blocks, _ := core.GenerateChainWithGenesis(gspec, ethash.NewFaker(), nBlocks, func(i int, gen *core.BlockGen) {}) + options := &core.BlockChainConfig{ + TrieCleanLimit: 0, + TrieDirtyLimit: 0, + TrieTimeLimit: 5 * time.Minute, + NoPrefetch: true, + SnapshotLimit: 0, + } + bc, err := core.NewBlockChain(db, gspec, ethash.NewFaker(), options) + if err != nil { + panic(err) + } + if _, err := bc.InsertChain(blocks); err != nil { + panic(err) + } + + // Store BALs for each block + var hashes []common.Hash + var bals []rlp.RawValue + for _, block := range blocks { + hash := block.Hash() + number := block.NumberU64() + bal := make(rlp.RawValue, balSize) + + // Fill with data based on block number + for j := range bal { + bal[j] = byte(number + uint64(j)) + } + rawdb.WriteAccessListRLP(db, hash, number, bal) + hashes = append(hashes, hash) + bals = append(bals, bal) + } + return bc, hashes, bals +} + +// TestServiceGetAccessListsQuery verifies that known block hashes return the +// correct BALs with positional correspondence. +func TestServiceGetAccessListsQuery(t *testing.T) { + t.Parallel() + bc, hashes, bals := getChainWithBALs(5, 100) + defer bc.Stop() + req := &GetAccessListsPacket{ + ID: 1, + Hashes: hashes, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Verify the results + if len(result) != len(hashes) { + t.Fatalf("expected %d results, got %d", len(hashes), len(result)) + } + for i, bal := range result { + if !bytes.Equal(bal, bals[i]) { + t.Errorf("BAL %d mismatch: got %x, want %x", i, bal, bals[i]) + } + } +} + +// TestServiceGetAccessListsQueryEmpty verifies that unknown block hashes return +// nil placeholders and that mixed known/unknown hashes preserve alignment. +func TestServiceGetAccessListsQueryEmpty(t *testing.T) { + t.Parallel() + bc, hashes, bals := getChainWithBALs(3, 100) + defer bc.Stop() + unknown := common.HexToHash("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + mixed := []common.Hash{hashes[0], unknown, hashes[1], unknown, hashes[2]} + req := &GetAccessListsPacket{ + ID: 2, + Hashes: mixed, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Verify length + if len(result) != len(mixed) { + t.Fatalf("expected %d results, got %d", len(mixed), len(result)) + } + + // Check positional correspondence + if !bytes.Equal(result[0], bals[0]) { + t.Errorf("index 0: expected known BAL, got %x", result[0]) + } + if result[1] != nil { + t.Errorf("index 1: expected nil for unknown hash, got %x", result[1]) + } + if !bytes.Equal(result[2], bals[1]) { + t.Errorf("index 2: expected known BAL, got %x", result[2]) + } + if result[3] != nil { + t.Errorf("index 3: expected nil for unknown hash, got %x", result[3]) + } + if !bytes.Equal(result[4], bals[2]) { + t.Errorf("index 4: expected known BAL, got %x", result[4]) + } +} + +// TestServiceGetAccessListsQueryCap verifies that requests exceeding +// maxAccessListLookups are capped. +func TestServiceGetAccessListsQueryCap(t *testing.T) { + t.Parallel() + + bc, _, _ := getChainWithBALs(2, 100) + defer bc.Stop() + + // Create a request with more hashes than the cap + hashes := make([]common.Hash, maxAccessListLookups+100) + for i := range hashes { + hashes[i] = common.BytesToHash([]byte{byte(i), byte(i >> 8)}) + } + req := &GetAccessListsPacket{ + ID: 3, + Hashes: hashes, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Can't get more than maxAccessListLookups results + if len(result) > maxAccessListLookups { + t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, len(result)) + } +} + +// TestServiceGetAccessListsQueryByteLimit verifies that the response stops +// once the byte limit is exceeded. The handler appends the entry that crosses +// the limit before breaking, so the total size will exceed the limit by at +// most one BAL. +func TestServiceGetAccessListsQueryByteLimit(t *testing.T) { + t.Parallel() + + // The handler will return 3/5 entries (3MB total) then break. + balSize := 1024 * 1024 + nBlocks := 5 + bc, hashes, _ := getChainWithBALs(nBlocks, balSize) + defer bc.Stop() + req := &GetAccessListsPacket{ + ID: 0, + Hashes: hashes, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Should have stopped before returning all blocks + if len(result) >= nBlocks { + t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, len(result)) + } + + // Should have returned at least one + if len(result) == 0 { + t.Fatal("expected at least one result") + } + + // The total size should exceed the limit (the entry that crosses it is included) + var total uint64 + for _, bal := range result { + total += uint64(len(bal)) + } + if total <= softResponseLimit { + t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", total, softResponseLimit) + } +} diff --git a/eth/protocols/snap/handlers.go b/eth/protocols/snap/handlers.go new file mode 100644 index 0000000000..64522343f9 --- /dev/null +++ b/eth/protocols/snap/handlers.go @@ -0,0 +1,593 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see + +package snap + +import ( + "bytes" + "fmt" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state/snapshot" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/tracker" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/trie/trienode" + "github.com/ethereum/go-ethereum/triedb/database" +) + +func handleGetAccountRange(backend Backend, msg Decoder, peer *Peer) error { + var req GetAccountRangePacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Service the request, potentially returning nothing in case of errors + accounts, proofs := ServiceGetAccountRangeQuery(backend.Chain(), &req) + + // Send back anything accumulated (or empty in case of errors) + return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ + ID: req.ID, + Accounts: accounts, + Proof: proofs, + }) +} + +// ServiceGetAccountRangeQuery assembles the response to an account range query. +// It is exposed to allow external packages to test protocol behavior. +func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePacket) ([]*AccountData, [][]byte) { + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // Retrieve the requested state and bail out if non existent + tr, err := trie.New(trie.StateTrieID(req.Root), chain.TrieDB()) + if err != nil { + return nil, nil + } + // Temporary solution: using the snapshot interface for both cases. + // This can be removed once the hash scheme is deprecated. + var it snapshot.AccountIterator + if chain.TrieDB().Scheme() == rawdb.HashScheme { + // The snapshot is assumed to be available in hash mode if + // the SNAP protocol is enabled. + it, err = chain.Snapshots().AccountIterator(req.Root, req.Origin) + } else { + it, err = chain.TrieDB().AccountIterator(req.Root, req.Origin) + } + if err != nil { + return nil, nil + } + // Iterate over the requested range and pile accounts up + var ( + accounts []*AccountData + size uint64 + last common.Hash + ) + for it.Next() { + hash, account := it.Hash(), common.CopyBytes(it.Account()) + + // Track the returned interval for the Merkle proofs + last = hash + + // Assemble the reply item + size += uint64(common.HashLength + len(account)) + accounts = append(accounts, &AccountData{ + Hash: hash, + Body: account, + }) + // If we've exceeded the request threshold, abort + if bytes.Compare(hash[:], req.Limit[:]) >= 0 { + break + } + if size > req.Bytes { + break + } + } + it.Release() + + // Generate the Merkle proofs for the first and last account + proof := trienode.NewProofSet() + if err := tr.Prove(req.Origin[:], proof); err != nil { + log.Warn("Failed to prove account range", "origin", req.Origin, "err", err) + return nil, nil + } + if last != (common.Hash{}) { + if err := tr.Prove(last[:], proof); err != nil { + log.Warn("Failed to prove account range", "last", last, "err", err) + return nil, nil + } + } + return accounts, proof.List() +} + +func handleAccountRange(backend Backend, msg Decoder, peer *Peer) error { + res := new(accountRangeInput) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + + // Check response validity. + if len := res.Proof.Len(); len > 128 { + return fmt.Errorf("AccountRange: invalid proof (length %d)", len) + } + tresp := tracker.Response{ID: res.ID, MsgCode: AccountRangeMsg, Size: len(res.Accounts.Content())} + if err := peer.tracker.Fulfil(tresp); err != nil { + return err + } + + // Decode. + accounts, err := res.Accounts.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid accounts list: %v", err) + } + proof, err := res.Proof.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid proof: %v", err) + } + + // Ensure the range is monotonically increasing + for i := 1; i < len(accounts); i++ { + if bytes.Compare(accounts[i-1].Hash[:], accounts[i].Hash[:]) >= 0 { + return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, accounts[i-1].Hash[:], i, accounts[i].Hash[:]) + } + } + + return backend.Handle(peer, &AccountRangePacket{res.ID, accounts, proof}) +} + +func handleGetStorageRanges(backend Backend, msg Decoder, peer *Peer) error { + var req GetStorageRangesPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Service the request, potentially returning nothing in case of errors + slots, proofs := ServiceGetStorageRangesQuery(backend.Chain(), &req) + + // Send back anything accumulated (or empty in case of errors) + return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ + ID: req.ID, + Slots: slots, + Proof: proofs, + }) +} + +func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesPacket) ([][]*StorageData, [][]byte) { + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // TODO(karalabe): Do we want to enforce > 0 accounts and 1 account if origin is set? + // TODO(karalabe): - Logging locally is not ideal as remote faults annoy the local user + // TODO(karalabe): - Dropping the remote peer is less flexible wrt client bugs (slow is better than non-functional) + + // Calculate the hard limit at which to abort, even if mid storage trie + hardLimit := uint64(float64(req.Bytes) * (1 + stateLookupSlack)) + + // Retrieve storage ranges until the packet limit is reached + var ( + slots [][]*StorageData + proofs [][]byte + size uint64 + ) + for _, account := range req.Accounts { + // If we've exceeded the requested data limit, abort without opening + // a new storage range (that we'd need to prove due to exceeded size) + if size >= req.Bytes { + break + } + // The first account might start from a different origin and end sooner + var origin common.Hash + if len(req.Origin) > 0 { + origin, req.Origin = common.BytesToHash(req.Origin), nil + } + var limit = common.MaxHash + if len(req.Limit) > 0 { + limit, req.Limit = common.BytesToHash(req.Limit), nil + } + // Retrieve the requested state and bail out if non existent + var ( + err error + it snapshot.StorageIterator + ) + // Temporary solution: using the snapshot interface for both cases. + // This can be removed once the hash scheme is deprecated. + if chain.TrieDB().Scheme() == rawdb.HashScheme { + // The snapshot is assumed to be available in hash mode if + // the SNAP protocol is enabled. + it, err = chain.Snapshots().StorageIterator(req.Root, account, origin) + } else { + it, err = chain.TrieDB().StorageIterator(req.Root, account, origin) + } + if err != nil { + return nil, nil + } + // Iterate over the requested range and pile slots up + var ( + storage []*StorageData + last common.Hash + abort bool + ) + for it.Next() { + if size >= hardLimit { + abort = true + break + } + hash, slot := it.Hash(), common.CopyBytes(it.Slot()) + + // Track the returned interval for the Merkle proofs + last = hash + + // Assemble the reply item + size += uint64(common.HashLength + len(slot)) + storage = append(storage, &StorageData{ + Hash: hash, + Body: slot, + }) + // If we've exceeded the request threshold, abort + if bytes.Compare(hash[:], limit[:]) >= 0 { + break + } + } + if len(storage) > 0 { + slots = append(slots, storage) + } + it.Release() + + // Generate the Merkle proofs for the first and last storage slot, but + // only if the response was capped. If the entire storage trie included + // in the response, no need for any proofs. + if origin != (common.Hash{}) || (abort && len(storage) > 0) { + // Request started at a non-zero hash or was capped prematurely, add + // the endpoint Merkle proofs + accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), chain.TrieDB()) + if err != nil { + return nil, nil + } + acc, err := accTrie.GetAccountByHash(account) + if err != nil || acc == nil { + return nil, nil + } + id := trie.StorageTrieID(req.Root, account, acc.Root) + stTrie, err := trie.NewStateTrie(id, chain.TrieDB()) + if err != nil { + return nil, nil + } + proof := trienode.NewProofSet() + if err := stTrie.Prove(origin[:], proof); err != nil { + log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err) + return nil, nil + } + if last != (common.Hash{}) { + if err := stTrie.Prove(last[:], proof); err != nil { + log.Warn("Failed to prove storage range", "last", last, "err", err) + return nil, nil + } + } + proofs = append(proofs, proof.List()...) + // Proof terminates the reply as proofs are only added if a node + // refuses to serve more data (exception when a contract fetch is + // finishing, but that's that). + break + } + } + return slots, proofs +} + +func handleStorageRanges(backend Backend, msg Decoder, peer *Peer) error { + res := new(storageRangesInput) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + + // Check response validity. + if len := res.Proof.Len(); len > 128 { + return fmt.Errorf("StorageRangesMsg: invalid proof (length %d)", len) + } + tresp := tracker.Response{ID: res.ID, MsgCode: StorageRangesMsg, Size: len(res.Slots.Content())} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("StorageRangesMsg: %w", err) + } + + // Decode. + slotLists, err := res.Slots.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid accounts list: %v", err) + } + proof, err := res.Proof.Items() + if err != nil { + return fmt.Errorf("AccountRange: invalid proof: %v", err) + } + + // Ensure the ranges are monotonically increasing + for i, slots := range slotLists { + for j := 1; j < len(slots); j++ { + if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 { + return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:]) + } + } + } + + return backend.Handle(peer, &StorageRangesPacket{res.ID, slotLists, proof}) +} + +func handleGetByteCodes(backend Backend, msg Decoder, peer *Peer) error { + var req GetByteCodesPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Service the request, potentially returning nothing in case of errors + codes := ServiceGetByteCodesQuery(backend.Chain(), &req) + + // Send back anything accumulated (or empty in case of errors) + return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{ + ID: req.ID, + Codes: codes, + }) +} + +// ServiceGetByteCodesQuery assembles the response to a byte codes query. +// It is exposed to allow external packages to test protocol behavior. +func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [][]byte { + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + if len(req.Hashes) > maxCodeLookups { + req.Hashes = req.Hashes[:maxCodeLookups] + } + // Retrieve bytecodes until the packet size limit is reached + var ( + codes [][]byte + bytes uint64 + ) + for _, hash := range req.Hashes { + if hash == types.EmptyCodeHash { + // Peers should not request the empty code, but if they do, at + // least sent them back a correct response without db lookups + codes = append(codes, []byte{}) + } else if blob := chain.ContractCodeWithPrefix(hash); len(blob) > 0 { + codes = append(codes, blob) + bytes += uint64(len(blob)) + } + if bytes > req.Bytes { + break + } + } + return codes +} + +func handleByteCodes(backend Backend, msg Decoder, peer *Peer) error { + res := new(byteCodesInput) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + + length := res.Codes.Len() + tresp := tracker.Response{ID: res.ID, MsgCode: ByteCodesMsg, Size: length} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("ByteCodes: %w", err) + } + + codes, err := res.Codes.Items() + if err != nil { + return fmt.Errorf("ByteCodes: %w", err) + } + + return backend.Handle(peer, &ByteCodesPacket{res.ID, codes}) +} + +func handleGetTrienodes(backend Backend, msg Decoder, peer *Peer) error { + var req GetTrieNodesPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Service the request, potentially returning nothing in case of errors + nodes, err := ServiceGetTrieNodesQuery(backend.Chain(), &req) + if err != nil { + return err + } + // Send back anything accumulated (or empty in case of errors) + return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ + ID: req.ID, + Nodes: nodes, + }) +} + +func nextBytes(it *rlp.Iterator) []byte { + if !it.Next() { + return nil + } + content, _, err := rlp.SplitString(it.Value()) + if err != nil { + return nil + } + return content +} + +// ServiceGetTrieNodesQuery assembles the response to a trie nodes query. +// It is exposed to allow external packages to test protocol behavior. +func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket) ([][]byte, error) { + start := time.Now() + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // Make sure we have the state associated with the request + triedb := chain.TrieDB() + + accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), triedb) + if err != nil { + // We don't have the requested state available, bail out + return nil, nil + } + // The 'reader' might be nil, in which case we cannot serve storage slots + // via snapshot. + var reader database.StateReader + if chain.Snapshots() != nil { + reader = chain.Snapshots().Snapshot(req.Root) + } + if reader == nil { + reader, _ = triedb.StateReader(req.Root) + } + + // Retrieve trie nodes until the packet size limit is reached + var ( + outerIt = req.Paths.ContentIterator() + nodes [][]byte + bytes uint64 + loads int // Trie hash expansions to count database reads + ) + for outerIt.Next() { + innerIt, err := rlp.NewListIterator(outerIt.Value()) + if err != nil { + return nodes, err + } + + switch innerIt.Count() { + case 0: + // Ensure we penalize invalid requests + return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest) + + case 1: + // If we're only retrieving an account trie node, fetch it directly + accKey := nextBytes(&innerIt) + if accKey == nil { + return nodes, fmt.Errorf("%w: invalid account node request", errBadRequest) + } + blob, resolved, err := accTrie.GetNode(accKey) + loads += resolved // always account database reads, even for failures + if err != nil { + break + } + nodes = append(nodes, blob) + bytes += uint64(len(blob)) + + default: + // Storage slots requested, open the storage trie and retrieve from there + accKey := nextBytes(&innerIt) + if accKey == nil { + return nodes, fmt.Errorf("%w: invalid account storage request", errBadRequest) + } + var stRoot common.Hash + if reader == nil { + // We don't have the requested state snapshotted yet (or it is stale), + // but can look up the account via the trie instead. + account, err := accTrie.GetAccountByHash(common.BytesToHash(accKey)) + loads += 8 // We don't know the exact cost of lookup, this is an estimate + if err != nil || account == nil { + break + } + stRoot = account.Root + } else { + account, err := reader.Account(common.BytesToHash(accKey)) + loads++ // always account database reads, even for failures + if err != nil || account == nil { + break + } + stRoot = common.BytesToHash(account.Root) + } + + id := trie.StorageTrieID(req.Root, common.BytesToHash(accKey), stRoot) + stTrie, err := trie.NewStateTrie(id, triedb) + loads++ // always account database reads, even for failures + if err != nil { + break + } + for innerIt.Next() { + path, _, err := rlp.SplitString(innerIt.Value()) + if err != nil { + return nil, fmt.Errorf("%w: invalid storage key: %v", errBadRequest, err) + } + blob, resolved, err := stTrie.GetNode(path) + loads += resolved // always account database reads, even for failures + if err != nil { + break + } + nodes = append(nodes, blob) + bytes += uint64(len(blob)) + + // Sanity check limits to avoid DoS on the store trie loads + if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent { + break + } + } + } + // Abort request processing if we've exceeded our limits + if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent { + break + } + } + return nodes, nil +} + +func handleTrieNodes(backend Backend, msg Decoder, peer *Peer) error { + res := new(trieNodesInput) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + + tresp := tracker.Response{ID: res.ID, MsgCode: TrieNodesMsg, Size: res.Nodes.Len()} + if err := peer.tracker.Fulfil(tresp); err != nil { + return fmt.Errorf("TrieNodes: %w", err) + } + nodes, err := res.Nodes.Items() + if err != nil { + return fmt.Errorf("TrieNodes: %w", err) + } + + return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes}) +} + +// nolint:unused +func handleGetAccessLists(backend Backend, msg Decoder, peer *Peer) error { + var req GetAccessListsPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + bals := ServiceGetAccessListsQuery(backend.Chain(), &req) + return p2p.Send(peer.rw, AccessListsMsg, &AccessListsPacket{ + ID: req.ID, + AccessLists: bals, + }) +} + +// ServiceGetAccessListsQuery assembles the response to an access list query. +// It is exposed to allow external packages to test protocol behavior. +func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) []rlp.RawValue { + // Cap the number of lookups + if len(req.Hashes) > maxAccessListLookups { + req.Hashes = req.Hashes[:maxAccessListLookups] + } + var ( + bals []rlp.RawValue + bytes uint64 + ) + for _, hash := range req.Hashes { + if bal := chain.GetAccessListRLP(hash); len(bal) > 0 { + bals = append(bals, bal) + bytes += uint64(len(bal)) + } else { + // Either the block is unknown or the BAL doesn't exist + bals = append(bals, nil) + } + if bytes > softResponseLimit { + break + } + } + return bals +} diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go index 25fe25822b..57b29bbe36 100644 --- a/eth/protocols/snap/protocol.go +++ b/eth/protocols/snap/protocol.go @@ -28,6 +28,7 @@ import ( // Constants to match up protocol versions and messages const ( SNAP1 = 1 + //SNAP2 = 2 ) // ProtocolName is the official short name of the `snap` protocol used during @@ -40,7 +41,7 @@ var ProtocolVersions = []uint{SNAP1} // protocolLengths are the number of implemented message corresponding to // different protocol versions. -var protocolLengths = map[uint]uint64{SNAP1: 8} +var protocolLengths = map[uint]uint64{ /*SNAP2: 10,*/ SNAP1: 8} // maxMessageSize is the maximum cap on the size of a protocol message. const maxMessageSize = 10 * 1024 * 1024 @@ -54,6 +55,8 @@ const ( ByteCodesMsg = 0x05 GetTrieNodesMsg = 0x06 TrieNodesMsg = 0x07 + GetAccessListsMsg = 0x08 + AccessListsMsg = 0x09 ) var ( @@ -215,6 +218,20 @@ type TrieNodesPacket struct { Nodes [][]byte // Requested state trie nodes } +// GetAccessListsPacket requests BALs for a set of block hashes. +type GetAccessListsPacket struct { + ID uint64 // Request ID to match up responses with + Hashes []common.Hash // Block hashes to retrieve BALs for +} + +// AccessListsPacket is the response to GetAccessListsPacket. +// Each entry corresponds to the requested hash at the same index. +// Empty entries indicate the BAL is unavailable. +type AccessListsPacket struct { + ID uint64 // ID of the request this is a response for + AccessLists []rlp.RawValue // Requested BALs +} + func (*GetAccountRangePacket) Name() string { return "GetAccountRange" } func (*GetAccountRangePacket) Kind() byte { return GetAccountRangeMsg } @@ -238,3 +255,9 @@ func (*GetTrieNodesPacket) Kind() byte { return GetTrieNodesMsg } func (*TrieNodesPacket) Name() string { return "TrieNodes" } func (*TrieNodesPacket) Kind() byte { return TrieNodesMsg } + +func (*GetAccessListsPacket) Name() string { return "GetAccessLists" } +func (*GetAccessListsPacket) Kind() byte { return GetAccessListsMsg } + +func (*AccessListsPacket) Name() string { return "AccessLists" } +func (*AccessListsPacket) Kind() byte { return AccessListsMsg }