mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-04-08 02:47:30 +00:00
eth/protocols/snap: fix block accessList encoding rule (#34644)
This PR refactors the encoding rules for `AccessListsPacket` in the wire protocol. Specifically: - The response is now encoded as a list of `rlp.RawValue` - `rlp.EmptyString` is used as a placeholder for unavailable BAL objects
This commit is contained in:
parent
bd6530a1d4
commit
b5d322000c
3 changed files with 175 additions and 56 deletions
|
|
@ -18,33 +18,48 @@ package snap
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/consensus/beacon"
|
||||
"github.com/ethereum/go-ethereum/consensus/ethash"
|
||||
"github.com/ethereum/go-ethereum/core"
|
||||
"github.com/ethereum/go-ethereum/core/rawdb"
|
||||
"github.com/ethereum/go-ethereum/core/types/bal"
|
||||
"github.com/ethereum/go-ethereum/params"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
func makeTestBAL(minSize int) *bal.BlockAccessList {
|
||||
n := minSize/33 + 1 // 33 bytes per storage read slot in RLP
|
||||
access := bal.AccountAccess{
|
||||
Address: common.HexToAddress("0x01"),
|
||||
StorageReads: make([][32]byte, n),
|
||||
}
|
||||
for i := range access.StorageReads {
|
||||
binary.BigEndian.PutUint64(access.StorageReads[i][24:], uint64(i))
|
||||
}
|
||||
return &bal.BlockAccessList{Accesses: []bal.AccountAccess{access}}
|
||||
}
|
||||
|
||||
// getChainWithBALs creates a minimal test chain with BALs stored for each block.
|
||||
// It returns the chain, block hashes, and the stored BAL data.
|
||||
func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash, []rlp.RawValue) {
|
||||
gspec := &core.Genesis{
|
||||
Config: params.TestChainConfig,
|
||||
Config: params.MergedTestChainConfig,
|
||||
}
|
||||
db := rawdb.NewMemoryDatabase()
|
||||
_, blocks, _ := core.GenerateChainWithGenesis(gspec, ethash.NewFaker(), nBlocks, func(i int, gen *core.BlockGen) {})
|
||||
engine := beacon.New(ethash.NewFaker())
|
||||
_, blocks, _ := core.GenerateChainWithGenesis(gspec, engine, nBlocks, func(i int, gen *core.BlockGen) {})
|
||||
options := &core.BlockChainConfig{
|
||||
TrieCleanLimit: 0,
|
||||
TrieDirtyLimit: 0,
|
||||
TrieTimeLimit: 5 * time.Minute,
|
||||
NoPrefetch: true,
|
||||
SnapshotLimit: 0,
|
||||
StateScheme: rawdb.PathScheme,
|
||||
TrieTimeLimit: 5 * time.Minute,
|
||||
NoPrefetch: true,
|
||||
}
|
||||
bc, err := core.NewBlockChain(db, gspec, ethash.NewFaker(), options)
|
||||
bc, err := core.NewBlockChain(db, gspec, engine, options)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
@ -53,20 +68,22 @@ func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash
|
|||
}
|
||||
|
||||
// Store BALs for each block
|
||||
var hashes []common.Hash
|
||||
var bals []rlp.RawValue
|
||||
var (
|
||||
hashes []common.Hash
|
||||
bals []rlp.RawValue
|
||||
)
|
||||
for _, block := range blocks {
|
||||
hash := block.Hash()
|
||||
number := block.NumberU64()
|
||||
bal := make(rlp.RawValue, balSize)
|
||||
|
||||
// Fill with data based on block number
|
||||
for j := range bal {
|
||||
bal[j] = byte(number + uint64(j))
|
||||
bytes, err := rlp.EncodeToBytes(makeTestBAL(balSize))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
rawdb.WriteAccessListRLP(db, hash, number, bal)
|
||||
rawdb.WriteAccessListRLP(db, hash, number, bytes)
|
||||
hashes = append(hashes, hash)
|
||||
bals = append(bals, bal)
|
||||
bals = append(bals, bytes)
|
||||
}
|
||||
return bc, hashes, bals
|
||||
}
|
||||
|
|
@ -85,13 +102,18 @@ func TestServiceGetAccessListsQuery(t *testing.T) {
|
|||
result := ServiceGetAccessListsQuery(bc, req)
|
||||
|
||||
// Verify the results
|
||||
if len(result) != len(hashes) {
|
||||
t.Fatalf("expected %d results, got %d", len(hashes), len(result))
|
||||
if result.Len() != len(hashes) {
|
||||
t.Fatalf("expected %d results, got %d", len(hashes), result.Len())
|
||||
}
|
||||
for i, bal := range result {
|
||||
if !bytes.Equal(bal, bals[i]) {
|
||||
t.Errorf("BAL %d mismatch: got %x, want %x", i, bal, bals[i])
|
||||
var (
|
||||
index int
|
||||
it = result.ContentIterator()
|
||||
)
|
||||
for it.Next() {
|
||||
if !bytes.Equal(it.Value(), bals[index]) {
|
||||
t.Errorf("BAL %d mismatch: got %x, want %x", index, it.Value(), bals[index])
|
||||
}
|
||||
index++
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -111,25 +133,23 @@ func TestServiceGetAccessListsQueryEmpty(t *testing.T) {
|
|||
result := ServiceGetAccessListsQuery(bc, req)
|
||||
|
||||
// Verify length
|
||||
if len(result) != len(mixed) {
|
||||
t.Fatalf("expected %d results, got %d", len(mixed), len(result))
|
||||
if result.Len() != len(mixed) {
|
||||
t.Fatalf("expected %d results, got %d", len(mixed), result.Len())
|
||||
}
|
||||
|
||||
// Check positional correspondence
|
||||
if !bytes.Equal(result[0], bals[0]) {
|
||||
t.Errorf("index 0: expected known BAL, got %x", result[0])
|
||||
var expectVal = []rlp.RawValue{
|
||||
bals[0], rlp.EmptyString, bals[1], rlp.EmptyString, bals[2],
|
||||
}
|
||||
if result[1] != nil {
|
||||
t.Errorf("index 1: expected nil for unknown hash, got %x", result[1])
|
||||
}
|
||||
if !bytes.Equal(result[2], bals[1]) {
|
||||
t.Errorf("index 2: expected known BAL, got %x", result[2])
|
||||
}
|
||||
if result[3] != nil {
|
||||
t.Errorf("index 3: expected nil for unknown hash, got %x", result[3])
|
||||
}
|
||||
if !bytes.Equal(result[4], bals[2]) {
|
||||
t.Errorf("index 4: expected known BAL, got %x", result[4])
|
||||
var (
|
||||
index int
|
||||
it = result.ContentIterator()
|
||||
)
|
||||
for it.Next() {
|
||||
if !bytes.Equal(it.Value(), expectVal[index]) {
|
||||
t.Errorf("BAL %d mismatch: got %x, want %x", index, it.Value(), expectVal[index])
|
||||
}
|
||||
index++
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -154,8 +174,8 @@ func TestServiceGetAccessListsQueryCap(t *testing.T) {
|
|||
result := ServiceGetAccessListsQuery(bc, req)
|
||||
|
||||
// Can't get more than maxAccessListLookups results
|
||||
if len(result) > maxAccessListLookups {
|
||||
t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, len(result))
|
||||
if result.Len() > maxAccessListLookups {
|
||||
t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, result.Len())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -179,21 +199,116 @@ func TestServiceGetAccessListsQueryByteLimit(t *testing.T) {
|
|||
result := ServiceGetAccessListsQuery(bc, req)
|
||||
|
||||
// Should have stopped before returning all blocks
|
||||
if len(result) >= nBlocks {
|
||||
t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, len(result))
|
||||
if result.Len() >= nBlocks {
|
||||
t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, result.Len())
|
||||
}
|
||||
|
||||
// Should have returned at least one
|
||||
if len(result) == 0 {
|
||||
if result.Len() == 0 {
|
||||
t.Fatal("expected at least one result")
|
||||
}
|
||||
|
||||
// The total size should exceed the limit (the entry that crosses it is included)
|
||||
var total uint64
|
||||
for _, bal := range result {
|
||||
total += uint64(len(bal))
|
||||
}
|
||||
if total <= softResponseLimit {
|
||||
t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", total, softResponseLimit)
|
||||
if result.Size() <= softResponseLimit {
|
||||
t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", result.Size(), softResponseLimit)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAccessListResponseDecoding verifies that an AccessListsPacket
|
||||
// round-trips through RLP encode/decode, preserving positional
|
||||
// correspondence and correctly representing absent BALs as empty strings.
|
||||
func TestGetAccessListResponseDecoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Build two real BALs of different sizes.
|
||||
bal1 := makeTestBAL(100)
|
||||
bal2 := makeTestBAL(200)
|
||||
bytes1, _ := rlp.EncodeToBytes(bal1)
|
||||
bytes2, _ := rlp.EncodeToBytes(bal2)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
items []rlp.RawValue // nil entry = unavailable BAL
|
||||
counts int // expected decoded length
|
||||
}{
|
||||
{
|
||||
name: "all present",
|
||||
items: []rlp.RawValue{bytes1, bytes2},
|
||||
counts: 2,
|
||||
},
|
||||
{
|
||||
name: "all absent",
|
||||
items: []rlp.RawValue{rlp.EmptyString, rlp.EmptyString, rlp.EmptyString},
|
||||
counts: 3,
|
||||
},
|
||||
{
|
||||
name: "mixed present and absent",
|
||||
items: []rlp.RawValue{bytes1, rlp.EmptyString, bytes2, rlp.EmptyString},
|
||||
counts: 4,
|
||||
},
|
||||
{
|
||||
name: "empty response",
|
||||
items: []rlp.RawValue{},
|
||||
counts: 0,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Build the packet using Append.
|
||||
var orig AccessListsPacket
|
||||
orig.ID = 42
|
||||
for _, item := range tt.items {
|
||||
if err := orig.AccessLists.AppendRaw(item); err != nil {
|
||||
t.Fatalf("AppendRaw failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode -> Decode round-trip.
|
||||
enc, err := rlp.EncodeToBytes(&orig)
|
||||
if err != nil {
|
||||
t.Fatalf("encode failed: %v", err)
|
||||
}
|
||||
var dec AccessListsPacket
|
||||
if err := rlp.DecodeBytes(enc, &dec); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify ID preserved.
|
||||
if dec.ID != orig.ID {
|
||||
t.Fatalf("ID mismatch: got %d, want %d", dec.ID, orig.ID)
|
||||
}
|
||||
|
||||
// Verify element count.
|
||||
if dec.AccessLists.Len() != tt.counts {
|
||||
t.Fatalf("length mismatch: got %d, want %d", dec.AccessLists.Len(), tt.counts)
|
||||
}
|
||||
|
||||
// Verify each element positionally.
|
||||
it := dec.AccessLists.ContentIterator()
|
||||
for i, want := range tt.items {
|
||||
if !it.Next() {
|
||||
t.Fatalf("iterator exhausted at index %d", i)
|
||||
}
|
||||
got := it.Value()
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("element %d: got %x, want %x", i, got, want)
|
||||
}
|
||||
if !bytes.Equal(got, rlp.EmptyString) {
|
||||
obj := new(bal.BlockAccessList)
|
||||
if err := rlp.DecodeBytes(got, obj); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if bytes.Equal(got, bytes1) && !reflect.DeepEqual(obj, bal1) {
|
||||
t.Fatalf("decode failed: got %x, want %x", obj, bal1)
|
||||
}
|
||||
if bytes.Equal(got, bytes2) && !reflect.DeepEqual(obj, bal2) {
|
||||
t.Fatalf("decode failed: got %x, want %x", obj, bal2)
|
||||
}
|
||||
}
|
||||
}
|
||||
if it.Next() {
|
||||
t.Error("iterator has extra elements after expected end")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -559,16 +559,15 @@ func handleGetAccessLists(backend Backend, msg Decoder, peer *Peer) error {
|
|||
if err := msg.Decode(&req); err != nil {
|
||||
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
|
||||
}
|
||||
bals := ServiceGetAccessListsQuery(backend.Chain(), &req)
|
||||
return p2p.Send(peer.rw, AccessListsMsg, &AccessListsPacket{
|
||||
ID: req.ID,
|
||||
AccessLists: bals,
|
||||
AccessLists: ServiceGetAccessListsQuery(backend.Chain(), &req),
|
||||
})
|
||||
}
|
||||
|
||||
// ServiceGetAccessListsQuery assembles the response to an access list query.
|
||||
// It is exposed to allow external packages to test protocol behavior.
|
||||
func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) []rlp.RawValue {
|
||||
func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) rlp.RawList[rlp.RawValue] {
|
||||
if req.Bytes > softResponseLimit {
|
||||
req.Bytes = softResponseLimit
|
||||
}
|
||||
|
|
@ -577,20 +576,25 @@ func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacke
|
|||
req.Hashes = req.Hashes[:maxAccessListLookups]
|
||||
}
|
||||
var (
|
||||
bals []rlp.RawValue
|
||||
bytes uint64
|
||||
err error
|
||||
bytes uint64
|
||||
response = rlp.RawList[rlp.RawValue]{}
|
||||
)
|
||||
for _, hash := range req.Hashes {
|
||||
if bal := chain.GetAccessListRLP(hash); len(bal) > 0 {
|
||||
bals = append(bals, bal)
|
||||
err = response.AppendRaw(bal)
|
||||
bytes += uint64(len(bal))
|
||||
} else {
|
||||
// Either the block is unknown or the BAL doesn't exist
|
||||
bals = append(bals, nil)
|
||||
err = response.AppendRaw(rlp.EmptyString)
|
||||
bytes += 1
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if bytes > req.Bytes {
|
||||
break
|
||||
}
|
||||
}
|
||||
return bals
|
||||
return response
|
||||
}
|
||||
|
|
|
|||
|
|
@ -229,8 +229,8 @@ type GetAccessListsPacket struct {
|
|||
// Each entry corresponds to the requested hash at the same index.
|
||||
// Empty entries indicate the BAL is unavailable.
|
||||
type AccessListsPacket struct {
|
||||
ID uint64 // ID of the request this is a response for
|
||||
AccessLists []rlp.RawValue // Requested BALs
|
||||
ID uint64 // ID of the request this is a response for
|
||||
AccessLists rlp.RawList[rlp.RawValue] // Requested BALs
|
||||
}
|
||||
|
||||
func (*GetAccountRangePacket) Name() string { return "GetAccountRange" }
|
||||
|
|
|
|||
Loading…
Reference in a new issue