diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go index f1b40d0d0c..8b026680d2 100644 --- a/core/blockchain_reader.go +++ b/core/blockchain_reader.go @@ -296,6 +296,14 @@ func (bc *BlockChain) GetReceiptsRLP(hash common.Hash) rlp.RawValue { return rawdb.ReadReceiptsRLP(bc.db, hash, number) } +func (bc *BlockChain) GetAccessListRLP(hash common.Hash) rlp.RawValue { + number, ok := rawdb.ReadHeaderNumber(bc.db, hash) + if !ok { + return nil + } + return rawdb.ReadAccessListRLP(bc.db, hash, number) +} + // GetUnclesInChain retrieves all the uncles from a given block backwards until // a specific distance is reached. func (bc *BlockChain) GetUnclesInChain(block *types.Block, length int) []*types.Header { diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index 071a0419fb..604604a099 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -55,6 +55,10 @@ const ( // number is there to limit the number of disk lookups. maxTrieNodeLookups = 1024 + // maxAccessListLookups is the maximum number of BALs to server. This number + // is there to limit the number of disk lookups. + maxAccessListLookups = 1024 + // maxTrieNodeTimeSpent is the maximum time we should spend on looking up trie nodes. // If we spend too much time, then it's a fairly high chance of timing out // at the remote side, which means all the work is in vain. @@ -317,6 +321,17 @@ func HandleMessage(backend Backend, peer *Peer) error { return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes}) + case msg.Code == GetAccessListsMsg: + var req GetAccessListsPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + bals := ServiceGetAccessListsQuery(backend.Chain(), &req) + return p2p.Send(peer.rw, AccessListsMsg, &AccessListsPacket{ + ID: req.ID, + AccessLists: bals, + }) + default: return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code) } @@ -654,6 +669,35 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s return nodes, nil } +// ServiceGetAccessListsQuery assembles the response to an access list query. +// It is exposed to allow external packages to test protocol behavior. +func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) []rlp.RawValue { + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // Cap the number of lookups + if len(req.Hashes) > maxAccessListLookups { + req.Hashes = req.Hashes[:maxAccessListLookups] + } + var ( + bals []rlp.RawValue + bytes uint64 + ) + for _, hash := range req.Hashes { + if bal := chain.GetAccessListRLP(hash); len(bal) > 0 { + bals = append(bals, bal) + bytes += uint64(len(bal)) + } else { + // Either the block is unknown or the BAL doesn't exist + bals = append(bals, nil) + } + if bytes > req.Bytes { + break + } + } + return bals +} + func nextBytes(it *rlp.Iterator) []byte { if !it.Next() { return nil diff --git a/eth/protocols/snap/handler_fuzzing_test.go b/eth/protocols/snap/handler_fuzzing_test.go index 4930ae9ae6..a52da0aac5 100644 --- a/eth/protocols/snap/handler_fuzzing_test.go +++ b/eth/protocols/snap/handler_fuzzing_test.go @@ -60,6 +60,12 @@ func FuzzTrieNodes(f *testing.F) { }) } +func FuzzAccessLists(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + doFuzz(data, &GetAccessListsPacket{}, GetAccessListsMsg) + }) +} + func doFuzz(input []byte, obj interface{}, code int) { bc := getChain() defer bc.Stop() diff --git a/eth/protocols/snap/handler_test.go b/eth/protocols/snap/handler_test.go new file mode 100644 index 0000000000..53b22ca7f7 --- /dev/null +++ b/eth/protocols/snap/handler_test.go @@ -0,0 +1,199 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snap + +import ( + "bytes" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" +) + +// getChainWithBALs creates a minimal test chain with BALs stored for each block. +// It returns the chain, block hashes, and the stored BAL data. +func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash, []rlp.RawValue) { + gspec := &core.Genesis{ + Config: params.TestChainConfig, + } + db := rawdb.NewMemoryDatabase() + _, blocks, _ := core.GenerateChainWithGenesis(gspec, ethash.NewFaker(), nBlocks, func(i int, gen *core.BlockGen) {}) + options := &core.BlockChainConfig{ + TrieCleanLimit: 0, + TrieDirtyLimit: 0, + TrieTimeLimit: 5 * time.Minute, + NoPrefetch: true, + SnapshotLimit: 0, + } + bc, err := core.NewBlockChain(db, gspec, ethash.NewFaker(), options) + if err != nil { + panic(err) + } + if _, err := bc.InsertChain(blocks); err != nil { + panic(err) + } + + // Store BALs for each block + var hashes []common.Hash + var bals []rlp.RawValue + for _, block := range blocks { + hash := block.Hash() + number := block.NumberU64() + bal := make(rlp.RawValue, balSize) + + // Fill with data based on block number + for j := range bal { + bal[j] = byte(number + uint64(j)) + } + rawdb.WriteAccessListRLP(db, hash, number, bal) + hashes = append(hashes, hash) + bals = append(bals, bal) + } + return bc, hashes, bals +} + +// TestServiceGetAccessListsQuery verifies that known block hashes return the +// correct BALs with positional correspondence. +func TestServiceGetAccessListsQuery(t *testing.T) { + t.Parallel() + bc, hashes, bals := getChainWithBALs(5, 100) + defer bc.Stop() + req := &GetAccessListsPacket{ + ID: 1, + Hashes: hashes, + Bytes: softResponseLimit, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Verify the results + if len(result) != len(hashes) { + t.Fatalf("expected %d results, got %d", len(hashes), len(result)) + } + for i, bal := range result { + if !bytes.Equal(bal, bals[i]) { + t.Errorf("BAL %d mismatch: got %x, want %x", i, bal, bals[i]) + } + } +} + +// TestServiceGetAccessListsQueryEmpty verifies that unknown block hashes return +// nil placeholders and that mixed known/unknown hashes preserve alignment. +func TestServiceGetAccessListsQueryEmpty(t *testing.T) { + t.Parallel() + bc, hashes, bals := getChainWithBALs(3, 100) + defer bc.Stop() + unknown := common.HexToHash("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + mixed := []common.Hash{hashes[0], unknown, hashes[1], unknown, hashes[2]} + req := &GetAccessListsPacket{ + ID: 2, + Hashes: mixed, + Bytes: softResponseLimit, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Verify length + if len(result) != len(mixed) { + t.Fatalf("expected %d results, got %d", len(mixed), len(result)) + } + + // Check positional correspondence + if !bytes.Equal(result[0], bals[0]) { + t.Errorf("index 0: expected known BAL, got %x", result[0]) + } + if result[1] != nil { + t.Errorf("index 1: expected nil for unknown hash, got %x", result[1]) + } + if !bytes.Equal(result[2], bals[1]) { + t.Errorf("index 2: expected known BAL, got %x", result[2]) + } + if result[3] != nil { + t.Errorf("index 3: expected nil for unknown hash, got %x", result[3]) + } + if !bytes.Equal(result[4], bals[2]) { + t.Errorf("index 4: expected known BAL, got %x", result[4]) + } +} + +// TestServiceGetAccessListsQueryCap verifies that requests exceeding +// maxAccessListLookups are capped. +func TestServiceGetAccessListsQueryCap(t *testing.T) { + t.Parallel() + + bc, _, _ := getChainWithBALs(2, 100) + defer bc.Stop() + + // Create a request with more hashes than the cap + hashes := make([]common.Hash, maxAccessListLookups+100) + for i := range hashes { + hashes[i] = common.BytesToHash([]byte{byte(i), byte(i >> 8)}) + } + req := &GetAccessListsPacket{ + ID: 3, + Hashes: hashes, + Bytes: softResponseLimit, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Can't get more than maxAccessListLookups results + if len(result) > maxAccessListLookups { + t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, len(result)) + } +} + +// TestServiceGetAccessListsQueryByteLimit verifies that the response stops +// once the byte limit is exceeded. The handler appends the entry that crosses +// the limit before breaking, so the total size will exceed the limit by at +// most one BAL. +func TestServiceGetAccessListsQueryByteLimit(t *testing.T) { + t.Parallel() + + // The handler will return 3/5 entries (3MB total) then break. + balSize := 1024 * 1024 + nBlocks := 5 + bc, hashes, _ := getChainWithBALs(nBlocks, balSize) + defer bc.Stop() + req := &GetAccessListsPacket{ + ID: 0, + Hashes: hashes, + Bytes: softResponseLimit, + } + result := ServiceGetAccessListsQuery(bc, req) + + // Should have stopped before returning all blocks + if len(result) >= nBlocks { + t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, len(result)) + } + + // Should have returned at least one + if len(result) == 0 { + t.Fatal("expected at least one result") + } + + // The total size should exceed the limit (the entry that crosses it is included) + var total uint64 + for _, bal := range result { + total += uint64(len(bal)) + } + if total <= softResponseLimit { + t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", total, softResponseLimit) + } +} diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go index 25fe25822b..9c51f9ddfb 100644 --- a/eth/protocols/snap/protocol.go +++ b/eth/protocols/snap/protocol.go @@ -28,6 +28,7 @@ import ( // Constants to match up protocol versions and messages const ( SNAP1 = 1 + SNAP2 = 2 ) // ProtocolName is the official short name of the `snap` protocol used during @@ -36,11 +37,11 @@ const ProtocolName = "snap" // ProtocolVersions are the supported versions of the `snap` protocol (first // is primary). -var ProtocolVersions = []uint{SNAP1} +var ProtocolVersions = []uint{SNAP2, SNAP1} // protocolLengths are the number of implemented message corresponding to // different protocol versions. -var protocolLengths = map[uint]uint64{SNAP1: 8} +var protocolLengths = map[uint]uint64{SNAP2: 10, SNAP1: 8} // maxMessageSize is the maximum cap on the size of a protocol message. const maxMessageSize = 10 * 1024 * 1024 @@ -54,6 +55,8 @@ const ( ByteCodesMsg = 0x05 GetTrieNodesMsg = 0x06 TrieNodesMsg = 0x07 + GetAccessListsMsg = 0x08 + AccessListsMsg = 0x09 ) var ( @@ -215,6 +218,21 @@ type TrieNodesPacket struct { Nodes [][]byte // Requested state trie nodes } +// GetAccessListsPacket requests BALs for a set of block hashes. +type GetAccessListsPacket struct { + ID uint64 // Request ID to match up responses with + Hashes []common.Hash // Block hashes to retrieve BALs for + Bytes uint64 // Soft limit at which to stop returning data +} + +// AccessListsPacket is the response to GetAccessListsPacket. +// Each entry corresponds to the requested hash at the same index. +// Empty entries indicate the BAL is unavailable. +type AccessListsPacket struct { + ID uint64 // ID of the request this is a response for + AccessLists []rlp.RawValue // Requested BALs +} + func (*GetAccountRangePacket) Name() string { return "GetAccountRange" } func (*GetAccountRangePacket) Kind() byte { return GetAccountRangeMsg } @@ -238,3 +256,9 @@ func (*GetTrieNodesPacket) Kind() byte { return GetTrieNodesMsg } func (*TrieNodesPacket) Name() string { return "TrieNodes" } func (*TrieNodesPacket) Kind() byte { return TrieNodesMsg } + +func (*GetAccessListsPacket) Name() string { return "GetAccessLists" } +func (*GetAccessListsPacket) Kind() byte { return GetAccessListsMsg } + +func (*AccessListsPacket) Name() string { return "AccessLists" } +func (*AccessListsPacket) Kind() byte { return AccessListsMsg }