diff --git a/core/vm/contract.go b/core/vm/contract.go index ab5102ed53..5393960bc6 100644 --- a/core/vm/contract.go +++ b/core/vm/contract.go @@ -138,7 +138,13 @@ func (c *Contract) UseGas(gas GasCosts, logger *tracing.Hooks, reason tracing.Ga } // RefundGas refunds gas to the contract -func (c *Contract) RefundGas(gas GasCosts, logger *tracing.Hooks, reason tracing.GasChangeReason) { +func (c *Contract) RefundGas(err error, gas GasCosts, logger *tracing.Hooks, reason tracing.GasChangeReason) { + // If the preceding call errored, return the state gas + // to the parent call + if err != nil { + gas.StateGas += gas.StateGasCharged + gas.StateGasCharged = 0 + } if gas.Max() == 0 { return } @@ -147,13 +153,6 @@ func (c *Contract) RefundGas(gas GasCosts, logger *tracing.Hooks, reason tracing } c.Gas.RegularGas += gas.RegularGas c.Gas.StateGas = gas.StateGas - /* - if c.Gas.StateGas < gas.StateGasCharged { - // We overcharged StateGas, during reverts we need to add back to regular gas - missing := c.Gas.StateGasCharged - c.Gas.StateGas - c.Gas.RegularGas += missing - } - */ c.Gas.StateGasCharged += gas.StateGasCharged } diff --git a/core/vm/evm.go b/core/vm/evm.go index 1f26fbafd6..dbe183a34e 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -321,9 +321,6 @@ func (evm *EVM) Call(caller common.Address, addr common.Address, input []byte, g } gas.RegularGas = 0 } - // Restore gas to the parent on child error - gas.StateGas += gas.StateGasCharged - gas.StateGasCharged = 0 } return ret, gas, err } @@ -380,9 +377,6 @@ func (evm *EVM) CallCode(caller common.Address, addr common.Address, input []byt } gas.RegularGas = 0 } - // Restore gas to the parent on child error - gas.StateGas += gas.StateGasCharged - gas.StateGasCharged = 0 } return ret, gas, err } @@ -432,9 +426,6 @@ func (evm *EVM) DelegateCall(originCaller common.Address, caller common.Address, } gas.RegularGas = 0 } - // Restore gas to the parent on child error - gas.StateGas += gas.StateGasCharged - gas.StateGasCharged = 0 } return ret, gas, err } @@ -495,9 +486,6 @@ func (evm *EVM) StaticCall(caller common.Address, addr common.Address, input []b } gas.RegularGas = 0 } - // Restore gas to the parent on child error - gas.StateGas += gas.StateGasCharged - gas.StateGasCharged = 0 } return ret, gas, err } @@ -611,9 +599,6 @@ func (evm *EVM) create(caller common.Address, code []byte, gas GasCosts, value * if err != ErrExecutionReverted { contract.UseGas(contract.Gas, evm.Config.Tracer, tracing.GasChangeCallFailedExecution) } - // Restore gas to the parent on child error - contract.Gas.StateGas += contract.Gas.StateGasCharged - contract.Gas.StateGasCharged = 0 } return ret, address, contract.Gas, err } diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 2b3c8e982a..1092f779f6 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -689,7 +689,7 @@ func opCreate(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { } scope.Stack.push(&stackvalue) - scope.Contract.RefundGas(returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) + scope.Contract.RefundGas(suberr, returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) if suberr == ErrExecutionReverted { evm.returnData = res // set REVERT data to return data buffer @@ -736,7 +736,7 @@ func opCreate2(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { } scope.Stack.push(&stackvalue) - scope.Contract.RefundGas(returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) + scope.Contract.RefundGas(suberr, returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) if suberr == ErrExecutionReverted { evm.returnData = res // set REVERT data to return data buffer @@ -776,7 +776,7 @@ func opCall(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - scope.Contract.RefundGas(returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) + scope.Contract.RefundGas(err, returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) evm.returnData = ret return ret, nil @@ -809,7 +809,7 @@ func opCallCode(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - scope.Contract.RefundGas(returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) + scope.Contract.RefundGas(err, returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) evm.returnData = ret return ret, nil @@ -839,7 +839,7 @@ func opDelegateCall(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - scope.Contract.RefundGas(returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) + scope.Contract.RefundGas(err, returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) evm.returnData = ret return ret, nil @@ -869,7 +869,7 @@ func opStaticCall(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - scope.Contract.RefundGas(returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) + scope.Contract.RefundGas(err, returnGas, evm.Config.Tracer, tracing.GasChangeCallLeftOverRefunded) evm.returnData = ret return ret, nil