core/types/bal: improve code validation in BAL

This commit is contained in:
Gary Rong 2026-04-24 10:00:27 +08:00
parent 1b291219f6
commit 2053222852
2 changed files with 23 additions and 12 deletions

View file

@ -78,14 +78,14 @@ func (e *BlockAccessList) DecodeRLP(s *rlp.Stream) error {
// Validate returns an error if the contents of the access list are not ordered // Validate returns an error if the contents of the access list are not ordered
// according to the spec or any code changes are contained which exceed protocol // according to the spec or any code changes are contained which exceed protocol
// max code size. // max code size.
func (e *BlockAccessList) Validate() error { func (e *BlockAccessList) Validate(rules params.Rules) error {
if !slices.IsSortedFunc(*e, func(a, b AccountAccess) int { if !slices.IsSortedFunc(*e, func(a, b AccountAccess) int {
return bytes.Compare(a.Address[:], b.Address[:]) return bytes.Compare(a.Address[:], b.Address[:])
}) { }) {
return errors.New("block access list accounts not in lexicographic order") return errors.New("block access list accounts not in lexicographic order")
} }
for _, entry := range *e { for _, entry := range *e {
if err := entry.validate(); err != nil { if err := entry.validate(rules); err != nil {
return err return err
} }
} }
@ -159,7 +159,7 @@ type AccountAccess struct {
// validate converts the account accesses out of encoding format. // validate converts the account accesses out of encoding format.
// If any of the keys in the encoding object are not ordered according to the // If any of the keys in the encoding object are not ordered according to the
// spec, an error is returned. // spec, an error is returned.
func (e *AccountAccess) validate() error { func (e *AccountAccess) validate(rules params.Rules) error {
// Check the storage write slots are sorted in order // Check the storage write slots are sorted in order
if !slices.IsSortedFunc(e.StorageWrites, func(a, b encodingSlotWrites) int { if !slices.IsSortedFunc(e.StorageWrites, func(a, b encodingSlotWrites) int {
return a.Slot.Cmp(b.Slot) return a.Slot.Cmp(b.Slot)
@ -200,9 +200,14 @@ func (e *AccountAccess) validate() error {
return errors.New("code changes not in ascending order by tx index") return errors.New("code changes not in ascending order by tx index")
} }
for _, change := range e.CodeChanges { for _, change := range e.CodeChanges {
// TODO(rjl493456442): This check should be fork-aware, since the limit may var sizeLimit int
// differ across forks. switch {
if len(change.Code) > params.MaxCodeSize { case rules.IsAmsterdam:
sizeLimit = params.MaxCodeSizeAmsterdam
default:
sizeLimit = params.MaxCodeSize
}
if len(change.Code) > sizeLimit {
return errors.New("code change contained oversized code") return errors.New("code change contained oversized code")
} }
} }
@ -257,8 +262,8 @@ func (b *ConstructionBlockAccessList) EncodeRLP(wr io.Writer) error {
var _ rlp.Encoder = &ConstructionBlockAccessList{} var _ rlp.Encoder = &ConstructionBlockAccessList{}
// toEncodingObj creates an instance of the ConstructionAccountAccess of the type that is // toEncodingObj creates an instance of the ConstructionAccountAccess of the type
// used as input for the encoding. // that is used as input for the encoding.
func (a *ConstructionAccountAccess) toEncodingObj(addr common.Address) AccountAccess { func (a *ConstructionAccountAccess) toEncodingObj(addr common.Address) AccountAccess {
res := AccountAccess{ res := AccountAccess{
Address: addr, Address: addr,
@ -324,7 +329,12 @@ func (a *ConstructionAccountAccess) toEncodingObj(addr common.Address) AccountAc
for _, idx := range codeIndices { for _, idx := range codeIndices {
res.CodeChanges = append(res.CodeChanges, encodingCodeChange{ res.CodeChanges = append(res.CodeChanges, encodingCodeChange{
TxIndex: idx, TxIndex: idx,
Code: a.CodeChange[idx],
// TODO(rjl493456442) the contract code is not deep-copied.
// In theory the deep-copy is unnecessary, the semantics of
// the function should be probably changed that the returned
// AccessList is unsafe for modification.
Code: a.CodeChange[idx],
}) })
} }
return res return res

View file

@ -25,6 +25,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/internal/testrand" "github.com/ethereum/go-ethereum/internal/testrand"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/holiman/uint256" "github.com/holiman/uint256"
) )
@ -233,7 +234,7 @@ func TestBlockAccessListCopy(t *testing.T) {
func TestBlockAccessListValidation(t *testing.T) { func TestBlockAccessListValidation(t *testing.T) {
// Validate the block access list after RLP decoding // Validate the block access list after RLP decoding
enc := makeTestBAL(true) enc := makeTestBAL(true)
if err := enc.Validate(); err != nil { if err := enc.Validate(params.Rules{}); err != nil {
t.Fatalf("Unexpected validation error: %v", err) t.Fatalf("Unexpected validation error: %v", err)
} }
var buf bytes.Buffer var buf bytes.Buffer
@ -245,14 +246,14 @@ func TestBlockAccessListValidation(t *testing.T) {
if err := dec.DecodeRLP(rlp.NewStream(bytes.NewReader(buf.Bytes()), 0)); err != nil { if err := dec.DecodeRLP(rlp.NewStream(bytes.NewReader(buf.Bytes()), 0)); err != nil {
t.Fatalf("Unexpected RLP-decode error: %v", err) t.Fatalf("Unexpected RLP-decode error: %v", err)
} }
if err := dec.Validate(); err != nil { if err := dec.Validate(params.Rules{}); err != nil {
t.Fatalf("Unexpected validation error: %v", err) t.Fatalf("Unexpected validation error: %v", err)
} }
// Validate the derived block access list // Validate the derived block access list
cBAL := makeTestConstructionBAL() cBAL := makeTestConstructionBAL()
listB := cBAL.toEncodingObj() listB := cBAL.toEncodingObj()
if err := listB.Validate(); err != nil { if err := listB.Validate(params.Rules{}); err != nil {
t.Fatalf("Unexpected validation error: %v", err) t.Fatalf("Unexpected validation error: %v", err)
} }
} }