diff --git a/core/txpool/legacypool/list.go b/core/txpool/legacypool/list.go index 0c9f13c62f..c401e2a71a 100644 --- a/core/txpool/legacypool/list.go +++ b/core/txpool/legacypool/list.go @@ -301,6 +301,7 @@ func (l *list) Contains(nonce uint64) bool { // If the new transaction is accepted into the list, the lists' cost and gas // thresholds are also potentially updated. func (l *list) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Transaction) { + base := new(uint256.Int).Set(l.totalcost) // If there's an older better transaction, abort old := l.txs.Get(tx.Nonce()) if old != nil { @@ -323,23 +324,22 @@ func (l *list) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Transa if tx.GasFeeCapIntCmp(thresholdFeeCap) < 0 || tx.GasTipCapIntCmp(thresholdTip) < 0 { return false, nil } + // Old is being replaced, subtract old cost + if _, underflow := base.SubOverflow(base, uint256.MustFromBig(old.Cost())); underflow { + panic("totalcost underflow") + } } // Add new tx cost to totalcost cost, overflow := uint256.FromBig(tx.Cost()) if overflow { return false, nil } - total, overflow := new(uint256.Int).AddOverflow(l.totalcost, cost) + total, overflow := new(uint256.Int).AddOverflow(base, cost) if overflow { return false, nil } l.totalcost = total - // Old is being replaced, subtract old cost - if old != nil { - l.subTotalCost([]*types.Transaction{old}) - } - // Otherwise overwrite the old transaction with the current one l.txs.Put(tx) if l.costcap.Cmp(cost) < 0 { diff --git a/core/txpool/legacypool/list_test.go b/core/txpool/legacypool/list_test.go index dc03def26c..5ba9ef15ed 100644 --- a/core/txpool/legacypool/list_test.go +++ b/core/txpool/legacypool/list_test.go @@ -68,6 +68,52 @@ func TestListAddVeryExpensive(t *testing.T) { } } +func TestListAddReplacementAvoidsIntermediateOverflow(t *testing.T) { + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + max := new(big.Int).Sub(new(big.Int).Lsh(common.Big1, 256), common.Big1) + oldPrice := new(big.Int).Sub(new(big.Int).Rsh(new(big.Int).Set(max), 1), big.NewInt(100)) + newPrice := new(big.Int).Add(oldPrice, common.Big1) + + oldTx, err := types.SignTx(types.NewTransaction(0, common.Address{}, common.Big0, 1, oldPrice, nil), types.HomesteadSigner{}, key) + if err != nil { + t.Fatalf("failed to sign old tx: %v", err) + } + newTx, err := types.SignTx(types.NewTransaction(0, common.Address{}, common.Big0, 1, newPrice, nil), types.HomesteadSigner{}, key) + if err != nil { + t.Fatalf("failed to sign replacement tx: %v", err) + } + + list := newList(true) + inserted, _ := list.Add(oldTx, 0) + if !inserted { + t.Fatal("failed to insert baseline transaction") + } + inserted, replaced := list.Add(newTx, 0) + if !inserted { + t.Fatal("replacement transaction should not overflow after subtracting old cost") + } + if replaced == nil || replaced.Hash() != oldTx.Hash() { + t.Fatal("expected old transaction to be replaced") + } + want, overflow := uint256.FromBig(newTx.Cost()) + if overflow { + t.Fatal("replacement tx cost overflowed uint256 in test setup") + } + if list.totalcost.Cmp(want) != 0 { + t.Fatalf("totalcost mismatch after replacement: have %v want %v", list.totalcost, want) + } + if tx := list.txs.Get(newTx.Nonce()); tx == nil || tx.Hash() != newTx.Hash() { + t.Fatal("replacement transaction was not stored in list") + } + list.Forward(1) + if list.totalcost.Sign() != 0 { + t.Fatalf("totalcost should be zero after removal, have %v", list.totalcost) + } +} + // TestPriceHeapCmp tests that the price heap comparison function works as intended. // It also tests combinations where the basefee is higher than the gas fee cap, which // are useful to sort in the mempool to support basefee changes.