From b5d322000cc499d802b2b3768a4eb1a1a5dd1f28 Mon Sep 17 00:00:00 2001 From: rjl493456442 Date: Tue, 7 Apr 2026 20:13:19 +0800 Subject: [PATCH] eth/protocols/snap: fix block accessList encoding rule (#34644) This PR refactors the encoding rules for `AccessListsPacket` in the wire protocol. Specifically: - The response is now encoded as a list of `rlp.RawValue` - `rlp.EmptyString` is used as a placeholder for unavailable BAL objects --- eth/protocols/snap/handler_test.go | 207 ++++++++++++++++++++++------- eth/protocols/snap/handlers.go | 20 +-- eth/protocols/snap/protocol.go | 4 +- 3 files changed, 175 insertions(+), 56 deletions(-) diff --git a/eth/protocols/snap/handler_test.go b/eth/protocols/snap/handler_test.go index 53b22ca7f7..3f6a43a059 100644 --- a/eth/protocols/snap/handler_test.go +++ b/eth/protocols/snap/handler_test.go @@ -18,33 +18,48 @@ package snap import ( "bytes" + "encoding/binary" + "reflect" "testing" "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/beacon" "github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types/bal" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" ) +func makeTestBAL(minSize int) *bal.BlockAccessList { + n := minSize/33 + 1 // 33 bytes per storage read slot in RLP + access := bal.AccountAccess{ + Address: common.HexToAddress("0x01"), + StorageReads: make([][32]byte, n), + } + for i := range access.StorageReads { + binary.BigEndian.PutUint64(access.StorageReads[i][24:], uint64(i)) + } + return &bal.BlockAccessList{Accesses: []bal.AccountAccess{access}} +} + // getChainWithBALs creates a minimal test chain with BALs stored for each block. // It returns the chain, block hashes, and the stored BAL data. func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash, []rlp.RawValue) { gspec := &core.Genesis{ - Config: params.TestChainConfig, + Config: params.MergedTestChainConfig, } db := rawdb.NewMemoryDatabase() - _, blocks, _ := core.GenerateChainWithGenesis(gspec, ethash.NewFaker(), nBlocks, func(i int, gen *core.BlockGen) {}) + engine := beacon.New(ethash.NewFaker()) + _, blocks, _ := core.GenerateChainWithGenesis(gspec, engine, nBlocks, func(i int, gen *core.BlockGen) {}) options := &core.BlockChainConfig{ - TrieCleanLimit: 0, - TrieDirtyLimit: 0, - TrieTimeLimit: 5 * time.Minute, - NoPrefetch: true, - SnapshotLimit: 0, + StateScheme: rawdb.PathScheme, + TrieTimeLimit: 5 * time.Minute, + NoPrefetch: true, } - bc, err := core.NewBlockChain(db, gspec, ethash.NewFaker(), options) + bc, err := core.NewBlockChain(db, gspec, engine, options) if err != nil { panic(err) } @@ -53,20 +68,22 @@ func getChainWithBALs(nBlocks int, balSize int) (*core.BlockChain, []common.Hash } // Store BALs for each block - var hashes []common.Hash - var bals []rlp.RawValue + var ( + hashes []common.Hash + bals []rlp.RawValue + ) for _, block := range blocks { hash := block.Hash() number := block.NumberU64() - bal := make(rlp.RawValue, balSize) // Fill with data based on block number - for j := range bal { - bal[j] = byte(number + uint64(j)) + bytes, err := rlp.EncodeToBytes(makeTestBAL(balSize)) + if err != nil { + panic(err) } - rawdb.WriteAccessListRLP(db, hash, number, bal) + rawdb.WriteAccessListRLP(db, hash, number, bytes) hashes = append(hashes, hash) - bals = append(bals, bal) + bals = append(bals, bytes) } return bc, hashes, bals } @@ -85,13 +102,18 @@ func TestServiceGetAccessListsQuery(t *testing.T) { result := ServiceGetAccessListsQuery(bc, req) // Verify the results - if len(result) != len(hashes) { - t.Fatalf("expected %d results, got %d", len(hashes), len(result)) + if result.Len() != len(hashes) { + t.Fatalf("expected %d results, got %d", len(hashes), result.Len()) } - for i, bal := range result { - if !bytes.Equal(bal, bals[i]) { - t.Errorf("BAL %d mismatch: got %x, want %x", i, bal, bals[i]) + var ( + index int + it = result.ContentIterator() + ) + for it.Next() { + if !bytes.Equal(it.Value(), bals[index]) { + t.Errorf("BAL %d mismatch: got %x, want %x", index, it.Value(), bals[index]) } + index++ } } @@ -111,25 +133,23 @@ func TestServiceGetAccessListsQueryEmpty(t *testing.T) { result := ServiceGetAccessListsQuery(bc, req) // Verify length - if len(result) != len(mixed) { - t.Fatalf("expected %d results, got %d", len(mixed), len(result)) + if result.Len() != len(mixed) { + t.Fatalf("expected %d results, got %d", len(mixed), result.Len()) } // Check positional correspondence - if !bytes.Equal(result[0], bals[0]) { - t.Errorf("index 0: expected known BAL, got %x", result[0]) + var expectVal = []rlp.RawValue{ + bals[0], rlp.EmptyString, bals[1], rlp.EmptyString, bals[2], } - if result[1] != nil { - t.Errorf("index 1: expected nil for unknown hash, got %x", result[1]) - } - if !bytes.Equal(result[2], bals[1]) { - t.Errorf("index 2: expected known BAL, got %x", result[2]) - } - if result[3] != nil { - t.Errorf("index 3: expected nil for unknown hash, got %x", result[3]) - } - if !bytes.Equal(result[4], bals[2]) { - t.Errorf("index 4: expected known BAL, got %x", result[4]) + var ( + index int + it = result.ContentIterator() + ) + for it.Next() { + if !bytes.Equal(it.Value(), expectVal[index]) { + t.Errorf("BAL %d mismatch: got %x, want %x", index, it.Value(), expectVal[index]) + } + index++ } } @@ -154,8 +174,8 @@ func TestServiceGetAccessListsQueryCap(t *testing.T) { result := ServiceGetAccessListsQuery(bc, req) // Can't get more than maxAccessListLookups results - if len(result) > maxAccessListLookups { - t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, len(result)) + if result.Len() > maxAccessListLookups { + t.Fatalf("expected at most %d results, got %d", maxAccessListLookups, result.Len()) } } @@ -179,21 +199,116 @@ func TestServiceGetAccessListsQueryByteLimit(t *testing.T) { result := ServiceGetAccessListsQuery(bc, req) // Should have stopped before returning all blocks - if len(result) >= nBlocks { - t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, len(result)) + if result.Len() >= nBlocks { + t.Fatalf("expected fewer than %d results due to byte limit, got %d", nBlocks, result.Len()) } // Should have returned at least one - if len(result) == 0 { + if result.Len() == 0 { t.Fatal("expected at least one result") } // The total size should exceed the limit (the entry that crosses it is included) - var total uint64 - for _, bal := range result { - total += uint64(len(bal)) - } - if total <= softResponseLimit { - t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", total, softResponseLimit) + if result.Size() <= softResponseLimit { + t.Errorf("total response size %d should exceed soft limit %d (includes one entry past limit)", result.Size(), softResponseLimit) + } +} + +// TestGetAccessListResponseDecoding verifies that an AccessListsPacket +// round-trips through RLP encode/decode, preserving positional +// correspondence and correctly representing absent BALs as empty strings. +func TestGetAccessListResponseDecoding(t *testing.T) { + t.Parallel() + + // Build two real BALs of different sizes. + bal1 := makeTestBAL(100) + bal2 := makeTestBAL(200) + bytes1, _ := rlp.EncodeToBytes(bal1) + bytes2, _ := rlp.EncodeToBytes(bal2) + + tests := []struct { + name string + items []rlp.RawValue // nil entry = unavailable BAL + counts int // expected decoded length + }{ + { + name: "all present", + items: []rlp.RawValue{bytes1, bytes2}, + counts: 2, + }, + { + name: "all absent", + items: []rlp.RawValue{rlp.EmptyString, rlp.EmptyString, rlp.EmptyString}, + counts: 3, + }, + { + name: "mixed present and absent", + items: []rlp.RawValue{bytes1, rlp.EmptyString, bytes2, rlp.EmptyString}, + counts: 4, + }, + { + name: "empty response", + items: []rlp.RawValue{}, + counts: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Build the packet using Append. + var orig AccessListsPacket + orig.ID = 42 + for _, item := range tt.items { + if err := orig.AccessLists.AppendRaw(item); err != nil { + t.Fatalf("AppendRaw failed: %v", err) + } + } + + // Encode -> Decode round-trip. + enc, err := rlp.EncodeToBytes(&orig) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + var dec AccessListsPacket + if err := rlp.DecodeBytes(enc, &dec); err != nil { + t.Fatalf("decode failed: %v", err) + } + + // Verify ID preserved. + if dec.ID != orig.ID { + t.Fatalf("ID mismatch: got %d, want %d", dec.ID, orig.ID) + } + + // Verify element count. + if dec.AccessLists.Len() != tt.counts { + t.Fatalf("length mismatch: got %d, want %d", dec.AccessLists.Len(), tt.counts) + } + + // Verify each element positionally. + it := dec.AccessLists.ContentIterator() + for i, want := range tt.items { + if !it.Next() { + t.Fatalf("iterator exhausted at index %d", i) + } + got := it.Value() + if !bytes.Equal(got, want) { + t.Errorf("element %d: got %x, want %x", i, got, want) + } + if !bytes.Equal(got, rlp.EmptyString) { + obj := new(bal.BlockAccessList) + if err := rlp.DecodeBytes(got, obj); err != nil { + t.Fatalf("decode failed: %v", err) + } + if bytes.Equal(got, bytes1) && !reflect.DeepEqual(obj, bal1) { + t.Fatalf("decode failed: got %x, want %x", obj, bal1) + } + if bytes.Equal(got, bytes2) && !reflect.DeepEqual(obj, bal2) { + t.Fatalf("decode failed: got %x, want %x", obj, bal2) + } + } + } + if it.Next() { + t.Error("iterator has extra elements after expected end") + } + }) } } diff --git a/eth/protocols/snap/handlers.go b/eth/protocols/snap/handlers.go index 4d60aab1f6..5a5733bdb4 100644 --- a/eth/protocols/snap/handlers.go +++ b/eth/protocols/snap/handlers.go @@ -559,16 +559,15 @@ func handleGetAccessLists(backend Backend, msg Decoder, peer *Peer) error { if err := msg.Decode(&req); err != nil { return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) } - bals := ServiceGetAccessListsQuery(backend.Chain(), &req) return p2p.Send(peer.rw, AccessListsMsg, &AccessListsPacket{ ID: req.ID, - AccessLists: bals, + AccessLists: ServiceGetAccessListsQuery(backend.Chain(), &req), }) } // ServiceGetAccessListsQuery assembles the response to an access list query. // It is exposed to allow external packages to test protocol behavior. -func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) []rlp.RawValue { +func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacket) rlp.RawList[rlp.RawValue] { if req.Bytes > softResponseLimit { req.Bytes = softResponseLimit } @@ -577,20 +576,25 @@ func ServiceGetAccessListsQuery(chain *core.BlockChain, req *GetAccessListsPacke req.Hashes = req.Hashes[:maxAccessListLookups] } var ( - bals []rlp.RawValue - bytes uint64 + err error + bytes uint64 + response = rlp.RawList[rlp.RawValue]{} ) for _, hash := range req.Hashes { if bal := chain.GetAccessListRLP(hash); len(bal) > 0 { - bals = append(bals, bal) + err = response.AppendRaw(bal) bytes += uint64(len(bal)) } else { // Either the block is unknown or the BAL doesn't exist - bals = append(bals, nil) + err = response.AppendRaw(rlp.EmptyString) + bytes += 1 + } + if err != nil { + break } if bytes > req.Bytes { break } } - return bals + return response } diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go index 7913f8b053..685f468da3 100644 --- a/eth/protocols/snap/protocol.go +++ b/eth/protocols/snap/protocol.go @@ -229,8 +229,8 @@ type GetAccessListsPacket struct { // Each entry corresponds to the requested hash at the same index. // Empty entries indicate the BAL is unavailable. type AccessListsPacket struct { - ID uint64 // ID of the request this is a response for - AccessLists []rlp.RawValue // Requested BALs + ID uint64 // ID of the request this is a response for + AccessLists rlp.RawList[rlp.RawValue] // Requested BALs } func (*GetAccountRangePacket) Name() string { return "GetAccountRange" }