From 46be549b0024352f712ba83a393ea30c869c733a Mon Sep 17 00:00:00 2001 From: Gary Rong Date: Tue, 28 Apr 2026 23:31:28 +0800 Subject: [PATCH] core: extend the journal and introduce the ability for traversal --- core/state/journal.go | 85 +++++++++++++- core/state/journal_test.go | 213 +++++++++++++++++++++++++++++++++++ core/state/statedb.go | 9 ++ core/state/statedb_hooked.go | 4 + core/vm/evm.go | 76 ++++++++----- core/vm/interface.go | 6 + 6 files changed, 361 insertions(+), 32 deletions(-) create mode 100644 core/state/journal_test.go diff --git a/core/state/journal.go b/core/state/journal.go index 6a7f54ebc8..c80efa93db 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -27,9 +27,24 @@ import ( "github.com/holiman/uint256" ) +// frameRange is a half-open interval [start, end) of journal entry indices, +// used to record the slice of entries occupied by a closed child call frame. +type frameRange struct { + start, end int +} + type revision struct { id int journalIndex int + // closedChildren holds the [start, end) ranges of child call frames that + // have been closed under this revision via closeSnapshot. Together with + // journalIndex (this frame's own start) and the current journal length + // (this frame's tentative end) they describe the slice of entries that + // belong directly to this frame, with descendant frames' entries excluded. + // + // Invariant: ranges are appended in increasing order, are non-overlapping, + // and lie entirely within [journalIndex, len(entries)). + closedChildren []frameRange } // journalEntry is a modification entry in the state change journal that can be @@ -86,7 +101,7 @@ func (j *journal) reset() { func (j *journal) snapshot() int { id := j.nextRevisionId j.nextRevisionId++ - j.validRevisions = append(j.validRevisions, revision{id, j.length()}) + j.validRevisions = append(j.validRevisions, revision{id: id, journalIndex: j.length()}) return id } @@ -106,6 +121,64 @@ func (j *journal) revertToSnapshot(revid int, s *StateDB) { j.validRevisions = j.validRevisions[:idx] } +// closeSnapshot marks the end of the call frame identified by revid without +// reverting any state. The frame's entry range [snapshot_index, current_length) +// is recorded on its parent revision so callers can later iterate the parent's +// own entries while skipping over closed children (and, transitively, their +// descendants — descendant ranges are absorbed into the closing child's range +// when the descendant itself was closed earlier under that child). +// +// closeSnapshot must be invoked in LIFO order: revid must identify the topmost +// snapshot. It panics otherwise. The corresponding revision is popped, so a +// subsequent revertToSnapshot on the same id is no longer valid. +func (j *journal) closeSnapshot(revid int) { + if len(j.validRevisions) == 0 { + panic(fmt.Errorf("revision id %v cannot be closed: no open snapshot", revid)) + } + top := len(j.validRevisions) - 1 + if j.validRevisions[top].id != revid { + panic(fmt.Errorf("revision id %v cannot be closed: top is %v", + revid, j.validRevisions[top].id)) + } + closed := frameRange{ + start: j.validRevisions[top].journalIndex, + end: len(j.entries), + } + // Only propagate non-empty ranges, and only if there is a parent frame to + // receive them. The outermost frame has nothing to bubble up to. + if closed.start < closed.end && top > 0 { + parent := &j.validRevisions[top-1] + parent.closedChildren = append(parent.closedChildren, closed) + } + // Drop this revision's bookkeeping. The slice is reused by the parent so + // avoid pinning it via the popped tail. + j.validRevisions[top].closedChildren = nil + j.validRevisions = j.validRevisions[:top] +} + +// frameEntries invokes visit for each entry that belongs directly to the +// current (topmost) call frame, skipping entries that lie within any closed +// child frame's range. Entries are visited in append order. If no frame is +// open, frameEntries is a no-op. +// +// nolint:unused +func (j *journal) frameEntries(visit func(entry journalEntry)) { + if len(j.validRevisions) == 0 { + return + } + rev := j.validRevisions[len(j.validRevisions)-1] + idx := rev.journalIndex + for _, child := range rev.closedChildren { + for ; idx < child.start; idx++ { + visit(j.entries[idx]) + } + idx = child.end + } + for ; idx < len(j.entries); idx++ { + visit(j.entries[idx]) + } +} + // append inserts a new modification entry to the end of the change journal. func (j *journal) append(entry journalEntry) { j.entries = append(j.entries, entry) @@ -244,10 +317,18 @@ func (j *journal) copy() *journal { for i := 0; i < j.length(); i++ { entries = append(entries, j.entries[i].copy()) } + revisions := make([]revision, len(j.validRevisions)) + for i, r := range j.validRevisions { + revisions[i] = revision{ + id: r.id, + journalIndex: r.journalIndex, + closedChildren: slices.Clone(r.closedChildren), + } + } return &journal{ entries: entries, dirties: maps.Clone(j.dirties), - validRevisions: slices.Clone(j.validRevisions), + validRevisions: revisions, nextRevisionId: j.nextRevisionId, stateBytesCharged: maps.Clone(j.stateBytesCharged), } diff --git a/core/state/journal_test.go b/core/state/journal_test.go new file mode 100644 index 0000000000..0e0e2b55b1 --- /dev/null +++ b/core/state/journal_test.go @@ -0,0 +1,213 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "slices" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +// tagEntry is a minimal journalEntry used by journal tests. It carries an +// integer tag so frameEntries iteration order can be verified, and is a no-op +// on revert so the surrounding StateDB can be a zero value. +type tagEntry struct { + tag int +} + +func (t tagEntry) revert(*StateDB) {} +func (t tagEntry) dirtied() (common.Address, bool) { return common.Address{}, false } +func (t tagEntry) copy() journalEntry { return t } + +// frameTags drives frameEntries and returns the visited tags in order. +func frameTags(j *journal) []int { + var got []int + j.frameEntries(func(e journalEntry) { + got = append(got, e.(tagEntry).tag) + }) + return got +} + +// didPanic reports whether fn panicked. +func didPanic(fn func()) (panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + fn() + return false +} + +// TestJournalFrameTracking covers the happy paths of closeSnapshot and +// frameEntries together: basic single-child filtering, empty-range elision, +// multiple siblings, transitive descendant absorption, and the no-open-frame +// edge case for frameEntries. Building one composite scenario and asserting +// at each step keeps the expected behaviour as a connected story rather than +// scattering it across many tiny tests. +func TestJournalFrameTracking(t *testing.T) { + j := newJournal() + + // frameEntries on an empty journal is a no-op. + if got := frameTags(j); len(got) != 0 { + t.Fatalf("empty journal frameEntries: have %v, want []", got) + } + + j.snapshot() + j.append(tagEntry{1}) // outer + + // Closing an empty child frame must not record a degenerate range. + empty := j.snapshot() + j.closeSnapshot(empty) + if got := j.validRevisions[0].closedChildren; len(got) != 0 { + t.Fatalf("empty child should not propagate, have %+v", got) + } + + // First sibling child: two entries, then close. Range goes onto outer. + c1 := j.snapshot() + c1Start := len(j.entries) + j.append(tagEntry{10}) + j.append(tagEntry{11}) + c1End := len(j.entries) + j.closeSnapshot(c1) + + j.append(tagEntry{2}) // outer between siblings + + // Second sibling, with a grandchild closed inside it. After the + // grandchild closes, more entries appear in the child before it itself + // closes. The outer must end up with a single range that covers the + // child (which transitively covers the grandchild). + c2 := j.snapshot() + c2Start := len(j.entries) + j.append(tagEntry{20}) + + gc := j.snapshot() + j.append(tagEntry{300}) + j.closeSnapshot(gc) + + j.append(tagEntry{21}) + c2End := len(j.entries) + j.closeSnapshot(c2) + + j.append(tagEntry{3}) // outer after both siblings + + got := j.validRevisions[0].closedChildren + want := []frameRange{{c1Start, c1End}, {c2Start, c2End}} + if !slices.Equal(got, want) { + t.Fatalf("closedChildren: have %+v, want %+v", got, want) + } + if tags := frameTags(j); !slices.Equal(tags, []int{1, 2, 3}) { + t.Fatalf("frameEntries: have %v, want [1 2 3]", tags) + } + + // Closing the outermost (no-parent) frame is allowed: there is nothing + // to populate, but the revision is still popped and its range silently + // dropped. The journal ends up with no open frames. + outer := j.validRevisions[0].id + j.closeSnapshot(outer) + if len(j.validRevisions) != 0 { + t.Fatalf("after closing outermost, have %d open revisions, want 0", len(j.validRevisions)) + } +} + +// TestJournalCloseSnapshotPanics asserts the LIFO precondition: closing when +// no snapshot is open, or closing a revision while a more recent snapshot is +// still open above it, must panic rather than silently mutate state. Closing +// the outermost (no-parent) frame *is* permitted and is covered in +// TestJournalFrameTracking. +func TestJournalCloseSnapshotPanics(t *testing.T) { + j := newJournal() + if !didPanic(func() { j.closeSnapshot(0) }) { + t.Fatal("closing with no open snapshot should panic") + } + bottom := j.snapshot() + j.snapshot() // a more recent snapshot is now on top + if !didPanic(func() { j.closeSnapshot(bottom) }) { + t.Fatal("closing a snapshot that is not the most recent should panic") + } +} + +// TestJournalRevertInteractions verifies the two cross-cuts between revert +// and close: reverting a parent that has absorbed closed children also +// throws away the children's entries, and reverting a child (rather than +// closing it) leaves no closed-child range on the parent. +func TestJournalRevertInteractions(t *testing.T) { + t.Run("revertParentWithClosedChild", func(t *testing.T) { + j := newJournal() + outer := j.snapshot() + j.append(tagEntry{1}) + + c := j.snapshot() + j.append(tagEntry{10}) + j.append(tagEntry{11}) + j.closeSnapshot(c) + + j.append(tagEntry{2}) + j.revertToSnapshot(outer, &StateDB{}) + + if len(j.entries) != 0 || len(j.validRevisions) != 0 { + t.Fatalf("after revert have entries=%d revisions=%d, want both 0", + len(j.entries), len(j.validRevisions)) + } + }) + t.Run("revertedChildLeavesNoRange", func(t *testing.T) { + j := newJournal() + j.snapshot() + j.append(tagEntry{1}) + + c := j.snapshot() + j.append(tagEntry{10}) + j.revertToSnapshot(c, &StateDB{}) + j.append(tagEntry{2}) + + if got := j.validRevisions[0].closedChildren; len(got) != 0 { + t.Fatalf("reverted child should not appear in closedChildren, have %+v", got) + } + if tags := frameTags(j); !slices.Equal(tags, []int{1, 2}) { + t.Fatalf("frameEntries: have %v, want [1 2]", tags) + } + }) +} + +// TestJournalCopyAndReset checks that the bookkeeping for closed-child ranges +// participates in journal.copy (deep-copied, not aliased) and journal.reset +// (cleared along with everything else). +func TestJournalCopyAndReset(t *testing.T) { + j := newJournal() + j.snapshot() + j.append(tagEntry{1}) + c := j.snapshot() + j.append(tagEntry{10}) + j.closeSnapshot(c) + + cp := j.copy() + if !slices.Equal(cp.validRevisions[0].closedChildren, j.validRevisions[0].closedChildren) { + t.Fatalf("copy lost closedChildren: orig=%+v copy=%+v", + j.validRevisions[0].closedChildren, cp.validRevisions[0].closedChildren) + } + cp.validRevisions[0].closedChildren = append(cp.validRevisions[0].closedChildren, frameRange{99, 100}) + if len(j.validRevisions[0].closedChildren) != 1 { + t.Fatal("original aliased copy's closedChildren slice") + } + + j.reset() + if len(j.entries) != 0 || len(j.validRevisions) != 0 { + t.Fatalf("after reset have entries=%d revisions=%d, want both 0", + len(j.entries), len(j.validRevisions)) + } +} diff --git a/core/state/statedb.go b/core/state/statedb.go index ee28605c03..875bde5d5a 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -753,6 +753,15 @@ func (s *StateDB) RevertToSnapshot(revid int) { s.journal.revertToSnapshot(revid, s) } +// CloseSnapshot marks the call frame identified by revid as completed without +// reverting any state. Its journal entry range is recorded on the parent +// frame so the parent can later iterate its own entries while skipping over +// closed children. revid must identify the topmost open snapshot (i.e. frames +// must be closed in LIFO order). It panics otherwise. +func (s *StateDB) CloseSnapshot(revid int) { + s.journal.closeSnapshot(revid) +} + // GetRefund returns the current value of the refund counter. func (s *StateDB) GetRefund() uint64 { return s.refund diff --git a/core/state/statedb_hooked.go b/core/state/statedb_hooked.go index a9ea4d5397..8886cacb70 100644 --- a/core/state/statedb_hooked.go +++ b/core/state/statedb_hooked.go @@ -147,6 +147,10 @@ func (s *hookedStateDB) RevertToSnapshot(i int) { s.inner.RevertToSnapshot(i) } +func (s *hookedStateDB) CloseSnapshot(i int) { + s.inner.CloseSnapshot(i) +} + func (s *hookedStateDB) Snapshot() int { return s.inner.Snapshot() } diff --git a/core/vm/evm.go b/core/vm/evm.go index daa751b95c..1cf78814e7 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -276,6 +276,7 @@ func (evm *EVM) Call(caller common.Address, addr common.Address, input []byte, g if !isPrecompile && evm.chainRules.IsEIP158 && value.IsZero() { // Calling a non-existing account, don't do anything. + evm.StateDB.CloseSnapshot(snapshot) return nil, gas, nil } evm.StateDB.CreateAccount(addr) @@ -322,15 +323,18 @@ func (evm *EVM) Call(caller common.Address, addr common.Address, input []byte, g // TODO: consider clearing up unused snapshots: //} else { // evm.StateDB.DiscardSnapshot(snapshot) - } else if evm.chainRules.IsAmsterdam { - // Charge state costs - bytesCharged := evm.StateDB.StateChangedBytes(innerSnapshot) - stateGasCost := GasCosts{StateGas: bytesCharged * int64(evm.Context.CostPerStateByte)} - if !gas.CanAfford(stateGasCost) { - gas.Exhaust() - return ret, gas, ErrOutOfGas + } else { + evm.StateDB.CloseSnapshot(snapshot) + if evm.chainRules.IsAmsterdam { + // Charge state costs + bytesCharged := evm.StateDB.StateChangedBytes(innerSnapshot) + stateGasCost := GasCosts{StateGas: bytesCharged * int64(evm.Context.CostPerStateByte)} + if !gas.CanAfford(stateGasCost) { + gas.Exhaust() + return ret, gas, ErrOutOfGas + } + gas.Charge(stateGasCost) } - gas.Charge(stateGasCost) } return ret, gas, err } @@ -382,14 +386,17 @@ func (evm *EVM) CallCode(caller common.Address, addr common.Address, input []byt } gas.Exhaust() } - } else if evm.chainRules.IsAmsterdam { - bytesCharged := evm.StateDB.StateChangedBytes(snapshot) - stateGasCost := GasCosts{StateGas: bytesCharged * int64(evm.Context.CostPerStateByte)} - if !gas.CanAfford(stateGasCost) { - gas.Exhaust() - return ret, gas, ErrOutOfGas + } else { + evm.StateDB.CloseSnapshot(snapshot) + if evm.chainRules.IsAmsterdam { + bytesCharged := evm.StateDB.StateChangedBytes(snapshot) + stateGasCost := GasCosts{StateGas: bytesCharged * int64(evm.Context.CostPerStateByte)} + if !gas.CanAfford(stateGasCost) { + gas.Exhaust() + return ret, gas, ErrOutOfGas + } + gas.Charge(stateGasCost) } - gas.Charge(stateGasCost) } return ret, gas, err } @@ -434,15 +441,19 @@ func (evm *EVM) DelegateCall(originCaller common.Address, caller common.Address, } gas.Exhaust() } - } else if evm.chainRules.IsAmsterdam { - bytesCharged := evm.StateDB.StateChangedBytes(snapshot) - stateGasCost := GasCosts{StateGas: bytesCharged * int64(evm.Context.CostPerStateByte)} - if !gas.CanAfford(stateGasCost) { - gas.Exhaust() - return ret, gas, ErrOutOfGas + } else { + evm.StateDB.CloseSnapshot(snapshot) + if evm.chainRules.IsAmsterdam { + bytesCharged := evm.StateDB.StateChangedBytes(snapshot) + stateGasCost := GasCosts{StateGas: bytesCharged * int64(evm.Context.CostPerStateByte)} + if !gas.CanAfford(stateGasCost) { + gas.Exhaust() + return ret, gas, ErrOutOfGas + } + gas.Charge(stateGasCost) } - gas.Charge(stateGasCost) } + return ret, gas, err } @@ -497,6 +508,8 @@ func (evm *EVM) StaticCall(caller common.Address, addr common.Address, input []b } gas.Exhaust() } + } else { + evm.StateDB.CloseSnapshot(snapshot) } return ret, gas, err } @@ -607,15 +620,18 @@ func (evm *EVM) create(caller common.Address, code []byte, gas GasBudget, value if err != ErrExecutionReverted { contract.UseGas(GasCosts{RegularGas: contract.Gas.RegularGas}, evm.Config.Tracer, tracing.GasChangeCallFailedExecution) } - } else if evm.chainRules.IsAmsterdam { - // Charge initcode's state changes to the created contract's gas. - bytesCharged := evm.StateDB.StateChangedBytes(initSnapshot) - stateGasCost := GasCosts{StateGas: bytesCharged * int64(evm.Context.CostPerStateByte)} - if !contract.Gas.CanAfford(stateGasCost) { - contract.Gas.Exhaust() - return ret, address, contract.Gas, ErrOutOfGas + } else { + evm.StateDB.CloseSnapshot(snapshot) + if evm.chainRules.IsAmsterdam { + // Charge initcode's state changes to the created contract's gas. + bytesCharged := evm.StateDB.StateChangedBytes(initSnapshot) + stateGasCost := GasCosts{StateGas: bytesCharged * int64(evm.Context.CostPerStateByte)} + if !contract.Gas.CanAfford(stateGasCost) { + contract.Gas.Exhaust() + return ret, address, contract.Gas, ErrOutOfGas + } + contract.Gas.Charge(stateGasCost) } - contract.Gas.Charge(stateGasCost) } return ret, address, contract.Gas, err } diff --git a/core/vm/interface.go b/core/vm/interface.go index a58fddac22..9c132061c6 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -87,6 +87,12 @@ type StateDB interface { Prepare(rules params.Rules, sender, coinbase common.Address, dest *common.Address, precompiles []common.Address, txAccesses types.AccessList) RevertToSnapshot(int) + + // CloseSnapshot marks the given snapshot's call frame as completed without + // reverting any state. The call frame's entry range is recorded on the + // parent frame so the parent can later iterate its own entries while + // skipping over closed children. Snapshots must be closed in LIFO order. + CloseSnapshot(int) Snapshot() int AddLog(*types.Log)