mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-06-21 14:14:30 +00:00
refactor: consolidate once-only registration of extras (#85)
## Why this should be merged Consolidates duplicated logic. Similar rationale to #84. ## How this works New `register.AtMostOnce[T]` type is responsible for limiting calls to `Register()`. ## How this was tested Existing unit tests of `params`. Note that the equivalent functionality in `types` wasn't tested but now is.
This commit is contained in:
parent
25e5ca3eb1
commit
d71677f141
6 changed files with 111 additions and 48 deletions
|
|
@ -21,6 +21,7 @@ import (
|
|||
"io"
|
||||
|
||||
"github.com/ava-labs/libevm/libevm/pseudo"
|
||||
"github.com/ava-labs/libevm/libevm/register"
|
||||
"github.com/ava-labs/libevm/libevm/testonly"
|
||||
"github.com/ava-labs/libevm/rlp"
|
||||
)
|
||||
|
|
@ -37,18 +38,15 @@ import (
|
|||
// The payload can be accessed via the [ExtraPayloads.FromPayloadCarrier] method
|
||||
// of the accessor returned by RegisterExtras.
|
||||
func RegisterExtras[SA any]() ExtraPayloads[SA] {
|
||||
if registeredExtras != nil {
|
||||
panic("re-registration of Extras")
|
||||
}
|
||||
var extra ExtraPayloads[SA]
|
||||
registeredExtras = &extraConstructors{
|
||||
registeredExtras.MustRegister(&extraConstructors{
|
||||
stateAccountType: func() string {
|
||||
var x SA
|
||||
return fmt.Sprintf("%T", x)
|
||||
}(),
|
||||
newStateAccount: pseudo.NewConstructor[SA]().Zero,
|
||||
cloneStateAccount: extra.cloneStateAccount,
|
||||
}
|
||||
})
|
||||
return extra
|
||||
}
|
||||
|
||||
|
|
@ -59,12 +57,10 @@ func RegisterExtras[SA any]() ExtraPayloads[SA] {
|
|||
// defer-called afterwards, either directly or via testing.TB.Cleanup(). This is
|
||||
// a workaround for the single-call limitation on [RegisterExtras].
|
||||
func TestOnlyClearRegisteredExtras() {
|
||||
testonly.OrPanic(func() {
|
||||
registeredExtras = nil
|
||||
})
|
||||
registeredExtras.TestOnlyClear()
|
||||
}
|
||||
|
||||
var registeredExtras *extraConstructors
|
||||
var registeredExtras register.AtMostOnce[*extraConstructors]
|
||||
|
||||
type extraConstructors struct {
|
||||
stateAccountType string
|
||||
|
|
@ -74,10 +70,10 @@ type extraConstructors struct {
|
|||
|
||||
func (e *StateAccountExtra) clone() *StateAccountExtra {
|
||||
switch r := registeredExtras; {
|
||||
case r == nil, e == nil:
|
||||
case !r.Registered(), e == nil:
|
||||
return nil
|
||||
default:
|
||||
return r.cloneStateAccount(e)
|
||||
return r.Get().cloneStateAccount(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -146,7 +142,7 @@ func (a *SlimAccount) extra() *StateAccountExtra {
|
|||
func getOrSetNewStateAccountExtra(curr **StateAccountExtra) *StateAccountExtra {
|
||||
if *curr == nil {
|
||||
*curr = &StateAccountExtra{
|
||||
t: registeredExtras.newStateAccount(),
|
||||
t: registeredExtras.Get().newStateAccount(),
|
||||
}
|
||||
}
|
||||
return *curr
|
||||
|
|
@ -154,7 +150,7 @@ func getOrSetNewStateAccountExtra(curr **StateAccountExtra) *StateAccountExtra {
|
|||
|
||||
func (e *StateAccountExtra) payload() *pseudo.Type {
|
||||
if e.t == nil {
|
||||
e.t = registeredExtras.newStateAccount()
|
||||
e.t = registeredExtras.Get().newStateAccount()
|
||||
}
|
||||
return e.t
|
||||
}
|
||||
|
|
@ -196,13 +192,13 @@ var _ interface {
|
|||
// EncodeRLP implements the [rlp.Encoder] interface.
|
||||
func (e *StateAccountExtra) EncodeRLP(w io.Writer) error {
|
||||
switch r := registeredExtras; {
|
||||
case r == nil:
|
||||
case !r.Registered():
|
||||
return nil
|
||||
case e == nil:
|
||||
e = &StateAccountExtra{}
|
||||
fallthrough
|
||||
case e.t == nil:
|
||||
e.t = r.newStateAccount()
|
||||
e.t = r.Get().newStateAccount()
|
||||
}
|
||||
return e.t.EncodeRLP(w)
|
||||
}
|
||||
|
|
@ -210,10 +206,10 @@ func (e *StateAccountExtra) EncodeRLP(w io.Writer) error {
|
|||
// DecodeRLP implements the [rlp.Decoder] interface.
|
||||
func (e *StateAccountExtra) DecodeRLP(s *rlp.Stream) error {
|
||||
switch r := registeredExtras; {
|
||||
case r == nil:
|
||||
case !r.Registered():
|
||||
return nil
|
||||
case e.t == nil:
|
||||
e.t = r.newStateAccount()
|
||||
e.t = r.Get().newStateAccount()
|
||||
fallthrough
|
||||
default:
|
||||
return s.Decode(e.t)
|
||||
|
|
@ -224,10 +220,10 @@ func (e *StateAccountExtra) DecodeRLP(s *rlp.Stream) error {
|
|||
func (e *StateAccountExtra) Format(s fmt.State, verb rune) {
|
||||
var out string
|
||||
switch r := registeredExtras; {
|
||||
case r == nil:
|
||||
case !r.Registered():
|
||||
out = "<nil>"
|
||||
case e == nil, e.t == nil:
|
||||
out = fmt.Sprintf("<nil>[*StateAccountExtra[%s]]", r.stateAccountType)
|
||||
out = fmt.Sprintf("<nil>[*StateAccountExtra[%s]]", r.Get().stateAccountType)
|
||||
default:
|
||||
e.t.Format(s, verb)
|
||||
return
|
||||
|
|
|
|||
68
libevm/register/register.go
Normal file
68
libevm/register/register.go
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright 2024 the libevm authors.
|
||||
//
|
||||
// The libevm additions to go-ethereum are free software: you can redistribute
|
||||
// them and/or modify them under the terms of the GNU Lesser General Public License
|
||||
// as published by the Free Software Foundation, either version 3 of the License,
|
||||
// or (at your option) any later version.
|
||||
//
|
||||
// The libevm additions are distributed in the hope that they will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
|
||||
// General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see
|
||||
// <http://www.gnu.org/licenses/>.
|
||||
|
||||
// Package register provides functionality for optional registration of types.
|
||||
package register
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ava-labs/libevm/libevm/testonly"
|
||||
)
|
||||
|
||||
// An AtMostOnce allows zero or one registration of a T.
|
||||
type AtMostOnce[T any] struct {
|
||||
v *T
|
||||
}
|
||||
|
||||
// ErrReRegistration is returned on all but the first of calls to
|
||||
// [AtMostOnce.Register].
|
||||
var ErrReRegistration = errors.New("re-registration")
|
||||
|
||||
// Register registers `v` or returns [ErrReRegistration] if already called.
|
||||
func (o *AtMostOnce[T]) Register(v T) error {
|
||||
if o.Registered() {
|
||||
return ErrReRegistration
|
||||
}
|
||||
o.v = &v
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustRegister is equivalent to [AtMostOnce.Register], panicking on error.
|
||||
func (o *AtMostOnce[T]) MustRegister(v T) {
|
||||
if err := o.Register(v); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Registered reports whether [AtMostOnce.Register] has been called.
|
||||
func (o *AtMostOnce[T]) Registered() bool {
|
||||
return o.v != nil
|
||||
}
|
||||
|
||||
// Get returns the registered value. It MUST NOT be called before
|
||||
// [AtMostOnce.Register].
|
||||
func (o *AtMostOnce[T]) Get() T {
|
||||
return *o.v
|
||||
}
|
||||
|
||||
// TestOnlyClear clears any previously registered value, returning `o` to its
|
||||
// default state. It panics if called from a non-testing call stack.
|
||||
func (o *AtMostOnce[T]) TestOnlyClear() {
|
||||
testonly.OrPanic(func() {
|
||||
o.v = nil
|
||||
})
|
||||
}
|
||||
|
|
@ -22,7 +22,7 @@ import (
|
|||
"reflect"
|
||||
|
||||
"github.com/ava-labs/libevm/libevm/pseudo"
|
||||
"github.com/ava-labs/libevm/libevm/testonly"
|
||||
"github.com/ava-labs/libevm/libevm/register"
|
||||
)
|
||||
|
||||
// Extras are arbitrary payloads to be added as extra fields in [ChainConfig]
|
||||
|
|
@ -68,20 +68,17 @@ type Extras[C ChainConfigHooks, R RulesHooks] struct {
|
|||
// alter Ethereum behaviour; if this isn't desired then they can embed
|
||||
// [NOOPHooks] to satisfy either interface.
|
||||
func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPayloads[C, R] {
|
||||
if registeredExtras != nil {
|
||||
panic("re-registration of Extras")
|
||||
}
|
||||
mustBeStructOrPointerToOne[C]()
|
||||
mustBeStructOrPointerToOne[R]()
|
||||
|
||||
payloads := e.payloads()
|
||||
registeredExtras = &extraConstructors{
|
||||
registeredExtras.MustRegister(&extraConstructors{
|
||||
newChainConfig: pseudo.NewConstructor[C]().Zero,
|
||||
newRules: pseudo.NewConstructor[R]().Zero,
|
||||
reuseJSONRoot: e.ReuseJSONRoot,
|
||||
newForRules: e.newForRules,
|
||||
payloads: payloads,
|
||||
}
|
||||
})
|
||||
return payloads
|
||||
}
|
||||
|
||||
|
|
@ -92,14 +89,12 @@ func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPaylo
|
|||
// defer-called afterwards, either directly or via testing.TB.Cleanup(). This is
|
||||
// a workaround for the single-call limitation on [RegisterExtras].
|
||||
func TestOnlyClearRegisteredExtras() {
|
||||
testonly.OrPanic(func() {
|
||||
registeredExtras = nil
|
||||
})
|
||||
registeredExtras.TestOnlyClear()
|
||||
}
|
||||
|
||||
// registeredExtras holds non-generic constructors for the [Extras] types
|
||||
// registered via [RegisterExtras].
|
||||
var registeredExtras *extraConstructors
|
||||
var registeredExtras register.AtMostOnce[*extraConstructors]
|
||||
|
||||
type extraConstructors struct {
|
||||
newChainConfig, newRules func() *pseudo.Type
|
||||
|
|
@ -115,7 +110,7 @@ type extraConstructors struct {
|
|||
|
||||
func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type {
|
||||
if e.NewRules == nil {
|
||||
return registeredExtras.newRules()
|
||||
return registeredExtras.Get().newRules()
|
||||
}
|
||||
rExtra := e.NewRules(c, r, e.payloads().FromChainConfig(c), blockNum, isMerge, timestamp)
|
||||
return pseudo.From(rExtra).Type
|
||||
|
|
@ -209,8 +204,8 @@ func (e ExtraPayloads[C, R]) hooksFromRules(r *Rules) RulesHooks {
|
|||
// abstract the libevm-specific behaviour outside of original geth code.
|
||||
func (c *ChainConfig) addRulesExtra(r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) {
|
||||
r.extra = nil
|
||||
if registeredExtras != nil {
|
||||
r.extra = registeredExtras.newForRules(c, r, blockNum, isMerge, timestamp)
|
||||
if registeredExtras.Registered() {
|
||||
r.extra = registeredExtras.Get().newForRules(c, r, blockNum, isMerge, timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -218,7 +213,7 @@ func (c *ChainConfig) addRulesExtra(r *Rules, blockNum *big.Int, isMerge bool, t
|
|||
// already been called. If the payload hasn't been populated (typically via
|
||||
// unmarshalling of JSON), a nil value is constructed and returned.
|
||||
func (c *ChainConfig) extraPayload() *pseudo.Type {
|
||||
if registeredExtras == nil {
|
||||
if !registeredExtras.Registered() {
|
||||
// This will only happen if someone constructs an [ExtraPayloads]
|
||||
// directly, without a call to [RegisterExtras].
|
||||
//
|
||||
|
|
@ -226,19 +221,19 @@ func (c *ChainConfig) extraPayload() *pseudo.Type {
|
|||
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c))
|
||||
}
|
||||
if c.extra == nil {
|
||||
c.extra = registeredExtras.newChainConfig()
|
||||
c.extra = registeredExtras.Get().newChainConfig()
|
||||
}
|
||||
return c.extra
|
||||
}
|
||||
|
||||
// extraPayload is equivalent to [ChainConfig.extraPayload].
|
||||
func (r *Rules) extraPayload() *pseudo.Type {
|
||||
if registeredExtras == nil {
|
||||
if !registeredExtras.Registered() {
|
||||
// See ChainConfig.extraPayload() equivalent.
|
||||
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r))
|
||||
}
|
||||
if r.extra == nil {
|
||||
r.extra = registeredExtras.newRules()
|
||||
r.extra = registeredExtras.Get().newRules()
|
||||
}
|
||||
return r.extra
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ava-labs/libevm/libevm/pseudo"
|
||||
"github.com/ava-labs/libevm/libevm/register"
|
||||
)
|
||||
|
||||
type rawJSON struct {
|
||||
|
|
@ -255,18 +256,21 @@ func TestExtrasPanic(t *testing.T) {
|
|||
t, func() {
|
||||
RegisterExtras(Extras[struct{ ChainConfigHooks }, struct{ RulesHooks }]{})
|
||||
},
|
||||
"re-registration",
|
||||
register.ErrReRegistration.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
func assertPanics(t *testing.T, fn func(), wantContains string) {
|
||||
t.Helper()
|
||||
defer func() {
|
||||
t.Helper()
|
||||
switch r := recover().(type) {
|
||||
case nil:
|
||||
t.Error("function did not panic as expected")
|
||||
t.Error("function did not panic when panic expected")
|
||||
case string:
|
||||
assert.Contains(t, r, wantContains)
|
||||
case error:
|
||||
assert.Contains(t, r.Error(), wantContains)
|
||||
default:
|
||||
t.Fatalf("BAD TEST SETUP: recover() got unsupported type %T", r)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ type RulesAllowlistHooks interface {
|
|||
// Hooks returns the hooks registered with [RegisterExtras], or [NOOPHooks] if
|
||||
// none were registered.
|
||||
func (c *ChainConfig) Hooks() ChainConfigHooks {
|
||||
if e := registeredExtras; e != nil {
|
||||
return e.payloads.hooksFromChainConfig(c)
|
||||
if e := registeredExtras; e.Registered() {
|
||||
return e.Get().payloads.hooksFromChainConfig(c)
|
||||
}
|
||||
return NOOPHooks{}
|
||||
}
|
||||
|
|
@ -78,8 +78,8 @@ func (c *ChainConfig) Hooks() ChainConfigHooks {
|
|||
// Hooks returns the hooks registered with [RegisterExtras], or [NOOPHooks] if
|
||||
// none were registered.
|
||||
func (r *Rules) Hooks() RulesHooks {
|
||||
if e := registeredExtras; e != nil {
|
||||
return e.payloads.hooksFromRules(r)
|
||||
if e := registeredExtras; e.Registered() {
|
||||
return e.Get().payloads.hooksFromRules(r)
|
||||
}
|
||||
return NOOPHooks{}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,11 +42,11 @@ type chainConfigWithExportedExtra struct {
|
|||
// UnmarshalJSON implements the [json.Unmarshaler] interface.
|
||||
func (c *ChainConfig) UnmarshalJSON(data []byte) error {
|
||||
switch reg := registeredExtras; {
|
||||
case reg != nil && !reg.reuseJSONRoot:
|
||||
case reg.Registered() && !reg.Get().reuseJSONRoot:
|
||||
return c.unmarshalJSONWithExtra(data)
|
||||
|
||||
case reg != nil && reg.reuseJSONRoot: // although the latter is redundant, it's clearer
|
||||
c.extra = reg.newChainConfig()
|
||||
case reg.Registered() && reg.Get().reuseJSONRoot: // although the latter is redundant, it's clearer
|
||||
c.extra = reg.Get().newChainConfig()
|
||||
if err := json.Unmarshal(data, c.extra); err != nil {
|
||||
c.extra = nil
|
||||
return err
|
||||
|
|
@ -63,7 +63,7 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error {
|
|||
func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error {
|
||||
cc := &chainConfigWithExportedExtra{
|
||||
chainConfigWithoutMethods: (*chainConfigWithoutMethods)(c),
|
||||
Extra: registeredExtras.newChainConfig(),
|
||||
Extra: registeredExtras.Get().newChainConfig(),
|
||||
}
|
||||
if err := json.Unmarshal(data, cc); err != nil {
|
||||
return err
|
||||
|
|
@ -75,10 +75,10 @@ func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error {
|
|||
// MarshalJSON implements the [json.Marshaler] interface.
|
||||
func (c *ChainConfig) MarshalJSON() ([]byte, error) {
|
||||
switch reg := registeredExtras; {
|
||||
case reg == nil:
|
||||
case !reg.Registered():
|
||||
return json.Marshal((*chainConfigWithoutMethods)(c))
|
||||
|
||||
case !reg.reuseJSONRoot:
|
||||
case !reg.Get().reuseJSONRoot:
|
||||
return c.marshalJSONWithExtra()
|
||||
|
||||
default: // reg.reuseJSONRoot == true
|
||||
|
|
|
|||
Loading…
Reference in a new issue