mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-06-12 09:51:36 +00:00
core, eth/protocols/snap: Snap/2 Protocol + BAL Serving
This commit is contained in:
parent
dc3794e3dc
commit
15144e391e
5 changed files with 283 additions and 2 deletions
|
|
@ -296,6 +296,14 @@ func (bc *BlockChain) GetReceiptsRLP(hash common.Hash) rlp.RawValue {
|
|||
return rawdb.ReadReceiptsRLP(bc.db, hash, number)
|
||||
}
|
||||
|
||||
func (bc *BlockChain) GetAccessListRLP(hash common.Hash) rlp.RawValue {
|
||||
number, ok := rawdb.ReadHeaderNumber(bc.db, hash)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return rawdb.ReadAccessListRLP(bc.db, hash, number)
|
||||
}
|
||||
|
||||
// GetUnclesInChain retrieves all the uncles from a given block backwards until
|
||||
// a specific distance is reached.
|
||||
func (bc *BlockChain) GetUnclesInChain(block *types.Block, length int) []*types.Header {
|
||||
|
|
|
|||
|
|
@ -55,6 +55,10 @@ const (
|
|||
// number is there to limit the number of disk lookups.
|
||||
maxTrieNodeLookups = 1024
|
||||
|
||||
// maxAccessListLookups is the maximum number of BALs to server. This number
|
||||
// is there to limit the number of disk lookups.
|
||||
maxAccessListLookups = 1024
|
||||
|
||||
// maxTrieNodeTimeSpent is the maximum time we should spend on looking up trie nodes.
|
||||
// If we spend too much time, then it's a fairly high chance of timing out
|
||||
// at the remote side, which means all the work is in vain.
|
||||
|
|
@ -317,6 +321,17 @@ func HandleMessage(backend Backend, peer *Peer) error {
|
|||
|
||||
return backend.Handle(peer, &TrieNodesPacket{res.ID, nodes})
|
||||
|
||||
case msg.Code == GetAccessListsMsg:
|
||||
var req GetAccessListsPacket
|
||||
if err := msg.Decode(&req); err != nil {
|
||||
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
|
||||
}
|
||||
bals := ServiceGetAccessListsQuery(backend.Chain(), &req)
|
||||
return p2p.Send(peer.rw, AccessListsMsg, &AccessListsPacket{
|
||||
ID: req.ID,
|
||||
AccessLists: bals,
|
||||
})
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
|
||||
}
|
||||
|
|
@ -654,6 +669,35 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s
|
|||
return nodes, nil
|
||||
}
|
||||
|
||||
// ServiceGetAccessListsQuery assembles the response to an access list query.
|
||||
// It is exposed to allow external packages to test protocol behavior.
|
||||
func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) []rlp.RawValue {
|
||||
if req.Bytes > softResponseLimit {
|
||||
req.Bytes = softResponseLimit
|
||||
}
|
||||
// Cap the number of lookups
|
||||
if len(req.Hashes) > maxAccessListLookups {
|
||||
req.Hashes = req.Hashes[:maxAccessListLookups]
|
||||
}
|
||||
var (
|
||||
bals []rlp.RawValue
|
||||
bytes uint64
|
||||
)
|
||||
for _, hash := range req.Hashes {
|
||||
if bal := chain.GetAccessListRLP(hash); len(bal) > 0 {
|
||||
bals = append(bals, bal)
|
||||
bytes += uint64(len(bal))
|
||||
} else {
|
||||
// Either the block is unknown or the BAL doesn't exist
|
||||
bals = append(bals, nil)
|
||||
}
|
||||
if bytes > req.Bytes {
|
||||
break
|
||||
}
|
||||
}
|
||||
return bals
|
||||
}
|
||||
|
||||
func nextBytes(it *rlp.Iterator) []byte {
|
||||
if !it.Next() {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -60,6 +60,12 @@ func FuzzTrieNodes(f *testing.F) {
|
|||
})
|
||||
}
|
||||
|
||||
func FuzzAccessLists(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
doFuzz(data, &GetAccessListsPacket{}, GetAccessListsMsg)
|
||||
})
|
||||
}
|
||||
|
||||
func doFuzz(input []byte, obj interface{}, code int) {
|
||||
bc := getChain()
|
||||
defer bc.Stop()
|
||||
|
|
|
|||
199
eth/protocols/snap/handler_test.go
Normal file
199
eth/protocols/snap/handler_test.go
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
// Copyright 2026 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package snap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/consensus/ethash"
|
||||
"github.com/ethereum/go-ethereum/core"
|
||||
"github.com/ethereum/go-ethereum/core/rawdb"
|
||||
"github.com/ethereum/go-ethereum/params"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
// getChainWithBALs creates a minimal test chain with BALs stored for each block.
|
||||
// It returns the chain, block hashes, and the stored BAL data.
|
||||
func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash, []rlp.RawValue) {
|
||||
gspec := &core.Genesis{
|
||||
Config: params.TestChainConfig,
|
||||
}
|
||||
db := rawdb.NewMemoryDatabase()
|
||||
_, blocks, _ := core.GenerateChainWithGenesis(gspec, ethash.NewFaker(), nBlocks, func(i int, gen *core.BlockGen) {})
|
||||
options := &core.BlockChainConfig{
|
||||
TrieCleanLimit: 0,
|
||||
TrieDirtyLimit: 0,
|
||||
TrieTimeLimit: 5 * time.Minute,
|
||||
NoPrefetch: true,
|
||||
SnapshotLimit: 0,
|
||||
}
|
||||
bc, err := core.NewBlockChain(db, gspec, ethash.NewFaker(), options)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if _, err := bc.InsertChain(blocks); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Store BALs for each block
|
||||
var hashes []common.Hash
|
||||
var bals []rlp.RawValue
|
||||
for _, block := range blocks {
|
||||
hash := block.Hash()
|
||||
number := block.NumberU64()
|
||||
bal := make(rlp.RawValue, balSize)
|
||||
|
||||
// Fill with data based on block number
|
||||
for j := range bal {
|
||||
bal[j] = byte(number + uint64(j))
|
||||
}
|
||||
rawdb.WriteAccessListRLP(db, hash, number, bal)
|
||||
hashes = append(hashes, hash)
|
||||
bals = append(bals, bal)
|
||||
}
|
||||
return bc, hashes, bals
|
||||
}
|
||||
|
||||
// TestServiceGetAccessListsQuery verifies that known block hashes return the
|
||||
// correct BALs with positional correspondence.
|
||||
func TestServiceGetAccessListsQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
bc, hashes, bals := getChainWithBALs(5, 100)
|
||||
defer bc.Stop()
|
||||
req := &GetAccessListsPacket{
|
||||
ID: 1,
|
||||
Hashes: hashes,
|
||||
Bytes: softResponseLimit,
|
||||
}
|
||||
result := ServiceGetAccessListsQuery(bc, req)
|
||||
|
||||
// Verify the results
|
||||
if len(result) != len(hashes) {
|
||||
t.Fatalf("expected %d results, got %d", len(hashes), len(result))
|
||||
}
|
||||
for i, bal := range result {
|
||||
if !bytes.Equal(bal, bals[i]) {
|
||||
t.Errorf("BAL %d mismatch: got %x, want %x", i, bal, bals[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceGetAccessListsQueryEmpty verifies that unknown block hashes return
|
||||
// nil placeholders and that mixed known/unknown hashes preserve alignment.
|
||||
func TestServiceGetAccessListsQueryEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
bc, hashes, bals := getChainWithBALs(3, 100)
|
||||
defer bc.Stop()
|
||||
unknown := common.HexToHash("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef")
|
||||
mixed := []common.Hash{hashes[0], unknown, hashes[1], unknown, hashes[2]}
|
||||
req := &GetAccessListsPacket{
|
||||
ID: 2,
|
||||
Hashes: mixed,
|
||||
Bytes: softResponseLimit,
|
||||
}
|
||||
result := ServiceGetAccessListsQuery(bc, req)
|
||||
|
||||
// Verify length
|
||||
if len(result) != len(mixed) {
|
||||
t.Fatalf("expected %d results, got %d", len(mixed), len(result))
|
||||
}
|
||||
|
||||
// Check positional correspondence
|
||||
if !bytes.Equal(result[0], bals[0]) {
|
||||
t.Errorf("index 0: expected known BAL, got %x", result[0])
|
||||
}
|
||||
if result[1] != nil {
|
||||
t.Errorf("index 1: expected nil for unknown hash, got %x", result[1])
|
||||
}
|
||||
if !bytes.Equal(result[2], bals[1]) {
|
||||
t.Errorf("index 2: expected known BAL, got %x", result[2])
|
||||
}
|
||||
if result[3] != nil {
|
||||
t.Errorf("index 3: expected nil for unknown hash, got %x", result[3])
|
||||
}
|
||||
if !bytes.Equal(result[4], bals[2]) {
|
||||
t.Errorf("index 4: expected known BAL, got %x", result[4])
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceGetAccessListsQueryCap verifies that requests exceeding
|
||||
// maxAccessListLookups are capped.
|
||||
func TestServiceGetAccessListsQueryCap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bc, _, _ := getChainWithBALs(2, 100)
|
||||
defer bc.Stop()
|
||||
|
||||
// Create a request with more hashes than the cap
|
||||
hashes := make([]common.Hash, maxAccessListLookups+100)
|
||||
for i := range hashes {
|
||||
hashes[i] = common.BytesToHash([]byte{byte(i), byte(i >> 8)})
|
||||
}
|
||||
req := &GetAccessListsPacket{
|
||||
ID: 3,
|
||||
Hashes: hashes,
|
||||
Bytes: softResponseLimit,
|
||||
}
|
||||
result := ServiceGetAccessListsQuery(bc, req)
|
||||
|
||||
// Can't get more than maxAccessListLookups results
|
||||
if len(result) > maxAccessListLookups {
|
||||
t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, len(result))
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceGetAccessListsQueryByteLimit verifies that the response stops
|
||||
// once the byte limit is exceeded. The handler appends the entry that crosses
|
||||
// the limit before breaking, so the total size will exceed the limit by at
|
||||
// most one BAL.
|
||||
func TestServiceGetAccessListsQueryByteLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The handler will return 3/5 entries (3MB total) then break.
|
||||
balSize := 1024 * 1024
|
||||
nBlocks := 5
|
||||
bc, hashes, _ := getChainWithBALs(nBlocks, balSize)
|
||||
defer bc.Stop()
|
||||
req := &GetAccessListsPacket{
|
||||
ID: 0,
|
||||
Hashes: hashes,
|
||||
Bytes: softResponseLimit,
|
||||
}
|
||||
result := ServiceGetAccessListsQuery(bc, req)
|
||||
|
||||
// Should have stopped before returning all blocks
|
||||
if len(result) >= nBlocks {
|
||||
t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, len(result))
|
||||
}
|
||||
|
||||
// Should have returned at least one
|
||||
if len(result) == 0 {
|
||||
t.Fatal("expected at least one result")
|
||||
}
|
||||
|
||||
// The total size should exceed the limit (the entry that crosses it is included)
|
||||
var total uint64
|
||||
for _, bal := range result {
|
||||
total += uint64(len(bal))
|
||||
}
|
||||
if total <= softResponseLimit {
|
||||
t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", total, softResponseLimit)
|
||||
}
|
||||
}
|
||||
|
|
@ -28,6 +28,7 @@ import (
|
|||
// Constants to match up protocol versions and messages
|
||||
const (
|
||||
SNAP1 = 1
|
||||
SNAP2 = 2
|
||||
)
|
||||
|
||||
// ProtocolName is the official short name of the `snap` protocol used during
|
||||
|
|
@ -36,11 +37,11 @@ const ProtocolName = "snap"
|
|||
|
||||
// ProtocolVersions are the supported versions of the `snap` protocol (first
|
||||
// is primary).
|
||||
var ProtocolVersions = []uint{SNAP1}
|
||||
var ProtocolVersions = []uint{SNAP2, SNAP1}
|
||||
|
||||
// protocolLengths are the number of implemented message corresponding to
|
||||
// different protocol versions.
|
||||
var protocolLengths = map[uint]uint64{SNAP1: 8}
|
||||
var protocolLengths = map[uint]uint64{SNAP2: 10, SNAP1: 8}
|
||||
|
||||
// maxMessageSize is the maximum cap on the size of a protocol message.
|
||||
const maxMessageSize = 10 * 1024 * 1024
|
||||
|
|
@ -54,6 +55,8 @@ const (
|
|||
ByteCodesMsg = 0x05
|
||||
GetTrieNodesMsg = 0x06
|
||||
TrieNodesMsg = 0x07
|
||||
GetAccessListsMsg = 0x08
|
||||
AccessListsMsg = 0x09
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -215,6 +218,21 @@ type TrieNodesPacket struct {
|
|||
Nodes [][]byte // Requested state trie nodes
|
||||
}
|
||||
|
||||
// GetAccessListsPacket requests BALs for a set of block hashes.
|
||||
type GetAccessListsPacket struct {
|
||||
ID uint64 // Request ID to match up responses with
|
||||
Hashes []common.Hash // Block hashes to retrieve BALs for
|
||||
Bytes uint64 // Soft limit at which to stop returning data
|
||||
}
|
||||
|
||||
// AccessListsPacket is the response to GetAccessListsPacket.
|
||||
// Each entry corresponds to the requested hash at the same index.
|
||||
// Empty entries indicate the BAL is unavailable.
|
||||
type AccessListsPacket struct {
|
||||
ID uint64 // ID of the request this is a response for
|
||||
AccessLists []rlp.RawValue // Requested BALs
|
||||
}
|
||||
|
||||
func (*GetAccountRangePacket) Name() string { return "GetAccountRange" }
|
||||
func (*GetAccountRangePacket) Kind() byte { return GetAccountRangeMsg }
|
||||
|
||||
|
|
@ -238,3 +256,9 @@ func (*GetTrieNodesPacket) Kind() byte { return GetTrieNodesMsg }
|
|||
|
||||
func (*TrieNodesPacket) Name() string { return "TrieNodes" }
|
||||
func (*TrieNodesPacket) Kind() byte { return TrieNodesMsg }
|
||||
|
||||
func (*GetAccessListsPacket) Name() string { return "GetAccessLists" }
|
||||
func (*GetAccessListsPacket) Kind() byte { return GetAccessListsMsg }
|
||||
|
||||
func (*AccessListsPacket) Name() string { return "AccessLists" }
|
||||
func (*AccessListsPacket) Kind() byte { return AccessListsMsg }
|
||||
|
|
|
|||
Loading…
Reference in a new issue