core: check nilness before validating bal

This commit is contained in:
Gary Rong 2026-05-19 10:50:34 +08:00
parent 285a13c43d
commit ededc0f238
3 changed files with 20 additions and 14 deletions

View file

@ -115,23 +115,22 @@ func (v *BlockValidator) ValidateBody(block *types.Block) error {
// Amsterdam hard fork. // Amsterdam hard fork.
if v.config.IsAmsterdam(block.Number(), block.Time()) { if v.config.IsAmsterdam(block.Number(), block.Time()) {
if block.Header().BlockAccessListHash == nil { if block.Header().BlockAccessListHash == nil {
return fmt.Errorf("block access list hash not set in header") return errors.New("block access list hash not set in header")
} }
// If the block does not come with an access list, we compute the access list // If the block does not include an access list, compute it locally during
// locally as part of execution and validate against the header's access list // execution and validate it against the access list hash in the header.
// hash. //
// If the block includes an attached access list, validate it directly here.
if block.AccessList() != nil { if block.AccessList() != nil {
computed := block.AccessList().Hash() computed := block.AccessList().Hash()
if *block.Header().BlockAccessListHash != computed { if *block.Header().BlockAccessListHash != computed {
return fmt.Errorf("access list hash mismatch, computed: %x, remote: %x", computed, *block.Header().BlockAccessListHash) return fmt.Errorf("access list hash mismatch, computed: %x, remote: %x", computed, *block.Header().BlockAccessListHash)
} else if err := block.AccessList().Validate(); err != nil { } else if err := block.AccessList().Validate(block.GasLimit()); err != nil {
return fmt.Errorf("invalid block access list: %v", err)
} else if err := block.AccessList().ValidateSize(block.GasLimit()); err != nil {
return fmt.Errorf("invalid block access list: %v", err) return fmt.Errorf("invalid block access list: %v", err)
} }
} }
} else if block.Header().BlockAccessListHash != nil || block.AccessList() != nil { } else if block.Header().BlockAccessListHash != nil || block.AccessList() != nil {
return fmt.Errorf("block had access list before Amsterdam") return errors.New("block had access list before Amsterdam")
} }
// Ancestor block must be known. // Ancestor block must be known.
@ -185,12 +184,18 @@ func (v *BlockValidator) ValidateState(block *types.Block, statedb *state.StateD
} }
// Verify Block-level accessList once Amsterdam is enabled // Verify Block-level accessList once Amsterdam is enabled
if v.config.IsAmsterdam(block.Number(), block.Time()) { if v.config.IsAmsterdam(block.Number(), block.Time()) {
if res.Bal == nil {
return errors.New("block access list is not available in amsterdam")
}
if block.Header().BlockAccessListHash == nil {
return errors.New("block access list hash not set in header")
}
enc := res.Bal.ToEncodingObj() enc := res.Bal.ToEncodingObj()
local, remote := enc.Hash(), *block.Header().BlockAccessListHash local, remote := enc.Hash(), *block.Header().BlockAccessListHash
if local != remote { if local != remote {
return fmt.Errorf("access list hash mismatch, local: %x, remote: %x", local, remote) return fmt.Errorf("access list hash mismatch, local: %x, remote: %x", local, remote)
} }
if err := enc.ValidateSize(block.GasLimit()); err != nil { if err := enc.Validate(block.GasLimit()); err != nil {
return fmt.Errorf("invalid block access list: %v", err) return fmt.Errorf("invalid block access list: %v", err)
} }
} }

View file

@ -78,7 +78,7 @@ func (e *BlockAccessList) DecodeRLP(s *rlp.Stream) error {
// Validate returns an error if the contents of the access list are not ordered // Validate returns an error if the contents of the access list are not ordered
// according to the spec or any code changes are contained which exceed protocol // according to the spec or any code changes are contained which exceed protocol
// max code size. // max code size.
func (e *BlockAccessList) Validate() error { func (e *BlockAccessList) Validate(blockGasLimit uint64) error {
if !slices.IsSortedFunc(*e, func(a, b AccountAccess) int { if !slices.IsSortedFunc(*e, func(a, b AccountAccess) int {
return bytes.Compare(a.Address[:], b.Address[:]) return bytes.Compare(a.Address[:], b.Address[:])
}) { }) {
@ -89,7 +89,7 @@ func (e *BlockAccessList) Validate() error {
return err return err
} }
} }
return nil return e.ValidateSize(blockGasLimit)
} }
// itemCount returns the number of items in the BAL for EIP-7928 size-constraint // itemCount returns the number of items in the BAL for EIP-7928 size-constraint

View file

@ -19,6 +19,7 @@ package bal
import ( import (
"bytes" "bytes"
"cmp" "cmp"
"math"
"reflect" "reflect"
"slices" "slices"
"testing" "testing"
@ -357,7 +358,7 @@ func TestBlockAccessListValidateSize(t *testing.T) {
func TestBlockAccessListValidation(t *testing.T) { func TestBlockAccessListValidation(t *testing.T) {
// Validate the block access list after RLP decoding // Validate the block access list after RLP decoding
enc := makeTestBAL(true) enc := makeTestBAL(true)
if err := enc.Validate(); err != nil { if err := enc.Validate(math.MaxUint64); err != nil {
t.Fatalf("Unexpected validation error: %v", err) t.Fatalf("Unexpected validation error: %v", err)
} }
var buf bytes.Buffer var buf bytes.Buffer
@ -369,14 +370,14 @@ func TestBlockAccessListValidation(t *testing.T) {
if err := dec.DecodeRLP(rlp.NewStream(bytes.NewReader(buf.Bytes()), 0)); err != nil { if err := dec.DecodeRLP(rlp.NewStream(bytes.NewReader(buf.Bytes()), 0)); err != nil {
t.Fatalf("Unexpected RLP-decode error: %v", err) t.Fatalf("Unexpected RLP-decode error: %v", err)
} }
if err := dec.Validate(); err != nil { if err := dec.Validate(math.MaxUint64); err != nil {
t.Fatalf("Unexpected validation error: %v", err) t.Fatalf("Unexpected validation error: %v", err)
} }
// Validate the derived block access list // Validate the derived block access list
cBAL := makeTestConstructionBAL() cBAL := makeTestConstructionBAL()
listB := cBAL.ToEncodingObj() listB := cBAL.ToEncodingObj()
if err := listB.Validate(); err != nil { if err := listB.Validate(math.MaxUint64); err != nil {
t.Fatalf("Unexpected validation error: %v", err) t.Fatalf("Unexpected validation error: %v", err)
} }
} }