diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go
index f1b40d0d0c..8b026680d2 100644
--- a/core/blockchain_reader.go
+++ b/core/blockchain_reader.go
@@ -296,6 +296,14 @@ func (bc *BlockChain) GetReceiptsRLP(hash common.Hash) rlp.RawValue {
return rawdb.ReadReceiptsRLP(bc.db, hash, number)
}
+func (bc *BlockChain) GetAccessListRLP(hash common.Hash) rlp.RawValue {
+ number, ok := rawdb.ReadHeaderNumber(bc.db, hash)
+ if !ok {
+ return nil
+ }
+ return rawdb.ReadAccessListRLP(bc.db, hash, number)
+}
+
// GetUnclesInChain retrieves all the uncles from a given block backwards until
// a specific distance is reached.
func (bc *BlockChain) GetUnclesInChain(block *types.Block, length int) []*types.Header {
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index 01a994dbfd..9280d455fb 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -370,7 +370,7 @@ func (dlp *downloadTesterPeer) RequestTrieNodes(id uint64, root common.Hash, cou
Paths: encPaths,
Bytes: uint64(bytes),
}
- nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req, time.Now())
+ nodes, _ := snap.ServiceGetTrieNodesQuery(dlp.chain, req)
go dlp.dl.downloader.SnapSyncer.OnTrieNodes(dlp, id, nodes)
return nil
}
diff --git a/eth/protocols/eth/handler.go b/eth/protocols/eth/handler.go
index 59512f5be7..f7d25bd8ca 100644
--- a/eth/protocols/eth/handler.go
+++ b/eth/protocols/eth/handler.go
@@ -167,7 +167,6 @@ func Handle(backend Backend, peer *Peer) error {
type msgHandler func(backend Backend, msg Decoder, peer *Peer) error
type Decoder interface {
Decode(val interface{}) error
- Time() time.Time
}
var eth69 = map[uint64]msgHandler{
diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go
index 071a0419fb..26545f2960 100644
--- a/eth/protocols/snap/handler.go
+++ b/eth/protocols/snap/handler.go
@@ -17,25 +17,14 @@
package snap
import (
- "bytes"
"fmt"
"time"
- "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
- "github.com/ethereum/go-ethereum/core/rawdb"
- "github.com/ethereum/go-ethereum/core/state/snapshot"
- "github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
- "github.com/ethereum/go-ethereum/p2p/tracker"
- "github.com/ethereum/go-ethereum/rlp"
- "github.com/ethereum/go-ethereum/trie"
- "github.com/ethereum/go-ethereum/trie/trienode"
- "github.com/ethereum/go-ethereum/triedb/database"
)
const (
@@ -55,6 +44,10 @@ const (
// number is there to limit the number of disk lookups.
maxTrieNodeLookups = 1024
+ // maxAccessListLookups is the maximum number of BALs to server. This number
+ // is there to limit the number of disk lookups.
+ maxAccessListLookups = 1024
+
// maxTrieNodeTimeSpent is the maximum time we should spend on looking up trie nodes.
// If we spend too much time, then it's a fairly high chance of timing out
// at the remote side, which means all the work is in vain.
@@ -123,6 +116,34 @@ func Handle(backend Backend, peer *Peer) error {
}
}
+type msgHandler func(backend Backend, msg Decoder, peer *Peer) error
+type Decoder interface {
+ Decode(val interface{}) error
+}
+
+var snap1 = map[uint64]msgHandler{
+ GetAccountRangeMsg: handleGetAccountRange,
+ AccountRangeMsg: handleAccountRange,
+ GetStorageRangesMsg: handleGetStorageRanges,
+ StorageRangesMsg: handleStorageRanges,
+ GetByteCodesMsg: handleGetByteCodes,
+ ByteCodesMsg: handleByteCodes,
+ GetTrieNodesMsg: handleGetTrienodes,
+ TrieNodesMsg: handleTrieNodes,
+}
+
+// nolint:unused
+var snap2 = map[uint64]msgHandler{
+ GetAccountRangeMsg: handleGetAccountRange,
+ AccountRangeMsg: handleAccountRange,
+ GetStorageRangesMsg: handleGetStorageRanges,
+ StorageRangesMsg: handleStorageRanges,
+ GetByteCodesMsg: handleGetByteCodes,
+ ByteCodesMsg: handleByteCodes,
+ GetAccessListsMsg: handleGetAccessLists,
+ // AccessListsMsg: TODO
+}
+
// HandleMessage is invoked whenever an inbound message is received from a
// remote peer on the `snap` protocol. The remote connection is torn down upon
// returning any error.
@@ -136,8 +157,19 @@ func HandleMessage(backend Backend, peer *Peer) error {
return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize)
}
defer msg.Discard()
- start := time.Now()
+
+ var handlers map[uint64]msgHandler
+ switch peer.version {
+ case SNAP1:
+ handlers = snap1
+ //case SNAP2:
+ // handlers = snap2
+ default:
+ return fmt.Errorf("unknown eth protocol version: %v", peer.version)
+ }
+
// Track the amount of time it takes to serve the request and run the handler
+ start := time.Now()
if metrics.Enabled() {
h := fmt.Sprintf("%s/%s/%d/%#02x", p2p.HandleHistName, ProtocolName, peer.Version(), msg.Code)
defer func(start time.Time) {
@@ -149,520 +181,11 @@ func HandleMessage(backend Backend, peer *Peer) error {
metrics.GetOrRegisterHistogramLazy(h, nil, sampler).Update(time.Since(start).Microseconds())
}(start)
}
- // Handle the message depending on its contents
- switch {
- case msg.Code == GetAccountRangeMsg:
- var req GetAccountRangePacket
- if err := msg.Decode(&req); err != nil {
- return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
- }
- // Service the request, potentially returning nothing in case of errors
- accounts, proofs := ServiceGetAccountRangeQuery(backend.Chain(), &req)
- // Send back anything accumulated (or empty in case of errors)
- return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{
- ID: req.ID,
- Accounts: accounts,
- Proof: proofs,
- })
-
- case msg.Code == AccountRangeMsg:
- res := new(accountRangeInput)
- if err := msg.Decode(res); err != nil {
- return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
- }
-
- // Check response validity.
- if len := res.Proof.Len(); len > 128 {
- return fmt.Errorf("AccountRange: invalid proof (length %d)", len)
- }
- tresp := tracker.Response{ID: res.ID, MsgCode: AccountRangeMsg, Size: len(res.Accounts.Content())}
- if err := peer.tracker.Fulfil(tresp); err != nil {
- return err
- }
-
- // Decode.
- accounts, err := res.Accounts.Items()
- if err != nil {
- return fmt.Errorf("AccountRange: invalid accounts list: %v", err)
- }
- proof, err := res.Proof.Items()
- if err != nil {
- return fmt.Errorf("AccountRange: invalid proof: %v", err)
- }
-
- // Ensure the range is monotonically increasing
- for i := 1; i < len(accounts); i++ {
- if bytes.Compare(accounts[i-1].Hash[:], accounts[i].Hash[:]) >= 0 {
- return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, accounts[i-1].Hash[:], i, accounts[i].Hash[:])
- }
- }
-
- return backend.Handle(peer, &AccountRangePacket{res.ID, accounts, proof})
-
- case msg.Code == GetStorageRangesMsg:
- var req GetStorageRangesPacket
- if err := msg.Decode(&req); err != nil {
- return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
- }
- // Service the request, potentially returning nothing in case of errors
- slots, proofs := ServiceGetStorageRangesQuery(backend.Chain(), &req)
-
- // Send back anything accumulated (or empty in case of errors)
- return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{
- ID: req.ID,
- Slots: slots,
- Proof: proofs,
- })
-
- case msg.Code == StorageRangesMsg:
- res := new(storageRangesInput)
- if err := msg.Decode(res); err != nil {
- return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
- }
-
- // Check response validity.
- if len := res.Proof.Len(); len > 128 {
- return fmt.Errorf("StorageRangesMsg: invalid proof (length %d)", len)
- }
- tresp := tracker.Response{ID: res.ID, MsgCode: StorageRangesMsg, Size: len(res.Slots.Content())}
- if err := peer.tracker.Fulfil(tresp); err != nil {
- return fmt.Errorf("StorageRangesMsg: %w", err)
- }
-
- // Decode.
- slotLists, err := res.Slots.Items()
- if err != nil {
- return fmt.Errorf("AccountRange: invalid accounts list: %v", err)
- }
- proof, err := res.Proof.Items()
- if err != nil {
- return fmt.Errorf("AccountRange: invalid proof: %v", err)
- }
-
- // Ensure the ranges are monotonically increasing
- for i, slots := range slotLists {
- for j := 1; j < len(slots); j++ {
- if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 {
- return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:])
- }
- }
- }
-
- return backend.Handle(peer, &StorageRangesPacket{res.ID, slotLists, proof})
-
- case msg.Code == GetByteCodesMsg:
- var req GetByteCodesPacket
- if err := msg.Decode(&req); err != nil {
- return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
- }
- // Service the request, potentially returning nothing in case of errors
- codes := ServiceGetByteCodesQuery(backend.Chain(), &req)
-
- // Send back anything accumulated (or empty in case of errors)
- return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{
- ID: req.ID,
- Codes: codes,
- })
-
- case msg.Code == ByteCodesMsg:
- res := new(byteCodesInput)
- if err := msg.Decode(res); err != nil {
- return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
- }
-
- length := res.Codes.Len()
- tresp := tracker.Response{ID: res.ID, MsgCode: ByteCodesMsg, Size: length}
- if err := peer.tracker.Fulfil(tresp); err != nil {
- return fmt.Errorf("ByteCodes: %w", err)
- }
-
- codes, err := res.Codes.Items()
- if err != nil {
- return fmt.Errorf("ByteCodes: %w", err)
- }
-
- return backend.Handle(peer, &ByteCodesPacket{res.ID, codes})
-
- case msg.Code == GetTrieNodesMsg:
- var req GetTrieNodesPacket
- if err := msg.Decode(&req); err != nil {
- return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
- }
- // Service the request, potentially returning nothing in case of errors
- nodes, err := ServiceGetTrieNodesQuery(backend.Chain(), &req, start)
- if err != nil {
- return err
- }
- // Send back anything accumulated (or empty in case of errors)
- return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{
- ID: req.ID,
- Nodes: nodes,
- })
-
- case msg.Code == TrieNodesMsg:
- res := new(trieNodesInput)
- if err := msg.Decode(res); err != nil {
- return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
- }
-
- tresp := tracker.Response{ID: res.ID, MsgCode: TrieNodesMsg, Size: res.Nodes.Len()}
- if err := peer.tracker.Fulfil(tresp); err != nil {
- return fmt.Errorf("TrieNodes: %w", err)
- }
- nodes, err := res.Nodes.Items()
- if err != nil {
- return fmt.Errorf("TrieNodes: %w", err)
- }
-
- return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes})
-
- default:
- return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
+ if handler := handlers[msg.Code]; handler != nil {
+ return handler(backend, msg, peer)
}
-}
-
-// ServiceGetAccountRangeQuery assembles the response to an account range query.
-// It is exposed to allow external packages to test protocol behavior.
-func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePacket) ([]*AccountData, [][]byte) {
- if req.Bytes > softResponseLimit {
- req.Bytes = softResponseLimit
- }
- // Retrieve the requested state and bail out if non existent
- tr, err := trie.New(trie.StateTrieID(req.Root), chain.TrieDB())
- if err != nil {
- return nil, nil
- }
- // Temporary solution: using the snapshot interface for both cases.
- // This can be removed once the hash scheme is deprecated.
- var it snapshot.AccountIterator
- if chain.TrieDB().Scheme() == rawdb.HashScheme {
- // The snapshot is assumed to be available in hash mode if
- // the SNAP protocol is enabled.
- it, err = chain.Snapshots().AccountIterator(req.Root, req.Origin)
- } else {
- it, err = chain.TrieDB().AccountIterator(req.Root, req.Origin)
- }
- if err != nil {
- return nil, nil
- }
- // Iterate over the requested range and pile accounts up
- var (
- accounts []*AccountData
- size uint64
- last common.Hash
- )
- for it.Next() {
- hash, account := it.Hash(), common.CopyBytes(it.Account())
-
- // Track the returned interval for the Merkle proofs
- last = hash
-
- // Assemble the reply item
- size += uint64(common.HashLength + len(account))
- accounts = append(accounts, &AccountData{
- Hash: hash,
- Body: account,
- })
- // If we've exceeded the request threshold, abort
- if bytes.Compare(hash[:], req.Limit[:]) >= 0 {
- break
- }
- if size > req.Bytes {
- break
- }
- }
- it.Release()
-
- // Generate the Merkle proofs for the first and last account
- proof := trienode.NewProofSet()
- if err := tr.Prove(req.Origin[:], proof); err != nil {
- log.Warn("Failed to prove account range", "origin", req.Origin, "err", err)
- return nil, nil
- }
- if last != (common.Hash{}) {
- if err := tr.Prove(last[:], proof); err != nil {
- log.Warn("Failed to prove account range", "last", last, "err", err)
- return nil, nil
- }
- }
- return accounts, proof.List()
-}
-
-func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesPacket) ([][]*StorageData, [][]byte) {
- if req.Bytes > softResponseLimit {
- req.Bytes = softResponseLimit
- }
- // TODO(karalabe): Do we want to enforce > 0 accounts and 1 account if origin is set?
- // TODO(karalabe): - Logging locally is not ideal as remote faults annoy the local user
- // TODO(karalabe): - Dropping the remote peer is less flexible wrt client bugs (slow is better than non-functional)
-
- // Calculate the hard limit at which to abort, even if mid storage trie
- hardLimit := uint64(float64(req.Bytes) * (1 + stateLookupSlack))
-
- // Retrieve storage ranges until the packet limit is reached
- var (
- slots [][]*StorageData
- proofs [][]byte
- size uint64
- )
- for _, account := range req.Accounts {
- // If we've exceeded the requested data limit, abort without opening
- // a new storage range (that we'd need to prove due to exceeded size)
- if size >= req.Bytes {
- break
- }
- // The first account might start from a different origin and end sooner
- var origin common.Hash
- if len(req.Origin) > 0 {
- origin, req.Origin = common.BytesToHash(req.Origin), nil
- }
- var limit = common.MaxHash
- if len(req.Limit) > 0 {
- limit, req.Limit = common.BytesToHash(req.Limit), nil
- }
- // Retrieve the requested state and bail out if non existent
- var (
- err error
- it snapshot.StorageIterator
- )
- // Temporary solution: using the snapshot interface for both cases.
- // This can be removed once the hash scheme is deprecated.
- if chain.TrieDB().Scheme() == rawdb.HashScheme {
- // The snapshot is assumed to be available in hash mode if
- // the SNAP protocol is enabled.
- it, err = chain.Snapshots().StorageIterator(req.Root, account, origin)
- } else {
- it, err = chain.TrieDB().StorageIterator(req.Root, account, origin)
- }
- if err != nil {
- return nil, nil
- }
- // Iterate over the requested range and pile slots up
- var (
- storage []*StorageData
- last common.Hash
- abort bool
- )
- for it.Next() {
- if size >= hardLimit {
- abort = true
- break
- }
- hash, slot := it.Hash(), common.CopyBytes(it.Slot())
-
- // Track the returned interval for the Merkle proofs
- last = hash
-
- // Assemble the reply item
- size += uint64(common.HashLength + len(slot))
- storage = append(storage, &StorageData{
- Hash: hash,
- Body: slot,
- })
- // If we've exceeded the request threshold, abort
- if bytes.Compare(hash[:], limit[:]) >= 0 {
- break
- }
- }
- if len(storage) > 0 {
- slots = append(slots, storage)
- }
- it.Release()
-
- // Generate the Merkle proofs for the first and last storage slot, but
- // only if the response was capped. If the entire storage trie included
- // in the response, no need for any proofs.
- if origin != (common.Hash{}) || (abort && len(storage) > 0) {
- // Request started at a non-zero hash or was capped prematurely, add
- // the endpoint Merkle proofs
- accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), chain.TrieDB())
- if err != nil {
- return nil, nil
- }
- acc, err := accTrie.GetAccountByHash(account)
- if err != nil || acc == nil {
- return nil, nil
- }
- id := trie.StorageTrieID(req.Root, account, acc.Root)
- stTrie, err := trie.NewStateTrie(id, chain.TrieDB())
- if err != nil {
- return nil, nil
- }
- proof := trienode.NewProofSet()
- if err := stTrie.Prove(origin[:], proof); err != nil {
- log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err)
- return nil, nil
- }
- if last != (common.Hash{}) {
- if err := stTrie.Prove(last[:], proof); err != nil {
- log.Warn("Failed to prove storage range", "last", last, "err", err)
- return nil, nil
- }
- }
- proofs = append(proofs, proof.List()...)
- // Proof terminates the reply as proofs are only added if a node
- // refuses to serve more data (exception when a contract fetch is
- // finishing, but that's that).
- break
- }
- }
- return slots, proofs
-}
-
-// ServiceGetByteCodesQuery assembles the response to a byte codes query.
-// It is exposed to allow external packages to test protocol behavior.
-func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [][]byte {
- if req.Bytes > softResponseLimit {
- req.Bytes = softResponseLimit
- }
- if len(req.Hashes) > maxCodeLookups {
- req.Hashes = req.Hashes[:maxCodeLookups]
- }
- // Retrieve bytecodes until the packet size limit is reached
- var (
- codes [][]byte
- bytes uint64
- )
- for _, hash := range req.Hashes {
- if hash == types.EmptyCodeHash {
- // Peers should not request the empty code, but if they do, at
- // least sent them back a correct response without db lookups
- codes = append(codes, []byte{})
- } else if blob := chain.ContractCodeWithPrefix(hash); len(blob) > 0 {
- codes = append(codes, blob)
- bytes += uint64(len(blob))
- }
- if bytes > req.Bytes {
- break
- }
- }
- return codes
-}
-
-// ServiceGetTrieNodesQuery assembles the response to a trie nodes query.
-// It is exposed to allow external packages to test protocol behavior.
-func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, start time.Time) ([][]byte, error) {
- if req.Bytes > softResponseLimit {
- req.Bytes = softResponseLimit
- }
- // Make sure we have the state associated with the request
- triedb := chain.TrieDB()
-
- accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), triedb)
- if err != nil {
- // We don't have the requested state available, bail out
- return nil, nil
- }
- // The 'reader' might be nil, in which case we cannot serve storage slots
- // via snapshot.
- var reader database.StateReader
- if chain.Snapshots() != nil {
- reader = chain.Snapshots().Snapshot(req.Root)
- }
- if reader == nil {
- reader, _ = triedb.StateReader(req.Root)
- }
-
- // Retrieve trie nodes until the packet size limit is reached
- var (
- outerIt = req.Paths.ContentIterator()
- nodes [][]byte
- bytes uint64
- loads int // Trie hash expansions to count database reads
- )
- for outerIt.Next() {
- innerIt, err := rlp.NewListIterator(outerIt.Value())
- if err != nil {
- return nodes, err
- }
-
- switch innerIt.Count() {
- case 0:
- // Ensure we penalize invalid requests
- return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest)
-
- case 1:
- // If we're only retrieving an account trie node, fetch it directly
- accKey := nextBytes(&innerIt)
- if accKey == nil {
- return nodes, fmt.Errorf("%w: invalid account node request", errBadRequest)
- }
- blob, resolved, err := accTrie.GetNode(accKey)
- loads += resolved // always account database reads, even for failures
- if err != nil {
- break
- }
- nodes = append(nodes, blob)
- bytes += uint64(len(blob))
-
- default:
- // Storage slots requested, open the storage trie and retrieve from there
- accKey := nextBytes(&innerIt)
- if accKey == nil {
- return nodes, fmt.Errorf("%w: invalid account storage request", errBadRequest)
- }
- var stRoot common.Hash
- if reader == nil {
- // We don't have the requested state snapshotted yet (or it is stale),
- // but can look up the account via the trie instead.
- account, err := accTrie.GetAccountByHash(common.BytesToHash(accKey))
- loads += 8 // We don't know the exact cost of lookup, this is an estimate
- if err != nil || account == nil {
- break
- }
- stRoot = account.Root
- } else {
- account, err := reader.Account(common.BytesToHash(accKey))
- loads++ // always account database reads, even for failures
- if err != nil || account == nil {
- break
- }
- stRoot = common.BytesToHash(account.Root)
- }
-
- id := trie.StorageTrieID(req.Root, common.BytesToHash(accKey), stRoot)
- stTrie, err := trie.NewStateTrie(id, triedb)
- loads++ // always account database reads, even for failures
- if err != nil {
- break
- }
- for innerIt.Next() {
- path, _, err := rlp.SplitString(innerIt.Value())
- if err != nil {
- return nil, fmt.Errorf("%w: invalid storage key: %v", errBadRequest, err)
- }
- blob, resolved, err := stTrie.GetNode(path)
- loads += resolved // always account database reads, even for failures
- if err != nil {
- break
- }
- nodes = append(nodes, blob)
- bytes += uint64(len(blob))
-
- // Sanity check limits to avoid DoS on the store trie loads
- if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent {
- break
- }
- }
- }
- // Abort request processing if we've exceeded our limits
- if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent {
- break
- }
- }
- return nodes, nil
-}
-
-func nextBytes(it *rlp.Iterator) []byte {
- if !it.Next() {
- return nil
- }
- content, _, err := rlp.SplitString(it.Value())
- if err != nil {
- return nil
- }
- return content
+ return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
}
// NodeInfo represents a short summary of the `snap` sub-protocol metadata
diff --git a/eth/protocols/snap/handler_fuzzing_test.go b/eth/protocols/snap/handler_fuzzing_test.go
index 4930ae9ae6..a52da0aac5 100644
--- a/eth/protocols/snap/handler_fuzzing_test.go
+++ b/eth/protocols/snap/handler_fuzzing_test.go
@@ -60,6 +60,12 @@ func FuzzTrieNodes(f *testing.F) {
})
}
+func FuzzAccessLists(f *testing.F) {
+ f.Fuzz(func(t *testing.T, data []byte) {
+ doFuzz(data, &GetAccessListsPacket{}, GetAccessListsMsg)
+ })
+}
+
func doFuzz(input []byte, obj interface{}, code int) {
bc := getChain()
defer bc.Stop()
diff --git a/eth/protocols/snap/handler_test.go b/eth/protocols/snap/handler_test.go
new file mode 100644
index 0000000000..cb4b378a8d
--- /dev/null
+++ b/eth/protocols/snap/handler_test.go
@@ -0,0 +1,195 @@
+// Copyright 2026 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snap
+
+import (
+ "bytes"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/consensus/ethash"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// getChainWithBALs creates a minimal test chain with BALs stored for each block.
+// It returns the chain, block hashes, and the stored BAL data.
+func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash, []rlp.RawValue) {
+ gspec := &core.Genesis{
+ Config: params.TestChainConfig,
+ }
+ db := rawdb.NewMemoryDatabase()
+ _, blocks, _ := core.GenerateChainWithGenesis(gspec, ethash.NewFaker(), nBlocks, func(i int, gen *core.BlockGen) {})
+ options := &core.BlockChainConfig{
+ TrieCleanLimit: 0,
+ TrieDirtyLimit: 0,
+ TrieTimeLimit: 5 * time.Minute,
+ NoPrefetch: true,
+ SnapshotLimit: 0,
+ }
+ bc, err := core.NewBlockChain(db, gspec, ethash.NewFaker(), options)
+ if err != nil {
+ panic(err)
+ }
+ if _, err := bc.InsertChain(blocks); err != nil {
+ panic(err)
+ }
+
+ // Store BALs for each block
+ var hashes []common.Hash
+ var bals []rlp.RawValue
+ for _, block := range blocks {
+ hash := block.Hash()
+ number := block.NumberU64()
+ bal := make(rlp.RawValue, balSize)
+
+ // Fill with data based on block number
+ for j := range bal {
+ bal[j] = byte(number + uint64(j))
+ }
+ rawdb.WriteAccessListRLP(db, hash, number, bal)
+ hashes = append(hashes, hash)
+ bals = append(bals, bal)
+ }
+ return bc, hashes, bals
+}
+
+// TestServiceGetAccessListsQuery verifies that known block hashes return the
+// correct BALs with positional correspondence.
+func TestServiceGetAccessListsQuery(t *testing.T) {
+ t.Parallel()
+ bc, hashes, bals := getChainWithBALs(5, 100)
+ defer bc.Stop()
+ req := &GetAccessListsPacket{
+ ID: 1,
+ Hashes: hashes,
+ }
+ result := ServiceGetAccessListsQuery(bc, req)
+
+ // Verify the results
+ if len(result) != len(hashes) {
+ t.Fatalf("expected %d results, got %d", len(hashes), len(result))
+ }
+ for i, bal := range result {
+ if !bytes.Equal(bal, bals[i]) {
+ t.Errorf("BAL %d mismatch: got %x, want %x", i, bal, bals[i])
+ }
+ }
+}
+
+// TestServiceGetAccessListsQueryEmpty verifies that unknown block hashes return
+// nil placeholders and that mixed known/unknown hashes preserve alignment.
+func TestServiceGetAccessListsQueryEmpty(t *testing.T) {
+ t.Parallel()
+ bc, hashes, bals := getChainWithBALs(3, 100)
+ defer bc.Stop()
+ unknown := common.HexToHash("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef")
+ mixed := []common.Hash{hashes[0], unknown, hashes[1], unknown, hashes[2]}
+ req := &GetAccessListsPacket{
+ ID: 2,
+ Hashes: mixed,
+ }
+ result := ServiceGetAccessListsQuery(bc, req)
+
+ // Verify length
+ if len(result) != len(mixed) {
+ t.Fatalf("expected %d results, got %d", len(mixed), len(result))
+ }
+
+ // Check positional correspondence
+ if !bytes.Equal(result[0], bals[0]) {
+ t.Errorf("index 0: expected known BAL, got %x", result[0])
+ }
+ if result[1] != nil {
+ t.Errorf("index 1: expected nil for unknown hash, got %x", result[1])
+ }
+ if !bytes.Equal(result[2], bals[1]) {
+ t.Errorf("index 2: expected known BAL, got %x", result[2])
+ }
+ if result[3] != nil {
+ t.Errorf("index 3: expected nil for unknown hash, got %x", result[3])
+ }
+ if !bytes.Equal(result[4], bals[2]) {
+ t.Errorf("index 4: expected known BAL, got %x", result[4])
+ }
+}
+
+// TestServiceGetAccessListsQueryCap verifies that requests exceeding
+// maxAccessListLookups are capped.
+func TestServiceGetAccessListsQueryCap(t *testing.T) {
+ t.Parallel()
+
+ bc, _, _ := getChainWithBALs(2, 100)
+ defer bc.Stop()
+
+ // Create a request with more hashes than the cap
+ hashes := make([]common.Hash, maxAccessListLookups+100)
+ for i := range hashes {
+ hashes[i] = common.BytesToHash([]byte{byte(i), byte(i >> 8)})
+ }
+ req := &GetAccessListsPacket{
+ ID: 3,
+ Hashes: hashes,
+ }
+ result := ServiceGetAccessListsQuery(bc, req)
+
+ // Can't get more than maxAccessListLookups results
+ if len(result) > maxAccessListLookups {
+ t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, len(result))
+ }
+}
+
+// TestServiceGetAccessListsQueryByteLimit verifies that the response stops
+// once the byte limit is exceeded. The handler appends the entry that crosses
+// the limit before breaking, so the total size will exceed the limit by at
+// most one BAL.
+func TestServiceGetAccessListsQueryByteLimit(t *testing.T) {
+ t.Parallel()
+
+ // The handler will return 3/5 entries (3MB total) then break.
+ balSize := 1024 * 1024
+ nBlocks := 5
+ bc, hashes, _ := getChainWithBALs(nBlocks, balSize)
+ defer bc.Stop()
+ req := &GetAccessListsPacket{
+ ID: 0,
+ Hashes: hashes,
+ }
+ result := ServiceGetAccessListsQuery(bc, req)
+
+ // Should have stopped before returning all blocks
+ if len(result) >= nBlocks {
+ t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, len(result))
+ }
+
+ // Should have returned at least one
+ if len(result) == 0 {
+ t.Fatal("expected at least one result")
+ }
+
+ // The total size should exceed the limit (the entry that crosses it is included)
+ var total uint64
+ for _, bal := range result {
+ total += uint64(len(bal))
+ }
+ if total <= softResponseLimit {
+ t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", total, softResponseLimit)
+ }
+}
diff --git a/eth/protocols/snap/handlers.go b/eth/protocols/snap/handlers.go
new file mode 100644
index 0000000000..64522343f9
--- /dev/null
+++ b/eth/protocols/snap/handlers.go
@@ -0,0 +1,593 @@
+// Copyright 2026 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see
+
+package snap
+
+import (
+ "bytes"
+ "fmt"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/state/snapshot"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/p2p"
+ "github.com/ethereum/go-ethereum/p2p/tracker"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
+ "github.com/ethereum/go-ethereum/trie/trienode"
+ "github.com/ethereum/go-ethereum/triedb/database"
+)
+
+func handleGetAccountRange(backend Backend, msg Decoder, peer *Peer) error {
+ var req GetAccountRangePacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Service the request, potentially returning nothing in case of errors
+ accounts, proofs := ServiceGetAccountRangeQuery(backend.Chain(), &req)
+
+ // Send back anything accumulated (or empty in case of errors)
+ return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{
+ ID: req.ID,
+ Accounts: accounts,
+ Proof: proofs,
+ })
+}
+
+// ServiceGetAccountRangeQuery assembles the response to an account range query.
+// It is exposed to allow external packages to test protocol behavior.
+func ServiceGetAccountRangeQuery(chain *core.BlockChain, req *GetAccountRangePacket) ([]*AccountData, [][]byte) {
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ // Retrieve the requested state and bail out if non existent
+ tr, err := trie.New(trie.StateTrieID(req.Root), chain.TrieDB())
+ if err != nil {
+ return nil, nil
+ }
+ // Temporary solution: using the snapshot interface for both cases.
+ // This can be removed once the hash scheme is deprecated.
+ var it snapshot.AccountIterator
+ if chain.TrieDB().Scheme() == rawdb.HashScheme {
+ // The snapshot is assumed to be available in hash mode if
+ // the SNAP protocol is enabled.
+ it, err = chain.Snapshots().AccountIterator(req.Root, req.Origin)
+ } else {
+ it, err = chain.TrieDB().AccountIterator(req.Root, req.Origin)
+ }
+ if err != nil {
+ return nil, nil
+ }
+ // Iterate over the requested range and pile accounts up
+ var (
+ accounts []*AccountData
+ size uint64
+ last common.Hash
+ )
+ for it.Next() {
+ hash, account := it.Hash(), common.CopyBytes(it.Account())
+
+ // Track the returned interval for the Merkle proofs
+ last = hash
+
+ // Assemble the reply item
+ size += uint64(common.HashLength + len(account))
+ accounts = append(accounts, &AccountData{
+ Hash: hash,
+ Body: account,
+ })
+ // If we've exceeded the request threshold, abort
+ if bytes.Compare(hash[:], req.Limit[:]) >= 0 {
+ break
+ }
+ if size > req.Bytes {
+ break
+ }
+ }
+ it.Release()
+
+ // Generate the Merkle proofs for the first and last account
+ proof := trienode.NewProofSet()
+ if err := tr.Prove(req.Origin[:], proof); err != nil {
+ log.Warn("Failed to prove account range", "origin", req.Origin, "err", err)
+ return nil, nil
+ }
+ if last != (common.Hash{}) {
+ if err := tr.Prove(last[:], proof); err != nil {
+ log.Warn("Failed to prove account range", "last", last, "err", err)
+ return nil, nil
+ }
+ }
+ return accounts, proof.List()
+}
+
+func handleAccountRange(backend Backend, msg Decoder, peer *Peer) error {
+ res := new(accountRangeInput)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+
+ // Check response validity.
+ if len := res.Proof.Len(); len > 128 {
+ return fmt.Errorf("AccountRange: invalid proof (length %d)", len)
+ }
+ tresp := tracker.Response{ID: res.ID, MsgCode: AccountRangeMsg, Size: len(res.Accounts.Content())}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return err
+ }
+
+ // Decode.
+ accounts, err := res.Accounts.Items()
+ if err != nil {
+ return fmt.Errorf("AccountRange: invalid accounts list: %v", err)
+ }
+ proof, err := res.Proof.Items()
+ if err != nil {
+ return fmt.Errorf("AccountRange: invalid proof: %v", err)
+ }
+
+ // Ensure the range is monotonically increasing
+ for i := 1; i < len(accounts); i++ {
+ if bytes.Compare(accounts[i-1].Hash[:], accounts[i].Hash[:]) >= 0 {
+ return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, accounts[i-1].Hash[:], i, accounts[i].Hash[:])
+ }
+ }
+
+ return backend.Handle(peer, &AccountRangePacket{res.ID, accounts, proof})
+}
+
+func handleGetStorageRanges(backend Backend, msg Decoder, peer *Peer) error {
+ var req GetStorageRangesPacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Service the request, potentially returning nothing in case of errors
+ slots, proofs := ServiceGetStorageRangesQuery(backend.Chain(), &req)
+
+ // Send back anything accumulated (or empty in case of errors)
+ return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{
+ ID: req.ID,
+ Slots: slots,
+ Proof: proofs,
+ })
+}
+
+func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesPacket) ([][]*StorageData, [][]byte) {
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ // TODO(karalabe): Do we want to enforce > 0 accounts and 1 account if origin is set?
+ // TODO(karalabe): - Logging locally is not ideal as remote faults annoy the local user
+ // TODO(karalabe): - Dropping the remote peer is less flexible wrt client bugs (slow is better than non-functional)
+
+ // Calculate the hard limit at which to abort, even if mid storage trie
+ hardLimit := uint64(float64(req.Bytes) * (1 + stateLookupSlack))
+
+ // Retrieve storage ranges until the packet limit is reached
+ var (
+ slots [][]*StorageData
+ proofs [][]byte
+ size uint64
+ )
+ for _, account := range req.Accounts {
+ // If we've exceeded the requested data limit, abort without opening
+ // a new storage range (that we'd need to prove due to exceeded size)
+ if size >= req.Bytes {
+ break
+ }
+ // The first account might start from a different origin and end sooner
+ var origin common.Hash
+ if len(req.Origin) > 0 {
+ origin, req.Origin = common.BytesToHash(req.Origin), nil
+ }
+ var limit = common.MaxHash
+ if len(req.Limit) > 0 {
+ limit, req.Limit = common.BytesToHash(req.Limit), nil
+ }
+ // Retrieve the requested state and bail out if non existent
+ var (
+ err error
+ it snapshot.StorageIterator
+ )
+ // Temporary solution: using the snapshot interface for both cases.
+ // This can be removed once the hash scheme is deprecated.
+ if chain.TrieDB().Scheme() == rawdb.HashScheme {
+ // The snapshot is assumed to be available in hash mode if
+ // the SNAP protocol is enabled.
+ it, err = chain.Snapshots().StorageIterator(req.Root, account, origin)
+ } else {
+ it, err = chain.TrieDB().StorageIterator(req.Root, account, origin)
+ }
+ if err != nil {
+ return nil, nil
+ }
+ // Iterate over the requested range and pile slots up
+ var (
+ storage []*StorageData
+ last common.Hash
+ abort bool
+ )
+ for it.Next() {
+ if size >= hardLimit {
+ abort = true
+ break
+ }
+ hash, slot := it.Hash(), common.CopyBytes(it.Slot())
+
+ // Track the returned interval for the Merkle proofs
+ last = hash
+
+ // Assemble the reply item
+ size += uint64(common.HashLength + len(slot))
+ storage = append(storage, &StorageData{
+ Hash: hash,
+ Body: slot,
+ })
+ // If we've exceeded the request threshold, abort
+ if bytes.Compare(hash[:], limit[:]) >= 0 {
+ break
+ }
+ }
+ if len(storage) > 0 {
+ slots = append(slots, storage)
+ }
+ it.Release()
+
+ // Generate the Merkle proofs for the first and last storage slot, but
+ // only if the response was capped. If the entire storage trie included
+ // in the response, no need for any proofs.
+ if origin != (common.Hash{}) || (abort && len(storage) > 0) {
+ // Request started at a non-zero hash or was capped prematurely, add
+ // the endpoint Merkle proofs
+ accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), chain.TrieDB())
+ if err != nil {
+ return nil, nil
+ }
+ acc, err := accTrie.GetAccountByHash(account)
+ if err != nil || acc == nil {
+ return nil, nil
+ }
+ id := trie.StorageTrieID(req.Root, account, acc.Root)
+ stTrie, err := trie.NewStateTrie(id, chain.TrieDB())
+ if err != nil {
+ return nil, nil
+ }
+ proof := trienode.NewProofSet()
+ if err := stTrie.Prove(origin[:], proof); err != nil {
+ log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err)
+ return nil, nil
+ }
+ if last != (common.Hash{}) {
+ if err := stTrie.Prove(last[:], proof); err != nil {
+ log.Warn("Failed to prove storage range", "last", last, "err", err)
+ return nil, nil
+ }
+ }
+ proofs = append(proofs, proof.List()...)
+ // Proof terminates the reply as proofs are only added if a node
+ // refuses to serve more data (exception when a contract fetch is
+ // finishing, but that's that).
+ break
+ }
+ }
+ return slots, proofs
+}
+
+func handleStorageRanges(backend Backend, msg Decoder, peer *Peer) error {
+ res := new(storageRangesInput)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+
+ // Check response validity.
+ if len := res.Proof.Len(); len > 128 {
+ return fmt.Errorf("StorageRangesMsg: invalid proof (length %d)", len)
+ }
+ tresp := tracker.Response{ID: res.ID, MsgCode: StorageRangesMsg, Size: len(res.Slots.Content())}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("StorageRangesMsg: %w", err)
+ }
+
+ // Decode.
+ slotLists, err := res.Slots.Items()
+ if err != nil {
+ return fmt.Errorf("AccountRange: invalid accounts list: %v", err)
+ }
+ proof, err := res.Proof.Items()
+ if err != nil {
+ return fmt.Errorf("AccountRange: invalid proof: %v", err)
+ }
+
+ // Ensure the ranges are monotonically increasing
+ for i, slots := range slotLists {
+ for j := 1; j < len(slots); j++ {
+ if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 {
+ return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:])
+ }
+ }
+ }
+
+ return backend.Handle(peer, &StorageRangesPacket{res.ID, slotLists, proof})
+}
+
+func handleGetByteCodes(backend Backend, msg Decoder, peer *Peer) error {
+ var req GetByteCodesPacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Service the request, potentially returning nothing in case of errors
+ codes := ServiceGetByteCodesQuery(backend.Chain(), &req)
+
+ // Send back anything accumulated (or empty in case of errors)
+ return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{
+ ID: req.ID,
+ Codes: codes,
+ })
+}
+
+// ServiceGetByteCodesQuery assembles the response to a byte codes query.
+// It is exposed to allow external packages to test protocol behavior.
+func ServiceGetByteCodesQuery(chain *core.BlockChain, req *GetByteCodesPacket) [][]byte {
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ if len(req.Hashes) > maxCodeLookups {
+ req.Hashes = req.Hashes[:maxCodeLookups]
+ }
+ // Retrieve bytecodes until the packet size limit is reached
+ var (
+ codes [][]byte
+ bytes uint64
+ )
+ for _, hash := range req.Hashes {
+ if hash == types.EmptyCodeHash {
+ // Peers should not request the empty code, but if they do, at
+ // least sent them back a correct response without db lookups
+ codes = append(codes, []byte{})
+ } else if blob := chain.ContractCodeWithPrefix(hash); len(blob) > 0 {
+ codes = append(codes, blob)
+ bytes += uint64(len(blob))
+ }
+ if bytes > req.Bytes {
+ break
+ }
+ }
+ return codes
+}
+
+func handleByteCodes(backend Backend, msg Decoder, peer *Peer) error {
+ res := new(byteCodesInput)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+
+ length := res.Codes.Len()
+ tresp := tracker.Response{ID: res.ID, MsgCode: ByteCodesMsg, Size: length}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("ByteCodes: %w", err)
+ }
+
+ codes, err := res.Codes.Items()
+ if err != nil {
+ return fmt.Errorf("ByteCodes: %w", err)
+ }
+
+ return backend.Handle(peer, &ByteCodesPacket{res.ID, codes})
+}
+
+func handleGetTrienodes(backend Backend, msg Decoder, peer *Peer) error {
+ var req GetTrieNodesPacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ // Service the request, potentially returning nothing in case of errors
+ nodes, err := ServiceGetTrieNodesQuery(backend.Chain(), &req)
+ if err != nil {
+ return err
+ }
+ // Send back anything accumulated (or empty in case of errors)
+ return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{
+ ID: req.ID,
+ Nodes: nodes,
+ })
+}
+
+func nextBytes(it *rlp.Iterator) []byte {
+ if !it.Next() {
+ return nil
+ }
+ content, _, err := rlp.SplitString(it.Value())
+ if err != nil {
+ return nil
+ }
+ return content
+}
+
+// ServiceGetTrieNodesQuery assembles the response to a trie nodes query.
+// It is exposed to allow external packages to test protocol behavior.
+func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket) ([][]byte, error) {
+ start := time.Now()
+ if req.Bytes > softResponseLimit {
+ req.Bytes = softResponseLimit
+ }
+ // Make sure we have the state associated with the request
+ triedb := chain.TrieDB()
+
+ accTrie, err := trie.NewStateTrie(trie.StateTrieID(req.Root), triedb)
+ if err != nil {
+ // We don't have the requested state available, bail out
+ return nil, nil
+ }
+ // The 'reader' might be nil, in which case we cannot serve storage slots
+ // via snapshot.
+ var reader database.StateReader
+ if chain.Snapshots() != nil {
+ reader = chain.Snapshots().Snapshot(req.Root)
+ }
+ if reader == nil {
+ reader, _ = triedb.StateReader(req.Root)
+ }
+
+ // Retrieve trie nodes until the packet size limit is reached
+ var (
+ outerIt = req.Paths.ContentIterator()
+ nodes [][]byte
+ bytes uint64
+ loads int // Trie hash expansions to count database reads
+ )
+ for outerIt.Next() {
+ innerIt, err := rlp.NewListIterator(outerIt.Value())
+ if err != nil {
+ return nodes, err
+ }
+
+ switch innerIt.Count() {
+ case 0:
+ // Ensure we penalize invalid requests
+ return nil, fmt.Errorf("%w: zero-item pathset requested", errBadRequest)
+
+ case 1:
+ // If we're only retrieving an account trie node, fetch it directly
+ accKey := nextBytes(&innerIt)
+ if accKey == nil {
+ return nodes, fmt.Errorf("%w: invalid account node request", errBadRequest)
+ }
+ blob, resolved, err := accTrie.GetNode(accKey)
+ loads += resolved // always account database reads, even for failures
+ if err != nil {
+ break
+ }
+ nodes = append(nodes, blob)
+ bytes += uint64(len(blob))
+
+ default:
+ // Storage slots requested, open the storage trie and retrieve from there
+ accKey := nextBytes(&innerIt)
+ if accKey == nil {
+ return nodes, fmt.Errorf("%w: invalid account storage request", errBadRequest)
+ }
+ var stRoot common.Hash
+ if reader == nil {
+ // We don't have the requested state snapshotted yet (or it is stale),
+ // but can look up the account via the trie instead.
+ account, err := accTrie.GetAccountByHash(common.BytesToHash(accKey))
+ loads += 8 // We don't know the exact cost of lookup, this is an estimate
+ if err != nil || account == nil {
+ break
+ }
+ stRoot = account.Root
+ } else {
+ account, err := reader.Account(common.BytesToHash(accKey))
+ loads++ // always account database reads, even for failures
+ if err != nil || account == nil {
+ break
+ }
+ stRoot = common.BytesToHash(account.Root)
+ }
+
+ id := trie.StorageTrieID(req.Root, common.BytesToHash(accKey), stRoot)
+ stTrie, err := trie.NewStateTrie(id, triedb)
+ loads++ // always account database reads, even for failures
+ if err != nil {
+ break
+ }
+ for innerIt.Next() {
+ path, _, err := rlp.SplitString(innerIt.Value())
+ if err != nil {
+ return nil, fmt.Errorf("%w: invalid storage key: %v", errBadRequest, err)
+ }
+ blob, resolved, err := stTrie.GetNode(path)
+ loads += resolved // always account database reads, even for failures
+ if err != nil {
+ break
+ }
+ nodes = append(nodes, blob)
+ bytes += uint64(len(blob))
+
+ // Sanity check limits to avoid DoS on the store trie loads
+ if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent {
+ break
+ }
+ }
+ }
+ // Abort request processing if we've exceeded our limits
+ if bytes > req.Bytes || loads > maxTrieNodeLookups || time.Since(start) > maxTrieNodeTimeSpent {
+ break
+ }
+ }
+ return nodes, nil
+}
+
+func handleTrieNodes(backend Backend, msg Decoder, peer *Peer) error {
+ res := new(trieNodesInput)
+ if err := msg.Decode(res); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+
+ tresp := tracker.Response{ID: res.ID, MsgCode: TrieNodesMsg, Size: res.Nodes.Len()}
+ if err := peer.tracker.Fulfil(tresp); err != nil {
+ return fmt.Errorf("TrieNodes: %w", err)
+ }
+ nodes, err := res.Nodes.Items()
+ if err != nil {
+ return fmt.Errorf("TrieNodes: %w", err)
+ }
+
+ return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes})
+}
+
+// nolint:unused
+func handleGetAccessLists(backend Backend, msg Decoder, peer *Peer) error {
+ var req GetAccessListsPacket
+ if err := msg.Decode(&req); err != nil {
+ return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
+ }
+ bals := ServiceGetAccessListsQuery(backend.Chain(), &req)
+ return p2p.Send(peer.rw, AccessListsMsg, &AccessListsPacket{
+ ID: req.ID,
+ AccessLists: bals,
+ })
+}
+
+// ServiceGetAccessListsQuery assembles the response to an access list query.
+// It is exposed to allow external packages to test protocol behavior.
+func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) []rlp.RawValue {
+ // Cap the number of lookups
+ if len(req.Hashes) > maxAccessListLookups {
+ req.Hashes = req.Hashes[:maxAccessListLookups]
+ }
+ var (
+ bals []rlp.RawValue
+ bytes uint64
+ )
+ for _, hash := range req.Hashes {
+ if bal := chain.GetAccessListRLP(hash); len(bal) > 0 {
+ bals = append(bals, bal)
+ bytes += uint64(len(bal))
+ } else {
+ // Either the block is unknown or the BAL doesn't exist
+ bals = append(bals, nil)
+ }
+ if bytes > softResponseLimit {
+ break
+ }
+ }
+ return bals
+}
diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go
index 25fe25822b..57b29bbe36 100644
--- a/eth/protocols/snap/protocol.go
+++ b/eth/protocols/snap/protocol.go
@@ -28,6 +28,7 @@ import (
// Constants to match up protocol versions and messages
const (
SNAP1 = 1
+ //SNAP2 = 2
)
// ProtocolName is the official short name of the `snap` protocol used during
@@ -40,7 +41,7 @@ var ProtocolVersions = []uint{SNAP1}
// protocolLengths are the number of implemented message corresponding to
// different protocol versions.
-var protocolLengths = map[uint]uint64{SNAP1: 8}
+var protocolLengths = map[uint]uint64{ /*SNAP2: 10,*/ SNAP1: 8}
// maxMessageSize is the maximum cap on the size of a protocol message.
const maxMessageSize = 10 * 1024 * 1024
@@ -54,6 +55,8 @@ const (
ByteCodesMsg = 0x05
GetTrieNodesMsg = 0x06
TrieNodesMsg = 0x07
+ GetAccessListsMsg = 0x08
+ AccessListsMsg = 0x09
)
var (
@@ -215,6 +218,20 @@ type TrieNodesPacket struct {
Nodes [][]byte // Requested state trie nodes
}
+// GetAccessListsPacket requests BALs for a set of block hashes.
+type GetAccessListsPacket struct {
+ ID uint64 // Request ID to match up responses with
+ Hashes []common.Hash // Block hashes to retrieve BALs for
+}
+
+// AccessListsPacket is the response to GetAccessListsPacket.
+// Each entry corresponds to the requested hash at the same index.
+// Empty entries indicate the BAL is unavailable.
+type AccessListsPacket struct {
+ ID uint64 // ID of the request this is a response for
+ AccessLists []rlp.RawValue // Requested BALs
+}
+
func (*GetAccountRangePacket) Name() string { return "GetAccountRange" }
func (*GetAccountRangePacket) Kind() byte { return GetAccountRangeMsg }
@@ -238,3 +255,9 @@ func (*GetTrieNodesPacket) Kind() byte { return GetTrieNodesMsg }
func (*TrieNodesPacket) Name() string { return "TrieNodes" }
func (*TrieNodesPacket) Kind() byte { return TrieNodesMsg }
+
+func (*GetAccessListsPacket) Name() string { return "GetAccessLists" }
+func (*GetAccessListsPacket) Kind() byte { return GetAccessListsMsg }
+
+func (*AccessListsPacket) Name() string { return "AccessLists" }
+func (*AccessListsPacket) Kind() byte { return AccessListsMsg }