diff --git a/core/state/journal.go b/core/state/journal.go index d8d8923e64..8f2e888b38 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -31,9 +31,6 @@ type revision struct { journalIndex int } -// journalMutation represents a set of mutations applied to a certain account. -type journalMutation uint8 - // journalMutationKind indicates the type of account mutation. type journalMutationKind uint8 @@ -47,13 +44,6 @@ const ( journalMutationKindStorage ) -func (k journalMutationKind) mask() journalMutation { - if k == 0 { - return 0 - } - return journalMutation(1) << (k - 1) -} - type journalMutationCounts struct { touch int create int @@ -64,25 +54,60 @@ type journalMutationCounts struct { storage int } +// journalMutationState tracks, per account, both the per-kind count of mutation +// entries currently present in the journal and the pre-tx value of each +// metadata field captured on its first touch (balance/nonce/code). +// The *Set flags indicate whether the corresponding field has been mutated +// at least once in the current tx window; they are cleared when all entries +// of that kind are reverted. Storage slots are tracked elsewhere. type journalMutationState struct { - mask journalMutation counts journalMutationCounts + + balance *uint256.Int + balanceSet bool + nonce uint64 + nonceSet bool + code []byte + codeSet bool } func (s *journalMutationState) add(kind journalMutationKind) { s.counts.add(kind) - s.mask |= kind.mask() } -func (s *journalMutationState) remove(kind journalMutationKind) bool { - if s.counts.remove(kind) { - s.mask &^= kind.mask() +// remove drops one occurrence of the given mutation kind. It returns two +// booleans: kindEmpty is true when no entries of that kind remain for the +// account, and stateEmpty is true when no entries of any kind remain. +func (s *journalMutationState) remove(kind journalMutationKind) (kindEmpty, stateEmpty bool) { + kindEmpty = s.counts.remove(kind) + return kindEmpty, s.counts == (journalMutationCounts{}) +} + +// clearKind drops the stashed original for the given mutation kind. It is +// invoked during revert once no journal entries of that kind remain for the +// account. Kinds that don't correspond to a tracked metadata field are no-ops. +func (s *journalMutationState) clearKind(kind journalMutationKind) { + switch kind { + case journalMutationKindBalance: + s.balance = nil + s.balanceSet = false + case journalMutationKindNonce: + s.nonce = 0 + s.nonceSet = false + case journalMutationKindCode: + s.code = nil + s.codeSet = false } - return s.mask == 0 } func (s journalMutationState) copy() *journalMutationState { cpy := s + if s.balance != nil { + cpy.balance = new(uint256.Int).Set(s.balance) + } + if s.code != nil { + cpy.code = slices.Clone(s.code) + } return &cpy } @@ -146,12 +171,55 @@ type journalEntry interface { copy() journalEntry } +// stashBalance records prev as the pre-tx balance of addr, iff this is the +// first balance touch seen in the current tx. Subsequent balance writes are +// ignored so the stored value remains the true pre-tx original. +func (j *journal) stashBalance(addr common.Address, prev *uint256.Int) { + s := j.mutationStateFor(addr) + if s.balanceSet { + return + } + s.balance = prev.Clone() + s.balanceSet = true +} + +// stashNonce records prev as the pre-tx nonce of addr on first touch. +func (j *journal) stashNonce(addr common.Address, prev uint64) { + s := j.mutationStateFor(addr) + if s.nonceSet { + return + } + s.nonce = prev + s.nonceSet = true +} + +// stashCode records prev as the pre-tx code of addr on first touch. +func (j *journal) stashCode(addr common.Address, prev []byte) { + s := j.mutationStateFor(addr) + if s.codeSet { + return + } + s.code = slices.Clone(prev) + s.codeSet = true +} + +// mutationStateFor returns the mutation state for addr, creating an empty one +// if absent. +func (j *journal) mutationStateFor(addr common.Address) *journalMutationState { + s := j.mutations[addr] + if s == nil { + s = new(journalMutationState) + j.mutations[addr] = s + } + return s +} + // journal contains the list of state modifications applied since the last state // commit. These are tracked to be able to be reverted in the case of an execution // exception or request for reversal. type journal struct { entries []journalEntry // Current changes tracked by the journal - mutations map[common.Address]*journalMutationState // Account mutation state accumulated across entries + mutations map[common.Address]*journalMutationState // Per-account mutation kinds and pre-tx originals validRevisions []revision nextRevisionId int @@ -224,7 +292,14 @@ func (j *journal) revert(statedb *StateDB, snapshot int) { if state == nil { panic(fmt.Errorf("journal mutation tracking missing for %x", addr[:])) } - if state.remove(kind) { + kindEmpty, stateEmpty := state.remove(kind) + if kindEmpty { + // No entries of this kind remain for this account; drop the + // corresponding stashed original so the state mirrors the + // live mutation set. + state.clearKind(kind) + } + if stateEmpty { delete(j.mutations, addr) } } @@ -249,24 +324,6 @@ func (j *journal) ripemdMagic() { state.add(journalMutationKindTouch) } -func (j *journal) mutation(addr common.Address) journalMutation { - if state := j.mutations[addr]; state != nil { - return state.mask - } - return 0 -} - -func (j *journal) mutationSet() map[common.Address]journalMutation { - if j.mutations == nil { - return nil - } - out := make(map[common.Address]journalMutation, len(j.mutations)) - for addr, state := range j.mutations { - out[addr] = state.mask - } - return out -} - // length returns the current number of entries in the journal. func (j *journal) length() int { return len(j.entries) @@ -335,6 +392,7 @@ func (j *journal) refundChange(previous uint64) { } func (j *journal) balanceChange(addr common.Address, previous *uint256.Int) { + j.stashBalance(addr, previous) j.append(balanceChange{ account: addr, prev: previous.Clone(), @@ -342,6 +400,7 @@ func (j *journal) balanceChange(addr common.Address, previous *uint256.Int) { } func (j *journal) setCode(address common.Address, prevCode []byte) { + j.stashCode(address, prevCode) j.append(codeChange{ account: address, prevCode: prevCode, @@ -349,6 +408,7 @@ func (j *journal) setCode(address common.Address, prevCode []byte) { } func (j *journal) nonceChange(address common.Address, prev uint64) { + j.stashNonce(address, prev) j.append(nonceChange{ account: address, prev: prev, diff --git a/core/state/journal_test.go b/core/state/journal_test.go new file mode 100644 index 0000000000..56e6f575f0 --- /dev/null +++ b/core/state/journal_test.go @@ -0,0 +1,219 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/tracing" + "github.com/ethereum/go-ethereum/core/types" + "github.com/holiman/uint256" +) + +// fuzzJournalAddrs is a small fixed pool used by the fuzz harness to force +// repeated collisions on the same account, which exercises the multi-entry +// path in the journal's mutation tracking and originals cleanup on revert. +// It deliberately excludes the RIPEMD-160 precompile (0x03), which has a +// consensus-level touch/revert exception that would complicate invariants. +var fuzzJournalAddrs = []common.Address{ + common.BytesToAddress([]byte{0x11}), + common.BytesToAddress([]byte{0x22}), + common.BytesToAddress([]byte{0x44}), +} + +// checkJournalInvariants validates that: +// - journal.mutations exactly reflects the dirty entries currently in +// journal.entries (per-kind counts and mask match what you'd get by +// walking the entries from scratch). +// - journal.originals mirrors that set for the three tracked metadata kinds +// (balance/nonce/code): a *Set flag is true iff the account currently has +// at least one corresponding entry in the journal. +// - An address is present in originals only if it also has at least one +// tracked-kind mutation in the journal. +func checkJournalInvariants(t *testing.T, j *journal) { + t.Helper() + + // Reconstruct the expected per-address counts from the live entries. + expected := make(map[common.Address]*journalMutationCounts) + for _, e := range j.entries { + addr, kind, dirty := e.mutation() + if !dirty { + continue + } + c := expected[addr] + if c == nil { + c = &journalMutationCounts{} + expected[addr] = c + } + c.add(kind) + } + + if len(j.mutations) != len(expected) { + t.Fatalf("mutations size %d, want %d", len(j.mutations), len(expected)) + } + for addr, state := range j.mutations { + want, ok := expected[addr] + if !ok { + t.Fatalf("mutations has extra address %x", addr) + } + if state.counts != *want { + t.Fatalf("addr %x: counts=%+v want=%+v", addr, state.counts, *want) + } + // First-touch *Set flags must mirror the live per-kind counts. + if state.balanceSet != (want.balance > 0) { + t.Fatalf("addr %x: balanceSet=%v want=%v (balance count=%d)", + addr, state.balanceSet, want.balance > 0, want.balance) + } + if state.nonceSet != (want.nonce > 0) { + t.Fatalf("addr %x: nonceSet=%v want=%v (nonce count=%d)", + addr, state.nonceSet, want.nonce > 0, want.nonce) + } + if state.codeSet != (want.code > 0) { + t.Fatalf("addr %x: codeSet=%v want=%v (code count=%d)", + addr, state.codeSet, want.code > 0, want.code) + } + } +} + +// FuzzJournal drives a randomised sequence of state mutations, snapshots and +// reverts against a fresh StateDB and validates the journal's internal +// bookkeeping invariants after every step. It also asserts that reverting +// back to the root snapshot empties mutations, originals and entries +// completely. The seed corpus ensures the test also runs as a regular unit +// test via `go test -run FuzzJournal`. +func FuzzJournal(f *testing.F) { + seeds := [][]byte{ + // balance then full revert (simplest a→b→a case). + {0x00, 0x00, 0x05, 0x05, 0x00}, + // balance+nonce+code mixed, then revert to root. + {0x00, 0x00, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x00, 0x03, 0x05, 0x00}, + // snapshot, mutate, revert, mutate again. + {0x04, 0x00, 0x00, 0x07, 0x05, 0x00, 0x00, 0x01, 0x05}, + // storage interleaved with metadata. + {0x03, 0x00, 0x01, 0x00, 0x01, 0x05, 0x03, 0x02, 0x02, 0x04, 0x03, 0x01, 0x07}, + // many ops, no explicit revert — exercises steady-state invariants. + {0x00, 0x01, 0x02, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, + 0x03, 0x04, 0x00, 0x01, 0x02, 0x00, 0x06, 0x08, 0x0a, 0x0c}, + } + for _, s := range seeds { + f.Add(s) + } + + f.Fuzz(func(t *testing.T, data []byte) { + sdb, err := New(types.EmptyRootHash, NewDatabaseForTesting()) + if err != nil { + t.Fatal(err) + } + root := sdb.Snapshot() + + // Stack of snapshot IDs taken during the fuzz loop. + var pending []int + + // readByte returns the next byte and advances the cursor. Returns + // (0, false) if exhausted. + i := 0 + readByte := func() (byte, bool) { + if i >= len(data) { + return 0, false + } + b := data[i] + i++ + return b, true + } + + for { + op, ok := readByte() + if !ok { + break + } + switch op % 6 { + case 0: // SetBalance + a, ok1 := readByte() + v, ok2 := readByte() + if !ok1 || !ok2 { + break + } + addr := fuzzJournalAddrs[int(a)%len(fuzzJournalAddrs)] + sdb.SetBalance(addr, uint256.NewInt(uint64(v)), tracing.BalanceChangeUnspecified) + case 1: // SetNonce + a, ok1 := readByte() + n, ok2 := readByte() + if !ok1 || !ok2 { + break + } + addr := fuzzJournalAddrs[int(a)%len(fuzzJournalAddrs)] + sdb.SetNonce(addr, uint64(n), tracing.NonceChangeUnspecified) + case 2: // SetCode + a, ok1 := readByte() + l, ok2 := readByte() + if !ok1 || !ok2 { + break + } + addr := fuzzJournalAddrs[int(a)%len(fuzzJournalAddrs)] + code := make([]byte, int(l)%8) + for k := range code { + b, ok := readByte() + if !ok { + break + } + code[k] = b + } + sdb.SetCode(addr, code, tracing.CodeChangeUnspecified) + case 3: // SetState (storage; tracked as mutation kind, no original) + a, ok1 := readByte() + k, ok2 := readByte() + v, ok3 := readByte() + if !ok1 || !ok2 || !ok3 { + break + } + addr := fuzzJournalAddrs[int(a)%len(fuzzJournalAddrs)] + sdb.SetState(addr, + common.BytesToHash([]byte{k}), + common.BytesToHash([]byte{v})) + case 4: // Snapshot + pending = append(pending, sdb.Snapshot()) + case 5: // RevertToSnapshot + if len(pending) == 0 { + break + } + sel, ok := readByte() + if !ok { + break + } + idx := int(sel) % len(pending) + sdb.RevertToSnapshot(pending[idx]) + pending = pending[:idx] + } + checkJournalInvariants(t, sdb.journal) + } + + // After reverting to the root snapshot, the journal must be fully + // drained: no entries, no mutations, no originals. This is the core + // guarantee the user cares about — "all mutations against a single + // account reverted" taken to its limit across every account. + sdb.RevertToSnapshot(root) + checkJournalInvariants(t, sdb.journal) + + if n := len(sdb.journal.entries); n != 0 { + t.Fatalf("entries not drained after revert-to-root: %d remain", n) + } + if n := len(sdb.journal.mutations); n != 0 { + t.Fatalf("mutations not drained after revert-to-root: %d remain", n) + } + }) +} diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 8c073e2fc2..c6e14e7477 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -662,12 +662,30 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", state.GetLogs(common.Hash{}, 0, common.Hash{}, 0), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}, 0)) } - if !maps.Equal(state.journal.mutationSet(), checkstate.journal.mutationSet()) { - return fmt.Errorf("journal mutation set mismatch.\nhave:\n%v\nwant:\n%v\n", state.journal.mutationSet(), checkstate.journal.mutationSet()) + if !equalMutationSets(state.journal.mutations, checkstate.journal.mutations) { + return fmt.Errorf("journal mutation set mismatch.\nhave:\n%v\nwant:\n%v\n", state.journal.mutations, checkstate.journal.mutations) } return nil } +// equalMutationSets checks that two journal mutation maps have the same set of +// addresses and, for each address, the same per-kind counts. The stashed +// original values are ignored because comparing them across two independent +// state databases (with distinct pointer identities) isn't the point of this +// check — we only care that the two journals agree on what was touched. +func equalMutationSets(a, b map[common.Address]*journalMutationState) bool { + if len(a) != len(b) { + return false + } + for addr, sa := range a { + sb, ok := b[addr] + if !ok || sa.counts != sb.counts { + return false + } + } + return true +} + func TestTouchDelete(t *testing.T) { s := newStateEnv() s.state.getOrNewStateObject(common.Address{}) @@ -677,11 +695,11 @@ func TestTouchDelete(t *testing.T) { snapshot := s.state.Snapshot() s.state.AddBalance(common.Address{}, new(uint256.Int), tracing.BalanceChangeUnspecified) - if len(s.state.journal.mutationSet()) != 1 { + if len(s.state.journal.mutations) != 1 { t.Fatal("expected one mutated state object") } s.state.RevertToSnapshot(snapshot) - if len(s.state.journal.mutationSet()) != 0 { + if len(s.state.journal.mutations) != 0 { t.Fatal("expected no journal mutations") } } @@ -691,8 +709,8 @@ func TestJournalMutationTracking(t *testing.T) { addr := common.HexToAddress("0x01") key := common.HexToHash("0x02") - if got := state.journal.mutation(addr); got != 0 { - t.Fatalf("unexpected initial mutation set: %v", got) + if _, ok := state.journal.mutations[addr]; ok { + t.Fatal("unexpected initial mutation entry") } snapshot := state.Snapshot() @@ -701,23 +719,30 @@ func TestJournalMutationTracking(t *testing.T) { state.SetCode(addr, []byte{0x1}, tracing.CodeChangeUnspecified) state.SetState(addr, key, common.Hash{0x3}) - want := journalMutationKindCreate.mask() | - journalMutationKindBalance.mask() | - journalMutationKindNonce.mask() | - journalMutationKindCode.mask() | - journalMutationKindStorage.mask() - if got := state.journal.mutation(addr); got != want { - t.Fatalf("mutation set mismatch: have %08b, want %08b", got, want) + want := journalMutationCounts{ + create: 1, + balance: 1, + nonce: 1, + code: 1, + storage: 1, } + checkCounts := func(got *journalMutationState, label string) { + t.Helper() + if got == nil { + t.Fatalf("%s: missing mutation entry for %x", label, addr) + } + if got.counts != want { + t.Fatalf("%s: counts=%+v, want=%+v", label, got.counts, want) + } + } + checkCounts(state.journal.mutations[addr], "state") copy := state.Copy() - if got := copy.journal.mutation(addr); got != want { - t.Fatalf("copy mutation set mismatch: have %08b, want %08b", got, want) - } + checkCounts(copy.journal.mutations[addr], "copy") state.RevertToSnapshot(snapshot) - if got := state.journal.mutation(addr); got != 0 { - t.Fatalf("unexpected mutation set after revert: %08b", got) + if _, ok := state.journal.mutations[addr]; ok { + t.Fatalf("unexpected mutation entry after revert") } }