core/state: track the account originals

This commit is contained in:
Gary Rong 2026-04-23 12:14:46 +08:00 committed by Jared Wasinger
parent 53c3d614a9
commit 5b9a06d2f6
3 changed files with 358 additions and 54 deletions

View file

@ -31,9 +31,6 @@ type revision struct {
journalIndex int
}
// journalMutation represents a set of mutations applied to a certain account.
type journalMutation uint8
// journalMutationKind indicates the type of account mutation.
type journalMutationKind uint8
@ -47,13 +44,6 @@ const (
journalMutationKindStorage
)
func (k journalMutationKind) mask() journalMutation {
if k == 0 {
return 0
}
return journalMutation(1) << (k - 1)
}
type journalMutationCounts struct {
touch int
create int
@ -64,25 +54,60 @@ type journalMutationCounts struct {
storage int
}
// journalMutationState tracks, per account, both the per-kind count of mutation
// entries currently present in the journal and the pre-tx value of each
// metadata field captured on its first touch (balance/nonce/code).
// The *Set flags indicate whether the corresponding field has been mutated
// at least once in the current tx window; they are cleared when all entries
// of that kind are reverted. Storage slots are tracked elsewhere.
type journalMutationState struct {
mask journalMutation
counts journalMutationCounts
balance *uint256.Int
balanceSet bool
nonce uint64
nonceSet bool
code []byte
codeSet bool
}
func (s *journalMutationState) add(kind journalMutationKind) {
s.counts.add(kind)
s.mask |= kind.mask()
}
func (s *journalMutationState) remove(kind journalMutationKind) bool {
if s.counts.remove(kind) {
s.mask &^= kind.mask()
// remove drops one occurrence of the given mutation kind. It returns two
// booleans: kindEmpty is true when no entries of that kind remain for the
// account, and stateEmpty is true when no entries of any kind remain.
func (s *journalMutationState) remove(kind journalMutationKind) (kindEmpty, stateEmpty bool) {
kindEmpty = s.counts.remove(kind)
return kindEmpty, s.counts == (journalMutationCounts{})
}
// clearKind drops the stashed original for the given mutation kind. It is
// invoked during revert once no journal entries of that kind remain for the
// account. Kinds that don't correspond to a tracked metadata field are no-ops.
func (s *journalMutationState) clearKind(kind journalMutationKind) {
switch kind {
case journalMutationKindBalance:
s.balance = nil
s.balanceSet = false
case journalMutationKindNonce:
s.nonce = 0
s.nonceSet = false
case journalMutationKindCode:
s.code = nil
s.codeSet = false
}
return s.mask == 0
}
func (s journalMutationState) copy() *journalMutationState {
cpy := s
if s.balance != nil {
cpy.balance = new(uint256.Int).Set(s.balance)
}
if s.code != nil {
cpy.code = slices.Clone(s.code)
}
return &cpy
}
@ -146,12 +171,55 @@ type journalEntry interface {
copy() journalEntry
}
// stashBalance records prev as the pre-tx balance of addr, iff this is the
// first balance touch seen in the current tx. Subsequent balance writes are
// ignored so the stored value remains the true pre-tx original.
func (j *journal) stashBalance(addr common.Address, prev *uint256.Int) {
s := j.mutationStateFor(addr)
if s.balanceSet {
return
}
s.balance = prev.Clone()
s.balanceSet = true
}
// stashNonce records prev as the pre-tx nonce of addr on first touch.
func (j *journal) stashNonce(addr common.Address, prev uint64) {
s := j.mutationStateFor(addr)
if s.nonceSet {
return
}
s.nonce = prev
s.nonceSet = true
}
// stashCode records prev as the pre-tx code of addr on first touch.
func (j *journal) stashCode(addr common.Address, prev []byte) {
s := j.mutationStateFor(addr)
if s.codeSet {
return
}
s.code = slices.Clone(prev)
s.codeSet = true
}
// mutationStateFor returns the mutation state for addr, creating an empty one
// if absent.
func (j *journal) mutationStateFor(addr common.Address) *journalMutationState {
s := j.mutations[addr]
if s == nil {
s = new(journalMutationState)
j.mutations[addr] = s
}
return s
}
// journal contains the list of state modifications applied since the last state
// commit. These are tracked to be able to be reverted in the case of an execution
// exception or request for reversal.
type journal struct {
entries []journalEntry // Current changes tracked by the journal
mutations map[common.Address]*journalMutationState // Account mutation state accumulated across entries
mutations map[common.Address]*journalMutationState // Per-account mutation kinds and pre-tx originals
validRevisions []revision
nextRevisionId int
@ -224,7 +292,14 @@ func (j *journal) revert(statedb *StateDB, snapshot int) {
if state == nil {
panic(fmt.Errorf("journal mutation tracking missing for %x", addr[:]))
}
if state.remove(kind) {
kindEmpty, stateEmpty := state.remove(kind)
if kindEmpty {
// No entries of this kind remain for this account; drop the
// corresponding stashed original so the state mirrors the
// live mutation set.
state.clearKind(kind)
}
if stateEmpty {
delete(j.mutations, addr)
}
}
@ -249,24 +324,6 @@ func (j *journal) ripemdMagic() {
state.add(journalMutationKindTouch)
}
func (j *journal) mutation(addr common.Address) journalMutation {
if state := j.mutations[addr]; state != nil {
return state.mask
}
return 0
}
func (j *journal) mutationSet() map[common.Address]journalMutation {
if j.mutations == nil {
return nil
}
out := make(map[common.Address]journalMutation, len(j.mutations))
for addr, state := range j.mutations {
out[addr] = state.mask
}
return out
}
// length returns the current number of entries in the journal.
func (j *journal) length() int {
return len(j.entries)
@ -335,6 +392,7 @@ func (j *journal) refundChange(previous uint64) {
}
func (j *journal) balanceChange(addr common.Address, previous *uint256.Int) {
j.stashBalance(addr, previous)
j.append(balanceChange{
account: addr,
prev: previous.Clone(),
@ -342,6 +400,7 @@ func (j *journal) balanceChange(addr common.Address, previous *uint256.Int) {
}
func (j *journal) setCode(address common.Address, prevCode []byte) {
j.stashCode(address, prevCode)
j.append(codeChange{
account: address,
prevCode: prevCode,
@ -349,6 +408,7 @@ func (j *journal) setCode(address common.Address, prevCode []byte) {
}
func (j *journal) nonceChange(address common.Address, prev uint64) {
j.stashNonce(address, prev)
j.append(nonceChange{
account: address,
prev: prev,

219
core/state/journal_test.go Normal file
View file

@ -0,0 +1,219 @@
// Copyright 2026 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/tracing"
"github.com/ethereum/go-ethereum/core/types"
"github.com/holiman/uint256"
)
// fuzzJournalAddrs is a small fixed pool used by the fuzz harness to force
// repeated collisions on the same account, which exercises the multi-entry
// path in the journal's mutation tracking and originals cleanup on revert.
// It deliberately excludes the RIPEMD-160 precompile (0x03), which has a
// consensus-level touch/revert exception that would complicate invariants.
var fuzzJournalAddrs = []common.Address{
common.BytesToAddress([]byte{0x11}),
common.BytesToAddress([]byte{0x22}),
common.BytesToAddress([]byte{0x44}),
}
// checkJournalInvariants validates that:
// - journal.mutations exactly reflects the dirty entries currently in
// journal.entries (per-kind counts and mask match what you'd get by
// walking the entries from scratch).
// - journal.originals mirrors that set for the three tracked metadata kinds
// (balance/nonce/code): a *Set flag is true iff the account currently has
// at least one corresponding entry in the journal.
// - An address is present in originals only if it also has at least one
// tracked-kind mutation in the journal.
func checkJournalInvariants(t *testing.T, j *journal) {
t.Helper()
// Reconstruct the expected per-address counts from the live entries.
expected := make(map[common.Address]*journalMutationCounts)
for _, e := range j.entries {
addr, kind, dirty := e.mutation()
if !dirty {
continue
}
c := expected[addr]
if c == nil {
c = &journalMutationCounts{}
expected[addr] = c
}
c.add(kind)
}
if len(j.mutations) != len(expected) {
t.Fatalf("mutations size %d, want %d", len(j.mutations), len(expected))
}
for addr, state := range j.mutations {
want, ok := expected[addr]
if !ok {
t.Fatalf("mutations has extra address %x", addr)
}
if state.counts != *want {
t.Fatalf("addr %x: counts=%+v want=%+v", addr, state.counts, *want)
}
// First-touch *Set flags must mirror the live per-kind counts.
if state.balanceSet != (want.balance > 0) {
t.Fatalf("addr %x: balanceSet=%v want=%v (balance count=%d)",
addr, state.balanceSet, want.balance > 0, want.balance)
}
if state.nonceSet != (want.nonce > 0) {
t.Fatalf("addr %x: nonceSet=%v want=%v (nonce count=%d)",
addr, state.nonceSet, want.nonce > 0, want.nonce)
}
if state.codeSet != (want.code > 0) {
t.Fatalf("addr %x: codeSet=%v want=%v (code count=%d)",
addr, state.codeSet, want.code > 0, want.code)
}
}
}
// FuzzJournal drives a randomised sequence of state mutations, snapshots and
// reverts against a fresh StateDB and validates the journal's internal
// bookkeeping invariants after every step. It also asserts that reverting
// back to the root snapshot empties mutations, originals and entries
// completely. The seed corpus ensures the test also runs as a regular unit
// test via `go test -run FuzzJournal`.
func FuzzJournal(f *testing.F) {
seeds := [][]byte{
// balance then full revert (simplest a→b→a case).
{0x00, 0x00, 0x05, 0x05, 0x00},
// balance+nonce+code mixed, then revert to root.
{0x00, 0x00, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x00, 0x03, 0x05, 0x00},
// snapshot, mutate, revert, mutate again.
{0x04, 0x00, 0x00, 0x07, 0x05, 0x00, 0x00, 0x01, 0x05},
// storage interleaved with metadata.
{0x03, 0x00, 0x01, 0x00, 0x01, 0x05, 0x03, 0x02, 0x02, 0x04, 0x03, 0x01, 0x07},
// many ops, no explicit revert — exercises steady-state invariants.
{0x00, 0x01, 0x02, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02,
0x03, 0x04, 0x00, 0x01, 0x02, 0x00, 0x06, 0x08, 0x0a, 0x0c},
}
for _, s := range seeds {
f.Add(s)
}
f.Fuzz(func(t *testing.T, data []byte) {
sdb, err := New(types.EmptyRootHash, NewDatabaseForTesting())
if err != nil {
t.Fatal(err)
}
root := sdb.Snapshot()
// Stack of snapshot IDs taken during the fuzz loop.
var pending []int
// readByte returns the next byte and advances the cursor. Returns
// (0, false) if exhausted.
i := 0
readByte := func() (byte, bool) {
if i >= len(data) {
return 0, false
}
b := data[i]
i++
return b, true
}
for {
op, ok := readByte()
if !ok {
break
}
switch op % 6 {
case 0: // SetBalance
a, ok1 := readByte()
v, ok2 := readByte()
if !ok1 || !ok2 {
break
}
addr := fuzzJournalAddrs[int(a)%len(fuzzJournalAddrs)]
sdb.SetBalance(addr, uint256.NewInt(uint64(v)), tracing.BalanceChangeUnspecified)
case 1: // SetNonce
a, ok1 := readByte()
n, ok2 := readByte()
if !ok1 || !ok2 {
break
}
addr := fuzzJournalAddrs[int(a)%len(fuzzJournalAddrs)]
sdb.SetNonce(addr, uint64(n), tracing.NonceChangeUnspecified)
case 2: // SetCode
a, ok1 := readByte()
l, ok2 := readByte()
if !ok1 || !ok2 {
break
}
addr := fuzzJournalAddrs[int(a)%len(fuzzJournalAddrs)]
code := make([]byte, int(l)%8)
for k := range code {
b, ok := readByte()
if !ok {
break
}
code[k] = b
}
sdb.SetCode(addr, code, tracing.CodeChangeUnspecified)
case 3: // SetState (storage; tracked as mutation kind, no original)
a, ok1 := readByte()
k, ok2 := readByte()
v, ok3 := readByte()
if !ok1 || !ok2 || !ok3 {
break
}
addr := fuzzJournalAddrs[int(a)%len(fuzzJournalAddrs)]
sdb.SetState(addr,
common.BytesToHash([]byte{k}),
common.BytesToHash([]byte{v}))
case 4: // Snapshot
pending = append(pending, sdb.Snapshot())
case 5: // RevertToSnapshot
if len(pending) == 0 {
break
}
sel, ok := readByte()
if !ok {
break
}
idx := int(sel) % len(pending)
sdb.RevertToSnapshot(pending[idx])
pending = pending[:idx]
}
checkJournalInvariants(t, sdb.journal)
}
// After reverting to the root snapshot, the journal must be fully
// drained: no entries, no mutations, no originals. This is the core
// guarantee the user cares about — "all mutations against a single
// account reverted" taken to its limit across every account.
sdb.RevertToSnapshot(root)
checkJournalInvariants(t, sdb.journal)
if n := len(sdb.journal.entries); n != 0 {
t.Fatalf("entries not drained after revert-to-root: %d remain", n)
}
if n := len(sdb.journal.mutations); n != 0 {
t.Fatalf("mutations not drained after revert-to-root: %d remain", n)
}
})
}

View file

@ -662,12 +662,30 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
state.GetLogs(common.Hash{}, 0, common.Hash{}, 0), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}, 0))
}
if !maps.Equal(state.journal.mutationSet(), checkstate.journal.mutationSet()) {
return fmt.Errorf("journal mutation set mismatch.\nhave:\n%v\nwant:\n%v\n", state.journal.mutationSet(), checkstate.journal.mutationSet())
if !equalMutationSets(state.journal.mutations, checkstate.journal.mutations) {
return fmt.Errorf("journal mutation set mismatch.\nhave:\n%v\nwant:\n%v\n", state.journal.mutations, checkstate.journal.mutations)
}
return nil
}
// equalMutationSets checks that two journal mutation maps have the same set of
// addresses and, for each address, the same per-kind counts. The stashed
// original values are ignored because comparing them across two independent
// state databases (with distinct pointer identities) isn't the point of this
// check — we only care that the two journals agree on what was touched.
func equalMutationSets(a, b map[common.Address]*journalMutationState) bool {
if len(a) != len(b) {
return false
}
for addr, sa := range a {
sb, ok := b[addr]
if !ok || sa.counts != sb.counts {
return false
}
}
return true
}
func TestTouchDelete(t *testing.T) {
s := newStateEnv()
s.state.getOrNewStateObject(common.Address{})
@ -677,11 +695,11 @@ func TestTouchDelete(t *testing.T) {
snapshot := s.state.Snapshot()
s.state.AddBalance(common.Address{}, new(uint256.Int), tracing.BalanceChangeUnspecified)
if len(s.state.journal.mutationSet()) != 1 {
if len(s.state.journal.mutations) != 1 {
t.Fatal("expected one mutated state object")
}
s.state.RevertToSnapshot(snapshot)
if len(s.state.journal.mutationSet()) != 0 {
if len(s.state.journal.mutations) != 0 {
t.Fatal("expected no journal mutations")
}
}
@ -691,8 +709,8 @@ func TestJournalMutationTracking(t *testing.T) {
addr := common.HexToAddress("0x01")
key := common.HexToHash("0x02")
if got := state.journal.mutation(addr); got != 0 {
t.Fatalf("unexpected initial mutation set: %v", got)
if _, ok := state.journal.mutations[addr]; ok {
t.Fatal("unexpected initial mutation entry")
}
snapshot := state.Snapshot()
@ -701,23 +719,30 @@ func TestJournalMutationTracking(t *testing.T) {
state.SetCode(addr, []byte{0x1}, tracing.CodeChangeUnspecified)
state.SetState(addr, key, common.Hash{0x3})
want := journalMutationKindCreate.mask() |
journalMutationKindBalance.mask() |
journalMutationKindNonce.mask() |
journalMutationKindCode.mask() |
journalMutationKindStorage.mask()
if got := state.journal.mutation(addr); got != want {
t.Fatalf("mutation set mismatch: have %08b, want %08b", got, want)
want := journalMutationCounts{
create: 1,
balance: 1,
nonce: 1,
code: 1,
storage: 1,
}
checkCounts := func(got *journalMutationState, label string) {
t.Helper()
if got == nil {
t.Fatalf("%s: missing mutation entry for %x", label, addr)
}
if got.counts != want {
t.Fatalf("%s: counts=%+v, want=%+v", label, got.counts, want)
}
}
checkCounts(state.journal.mutations[addr], "state")
copy := state.Copy()
if got := copy.journal.mutation(addr); got != want {
t.Fatalf("copy mutation set mismatch: have %08b, want %08b", got, want)
}
checkCounts(copy.journal.mutations[addr], "copy")
state.RevertToSnapshot(snapshot)
if got := state.journal.mutation(addr); got != 0 {
t.Fatalf("unexpected mutation set after revert: %08b", got)
if _, ok := state.journal.mutations[addr]; ok {
t.Fatalf("unexpected mutation entry after revert")
}
}