refactor: params extra types are zero values not nil pointers by default (#13)

* refactor: extra types `C` + `R` are never plumbed as `*C` / `*R`

* refactor: force use of `pseudo.Constructor.Zero()` instead of `NilPointer()`

* feat: `pseudo.PointerTo()`

* feat: `params.ExtraPayloadGetter[C,R].PointerFromChainConfig(...) *C` and `Rules => *R` equiv

* test: shallow copy of `ChainConfig`/`Rules` includes extras
This commit is contained in:
Arran Schlosberg 2024-09-12 07:54:08 +01:00 committed by GitHub
parent 72744cebe7
commit d31803a0ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 229 additions and 63 deletions

View file

@ -32,7 +32,7 @@ type Stub struct {
// Register is a convenience wrapper for registering s as both the
// [params.ChainConfigHooks] and [params.RulesHooks] via [Register].
func (s *Stub) Register(tb testing.TB) {
Register(tb, params.Extras[Stub, Stub]{
Register(tb, params.Extras[*Stub, *Stub]{
NewRules: func(_ *params.ChainConfig, _ *params.Rules, _ *Stub, blockNum *big.Int, isMerge bool, timestamp uint64) *Stub {
return s
},

View file

@ -60,6 +60,27 @@ func Zero[T any]() *Pseudo[T] {
return From[T](x)
}
// PointerTo is equivalent to [From] called with a pointer to the payload
// carried by `t`. It first confirms that the payload is of type `T`.
func PointerTo[T any](t *Type) (*Pseudo[*T], error) {
c, ok := t.val.(*concrete[T])
if !ok {
var want *T
return nil, fmt.Errorf("cannot create *Pseudo[%T] from *Type carrying %T", want, t.val.get())
}
return From(&c.val), nil
}
// MustPointerTo is equivalent to [PointerTo] except that it panics instead of
// returning an error.
func MustPointerTo[T any](t *Type) *Pseudo[*T] {
p, err := PointerTo[T](t)
if err != nil {
panic(err)
}
return p
}
// Interface returns the wrapped value as an `any`, equivalent to
// [reflect.Value.Interface]. Prefer [Value.Get].
func (t *Type) Interface() any { return t.val.get() }

View file

@ -77,3 +77,26 @@ func ExamplePseudo_TypeAndValue() {
_ = typ
_ = val
}
func TestPointer(t *testing.T) {
type carrier struct {
payload int
}
typ, val := From(carrier{42}).TypeAndValue()
t.Run("invalid type", func(t *testing.T) {
_, err := PointerTo[int](typ)
require.Errorf(t, err, "PointerTo[int](%T)", carrier{})
})
t.Run("valid type", func(t *testing.T) {
ptrVal := MustPointerTo[carrier](typ).Value
assert.Equal(t, 42, val.Get().payload, "before setting via pointer")
var ptr *carrier = ptrVal.Get()
ptr.payload = 314159
assert.Equal(t, 314159, val.Get().payload, "after setting via pointer")
})
}

View file

@ -25,25 +25,27 @@ type Extras[C ChainConfigHooks, R RulesHooks] struct {
// NewRules, if non-nil is called at the end of [ChainConfig.Rules] with the
// newly created [Rules] and other context from the method call. Its
// returned value will be the extra payload of the [Rules]. If NewRules is
// nil then so too will the [Rules] extra payload be a nil `*R`.
// nil then so too will the [Rules] extra payload be a zero-value `R`.
//
// NewRules MAY modify the [Rules] but MUST NOT modify the [ChainConfig].
NewRules func(_ *ChainConfig, _ *Rules, _ *C, blockNum *big.Int, isMerge bool, timestamp uint64) *R
// TODO(arr4n): add the [Rules] to the return signature to make it clearer
// that the caller can modify the generated Rules.
NewRules func(_ *ChainConfig, _ *Rules, _ C, blockNum *big.Int, isMerge bool, timestamp uint64) R
}
// RegisterExtras registers the types `C` and `R` such that they are carried as
// extra payloads in [ChainConfig] and [Rules] structs, respectively. It is
// expected to be called in an `init()` function and MUST NOT be called more
// than once. Both `C` and `R` MUST be structs.
// than once. Both `C` and `R` MUST be structs or pointers to structs.
//
// After registration, JSON unmarshalling of a [ChainConfig] will create a new
// `*C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling
// will populate the "extra" key with the contents of the `*C`. Both the
// `C` and unmarshal the JSON key "extra" into it. Conversely, JSON marshalling
// will populate the "extra" key with the contents of the `C`. Both the
// [json.Marshaler] and [json.Unmarshaler] interfaces are honoured if
// implemented by `C` and/or `R.`
//
// Calls to [ChainConfig.Rules] will call the `NewRules` function of the
// registered [Extras] to create a new `*R`.
// registered [Extras] to create a new `R`.
//
// The payloads can be accessed via the [ExtraPayloadGetter.FromChainConfig] and
// [ExtraPayloadGetter.FromRules] methods of the getter returned by
@ -54,16 +56,16 @@ func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPaylo
if registeredExtras != nil {
panic("re-registration of Extras")
}
mustBeStruct[C]()
mustBeStruct[R]()
mustBeStructOrPointerToOne[C]()
mustBeStructOrPointerToOne[R]()
getter := e.getter()
registeredExtras = &extraConstructors{
chainConfig: pseudo.NewConstructor[C](),
rules: pseudo.NewConstructor[R](),
reuseJSONRoot: e.ReuseJSONRoot,
newForRules: e.newForRules,
getter: getter,
newChainConfig: pseudo.NewConstructor[C]().Zero,
newRules: pseudo.NewConstructor[R]().Zero,
reuseJSONRoot: e.ReuseJSONRoot,
newForRules: e.newForRules,
getter: getter,
}
return getter
}
@ -95,9 +97,9 @@ func TestOnlyClearRegisteredExtras() {
var registeredExtras *extraConstructors
type extraConstructors struct {
chainConfig, rules pseudo.Constructor
reuseJSONRoot bool
newForRules func(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type
newChainConfig, newRules func() *pseudo.Type
reuseJSONRoot bool
newForRules func(_ *ChainConfig, _ *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type
// use top-level hooksFrom<X>() functions instead of these as they handle
// instances where no [Extras] were registered.
getter interface {
@ -108,7 +110,7 @@ type extraConstructors struct {
func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type {
if e.NewRules == nil {
return registeredExtras.rules.NilPointer()
return registeredExtras.newRules()
}
rExtra := e.NewRules(c, r, e.getter().FromChainConfig(c), blockNum, isMerge, timestamp)
return pseudo.From(rExtra).Type
@ -116,19 +118,26 @@ func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int,
func (*Extras[C, R]) getter() (g ExtraPayloadGetter[C, R]) { return }
// mustBeStruct panics if `T` isn't a struct.
func mustBeStruct[T any]() {
// mustBeStructOrPointerToOne panics if `T` isn't a struct or a *struct.
func mustBeStructOrPointerToOne[T any]() {
var x T
if k := reflect.TypeOf(x).Kind(); k != reflect.Struct {
panic(notStructMessage[T]())
switch t := reflect.TypeOf(x); t.Kind() {
case reflect.Struct:
return
case reflect.Pointer:
if t.Elem().Kind() == reflect.Struct {
return
}
}
panic(notStructMessage[T]())
}
// notStructMessage returns the message with which [mustBeStruct] might panic.
// It exists to avoid change-detector tests should the message contents change.
// notStructMessage returns the message with which [mustBeStructOrPointerToOne]
// might panic. It exists to avoid change-detector tests should the message
// contents change.
func notStructMessage[T any]() string {
var x T
return fmt.Sprintf("%T is not a struct", x)
return fmt.Sprintf("%T is not a struct nor a pointer to a struct", x)
}
// An ExtraPayloadGettter provides strongly typed access to the extra payloads
@ -139,33 +148,37 @@ type ExtraPayloadGetter[C ChainConfigHooks, R RulesHooks] struct {
}
// FromChainConfig returns the ChainConfig's extra payload.
func (ExtraPayloadGetter[C, R]) FromChainConfig(c *ChainConfig) *C {
return pseudo.MustNewValue[*C](c.extraPayload()).Get()
func (ExtraPayloadGetter[C, R]) FromChainConfig(c *ChainConfig) C {
return pseudo.MustNewValue[C](c.extraPayload()).Get()
}
// PointerFromChainConfig returns a pointer to the ChainConfig's extra payload.
// This is guaranteed to be non-nil.
func (ExtraPayloadGetter[C, R]) PointerFromChainConfig(c *ChainConfig) *C {
return pseudo.MustPointerTo[C](c.extraPayload()).Value.Get()
}
// hooksFromChainConfig is equivalent to FromChainConfig(), but returns an
// interface instead of the concrete type implementing it; this allows it to be
// used in non-generic code. If the concrete-type value is nil (typically
// because no [Extras] were registered) a [noopHooks] is returned so it can be
// used without nil checks.
// used in non-generic code.
func (e ExtraPayloadGetter[C, R]) hooksFromChainConfig(c *ChainConfig) ChainConfigHooks {
if h := e.FromChainConfig(c); h != nil {
return *h
}
return NOOPHooks{}
return e.FromChainConfig(c)
}
// FromRules returns the Rules' extra payload.
func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) *R {
return pseudo.MustNewValue[*R](r.extraPayload()).Get()
func (ExtraPayloadGetter[C, R]) FromRules(r *Rules) R {
return pseudo.MustNewValue[R](r.extraPayload()).Get()
}
// PointerFromRules returns a pointer to the Rules's extra payload. This is
// guaranteed to be non-nil.
func (ExtraPayloadGetter[C, R]) PointerFromRules(r *Rules) *R {
return pseudo.MustPointerTo[R](r.extraPayload()).Value.Get()
}
// hooksFromRules is the [RulesHooks] equivalent of hooksFromChainConfig().
func (e ExtraPayloadGetter[C, R]) hooksFromRules(r *Rules) RulesHooks {
if h := e.FromRules(r); h != nil {
return *h
}
return NOOPHooks{}
return e.FromRules(r)
}
// addRulesExtra is called at the end of [ChainConfig.Rules]; it exists to
@ -189,7 +202,7 @@ func (c *ChainConfig) extraPayload() *pseudo.Type {
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c))
}
if c.extra == nil {
c.extra = registeredExtras.chainConfig.NilPointer()
c.extra = registeredExtras.newChainConfig()
}
return c.extra
}
@ -201,7 +214,7 @@ func (r *Rules) extraPayload() *pseudo.Type {
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r))
}
if r.extra == nil {
r.extra = registeredExtras.rules.NilPointer()
r.extra = registeredExtras.newRules()
}
return r.extra
}

View file

@ -50,26 +50,36 @@ func TestRegisterExtras(t *testing.T) {
name: "Rules payload copied from ChainConfig payload",
register: func() {
RegisterExtras(Extras[ccExtraA, rulesExtraA]{
NewRules: func(cc *ChainConfig, r *Rules, ex *ccExtraA, _ *big.Int, _ bool, _ uint64) *rulesExtraA {
return &rulesExtraA{
NewRules: func(cc *ChainConfig, r *Rules, ex ccExtraA, _ *big.Int, _ bool, _ uint64) rulesExtraA {
return rulesExtraA{
A: ex.A,
}
},
})
},
ccExtra: pseudo.From(&ccExtraA{
ccExtra: pseudo.From(ccExtraA{
A: "hello",
}).Type,
wantRulesExtra: &rulesExtraA{
wantRulesExtra: rulesExtraA{
A: "hello",
},
},
{
name: "no NewForRules() function results in typed but nil pointer",
name: "no NewForRules() function results in zero value",
register: func() {
RegisterExtras(Extras[ccExtraB, rulesExtraB]{})
},
ccExtra: pseudo.From(&ccExtraB{
ccExtra: pseudo.From(ccExtraB{
B: "world",
}).Type,
wantRulesExtra: rulesExtraB{},
},
{
name: "no NewForRules() function results in nil pointer",
register: func() {
RegisterExtras(Extras[ccExtraB, *rulesExtraB]{})
},
ccExtra: pseudo.From(ccExtraB{
B: "world",
}).Type,
wantRulesExtra: (*rulesExtraB)(nil),
@ -79,10 +89,10 @@ func TestRegisterExtras(t *testing.T) {
register: func() {
RegisterExtras(Extras[rawJSON, struct{ RulesHooks }]{})
},
ccExtra: pseudo.From(&rawJSON{
ccExtra: pseudo.From(rawJSON{
RawMessage: []byte(`"hello, world"`),
}).Type,
wantRulesExtra: (*struct{ RulesHooks })(nil),
wantRulesExtra: struct{ RulesHooks }{},
},
}
@ -111,6 +121,75 @@ func TestRegisterExtras(t *testing.T) {
}
}
func TestModificationOfZeroExtras(t *testing.T) {
type (
ccExtra struct {
X int
NOOPHooks
}
rulesExtra struct {
X int
NOOPHooks
}
)
TestOnlyClearRegisteredExtras()
t.Cleanup(TestOnlyClearRegisteredExtras)
getter := RegisterExtras(Extras[ccExtra, rulesExtra]{})
config := new(ChainConfig)
rules := new(Rules)
// These assertion helpers are defined before any modifications so that the
// closure is demonstrably over the original zero values.
assertChainConfigExtra := func(t *testing.T, want ccExtra, msg string) {
t.Helper()
assert.Equalf(t, want, getter.FromChainConfig(config), "%T: "+msg, &config)
}
assertRulesExtra := func(t *testing.T, want rulesExtra, msg string) {
t.Helper()
assert.Equalf(t, want, getter.FromRules(rules), "%T: "+msg, &rules)
}
assertChainConfigExtra(t, ccExtra{}, "zero value")
assertRulesExtra(t, rulesExtra{}, "zero value")
const answer = 42
getter.PointerFromChainConfig(config).X = answer
assertChainConfigExtra(t, ccExtra{X: answer}, "after setting via pointer field")
const pi = 314159
getter.PointerFromRules(rules).X = pi
assertRulesExtra(t, rulesExtra{X: pi}, "after setting via pointer field")
ccReplace := ccExtra{X: 142857}
*getter.PointerFromChainConfig(config) = ccReplace
assertChainConfigExtra(t, ccReplace, "after replacement of entire extra via `*pointer = x`")
rulesReplace := rulesExtra{X: 18101986}
*getter.PointerFromRules(rules) = rulesReplace
assertRulesExtra(t, rulesReplace, "after replacement of entire extra via `*pointer = x`")
if t.Failed() {
// The test of shallow copying is now guaranteed to fail.
return
}
t.Run("shallow copy", func(t *testing.T) {
ccCopy := *config
rCopy := *rules
assert.Equal(t, getter.FromChainConfig(&ccCopy), ccReplace, "ChainConfig extras copied")
assert.Equal(t, getter.FromRules(&rCopy), rulesReplace, "Rules extras copied")
const seqUp = 123456789
getter.PointerFromChainConfig(&ccCopy).X = seqUp
assertChainConfigExtra(t, ccExtra{X: seqUp}, "original changed because copy only shallow")
const seqDown = 987654321
getter.PointerFromRules(&rCopy).X = seqDown
assertRulesExtra(t, rulesExtra{X: seqDown}, "original changed because copy only shallow")
})
}
func TestExtrasPanic(t *testing.T) {
TestOnlyClearRegisteredExtras()
defer TestOnlyClearRegisteredExtras()
@ -131,7 +210,7 @@ func TestExtrasPanic(t *testing.T) {
assertPanics(
t, func() {
mustBeStruct[int]()
mustBeStructOrPointerToOne[int]()
},
notStructMessage[int](),
)

View file

@ -40,8 +40,8 @@ var getter params.ExtraPayloadGetter[ChainConfigExtra, RulesExtra]
// constructRulesExtra acts as an adjunct to the [params.ChainConfig.Rules]
// method. Its primary purpose is to construct the extra payload for the
// [params.Rules] but it MAY also modify the [params.Rules].
func constructRulesExtra(c *params.ChainConfig, r *params.Rules, cEx *ChainConfigExtra, blockNum *big.Int, isMerge bool, timestamp uint64) *RulesExtra {
return &RulesExtra{
func constructRulesExtra(c *params.ChainConfig, r *params.Rules, cEx ChainConfigExtra, blockNum *big.Int, isMerge bool, timestamp uint64) RulesExtra {
return RulesExtra{
IsMyFork: cEx.MyForkTime != nil && *cEx.MyForkTime <= timestamp,
timestamp: timestamp,
}
@ -66,12 +66,12 @@ type RulesExtra struct {
}
// FromChainConfig returns the extra payload carried by the ChainConfig.
func FromChainConfig(c *params.ChainConfig) *ChainConfigExtra {
func FromChainConfig(c *params.ChainConfig) ChainConfigExtra {
return getter.FromChainConfig(c)
}
// FromRules returns the extra payload carried by the Rules.
func FromRules(r *params.Rules) *RulesExtra {
func FromRules(r *params.Rules) RulesExtra {
return getter.FromRules(r)
}
@ -137,16 +137,14 @@ func ExampleExtraPayloadGetter() {
fmt.Println("Chain ID", config.ChainID) // original geth fields work as expected
ccExtra := FromChainConfig(config) // extraparams.FromChainConfig() in practice
if ccExtra != nil && ccExtra.MyForkTime != nil {
if ccExtra.MyForkTime != nil {
fmt.Println("Fork time", *ccExtra.MyForkTime)
}
for _, time := range []uint64{forkTime - 1, forkTime, forkTime + 1} {
rules := config.Rules(nil, false, time)
rExtra := FromRules(&rules) // extraparams.FromRules() in practice
if rExtra != nil {
fmt.Printf("IsMyFork at %v: %t\n", rExtra.timestamp, rExtra.IsMyFork)
}
fmt.Printf("IsMyFork at %v: %t\n", rExtra.timestamp, rExtra.IsMyFork)
}
// Output:

View file

@ -30,7 +30,7 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error {
return c.unmarshalJSONWithExtra(data)
case reg != nil && reg.reuseJSONRoot: // although the latter is redundant, it's clearer
c.extra = reg.chainConfig.NilPointer()
c.extra = reg.newChainConfig()
if err := json.Unmarshal(data, c.extra); err != nil {
c.extra = nil
return err
@ -47,7 +47,7 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error {
func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error {
cc := &chainConfigWithExportedExtra{
chainConfigWithoutMethods: (*chainConfigWithoutMethods)(c),
Extra: registeredExtras.chainConfig.NilPointer(),
Extra: registeredExtras.newChainConfig(),
}
if err := json.Unmarshal(data, cc); err != nil {
return err

View file

@ -40,7 +40,7 @@ func TestChainConfigJSONRoundTrip(t *testing.T) {
},
},
{
name: "reuse top-level JSON",
name: "reuse top-level JSON with non-pointer",
register: func() {
RegisterExtras(Extras[rootJSONChainConfigExtra, NOOPHooks]{
ReuseJSONRoot: true,
@ -50,13 +50,29 @@ func TestChainConfigJSONRoundTrip(t *testing.T) {
"chainId": 5678,
"foo": "hello"
}`,
want: &ChainConfig{
ChainID: big.NewInt(5678),
extra: pseudo.From(rootJSONChainConfigExtra{TopLevelFoo: "hello"}).Type,
},
},
{
name: "reuse top-level JSON with pointer",
register: func() {
RegisterExtras(Extras[*rootJSONChainConfigExtra, NOOPHooks]{
ReuseJSONRoot: true,
})
},
jsonInput: `{
"chainId": 5678,
"foo": "hello"
}`,
want: &ChainConfig{
ChainID: big.NewInt(5678),
extra: pseudo.From(&rootJSONChainConfigExtra{TopLevelFoo: "hello"}).Type,
},
},
{
name: "nested JSON",
name: "nested JSON with non-pointer",
register: func() {
RegisterExtras(Extras[nestedChainConfigExtra, NOOPHooks]{
ReuseJSONRoot: false, // explicit zero value only for tests
@ -66,6 +82,22 @@ func TestChainConfigJSONRoundTrip(t *testing.T) {
"chainId": 42,
"extra": {"foo": "world"}
}`,
want: &ChainConfig{
ChainID: big.NewInt(42),
extra: pseudo.From(nestedChainConfigExtra{NestedFoo: "world"}).Type,
},
},
{
name: "nested JSON with pointer",
register: func() {
RegisterExtras(Extras[*nestedChainConfigExtra, NOOPHooks]{
ReuseJSONRoot: false, // explicit zero value only for tests
})
},
jsonInput: `{
"chainId": 42,
"extra": {"foo": "world"}
}`,
want: &ChainConfig{
ChainID: big.NewInt(42),
extra: pseudo.From(&nestedChainConfigExtra{NestedFoo: "world"}).Type,