diff --git a/XDCx/XDCx.go b/XDCx/XDCx.go index 6506929984..3dc3386d92 100644 --- a/XDCx/XDCx.go +++ b/XDCx/XDCx.go @@ -292,10 +292,10 @@ func (XDCx *XDCX) GetAveragePriceLastEpoch(chain consensus.ChainContext, statedb // return tokenQuantity (after convert from XDC to token), tokenPriceInXDC, error func (XDCx *XDCX) ConvertXDCToToken(chain consensus.ChainContext, statedb *state.StateDB, tradingStateDb *tradingstate.TradingStateDB, token common.Address, quantity *big.Int) (*big.Int, *big.Int, error) { - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { return quantity, common.BasePrice, nil } - tokenPriceInXDC, err := XDCx.GetAveragePriceLastEpoch(chain, statedb, tradingStateDb, token, common.HexToAddress(common.XDCNativeAddress)) + tokenPriceInXDC, err := XDCx.GetAveragePriceLastEpoch(chain, statedb, tradingStateDb, token, common.XDCNativeAddressBinary) if err != nil || tokenPriceInXDC == nil || tokenPriceInXDC.Sign() <= 0 { return common.Big0, common.Big0, err } @@ -595,10 +595,11 @@ func (XDCx *XDCX) GetTriegc() *prque.Prque { func (XDCx *XDCX) GetTradingStateRoot(block *types.Block, author common.Address) (common.Hash, error) { for _, tx := range block.Transactions() { - from := *(tx.From()) - if tx.To() != nil && tx.To().Hex() == common.TradingStateAddr && from.String() == author.String() { - if len(tx.Data()) >= 32 { - return common.BytesToHash(tx.Data()[:32]), nil + to := tx.To() + if to != nil && *to == common.TradingStateAddrBinary && *tx.From() == author { + data := tx.Data() + if len(data) >= 32 { + return common.BytesToHash(data[:32]), nil } } } diff --git a/XDCx/order_processor.go b/XDCx/order_processor.go index ca9ca496e3..86d156fec8 100644 --- a/XDCx/order_processor.go +++ b/XDCx/order_processor.go @@ -236,11 +236,11 @@ func (XDCx *XDCX) processOrderList(coinbase common.Address, chain consensus.Chai maxTradedQuantity = tradingstate.CloneBigInt(amount) } var quotePrice *big.Int - if oldestOrder.QuoteToken.String() != common.XDCNativeAddress { - quotePrice = tradingStateDB.GetLastPrice(tradingstate.GetTradingOrderBookHash(oldestOrder.QuoteToken, common.HexToAddress(common.XDCNativeAddress))) + if oldestOrder.QuoteToken != common.XDCNativeAddressBinary { + quotePrice = tradingStateDB.GetLastPrice(tradingstate.GetTradingOrderBookHash(oldestOrder.QuoteToken, common.XDCNativeAddressBinary)) log.Debug("TryGet quotePrice QuoteToken/XDC", "quotePrice", quotePrice) if quotePrice == nil || quotePrice.Sign() == 0 { - inversePrice := tradingStateDB.GetLastPrice(tradingstate.GetTradingOrderBookHash(common.HexToAddress(common.XDCNativeAddress), oldestOrder.QuoteToken)) + inversePrice := tradingStateDB.GetLastPrice(tradingstate.GetTradingOrderBookHash(common.XDCNativeAddressBinary, oldestOrder.QuoteToken)) quoteTokenDecimal, err := XDCx.GetTokenDecimal(chain, statedb, oldestOrder.QuoteToken) if err != nil || quoteTokenDecimal.Sign() == 0 { return nil, nil, nil, fmt.Errorf("Fail to get tokenDecimal. Token: %v . Err: %v", oldestOrder.QuoteToken.String(), err) @@ -374,10 +374,10 @@ func (XDCx *XDCX) getTradeQuantity(quotePrice *big.Int, coinbase common.Address, if err != nil || quoteTokenDecimal.Sign() == 0 { return tradingstate.Zero, false, nil, fmt.Errorf("Fail to get tokenDecimal. Token: %v . Err: %v", makerOrder.QuoteToken.String(), err) } - if makerOrder.QuoteToken.String() == common.XDCNativeAddress { + if makerOrder.QuoteToken == common.XDCNativeAddressBinary { quotePrice = quoteTokenDecimal } - if takerOrder.ExchangeAddress.String() == makerOrder.ExchangeAddress.String() { + if takerOrder.ExchangeAddress == makerOrder.ExchangeAddress { if err := tradingstate.CheckRelayerFee(takerOrder.ExchangeAddress, new(big.Int).Mul(common.RelayerFee, big.NewInt(2)), statedb); err != nil { log.Debug("Reject order Taker Exchnage = Maker Exchange , relayer not enough fee ", "err", err) return tradingstate.Zero, false, nil, nil diff --git a/XDCx/order_processor_test.go b/XDCx/order_processor_test.go index 7ab396f83f..b3a651ae98 100644 --- a/XDCx/order_processor_test.go +++ b/XDCx/order_processor_test.go @@ -1,12 +1,13 @@ package XDCx import ( - "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" - "github.com/XinFinOrg/XDPoSChain/common" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math/big" "reflect" "testing" + + "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" ) func Test_getCancelFeeV1(t *testing.T) { @@ -103,9 +104,9 @@ func Test_getCancelFee(t *testing.T) { XDCx.SetTokenDecimal(testTokenB, tokenBDecimal) // set tokenAPrice = 1 XDC - tradingStateDb.SetMediumPriceBeforeEpoch(tradingstate.GetTradingOrderBookHash(testTokenA, common.HexToAddress(common.XDCNativeAddress)), common.BasePrice) + tradingStateDb.SetMediumPriceBeforeEpoch(tradingstate.GetTradingOrderBookHash(testTokenA, common.XDCNativeAddressBinary), common.BasePrice) // set tokenBPrice = 1 XDC - tradingStateDb.SetMediumPriceBeforeEpoch(tradingstate.GetTradingOrderBookHash(common.HexToAddress(common.XDCNativeAddress), testTokenB), tokenBDecimal) + tradingStateDb.SetMediumPriceBeforeEpoch(tradingstate.GetTradingOrderBookHash(common.XDCNativeAddressBinary, testTokenB), tokenBDecimal) type CancelFeeArg struct { feeRate *big.Int @@ -127,7 +128,7 @@ func Test_getCancelFee(t *testing.T) { feeRate: common.Big0, order: &tradingstate.OrderItem{ BaseToken: testTokenA, - QuoteToken: common.HexToAddress(common.XDCNativeAddress), + QuoteToken: common.XDCNativeAddressBinary, Quantity: new(big.Int).SetUint64(10000), Side: tradingstate.Ask, }, @@ -142,7 +143,7 @@ func Test_getCancelFee(t *testing.T) { feeRate: common.Big0, order: &tradingstate.OrderItem{ BaseToken: testTokenA, - QuoteToken: common.HexToAddress(common.XDCNativeAddress), + QuoteToken: common.XDCNativeAddressBinary, Quantity: new(big.Int).SetUint64(10000), Side: tradingstate.Bid, }, @@ -156,7 +157,7 @@ func Test_getCancelFee(t *testing.T) { CancelFeeArg{ feeRate: new(big.Int).SetUint64(10), // 10/10000= 0.1% order: &tradingstate.OrderItem{ - BaseToken: common.HexToAddress(common.XDCNativeAddress), + BaseToken: common.XDCNativeAddressBinary, QuoteToken: testTokenA, Quantity: new(big.Int).SetUint64(10000), Side: tradingstate.Ask, @@ -172,7 +173,7 @@ func Test_getCancelFee(t *testing.T) { feeRate: new(big.Int).SetUint64(10), // 10/10000= 0.1% order: &tradingstate.OrderItem{ Quantity: new(big.Int).SetUint64(10000), - BaseToken: common.HexToAddress(common.XDCNativeAddress), + BaseToken: common.XDCNativeAddressBinary, QuoteToken: testTokenA, Side: tradingstate.Bid, }, @@ -188,7 +189,7 @@ func Test_getCancelFee(t *testing.T) { CancelFeeArg{ feeRate: common.Big0, order: &tradingstate.OrderItem{ - BaseToken: common.HexToAddress(common.XDCNativeAddress), + BaseToken: common.XDCNativeAddressBinary, QuoteToken: testTokenA, Quantity: new(big.Int).SetUint64(10000), Side: tradingstate.Ask, @@ -203,7 +204,7 @@ func Test_getCancelFee(t *testing.T) { CancelFeeArg{ feeRate: common.Big0, order: &tradingstate.OrderItem{ - BaseToken: common.HexToAddress(common.XDCNativeAddress), + BaseToken: common.XDCNativeAddressBinary, QuoteToken: testTokenA, Quantity: new(big.Int).SetUint64(10000), Side: tradingstate.Bid, @@ -218,7 +219,7 @@ func Test_getCancelFee(t *testing.T) { CancelFeeArg{ feeRate: new(big.Int).SetUint64(10), // 10/10000= 0.1% order: &tradingstate.OrderItem{ - BaseToken: common.HexToAddress(common.XDCNativeAddress), + BaseToken: common.XDCNativeAddressBinary, QuoteToken: testTokenA, Quantity: new(big.Int).SetUint64(10000), Side: tradingstate.Ask, @@ -234,7 +235,7 @@ func Test_getCancelFee(t *testing.T) { feeRate: new(big.Int).SetUint64(10), // 10/10000= 0.1% order: &tradingstate.OrderItem{ Quantity: new(big.Int).SetUint64(10000), - BaseToken: common.HexToAddress(common.XDCNativeAddress), + BaseToken: common.XDCNativeAddressBinary, QuoteToken: testTokenA, Side: tradingstate.Bid, }, diff --git a/XDCx/token.go b/XDCx/token.go index 2bf37e3f2f..5aa8554cf9 100644 --- a/XDCx/token.go +++ b/XDCx/token.go @@ -48,7 +48,7 @@ func (XDCx *XDCX) GetTokenDecimal(chain consensus.ChainContext, statedb *state.S if tokenDecimal, ok := XDCx.tokenDecimalCache.Get(tokenAddr); ok { return tokenDecimal.(*big.Int), nil } - if tokenAddr.String() == common.XDCNativeAddress { + if tokenAddr == common.XDCNativeAddressBinary { XDCx.tokenDecimalCache.Add(tokenAddr, common.BasePrice) return common.BasePrice, nil } diff --git a/XDCx/tradingstate/orderitem.go b/XDCx/tradingstate/orderitem.go index f28e978435..8c39375b5f 100644 --- a/XDCx/tradingstate/orderitem.go +++ b/XDCx/tradingstate/orderitem.go @@ -239,7 +239,7 @@ func (o *OrderItem) verifyRelayer(state *state.StateDB) error { return nil } -//verify signatures +// verify signatures func (o *OrderItem) verifySignature() error { bigstr := o.Nonce.String() n, err := strconv.ParseInt(bigstr, 10, 64) @@ -269,7 +269,7 @@ func (o *OrderItem) verifyOrderType() error { return nil } -//verify order side +// verify order side func (o *OrderItem) verifyOrderSide() error { if o.Side != Bid && o.Side != Ask { @@ -356,11 +356,11 @@ func VerifyPair(statedb *state.StateDB, exchangeAddress, baseToken, quoteToken c func VerifyBalance(statedb *state.StateDB, XDCxStateDb *TradingStateDB, order *types.OrderTransaction, baseDecimal, quoteDecimal *big.Int) error { var quotePrice *big.Int - if order.QuoteToken().String() != common.XDCNativeAddress { - quotePrice = XDCxStateDb.GetLastPrice(GetTradingOrderBookHash(order.QuoteToken(), common.HexToAddress(common.XDCNativeAddress))) + if order.QuoteToken() != common.XDCNativeAddressBinary { + quotePrice = XDCxStateDb.GetLastPrice(GetTradingOrderBookHash(order.QuoteToken(), common.XDCNativeAddressBinary)) log.Debug("TryGet quotePrice QuoteToken/XDC", "quotePrice", quotePrice) if quotePrice == nil || quotePrice.Sign() == 0 { - inversePrice := XDCxStateDb.GetLastPrice(GetTradingOrderBookHash(common.HexToAddress(common.XDCNativeAddress), order.QuoteToken())) + inversePrice := XDCxStateDb.GetLastPrice(GetTradingOrderBookHash(common.XDCNativeAddressBinary, order.QuoteToken())) log.Debug("TryGet inversePrice XDC/QuoteToken", "inversePrice", inversePrice) if inversePrice != nil && inversePrice.Sign() > 0 { quotePrice = new(big.Int).Mul(common.BasePrice, quoteDecimal) diff --git a/XDCx/tradingstate/relayer_state.go b/XDCx/tradingstate/relayer_state.go index 2a1892492d..348f213041 100644 --- a/XDCx/tradingstate/relayer_state.go +++ b/XDCx/tradingstate/relayer_state.go @@ -159,9 +159,9 @@ func CheckRelayerFee(relayer common.Address, fee *big.Int, statedb *state.StateD } func AddTokenBalance(addr common.Address, value *big.Int, token common.Address, statedb *state.StateDB) error { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { balance := statedb.GetBalance(addr) - log.Debug("ApplyXDCXMatchedTransaction settle balance: ADD TOKEN XDC NATIVE BEFORE", "token", token.String(), "address", addr.String(), "balance", balance, "orderValue", value) + log.Debug("ApplyXDCXMatchedTransaction settle balance: ADD TOKEN XDC NATIVE BEFORE", "token", common.XDCNativeAddress, "address", addr.String(), "balance", balance, "orderValue", value) statedb.AddBalance(addr, value) balance = statedb.GetBalance(addr) log.Debug("ApplyXDCXMatchedTransaction settle balance: ADD XDC NATIVE BALANCE AFTER", "token", token.String(), "address", addr.String(), "balance", balance, "orderValue", value) @@ -186,10 +186,9 @@ func AddTokenBalance(addr common.Address, value *big.Int, token common.Address, func SubTokenBalance(addr common.Address, value *big.Int, token common.Address, statedb *state.StateDB) error { // XDC native - if token.String() == common.XDCNativeAddress { - + if token == common.XDCNativeAddressBinary { balance := statedb.GetBalance(addr) - log.Debug("ApplyXDCXMatchedTransaction settle balance: SUB XDC NATIVE BALANCE BEFORE", "token", token.String(), "address", addr.String(), "balance", balance, "orderValue", value) + log.Debug("ApplyXDCXMatchedTransaction settle balance: SUB XDC NATIVE BALANCE BEFORE", "token", common.XDCNativeAddress, "address", addr.String(), "balance", balance, "orderValue", value) if balance.Cmp(value) < 0 { return errors.Errorf("value %s in token %s not enough , have : %s , want : %s ", addr.String(), token.String(), balance, value) } @@ -219,7 +218,7 @@ func SubTokenBalance(addr common.Address, value *big.Int, token common.Address, func CheckSubTokenBalance(addr common.Address, value *big.Int, token common.Address, statedb *state.StateDB, mapBalances map[common.Address]map[common.Address]*big.Int) (*big.Int, error) { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { var balance *big.Int if value := mapBalances[token][addr]; value != nil { balance = value @@ -256,7 +255,7 @@ func CheckSubTokenBalance(addr common.Address, value *big.Int, token common.Addr func CheckAddTokenBalance(addr common.Address, value *big.Int, token common.Address, statedb *state.StateDB, mapBalances map[common.Address]map[common.Address]*big.Int) (*big.Int, error) { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { var balance *big.Int if value := mapBalances[token][addr]; value != nil { balance = value @@ -308,7 +307,7 @@ func CheckSubRelayerFee(relayer common.Address, fee *big.Int, statedb *state.Sta func GetTokenBalance(addr common.Address, token common.Address, statedb *state.StateDB) *big.Int { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { return statedb.GetBalance(addr) } // TRC tokens @@ -323,7 +322,7 @@ func GetTokenBalance(addr common.Address, token common.Address, statedb *state.S func SetTokenBalance(addr common.Address, balance *big.Int, token common.Address, statedb *state.StateDB) error { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { statedb.SetBalance(addr, balance) return nil } diff --git a/XDCx/tradingstate/settle_balance.go b/XDCx/tradingstate/settle_balance.go index 66996ae5b2..771da0c29b 100644 --- a/XDCx/tradingstate/settle_balance.go +++ b/XDCx/tradingstate/settle_balance.go @@ -52,7 +52,7 @@ func GetSettleBalance(quotePrice *big.Int, takerSide string, takerFeeRate *big.I log.Debug("quantity trade too small", "quoteTokenQuantity", quoteTokenQuantity, "makerFee", makerFee, "defaultFee", defaultFee) return result, ErrQuantityTradeTooSmall } - if quoteToken.String() != common.XDCNativeAddress && quotePrice != nil && quotePrice.Cmp(common.Big0) > 0 { + if quoteToken != common.XDCNativeAddressBinary && quotePrice != nil && quotePrice.Cmp(common.Big0) > 0 { // defaultFeeInXDC defaultFeeInXDC := new(big.Int).Mul(defaultFee, quotePrice) defaultFeeInXDC = new(big.Int).Div(defaultFeeInXDC, quoteTokenDecimal) @@ -69,7 +69,7 @@ func GetSettleBalance(quotePrice *big.Int, takerSide string, takerFeeRate *big.I log.Debug("takerFee too small", "quoteTokenQuantity", quoteTokenQuantity, "takerFee", takerFee, "exTakerReceivedFee", exTakerReceivedFee, "quotePrice", quotePrice, "defaultFeeInXDC", defaultFeeInXDC) return result, ErrQuantityTradeTooSmall } - } else if quoteToken.String() == common.XDCNativeAddress { + } else if quoteToken == common.XDCNativeAddressBinary { exMakerReceivedFee := makerFee if (exMakerReceivedFee.Cmp(common.RelayerFee) <= 0 && exMakerReceivedFee.Sign() > 0) || defaultFee.Cmp(common.RelayerFee) <= 0 { log.Debug("makerFee too small", "quantityToTrade", quantityToTrade, "makerFee", makerFee, "exMakerReceivedFee", exMakerReceivedFee, "makerFeeRate", makerFeeRate, "defaultFee", defaultFee) @@ -108,7 +108,7 @@ func GetSettleBalance(quotePrice *big.Int, takerSide string, takerFeeRate *big.I log.Debug("quantity trade too small", "quoteTokenQuantity", quoteTokenQuantity, "takerFee", takerFee) return result, ErrQuantityTradeTooSmall } - if quoteToken.String() != common.XDCNativeAddress && quotePrice != nil && quotePrice.Cmp(common.Big0) > 0 { + if quoteToken != common.XDCNativeAddressBinary && quotePrice != nil && quotePrice.Cmp(common.Big0) > 0 { // defaultFeeInXDC defaultFeeInXDC := new(big.Int).Mul(defaultFee, quotePrice) defaultFeeInXDC = new(big.Int).Div(defaultFeeInXDC, quoteTokenDecimal) @@ -126,7 +126,7 @@ func GetSettleBalance(quotePrice *big.Int, takerSide string, takerFeeRate *big.I log.Debug("takerFee too small", "quoteTokenQuantity", quoteTokenQuantity, "takerFee", takerFee, "exTakerReceivedFee", exTakerReceivedFee, "quotePrice", quotePrice, "defaultFeeInXDC", defaultFeeInXDC) return result, ErrQuantityTradeTooSmall } - } else if quoteToken.String() == common.XDCNativeAddress { + } else if quoteToken == common.XDCNativeAddressBinary { exMakerReceivedFee := makerFee if (exMakerReceivedFee.Cmp(common.RelayerFee) <= 0 && exMakerReceivedFee.Sign() > 0) || defaultFee.Cmp(common.RelayerFee) <= 0 { log.Debug("makerFee too small", "quantityToTrade", quantityToTrade, "makerFee", makerFee, "exMakerReceivedFee", exMakerReceivedFee, "makerFeeRate", makerFeeRate, "defaultFee", defaultFee) diff --git a/XDCx/tradingstate/settle_balance_test.go b/XDCx/tradingstate/settle_balance_test.go index 674e381fd8..d340226fc7 100644 --- a/XDCx/tradingstate/settle_balance_test.go +++ b/XDCx/tradingstate/settle_balance_test.go @@ -1,10 +1,11 @@ package tradingstate import ( - "github.com/XinFinOrg/XDPoSChain/common" "math/big" "reflect" "testing" + + "github.com/XinFinOrg/XDPoSChain/common" ) func TestGetSettleBalance(t *testing.T) { @@ -89,7 +90,7 @@ func TestGetSettleBalance(t *testing.T) { takerSide: Bid, takerFeeRate: big.NewInt(10), // feeRate 0.1% baseToken: testToken, - quoteToken: common.HexToAddress(common.XDCNativeAddress), + quoteToken: common.XDCNativeAddressBinary, makerPrice: common.BasePrice, makerFeeRate: big.NewInt(10), // feeRate 0.1% baseTokenDecimal: common.BasePrice, @@ -106,7 +107,7 @@ func TestGetSettleBalance(t *testing.T) { takerSide: Bid, takerFeeRate: big.NewInt(5), // feeRate 0.05% baseToken: testToken, - quoteToken: common.HexToAddress(common.XDCNativeAddress), + quoteToken: common.XDCNativeAddressBinary, makerPrice: common.BasePrice, makerFeeRate: big.NewInt(10), // feeRate 0.1% baseTokenDecimal: common.BasePrice, @@ -124,7 +125,7 @@ func TestGetSettleBalance(t *testing.T) { takerSide: Bid, takerFeeRate: big.NewInt(10), // feeRate 0.1% baseToken: testToken, - quoteToken: common.HexToAddress(common.XDCNativeAddress), + quoteToken: common.XDCNativeAddressBinary, makerPrice: common.BasePrice, makerFeeRate: big.NewInt(10), // feeRate 0.1% baseTokenDecimal: common.BasePrice, @@ -132,8 +133,8 @@ func TestGetSettleBalance(t *testing.T) { quantityToTrade: new(big.Int).Mul(big.NewInt(1000), common.BasePrice), }, &SettleBalance{ - Taker: TradeResult{Fee: testFee, InToken: testToken, InTotal: tradeQuantity, OutToken: common.HexToAddress(common.XDCNativeAddress), OutTotal: tradeQuantityIncludedFee}, - Maker: TradeResult{Fee: testFee, InToken: common.HexToAddress(common.XDCNativeAddress), InTotal: tradeQuantityExcludedFee, OutToken: testToken, OutTotal: tradeQuantity}, + Taker: TradeResult{Fee: testFee, InToken: testToken, InTotal: tradeQuantity, OutToken: common.XDCNativeAddressBinary, OutTotal: tradeQuantityIncludedFee}, + Maker: TradeResult{Fee: testFee, InToken: common.XDCNativeAddressBinary, InTotal: tradeQuantityExcludedFee, OutToken: testToken, OutTotal: tradeQuantity}, }, false, }, @@ -196,7 +197,7 @@ func TestGetSettleBalance(t *testing.T) { takerSide: Ask, takerFeeRate: big.NewInt(10), // feeRate 0.1% baseToken: testToken, - quoteToken: common.HexToAddress(common.XDCNativeAddress), + quoteToken: common.XDCNativeAddressBinary, makerPrice: common.BasePrice, makerFeeRate: big.NewInt(10), // feeRate 0.1% baseTokenDecimal: common.BasePrice, @@ -213,7 +214,7 @@ func TestGetSettleBalance(t *testing.T) { takerSide: Ask, takerFeeRate: big.NewInt(5), // feeRate 0.05% baseToken: testToken, - quoteToken: common.HexToAddress(common.XDCNativeAddress), + quoteToken: common.XDCNativeAddressBinary, makerPrice: common.BasePrice, makerFeeRate: big.NewInt(10), // feeRate 0.1% baseTokenDecimal: common.BasePrice, @@ -231,7 +232,7 @@ func TestGetSettleBalance(t *testing.T) { takerSide: Ask, takerFeeRate: big.NewInt(10), // feeRate 15% baseToken: testToken, - quoteToken: common.HexToAddress(common.XDCNativeAddress), + quoteToken: common.XDCNativeAddressBinary, makerPrice: common.BasePrice, makerFeeRate: big.NewInt(10), // feeRate 0.1% baseTokenDecimal: common.BasePrice, @@ -239,8 +240,8 @@ func TestGetSettleBalance(t *testing.T) { quantityToTrade: new(big.Int).Mul(big.NewInt(1000), common.BasePrice), }, &SettleBalance{ - Maker: TradeResult{Fee: testFee, InToken: testToken, InTotal: tradeQuantity, OutToken: common.HexToAddress(common.XDCNativeAddress), OutTotal: tradeQuantityIncludedFee}, - Taker: TradeResult{Fee: testFee, InToken: common.HexToAddress(common.XDCNativeAddress), InTotal: tradeQuantityExcludedFee, OutToken: testToken, OutTotal: tradeQuantity}, + Maker: TradeResult{Fee: testFee, InToken: testToken, InTotal: tradeQuantity, OutToken: common.XDCNativeAddressBinary, OutTotal: tradeQuantityIncludedFee}, + Taker: TradeResult{Fee: testFee, InToken: common.XDCNativeAddressBinary, InTotal: tradeQuantityExcludedFee, OutToken: testToken, OutTotal: tradeQuantity}, }, false, }, diff --git a/XDCxlending/XDCxlending.go b/XDCxlending/XDCxlending.go index 352c224d0a..48ff540778 100644 --- a/XDCxlending/XDCxlending.go +++ b/XDCxlending/XDCxlending.go @@ -696,10 +696,11 @@ func (l *Lending) GetTriegc() *prque.Prque { func (l *Lending) GetLendingStateRoot(block *types.Block, author common.Address) (common.Hash, error) { for _, tx := range block.Transactions() { - from := *(tx.From()) - if tx.To() != nil && tx.To().Hex() == common.TradingStateAddr && from.String() == author.String() { - if len(tx.Data()) >= 64 { - return common.BytesToHash(tx.Data()[32:]), nil + to := tx.To() + if to != nil && *to == common.TradingStateAddrBinary && *tx.From() == author { + data := tx.Data() + if len(data) >= 64 { + return common.BytesToHash(data[32:]), nil } } } diff --git a/XDCxlending/lendingstate/lendingcontract.go b/XDCxlending/lendingstate/lendingcontract.go index fa2df7dc84..eba5b1a065 100644 --- a/XDCxlending/lendingstate/lendingcontract.go +++ b/XDCxlending/lendingstate/lendingcontract.go @@ -1,7 +1,7 @@ package lendingstate import ( - "fmt" + "errors" "math/big" "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" @@ -273,10 +273,10 @@ func GetAllLendingBooks(statedb *state.StateDB) (mapLendingBook map[common.Hash] baseTokens := GetSupportedBaseToken(statedb) terms := GetSupportedTerms(statedb) if len(baseTokens) == 0 { - return nil, fmt.Errorf("GetAllLendingBooks: empty baseToken list") + return nil, errors.New("GetAllLendingBooks: empty baseToken list") } if len(terms) == 0 { - return nil, fmt.Errorf("GetAllLendingPairs: empty term list") + return nil, errors.New("GetAllLendingPairs: empty term list") } for _, baseToken := range baseTokens { for _, term := range terms { @@ -295,10 +295,10 @@ func GetAllLendingPairs(statedb *state.StateDB) (allPairs []LendingPair, err err baseTokens := GetSupportedBaseToken(statedb) collaterals := GetAllCollateral(statedb) if len(baseTokens) == 0 { - return allPairs, fmt.Errorf("GetAllLendingPairs: empty baseToken list") + return allPairs, errors.New("GetAllLendingPairs: empty baseToken list") } if len(collaterals) == 0 { - return allPairs, fmt.Errorf("GetAllLendingPairs: empty collateral list") + return allPairs, errors.New("GetAllLendingPairs: empty collateral list") } for _, baseToken := range baseTokens { for _, collateral := range collaterals { diff --git a/XDCxlending/lendingstate/lendingitem.go b/XDCxlending/lendingstate/lendingitem.go index 235a77fa60..dd4553ab87 100644 --- a/XDCxlending/lendingstate/lendingitem.go +++ b/XDCxlending/lendingstate/lendingitem.go @@ -1,6 +1,7 @@ package lendingstate import ( + "errors" "fmt" "math/big" "strconv" @@ -260,7 +261,7 @@ func (l *LendingItem) VerifyCollateral(state *state.StateDB) error { validCollateral := false collateralList := GetCollaterals(state, l.Relayer, l.LendingToken, l.Term) for _, collateral := range collateralList { - if l.CollateralToken.String() == collateral.String() { + if l.CollateralToken == collateral { validCollateral = true break } @@ -359,7 +360,7 @@ func (l *LendingItem) VerifyLendingSignature() error { tx.ImportSignature(V, R, S) from, _ := types.LendingSender(types.LendingTxSigner{}, tx) if from != tx.UserAddress() { - return fmt.Errorf("verify lending item: invalid signature") + return errors.New("verify lending item: invalid signature") } return nil } @@ -411,7 +412,7 @@ func VerifyBalance(isXDCXLendingFork bool, statedb *state.StateDB, lendingStateD defaultFee := new(big.Int).Mul(quantity, new(big.Int).SetUint64(DefaultFeeRate)) defaultFee = new(big.Int).Div(defaultFee, common.XDCXBaseFee) defaultFeeInXDC := common.Big0 - if lendingToken.String() != common.XDCNativeAddress { + if lendingToken != common.XDCNativeAddressBinary { defaultFeeInXDC = new(big.Int).Mul(defaultFee, lendTokenXDCPrice) defaultFeeInXDC = new(big.Int).Div(defaultFeeInXDC, lendingTokenDecimal) } else { @@ -473,10 +474,10 @@ func VerifyBalance(isXDCXLendingFork bool, statedb *state.StateDB, lendingStateD } return nil default: - return fmt.Errorf("VerifyBalance: unknown lending side") + return errors.New("VerifyBalance: unknown lending side") } default: - return fmt.Errorf("VerifyBalance: unknown lending type") + return errors.New("VerifyBalance: unknown lending type") } return nil } diff --git a/XDCxlending/lendingstate/relayer.go b/XDCxlending/lendingstate/relayer.go index 571ecdcd15..536784741b 100644 --- a/XDCxlending/lendingstate/relayer.go +++ b/XDCxlending/lendingstate/relayer.go @@ -114,12 +114,12 @@ func CheckRelayerFee(relayer common.Address, fee *big.Int, statedb *state.StateD } func AddTokenBalance(addr common.Address, value *big.Int, token common.Address, statedb *state.StateDB) error { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { balance := statedb.GetBalance(addr) - log.Debug("ApplyXDCXMatchedTransaction settle balance: ADD TOKEN XDC NATIVE BEFORE", "token", token.String(), "address", addr.String(), "balance", balance, "orderValue", value) + log.Debug("ApplyXDCXMatchedTransaction settle balance: ADD TOKEN XDC NATIVE BEFORE", "token", common.XDCNativeAddress, "address", addr.String(), "balance", balance, "orderValue", value) statedb.AddBalance(addr, value) balance = statedb.GetBalance(addr) - log.Debug("ApplyXDCXMatchedTransaction settle balance: ADD XDC NATIVE BALANCE AFTER", "token", token.String(), "address", addr.String(), "balance", balance, "orderValue", value) + log.Debug("ApplyXDCXMatchedTransaction settle balance: ADD XDC NATIVE BALANCE AFTER", "token", common.XDCNativeAddress, "address", addr.String(), "balance", balance, "orderValue", value) return nil } @@ -141,15 +141,15 @@ func AddTokenBalance(addr common.Address, value *big.Int, token common.Address, func SubTokenBalance(addr common.Address, value *big.Int, token common.Address, statedb *state.StateDB) error { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { balance := statedb.GetBalance(addr) - log.Debug("ApplyXDCXMatchedTransaction settle balance: SUB XDC NATIVE BALANCE BEFORE", "token", token.String(), "address", addr.String(), "balance", balance, "orderValue", value) + log.Debug("ApplyXDCXMatchedTransaction settle balance: SUB XDC NATIVE BALANCE BEFORE", "token", common.XDCNativeAddress, "address", addr.String(), "balance", balance, "orderValue", value) if balance.Cmp(value) < 0 { - return errors.Errorf("value %s in token %s not enough , have : %s , want : %s ", addr.String(), token.String(), balance, value) + return errors.Errorf("value %s in token %s not enough , have : %s , want : %s ", addr.String(), common.XDCNativeAddress, balance, value) } statedb.SubBalance(addr, value) balance = statedb.GetBalance(addr) - log.Debug("ApplyXDCXMatchedTransaction settle balance: SUB XDC NATIVE BALANCE AFTER", "token", token.String(), "address", addr.String(), "balance", balance, "orderValue", value) + log.Debug("ApplyXDCXMatchedTransaction settle balance: SUB XDC NATIVE BALANCE AFTER", "token", common.XDCNativeAddress, "address", addr.String(), "balance", balance, "orderValue", value) return nil } @@ -174,7 +174,7 @@ func SubTokenBalance(addr common.Address, value *big.Int, token common.Address, func CheckSubTokenBalance(addr common.Address, value *big.Int, token common.Address, statedb *state.StateDB, mapBalances map[common.Address]map[common.Address]*big.Int) (*big.Int, error) { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { var balance *big.Int if value := mapBalances[token][addr]; value != nil { balance = value @@ -182,10 +182,10 @@ func CheckSubTokenBalance(addr common.Address, value *big.Int, token common.Addr balance = statedb.GetBalance(addr) } if balance.Cmp(value) < 0 { - return nil, errors.Errorf("value %s in token %s not enough , have : %s , want : %s ", addr.String(), token.String(), balance, value) + return nil, errors.Errorf("value %s in token %s not enough , have : %s , want : %s ", addr.String(), common.XDCNativeAddress, balance, value) } newBalance := new(big.Int).Sub(balance, value) - log.Debug("CheckSubTokenBalance settle balance: SUB XDC NATIVE BALANCE ", "token", token.String(), "address", addr.String(), "balance", balance, "value", value, "newBalance", newBalance) + log.Debug("CheckSubTokenBalance settle balance: SUB XDC NATIVE BALANCE ", "token", common.XDCNativeAddress, "address", addr.String(), "balance", balance, "value", value, "newBalance", newBalance) return newBalance, nil } // TRC tokens @@ -211,7 +211,7 @@ func CheckSubTokenBalance(addr common.Address, value *big.Int, token common.Addr func CheckAddTokenBalance(addr common.Address, value *big.Int, token common.Address, statedb *state.StateDB, mapBalances map[common.Address]map[common.Address]*big.Int) (*big.Int, error) { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { var balance *big.Int if value := mapBalances[token][addr]; value != nil { balance = value @@ -219,7 +219,7 @@ func CheckAddTokenBalance(addr common.Address, value *big.Int, token common.Addr balance = statedb.GetBalance(addr) } newBalance := new(big.Int).Add(balance, value) - log.Debug("CheckAddTokenBalance settle balance: ADD XDC NATIVE BALANCE ", "token", token.String(), "address", addr.String(), "balance", balance, "value", value, "newBalance", newBalance) + log.Debug("CheckAddTokenBalance settle balance: ADD XDC NATIVE BALANCE ", "token", common.XDCNativeAddress, "address", addr.String(), "balance", balance, "value", value, "newBalance", newBalance) return newBalance, nil } // TRC tokens @@ -263,7 +263,7 @@ func CheckSubRelayerFee(relayer common.Address, fee *big.Int, statedb *state.Sta func GetTokenBalance(addr common.Address, token common.Address, statedb *state.StateDB) *big.Int { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { return statedb.GetBalance(addr) } // TRC tokens @@ -278,7 +278,7 @@ func GetTokenBalance(addr common.Address, token common.Address, statedb *state.S func SetTokenBalance(addr common.Address, balance *big.Int, token common.Address, statedb *state.StateDB) error { // XDC native - if token.String() == common.XDCNativeAddress { + if token == common.XDCNativeAddressBinary { statedb.SetBalance(addr, balance) return nil } diff --git a/XDCxlending/lendingstate/settle_balance.go b/XDCxlending/lendingstate/settle_balance.go index 108a459afe..93ee96d936 100644 --- a/XDCxlending/lendingstate/settle_balance.go +++ b/XDCxlending/lendingstate/settle_balance.go @@ -3,9 +3,10 @@ package lendingstate import ( "encoding/json" "errors" + "math/big" + "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/log" - "math/big" ) const DefaultFeeRate = 100 // 100 / XDCXBaseFee = 100 / 10000 = 1% @@ -71,7 +72,7 @@ func GetSettleBalance(isXDCXLendingFork bool, log.Debug("quantity lending too small", "quantityToLend", quantityToLend, "takerFee", takerFee) return result, ErrQuantityTradeTooSmall } - if lendingToken.String() != common.XDCNativeAddress && lendTokenXDCPrice != nil && lendTokenXDCPrice.Cmp(common.Big0) > 0 { + if lendingToken != common.XDCNativeAddressBinary && lendTokenXDCPrice != nil && lendTokenXDCPrice.Cmp(common.Big0) > 0 { exTakerReceivedFee := new(big.Int).Mul(takerFee, lendTokenXDCPrice) exTakerReceivedFee = new(big.Int).Div(exTakerReceivedFee, lendTokenDecimal) @@ -82,7 +83,7 @@ func GetSettleBalance(isXDCXLendingFork bool, log.Debug("takerFee too small", "quantityToLend", quantityToLend, "takerFee", takerFee, "exTakerReceivedFee", exTakerReceivedFee, "borrowFeeRate", borrowFeeRate, "defaultFeeInXDC", defaultFeeInXDC) return result, ErrQuantityTradeTooSmall } - } else if lendingToken.String() == common.XDCNativeAddress { + } else if lendingToken == common.XDCNativeAddressBinary { exTakerReceivedFee := takerFee if (exTakerReceivedFee.Cmp(common.RelayerLendingFee) <= 0 && exTakerReceivedFee.Sign() > 0) || defaultFee.Cmp(common.RelayerLendingFee) <= 0 { log.Debug("takerFee too small", "quantityToLend", quantityToLend, "takerFee", takerFee, "exTakerReceivedFee", exTakerReceivedFee, "borrowFeeRate", borrowFeeRate, "defaultFee", defaultFee) @@ -121,7 +122,7 @@ func GetSettleBalance(isXDCXLendingFork bool, log.Debug("quantity lending too small", "quantityToLend", quantityToLend, "makerFee", makerFee) return result, ErrQuantityTradeTooSmall } - if lendingToken.String() != common.XDCNativeAddress && lendTokenXDCPrice != nil && lendTokenXDCPrice.Cmp(common.Big0) > 0 { + if lendingToken != common.XDCNativeAddressBinary && lendTokenXDCPrice != nil && lendTokenXDCPrice.Cmp(common.Big0) > 0 { exMakerReceivedFee := new(big.Int).Mul(makerFee, lendTokenXDCPrice) exMakerReceivedFee = new(big.Int).Div(exMakerReceivedFee, lendTokenDecimal) @@ -132,7 +133,7 @@ func GetSettleBalance(isXDCXLendingFork bool, log.Debug("makerFee too small", "quantityToLend", quantityToLend, "makerFee", makerFee, "exMakerReceivedFee", exMakerReceivedFee, "borrowFeeRate", borrowFeeRate, "defaultFeeInXDC", defaultFeeInXDC) return result, ErrQuantityTradeTooSmall } - } else if lendingToken.String() == common.XDCNativeAddress { + } else if lendingToken == common.XDCNativeAddressBinary { exMakerReceivedFee := makerFee if (exMakerReceivedFee.Cmp(common.RelayerLendingFee) <= 0 && exMakerReceivedFee.Sign() > 0) || defaultFee.Cmp(common.RelayerLendingFee) <= 0 { log.Debug("makerFee too small", "quantityToLend", quantityToLend, "makerFee", makerFee, "exMakerReceivedFee", exMakerReceivedFee, "borrowFeeRate", borrowFeeRate, "defaultFee", defaultFee) @@ -171,7 +172,7 @@ func GetSettleBalance(isXDCXLendingFork bool, log.Debug("quantity lending too small", "quantityToLend", quantityToLend, "borrowFee", borrowFee) return result, ErrQuantityTradeTooSmall } - if lendingToken.String() != common.XDCNativeAddress && lendTokenXDCPrice != nil && lendTokenXDCPrice.Cmp(common.Big0) > 0 { + if lendingToken != common.XDCNativeAddressBinary && lendTokenXDCPrice != nil && lendTokenXDCPrice.Cmp(common.Big0) > 0 { // exReceivedFee: the fee amount which borrowingRelayer will receive exReceivedFee := new(big.Int).Mul(borrowFee, lendTokenXDCPrice) exReceivedFee = new(big.Int).Div(exReceivedFee, lendTokenDecimal) @@ -183,7 +184,7 @@ func GetSettleBalance(isXDCXLendingFork bool, log.Debug("takerFee too small", "quantityToLend", quantityToLend, "borrowFee", borrowFee, "exReceivedFee", exReceivedFee, "borrowFeeRate", borrowFeeRate, "defaultFeeInXDC", defaultFeeInXDC) return result, ErrQuantityTradeTooSmall } - } else if lendingToken.String() == common.XDCNativeAddress { + } else if lendingToken == common.XDCNativeAddressBinary { exReceivedFee := borrowFee if (exReceivedFee.Cmp(common.RelayerLendingFee) <= 0 && exReceivedFee.Sign() > 0) || defaultFee.Cmp(common.RelayerLendingFee) <= 0 { log.Debug("takerFee too small", "quantityToLend", quantityToLend, "borrowFee", borrowFee, "exReceivedFee", exReceivedFee, "borrowFeeRate", borrowFeeRate, "defaultFee", defaultFee) diff --git a/XDCxlending/lendingstate/settle_balance_test.go b/XDCxlending/lendingstate/settle_balance_test.go index e4b318b6cf..711777503c 100644 --- a/XDCxlending/lendingstate/settle_balance_test.go +++ b/XDCxlending/lendingstate/settle_balance_test.go @@ -1,10 +1,11 @@ package lendingstate import ( - "github.com/XinFinOrg/XDPoSChain/common" "math/big" "reflect" "testing" + + "github.com/XinFinOrg/XDPoSChain/common" ) func TestCalculateInterestRate(t *testing.T) { @@ -171,7 +172,7 @@ func TestGetSettleBalance(t *testing.T) { common.BasePrice, big.NewInt(150), big.NewInt(100), // 1% - common.HexToAddress(common.XDCNativeAddress), + common.XDCNativeAddressBinary, common.Address{}, common.BasePrice, common.BasePrice, @@ -277,7 +278,7 @@ func TestGetSettleBalance(t *testing.T) { common.BasePrice, big.NewInt(150), big.NewInt(100), // 1% - common.HexToAddress(common.XDCNativeAddress), + common.XDCNativeAddressBinary, collateral, common.BasePrice, common.BasePrice, @@ -288,12 +289,12 @@ func TestGetSettleBalance(t *testing.T) { Fee: common.Big0, InToken: common.Address{}, InTotal: common.Big0, - OutToken: common.HexToAddress(common.XDCNativeAddress), + OutToken: common.XDCNativeAddressBinary, OutTotal: lendQuantity, }, Maker: TradeResult{ Fee: fee, - InToken: common.HexToAddress(common.XDCNativeAddress), + InToken: common.XDCNativeAddressBinary, InTotal: lendQuantityExcluded, OutToken: collateral, OutTotal: collateralLocked, @@ -312,7 +313,7 @@ func TestGetSettleBalance(t *testing.T) { common.BasePrice, big.NewInt(150), big.NewInt(100), // 1% - common.HexToAddress(common.XDCNativeAddress), + common.XDCNativeAddressBinary, collateral, common.BasePrice, common.BasePrice, @@ -323,12 +324,12 @@ func TestGetSettleBalance(t *testing.T) { Fee: common.Big0, InToken: common.Address{}, InTotal: common.Big0, - OutToken: common.HexToAddress(common.XDCNativeAddress), + OutToken: common.XDCNativeAddressBinary, OutTotal: lendQuantity, }, Taker: TradeResult{ Fee: fee, - InToken: common.HexToAddress(common.XDCNativeAddress), + InToken: common.XDCNativeAddressBinary, InTotal: lendQuantityExcluded, OutToken: collateral, OutTotal: collateralLocked, diff --git a/XDCxlending/order_processor.go b/XDCxlending/order_processor.go index 5725c2003b..5af2d27713 100644 --- a/XDCxlending/order_processor.go +++ b/XDCxlending/order_processor.go @@ -2,6 +2,7 @@ package XDCxlending import ( "encoding/json" + "errors" "fmt" "math/big" @@ -264,7 +265,7 @@ func (l *Lending) processOrderList(header *types.Header, coinbase common.Address borrowFee = lendingstate.GetFee(statedb, oldestOrder.Relayer) } if collateralToken.IsZero() { - return nil, nil, nil, fmt.Errorf("empty collateral") + return nil, nil, nil, errors.New("empty collateral") } depositRate, liquidationRate, recallRate := lendingstate.GetCollateralDetail(statedb, collateralToken) if depositRate == nil || depositRate.Sign() <= 0 { @@ -282,10 +283,10 @@ func (l *Lending) processOrderList(header *types.Header, coinbase common.Address return nil, nil, nil, err } if lendTokenXDCPrice == nil || lendTokenXDCPrice.Sign() <= 0 { - return nil, nil, nil, fmt.Errorf("invalid lendToken price") + return nil, nil, nil, errors.New("invalid lendToken price") } if collateralPrice == nil || collateralPrice.Sign() <= 0 { - return nil, nil, nil, fmt.Errorf("invalid collateral price") + return nil, nil, nil, errors.New("invalid collateral price") } tradedQuantity, collateralLockedAmount, rejectMaker, settleBalanceResult, err := l.getLendQuantity(lendTokenXDCPrice, collateralPrice, depositRate, borrowFee, coinbase, chain, header, statedb, order, &oldestOrder, maxTradedQuantity) if err != nil && err == lendingstate.ErrQuantityTradeTooSmall && tradedQuantity != nil && tradedQuantity.Sign() >= 0 { @@ -447,7 +448,7 @@ func (l *Lending) getLendQuantity( if err != nil || collateralTokenDecimal.Sign() == 0 { return lendingstate.Zero, lendingstate.Zero, false, nil, fmt.Errorf("fail to get tokenDecimal. Token: %v . Err: %v", collateralToken.String(), err) } - if takerOrder.Relayer.String() == makerOrder.Relayer.String() { + if takerOrder.Relayer == makerOrder.Relayer { if err := lendingstate.CheckRelayerFee(takerOrder.Relayer, new(big.Int).Mul(common.RelayerLendingFee, big.NewInt(2)), statedb); err != nil { log.Debug("Reject order Taker Exchnage = Maker Exchange , relayer not enough fee ", "err", err) return lendingstate.Zero, lendingstate.Zero, false, nil, nil @@ -621,11 +622,11 @@ func DoSettleBalance(coinbase common.Address, takerOrder, makerOrder *lendingsta } mapBalances[settleBalance.Taker.InToken][takerExOwner] = newTakerFee - newCollateralTokenLock, err := lendingstate.CheckAddTokenBalance(common.HexToAddress(common.LendingLockAddress), settleBalance.Taker.OutTotal, settleBalance.Taker.OutToken, statedb, mapBalances) + newCollateralTokenLock, err := lendingstate.CheckAddTokenBalance(common.LendingLockAddressBinary, settleBalance.Taker.OutTotal, settleBalance.Taker.OutToken, statedb, mapBalances) if err != nil { return err } - mapBalances[settleBalance.Taker.OutToken][common.HexToAddress(common.LendingLockAddress)] = newCollateralTokenLock + mapBalances[settleBalance.Taker.OutToken][common.LendingLockAddressBinary] = newCollateralTokenLock } else { relayerFee, err := lendingstate.CheckSubRelayerFee(makerOrder.Relayer, common.RelayerLendingFee, statedb, map[common.Address]*big.Int{}) if err != nil { @@ -662,11 +663,11 @@ func DoSettleBalance(coinbase common.Address, takerOrder, makerOrder *lendingsta } mapBalances[settleBalance.Maker.InToken][makerExOwner] = newMakerFee - newCollateralTokenLock, err := lendingstate.CheckAddTokenBalance(common.HexToAddress(common.LendingLockAddress), settleBalance.Maker.OutTotal, settleBalance.Maker.OutToken, statedb, mapBalances) + newCollateralTokenLock, err := lendingstate.CheckAddTokenBalance(common.LendingLockAddressBinary, settleBalance.Maker.OutTotal, settleBalance.Maker.OutToken, statedb, mapBalances) if err != nil { return err } - mapBalances[settleBalance.Maker.OutToken][common.HexToAddress(common.LendingLockAddress)] = newCollateralTokenLock + mapBalances[settleBalance.Maker.OutToken][common.LendingLockAddressBinary] = newCollateralTokenLock } masternodeOwner := statedb.GetOwner(coinbase) statedb.AddBalance(masternodeOwner, matchingFee) @@ -789,10 +790,10 @@ func (l *Lending) ProcessTopUp(lendingStateDB *lendingstate.LendingStateDB, stat if lendingTrade == lendingstate.EmptyLendingTrade { return fmt.Errorf("process deposit for emptyLendingTrade is not allowed. lendingTradeId: %v", lendingTradeId.Hex()), true, nil } - if order.UserAddress.String() != lendingTrade.Borrower.String() { + if order.UserAddress != lendingTrade.Borrower { return fmt.Errorf("ProcessTopUp: invalid userAddress . UserAddress: %s . Borrower: %s", order.UserAddress.Hex(), lendingTrade.Borrower.Hex()), true, nil } - if order.Relayer.String() != lendingTrade.BorrowingRelayer.String() { + if order.Relayer != lendingTrade.BorrowingRelayer { return fmt.Errorf("ProcessTopUp: invalid relayerAddress . Got: %s . Expect: %s", order.Relayer.Hex(), lendingTrade.BorrowingRelayer.Hex()), true, nil } if order.Quantity.Sign() <= 0 || lendingTrade.TradeId != lendingTradeId.Big().Uint64() { @@ -810,10 +811,10 @@ func (l *Lending) ProcessRepay(header *types.Header, chain consensus.ChainContex if lendingTrade == lendingstate.EmptyLendingTrade || lendingTrade.TradeId != lendingTradeIdHash.Big().Uint64() { return nil, fmt.Errorf("ProcessRepay for emptyLendingTrade is not allowed. lendingTradeId: %v", lendingTradeId) } - if order.UserAddress.String() != lendingTrade.Borrower.String() { + if order.UserAddress != lendingTrade.Borrower { return nil, fmt.Errorf("ProcessRepay: invalid userAddress . UserAddress: %s . Borrower: %s", order.UserAddress.Hex(), lendingTrade.Borrower.Hex()) } - if order.Relayer.String() != lendingTrade.BorrowingRelayer.String() { + if order.Relayer != lendingTrade.BorrowingRelayer { return nil, fmt.Errorf("ProcessRepay: invalid relayerAddress . Got: %s . Expect: %s", order.Relayer.Hex(), lendingTrade.BorrowingRelayer.Hex()) } return l.ProcessRepayLendingTrade(header, chain, lendingStateDB, statedb, tradingstateDB, lendingBook, lendingTradeId) @@ -854,9 +855,9 @@ func (l *Lending) LiquidationExpiredTrade(header *types.Header, chain consensus. } else { repayAmount = lendingTrade.CollateralLockedAmount } - err = lendingstate.SubTokenBalance(common.HexToAddress(common.LendingLockAddress), lendingTrade.CollateralLockedAmount, lendingTrade.CollateralToken, statedb) + err = lendingstate.SubTokenBalance(common.LendingLockAddressBinary, lendingTrade.CollateralLockedAmount, lendingTrade.CollateralToken, statedb) if err != nil { - log.Warn("LiquidationExpiredTrade SubTokenBalance", "err", err, "LendingLockAddress", common.HexToAddress(common.LendingLockAddress), "lendingTrade.CollateralLockedAmount", *lendingTrade.CollateralLockedAmount, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) + log.Warn("LiquidationExpiredTrade SubTokenBalance", "err", err, "LendingLockAddress", common.LendingLockAddress, "lendingTrade.CollateralLockedAmount", *lendingTrade.CollateralLockedAmount, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) } err = lendingstate.AddTokenBalance(lendingTrade.Investor, repayAmount, lendingTrade.CollateralToken, statedb) if err != nil { @@ -897,9 +898,9 @@ func (l *Lending) LiquidationTrade(lendingStateDB *lendingstate.LendingStateDB, if lendingTrade.TradeId != lendingTradeId { return nil, fmt.Errorf("Lending Trade Id not found : %d ", lendingTradeId) } - err := lendingstate.SubTokenBalance(common.HexToAddress(common.LendingLockAddress), lendingTrade.CollateralLockedAmount, lendingTrade.CollateralToken, statedb) + err := lendingstate.SubTokenBalance(common.LendingLockAddressBinary, lendingTrade.CollateralLockedAmount, lendingTrade.CollateralToken, statedb) if err != nil { - log.Warn("LiquidationTrade SubTokenBalance", "err", err, "LendingLockAddress", common.HexToAddress(common.LendingLockAddress), "lendingTrade.CollateralLockedAmount", *lendingTrade.CollateralLockedAmount, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) + log.Warn("LiquidationTrade SubTokenBalance", "err", err, "LendingLockAddress", common.LendingLockAddress, "lendingTrade.CollateralLockedAmount", *lendingTrade.CollateralLockedAmount, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) } err = lendingstate.AddTokenBalance(lendingTrade.Investor, lendingTrade.CollateralLockedAmount, lendingTrade.CollateralToken, statedb) if err != nil { @@ -1052,10 +1053,10 @@ func (l *Lending) GetCollateralPrices(header *types.Header, chain consensus.Chai func (l *Lending) GetXDCBasePrices(header *types.Header, chain consensus.ChainContext, statedb *state.StateDB, tradingStateDb *tradingstate.TradingStateDB, token common.Address) (*big.Int, error) { - tokenXDCPriceFromContract, updatedBlock := lendingstate.GetCollateralPrice(statedb, token, common.HexToAddress(common.XDCNativeAddress)) + tokenXDCPriceFromContract, updatedBlock := lendingstate.GetCollateralPrice(statedb, token, common.XDCNativeAddressBinary) tokenXDCPriceUpdatedFromContract := updatedBlock.Uint64()/chain.Config().XDPoS.Epoch == header.Number.Uint64()/chain.Config().XDPoS.Epoch - if token == common.HexToAddress(common.XDCNativeAddress) { + if token == common.XDCNativeAddressBinary { return common.BasePrice, nil } else if tokenXDCPriceUpdatedFromContract { // getting lendToken price from contract first @@ -1063,7 +1064,7 @@ func (l *Lending) GetXDCBasePrices(header *types.Header, chain consensus.ChainCo log.Debug("Getting token/XDC price from contract", "price", tokenXDCPriceFromContract) return tokenXDCPriceFromContract, nil } else { - XDCTokenPriceFromContract, updatedBlock := lendingstate.GetCollateralPrice(statedb, common.HexToAddress(common.XDCNativeAddress), token) + XDCTokenPriceFromContract, updatedBlock := lendingstate.GetCollateralPrice(statedb, common.XDCNativeAddressBinary, token) XDCTokenPriceUpdatedFromContract := updatedBlock.Uint64()/chain.Config().XDPoS.Epoch == header.Number.Uint64()/chain.Config().XDPoS.Epoch if XDCTokenPriceUpdatedFromContract && XDCTokenPriceFromContract != nil && XDCTokenPriceFromContract.Sign() > 0 { // getting lendToken price from contract first @@ -1078,7 +1079,7 @@ func (l *Lending) GetXDCBasePrices(header *types.Header, chain consensus.ChainCo tokenXDCPrice = new(big.Int).Div(tokenXDCPrice, XDCTokenPriceFromContract) return tokenXDCPrice, nil } - tokenXDCPrice, err := l.GetMediumTradePriceBeforeEpoch(chain, statedb, tradingStateDb, token, common.HexToAddress(common.XDCNativeAddress)) + tokenXDCPrice, err := l.GetMediumTradePriceBeforeEpoch(chain, statedb, tradingStateDb, token, common.XDCNativeAddressBinary) if err != nil { return nil, err } @@ -1133,9 +1134,9 @@ func (l *Lending) ProcessTopUpLendingTrade(lendingStateDB *lendingstate.LendingS if err != nil { log.Warn("ProcessTopUpLendingTrade SubTokenBalance", "err", err, "lendingTrade.Borrower", lendingTrade.Borrower, "quantity", *quantity, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) } - err = lendingstate.AddTokenBalance(common.HexToAddress(common.LendingLockAddress), quantity, lendingTrade.CollateralToken, statedb) + err = lendingstate.AddTokenBalance(common.LendingLockAddressBinary, quantity, lendingTrade.CollateralToken, statedb) if err != nil { - log.Warn("ProcessTopUpLendingTrade AddTokenBalance", "err", err, "LendingLockAddress", common.HexToAddress(common.LendingLockAddress), "quantity", *quantity, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) + log.Warn("ProcessTopUpLendingTrade AddTokenBalance", "err", err, "LendingLockAddress", common.LendingLockAddress, "quantity", *quantity, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) } oldLockedAmount := lendingTrade.CollateralLockedAmount newLockedAmount := new(big.Int).Add(quantity, oldLockedAmount) @@ -1199,9 +1200,9 @@ func (l *Lending) ProcessRepayLendingTrade(header *types.Header, chain consensus if err != nil { log.Warn("ProcessRepayLendingTrade AddTokenBalance", "err", err, "lendingTrade.Investor", lendingTrade.Investor, "paymentBalance", *paymentBalance, "lendingTrade.LendingToken", lendingTrade.LendingToken) } - err = lendingstate.SubTokenBalance(common.HexToAddress(common.LendingLockAddress), lendingTrade.CollateralLockedAmount, lendingTrade.CollateralToken, statedb) + err = lendingstate.SubTokenBalance(common.LendingLockAddressBinary, lendingTrade.CollateralLockedAmount, lendingTrade.CollateralToken, statedb) if err != nil { - log.Warn("ProcessRepayLendingTrade SubTokenBalance", "err", err, "LendingLockAddress", common.HexToAddress(common.LendingLockAddress), "lendingTrade.CollateralLockedAmount", *lendingTrade.CollateralLockedAmount, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) + log.Warn("ProcessRepayLendingTrade SubTokenBalance", "err", err, "LendingLockAddress", common.LendingLockAddress, "lendingTrade.CollateralLockedAmount", *lendingTrade.CollateralLockedAmount, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) } err = lendingstate.AddTokenBalance(lendingTrade.Borrower, lendingTrade.CollateralLockedAmount, lendingTrade.CollateralToken, statedb) if err != nil { @@ -1255,9 +1256,9 @@ func (l *Lending) ProcessRecallLendingTrade(lendingStateDB *lendingstate.Lending if err != nil { log.Warn("ProcessRecallLendingTrade AddTokenBalance", "err", err, "lendingTrade.Borrower", lendingTrade.Borrower, "recallAmount", *recallAmount, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) } - err = lendingstate.SubTokenBalance(common.HexToAddress(common.LendingLockAddress), recallAmount, lendingTrade.CollateralToken, statedb) + err = lendingstate.SubTokenBalance(common.LendingLockAddressBinary, recallAmount, lendingTrade.CollateralToken, statedb) if err != nil { - log.Warn("ProcessRecallLendingTrade SubTokenBalance", "err", err, "LendingLockAddress", common.HexToAddress(common.LendingLockAddress), "recallAmount", *recallAmount, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) + log.Warn("ProcessRecallLendingTrade SubTokenBalance", "err", err, "LendingLockAddress", common.LendingLockAddress, "recallAmount", *recallAmount, "lendingTrade.CollateralToken", lendingTrade.CollateralToken) } lendingStateDB.UpdateLiquidationPrice(lendingBook, lendingTrade.TradeId, newLiquidationPrice) diff --git a/XDCxlending/order_processor_test.go b/XDCxlending/order_processor_test.go index 028ed09aed..172caa90a2 100644 --- a/XDCxlending/order_processor_test.go +++ b/XDCxlending/order_processor_test.go @@ -1,14 +1,15 @@ package XDCxlending import ( + "math/big" + "reflect" + "testing" + "github.com/XinFinOrg/XDPoSChain/XDCx" "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/rawdb" - "math/big" - "reflect" - "testing" ) func Test_getCancelFeeV1(t *testing.T) { @@ -107,9 +108,9 @@ func Test_getCancelFee(t *testing.T) { XDCx.SetTokenDecimal(testTokenB, new(big.Int).Exp(big.NewInt(10), big.NewInt(8), nil)) // set tokenAPrice = 1 XDC - tradingStateDb.SetMediumPriceBeforeEpoch(tradingstate.GetTradingOrderBookHash(testTokenA, common.HexToAddress(common.XDCNativeAddress)), common.BasePrice) + tradingStateDb.SetMediumPriceBeforeEpoch(tradingstate.GetTradingOrderBookHash(testTokenA, common.XDCNativeAddressBinary), common.BasePrice) // set tokenBPrice = 1 XDC - tradingStateDb.SetMediumPriceBeforeEpoch(tradingstate.GetTradingOrderBookHash(testTokenB, common.HexToAddress(common.XDCNativeAddress)), common.BasePrice) + tradingStateDb.SetMediumPriceBeforeEpoch(tradingstate.GetTradingOrderBookHash(testTokenB, common.XDCNativeAddressBinary), common.BasePrice) l := New(XDCx) @@ -132,7 +133,7 @@ func Test_getCancelFee(t *testing.T) { borrowFeeRate: common.Big0, order: &lendingstate.LendingItem{ LendingToken: testTokenA, - CollateralToken: common.HexToAddress(common.XDCNativeAddress), + CollateralToken: common.XDCNativeAddressBinary, Quantity: new(big.Int).SetUint64(10000), Side: lendingstate.Investing, }, @@ -147,7 +148,7 @@ func Test_getCancelFee(t *testing.T) { borrowFeeRate: common.Big0, order: &lendingstate.LendingItem{ LendingToken: testTokenA, - CollateralToken: common.HexToAddress(common.XDCNativeAddress), + CollateralToken: common.XDCNativeAddressBinary, Quantity: new(big.Int).SetUint64(10000), Side: lendingstate.Borrowing, }, @@ -162,7 +163,7 @@ func Test_getCancelFee(t *testing.T) { borrowFeeRate: new(big.Int).SetUint64(30), // 30/10000= 0.3% order: &lendingstate.LendingItem{ LendingToken: testTokenA, - CollateralToken: common.HexToAddress(common.XDCNativeAddress), + CollateralToken: common.XDCNativeAddressBinary, Quantity: new(big.Int).SetUint64(10000), Side: lendingstate.Investing, }, @@ -177,7 +178,7 @@ func Test_getCancelFee(t *testing.T) { borrowFeeRate: new(big.Int).SetUint64(30), // 30/10000= 0.3% order: &lendingstate.LendingItem{ LendingToken: testTokenA, - CollateralToken: common.HexToAddress(common.XDCNativeAddress), + CollateralToken: common.XDCNativeAddressBinary, Quantity: new(big.Int).SetUint64(10000), Side: lendingstate.Borrowing, }, @@ -194,7 +195,7 @@ func Test_getCancelFee(t *testing.T) { CancelFeeArg{ borrowFeeRate: common.Big0, order: &lendingstate.LendingItem{ - LendingToken: common.HexToAddress(common.XDCNativeAddress), + LendingToken: common.XDCNativeAddressBinary, CollateralToken: testTokenA, Quantity: new(big.Int).SetUint64(10000), Side: lendingstate.Investing, @@ -209,7 +210,7 @@ func Test_getCancelFee(t *testing.T) { CancelFeeArg{ borrowFeeRate: common.Big0, order: &lendingstate.LendingItem{ - LendingToken: common.HexToAddress(common.XDCNativeAddress), + LendingToken: common.XDCNativeAddressBinary, CollateralToken: testTokenA, Quantity: new(big.Int).SetUint64(10000), Side: lendingstate.Borrowing, @@ -224,7 +225,7 @@ func Test_getCancelFee(t *testing.T) { CancelFeeArg{ borrowFeeRate: new(big.Int).SetUint64(30), // 30/10000= 0.3% order: &lendingstate.LendingItem{ - LendingToken: common.HexToAddress(common.XDCNativeAddress), + LendingToken: common.XDCNativeAddressBinary, CollateralToken: testTokenA, Quantity: new(big.Int).SetUint64(10000), Side: lendingstate.Investing, @@ -239,7 +240,7 @@ func Test_getCancelFee(t *testing.T) { CancelFeeArg{ borrowFeeRate: new(big.Int).SetUint64(30), // 30/10000= 0.3% order: &lendingstate.LendingItem{ - LendingToken: common.HexToAddress(common.XDCNativeAddress), + LendingToken: common.XDCNativeAddressBinary, CollateralToken: testTokenA, Quantity: new(big.Int).SetUint64(10000), Side: lendingstate.Borrowing, diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index e3687968a5..0ca97525cd 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -79,19 +79,19 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { // Unpack output in v according to the abi specification func (abi ABI) Unpack(v interface{}, name string, output []byte) (err error) { if len(output) == 0 { - return fmt.Errorf("abi: unmarshalling empty output") + return errors.New("abi: unmarshalling empty output") } // since there can't be naming collisions with contracts and events, // we need to decide whether we're calling a method or an event if method, ok := abi.Methods[name]; ok { if len(output)%32 != 0 { - return fmt.Errorf("abi: improperly formatted output") + return errors.New("abi: improperly formatted output") } return method.Outputs.Unpack(v, output) } else if event, ok := abi.Events[name]; ok { return event.Inputs.Unpack(v, output) } - return fmt.Errorf("abi: could not locate named method or event") + return errors.New("abi: could not locate named method or event") } // UnmarshalJSON implements json.Unmarshaler interface diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 70bf2c8ac2..0df4bd7450 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -25,11 +25,9 @@ import ( "sync" "time" + "github.com/XinFinOrg/XDPoSChain" "github.com/XinFinOrg/XDPoSChain/XDCx" "github.com/XinFinOrg/XDPoSChain/XDCxlending" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" - - "github.com/XinFinOrg/XDPoSChain" "github.com/XinFinOrg/XDPoSChain/accounts" "github.com/XinFinOrg/XDPoSChain/accounts/abi/bind" "github.com/XinFinOrg/XDPoSChain/accounts/keystore" @@ -37,10 +35,10 @@ import ( "github.com/XinFinOrg/XDPoSChain/common/math" "github.com/XinFinOrg/XDPoSChain/consensus/XDPoS" "github.com/XinFinOrg/XDPoSChain/consensus/XDPoS/utils" - "github.com/XinFinOrg/XDPoSChain/consensus/ethash" "github.com/XinFinOrg/XDPoSChain/core" "github.com/XinFinOrg/XDPoSChain/core/bloombits" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" "github.com/XinFinOrg/XDPoSChain/core/state" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/core/vm" @@ -365,7 +363,7 @@ func (b *SimulatedBackend) callContract(ctx context.Context, call XDPoSChain.Cal from := statedb.GetOrNewStateObject(call.From) from.SetBalance(math.MaxBig256) // Execute the call. - msg := callmsg{call} + msg := callMsg{call} feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) if msg.To() != nil { if value, ok := feeCapacity[*msg.To()]; ok { @@ -388,7 +386,10 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa b.mu.Lock() defer b.mu.Unlock() - sender, err := types.Sender(types.HomesteadSigner{}, tx) + // Check transaction validity. + block := b.blockchain.CurrentBlock() + signer := types.MakeSigner(b.blockchain.Config(), block.Number()) + sender, err := types.Sender(signer, tx) if err != nil { panic(fmt.Errorf("invalid transaction: %v", err)) } @@ -397,7 +398,8 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa panic(fmt.Errorf("invalid transaction nonce: got %d, want %d", tx.Nonce(), nonce)) } - blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), b.blockchain.Engine(), b.database, 1, func(number int, block *core.BlockGen) { + // Include tx in chain. + blocks, _ := core.GenerateChain(b.config, block, b.blockchain.Engine(), b.database, 1, func(number int, block *core.BlockGen) { for _, tx := range b.pendingBlock.Transactions() { block.AddTxWithChain(b.blockchain, tx) } @@ -501,20 +503,21 @@ func (b *SimulatedBackend) GetBlockChain() *core.BlockChain { return b.blockchain } -// callmsg implements core.Message to allow passing it as a transaction simulator. -type callmsg struct { +// callMsg implements core.Message to allow passing it as a transaction simulator. +type callMsg struct { XDPoSChain.CallMsg } -func (m callmsg) From() common.Address { return m.CallMsg.From } -func (m callmsg) Nonce() uint64 { return 0 } -func (m callmsg) CheckNonce() bool { return false } -func (m callmsg) To() *common.Address { return m.CallMsg.To } -func (m callmsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } -func (m callmsg) Gas() uint64 { return m.CallMsg.Gas } -func (m callmsg) Value() *big.Int { return m.CallMsg.Value } -func (m callmsg) Data() []byte { return m.CallMsg.Data } -func (m callmsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } +func (m callMsg) From() common.Address { return m.CallMsg.From } +func (m callMsg) Nonce() uint64 { return 0 } +func (m callMsg) CheckNonce() bool { return false } +func (m callMsg) To() *common.Address { return m.CallMsg.To } +func (m callMsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } +func (m callMsg) Gas() uint64 { return m.CallMsg.Gas } +func (m callMsg) Value() *big.Int { return m.CallMsg.Value } +func (m callMsg) Data() []byte { return m.CallMsg.Data } +func (m callMsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } +func (m callMsg) AccessList() types.AccessList { return m.CallMsg.AccessList } // filterBackend implements filters.Backend to support filtering for logs without // taking bloom-bits acceleration structures into account. diff --git a/accounts/abi/bind/base.go b/accounts/abi/bind/base.go index 1b5c71701c..d0e2d1c393 100644 --- a/accounts/abi/bind/base.go +++ b/accounts/abi/bind/base.go @@ -335,7 +335,7 @@ func (c *BoundContract) UnpackLog(out interface{}, event string, log types.Log) return errNoEventSignature } if log.Topics[0] != c.abi.Events[event].Id() { - return fmt.Errorf("event signature mismatch") + return errors.New("event signature mismatch") } if len(log.Data) > 0 { if err := c.abi.Unpack(out, event, log.Data); err != nil { diff --git a/accounts/abi/bind/util.go b/accounts/abi/bind/util.go index 7943884f07..dec604a801 100644 --- a/accounts/abi/bind/util.go +++ b/accounts/abi/bind/util.go @@ -18,7 +18,7 @@ package bind import ( "context" - "fmt" + "errors" "time" "github.com/XinFinOrg/XDPoSChain/common" @@ -56,14 +56,14 @@ func WaitMined(ctx context.Context, b DeployBackend, tx *types.Transaction) (*ty // contract address when it is mined. It stops waiting when ctx is canceled. func WaitDeployed(ctx context.Context, b DeployBackend, tx *types.Transaction) (common.Address, error) { if tx.To() != nil { - return common.Address{}, fmt.Errorf("tx is not contract creation") + return common.Address{}, errors.New("tx is not contract creation") } receipt, err := WaitMined(ctx, b, tx) if err != nil { return common.Address{}, err } if receipt.ContractAddress == (common.Address{}) { - return common.Address{}, fmt.Errorf("zero address") + return common.Address{}, errors.New("zero address") } // Check that code has indeed been deployed at the address. // This matters on pre-Homestead chains: OOG in the constructor diff --git a/accounts/abi/reflect.go b/accounts/abi/reflect.go index 2e6bf7098f..c1d411ac95 100644 --- a/accounts/abi/reflect.go +++ b/accounts/abi/reflect.go @@ -17,6 +17,7 @@ package abi import ( + "errors" "fmt" "reflect" ) @@ -117,7 +118,7 @@ func requireUniqueStructFieldNames(args Arguments) error { for _, arg := range args { field := capitalise(arg.Name) if field == "" { - return fmt.Errorf("abi: purely underscored output cannot unpack to struct") + return errors.New("abi: purely underscored output cannot unpack to struct") } if exists[field] { return fmt.Errorf("abi: multiple outputs mapping to the same struct field '%s'", field) diff --git a/accounts/abi/type.go b/accounts/abi/type.go index cec1ce8f56..355ac3ac77 100644 --- a/accounts/abi/type.go +++ b/accounts/abi/type.go @@ -17,6 +17,7 @@ package abi import ( + "errors" "fmt" "reflect" "regexp" @@ -61,7 +62,7 @@ var ( func NewType(t string) (typ Type, err error) { // check that array brackets are equal if they exist if strings.Count(t, "[") != strings.Count(t, "]") { - return Type{}, fmt.Errorf("invalid arg type in abi") + return Type{}, errors.New("invalid arg type in abi") } typ.stringKind = t @@ -98,7 +99,7 @@ func NewType(t string) (typ Type, err error) { } typ.Type = reflect.ArrayOf(typ.Size, embeddedType.Type) } else { - return Type{}, fmt.Errorf("invalid formatting of array type") + return Type{}, errors.New("invalid formatting of array type") } return typ, err } diff --git a/accounts/abi/unpack.go b/accounts/abi/unpack.go index 9f0bccf75b..ff2bacebbc 100644 --- a/accounts/abi/unpack.go +++ b/accounts/abi/unpack.go @@ -18,6 +18,7 @@ package abi import ( "encoding/binary" + "errors" "fmt" "math/big" "reflect" @@ -70,7 +71,7 @@ func readBool(word []byte) (bool, error) { // This enforces that standard by always presenting it as a 24-array (address + sig = 24 bytes) func readFunctionType(t Type, word []byte) (funcTy [24]byte, err error) { if t.T != FunctionTy { - return [24]byte{}, fmt.Errorf("abi: invalid type in call to make function type byte array") + return [24]byte{}, errors.New("abi: invalid type in call to make function type byte array") } if garbage := binary.BigEndian.Uint64(word[24:32]); garbage != 0 { err = fmt.Errorf("abi: got improperly encoded function type, got %v", word) @@ -83,7 +84,7 @@ func readFunctionType(t Type, word []byte) (funcTy [24]byte, err error) { // through reflection, creates a fixed array to be read from func readFixedBytes(t Type, word []byte) (interface{}, error) { if t.T != FixedBytesTy { - return nil, fmt.Errorf("abi: invalid type in call to make fixed byte array") + return nil, errors.New("abi: invalid type in call to make fixed byte array") } // convert array := reflect.New(t.Type).Elem() @@ -123,7 +124,7 @@ func forEachUnpack(t Type, output []byte, start, size int) (interface{}, error) // declare our array refSlice = reflect.New(t.Type).Elem() } else { - return nil, fmt.Errorf("abi: invalid type in array/slice unpacking stage") + return nil, errors.New("abi: invalid type in array/slice unpacking stage") } // Arrays have packed elements, resulting in longer unpack steps. diff --git a/accounts/accounts.go b/accounts/accounts.go index ba57577925..157112c8c1 100644 --- a/accounts/accounts.go +++ b/accounts/accounts.go @@ -18,12 +18,14 @@ package accounts import ( + "fmt" "math/big" ethereum "github.com/XinFinOrg/XDPoSChain" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/event" + "golang.org/x/crypto/sha3" ) // Account represents an Ethereum account located at a specific location defined @@ -148,6 +150,34 @@ type Backend interface { Subscribe(sink chan<- WalletEvent) event.Subscription } +// TextHash is a helper function that calculates a hash for the given message that can be +// safely used to calculate a signature from. +// +// The hash is calulcated as +// +// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). +// +// This gives context to the signed message and prevents signing of transactions. +func TextHash(data []byte) []byte { + hash, _ := TextAndHash(data) + return hash +} + +// TextAndHash is a helper function that calculates a hash for the given message that can be +// safely used to calculate a signature from. +// +// The hash is calulcated as +// +// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). +// +// This gives context to the signed message and prevents signing of transactions. +func TextAndHash(data []byte) ([]byte, string) { + msg := fmt.Sprintf("\x19Ethereum Signed Message:\n%d%s", len(data), string(data)) + hasher := sha3.NewLegacyKeccak256() + hasher.Write([]byte(msg)) + return hasher.Sum(nil), msg +} + // WalletEventType represents the different event types that can be fired by // the wallet subscription subsystem. type WalletEventType int diff --git a/accounts/keystore/account_cache_test.go b/accounts/keystore/account_cache_test.go index e2436f4160..aa1636efc8 100644 --- a/accounts/keystore/account_cache_test.go +++ b/accounts/keystore/account_cache_test.go @@ -17,6 +17,7 @@ package keystore import ( + "errors" "fmt" "math/rand" "os" @@ -305,7 +306,7 @@ func waitForAccounts(wantAccounts []accounts.Account, ks *KeyStore) error { select { case <-ks.changes: default: - return fmt.Errorf("wasn't notified of new accounts") + return errors.New("wasn't notified of new accounts") } return nil } diff --git a/accounts/keystore/keystore.go b/accounts/keystore/keystore.go index 93854794c4..80ac1102d3 100644 --- a/accounts/keystore/keystore.go +++ b/accounts/keystore/keystore.go @@ -24,7 +24,6 @@ import ( "crypto/ecdsa" crand "crypto/rand" "errors" - "fmt" "math/big" "os" "path/filepath" @@ -288,11 +287,9 @@ func (ks *KeyStore) SignTx(a accounts.Account, tx *types.Transaction, chainID *b if !found { return nil, ErrLocked } - // Depending on the presence of the chain ID, sign with EIP155 or homestead - if chainID != nil { - return types.SignTx(tx, types.NewEIP155Signer(chainID), unlockedKey.PrivateKey) - } - return types.SignTx(tx, types.HomesteadSigner{}, unlockedKey.PrivateKey) + // Depending on the presence of the chain ID, sign with 2718 or homestead + signer := types.LatestSignerForChainID(chainID) + return types.SignTx(tx, signer, unlockedKey.PrivateKey) } // SignHashWithPassphrase signs hash if the private key matching the given address @@ -316,11 +313,9 @@ func (ks *KeyStore) SignTxWithPassphrase(a accounts.Account, passphrase string, } defer zeroKey(key.PrivateKey) - // Depending on the presence of the chain ID, sign with EIP155 or homestead - if chainID != nil { - return types.SignTx(tx, types.NewEIP155Signer(chainID), key.PrivateKey) - } - return types.SignTx(tx, types.HomesteadSigner{}, key.PrivateKey) + // Depending on the presence of the chain ID, sign with or without replay protection. + signer := types.LatestSignerForChainID(chainID) + return types.SignTx(tx, signer, key.PrivateKey) } // Unlock unlocks the given account indefinitely. @@ -459,7 +454,7 @@ func (ks *KeyStore) Import(keyJSON []byte, passphrase, newPassphrase string) (ac func (ks *KeyStore) ImportECDSA(priv *ecdsa.PrivateKey, passphrase string) (accounts.Account, error) { key := newKeyFromECDSA(priv) if ks.cache.hasAddress(key.Address) { - return accounts.Account{}, fmt.Errorf("account already exists") + return accounts.Account{}, errors.New("account already exists") } return ks.importKey(key, passphrase) } diff --git a/accounts/usbwallet/trezor.go b/accounts/usbwallet/trezor.go index 551b69e224..04e57292ba 100644 --- a/accounts/usbwallet/trezor.go +++ b/accounts/usbwallet/trezor.go @@ -27,13 +27,13 @@ import ( "io" "math/big" - "github.com/golang/protobuf/proto" "github.com/XinFinOrg/XDPoSChain/accounts" "github.com/XinFinOrg/XDPoSChain/accounts/usbwallet/internal/trezor" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/log" + "github.com/golang/protobuf/proto" ) // ErrTrezorPINNeeded is returned if opening the trezor requires a PIN code. In @@ -80,13 +80,13 @@ func (w *trezorDriver) Status() (string, error) { // Open implements usbwallet.driver, attempting to initialize the connection to // the Trezor hardware wallet. Initializing the Trezor is a two phase operation: -// * The first phase is to initialize the connection and read the wallet's -// features. This phase is invoked is the provided passphrase is empty. The -// device will display the pinpad as a result and will return an appropriate -// error to notify the user that a second open phase is needed. -// * The second phase is to unlock access to the Trezor, which is done by the -// user actually providing a passphrase mapping a keyboard keypad to the pin -// number of the user (shuffled according to the pinpad displayed). +// - The first phase is to initialize the connection and read the wallet's +// features. This phase is invoked is the provided passphrase is empty. The +// device will display the pinpad as a result and will return an appropriate +// error to notify the user that a second open phase is needed. +// - The second phase is to unlock access to the Trezor, which is done by the +// user actually providing a passphrase mapping a keyboard keypad to the pin +// number of the user (shuffled according to the pinpad displayed). func (w *trezorDriver) Open(device io.ReadWriter, passphrase string) error { w.device, w.failure = device, nil @@ -220,9 +220,11 @@ func (w *trezorDriver) trezorSign(derivationPath []uint32, tx *types.Transaction if chainID == nil { signer = new(types.HomesteadSigner) } else { + // Trezor backend does not support typed transactions yet. signer = types.NewEIP155Signer(chainID) signature[64] = signature[64] - byte(chainID.Uint64()*2+35) } + // Inject the final signature into the transaction and sanity check the sender signed, err := tx.WithSignature(signer, signature) if err != nil { diff --git a/bmt/bmt_test.go b/bmt/bmt_test.go index 7a47963700..ac762993c5 100644 --- a/bmt/bmt_test.go +++ b/bmt/bmt_test.go @@ -19,6 +19,7 @@ package bmt import ( "bytes" crand "crypto/rand" + "errors" "fmt" "hash" "io" @@ -288,7 +289,7 @@ func TestHasherConcurrency(t *testing.T) { var err error select { case <-time.NewTimer(5 * time.Second).C: - err = fmt.Errorf("timed out") + err = errors.New("timed out") case err = <-errc: } if err != nil { @@ -321,7 +322,7 @@ func testHasherCorrectness(bmt hash.Hash, hasher BaseHasher, d []byte, n, count }() select { case <-timeout.C: - err = fmt.Errorf("BMT hash calculation timed out") + err = errors.New("BMT hash calculation timed out") case err = <-c: } return err diff --git a/cicd/devnet/start-local-devnet.sh b/cicd/devnet/start-local-devnet.sh index 4bbd5786c9..c0a9448a86 100755 --- a/cicd/devnet/start-local-devnet.sh +++ b/cicd/devnet/start-local-devnet.sh @@ -55,7 +55,7 @@ echo "Running a node with wallet: ${wallet} at local" --datadir ./tmp/xdcchain --networkid 551 \ -port 30303 --rpc --rpccorsdomain "*" --rpcaddr 0.0.0.0 \ --rpcport 8545 \ ---rpcapi admin,db,eth,debug,miner,net,shh,txpool,personal,web3,XDPoS \ +--rpcapi db,eth,debug,miner,net,shh,txpool,personal,web3,XDPoS \ --rpcvhosts "*" --unlock "${wallet}" --password ./tmp/.pwd --mine \ --gasprice "1" --targetgaslimit "420000000" --verbosity ${log_level} \ --ws --wsaddr=0.0.0.0 --wsport 8555 \ diff --git a/cicd/devnet/start.sh b/cicd/devnet/start.sh index 4c3a81ee5b..707b8b32a5 100755 --- a/cicd/devnet/start.sh +++ b/cicd/devnet/start.sh @@ -77,7 +77,7 @@ XDC --ethstats ${netstats} --gcmode archive \ --datadir /work/xdcchain --networkid 551 \ -port $port --rpc --rpccorsdomain "*" --rpcaddr 0.0.0.0 \ --rpcport $rpc_port \ ---rpcapi admin,db,eth,debug,net,shh,txpool,personal,web3,XDPoS \ +--rpcapi db,eth,debug,net,shh,txpool,personal,web3,XDPoS \ --rpcvhosts "*" --unlock "${wallet}" --password /work/.pwd --mine \ --gasprice "1" --targetgaslimit "420000000" --verbosity ${log_level} \ --debugdatadir /work/xdcchain \ diff --git a/cicd/mainnet/start.sh b/cicd/mainnet/start.sh index 35f11a5d34..0cfbe26e27 100755 --- a/cicd/mainnet/start.sh +++ b/cicd/mainnet/start.sh @@ -76,7 +76,7 @@ XDC --ethstats ${netstats} --gcmode archive \ --datadir /work/xdcchain --networkid 50 \ -port $port --rpc --rpccorsdomain "*" --rpcaddr 0.0.0.0 \ --rpcport $rpc_port \ ---rpcapi admin,db,eth,debug,net,shh,txpool,personal,web3,XDPoS \ +--rpcapi db,eth,debug,net,shh,txpool,personal,web3,XDPoS \ --rpcvhosts "*" --unlock "${wallet}" --password /work/.pwd --mine \ --gasprice "1" --targetgaslimit "420000000" --verbosity ${log_level} \ --debugdatadir /work/xdcchain \ diff --git a/cicd/terraform/main.tf b/cicd/terraform/main.tf index 5a44d22385..dcb6e03914 100644 --- a/cicd/terraform/main.tf +++ b/cicd/terraform/main.tf @@ -86,6 +86,7 @@ module "devnet_rpc" { instance_type = "t3.large" ssh_key_name = local.ssh_key_name rpc_image = local.rpc_image + volume_size = 1500 providers = { aws = aws.ap-southeast-1 @@ -101,6 +102,7 @@ module "testnet_rpc" { instance_type = "t3.large" ssh_key_name = local.ssh_key_name rpc_image = local.rpc_image + volume_size = 1500 providers = { aws = aws.ap-southeast-1 @@ -116,6 +118,7 @@ module "mainnet_rpc" { instance_type = "t3.large" ssh_key_name = local.ssh_key_name rpc_image = local.rpc_image + volume_size = 3000 providers = { aws = aws.ap-southeast-1 diff --git a/cicd/terraform/module/ec2_rpc/main.tf b/cicd/terraform/module/ec2_rpc/main.tf index 75535dd0ac..00594517ac 100644 --- a/cicd/terraform/module/ec2_rpc/main.tf +++ b/cicd/terraform/module/ec2_rpc/main.tf @@ -27,6 +27,9 @@ variable ssh_key_name { variable rpc_image { type = string } +variable volume_size{ + type = number +} resource "aws_security_group" "rpc_sg" { name_prefix = "${var.network}_rpc_sg" @@ -75,9 +78,9 @@ resource "aws_instance" "rpc_instance" { } key_name = var.ssh_key_name vpc_security_group_ids = [aws_security_group.rpc_sg.id] - ebs_block_device { - device_name = "/dev/xvda" - volume_size = 500 + + root_block_device { + volume_size = var.volume_size } diff --git a/cicd/testnet/bootnodes.list b/cicd/testnet/bootnodes.list index db7e2c81ff..9fca6cadab 100644 --- a/cicd/testnet/bootnodes.list +++ b/cicd/testnet/bootnodes.list @@ -1,9 +1,39 @@ enode://1c20e6b46ce608c1fe739e78611225b94e663535b74a1545b1667eac8ff75ed43216306d123306c10e043f228e42cc53cb2728655019292380313393eaaf6e23@188.227.164.51:30301 -enode://1c20e6b46ce608c1fe739e78611225b94e663535b74a1545b1667eac8ff75ed43216306d123306c10e043f228e42cc53cb2728655019292380313393eaaf6e23@95.179.217.201:30301 -enode://1c20e6b46ce608c1fe739e78611225b94e663535b74a1545b1667eac8ff75ed43216306d123306c10e043f228e42cc53cb2728655019292380313393eaaf6e23@149.28.167.190:30301 -enode://1c20e6b46ce608c1fe739e78611225b94e663535b74a1545b1667eac8ff75ed43216306d123306c10e043f228e42cc53cb2728655019292380313393eaaf6e23@194.233.77.19:30301 -enode://1c20e6b46ce608c1fe739e78611225b94e663535b74a1545b1667eac8ff75ed43216306d123306c10e043f228e42cc53cb2728655019292380313393eaaf6e23@144.91.108.231:30301 -enode://1c20e6b46ce608c1fe739e78611225b94e663535b74a1545b1667eac8ff75ed43216306d123306c10e043f228e42cc53cb2728655019292380313393eaaf6e23@207.244.240.232:30301 -enode://1c20e6b46ce608c1fe739e78611225b94e663535b74a1545b1667eac8ff75ed43216306d123306c10e043f228e42cc53cb2728655019292380313393eaaf6e23@66.94.121.62:30301 -enode://1c20e6b46ce608c1fe739e78611225b94e663535b74a1545b1667eac8ff75ed43216306d123306c10e043f228e42cc53cb2728655019292380313393eaaf6e23@144.126.150.69:30301 -enode://1c20e6b46ce608c1fe739e78611225b94e663535b74a1545b1667eac8ff75ed43216306d123306c10e043f228e42cc53cb2728655019292380313393eaaf6e23@161.97.93.168:30301 +enode://278a1ac8aab1381b460788dcf4dcaffd58252f0f4c57d95e4c68b4b2bbf0b0fee83b5c140b3e6eb3af8b8369e76fa6996ccc667f65f9ce35c1d14fbffbfdfd52@95.179.217.201:30301 +enode://518a2b963f0e41ef6520973766a8caad9b13c4261c50a73eedfeaa3ebf3cce623dadc132d8d42ec41a50195c925ca55abd4e2a96258f9d335e3f625e6c2df789@149.28.167.190:30301 +enode://63f5a65ffce84b7123562b8bc871ae63db6dadded14df1dae6ad09b4a30e722a72e522d99aaf0bbad19380a0dd3abe94768ed03a8f68204b5dee6ae16ddb4d83@194.233.77.19:30301 +enode://87497c34ce0102e888040d6ec530e73c33bb0330770105797cf5667f465115d49d53cb583a8b4e34ff14048448c63f03fbd1c269897d7b99716dcdb942b6d707@144.91.108.231:30301 +enode://8c372fc5859e69a49037768f2ece54f14d1d159ec3181d06d7b58f1a54dac4f2f13623f1377a7e506e48e5f33b1752a9726a44b33fdb158ce7d6d5a57c055f65@207.244.240.232:30301 +enode://bed408a65f2894b09ac11ec03eee6b4b1af279e27c81fbe942aa6300ebe5d734913c0c7b30cdbcf98dfebd7a46bf9c79cfc1a8878e20240b5da872dc358be571@66.94.121.62:30301 +enode://dd8973429effba6d6f8edd27efad8544553bb39c82c6350c09e82cd3cfdcdecaf6d0ce9820ed97482aedd0b9236e88d362ca031a65dcbdac52aa292e693723e3@144.126.150.69:30301 +enode://fb5616c009265a162c08fb3984192ffc4df18946dc4d591ba5480011169e6f3f324685b5e102e8123d32e27cd968220d6400b9f1f0a6d1611c130079cd1e71bb@161.97.93.168:30301 +enode://75e95709fd89a6314b0d5363226e3e46c56d98c6ddd3ec3cbdffcb65fdfedc8cbad870e2af1382b97e5986c1465cdb00e621cc01e1910f622ba20ed2359a02a0@213.136.89.186:30304 +enode://34ba74c6a0ef379040243e19c5f673b29155ec5928a5300f1cdfd2215983b7720d52585acb16f7199cc5abdd9ff3657305e901e097f951e35740b0a80fce3e18@185.217.126.17:30304 +enode://d7c070939155be296a8b254d43a0927e6c6777f1352239499fab867af455b9c671bddd5f9ae359ddcd6af9a368746e2f993416bafff2d03a4d974d888daaf020@38.242.244.116:30304 +enode://0ad61fd32bcc99b32f6ead959f2f3472ecd11cd5b555767983b1827240cbac25fd978b94cf79ca8423088e5dd519479a6f7f26fe43a98e0da0b4b1caa995923d@207.180.210.192:30304 +enode://0985cfee342bf68bc21fea7ce728018d3a5f27a43a2c79b78e9103d07ac4893960fc399603d372744dfe2020a970a1daba979210e07c1d27bf6cfc317036ae13@173.249.33.28:30304 +enode://4c750ba2c069e00a8bbe37e45053e04a975a4cc635f1c53506da555bc9cc137da2680b76f48a232b88f762153af83aced601ad45475576100f175c0085750822@209.145.57.76:30304 +enode://d31b551d02ead096d5cbb3adda68fbcfbc76e4f939a6a6fe41e1a4e1be19c6f4b1c45a7545977764267d8bc6b7d1835c0e2060045a223af8ebb81d5cd30d01de@5.189.132.151:30304 +enode://686d8c5b886bf29633b73e5b4f7eb73cb1afa8c11cd5d3cc79027b742eeaff62fe44a42417d18c2d96ff8de2bf2c0c73ce51a07f3d5635b232e2eece090234b6@164.68.125.57:30304 +enode://7325b2eca70dfb9bde340ccfb6a5076f146f2af2401fae7996044207c95c797c2be4c0d90d76183abfdf33395553d6eb05fff6259c8fbe1df884df85d59f40b0@31.220.84.216:30304 +enode://f8e45296452c4e3988f9398bee1e1be029993a5332bd293629397dd71cca281fd7aaa3ade4d92c63a984218a6dd7ffddcd70135cfd5b3a66e0dd124dc9c35a37@31.220.84.220:30304 +enode://d8d8dff11b36dd683daccbb0306be706d97838dc0239aca077ab5475cbd488f9ebafa9a4a242d87f6def31028969756060ea412621d83a4715bd9a47e0787d3d@31.220.84.222:30304 +enode://0d3a38063c594523dbd619033ecc6b87545b204e91aa883b789389483baca964fcc4221832ea470308e1faa56b90705be234a113d63a3f0f7347d1237b58fec2@31.220.84.224:30304 +enode://1f479698e303b5ee9b8d35cdc6c660486b39a3f3f4581a1ecae5464142bef3103abe8665f1e25c864decaf80ee1261c2dc6f4f0caf2a2a66708fe934a63a564c@158.220.87.132:30304 +enode://6dd64c63402ebe46c1043c97ce66b7de88fb220b45f7435964bbfe6af2f3d0bf4d6b4702c274231afdab7d246054f85414befbe6ed0086aa5e9d98272e931278@158.220.87.144:30304 +enode://587d3c6962fedf45f07c758b5d7e07fbd3570e599a69f82d3f4676012327a92644785599ceaac9a35ed4f35478bd3dfe0dd4788b4529fa05cef15ba2e611f045@84.46.248.126:30304 +enode://e9684ecbf96348727f21d1d2253893aa2c5815dd8d9e0e2f8e842dfd92347d1daab1c2cea3a379a66fd6bf5321945eb65cc794a10345a4e1ec0a0121b77481d3@185.209.230.34:30304 +enode://52744d57d65b91fa2ccc43cb5758fca34f97509b068fc1fe1daee4d905cfbae7172a1647c28583b97fc51aaf0bf23b4e47340e2e658fcd5683daaa91c8d41c5d@167.86.125.253:30304 +enode://09aa42d01437643d3f2a02e12e0d6e41a6951281ef80d2a332dd025d5da0d4990b9932695f91ffb0a0998cec2531dc9c0e8e0fd4a3bb0b69409ad76c02da9f4e@167.86.125.15:30304 +enode://017ef59b3f3734aad9277170cd08c2a3231c75a63465085584ca00ffb32daccd2dfe657b2a3539af1d45e6b24251a18f2bc1b0e31b61f57c2039af95ebda1e2c@95.111.237.15:30304 +enode://e3eb10d6616dc9dbe6006bafbe02fceaef60bcce66666ef357b458e91df64b33572014ea2a162873ec6a68af80a405cc0a60e8125f8ec155be7e34caa4e8aeb6@173.212.253.234:30304 +enode://cbba3cf7f151bf79319107d0f24ca0fed9d6bc32bd922b173e5b87174fedd1d25378cf827236f05c00ce403a68d2ba75dda28e16ca43cbe87b0231a9ad16426e@178.18.249.111:30304 +enode://8b9a5f0433ed2dcb7fe0e4424ae6f83158cc73a562c8c5c1733a01eb2c92c1b77ffb46dc1f4b54d8b2392d9fd985661b968d9d062ffc73ece02b58f41589527a@173.249.54.137:30304 +enode://38d42a90e8c00beff03dc2f759a00e2f910b5bfd7b9eaafab3f24682c5e0bd2ae7094823a1af006295922848e3e417188f90766840220f377a780c8c5afa4c47@38.242.129.33:30304 +enode://e9efc55e68dbc38842c5ed43c597d7550d31d2c9d7e074e14a43a9304fa2eb8a2ad0955eef1961e6c5b850ce2061479623ba2ad8ae22db944f17fd4a19a8b902@167.86.106.195:30304 +enode://0c5dc604acbf5f04bd69e015e89c1c015f2641f5dbe8f56c0b2dc1d03e7b7d79e049ed3cca6b0f23cfecfd32c394a53e3d3810b0c4784b92646e658e1d879bf6@84.21.171.77:30304 +enode://a1dd03c142c2db1a786e51250e929f2bf5edf7575f23ecfedec6da4a51e89e7ed36479fb4f40677da496364894b97ab189475fbb6c8b7ca2ecfe4b3ff3d24ca9@85.208.51.215:30304 +enode://94389b49ca856a91b9962ead5a562a64383b7f8fdb819b6fa22d29db984a78b89245e894b8c29facc765348399e1dc4d4a16a787f94f21df8813bc7e703db4fc@167.86.118.99:30304 +enode://b786080fabb1b07046359c1820db88c589c6c4a964ec8b1f1a287c12c4178dec2fbbf406187e16dc29fd69cc5e4960741471bebbfe2ae111036f66b61f16cf41@38.242.129.40:30304 +enode://326d2d562754cdd72e1a4b7d42de44713d97f00c4b6e7250e6896b99d331a395012fe0cc6a2d133bdd8784e02649eb490197b65244f72003fff0c829d8695cb4@167.86.83.167:30304 +enode://c85d71dfb2dfc5abd832ed60bc0197ecb788d0f6bd44655d870dec6002a23345790344d65a2cfc0d2c0e023a7a6f630f99b6575cf5ca57203870f1971d1fc370@167.86.83.166:30304 \ No newline at end of file diff --git a/cicd/testnet/start.sh b/cicd/testnet/start.sh index d5f9a0f443..abfd91cadf 100755 --- a/cicd/testnet/start.sh +++ b/cicd/testnet/start.sh @@ -78,7 +78,7 @@ XDC --ethstats ${netstats} --gcmode archive \ --datadir /work/xdcchain --networkid 51 \ -port $port --rpc --rpccorsdomain "*" --rpcaddr 0.0.0.0 \ --rpcport $rpc_port \ ---rpcapi admin,db,eth,debug,net,shh,txpool,personal,web3,XDPoS \ +--rpcapi db,eth,debug,net,shh,txpool,personal,web3,XDPoS \ --rpcvhosts "*" --unlock "${wallet}" --password /work/.pwd --mine \ --gasprice "1" --targetgaslimit "420000000" --verbosity ${log_level} \ --debugdatadir /work/xdcchain \ diff --git a/cmd/evm/json_logger.go b/cmd/evm/json_logger.go index 90e44f9c4a..a5b8c0fea2 100644 --- a/cmd/evm/json_logger.go +++ b/cmd/evm/json_logger.go @@ -33,15 +33,23 @@ type JSONLogger struct { } func NewJSONLogger(cfg *vm.LogConfig, writer io.Writer) *JSONLogger { - return &JSONLogger{json.NewEncoder(writer), cfg} + l := &JSONLogger{json.NewEncoder(writer), cfg} + if l.cfg == nil { + l.cfg = &vm.LogConfig{} + } + return l } -func (l *JSONLogger) CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error { - return nil +func (l *JSONLogger) CaptureStart(env *vm.EVM, from, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { +} + +func (l *JSONLogger) CaptureFault(*vm.EVM, uint64, vm.OpCode, uint64, uint64, *vm.ScopeContext, int, error) { } // CaptureState outputs state information on the logger. -func (l *JSONLogger) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, contract *vm.Contract, depth int, err error) error { +func (l *JSONLogger) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, rData []byte, depth int, err error) { + memory := scope.Memory + stack := scope.Stack log := vm.StructLog{ Pc: pc, Op: op, @@ -63,24 +71,20 @@ func (l *JSONLogger) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cos } log.Stack = logstack } - return l.encoder.Encode(log) -} - -// CaptureFault outputs state information on the logger. -func (l *JSONLogger) CaptureFault(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, contract *vm.Contract, depth int, err error) error { - return nil + l.encoder.Encode(log) } // CaptureEnd is triggered at end of execution. -func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error { +func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) { type endLog struct { Output string `json:"output"` GasUsed math.HexOrDecimal64 `json:"gasUsed"` Time time.Duration `json:"time"` Err string `json:"error,omitempty"` } + var errMsg string if err != nil { - return l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, err.Error()}) + errMsg = err.Error() } - return l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, ""}) + l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, errMsg}) } diff --git a/cmd/gc/main.go b/cmd/gc/main.go index ce48f0aae8..4b3b004b77 100644 --- a/cmd/gc/main.go +++ b/cmd/gc/main.go @@ -28,7 +28,7 @@ var ( cacheSize = flag.Int("size", 1000000, "LRU cache size") sercureKey = []byte("secure-key-") nWorker = runtime.NumCPU() / 2 - cleanAddress = []common.Address{common.HexToAddress(common.BlockSigners)} + cleanAddress = []common.Address{common.BlockSignersBinary} cache *lru.Cache finish = int32(0) running = true diff --git a/cmd/puppeth/wizard_genesis.go b/cmd/puppeth/wizard_genesis.go index 92aa2197cb..c1e077c595 100644 --- a/cmd/puppeth/wizard_genesis.go +++ b/cmd/puppeth/wizard_genesis.go @@ -145,8 +145,8 @@ func (w *wizard) makeGenesis() { genesis.Config.XDPoS.V2.CurrentConfig.TimeoutSyncThreshold = w.readDefaultInt(3) fmt.Println() - fmt.Printf("How many v2 vote collection to generate a QC, should be two thirds of masternodes? (default = %f)\n", 0.666) - genesis.Config.XDPoS.V2.CurrentConfig.CertThreshold = w.readDefaultFloat(0.666) + fmt.Printf("Proportion of total masternodes v2 vote collection to generate a QC (float value), should be two thirds of masternodes? (default = %f)\n", 0.667) + genesis.Config.XDPoS.V2.CurrentConfig.CertThreshold = w.readDefaultFloat(0.667) genesis.Config.XDPoS.V2.AllConfigs[0] = genesis.Config.XDPoS.V2.CurrentConfig @@ -197,7 +197,7 @@ func (w *wizard) makeGenesis() { fmt.Println() fmt.Println("What is foundation wallet address? (default = xdc0000000000000000000000000000000000000068)") - genesis.Config.XDPoS.FoudationWalletAddr = w.readDefaultAddress(common.HexToAddress(common.FoudationAddr)) + genesis.Config.XDPoS.FoudationWalletAddr = w.readDefaultAddress(common.FoudationAddrBinary) // Validator Smart Contract Code pKey, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") @@ -225,7 +225,7 @@ func (w *wizard) makeGenesis() { return true } contractBackend.ForEachStorageAt(ctx, validatorAddress, nil, f) - genesis.Alloc[common.HexToAddress(common.MasternodeVotingSMC)] = core.GenesisAccount{ + genesis.Alloc[common.MasternodeVotingSMCBinary] = core.GenesisAccount{ Balance: validatorCap.Mul(validatorCap, big.NewInt(int64(len(validatorCaps)))), Code: code, Storage: storage, @@ -259,7 +259,7 @@ func (w *wizard) makeGenesis() { fBalance := big.NewInt(0) // 16m fBalance.Add(fBalance, big.NewInt(16*1000*1000)) fBalance.Mul(fBalance, big.NewInt(1000000000000000000)) - genesis.Alloc[common.HexToAddress(common.FoudationAddr)] = core.GenesisAccount{ + genesis.Alloc[common.FoudationAddrBinary] = core.GenesisAccount{ Balance: fBalance, Code: code, Storage: storage, @@ -275,7 +275,7 @@ func (w *wizard) makeGenesis() { code, _ = contractBackend.CodeAt(ctx, blockSignerAddress, nil) storage = make(map[common.Hash]common.Hash) contractBackend.ForEachStorageAt(ctx, blockSignerAddress, nil, f) - genesis.Alloc[common.HexToAddress(common.BlockSigners)] = core.GenesisAccount{ + genesis.Alloc[common.BlockSignersBinary] = core.GenesisAccount{ Balance: big.NewInt(0), Code: code, Storage: storage, @@ -291,7 +291,7 @@ func (w *wizard) makeGenesis() { code, _ = contractBackend.CodeAt(ctx, randomizeAddress, nil) storage = make(map[common.Hash]common.Hash) contractBackend.ForEachStorageAt(ctx, randomizeAddress, nil, f) - genesis.Alloc[common.HexToAddress(common.RandomizeSMC)] = core.GenesisAccount{ + genesis.Alloc[common.RandomizeSMCBinary] = core.GenesisAccount{ Balance: big.NewInt(0), Code: code, Storage: storage, @@ -330,7 +330,7 @@ func (w *wizard) makeGenesis() { subBalance.Add(subBalance, big.NewInt(int64(len(signers))*50*1000)) subBalance.Mul(subBalance, big.NewInt(1000000000000000000)) balance.Sub(balance, subBalance) // 12m - i * 50k - genesis.Alloc[common.HexToAddress(common.TeamAddr)] = core.GenesisAccount{ + genesis.Alloc[common.TeamAddrBinary] = core.GenesisAccount{ Balance: balance, Code: code, Storage: storage, diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index 161a65a417..0e736b95e0 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -19,6 +19,7 @@ package utils import ( "compress/gzip" + "errors" "fmt" "io" "os" @@ -130,7 +131,7 @@ func ImportChain(chain *core.BlockChain, fn string) error { for batch := 0; ; batch++ { // Load a batch of RLP blocks. if checkInterrupt() { - return fmt.Errorf("interrupted") + return errors.New("interrupted") } i := 0 for ; i < importBatchSize; i++ { @@ -153,7 +154,7 @@ func ImportChain(chain *core.BlockChain, fn string) error { } // Import the batch. if checkInterrupt() { - return fmt.Errorf("interrupted") + return errors.New("interrupted") } missing := missingBlocks(chain, blocks[:i]) if len(missing) == 0 { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 6ad25c7daf..a9fd25c492 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -118,7 +118,7 @@ var ( Enable0xPrefixFlag = cli.BoolFlag{ Name: "enable-0x-prefix", Usage: "Addres use 0x-prefix (Deprecated: this is on by default, to use xdc prefix use --enable-xdc-prefix)", - } + } EnableXDCPrefixFlag = cli.BoolFlag{ Name: "enable-xdc-prefix", Usage: "Addres use xdc-prefix (default = false)", @@ -358,6 +358,11 @@ var ( Name: "vmdebug", Usage: "Record information useful for VM and contract debugging", } + RPCGlobalGasCapFlag = cli.Uint64Flag{ + Name: "rpc.gascap", + Usage: "Sets a cap on gas that can be used in eth_call/estimateGas (0=infinite)", + Value: eth.DefaultConfig.RPCGasCap, + } // Logging and debug settings EthStatsURLFlag = cli.StringFlag{ Name: "ethstats", @@ -1176,6 +1181,11 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { // TODO(fjl): force-enable this in --dev mode cfg.EnablePreimageRecording = ctx.GlobalBool(VMEnableDebugFlag.Name) } + if cfg.RPCGasCap != 0 { + log.Info("Set global gas cap", "cap", cfg.RPCGasCap) + } else { + log.Info("Global gas cap disabled") + } if ctx.GlobalIsSet(StoreRewardFlag.Name) { common.StoreRewardFolder = filepath.Join(stack.DataDir(), "XDC", "rewards") if _, err := os.Stat(common.StoreRewardFolder); os.IsNotExist(err) { diff --git a/cmd/wnode/main.go b/cmd/wnode/main.go index 5441428e88..5fa29ab96c 100644 --- a/cmd/wnode/main.go +++ b/cmd/wnode/main.go @@ -139,8 +139,8 @@ func processArgs() { } if *asymmetricMode && len(*argPub) > 0 { - pub = crypto.ToECDSAPub(common.FromHex(*argPub)) - if !isKeyValid(pub) { + var err error + if pub, err = crypto.UnmarshalPubkey(common.FromHex(*argPub)); err != nil { utils.Fatalf("invalid public key") } } @@ -337,9 +337,8 @@ func configureNode() { if b == nil { utils.Fatalf("Error: can not convert hexadecimal string") } - pub = crypto.ToECDSAPub(b) - if !isKeyValid(pub) { - utils.Fatalf("Error: invalid public key") + if pub, err = crypto.UnmarshalPubkey(b); err != nil { + utils.Fatalf("Error: invalid peer public key") } } } diff --git a/common/constants.go b/common/constants.go index a2fa952460..cbe997f115 100644 --- a/common/constants.go +++ b/common/constants.go @@ -47,11 +47,12 @@ var TIPXDCXCancellationFee = big.NewInt(38383838) var TIPXDCXCancellationFeeTestnet = big.NewInt(38383838) var TIPXDCXMinerDisable = big.NewInt(88999999900) var TIPXDCXReceiverDisable = big.NewInt(99999999999) -var BerlinBlock = big.NewInt(9999999999) -var LondonBlock = big.NewInt(9999999999) -var MergeBlock = big.NewInt(9999999999) -var ShanghaiBlock = big.NewInt(9999999999) var Eip1559Block = big.NewInt(9999999999) +var TIPXDCXDISABLE = big.NewInt(99999999900) +var BerlinBlock = big.NewInt(76321000) // Target 19th June 2024 +var LondonBlock = big.NewInt(76321000) // Target 19th June 2024 +var MergeBlock = big.NewInt(76321000) // Target 19th June 2024 +var ShanghaiBlock = big.NewInt(76321000) // Target 19th June 2024 var TIPXDCXTestnet = big.NewInt(38383838) var IsTestnet bool = false diff --git a/common/countdown/countdown_test.go b/common/countdown/countdown_test.go index fb5356dfc9..6f1b0e1022 100644 --- a/common/countdown/countdown_test.go +++ b/common/countdown/countdown_test.go @@ -1,7 +1,7 @@ package countdown import ( - "fmt" + "errors" "testing" "time" @@ -76,7 +76,7 @@ func TestCountdownShouldResetEvenIfErrored(t *testing.T) { called := make(chan int) OnTimeoutFn := func(time.Time, interface{}) error { called <- 1 - return fmt.Errorf("ERROR!") + return errors.New("ERROR!") } countdown := NewCountDown(5000 * time.Millisecond) diff --git a/common/types.go b/common/types.go index 905c7d5a4e..c7abab289b 100644 --- a/common/types.go +++ b/common/types.go @@ -50,6 +50,20 @@ const ( XDCZApplyMethod = "0xc6b32f34" ) +var ( + BlockSignersBinary = Address{19: 0x89} // xdc0000000000000000000000000000000000000089 + MasternodeVotingSMCBinary = Address{19: 0x88} // xdc0000000000000000000000000000000000000088 + RandomizeSMCBinary = Address{19: 0x90} // xdc0000000000000000000000000000000000000090 + FoudationAddrBinary = Address{19: 0x68} // xdc0000000000000000000000000000000000000068 + TeamAddrBinary = Address{19: 0x99} // xdc0000000000000000000000000000000000000099 + XDCXAddrBinary = Address{19: 0x91} // xdc0000000000000000000000000000000000000091 + TradingStateAddrBinary = Address{19: 0x92} // xdc0000000000000000000000000000000000000092 + XDCXLendingAddressBinary = Address{19: 0x93} // xdc0000000000000000000000000000000000000093 + XDCXLendingFinalizedTradeAddressBinary = Address{19: 0x94} // xdc0000000000000000000000000000000000000094 + XDCNativeAddressBinary = Address{19: 0x01} // xdc0000000000000000000000000000000000000001 + LendingLockAddressBinary = Address{19: 0x11} // xdc0000000000000000000000000000000000000011 +) + var ( hashT = reflect.TypeOf(Hash{}) addressT = reflect.TypeOf(Address{}) @@ -256,7 +270,7 @@ func (a *Address) Set(other Address) { // MarshalText returns the hex representation of a. func (a Address) MarshalText() ([]byte, error) { // Handle '0x' or 'xdc' prefix here. - if (Enable0xPrefix) { + if Enable0xPrefix { return hexutil.Bytes(a[:]).MarshalText() } else { return hexutil.Bytes(a[:]).MarshalXDCText() diff --git a/common/types_test.go b/common/types_test.go index 1ec2bfc26c..7621a6e4b8 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -167,3 +167,39 @@ func TestRemoveItemInArray(t *testing.T) { t.Error("fail remove item from array address") } } + +var testCases = []struct { + bin Address + str string +}{ + {BlockSignersBinary, BlockSigners}, + {MasternodeVotingSMCBinary, MasternodeVotingSMC}, + {RandomizeSMCBinary, RandomizeSMC}, + {FoudationAddrBinary, FoudationAddr}, + {TeamAddrBinary, TeamAddr}, + {XDCXAddrBinary, XDCXAddr}, + {TradingStateAddrBinary, TradingStateAddr}, + {XDCXLendingAddressBinary, XDCXLendingAddress}, + {XDCXLendingFinalizedTradeAddressBinary, XDCXLendingFinalizedTradeAddress}, + {XDCNativeAddressBinary, XDCNativeAddress}, + {LendingLockAddressBinary, LendingLockAddress}, +} + +func TestBinaryAddressToString(t *testing.T) { + for _, tt := range testCases { + have := tt.bin.String() + want := tt.str + if have != want { + t.Errorf("fail to convert binary address to string address\nwant:%s\nhave:%s", have, want) + } + } +} +func TestStringToBinaryAddress(t *testing.T) { + for _, tt := range testCases { + want := tt.bin + have := HexToAddress(tt.str) + if have != want { + t.Errorf("fail to convert string address to binary address\nwant:%s\nhave:%s", have, want) + } + } +} diff --git a/consensus/XDPoS/XDPoS.go b/consensus/XDPoS/XDPoS.go index 6d5477a1ee..fc7b341894 100644 --- a/consensus/XDPoS/XDPoS.go +++ b/consensus/XDPoS/XDPoS.go @@ -17,6 +17,7 @@ package XDPoS import ( + "errors" "fmt" "math/big" @@ -438,7 +439,7 @@ func (x *XDPoS) CalculateMissingRounds(chain consensus.ChainReader, header *type case params.ConsensusEngineVersion2: return x.EngineV2.CalculateMissingRounds(chain, header) default: // Default "v1" - return nil, fmt.Errorf("Not supported in the v1 consensus") + return nil, errors.New("Not supported in the v1 consensus") } } diff --git a/consensus/XDPoS/api.go b/consensus/XDPoS/api.go index a922e57f78..7209939662 100644 --- a/consensus/XDPoS/api.go +++ b/consensus/XDPoS/api.go @@ -259,7 +259,7 @@ func (api *API) GetV2BlockByHash(blockHash common.Hash) *V2BlockInfo { func (api *API) NetworkInformation() NetworkInformation { info := NetworkInformation{} info.NetworkId = api.chain.Config().ChainId - info.XDCValidatorAddress = common.HexToAddress(common.MasternodeVotingSMC) + info.XDCValidatorAddress = common.MasternodeVotingSMCBinary if common.IsTestnet { info.LendingAddress = common.HexToAddress(common.LendingRegistrationSMCTestnet) info.RelayerRegistrationAddress = common.HexToAddress(common.RelayerRegistrationSMCTestnet) diff --git a/consensus/XDPoS/engines/engine_v1/engine.go b/consensus/XDPoS/engines/engine_v1/engine.go index 7a862370c8..6317200fa2 100644 --- a/consensus/XDPoS/engines/engine_v1/engine.go +++ b/consensus/XDPoS/engines/engine_v1/engine.go @@ -686,7 +686,7 @@ func (x *XDPoS_v1) GetValidator(creator common.Address, chain consensus.ChainRea if no%epoch == 0 { cpHeader = header } else { - return common.Address{}, fmt.Errorf("couldn't find checkpoint header") + return common.Address{}, errors.New("couldn't find checkpoint header") } } m, err := getM1M2FromCheckpointHeader(cpHeader, header, chain.Config()) diff --git a/consensus/XDPoS/engines/engine_v2/engine.go b/consensus/XDPoS/engines/engine_v2/engine.go index 9cc25d7fab..c2235f5d32 100644 --- a/consensus/XDPoS/engines/engine_v2/engine.go +++ b/consensus/XDPoS/engines/engine_v2/engine.go @@ -661,7 +661,7 @@ func (x *XDPoS_v2) VerifyTimeoutMessage(chain consensus.ChainReader, timeoutMsg } if len(snap.NextEpochCandidates) == 0 { log.Error("[VerifyTimeoutMessage] cannot find NextEpochCandidates from snapshot", "messageGapNumber", timeoutMsg.GapNumber) - return false, fmt.Errorf("Empty master node lists from snapshot") + return false, errors.New("Empty master node lists from snapshot") } verified, signer, err := x.verifyMsgSignature(types.TimeoutSigHash(&types.TimeoutForSign{ @@ -748,7 +748,7 @@ func (x *XDPoS_v2) VerifyBlockInfo(blockChainReader consensus.ChainReader, block // If blockHeader present, then its value shall consistent with what's provided in the blockInfo if blockHeader.Hash() != blockInfo.Hash { log.Warn("[VerifyBlockInfo] BlockHeader and blockInfo mismatch", "BlockInfoHash", blockInfo.Hash.Hex(), "BlockHeaderHash", blockHeader.Hash()) - return fmt.Errorf("[VerifyBlockInfo] Provided blockheader does not match what's in the blockInfo") + return errors.New("[VerifyBlockInfo] Provided blockheader does not match what's in the blockInfo") } } @@ -761,7 +761,7 @@ func (x *XDPoS_v2) VerifyBlockInfo(blockChainReader consensus.ChainReader, block if blockInfo.Number.Cmp(x.config.V2.SwitchBlock) == 0 { if blockInfo.Round != 0 { log.Error("[VerifyBlockInfo] Switch block round is not 0", "BlockInfoHash", blockInfo.Hash.Hex(), "BlockInfoNum", blockInfo.Number, "BlockInfoRound", blockInfo.Round, "blockHeaderNum", blockHeader.Number) - return fmt.Errorf("[VerifyBlockInfo] switch block round have to be 0") + return errors.New("[VerifyBlockInfo] switch block round have to be 0") } return nil } @@ -789,7 +789,7 @@ func (x *XDPoS_v2) verifyQC(blockChainReader consensus.ChainReader, quorumCert * epochInfo, err := x.getEpochSwitchInfo(blockChainReader, parentHeader, quorumCert.ProposedBlockInfo.Hash) if err != nil { log.Error("[verifyQC] Error when getting epoch switch Info to verify QC", "Error", err) - return fmt.Errorf("Fail to verify QC due to failure in getting epoch switch info") + return errors.New("Fail to verify QC due to failure in getting epoch switch info") } signatures, duplicates := UniqueSignatures(quorumCert.Signatures) @@ -821,12 +821,12 @@ func (x *XDPoS_v2) verifyQC(blockChainReader consensus.ChainReader, quorumCert * }), sig, epochInfo.Masternodes) if err != nil { log.Error("[verifyQC] Error while verfying QC message signatures", "Error", err) - haveError = fmt.Errorf("Error while verfying QC message signatures") + haveError = errors.New("Error while verfying QC message signatures") return } if !verified { log.Warn("[verifyQC] Signature not verified doing QC verification", "QC", quorumCert) - haveError = fmt.Errorf("Fail to verify QC due to signature mis-match") + haveError = errors.New("Fail to verify QC due to signature mis-match") return } }(signature) diff --git a/consensus/XDPoS/engines/engine_v2/forensics.go b/consensus/XDPoS/engines/engine_v2/forensics.go index 1c4542ab0e..45f88097d6 100644 --- a/consensus/XDPoS/engines/engine_v2/forensics.go +++ b/consensus/XDPoS/engines/engine_v2/forensics.go @@ -2,6 +2,7 @@ package engine_v2 import ( "encoding/json" + "errors" "fmt" "math/big" "reflect" @@ -49,7 +50,7 @@ func (f *Forensics) SetCommittedQCs(headers []types.Header, incomingQC types.Quo // highestCommitQCs is an array, assign the parentBlockQc and its child as well as its grandchild QC into this array for forensics purposes. if len(headers) != NUM_OF_FORENSICS_QC-1 { log.Error("[SetCommittedQcs] Received input length not equal to 2", len(headers)) - return fmt.Errorf("received headers length not equal to 2 ") + return errors.New("received headers length not equal to 2 ") } var committedQCs []types.QuorumCert @@ -64,11 +65,11 @@ func (f *Forensics) SetCommittedQCs(headers []types.Header, incomingQC types.Quo if i != 0 { if decodedExtraField.QuorumCert.ProposedBlockInfo.Hash != headers[i-1].Hash() { log.Error("[SetCommittedQCs] Headers shall be on the same chain and in the right order", "parentHash", h.ParentHash.Hex(), "headers[i-1].Hash()", headers[i-1].Hash().Hex()) - return fmt.Errorf("headers shall be on the same chain and in the right order") + return errors.New("headers shall be on the same chain and in the right order") } else if i == len(headers)-1 { // The last header shall be pointed by the incoming QC if incomingQC.ProposedBlockInfo.Hash != h.Hash() { log.Error("[SetCommittedQCs] incomingQc is not pointing at the last header received", "hash", h.Hash().Hex(), "incomingQC.ProposedBlockInfo.Hash", incomingQC.ProposedBlockInfo.Hash.Hex()) - return fmt.Errorf("incomingQc is not pointing at the last header received") + return errors.New("incomingQc is not pointing at the last header received") } } } @@ -91,7 +92,7 @@ func (f *Forensics) ProcessForensics(chain consensus.ChainReader, engine *XDPoS_ highestCommittedQCs := f.HighestCommittedQCs if len(highestCommittedQCs) != NUM_OF_FORENSICS_QC { log.Error("[ProcessForensics] HighestCommittedQCs value not set", "incomingQcProposedBlockHash", incomingQC.ProposedBlockInfo.Hash, "incomingQcProposedBlockNumber", incomingQC.ProposedBlockInfo.Number.Uint64(), "incomingQcProposedBlockRound", incomingQC.ProposedBlockInfo.Round) - return fmt.Errorf("HighestCommittedQCs value not set") + return errors.New("HighestCommittedQCs value not set") } // Find the QC1 and QC2. We only care 2 parents in front of the incomingQC. The returned value contains QC1, QC2 and QC3(the incomingQC) incomingQuorunCerts, err := f.findAncestorQCs(chain, incomingQC, 2) @@ -163,7 +164,7 @@ func (f *Forensics) SendForensicProof(chain consensus.ChainReader, engine *XDPoS if ancestorBlock == nil { log.Error("[SendForensicProof] Unable to find the ancestor block by its hash", "Hash", ancestorHash) - return fmt.Errorf("Can't find ancestor block via hash") + return errors.New("Can't find ancestor block via hash") } content, err := json.Marshal(&types.ForensicsContent{ @@ -209,7 +210,7 @@ func (f *Forensics) findAncestorQCs(chain consensus.ChainReader, currentQc types parentHeader := chain.GetHeaderByHash(parentHash) if parentHeader == nil { log.Error("[findAncestorQCs] Forensics findAncestorQCs unable to find its parent block header", "BlockNum", parentHeader.Number.Int64(), "ParentHash", parentHash.Hex()) - return nil, fmt.Errorf("unable to find parent block header in forensics") + return nil, errors.New("unable to find parent block header in forensics") } var decodedExtraField types.ExtraFields_v2 err := utils.DecodeBytesExtraFields(parentHeader.Extra, &decodedExtraField) @@ -318,7 +319,7 @@ func (f *Forensics) findAncestorQcThroughRound(chain consensus.ChainReader, high } ancestorQC = *decodedExtraField.QuorumCert } - return ancestorQC, lowerRoundQCs, higherRoundQCs, fmt.Errorf("[findAncestorQcThroughRound] Could not find ancestor QC") + return ancestorQC, lowerRoundQCs, higherRoundQCs, errors.New("[findAncestorQcThroughRound] Could not find ancestor QC") } func (f *Forensics) FindAncestorBlockHash(chain consensus.ChainReader, firstBlockInfo *types.BlockInfo, secondBlockInfo *types.BlockInfo) (common.Hash, []string, []string, error) { @@ -398,7 +399,7 @@ func (f *Forensics) ProcessVoteEquivocation(chain consensus.ChainReader, engine highestCommittedQCs := f.HighestCommittedQCs if len(highestCommittedQCs) != NUM_OF_FORENSICS_QC { log.Error("[ProcessVoteEquivocation] HighestCommittedQCs value not set", "incomingVoteProposedBlockHash", incomingVote.ProposedBlockInfo.Hash, "incomingVoteProposedBlockNumber", incomingVote.ProposedBlockInfo.Number.Uint64(), "incomingVoteProposedBlockRound", incomingVote.ProposedBlockInfo.Round) - return fmt.Errorf("HighestCommittedQCs value not set") + return errors.New("HighestCommittedQCs value not set") } if incomingVote.ProposedBlockInfo.Round < highestCommittedQCs[NUM_OF_FORENSICS_QC-1].ProposedBlockInfo.Round { log.Debug("Received a too old vote in forensics", "vote", incomingVote) diff --git a/consensus/XDPoS/engines/engine_v2/snapshot.go b/consensus/XDPoS/engines/engine_v2/snapshot.go index 78ce0aff55..ccae598412 100644 --- a/consensus/XDPoS/engines/engine_v2/snapshot.go +++ b/consensus/XDPoS/engines/engine_v2/snapshot.go @@ -64,7 +64,7 @@ func (s *SnapshotV2) GetMappedCandidates() map[common.Address]struct{} { func (s *SnapshotV2) IsCandidates(address common.Address) bool { for _, n := range s.NextEpochCandidates { - if n.String() == address.String() { + if n == address { return true } } diff --git a/consensus/XDPoS/engines/engine_v2/timeout.go b/consensus/XDPoS/engines/engine_v2/timeout.go index 39d8100b24..2f077b2a00 100644 --- a/consensus/XDPoS/engines/engine_v2/timeout.go +++ b/consensus/XDPoS/engines/engine_v2/timeout.go @@ -1,6 +1,7 @@ package engine_v2 import ( + "errors" "fmt" "strconv" "strings" @@ -99,7 +100,7 @@ func (x *XDPoS_v2) verifyTC(chain consensus.ChainReader, timeoutCert *types.Time } if snap == nil || len(snap.NextEpochCandidates) == 0 { log.Error("[verifyTC] Something wrong with the snapshot from gapNumber", "messageGapNumber", timeoutCert.GapNumber, "snapshot", snap) - return fmt.Errorf("empty master node lists from snapshot") + return errors.New("empty master node lists from snapshot") } signatures, duplicates := UniqueSignatures(timeoutCert.Signatures) @@ -145,7 +146,7 @@ func (x *XDPoS_v2) verifyTC(chain consensus.ChainReader, timeoutCert *types.Time haveError = fmt.Errorf("error while verifying TC message signatures, %s", err) } else { log.Warn("[verifyTC] Signature not verified doing TC verification", "timeoutCert.Round", timeoutCert.Round, "timeoutCert.GapNumber", timeoutCert.GapNumber, "Signatures len", len(signatures)) - haveError = fmt.Errorf("fail to verify TC due to signature mis-match") + haveError = errors.New("fail to verify TC due to signature mis-match") } } mutex.Unlock() // Unlock after modifying haveError diff --git a/consensus/XDPoS/engines/engine_v2/utils.go b/consensus/XDPoS/engines/engine_v2/utils.go index e8006951ea..2019881628 100644 --- a/consensus/XDPoS/engines/engine_v2/utils.go +++ b/consensus/XDPoS/engines/engine_v2/utils.go @@ -1,6 +1,7 @@ package engine_v2 import ( + "errors" "fmt" "github.com/XinFinOrg/XDPoSChain/accounts" @@ -105,7 +106,7 @@ func (x *XDPoS_v2) signSignature(signingHash common.Hash) (types.Signature, erro func (x *XDPoS_v2) verifyMsgSignature(signedHashToBeVerified common.Hash, signature types.Signature, masternodes []common.Address) (bool, common.Address, error) { var signerAddress common.Address if len(masternodes) == 0 { - return false, signerAddress, fmt.Errorf("Empty masternode list detected when verifying message signatures") + return false, signerAddress, errors.New("Empty masternode list detected when verifying message signatures") } // Recover the public key and the Ethereum address pubkey, err := crypto.Ecrecover(signedHashToBeVerified.Bytes(), signature) diff --git a/consensus/XDPoS/engines/engine_v2/vote.go b/consensus/XDPoS/engines/engine_v2/vote.go index 1ec2d4b24e..de79585297 100644 --- a/consensus/XDPoS/engines/engine_v2/vote.go +++ b/consensus/XDPoS/engines/engine_v2/vote.go @@ -1,6 +1,7 @@ package engine_v2 import ( + "errors" "fmt" "math/big" "strconv" @@ -78,7 +79,7 @@ func (x *XDPoS_v2) voteHandler(chain consensus.ChainReader, voteMsg *types.Vote) epochInfo, err := x.getEpochSwitchInfo(chain, chain.CurrentHeader(), chain.CurrentHeader().Hash()) if err != nil { log.Error("[voteHandler] Error when getting epoch switch Info", "error", err) - return fmt.Errorf("Fail on voteHandler due to failure in getting epoch switch info") + return errors.New("Fail on voteHandler due to failure in getting epoch switch info") } certThreshold := x.config.V2.Config(uint64(voteMsg.ProposedBlockInfo.Round)).CertThreshold @@ -177,7 +178,7 @@ func (x *XDPoS_v2) onVotePoolThresholdReached(chain consensus.ChainReader, poole epochInfo, err := x.getEpochSwitchInfo(chain, chain.CurrentHeader(), chain.CurrentHeader().Hash()) if err != nil { log.Error("[voteHandler] Error when getting epoch switch Info", "error", err) - return fmt.Errorf("Fail on voteHandler due to failure in getting epoch switch info") + return errors.New("Fail on voteHandler due to failure in getting epoch switch info") } // Skip and wait for the next vote to process again if valid votes is less than what we required diff --git a/consensus/XDPoS/utils/utils.go b/consensus/XDPoS/utils/utils.go index 1b4e816dd1..62b394c836 100644 --- a/consensus/XDPoS/utils/utils.go +++ b/consensus/XDPoS/utils/utils.go @@ -2,6 +2,7 @@ package utils import ( "bytes" + "errors" "fmt" "reflect" "sort" @@ -79,7 +80,7 @@ func CompareSignersLists(list1 []common.Address, list2 []common.Address) bool { // Decode extra fields for consensus version >= 2 (XDPoS 2.0 and future versions) func DecodeBytesExtraFields(b []byte, val interface{}) error { if len(b) == 0 { - return fmt.Errorf("extra field is 0 length") + return errors.New("extra field is 0 length") } switch b[0] { case 2: diff --git a/consensus/tests/engine_v1_tests/helper.go b/consensus/tests/engine_v1_tests/helper.go index 3b777e487b..52baa648e0 100644 --- a/consensus/tests/engine_v1_tests/helper.go +++ b/consensus/tests/engine_v1_tests/helper.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/hex" + "errors" "fmt" "math/big" "math/rand" @@ -144,11 +145,11 @@ func getCommonBackend(t *testing.T, chainConfig *params.ChainConfig) *backends.S // create test backend with smart contract in it contractBackend2 := backends.NewXDCSimulatedBackend(core.GenesisAlloc{ - acc1Addr: {Balance: new(big.Int).SetUint64(10000000000)}, - acc2Addr: {Balance: new(big.Int).SetUint64(10000000000)}, - acc3Addr: {Balance: new(big.Int).SetUint64(10000000000)}, - voterAddr: {Balance: new(big.Int).SetUint64(10000000000)}, - common.HexToAddress(common.MasternodeVotingSMC): {Balance: new(big.Int).SetUint64(1), Code: code, Storage: storage}, // Binding the MasternodeVotingSMC with newly created 'code' for SC execution + acc1Addr: {Balance: new(big.Int).SetUint64(10000000000)}, + acc2Addr: {Balance: new(big.Int).SetUint64(10000000000)}, + acc3Addr: {Balance: new(big.Int).SetUint64(10000000000)}, + voterAddr: {Balance: new(big.Int).SetUint64(10000000000)}, + common.MasternodeVotingSMCBinary: {Balance: new(big.Int).SetUint64(1), Code: code, Storage: storage}, // Binding the MasternodeVotingSMC with newly created 'code' for SC execution }, 10000000, chainConfig) return contractBackend2 @@ -163,7 +164,7 @@ func transferTx(t *testing.T, to common.Address, transferAmount int64) *types.Tr amount := big.NewInt(transferAmount) nonce := uint64(1) tx := types.NewTransaction(nonce, to, amount, gasLimit, gasPrice, data) - signedTX, err := types.SignTx(tx, types.NewEIP155Signer(big.NewInt(chainID)), voterKey) + signedTX, err := types.SignTx(tx, types.LatestSignerForChainID(big.NewInt(chainID)), voterKey) if err != nil { t.Fatal(err) } @@ -178,12 +179,12 @@ func voteTX(gasLimit uint64, nonce uint64, addr string) (*types.Transaction, err amountInt := new(big.Int) amount, ok := amountInt.SetString("60000", 10) if !ok { - return nil, fmt.Errorf("big int init failed") + return nil, errors.New("big int init failed") } - to := common.HexToAddress(common.MasternodeVotingSMC) + to := common.MasternodeVotingSMCBinary tx := types.NewTransaction(nonce, to, amount, gasLimit, gasPrice, data) - signedTX, err := types.SignTx(tx, types.NewEIP155Signer(big.NewInt(chainID)), voterKey) + signedTX, err := types.SignTx(tx, types.LatestSignerForChainID(big.NewInt(chainID)), voterKey) if err != nil { return nil, err } @@ -213,7 +214,7 @@ func GetSnapshotSigner(bc *BlockChain, header *types.Header) (signersList, error } func GetCandidateFromCurrentSmartContract(backend bind.ContractBackend, t *testing.T) masterNodes { - addr := common.HexToAddress(common.MasternodeVotingSMC) + addr := common.MasternodeVotingSMCBinary validator, err := contractValidator.NewXDCValidator(addr, backend) if err != nil { t.Fatal(err) @@ -321,7 +322,7 @@ func CreateBlock(blockchain *BlockChain, chainConfig *params.ChainConfig, starti // Sign all the things for v1 block use v1 sigHash function sighash, err := signFn(accounts.Account{Address: signer}, blockchain.Engine().(*XDPoS.XDPoS).SigHash(header).Bytes()) if err != nil { - panic(fmt.Errorf("Error when sign last v1 block hash during test block creation")) + panic(errors.New("Error when sign last v1 block hash during test block creation")) } copy(header.Extra[len(header.Extra)-utils.ExtraSeal:], sighash) } @@ -411,7 +412,7 @@ func createBlockFromHeader(bc *BlockChain, customHeader *types.Header, txs []*ty // nonce := uint64(0) // to := common.HexToAddress("xdc35658f7b2a9e7701e65e7a654659eb1c481d1dc5") // tx := types.NewTransaction(nonce, to, amount, gasLimit, gasPrice, data) -// signedTX, err := types.SignTx(tx, types.NewEIP155Signer(big.NewInt(chainID)), acc4Key) +// signedTX, err := types.SignTx(tx, types.LatestSignerForChainID(big.NewInt(chainID)), acc4Key) // if err != nil { // t.Fatal(err) // } diff --git a/consensus/tests/engine_v2_tests/helper.go b/consensus/tests/engine_v2_tests/helper.go index f11baf1b2d..454f482d71 100644 --- a/consensus/tests/engine_v2_tests/helper.go +++ b/consensus/tests/engine_v2_tests/helper.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ecdsa" "encoding/hex" + "errors" "fmt" "math/big" "math/rand" @@ -103,12 +104,12 @@ func voteTX(gasLimit uint64, nonce uint64, addr string) (*types.Transaction, err amountInt := new(big.Int) amount, ok := amountInt.SetString("60000", 10) if !ok { - return nil, fmt.Errorf("big int init failed") + return nil, errors.New("big int init failed") } - to := common.HexToAddress(common.MasternodeVotingSMC) + to := common.MasternodeVotingSMCBinary tx := types.NewTransaction(nonce, to, amount, gasLimit, gasPrice, data) - signedTX, err := types.SignTx(tx, types.NewEIP155Signer(big.NewInt(chainID)), voterKey) + signedTX, err := types.SignTx(tx, types.LatestSignerForChainID(big.NewInt(chainID)), voterKey) if err != nil { return nil, err } @@ -189,11 +190,11 @@ func getCommonBackend(t *testing.T, chainConfig *params.ChainConfig) *backends.S // create test backend with smart contract in it contractBackend2 := backends.NewXDCSimulatedBackend(core.GenesisAlloc{ - acc1Addr: {Balance: new(big.Int).SetUint64(10000000000)}, - acc2Addr: {Balance: new(big.Int).SetUint64(10000000000)}, - acc3Addr: {Balance: new(big.Int).SetUint64(10000000000)}, - voterAddr: {Balance: new(big.Int).SetUint64(10000000000)}, - common.HexToAddress(common.MasternodeVotingSMC): {Balance: new(big.Int).SetUint64(1), Code: code, Storage: storage}, // Binding the MasternodeVotingSMC with newly created 'code' for SC execution + acc1Addr: {Balance: new(big.Int).SetUint64(10000000000)}, + acc2Addr: {Balance: new(big.Int).SetUint64(10000000000)}, + acc3Addr: {Balance: new(big.Int).SetUint64(10000000000)}, + voterAddr: {Balance: new(big.Int).SetUint64(10000000000)}, + common.MasternodeVotingSMCBinary: {Balance: new(big.Int).SetUint64(1), Code: code, Storage: storage}, // Binding the MasternodeVotingSMC with newly created 'code' for SC execution }, 10000000, chainConfig) return contractBackend2 @@ -273,19 +274,19 @@ func getMultiCandidatesBackend(t *testing.T, chainConfig *params.ChainConfig, n // create test backend with smart contract in it contractBackend2 := backends.NewXDCSimulatedBackend(core.GenesisAlloc{ - acc1Addr: {Balance: new(big.Int).SetUint64(10000000000)}, - acc2Addr: {Balance: new(big.Int).SetUint64(10000000000)}, - acc3Addr: {Balance: new(big.Int).SetUint64(10000000000)}, - voterAddr: {Balance: new(big.Int).SetUint64(10000000000)}, - common.HexToAddress(common.MasternodeVotingSMC): {Balance: new(big.Int).SetUint64(1), Code: code, Storage: storage}, // Binding the MasternodeVotingSMC with newly created 'code' for SC execution + acc1Addr: {Balance: new(big.Int).SetUint64(10000000000)}, + acc2Addr: {Balance: new(big.Int).SetUint64(10000000000)}, + acc3Addr: {Balance: new(big.Int).SetUint64(10000000000)}, + voterAddr: {Balance: new(big.Int).SetUint64(10000000000)}, + common.MasternodeVotingSMCBinary: {Balance: new(big.Int).SetUint64(1), Code: code, Storage: storage}, // Binding the MasternodeVotingSMC with newly created 'code' for SC execution }, 10000000, chainConfig) return contractBackend2 } func signingTxWithKey(header *types.Header, nonce uint64, privateKey *ecdsa.PrivateKey) (*types.Transaction, error) { - tx := contracts.CreateTxSign(header.Number, header.Hash(), nonce, common.HexToAddress(common.BlockSigners)) - s := types.NewEIP155Signer(big.NewInt(chainID)) + tx := contracts.CreateTxSign(header.Number, header.Hash(), nonce, common.BlockSignersBinary) + s := types.LatestSignerForChainID(big.NewInt(chainID)) h := s.Hash(tx) sig, err := crypto.Sign(h[:], privateKey) if err != nil { @@ -299,8 +300,8 @@ func signingTxWithKey(header *types.Header, nonce uint64, privateKey *ecdsa.Priv } func signingTxWithSignerFn(header *types.Header, nonce uint64, signer common.Address, signFn func(account accounts.Account, hash []byte) ([]byte, error)) (*types.Transaction, error) { - tx := contracts.CreateTxSign(header.Number, header.Hash(), nonce, common.HexToAddress(common.BlockSigners)) - s := types.NewEIP155Signer(big.NewInt(chainID)) + tx := contracts.CreateTxSign(header.Number, header.Hash(), nonce, common.BlockSignersBinary) + s := types.LatestSignerForChainID(big.NewInt(chainID)) h := s.Hash(tx) sig, err := signFn(accounts.Account{Address: signer}, h[:]) if err != nil { @@ -335,7 +336,7 @@ func GetSnapshotSigner(bc *BlockChain, header *types.Header) (signersList, error } func GetCandidateFromCurrentSmartContract(backend bind.ContractBackend, t *testing.T) masterNodes { - addr := common.HexToAddress(common.MasternodeVotingSMC) + addr := common.MasternodeVotingSMCBinary validator, err := contractValidator.NewXDCValidator(addr, backend) if err != nil { t.Fatal(err) @@ -633,7 +634,7 @@ func CreateBlock(blockchain *BlockChain, chainConfig *params.ChainConfig, starti // Sign all the things for v1 block use v1 sigHash function sighash, err := signFn(accounts.Account{Address: signer}, blockchain.Engine().(*XDPoS.XDPoS).SigHash(header).Bytes()) if err != nil { - panic(fmt.Errorf("Error when sign last v1 block hash during test block creation")) + panic(errors.New("Error when sign last v1 block hash during test block creation")) } copy(header.Extra[len(header.Extra)-utils.ExtraSeal:], sighash) } @@ -737,7 +738,7 @@ func findSignerAndSignFn(bc *BlockChain, header *types.Header, signer common.Add var decodedExtraField types.ExtraFields_v2 err := utils.DecodeBytesExtraFields(header.Extra, &decodedExtraField) if err != nil { - panic(fmt.Errorf("fail to seal header for v2 block")) + panic(errors.New("fail to seal header for v2 block")) } round := decodedExtraField.Round masterNodes := getMasternodesList(signer) @@ -757,7 +758,7 @@ func findSignerAndSignFn(bc *BlockChain, header *types.Header, signer common.Add } addressedSignFn = signFn if err != nil { - panic(fmt.Errorf("Error trying to use one of the pre-defined private key to sign")) + panic(errors.New("Error trying to use one of the pre-defined private key to sign")) } } diff --git a/console/bridge.go b/console/bridge.go index 80606a6878..80f8095336 100644 --- a/console/bridge.go +++ b/console/bridge.go @@ -18,6 +18,7 @@ package console import ( "encoding/json" + "errors" "fmt" "io" "reflect" @@ -75,18 +76,18 @@ func (b *bridge) NewAccount(call jsre.Call) (goja.Value, error) { return nil, err } if password != confirm { - return nil, fmt.Errorf("passwords don't match!") + return nil, errors.New("passwords don't match!") } // A single string password was specified, use that case len(call.Arguments) == 1 && call.Argument(0).ToString() != nil: password = call.Argument(0).ToString().String() default: - return nil, fmt.Errorf("expected 0 or 1 string argument") + return nil, errors.New("expected 0 or 1 string argument") } // Password acquired, execute the call and return newAccount, callable := goja.AssertFunction(getJeth(call.VM).Get("newAccount")) if !callable { - return nil, fmt.Errorf("jeth.newAccount is not callable") + return nil, errors.New("jeth.newAccount is not callable") } ret, err := newAccount(goja.Null(), call.VM.ToValue(password)) if err != nil { @@ -100,7 +101,7 @@ func (b *bridge) NewAccount(call jsre.Call) (goja.Value, error) { func (b *bridge) OpenWallet(call jsre.Call) (goja.Value, error) { // Make sure we have a wallet specified to open if call.Argument(0).ToObject(call.VM).ClassName() != "String" { - return nil, fmt.Errorf("first argument must be the wallet URL to open") + return nil, errors.New("first argument must be the wallet URL to open") } wallet := call.Argument(0) @@ -113,7 +114,7 @@ func (b *bridge) OpenWallet(call jsre.Call) (goja.Value, error) { // Open the wallet and return if successful in itself openWallet, callable := goja.AssertFunction(getJeth(call.VM).Get("openWallet")) if !callable { - return nil, fmt.Errorf("jeth.openWallet is not callable") + return nil, errors.New("jeth.openWallet is not callable") } val, err := openWallet(goja.Null(), wallet, passwd) if err == nil { @@ -147,7 +148,7 @@ func (b *bridge) readPassphraseAndReopenWallet(call jsre.Call) (goja.Value, erro } openWallet, callable := goja.AssertFunction(getJeth(call.VM).Get("openWallet")) if !callable { - return nil, fmt.Errorf("jeth.openWallet is not callable") + return nil, errors.New("jeth.openWallet is not callable") } return openWallet(goja.Null(), wallet, call.VM.ToValue(input)) } @@ -168,7 +169,7 @@ func (b *bridge) readPinAndReopenWallet(call jsre.Call) (goja.Value, error) { } openWallet, callable := goja.AssertFunction(getJeth(call.VM).Get("openWallet")) if !callable { - return nil, fmt.Errorf("jeth.openWallet is not callable") + return nil, errors.New("jeth.openWallet is not callable") } return openWallet(goja.Null(), wallet, call.VM.ToValue(input)) } @@ -180,7 +181,7 @@ func (b *bridge) readPinAndReopenWallet(call jsre.Call) (goja.Value, error) { func (b *bridge) UnlockAccount(call jsre.Call) (goja.Value, error) { // Make sure we have an account specified to unlock. if call.Argument(0).ExportType().Kind() != reflect.String { - return nil, fmt.Errorf("first argument must be the account to unlock") + return nil, errors.New("first argument must be the account to unlock") } account := call.Argument(0) @@ -195,7 +196,7 @@ func (b *bridge) UnlockAccount(call jsre.Call) (goja.Value, error) { passwd = call.VM.ToValue(input) } else { if call.Argument(1).ExportType().Kind() != reflect.String { - return nil, fmt.Errorf("password must be a string") + return nil, errors.New("password must be a string") } passwd = call.Argument(1) } @@ -204,7 +205,7 @@ func (b *bridge) UnlockAccount(call jsre.Call) (goja.Value, error) { duration := goja.Null() if !goja.IsUndefined(call.Argument(2)) && !goja.IsNull(call.Argument(2)) { if !isNumber(call.Argument(2)) { - return nil, fmt.Errorf("unlock duration must be a number") + return nil, errors.New("unlock duration must be a number") } duration = call.Argument(2) } @@ -212,7 +213,7 @@ func (b *bridge) UnlockAccount(call jsre.Call) (goja.Value, error) { // Send the request to the backend and return. unlockAccount, callable := goja.AssertFunction(getJeth(call.VM).Get("unlockAccount")) if !callable { - return nil, fmt.Errorf("jeth.unlockAccount is not callable") + return nil, errors.New("jeth.unlockAccount is not callable") } return unlockAccount(goja.Null(), account, passwd, duration) } @@ -228,10 +229,10 @@ func (b *bridge) Sign(call jsre.Call) (goja.Value, error) { ) if message.ExportType().Kind() != reflect.String { - return nil, fmt.Errorf("first argument must be the message to sign") + return nil, errors.New("first argument must be the message to sign") } if account.ExportType().Kind() != reflect.String { - return nil, fmt.Errorf("second argument must be the account to sign with") + return nil, errors.New("second argument must be the account to sign with") } // if the password is not given or null ask the user and ensure password is a string @@ -243,13 +244,13 @@ func (b *bridge) Sign(call jsre.Call) (goja.Value, error) { } passwd = call.VM.ToValue(input) } else if passwd.ExportType().Kind() != reflect.String { - return nil, fmt.Errorf("third argument must be the password to unlock the account") + return nil, errors.New("third argument must be the password to unlock the account") } // Send the request to the backend and return sign, callable := goja.AssertFunction(getJeth(call.VM).Get("unlockAccount")) if !callable { - return nil, fmt.Errorf("jeth.unlockAccount is not callable") + return nil, errors.New("jeth.unlockAccount is not callable") } return sign(goja.Null(), message, account, passwd) } @@ -257,7 +258,7 @@ func (b *bridge) Sign(call jsre.Call) (goja.Value, error) { // Sleep will block the console for the specified number of seconds. func (b *bridge) Sleep(call jsre.Call) (goja.Value, error) { if !isNumber(call.Argument(0)) { - return nil, fmt.Errorf("usage: sleep()") + return nil, errors.New("usage: sleep()") } sleep := call.Argument(0).ToFloat() time.Sleep(time.Duration(sleep * float64(time.Second))) @@ -274,17 +275,17 @@ func (b *bridge) SleepBlocks(call jsre.Call) (goja.Value, error) { ) nArgs := len(call.Arguments) if nArgs == 0 { - return nil, fmt.Errorf("usage: sleepBlocks([, max sleep in seconds])") + return nil, errors.New("usage: sleepBlocks([, max sleep in seconds])") } if nArgs >= 1 { if !isNumber(call.Argument(0)) { - return nil, fmt.Errorf("expected number as first argument") + return nil, errors.New("expected number as first argument") } blocks = call.Argument(0).ToInteger() } if nArgs >= 2 { if isNumber(call.Argument(1)) { - return nil, fmt.Errorf("expected number as second argument") + return nil, errors.New("expected number as second argument") } sleep = call.Argument(1).ToInteger() } @@ -361,7 +362,7 @@ func (b *bridge) Send(call jsre.Call) (goja.Value, error) { JSON := call.VM.Get("JSON").ToObject(call.VM) parse, callable := goja.AssertFunction(JSON.Get("parse")) if !callable { - return nil, fmt.Errorf("JSON.parse is not a function") + return nil, errors.New("JSON.parse is not a function") } resultVal, err := parse(goja.Null(), call.VM.ToValue(string(result))) if err != nil { diff --git a/contracts/trc21issuer/trc21issuer_test.go b/contracts/trc21issuer/trc21issuer_test.go index e0744b26d1..698aa3f97d 100644 --- a/contracts/trc21issuer/trc21issuer_test.go +++ b/contracts/trc21issuer/trc21issuer_test.go @@ -91,10 +91,13 @@ func TestFeeTxWithTRC21Token(t *testing.T) { t.Fatal("check balance after fail transfer in tr20: ", err, "get", balance, "transfer", airDropAmount) } - //check balance fee + // check balance fee balanceIssuerFee, err = trc21Issuer.GetTokenCapacity(trc21TokenAddr) - if err != nil || balanceIssuerFee.Cmp(remainFee) != 0 { - t.Fatal("can't get balance token fee in smart contract: ", err, "got", balanceIssuerFee, "wanted", remainFee) + if err != nil { + t.Fatal("can't get balance token fee in smart contract: ", err) + } + if balanceIssuerFee.Cmp(remainFee) != 0 { + t.Fatal("check balance token fee in smart contract: got", balanceIssuerFee, "wanted", remainFee) } //check trc21 SMC balance balance, err = contractBackend.BalanceAt(nil, trc21IssuerAddr, nil) diff --git a/contracts/utils.go b/contracts/utils.go index 8f605a5e86..6aba084f9a 100644 --- a/contracts/utils.go +++ b/contracts/utils.go @@ -87,7 +87,7 @@ func CreateTransactionSign(chainConfig *params.ChainConfig, pool *core.TxPool, m // Create and send tx to smart contract for sign validate block. nonce := pool.Nonce(account.Address) - tx := CreateTxSign(block.Number(), block.Hash(), nonce, common.HexToAddress(common.BlockSigners)) + tx := CreateTxSign(block.Number(), block.Hash(), nonce, common.BlockSignersBinary) txSigned, err := wallet.SignTx(account, tx, chainConfig.ChainId) if err != nil { log.Error("Fail to create tx sign", "error", err) @@ -112,7 +112,7 @@ func CreateTransactionSign(chainConfig *params.ChainConfig, pool *core.TxPool, m // Only process when private key empty in state db. // Save randomize key into state db. randomizeKeyValue := RandStringByte(32) - tx, err := BuildTxSecretRandomize(nonce+1, common.HexToAddress(common.RandomizeSMC), chainConfig.XDPoS.Epoch, randomizeKeyValue) + tx, err := BuildTxSecretRandomize(nonce+1, common.RandomizeSMCBinary, chainConfig.XDPoS.Epoch, randomizeKeyValue) if err != nil { log.Error("Fail to get tx opening for randomize", "error", err) return err @@ -141,7 +141,7 @@ func CreateTransactionSign(chainConfig *params.ChainConfig, pool *core.TxPool, m return err } - tx, err := BuildTxOpeningRandomize(nonce+1, common.HexToAddress(common.RandomizeSMC), randomizeKeyValue) + tx, err := BuildTxOpeningRandomize(nonce+1, common.RandomizeSMCBinary, randomizeKeyValue) if err != nil { log.Error("Fail to get tx opening for randomize", "error", err) return err @@ -232,7 +232,7 @@ func GetSignersByExecutingEVM(addrBlockSigner common.Address, client bind.Contra // Get random from randomize contract. func GetRandomizeFromContract(client bind.ContractBackend, addrMasternode common.Address) (int64, error) { - randomize, err := randomizeContract.NewXDCRandomize(common.HexToAddress(common.RandomizeSMC), client) + randomize, err := randomizeContract.NewXDCRandomize(common.RandomizeSMCBinary, client) if err != nil { log.Error("Fail to get instance of randomize", "error", err) } diff --git a/contracts/validator/validator_test.go b/contracts/validator/validator_test.go index 047f34deac..6ebe924b5c 100644 --- a/contracts/validator/validator_test.go +++ b/contracts/validator/validator_test.go @@ -150,7 +150,7 @@ func TestRewardBalance(t *testing.T) { logCaps[i] = &logCap{accounts[randIndex].From.String(), randCap} } - foundationAddr := common.HexToAddress(common.FoudationAddr) + foundationAddr := common.FoudationAddrBinary totalReward := new(big.Int).SetInt64(15 * 1000) rewards, err := GetRewardBalancesRate(foundationAddr, acc3Addr, totalReward, baseValidator) if err != nil { @@ -309,13 +309,13 @@ func TestStatedbUtils(t *testing.T) { return true } contractBackend.ForEachStorageAt(ctx, validatorAddress, nil, f) - genesisAlloc[common.HexToAddress(common.MasternodeVotingSMC)] = core.GenesisAccount{ + genesisAlloc[common.MasternodeVotingSMCBinary] = core.GenesisAccount{ Balance: validatorCap, Code: code, Storage: storage, } contractBackendForValidator := backends.NewXDCSimulatedBackend(genesisAlloc, 10000000, params.TestXDPoSMockChainConfig) - validator, err := NewValidator(transactOpts, common.HexToAddress(common.MasternodeVotingSMC), contractBackendForValidator) + validator, err := NewValidator(transactOpts, common.MasternodeVotingSMCBinary, contractBackendForValidator) if err != nil { t.Fatalf("can't get validator object: %v", err) } @@ -379,4 +379,4 @@ func TestStatedbUtils(t *testing.T) { t.Fatalf("cap should not be zero") } } -} \ No newline at end of file +} diff --git a/core/bench_test.go b/core/bench_test.go index 5884292402..7cfed07f45 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -85,7 +85,7 @@ func genValueTx(nbytes int) func(int, *BlockGen) { return func(i int, gen *BlockGen) { toaddr := common.Address{} data := make([]byte, nbytes) - gas, _ := IntrinsicGas(data, false, false) + gas, _ := IntrinsicGas(data, nil, false, false) tx, _ := types.SignTx(types.NewTransaction(gen.TxNonce(benchRootAddr), toaddr, big.NewInt(1), gas, nil, data), types.HomesteadSigner{}, benchRootKey) gen.AddTx(tx) } diff --git a/core/block_validator.go b/core/block_validator.go index a4144aa495..e713342d9c 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -17,6 +17,7 @@ package core import ( + "errors" "fmt" "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" @@ -113,7 +114,7 @@ func (v *BlockValidator) ValidateTradingOrder(statedb *state.StateDB, XDCxStated } XDCXService := XDPoSEngine.GetXDCXService() if XDCXService == nil { - return fmt.Errorf("XDCx not found") + return errors.New("XDCx not found") } log.Debug("verify matching transaction found a TxMatches Batch", "numTxMatches", len(txMatchBatch.Data)) tradingResult := map[common.Hash]tradingstate.MatchingResult{} @@ -149,11 +150,11 @@ func (v *BlockValidator) ValidateLendingOrder(statedb *state.StateDB, lendingSta } XDCXService := XDPoSEngine.GetXDCXService() if XDCXService == nil { - return fmt.Errorf("XDCx not found") + return errors.New("XDCx not found") } lendingService := XDPoSEngine.GetLendingService() if lendingService == nil { - return fmt.Errorf("lendingService not found") + return errors.New("lendingService not found") } log.Debug("verify lendingItem ", "numItems", len(batch.Data)) lendingResult := map[common.Hash]lendingstate.MatchingResult{} diff --git a/core/blockchain.go b/core/blockchain.go index 090a52b4bf..927e4628df 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -58,6 +58,11 @@ var ( blockInsertTimer = metrics.NewRegisteredTimer("chain/inserts", nil) CheckpointCh = make(chan int) ErrNoGenesis = errors.New("Genesis not found in chain") + + blockReorgMeter = metrics.NewRegisteredMeter("chain/reorg/executes", nil) + blockReorgAddMeter = metrics.NewRegisteredMeter("chain/reorg/add", nil) + blockReorgDropMeter = metrics.NewRegisteredMeter("chain/reorg/drop", nil) + blockReorgInvalidatedTx = metrics.NewRegisteredMeter("chain/reorg/invalidTx", nil) ) const ( @@ -252,6 +257,11 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par return bc, nil } +// GetVMConfig returns the block chain VM config. +func (bc *BlockChain) GetVMConfig() *vm.Config { + return &bc.vmConfig +} + // NewBlockChainEx extend old blockchain, add order state db func NewBlockChainEx(db ethdb.Database, XDCxDb ethdb.XDCxDatabase, cacheConfig *CacheConfig, chainConfig *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config) (*BlockChain, error) { blockchain, err := NewBlockChain(db, cacheConfig, chainConfig, engine, vmConfig) @@ -2187,10 +2197,10 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { } } if oldBlock == nil { - return fmt.Errorf("Invalid old chain") + return errors.New("Invalid old chain") } if newBlock == nil { - return fmt.Errorf("Invalid new chain") + return errors.New("Invalid new chain") } for { @@ -2206,10 +2216,10 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { oldBlock, newBlock = bc.GetBlock(oldBlock.ParentHash(), oldBlock.NumberU64()-1), bc.GetBlock(newBlock.ParentHash(), newBlock.NumberU64()-1) if oldBlock == nil { - return fmt.Errorf("Invalid old chain") + return errors.New("Invalid old chain") } if newBlock == nil { - return fmt.Errorf("Invalid new chain") + return errors.New("Invalid new chain") } } // Ensure XDPoS engine committed block will be not reverted @@ -2245,6 +2255,9 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { } logFn("Chain split detected", "number", commonBlock.Number(), "hash", commonBlock.Hash(), "drop", len(oldChain), "dropfrom", oldChain[0].Hash(), "add", len(newChain), "addfrom", newChain[0].Hash()) + blockReorgAddMeter.Mark(int64(len(newChain))) + blockReorgDropMeter.Mark(int64(len(oldChain))) + blockReorgMeter.Mark(1) } else { log.Error("Impossible reorg, please file an issue", "oldnum", oldBlock.Number(), "oldhash", oldBlock.Hash(), "newnum", newBlock.Number(), "newhash", newBlock.Hash()) } @@ -2549,7 +2562,7 @@ func (bc *BlockChain) UpdateM1() error { if err != nil { return err } - addr := common.HexToAddress(common.MasternodeVotingSMC) + addr := common.MasternodeVotingSMCBinary validator, err := contractValidator.NewXDCValidator(addr, client) if err != nil { return err diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 9dc4b65117..c19a68dcc1 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -561,7 +561,7 @@ func TestFastVsFullChains(t *testing.T) { Alloc: GenesisAlloc{address: {Balance: funds}}, } genesis = gspec.MustCommit(gendb) - signer = types.NewEIP155Signer(gspec.Config.ChainId) + signer = types.LatestSigner(gspec.Config) ) blocks, receipts := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), gendb, 1024, func(i int, block *BlockGen) { block.SetCoinbase(common.Address{0x00}) @@ -736,7 +736,7 @@ func TestChainTxReorgs(t *testing.T) { }, } genesis = gspec.MustCommit(db) - signer = types.NewEIP155Signer(gspec.Config.ChainId) + signer = types.LatestSigner(gspec.Config) ) // Create two transactions shared between the chains: @@ -842,7 +842,7 @@ func TestLogReorgs(t *testing.T) { code = common.Hex2Bytes("60606040525b7f24ec1d3ff24c2f6ff210738839dbc339cd45a5294d85c79361016243157aae7b60405180905060405180910390a15b600a8060416000396000f360606040526008565b00") gspec = &Genesis{Config: params.TestChainConfig, Alloc: GenesisAlloc{addr1: {Balance: big.NewInt(10000000000000)}}} genesis = gspec.MustCommit(db) - signer = types.NewEIP155Signer(gspec.Config.ChainId) + signer = types.LatestSigner(gspec.Config) ) blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) @@ -889,7 +889,7 @@ func TestLogReorgs(t *testing.T) { // Alloc: GenesisAlloc{addr1: {Balance: big.NewInt(10000000000000)}}, // } // genesis = gspec.MustCommit(db) -// signer = types.NewEIP155Signer(gspec.Config.ChainId) +// signer = types.LatestSigner(gspec.Config) // ) // // blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) @@ -1015,7 +1015,7 @@ func TestEIP155Transition(t *testing.T) { funds = big.NewInt(1000000000) deleteAddr = common.Address{1} gspec = &Genesis{ - Config: ¶ms.ChainConfig{ChainId: big.NewInt(1), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)}, + Config: ¶ms.ChainConfig{ChainId: big.NewInt(1), EIP150Block: big.NewInt(0), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)}, Alloc: GenesisAlloc{address: {Balance: funds}, deleteAddr: {Balance: new(big.Int)}}, } genesis = gspec.MustCommit(db) @@ -1046,7 +1046,7 @@ func TestEIP155Transition(t *testing.T) { } block.AddTx(tx) - tx, err = basicTx(types.NewEIP155Signer(gspec.Config.ChainId)) + tx, err = basicTx(types.LatestSigner(gspec.Config)) if err != nil { t.Fatal(err) } @@ -1058,7 +1058,7 @@ func TestEIP155Transition(t *testing.T) { } block.AddTx(tx) - tx, err = basicTx(types.NewEIP155Signer(gspec.Config.ChainId)) + tx, err = basicTx(types.LatestSigner(gspec.Config)) if err != nil { t.Fatal(err) } @@ -1086,7 +1086,7 @@ func TestEIP155Transition(t *testing.T) { } // generate an invalid chain id transaction - config := ¶ms.ChainConfig{ChainId: big.NewInt(2), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)} + config := ¶ms.ChainConfig{ChainId: big.NewInt(2), EIP150Block: big.NewInt(0), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)} blocks, _ = GenerateChain(config, blocks[len(blocks)-1], ethash.NewFaker(), db, 4, func(i int, block *BlockGen) { var ( tx *types.Transaction @@ -1095,9 +1095,8 @@ func TestEIP155Transition(t *testing.T) { return types.SignTx(types.NewTransaction(block.TxNonce(address), common.Address{}, new(big.Int), 21000, new(big.Int), nil), signer, key) } ) - switch i { - case 0: - tx, err = basicTx(types.NewEIP155Signer(big.NewInt(2))) + if i == 0 { + tx, err = basicTx(types.LatestSigner(config)) if err != nil { t.Fatal(err) } @@ -1136,7 +1135,7 @@ func TestEIP161AccountRemoval(t *testing.T) { var ( tx *types.Transaction err error - signer = types.NewEIP155Signer(gspec.Config.ChainId) + signer = types.LatestSigner(gspec.Config) ) switch i { case 0: @@ -1167,7 +1166,7 @@ func TestEIP161AccountRemoval(t *testing.T) { t.Error("account should not exist") } - // account musn't be created post eip 161 + // account mustn't be created post eip 161 if _, err := blockchain.InsertChain(types.Blocks{blocks[2]}); err != nil { t.Fatal(err) } @@ -1412,3 +1411,93 @@ func TestAreTwoBlocksSamePath(t *testing.T) { }) } + +// TestEIP2718Transition tests that an EIP-2718 transaction will be accepted +// after the fork block has passed. This is verified by sending an EIP-2930 +// access list transaction, which specifies a single slot access, and then +// checking that the gas usage of a hot SLOAD and a cold SLOAD are calculated +// correctly. +func TestEIP2718Transition(t *testing.T) { + var ( + aa = common.HexToAddress("0x000000000000000000000000000000000000aaaa") + + // Generate a canonical chain to act as the main dataset + engine = ethash.NewFaker() + db = rawdb.NewMemoryDatabase() + + // A sender who makes transactions, has some funds + key, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + address = crypto.PubkeyToAddress(key.PublicKey) + funds = big.NewInt(1000000000) + gspec = &Genesis{ + Config: ¶ms.ChainConfig{ + ChainId: new(big.Int).SetBytes([]byte("eip1559")), + HomesteadBlock: big.NewInt(0), + DAOForkBlock: nil, + DAOForkSupport: true, + EIP150Block: big.NewInt(0), + EIP155Block: big.NewInt(0), + EIP158Block: big.NewInt(0), + ByzantiumBlock: big.NewInt(0), + ConstantinopleBlock: big.NewInt(0), + PetersburgBlock: big.NewInt(0), + IstanbulBlock: big.NewInt(0), + Eip1559Block: big.NewInt(0), + }, + Alloc: GenesisAlloc{ + address: {Balance: funds}, + // The address 0xAAAA sloads 0x00 and 0x01 + aa: { + Code: []byte{ + byte(vm.PC), + byte(vm.PC), + byte(vm.SLOAD), + byte(vm.SLOAD), + }, + Nonce: 0, + Balance: big.NewInt(0), + }, + }, + } + genesis = gspec.MustCommit(db) + ) + + blocks, _ := GenerateChain(gspec.Config, genesis, engine, db, 1, func(i int, b *BlockGen) { + b.SetCoinbase(common.Address{1}) + + // One transaction to 0xAAAA + signer := types.LatestSigner(gspec.Config) + tx, _ := types.SignNewTx(key, signer, &types.AccessListTx{ + ChainID: gspec.Config.ChainId, + Nonce: 0, + To: &aa, + Gas: 30000, + GasPrice: big.NewInt(1), + AccessList: types.AccessList{{ + Address: aa, + StorageKeys: []common.Hash{{0}}, + }}, + }) + b.AddTx(tx) + }) + + // Import the canonical chain + diskdb := rawdb.NewMemoryDatabase() + gspec.MustCommit(diskdb) + + chain, err := NewBlockChain(diskdb, nil, gspec.Config, engine, vm.Config{}) + if err != nil { + t.Fatalf("failed to create tester chain: %v", err) + } + if n, err := chain.InsertChain(blocks); err != nil { + t.Fatalf("block %d: failed to insert into chain: %v", n, err) + } + + block := chain.GetBlockByNumber(1) + + // Expected gas is intrinsic + 2 * pc + hot load + cold load, since only one load is in the access list + expected := params.TxGas + params.TxAccessListAddressGas + params.TxAccessListStorageKeyGas + vm.GasQuickStep*2 + vm.WarmStorageReadCostEIP2929 + vm.ColdSloadCostEIP2929 + if block.GasUsed() != expected { + t.Fatalf("incorrect amount of gas spent: expected %d, got %d", expected, block.GasUsed()) + } +} diff --git a/core/chain_indexer.go b/core/chain_indexer.go index 7be107be85..dd6466ab37 100644 --- a/core/chain_indexer.go +++ b/core/chain_indexer.go @@ -18,6 +18,7 @@ package core import ( "encoding/binary" + "errors" "fmt" "sync" "sync/atomic" @@ -357,7 +358,7 @@ func (c *ChainIndexer) processSection(section uint64, lastHead common.Hash) (com if header == nil { return common.Hash{}, fmt.Errorf("block #%d [%x…] not found", number, hash[:4]) } else if header.ParentHash != lastHead { - return common.Hash{}, fmt.Errorf("chain reorged during section processing") + return common.Hash{}, errors.New("chain reorged during section processing") } c.backend.Process(header) lastHead = header.Hash() diff --git a/core/database_util_test.go b/core/database_util_test.go index a38a68fd41..ecd843a3e8 100644 --- a/core/database_util_test.go +++ b/core/database_util_test.go @@ -18,11 +18,11 @@ package core import ( "bytes" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math/big" "testing" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/crypto/sha3" "github.com/XinFinOrg/XDPoSChain/rlp" @@ -335,6 +335,10 @@ func TestLookupStorage(t *testing.T) { func TestBlockReceiptStorage(t *testing.T) { db := rawdb.NewMemoryDatabase() + // Create a live block since we need metadata to reconstruct the receipt + tx1 := types.NewTransaction(1, common.HexToAddress("0x1"), big.NewInt(1), 1, big.NewInt(1), nil) + tx2 := types.NewTransaction(2, common.HexToAddress("0x2"), big.NewInt(2), 2, big.NewInt(2), nil) + receipt1 := &types.Receipt{ Status: types.ReceiptStatusFailed, CumulativeGasUsed: 1, @@ -342,10 +346,12 @@ func TestBlockReceiptStorage(t *testing.T) { {Address: common.BytesToAddress([]byte{0x11})}, {Address: common.BytesToAddress([]byte{0x01, 0x11})}, }, - TxHash: common.BytesToHash([]byte{0x11, 0x11}), + TxHash: tx1.Hash(), ContractAddress: common.BytesToAddress([]byte{0x01, 0x11, 0x11}), GasUsed: 111111, } + receipt1.Bloom = types.CreateBloom(types.Receipts{receipt1}) + receipt2 := &types.Receipt{ PostState: common.Hash{2}.Bytes(), CumulativeGasUsed: 2, @@ -353,10 +359,12 @@ func TestBlockReceiptStorage(t *testing.T) { {Address: common.BytesToAddress([]byte{0x22})}, {Address: common.BytesToAddress([]byte{0x02, 0x22})}, }, - TxHash: common.BytesToHash([]byte{0x22, 0x22}), + TxHash: tx2.Hash(), ContractAddress: common.BytesToAddress([]byte{0x02, 0x22, 0x22}), GasUsed: 222222, } + receipt2.Bloom = types.CreateBloom(types.Receipts{receipt2}) + receipts := []*types.Receipt{receipt1, receipt2} // Check that no receipt entries are in a pristine database diff --git a/core/error.go b/core/error.go index f6559bf06f..6268c0dc9a 100644 --- a/core/error.go +++ b/core/error.go @@ -16,7 +16,11 @@ package core -import "errors" +import ( + "errors" + + "github.com/XinFinOrg/XDPoSChain/core/types" +) var ( // ErrKnownBlock is returned when a block to import is already known locally. @@ -38,4 +42,8 @@ var ( ErrNotFoundM1 = errors.New("list M1 not found ") ErrStopPreparingBlock = errors.New("stop calculating a block not verified by M2") + + // ErrTxTypeNotSupported is returned if a transaction is not supported in the + // current network configuration. + ErrTxTypeNotSupported = types.ErrTxTypeNotSupported ) diff --git a/core/genesis.go b/core/genesis.go index 79a73f34db..205074b714 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -141,10 +141,10 @@ func (e *GenesisMismatchError) Error() string { // SetupGenesisBlock writes or updates the genesis block in db. // The block that will be used is: // -// genesis == nil genesis != nil -// +------------------------------------------ -// db has no genesis | main-net default | genesis -// db has genesis | from DB | genesis (if compatible) +// genesis == nil genesis != nil +// +------------------------------------------ +// db has no genesis | main-net default | genesis +// db has genesis | from DB | genesis (if compatible) // // The stored chain configuration will be updated if it is compatible (i.e. does not // specify a fork block below the local head block). In case of a conflict, the @@ -200,7 +200,7 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig // are returned to the caller unless we're already at block zero. height := GetBlockNumber(db, GetHeadHeaderHash(db)) if height == missingNumber { - return newcfg, stored, fmt.Errorf("missing block number for head header hash") + return newcfg, stored, errors.New("missing block number for head header hash") } compatErr := storedcfg.CheckCompatible(newcfg, height) if compatErr != nil && height != 0 && compatErr.RewindTo != 0 { @@ -275,7 +275,7 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block { func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { block := g.ToBlock(db) if block.Number().Sign() != 0 { - return nil, fmt.Errorf("can't commit genesis block with number > 0") + return nil, errors.New("can't commit genesis block with number > 0") } if err := WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil { return nil, err diff --git a/core/lending_pool.go b/core/lending_pool.go index fc0deb1821..89aac0665f 100644 --- a/core/lending_pool.go +++ b/core/lending_pool.go @@ -435,7 +435,7 @@ func (pool *LendingPool) validateNewLending(cloneStateDb *state.StateDB, cloneLe validCollateral := false collateralList := lendingstate.GetCollaterals(cloneStateDb, tx.RelayerAddress(), tx.LendingToken(), tx.Term()) for _, collateral := range collateralList { - if tx.CollateralToken().String() == collateral.String() { + if tx.CollateralToken() == collateral { validCollateral = true break } @@ -476,10 +476,10 @@ func (pool *LendingPool) validateRepayLending(cloneStateDb *state.StateDB, clone if lendingTrade == lendingstate.EmptyLendingTrade { return ErrInvalidLendingTradeID } - if tx.UserAddress().String() != lendingTrade.Borrower.String() { + if tx.UserAddress() != lendingTrade.Borrower { return ErrInvalidLendingUserAddress } - if tx.RelayerAddress().String() != lendingTrade.BorrowingRelayer.String() { + if tx.RelayerAddress() != lendingTrade.BorrowingRelayer { return ErrInvalidLendingRelayer } if err := pool.validateBalance(cloneStateDb, cloneLendingStateDb, tx, tx.CollateralToken()); err != nil { @@ -499,10 +499,10 @@ func (pool *LendingPool) validateTopupLending(cloneStateDb *state.StateDB, clone if lendingTrade == lendingstate.EmptyLendingTrade { return ErrInvalidLendingTradeID } - if tx.UserAddress().String() != lendingTrade.Borrower.String() { + if tx.UserAddress() != lendingTrade.Borrower { return ErrInvalidLendingUserAddress } - if tx.RelayerAddress().String() != lendingTrade.BorrowingRelayer.String() { + if tx.RelayerAddress() != lendingTrade.BorrowingRelayer { return ErrInvalidLendingRelayer } if err := pool.validateBalance(cloneStateDb, cloneLendingStateDb, tx, lendingTrade.CollateralToken); err != nil { @@ -519,7 +519,7 @@ func (pool *LendingPool) validateBalance(cloneStateDb *state.StateDB, cloneLendi XDCXServ := XDPoSEngine.GetXDCXService() lendingServ := XDPoSEngine.GetLendingService() if XDCXServ == nil { - return fmt.Errorf("XDCx not found in order validation") + return errors.New("XDCx not found in order validation") } lendingTokenDecimal, err := XDCXServ.GetTokenDecimal(pool.chain, cloneStateDb, tx.LendingToken()) if err != nil { @@ -554,10 +554,10 @@ func (pool *LendingPool) validateBalance(cloneStateDb *state.StateDB, cloneLendi } } if lendTokenXDCPrice == nil || lendTokenXDCPrice.Sign() == 0 { - if tx.LendingToken().String() == common.XDCNativeAddress { + if tx.LendingToken() == common.XDCNativeAddressBinary { lendTokenXDCPrice = common.BasePrice } else { - lendTokenXDCPrice, err = lendingServ.GetMediumTradePriceBeforeEpoch(pool.chain, cloneStateDb, cloneTradingStateDb, tx.LendingToken(), common.HexToAddress(common.XDCNativeAddress)) + lendTokenXDCPrice, err = lendingServ.GetMediumTradePriceBeforeEpoch(pool.chain, cloneStateDb, cloneTradingStateDb, tx.LendingToken(), common.XDCNativeAddressBinary) if err != nil { return err } diff --git a/core/lending_pool_test.go b/core/lending_pool_test.go index 2d8104403f..5ffe82d58c 100644 --- a/core/lending_pool_test.go +++ b/core/lending_pool_test.go @@ -3,6 +3,13 @@ package core import ( "context" "fmt" + "log" + "math/big" + "strconv" + "strings" + "testing" + "time" + "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/types" @@ -10,12 +17,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/crypto/sha3" "github.com/XinFinOrg/XDPoSChain/ethclient" "github.com/XinFinOrg/XDPoSChain/rpc" - "log" - "math/big" - "strconv" - "strings" - "testing" - "time" ) type LendingMsg struct { @@ -197,7 +198,7 @@ func TestSendLending(t *testing.T) { testSendLending(key, nonce, USDAddress, common.Address{}, new(big.Int).Mul(_1E8, big.NewInt(1000)), interestRate, lendingstate.Investing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") nonce++ time.Sleep(time.Second) - testSendLending(key, nonce, USDAddress, common.HexToAddress(common.XDCNativeAddress), new(big.Int).Mul(_1E8, big.NewInt(1000)), interestRate, lendingstate.Borrowing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") + testSendLending(key, nonce, USDAddress, common.XDCNativeAddressBinary, new(big.Int).Mul(_1E8, big.NewInt(1000)), interestRate, lendingstate.Borrowing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") nonce++ time.Sleep(time.Second) @@ -206,7 +207,7 @@ func TestSendLending(t *testing.T) { testSendLending(key, nonce, BTCAddress, common.Address{}, new(big.Int).Mul(_1E18, big.NewInt(1)), interestRate, lendingstate.Investing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") nonce++ time.Sleep(time.Second) - testSendLending(key, nonce, BTCAddress, common.HexToAddress(common.XDCNativeAddress), new(big.Int).Mul(_1E18, big.NewInt(1)), interestRate, lendingstate.Borrowing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") + testSendLending(key, nonce, BTCAddress, common.XDCNativeAddressBinary, new(big.Int).Mul(_1E18, big.NewInt(1)), interestRate, lendingstate.Borrowing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") nonce++ time.Sleep(time.Second) @@ -221,19 +222,19 @@ func TestSendLending(t *testing.T) { // lendToken: XDC, collateral: BTC // amount 1000 XDC - testSendLending(key, nonce, common.HexToAddress(common.XDCNativeAddress), common.Address{}, new(big.Int).Mul(_1E18, big.NewInt(1000)), interestRate, lendingstate.Investing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") + testSendLending(key, nonce, common.XDCNativeAddressBinary, common.Address{}, new(big.Int).Mul(_1E18, big.NewInt(1000)), interestRate, lendingstate.Investing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") nonce++ time.Sleep(time.Second) - testSendLending(key, nonce, common.HexToAddress(common.XDCNativeAddress), BTCAddress, new(big.Int).Mul(_1E18, big.NewInt(1000)), interestRate, lendingstate.Borrowing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") + testSendLending(key, nonce, common.XDCNativeAddressBinary, BTCAddress, new(big.Int).Mul(_1E18, big.NewInt(1000)), interestRate, lendingstate.Borrowing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") nonce++ time.Sleep(time.Second) // lendToken: XDC, collateral: ETH // amount 1000 XDC - testSendLending(key, nonce, common.HexToAddress(common.XDCNativeAddress), common.Address{}, new(big.Int).Mul(_1E18, big.NewInt(1000)), interestRate, lendingstate.Investing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") + testSendLending(key, nonce, common.XDCNativeAddressBinary, common.Address{}, new(big.Int).Mul(_1E18, big.NewInt(1000)), interestRate, lendingstate.Investing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") nonce++ time.Sleep(time.Second) - testSendLending(key, nonce, common.HexToAddress(common.XDCNativeAddress), ETHAddress, new(big.Int).Mul(_1E18, big.NewInt(1000)), interestRate, lendingstate.Borrowing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") + testSendLending(key, nonce, common.XDCNativeAddressBinary, ETHAddress, new(big.Int).Mul(_1E18, big.NewInt(1000)), interestRate, lendingstate.Borrowing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") nonce++ time.Sleep(time.Second) } @@ -282,6 +283,6 @@ func TestRecallLending(t *testing.T) { t.Error("fail to get nonce") t.FailNow() } - testSendLending(key, nonce, USDAddress, common.HexToAddress(common.XDCNativeAddress), new(big.Int).Mul(_1E8, big.NewInt(1000)), interestRate, lendingstate.Borrowing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") + testSendLending(key, nonce, USDAddress, common.XDCNativeAddressBinary, new(big.Int).Mul(_1E8, big.NewInt(1000)), interestRate, lendingstate.Borrowing, lendingstate.LendingStatusNew, true, 0, 0, common.Hash{}, "") time.Sleep(2 * time.Second) } diff --git a/core/order_pool.go b/core/order_pool.go index c8708149b8..dfade0ecd4 100644 --- a/core/order_pool.go +++ b/core/order_pool.go @@ -468,7 +468,7 @@ func (pool *OrderPool) validateOrder(tx *types.OrderTransaction) error { } XDCXServ := XDPoSEngine.GetXDCXService() if XDCXServ == nil { - return fmt.Errorf("XDCx not found in order validation") + return errors.New("XDCx not found in order validation") } baseDecimal, err := XDCXServ.GetTokenDecimal(pool.chain, cloneStateDb, tx.BaseToken()) if err != nil { diff --git a/core/order_pool_test.go b/core/order_pool_test.go index ea11b265bc..1f03c692ee 100644 --- a/core/order_pool_test.go +++ b/core/order_pool_test.go @@ -2,17 +2,18 @@ package core import ( "context" - "github.com/XinFinOrg/XDPoSChain/common" - "github.com/XinFinOrg/XDPoSChain/core/types" - "github.com/XinFinOrg/XDPoSChain/crypto" - "github.com/XinFinOrg/XDPoSChain/ethclient" - "github.com/XinFinOrg/XDPoSChain/rpc" "log" "math/big" "strconv" "strings" "testing" "time" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/types" + "github.com/XinFinOrg/XDPoSChain/crypto" + "github.com/XinFinOrg/XDPoSChain/ethclient" + "github.com/XinFinOrg/XDPoSChain/rpc" ) type OrderMsg struct { @@ -89,7 +90,7 @@ func testSendOrder(t *testing.T, amount, price *big.Int, side string, status str Price: price, ExchangeAddress: common.HexToAddress("0x0D3ab14BBaD3D99F4203bd7a11aCB94882050E7e"), UserAddress: crypto.PubkeyToAddress(privateKey.PublicKey), - BaseToken: common.HexToAddress(common.XDCNativeAddress), + BaseToken: common.XDCNativeAddressBinary, QuoteToken: BTCAddress, Status: status, Side: side, @@ -124,7 +125,7 @@ func testSendOrderXDCUSD(t *testing.T, amount, price *big.Int, side string, stat Price: price, ExchangeAddress: common.HexToAddress("0x0D3ab14BBaD3D99F4203bd7a11aCB94882050E7e"), UserAddress: crypto.PubkeyToAddress(privateKey.PublicKey), - BaseToken: common.HexToAddress(common.XDCNativeAddress), + BaseToken: common.XDCNativeAddressBinary, QuoteToken: USDAddress, Status: status, Side: side, @@ -194,7 +195,7 @@ func testSendOrderXDCBTC(t *testing.T, amount, price *big.Int, side string, stat Price: price, ExchangeAddress: common.HexToAddress("0x0D3ab14BBaD3D99F4203bd7a11aCB94882050E7e"), UserAddress: crypto.PubkeyToAddress(privateKey.PublicKey), - BaseToken: common.HexToAddress(common.XDCNativeAddress), + BaseToken: common.XDCNativeAddressBinary, QuoteToken: BTCAddress, Status: status, Side: side, diff --git a/core/state/statedb.go b/core/state/statedb.go index e73273a710..bac5106f5b 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -730,6 +730,32 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) return root, err } +// PrepareAccessList handles the preparatory steps for executing a state transition with +// regards to both EIP-2929 and EIP-2930: +// +// - Add sender to access list (2929) +// - Add destination to access list (2929) +// - Add precompiles to access list (2929) +// - Add the contents of the optional tx access list (2930) +// +// This method should only be called if Yolov3/Berlin/2929+2930 is applicable at the current number. +func (s *StateDB) PrepareAccessList(sender common.Address, dst *common.Address, precompiles []common.Address, list types.AccessList) { + s.AddAddressToAccessList(sender) + if dst != nil { + s.AddAddressToAccessList(*dst) + // If it's a create-tx, the destination will be added inside evm.create + } + for _, addr := range precompiles { + s.AddAddressToAccessList(addr) + } + for _, el := range list { + s.AddAddressToAccessList(el.Address) + for _, key := range el.StorageKeys { + s.AddSlotToAccessList(el.Address, key) + } + } +} + // AddAddressToAccessList adds the given address to the access list func (s *StateDB) AddAddressToAccessList(addr common.Address) { if s.accessList.AddAddress(addr) { @@ -770,6 +796,6 @@ func (s *StateDB) GetOwner(candidate common.Address) common.Address { // validatorsState[_candidate].owner; locValidatorsState := GetLocMappingAtKey(candidate.Hash(), slot) locCandidateOwner := locValidatorsState.Add(locValidatorsState, new(big.Int).SetUint64(uint64(0))) - ret := s.GetState(common.HexToAddress(common.MasternodeVotingSMC), common.BigToHash(locCandidateOwner)) + ret := s.GetState(common.MasternodeVotingSMCBinary, common.BigToHash(locCandidateOwner)) return common.HexToAddress(ret.Hex()) } diff --git a/core/state/statedb_utils.go b/core/state/statedb_utils.go index a9c210b569..7bd2c4d355 100644 --- a/core/state/statedb_utils.go +++ b/core/state/statedb_utils.go @@ -20,7 +20,7 @@ func GetSigners(statedb *StateDB, block *types.Block) []common.Address { slot := slotBlockSignerMapping["blockSigners"] keys := []common.Hash{} keyArrSlot := GetLocMappingAtKey(block.Hash(), slot) - arrSlot := statedb.GetState(common.HexToAddress(common.BlockSigners), common.BigToHash(keyArrSlot)) + arrSlot := statedb.GetState(common.BlockSignersBinary, common.BigToHash(keyArrSlot)) arrLength := arrSlot.Big().Uint64() for i := uint64(0); i < arrLength; i++ { key := GetLocDynamicArrAtElement(common.BigToHash(keyArrSlot), i, 1) @@ -28,7 +28,7 @@ func GetSigners(statedb *StateDB, block *types.Block) []common.Address { } rets := []common.Address{} for _, key := range keys { - ret := statedb.GetState(common.HexToAddress(common.BlockSigners), key) + ret := statedb.GetState(common.BlockSignersBinary, key) rets = append(rets, common.HexToAddress(ret.Hex())) } @@ -45,7 +45,7 @@ var ( func GetSecret(statedb *StateDB, address common.Address) [][32]byte { slot := slotRandomizeMapping["randomSecret"] locSecret := GetLocMappingAtKey(address.Hash(), slot) - arrLength := statedb.GetState(common.HexToAddress(common.RandomizeSMC), common.BigToHash(locSecret)) + arrLength := statedb.GetState(common.RandomizeSMCBinary, common.BigToHash(locSecret)) keys := []common.Hash{} for i := uint64(0); i < arrLength.Big().Uint64(); i++ { key := GetLocDynamicArrAtElement(common.BigToHash(locSecret), i, 1) @@ -53,7 +53,7 @@ func GetSecret(statedb *StateDB, address common.Address) [][32]byte { } rets := [][32]byte{} for _, key := range keys { - ret := statedb.GetState(common.HexToAddress(common.RandomizeSMC), key) + ret := statedb.GetState(common.RandomizeSMCBinary, key) rets = append(rets, ret) } return rets @@ -62,7 +62,7 @@ func GetSecret(statedb *StateDB, address common.Address) [][32]byte { func GetOpening(statedb *StateDB, address common.Address) [32]byte { slot := slotRandomizeMapping["randomOpening"] locOpening := GetLocMappingAtKey(address.Hash(), slot) - ret := statedb.GetState(common.HexToAddress(common.RandomizeSMC), common.BigToHash(locOpening)) + ret := statedb.GetState(common.RandomizeSMCBinary, common.BigToHash(locOpening)) return ret } @@ -92,13 +92,13 @@ var ( func GetCandidates(statedb *StateDB) []common.Address { slot := slotValidatorMapping["candidates"] slotHash := common.BigToHash(new(big.Int).SetUint64(slot)) - arrLength := statedb.GetState(common.HexToAddress(common.MasternodeVotingSMC), slotHash) + arrLength := statedb.GetState(common.MasternodeVotingSMCBinary, slotHash) count := arrLength.Big().Uint64() rets := make([]common.Address, 0, count) for i := uint64(0); i < count; i++ { key := GetLocDynamicArrAtElement(slotHash, i, 1) - ret := statedb.GetState(common.HexToAddress(common.MasternodeVotingSMC), key) + ret := statedb.GetState(common.MasternodeVotingSMCBinary, key) if !ret.IsZero() { rets = append(rets, common.HexToAddress(ret.Hex())) } @@ -112,7 +112,7 @@ func GetCandidateOwner(statedb *StateDB, candidate common.Address) common.Addres // validatorsState[_candidate].owner; locValidatorsState := GetLocMappingAtKey(candidate.Hash(), slot) locCandidateOwner := locValidatorsState.Add(locValidatorsState, new(big.Int).SetUint64(uint64(0))) - ret := statedb.GetState(common.HexToAddress(common.MasternodeVotingSMC), common.BigToHash(locCandidateOwner)) + ret := statedb.GetState(common.MasternodeVotingSMCBinary, common.BigToHash(locCandidateOwner)) return common.HexToAddress(ret.Hex()) } @@ -121,7 +121,7 @@ func GetCandidateCap(statedb *StateDB, candidate common.Address) *big.Int { // validatorsState[_candidate].cap; locValidatorsState := GetLocMappingAtKey(candidate.Hash(), slot) locCandidateCap := locValidatorsState.Add(locValidatorsState, new(big.Int).SetUint64(uint64(1))) - ret := statedb.GetState(common.HexToAddress(common.MasternodeVotingSMC), common.BigToHash(locCandidateCap)) + ret := statedb.GetState(common.MasternodeVotingSMCBinary, common.BigToHash(locCandidateCap)) return ret.Big() } @@ -129,7 +129,7 @@ func GetVoters(statedb *StateDB, candidate common.Address) []common.Address { //mapping(address => address[]) voters; slot := slotValidatorMapping["voters"] locVoters := GetLocMappingAtKey(candidate.Hash(), slot) - arrLength := statedb.GetState(common.HexToAddress(common.MasternodeVotingSMC), common.BigToHash(locVoters)) + arrLength := statedb.GetState(common.MasternodeVotingSMCBinary, common.BigToHash(locVoters)) keys := []common.Hash{} for i := uint64(0); i < arrLength.Big().Uint64(); i++ { key := GetLocDynamicArrAtElement(common.BigToHash(locVoters), i, 1) @@ -137,7 +137,7 @@ func GetVoters(statedb *StateDB, candidate common.Address) []common.Address { } rets := []common.Address{} for _, key := range keys { - ret := statedb.GetState(common.HexToAddress(common.MasternodeVotingSMC), key) + ret := statedb.GetState(common.MasternodeVotingSMCBinary, key) rets = append(rets, common.HexToAddress(ret.Hex())) } @@ -149,6 +149,6 @@ func GetVoterCap(statedb *StateDB, candidate, voter common.Address) *big.Int { locValidatorsState := GetLocMappingAtKey(candidate.Hash(), slot) locCandidateVoters := locValidatorsState.Add(locValidatorsState, new(big.Int).SetUint64(uint64(2))) retByte := crypto.Keccak256(voter.Hash().Bytes(), common.BigToHash(locCandidateVoters).Bytes()) - ret := statedb.GetState(common.HexToAddress(common.MasternodeVotingSMC), common.BytesToHash(retByte)) + ret := statedb.GetState(common.MasternodeVotingSMCBinary, common.BytesToHash(retByte)) return ret.Big() } diff --git a/core/state_processor.go b/core/state_processor.go index e21df09779..e4bab6797b 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -80,7 +80,7 @@ func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, tra misc.ApplyDAOHardFork(statedb) } if common.TIPSigning.Cmp(header.Number) == 0 { - statedb.DeleteAddress(common.HexToAddress(common.BlockSigners)) + statedb.DeleteAddress(common.BlockSignersBinary) } parentState := statedb.Copy() InitSignerInTransactions(p.config, header, block.Transactions()) @@ -146,7 +146,7 @@ func (p *StateProcessor) ProcessBlockNoValidator(cBlock *CalculatedBlock, stated misc.ApplyDAOHardFork(statedb) } if common.TIPSigning.Cmp(header.Number) == 0 { - statedb.DeleteAddress(common.HexToAddress(common.BlockSigners)) + statedb.DeleteAddress(common.BlockSignersBinary) } if cBlock.stop { return nil, nil, 0, ErrStopPreparingBlock @@ -215,13 +215,14 @@ func (p *StateProcessor) ProcessBlockNoValidator(cBlock *CalculatedBlock, stated // for the transaction, gas used and an error if the transaction failed, // indicating the block was invalid. func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]*big.Int, bc *BlockChain, author *common.Address, gp *GasPool, statedb *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, tx *types.Transaction, usedGas *uint64, cfg vm.Config) (*types.Receipt, uint64, error, bool) { - if tx.To() != nil && tx.To().String() == common.BlockSigners && config.IsTIPSigning(header.Number) { + to := tx.To() + if to != nil && *to == common.BlockSignersBinary && config.IsTIPSigning(header.Number) { return ApplySignTransaction(config, statedb, header, tx, usedGas) } - if tx.To() != nil && tx.To().String() == common.TradingStateAddr && config.IsTIPXDCXReceiver(header.Number) { + if to != nil && *to == common.TradingStateAddrBinary && config.IsTIPXDCXReceiver(header.Number) { return ApplyEmptyTransaction(config, statedb, header, tx, usedGas) } - if tx.To() != nil && tx.To().String() == common.XDCXLendingAddress && config.IsTIPXDCXReceiver(header.Number) { + if to != nil && *to == common.XDCXLendingAddressBinary && config.IsTIPXDCXReceiver(header.Number) { return ApplyEmptyTransaction(config, statedb, header, tx, usedGas) } if tx.IsTradingTransaction() && config.IsTIPXDCXReceiver(header.Number) { @@ -233,8 +234,8 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* } var balanceFee *big.Int - if tx.To() != nil { - if value, ok := tokensFee[*tx.To()]; ok { + if to != nil { + if value, ok := tokensFee[*to]; ok { balanceFee = value } } @@ -242,23 +243,12 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* if err != nil { return nil, 0, err, false } - // Create a new context to be used in the EVM environment + // Create a new context to be used in the EVM environment. context := NewEVMContext(msg, header, bc, author) // Create a new environment which holds all relevant information // about the transaction and calling mechanisms. vmenv := vm.NewEVM(context, statedb, XDCxState, config, cfg) - if config.IsEIP1559(header.Number) { - statedb.AddAddressToAccessList(msg.From()) - if dst := msg.To(); dst != nil { - statedb.AddAddressToAccessList(*dst) - // If it's a create-tx, the destination will be added inside evm.create - } - for _, addr := range vmenv.ActivePrecompiles() { - statedb.AddAddressToAccessList(addr) - } - } - // If we don't have an explicit author (i.e. not mining), extract from the header var beneficiary common.Address if author == nil { @@ -419,7 +409,8 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* if err != nil { return nil, 0, err, false } - // Update the state with pending changes + + // Update the state with pending changes. var root []byte if config.IsByzantium(header.Number) { statedb.Finalise(true) @@ -428,23 +419,30 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* } *usedGas += gas - // Create a new receipt for the transaction, storing the intermediate root and gas used by the tx - // based on the eip phase, we're passing wether the root touch-delete accounts. - receipt := types.NewReceipt(root, failed, *usedGas) + // Create a new receipt for the transaction, storing the intermediate root and gas used + // by the tx. + receipt := &types.Receipt{Type: tx.Type(), PostState: root, CumulativeGasUsed: *usedGas} + if failed { + receipt.Status = types.ReceiptStatusFailed + } else { + receipt.Status = types.ReceiptStatusSuccessful + } receipt.TxHash = tx.Hash() receipt.GasUsed = gas - // if the transaction created a contract, store the creation address in the receipt. + + // If the transaction created a contract, store the creation address in the receipt. if msg.To() == nil { receipt.ContractAddress = crypto.CreateAddress(vmenv.Context.Origin, tx.Nonce()) } - // Set the receipt logs and create a bloom for filtering + + // Set the receipt logs and create the bloom filter. receipt.Logs = statedb.GetLogs(tx.Hash()) receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) receipt.BlockHash = statedb.BlockHash() receipt.BlockNumber = header.Number receipt.TransactionIndex = uint(statedb.TxIndex()) if balanceFee != nil && failed { - state.PayFeeWithTRC21TxFail(statedb, msg.From(), *tx.To()) + state.PayFeeWithTRC21TxFail(statedb, msg.From(), *to) } return receipt, gas, err, balanceFee != nil } @@ -476,7 +474,7 @@ func ApplySignTransaction(config *params.ChainConfig, statedb *state.StateDB, he // if the transaction created a contract, store the creation address in the receipt. // Set the receipt logs and create a bloom for filtering log := &types.Log{} - log.Address = common.HexToAddress(common.BlockSigners) + log.Address = common.BlockSignersBinary log.BlockNumber = header.Number.Uint64() statedb.AddLog(log) receipt.Logs = statedb.GetLogs(tx.Hash()) @@ -532,7 +530,6 @@ func InitSignerInTransactions(config *params.ChainConfig, header *types.Header, go func(from int, to int) { for j := from; j < to; j++ { types.CacheSigner(signer, txs[j]) - txs[j].CacheHash() } wg.Done() }(from, to) diff --git a/core/state_transition.go b/core/state_transition.go index c9c4fdfefd..48d6ebcae4 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -22,6 +22,7 @@ import ( "math/big" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/core/vm" "github.com/XinFinOrg/XDPoSChain/log" "github.com/XinFinOrg/XDPoSChain/params" @@ -42,8 +43,10 @@ The state transitioning model does all all the necessary work to work out a vali 3) Create a new state object if the recipient is \0*32 4) Value transfer == If contract creation == - 4a) Attempt to run transaction data - 4b) If valid, use result as code for the new state object + + 4a) Attempt to run transaction data + 4b) If valid, use result as code for the new state object + == end == 5) Run Script section 6) Derive new state root @@ -74,13 +77,14 @@ type Message interface { CheckNonce() bool Data() []byte BalanceTokenFee() *big.Int + AccessList() types.AccessList } // IntrinsicGas computes the 'intrinsic gas' for a message with the given data. -func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error) { +func IntrinsicGas(data []byte, accessList types.AccessList, isContractCreation, isHomestead bool) (uint64, error) { // Set the starting gas for the raw transaction var gas uint64 - if contractCreation && homestead { + if isContractCreation && isHomestead { gas = params.TxGasContractCreation } else { gas = params.TxGas @@ -106,6 +110,10 @@ func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error) } gas += z * params.TxDataZeroGas } + if accessList != nil { + gas += uint64(len(accessList)) * params.TxAccessListAddressGas + gas += uint64(accessList.StorageKeys()) * params.TxAccessListStorageKeyGas + } return gas, nil } @@ -226,7 +234,7 @@ func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedG contractCreation := msg.To() == nil // Pay intrinsic gas - gas, err := IntrinsicGas(st.data, contractCreation, homestead) + gas, err := IntrinsicGas(st.data, st.msg.AccessList(), contractCreation, homestead) if err != nil { return nil, 0, false, err, nil } @@ -234,6 +242,10 @@ func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedG return nil, 0, false, err, nil } + if rules := st.evm.ChainConfig().Rules(st.evm.Context.BlockNumber); rules.IsEIP1559 { + st.state.PrepareAccessList(msg.From(), msg.To(), vm.ActivePrecompiles(rules), msg.AccessList()) + } + var ( evm = st.evm // vm errors do not effect consensus and are therefor diff --git a/core/token_validator.go b/core/token_validator.go index f3ee929163..f604be504f 100644 --- a/core/token_validator.go +++ b/core/token_validator.go @@ -27,6 +27,7 @@ import ( "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/contracts/XDCx/contract" "github.com/XinFinOrg/XDPoSChain/core/state" + "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/core/vm" "github.com/XinFinOrg/XDPoSChain/log" ) @@ -37,20 +38,21 @@ const ( getDecimalFunction = "decimals" ) -// callmsg implements core.Message to allow passing it as a transaction simulator. -type callmsg struct { +// callMsg implements core.Message to allow passing it as a transaction simulator. +type callMsg struct { ethereum.CallMsg } -func (m callmsg) From() common.Address { return m.CallMsg.From } -func (m callmsg) Nonce() uint64 { return 0 } -func (m callmsg) CheckNonce() bool { return false } -func (m callmsg) To() *common.Address { return m.CallMsg.To } -func (m callmsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } -func (m callmsg) Gas() uint64 { return m.CallMsg.Gas } -func (m callmsg) Value() *big.Int { return m.CallMsg.Value } -func (m callmsg) Data() []byte { return m.CallMsg.Data } -func (m callmsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } +func (m callMsg) From() common.Address { return m.CallMsg.From } +func (m callMsg) Nonce() uint64 { return 0 } +func (m callMsg) CheckNonce() bool { return false } +func (m callMsg) To() *common.Address { return m.CallMsg.To } +func (m callMsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } +func (m callMsg) Gas() uint64 { return m.CallMsg.Gas } +func (m callMsg) Value() *big.Int { return m.CallMsg.Value } +func (m callMsg) Data() []byte { return m.CallMsg.Data } +func (m callMsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } +func (m callMsg) AccessList() types.AccessList { return m.CallMsg.AccessList } type SimulatedBackend interface { CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) @@ -99,7 +101,7 @@ func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, call.Value = new(big.Int) } // Execute the call. - msg := callmsg{call} + msg := callMsg{call} feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) if msg.To() != nil { if value, ok := feeCapacity[*msg.To()]; ok { diff --git a/core/tx_list.go b/core/tx_list.go index 2c3c33eb59..5f623806d6 100644 --- a/core/tx_list.go +++ b/core/tx_list.go @@ -24,7 +24,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/types" - "github.com/XinFinOrg/XDPoSChain/log" ) // nonceHeap is a heap.Interface implementation over 64bit unsigned integers for @@ -99,7 +98,30 @@ func (m *txSortedMap) Forward(threshold uint64) types.Transactions { // Filter iterates over the list of transactions and removes all of them for which // the specified function evaluates to true. +// Filter, as opposed to 'filter', re-initialises the heap after the operation is done. +// If you want to do several consecutive filterings, it's therefore better to first +// do a .filter(func1) followed by .Filter(func2) or reheap() func (m *txSortedMap) Filter(filter func(*types.Transaction) bool) types.Transactions { + removed := m.filter(filter) + // If transactions were removed, the heap and cache are ruined + if len(removed) > 0 { + m.reheap() + } + return removed +} + +func (m *txSortedMap) reheap() { + *m.index = make([]uint64, 0, len(m.items)) + for nonce := range m.items { + *m.index = append(*m.index, nonce) + } + heap.Init(m.index) + m.cache = nil +} + +// filter is identical to Filter, but **does not** regenerate the heap. This method +// should only be used if followed immediately by a call to Filter or reheap() +func (m *txSortedMap) filter(filter func(*types.Transaction) bool) types.Transactions { var removed types.Transactions // Collect all the transactions to filter out @@ -109,14 +131,7 @@ func (m *txSortedMap) Filter(filter func(*types.Transaction) bool) types.Transac delete(m.items, nonce) } } - // If transactions were removed, the heap and cache are ruined if len(removed) > 0 { - *m.index = make([]uint64, 0, len(m.items)) - for nonce := range m.items { - *m.index = append(*m.index, nonce) - } - heap.Init(m.index) - m.cache = nil } return removed @@ -197,10 +212,7 @@ func (m *txSortedMap) Len() int { return len(m.items) } -// Flatten creates a nonce-sorted slice of transactions based on the loosely -// sorted internal representation. The result of the sorting is cached in case -// it's requested again before any modifications are made to the contents. -func (m *txSortedMap) Flatten() types.Transactions { +func (m *txSortedMap) flatten() types.Transactions { // If the sorting was not cached yet, create and cache it if m.cache == nil { m.cache = make(types.Transactions, 0, len(m.items)) @@ -209,12 +221,27 @@ func (m *txSortedMap) Flatten() types.Transactions { } sort.Sort(types.TxByNonce(m.cache)) } + return m.cache +} + +// Flatten creates a nonce-sorted slice of transactions based on the loosely +// sorted internal representation. The result of the sorting is cached in case +// it's requested again before any modifications are made to the contents. +func (m *txSortedMap) Flatten() types.Transactions { // Copy the cache to prevent accidental modifications - txs := make(types.Transactions, len(m.cache)) - copy(txs, m.cache) + cache := m.flatten() + txs := make(types.Transactions, len(cache)) + copy(txs, cache) return txs } +// LastElement returns the last element of a flattened list, thus, the +// transaction with the highest nonce +func (m *txSortedMap) LastElement() *types.Transaction { + cache := m.flatten() + return cache[len(cache)-1] +} + // txList is a "list" of transactions belonging to an account, sorted by account // nonce. The same type can be used both for storing contiguous transactions for // the executable/pending queue; and for storing gapped transactions for the non- @@ -255,11 +282,15 @@ func (l *txList) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Tran return false, nil } if old != nil { - threshold := new(big.Int).Div(new(big.Int).Mul(old.GasPrice(), big.NewInt(100+int64(priceBump))), big.NewInt(100)) + // threshold = oldGP * (100 + priceBump) / 100 + a := big.NewInt(100 + int64(priceBump)) + a = a.Mul(a, old.GasPrice()) + b := big.NewInt(100) + threshold := a.Div(a, b) // Have to ensure that the new gas price is higher than the old gas // price as well as checking the percentage threshold to ensure that // this is accurate for low (Wei-level) gas price replacements - if old.GasPrice().Cmp(tx.GasPrice()) >= 0 || threshold.Cmp(tx.GasPrice()) > 0 { + if old.GasPriceCmp(tx) >= 0 || tx.GasPriceIntCmp(threshold) < 0 { return false, nil } } @@ -303,24 +334,27 @@ func (l *txList) Filter(costLimit *big.Int, gasLimit uint64, trc21Issuers map[co maximum := costLimit if tx.To() != nil { if feeCapacity, ok := trc21Issuers[*tx.To()]; ok { - return new(big.Int).Add(costLimit, feeCapacity).Cmp(tx.TxCost(number)) < 0 || tx.Gas() > gasLimit + return tx.Gas() > gasLimit || new(big.Int).Add(costLimit, feeCapacity).Cmp(tx.TxCost(number)) < 0 } } - return tx.Cost().Cmp(maximum) > 0 || tx.Gas() > gasLimit + return tx.Gas() > gasLimit || tx.Cost().Cmp(maximum) > 0 }) - // If the list was strict, filter anything above the lowest nonce + if len(removed) == 0 { + return nil, nil + } var invalids types.Transactions - - if l.strict && len(removed) > 0 { + // If the list was strict, filter anything above the lowest nonce + if l.strict { lowest := uint64(math.MaxUint64) for _, tx := range removed { if nonce := tx.Nonce(); lowest > nonce { lowest = nonce } } - invalids = l.txs.Filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) + invalids = l.txs.filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) } + l.txs.reheap() return removed, invalids } @@ -374,6 +408,12 @@ func (l *txList) Flatten() types.Transactions { return l.txs.Flatten() } +// LastElement returns the last element of a flattened list, thus, the +// transaction with the highest nonce +func (l *txList) LastElement() *types.Transaction { + return l.txs.LastElement() +} + // priceHeap is a heap.Interface implementation over transactions for retrieving // price-sorted transactions to discard when the pool fills up. type priceHeap []*types.Transaction @@ -383,7 +423,7 @@ func (h priceHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h priceHeap) Less(i, j int) bool { // Sort primarily by price, returning the cheaper one - switch h[i].GasPrice().Cmp(h[j].GasPrice()) { + switch h[i].GasPriceCmp(h[j]) { case -1: return true case 1: @@ -406,24 +446,29 @@ func (h *priceHeap) Pop() interface{} { } // txPricedList is a price-sorted heap to allow operating on transactions pool -// contents in a price-incrementing way. +// contents in a price-incrementing way. It's built opon the all transactions +// in txpool but only interested in the remote part. It means only remote transactions +// will be considered for tracking, sorting, eviction, etc. type txPricedList struct { - all *txLookup // Pointer to the map of all transactions - items *priceHeap // Heap of prices of all the stored transactions - stales int // Number of stale price points to (re-heap trigger) + all *txLookup // Pointer to the map of all transactions + remotes *priceHeap // Heap of prices of all the stored **remote** transactions + stales int // Number of stale price points to (re-heap trigger) } // newTxPricedList creates a new price-sorted transaction heap. func newTxPricedList(all *txLookup) *txPricedList { return &txPricedList{ - all: all, - items: new(priceHeap), + all: all, + remotes: new(priceHeap), } } // Put inserts a new transaction into the heap. -func (l *txPricedList) Put(tx *types.Transaction) { - heap.Push(l.items, tx) +func (l *txPricedList) Put(tx *types.Transaction, local bool) { + if local { + return + } + heap.Push(l.remotes, tx) } // Removed notifies the prices transaction list that an old transaction dropped @@ -432,100 +477,95 @@ func (l *txPricedList) Put(tx *types.Transaction) { func (l *txPricedList) Removed(count int) { // Bump the stale counter, but exit if still too low (< 25%) l.stales += count - if l.stales <= len(*l.items)/4 { + if l.stales <= len(*l.remotes)/4 { return } // Seems we've reached a critical number of stale transactions, reheap - reheap := make(priceHeap, 0, l.all.Count()) - - l.stales, l.items = 0, &reheap - l.all.Range(func(hash common.Hash, tx *types.Transaction) bool { - *l.items = append(*l.items, tx) - return true - }) - heap.Init(l.items) + l.Reheap() } // Cap finds all the transactions below the given price threshold, drops them -// from the priced list and returs them for further removal from the entire pool. -func (l *txPricedList) Cap(threshold *big.Int, local *accountSet) types.Transactions { +// from the priced list and returns them for further removal from the entire pool. +// +// Note: only remote transactions will be considered for eviction. +func (l *txPricedList) Cap(threshold *big.Int) types.Transactions { drop := make(types.Transactions, 0, 128) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep - - for len(*l.items) > 0 { + for len(*l.remotes) > 0 { // Discard stale transactions if found during cleanup - tx := heap.Pop(l.items).(*types.Transaction) - if l.all.Get(tx.Hash()) == nil { + cheapest := (*l.remotes)[0] + if l.all.GetRemote(cheapest.Hash()) == nil { // Removed or migrated + heap.Pop(l.remotes) l.stales-- continue } // Stop the discards if we've reached the threshold - if tx.GasPrice().Cmp(threshold) >= 0 { - save = append(save, tx) + if cheapest.GasPriceIntCmp(threshold) >= 0 { break } - // Non stale transaction found, discard unless local - if local.containsTx(tx) { - save = append(save, tx) - } else { - drop = append(drop, tx) - } - } - for _, tx := range save { - heap.Push(l.items, tx) + heap.Pop(l.remotes) + drop = append(drop, cheapest) } return drop } // Underpriced checks whether a transaction is cheaper than (or as cheap as) the -// lowest priced transaction currently being tracked. -func (l *txPricedList) Underpriced(tx *types.Transaction, local *accountSet) bool { - // Local transactions cannot be underpriced - if local.containsTx(tx) { - return false - } +// lowest priced (remote) transaction currently being tracked. +func (l *txPricedList) Underpriced(tx *types.Transaction) bool { // Discard stale price points if found at the heap start - for len(*l.items) > 0 { - head := []*types.Transaction(*l.items)[0] - if l.all.Get(head.Hash()) == nil { + for len(*l.remotes) > 0 { + head := []*types.Transaction(*l.remotes)[0] + if l.all.GetRemote(head.Hash()) == nil { // Removed or migrated l.stales-- - heap.Pop(l.items) + heap.Pop(l.remotes) continue } break } // Check if the transaction is underpriced or not - if len(*l.items) == 0 { - log.Error("Pricing query for empty pool") // This cannot happen, print to catch programming errors - return false + if len(*l.remotes) == 0 { + return false // There is no remote transaction at all. } - cheapest := []*types.Transaction(*l.items)[0] - return cheapest.GasPrice().Cmp(tx.GasPrice()) >= 0 + // If the remote transaction is even cheaper than the + // cheapest one tracked locally, reject it. + cheapest := []*types.Transaction(*l.remotes)[0] + return cheapest.GasPriceCmp(tx) >= 0 } // Discard finds a number of most underpriced transactions, removes them from the // priced list and returns them for further removal from the entire pool. -func (l *txPricedList) Discard(count int, local *accountSet) types.Transactions { - drop := make(types.Transactions, 0, count) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep - - for len(*l.items) > 0 && count > 0 { +// +// Note local transaction won't be considered for eviction. +func (l *txPricedList) Discard(slots int, force bool) (types.Transactions, bool) { + drop := make(types.Transactions, 0, slots) // Remote underpriced transactions to drop + for len(*l.remotes) > 0 && slots > 0 { // Discard stale transactions if found during cleanup - tx := heap.Pop(l.items).(*types.Transaction) - if l.all.Get(tx.Hash()) == nil { + tx := heap.Pop(l.remotes).(*types.Transaction) + if l.all.GetRemote(tx.Hash()) == nil { // Removed or migrated l.stales-- continue } - // Non stale transaction found, discard unless local - if local.containsTx(tx) { - save = append(save, tx) - } else { - drop = append(drop, tx) - count-- + // Non stale transaction found, discard it + drop = append(drop, tx) + slots -= numSlots(tx) + } + // If we still can't make enough room for the new transaction + if slots > 0 && !force { + for _, tx := range drop { + heap.Push(l.remotes, tx) } + return nil, false } - for _, tx := range save { - heap.Push(l.items, tx) - } - return drop + return drop, true +} + +// Reheap forcibly rebuilds the heap based on the current remote transaction set. +func (l *txPricedList) Reheap() { + reheap := make(priceHeap, 0, l.all.RemoteCount()) + + l.stales, l.remotes = 0, &reheap + l.all.Range(func(hash common.Hash, tx *types.Transaction, local bool) bool { + *l.remotes = append(*l.remotes, tx) + return true + }, false, true) // Only iterate remotes + heap.Init(l.remotes) } diff --git a/core/tx_list_test.go b/core/tx_list_test.go index f0ec8eb8b4..36a0196f1e 100644 --- a/core/tx_list_test.go +++ b/core/tx_list_test.go @@ -17,6 +17,7 @@ package core import ( + "math/big" "math/rand" "testing" @@ -49,3 +50,21 @@ func TestStrictTxListAdd(t *testing.T) { } } } + +func BenchmarkTxListAdd(t *testing.B) { + // Generate a list of transactions to insert + key, _ := crypto.GenerateKey() + + txs := make(types.Transactions, 100000) + for i := 0; i < len(txs); i++ { + txs[i] = transaction(uint64(i), 0, key) + } + // Insert the transactions in a random order + list := newTxList(true) + priceLimit := big.NewInt(int64(DefaultTxPoolConfig.PriceLimit)) + t.ResetTimer() + for _, v := range rand.Perm(len(txs)) { + list.Add(txs[v], DefaultTxPoolConfig.PriceBump) + list.Filter(priceLimit, DefaultTxPoolConfig.PriceBump, nil, nil) + } +} diff --git a/core/tx_pool.go b/core/tx_pool.go index b2c271f016..64359a972b 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -39,9 +39,25 @@ import ( const ( // chainHeadChanSize is the size of channel listening to ChainHeadEvent. chainHeadChanSize = 10 + + // txSlotSize is used to calculate how many data slots a single transaction + // takes up based on its size. The slots are used as DoS protection, ensuring + // that validating a new transaction remains a constant operation (in reality + // O(maxslots), where max slots are 4 currently). + txSlotSize = 32 * 1024 + + // txMaxSize is the maximum size a single transaction can have. This field has + // non-trivial consequences: larger transactions are significantly harder and + // more expensive to propagate; larger transactions also take more resources + // to validate whether they fit into the pool or not. + txMaxSize = 2 * txSlotSize // 64KB, don't bump without EIP-2464 support ) var ( + // ErrAlreadyKnown is returned if the transactions is already contained + // within the pool. + ErrAlreadyKnown = errors.New("already known") + // ErrInvalidSender is returned if the transaction contains an invalid signature. ErrInvalidSender = errors.New("invalid sender") @@ -53,6 +69,10 @@ var ( // configured for the transaction pool. ErrUnderpriced = errors.New("transaction underpriced") + // ErrTxPoolOverflow is returned if the transaction pool is full and can't accpet + // another remote transaction. + ErrTxPoolOverflow = errors.New("txpool is full") + // ErrReplaceUnderpriced is returned if a transaction is attempted to be replaced // with a different one without the required price bump. ErrReplaceUnderpriced = errors.New("replacement transaction underpriced") @@ -69,7 +89,7 @@ var ( // maximum allowance of the current block. ErrGasLimit = errors.New("exceeds block gas limit") - // ErrNegativeValue is a sanity error to ensure noone is able to specify a + // ErrNegativeValue is a sanity error to ensure no one is able to specify a // transaction with a negative value. ErrNegativeValue = errors.New("negative value") @@ -104,15 +124,19 @@ var ( queuedReplaceMeter = metrics.NewRegisteredMeter("txpool/queued/replace", nil) queuedRateLimitMeter = metrics.NewRegisteredMeter("txpool/queued/ratelimit", nil) // Dropped due to rate limiting queuedNofundsMeter = metrics.NewRegisteredMeter("txpool/queued/nofunds", nil) // Dropped due to out-of-funds + queuedEvictionMeter = metrics.NewRegisteredMeter("txpool/queued/eviction", nil) // Dropped due to lifetime // General tx metrics - validMeter = metrics.NewRegisteredMeter("txpool/valid", nil) + knownTxMeter = metrics.NewRegisteredMeter("txpool/known", nil) + validTxMeter = metrics.NewRegisteredMeter("txpool/valid", nil) invalidTxMeter = metrics.NewRegisteredMeter("txpool/invalid", nil) underpricedTxMeter = metrics.NewRegisteredMeter("txpool/underpriced", nil) + overflowedTxMeter = metrics.NewRegisteredMeter("txpool/overflowed", nil) pendingGauge = metrics.NewRegisteredGauge("txpool/pending", nil) queuedGauge = metrics.NewRegisteredGauge("txpool/queued", nil) localGauge = metrics.NewRegisteredGauge("txpool/local", nil) + slotsGauge = metrics.NewRegisteredGauge("txpool/slots", nil) ) // TxStatus is the current status of a transaction as seen by the pool. @@ -259,6 +283,7 @@ type TxPool struct { reorgShutdownCh chan struct{} // requests shutdown of scheduleReorgLoop wg sync.WaitGroup // tracks loop, scheduleReorgLoop + eip2718 bool // Fork indicator whether we are using EIP-2718 type transactions. IsSigner func(address common.Address) bool trc21FeeCapacity map[common.Address]*big.Int } @@ -278,7 +303,7 @@ func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain block config: config, chainconfig: chainconfig, chain: chain, - signer: types.NewEIP155Signer(chainconfig.ChainId), + signer: types.LatestSigner(chainconfig), pending: make(map[common.Address]*txList), queue: make(map[common.Address]*txList), beats: make(map[common.Address]time.Time), @@ -369,7 +394,7 @@ func (pool *TxPool) loop() { prevPending, prevQueued, prevStales = pending, queued, stales } - // Handle inactive account transaction eviction + // Handle inactive account transaction eviction case <-evict.C: pool.mu.Lock() for addr := range pool.queue { @@ -379,14 +404,16 @@ func (pool *TxPool) loop() { } // Any non-locals old enough should be removed if time.Since(pool.beats[addr]) > pool.config.Lifetime { - for _, tx := range pool.queue[addr].Flatten() { + list := pool.queue[addr].Flatten() + for _, tx := range list { pool.removeTx(tx.Hash(), true) } + queuedEvictionMeter.Mark(int64(len(list))) } } pool.mu.Unlock() - // Handle local transaction journal rotation + // Handle local transaction journal rotation case <-journal.C: if pool.journal != nil { pool.mu.Lock() @@ -435,7 +462,7 @@ func (pool *TxPool) SetGasPrice(price *big.Int) { defer pool.mu.Unlock() pool.gasPrice = price - for _, tx := range pool.priced.Cap(price, pool.locals) { + for _, tx := range pool.priced.Cap(price) { pool.removeTx(tx.Hash(), false) } log.Info("Transaction pool price threshold updated", "price", price) @@ -539,6 +566,14 @@ func (pool *TxPool) GetSender(tx *types.Transaction) (common.Address, error) { // validateTx checks whether a transaction is valid according to the consensus // rules and adheres to some heuristic limits of the local node (price and size). func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { + // Accept only legacy transactions until EIP-2718/2930 activates. + if !pool.eip2718 && tx.Type() != types.LegacyTxType { + return ErrTxTypeNotSupported + } + // Reject transactions over defined size to prevent DOS attacks + if uint64(tx.Size()) > txMaxSize { + return ErrOversizedData + } // check if sender is in black list if tx.From() != nil && common.Blacklist[*tx.From()] { return fmt.Errorf("Reject transaction with sender in black-list: %v", tx.From().Hex()) @@ -547,11 +582,6 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { if tx.To() != nil && common.Blacklist[*tx.To()] { return fmt.Errorf("Reject transaction with receiver in black-list: %v", tx.To().Hex()) } - - // Heuristic limit, reject transactions over 32KB to prevent DOS attacks - if tx.Size() > 32*1024 { - return ErrOversizedData - } // Transactions can't be negative. This may never happen using RLP decoded // transactions but may occur if you create a transaction using the RPC. if tx.Value().Sign() < 0 { @@ -561,14 +591,13 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { if pool.currentMaxGas < tx.Gas() { return ErrGasLimit } - // Make sure the transaction is signed properly + // Make sure the transaction is signed properly. from, err := types.Sender(pool.signer, tx) if err != nil { return ErrInvalidSender } // Drop non-local transactions under our own minimal accepted gas price - local = local || pool.locals.contains(from) // account may be local even if the transaction arrived from the network - if !local && pool.gasPrice.Cmp(tx.GasPrice()) > 0 { + if !local && tx.GasPriceIntCmp(pool.gasPrice) < 0 { if !tx.IsSpecialTransaction() || (pool.IsSigner != nil && !pool.IsSigner(from)) { return ErrUnderpriced } @@ -606,7 +635,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { if tx.To() == nil || (tx.To() != nil && !tx.IsSpecialTransaction()) { // Ensure the transaction has more gas than the basic tx fee. - intrGas, err := IntrinsicGas(tx.Data(), tx.To() == nil, true) + intrGas, err := IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, true) if err != nil { return err } @@ -659,39 +688,50 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e hash := tx.Hash() if pool.all.Get(hash) != nil { log.Trace("Discarding already known transaction", "hash", hash) - return false, fmt.Errorf("known transaction: %x", hash) + knownTxMeter.Mark(1) + return false, ErrAlreadyKnown } + // Make the local flag. If it's from local source or it's from the network but + // the sender is marked as local previously, treat it as the local transaction. + isLocal := local || pool.locals.containsTx(tx) // If the transaction fails basic validation, discard it - if err := pool.validateTx(tx, local); err != nil { + if err := pool.validateTx(tx, isLocal); err != nil { log.Trace("Discarding invalid transaction", "hash", hash, "err", err) invalidTxMeter.Mark(1) return false, err } - from, _ := types.Sender(pool.signer, tx) // already validated if tx.IsSpecialTransaction() && pool.IsSigner != nil && pool.IsSigner(from) && pool.pendingNonces.get(from) == tx.Nonce() { - return pool.promoteSpecialTx(from, tx) + return pool.promoteSpecialTx(from, tx, isLocal) } - // If the transaction pool is full, discard underpriced transactions if uint64(pool.all.Count()) >= pool.config.GlobalSlots+pool.config.GlobalQueue { log.Debug("Add transaction to pool full", "hash", hash, "nonce", tx.Nonce()) // If the new transaction is underpriced, don't accept it - if !local && pool.priced.Underpriced(tx, pool.locals) { + if !isLocal && pool.priced.Underpriced(tx) { log.Trace("Discarding underpriced transaction", "hash", hash, "price", tx.GasPrice()) underpricedTxMeter.Mark(1) return false, ErrUnderpriced } - // New transaction is better than our worse ones, make room for it - drop := pool.priced.Discard(pool.all.Count()-int(pool.config.GlobalSlots+pool.config.GlobalQueue-1), pool.locals) + // New transaction is better than our worse ones, make room for it. + // If it's a local transaction, forcibly discard all available transactions. + // Otherwise if we can't make enough room for new one, abort the operation. + drop, success := pool.priced.Discard(pool.all.Slots()-int(pool.config.GlobalSlots+pool.config.GlobalQueue)+numSlots(tx), isLocal) + + // Special case, we still can't make the room for the new remote one. + if !isLocal && !success { + log.Trace("Discarding overflown transaction", "hash", hash) + overflowedTxMeter.Mark(1) + return false, ErrTxPoolOverflow + } + // Kick out the underpriced remote transactions. for _, tx := range drop { log.Trace("Discarding freshly underpriced transaction", "hash", tx.Hash(), "price", tx.GasPrice()) underpricedTxMeter.Mark(1) pool.removeTx(tx.Hash(), false) } } - // Try to replace an existing transaction in the pending pool if list := pool.pending[from]; list != nil && list.Overlaps(tx) { // Nonce already pending, check if required price bump is met @@ -706,28 +746,28 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e pool.priced.Removed(1) pendingReplaceMeter.Mark(1) } - pool.all.Add(tx) - pool.priced.Put(tx) + pool.all.Add(tx, isLocal) + pool.priced.Put(tx, isLocal) pool.journalTx(from, tx) pool.queueTxEvent(tx) log.Trace("Pooled new executable transaction", "hash", hash, "from", from, "to", tx.To()) + + // Successful promotion, bump the heartbeat + pool.beats[from] = time.Now() return old != nil, nil } - // New transaction isn't replacing a pending one, push into queue - replaced, err = pool.enqueueTx(hash, tx) + replaced, err = pool.enqueueTx(hash, tx, isLocal, true) if err != nil { return false, err } - // Mark local addresses and journal local transactions - if local { - if !pool.locals.contains(from) { - log.Info("Setting new local account", "address", from) - pool.locals.add(from) - } + if local && !pool.locals.contains(from) { + log.Info("Setting new local account", "address", from) + pool.locals.add(from) + pool.priced.Removed(pool.all.RemoteToLocals(pool.locals)) // Migrate the remotes if it's marked as local first time. } - if local || pool.locals.contains(from) { + if isLocal { localGauge.Inc(1) } pool.journalTx(from, tx) @@ -739,7 +779,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e // enqueueTx inserts a new transaction into the non-executable transaction queue. // // Note, this method assumes the pool lock is held! -func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, error) { +func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction, local bool, addAll bool) (bool, error) { // Try to insert the transaction into the future queue from, _ := types.Sender(pool.signer, tx) // already validated if pool.queue[from] == nil { @@ -760,9 +800,18 @@ func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, er // Nothing was replaced, bump the queued counter queuedGauge.Inc(1) } - if pool.all.Get(hash) == nil { - pool.all.Add(tx) - pool.priced.Put(tx) + // If the transaction isn't in lookup set but it's expected to be there, + // show the error log. + if pool.all.Get(hash) == nil && !addAll { + log.Error("Missing transaction in lookup set, please report the issue", "hash", hash) + } + if addAll { + pool.all.Add(tx, local) + pool.priced.Put(tx, local) + } + // If we never record the heartbeat, do it right now. + if _, exist := pool.beats[from]; !exist { + pool.beats[from] = time.Now() } return old != nil, nil } @@ -803,30 +852,26 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T if old != nil { pool.all.Remove(old.Hash()) pool.priced.Removed(1) - pendingReplaceMeter.Mark(1) } else { // Nothing was replaced, bump the pending counter pendingGauge.Inc(1) } - // Failsafe to work around direct pending inserts (tests) - if pool.all.Get(hash) == nil { - pool.all.Add(tx) - pool.priced.Put(tx) - } // Set the potentially new pending nonce and notify any subsystems of the new tx - pool.beats[addr] = time.Now() pool.pendingNonces.set(addr, tx.Nonce()+1) + // Successful promotion, bump the heartbeat + pool.beats[addr] = time.Now() return true } -func (pool *TxPool) promoteSpecialTx(addr common.Address, tx *types.Transaction) (bool, error) { +func (pool *TxPool) promoteSpecialTx(addr common.Address, tx *types.Transaction, isLocal bool) (bool, error) { // Try to insert the transaction into the pending queue if pool.pending[addr] == nil { pool.pending[addr] = newTxList(true) } list := pool.pending[addr] + old := list.txs.Get(tx.Nonce()) if old != nil && old.IsSpecialTransaction() { return false, ErrDuplicateSpecialTransaction @@ -849,7 +894,7 @@ func (pool *TxPool) promoteSpecialTx(addr common.Address, tx *types.Transaction) } // Failsafe to work around direct pending inserts (tests) if pool.all.Get(tx.Hash()) == nil { - pool.all.Add(tx) + pool.all.Add(tx, isLocal) } // Set the potentially new pending nonce and notify any subsystems of the new tx pool.beats[addr] = time.Now() @@ -889,7 +934,7 @@ func (pool *TxPool) AddRemotesSync(txs []*types.Transaction) []error { } // This is like AddRemotes with a single transaction, but waits for pool reorganization. Tests use this method. -func (pool *TxPool) AddRemoteSync(tx *types.Transaction) error { +func (pool *TxPool) addRemoteSync(tx *types.Transaction) error { errs := pool.AddRemotesSync([]*types.Transaction{tx}) return errs[0] } @@ -905,15 +950,48 @@ func (pool *TxPool) AddRemote(tx *types.Transaction) error { // addTxs attempts to queue a batch of transactions if they are valid. func (pool *TxPool) addTxs(txs []*types.Transaction, local, sync bool) []error { - // Cache senders in transactions before obtaining lock (pool.signer is immutable) - for _, tx := range txs { - types.Sender(pool.signer, tx) + // Filter out known ones without obtaining the pool lock or recovering signatures + var ( + errs = make([]error, len(txs)) + news = make([]*types.Transaction, 0, len(txs)) + ) + for i, tx := range txs { + // If the transaction is known, pre-set the error slot + if pool.all.Get(tx.Hash()) != nil { + errs[i] = ErrAlreadyKnown + knownTxMeter.Mark(1) + continue + } + // Exclude transactions with invalid signatures as soon as + // possible and cache senders in transactions before + // obtaining lock + _, err := types.Sender(pool.signer, tx) + if err != nil { + errs[i] = ErrInvalidSender + invalidTxMeter.Mark(1) + continue + } + // Accumulate all unknown transactions for deeper processing + news = append(news, tx) + } + if len(news) == 0 { + return errs } + // Process all the new transaction and merge any errors into the original slice pool.mu.Lock() - errs, dirtyAddrs := pool.addTxsLocked(txs, local) + newErrs, dirtyAddrs := pool.addTxsLocked(news, local) pool.mu.Unlock() + var nilSlot = 0 + for _, err := range newErrs { + for errs[nilSlot] != nil { + nilSlot++ + } + errs[nilSlot] = err + nilSlot++ + } + // Reorg the pool internals if needed and return done := pool.requestPromoteExecutables(dirtyAddrs) if sync { <-done @@ -933,26 +1011,29 @@ func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) ([]error, dirty.addTx(tx) } } - validMeter.Mark(int64(len(dirty.accounts))) + validTxMeter.Mark(int64(len(dirty.accounts))) return errs, dirty } // Status returns the status (unknown/pending/queued) of a batch of transactions // identified by their hashes. func (pool *TxPool) Status(hashes []common.Hash) []TxStatus { - pool.mu.RLock() - defer pool.mu.RUnlock() - status := make([]TxStatus, len(hashes)) for i, hash := range hashes { - if tx := pool.all.Get(hash); tx != nil { - from, _ := types.Sender(pool.signer, tx) // already validated - if pool.pending[from] != nil && pool.pending[from].txs.items[tx.Nonce()] != nil { - status[i] = TxStatusPending - } else { - status[i] = TxStatusQueued - } + tx := pool.Get(hash) + if tx == nil { + continue } + from, _ := types.Sender(pool.signer, tx) // already validated + pool.mu.RLock() + if txList := pool.pending[from]; txList != nil && txList.txs.items[tx.Nonce()] != nil { + status[i] = TxStatusPending + } else if txList := pool.queue[from]; txList != nil && txList.txs.items[tx.Nonce()] != nil { + status[i] = TxStatusQueued + } + // implicit else: the tx may have been included into a block between + // checking pool.Get and obtaining the lock. In that case, TxStatusUnknown is correct + pool.mu.RUnlock() } return status } @@ -962,6 +1043,12 @@ func (pool *TxPool) Get(hash common.Hash) *types.Transaction { return pool.all.Get(hash) } +// Has returns an indicator whether txpool has a transaction cached with the +// given hash. +func (pool *TxPool) Has(hash common.Hash) bool { + return pool.all.Get(hash) != nil +} + // removeTx removes a single transaction from the queue, moving all subsequent // transactions back to the future queue. func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { @@ -986,11 +1073,11 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { // If no more pending transactions are left, remove the list if pending.Empty() { delete(pool.pending, addr) - delete(pool.beats, addr) } // Postpone any invalidated transactions for _, tx := range invalids { - pool.enqueueTx(tx.Hash(), tx) + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(tx.Hash(), tx, false, false) } // Update the account nonce if needed pool.pendingNonces.setIfLower(addr, tx.Nonce()) @@ -1007,11 +1094,12 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { } if future.Empty() { delete(pool.queue, addr) + delete(pool.beats, addr) } } } -// requestPromoteExecutables requests a pool reset to the new head block. +// requestReset requests a pool reset to the new head block. // The returned channel is closed when the reset has occurred. func (pool *TxPool) requestReset(oldHead *types.Header, newHead *types.Header) chan struct{} { select { @@ -1118,7 +1206,10 @@ func (pool *TxPool) runReorg(done chan struct{}, reset *txpoolResetRequest, dirt defer close(done) var promoteAddrs []common.Address - if dirtyAccounts != nil { + if dirtyAccounts != nil && reset == nil { + // Only dirty accounts need to be promoted, unless we're resetting. + // For resets, all addresses in the tx queue will be promoted and + // the flatten operation can be avoided. promoteAddrs = dirtyAccounts.flatten() } pool.mu.Lock() @@ -1134,20 +1225,14 @@ func (pool *TxPool) runReorg(done chan struct{}, reset *txpoolResetRequest, dirt } } // Reset needs promote for all addresses - promoteAddrs = promoteAddrs[:0] + promoteAddrs = make([]common.Address, 0, len(pool.queue)) for addr := range pool.queue { promoteAddrs = append(promoteAddrs, addr) } } // Check for pending transactions for every account that sent new ones promoted := pool.promoteExecutables(promoteAddrs) - for _, tx := range promoted { - addr, _ := types.Sender(pool.signer, tx) - if _, ok := events[addr]; !ok { - events[addr] = newTxSortedMap() - } - events[addr].Put(tx) - } + // If a new block appeared, validate the pool of pending transactions. This will // remove any transaction that has been included in the block or was invalidated // because of another transaction (e.g. higher gas price). @@ -1160,12 +1245,19 @@ func (pool *TxPool) runReorg(done chan struct{}, reset *txpoolResetRequest, dirt // Update all accounts to the latest known pending nonce for addr, list := range pool.pending { - txs := list.Flatten() // Heavy but will be cached and is needed by the miner anyway - pool.pendingNonces.set(addr, txs[len(txs)-1].Nonce()+1) + highestPending := list.LastElement() + pool.pendingNonces.set(addr, highestPending.Nonce()+1) } pool.mu.Unlock() // Notify subsystems for newly added transactions + for _, tx := range promoted { + addr, _ := types.Sender(pool.signer, tx) + if _, ok := events[addr]; !ok { + events[addr] = newTxSortedMap() + } + events[addr].Put(tx) + } if len(events) > 0 { var txs []*types.Transaction for _, set := range events { @@ -1200,44 +1292,45 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) { // head from the chain. // If that is the case, we don't have the lost transactions any more, and // there's nothing to add - if newNum < oldNum { - // If the reorg ended up on a lower number, it's indicative of setHead being the cause - log.Debug("Skipping transaction reset caused by setHead", - "old", oldHead.Hash(), "oldnum", oldNum, "new", newHead.Hash(), "newnum", newNum) - } else { + if newNum >= oldNum { // If we reorged to a same or higher number, then it's not a case of setHead log.Warn("Transaction pool reset with missing oldhead", "old", oldHead.Hash(), "oldnum", oldNum, "new", newHead.Hash(), "newnum", newNum) - } - return - } - for rem.NumberU64() > add.NumberU64() { - discarded = append(discarded, rem.Transactions()...) - if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { - log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) return } - } - for add.NumberU64() > rem.NumberU64() { - included = append(included, add.Transactions()...) - if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { - log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) - return + // If the reorg ended up on a lower number, it's indicative of setHead being the cause + log.Debug("Skipping transaction reset caused by setHead", + "old", oldHead.Hash(), "oldnum", oldNum, "new", newHead.Hash(), "newnum", newNum) + // We still need to update the current state s.th. the lost transactions can be readded by the user + } else { + for rem.NumberU64() > add.NumberU64() { + discarded = append(discarded, rem.Transactions()...) + if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { + log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) + return + } } - } - for rem.Hash() != add.Hash() { - discarded = append(discarded, rem.Transactions()...) - if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { - log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) - return + for add.NumberU64() > rem.NumberU64() { + included = append(included, add.Transactions()...) + if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { + log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) + return + } } - included = append(included, add.Transactions()...) - if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { - log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) - return + for rem.Hash() != add.Hash() { + discarded = append(discarded, rem.Transactions()...) + if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { + log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) + return + } + included = append(included, add.Transactions()...) + if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { + log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) + return + } } + reinject = types.TxDifference(discarded, included) } - reinject = types.TxDifference(discarded, included) } } // Initialize the internal state to the current head @@ -1258,6 +1351,10 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) { log.Debug("Reinjecting stale transactions", "count", len(reinject)) senderCacher.recover(pool.signer, reinject) pool.addTxsLocked(reinject, false) + + // Update all fork indicator by next pending block number. + next := new(big.Int).Add(newHead.Number, big.NewInt(1)) + pool.eip2718 = pool.chainconfig.IsEIP1559(next) } // promoteExecutables moves transactions that have become processable from the @@ -1283,8 +1380,8 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans for _, tx := range forwards { hash := tx.Hash() pool.all.Remove(hash) - log.Trace("Removed old queued transaction", "hash", hash) } + log.Trace("Removed old queued transactions", "count", len(forwards)) // Drop all transactions that are too costly (low balance or out of gas) var number *big.Int = nil if pool.chain.CurrentHeader() != nil { @@ -1294,8 +1391,8 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans for _, tx := range drops { hash := tx.Hash() pool.all.Remove(hash) - log.Trace("Removed unpayable queued transaction", "hash", hash) } + log.Trace("Removed unpayable queued transactions", "count", len(drops)) queuedNofundsMeter.Mark(int64(len(drops))) // Gather all executable transactions and promote them @@ -1303,10 +1400,10 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans for _, tx := range readies { hash := tx.Hash() if pool.promoteTx(addr, hash, tx) { - log.Trace("Promoting queued transaction", "hash", hash) promoted = append(promoted, tx) } } + log.Trace("Promoted queued transactions", "count", len(promoted)) queuedGauge.Dec(int64(len(readies))) // Drop all transactions over the allowed limit @@ -1329,6 +1426,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans // Delete the entire queue entry if it became empty. if list.Empty() { delete(pool.queue, addr) + delete(pool.beats, addr) } } return promoted @@ -1498,7 +1596,9 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range invalids { hash := tx.Hash() log.Trace("Demoting pending transaction", "hash", hash) - pool.enqueueTx(hash, tx) + + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(hash, tx, false, false) } pendingGauge.Dec(int64(len(olds) + len(drops) + len(invalids))) if pool.locals.contains(addr) { @@ -1510,14 +1610,17 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range gapped { hash := tx.Hash() log.Warn("Demoting invalidated transaction", "hash", hash) - pool.enqueueTx(hash, tx) + + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(hash, tx, false, false) } pendingGauge.Dec(int64(len(gapped))) + // This might happen in a reorg, so log it to the metering + blockReorgInvalidatedTx.Mark(int64(len(gapped))) } - // Delete the entire queue entry if it became empty. + // Delete the entire pending entry if it became empty. if list.Empty() { delete(pool.pending, addr) - delete(pool.beats, addr) } } } @@ -1561,6 +1664,10 @@ func (as *accountSet) contains(addr common.Address) bool { return exist } +func (as *accountSet) empty() bool { + return len(as.accounts) == 0 +} + // containsTx checks if the sender of a given tx is within the set. If the sender // cannot be derived, this method returns false. func (as *accountSet) containsTx(tx *types.Transaction) bool { @@ -1604,8 +1711,8 @@ func (as *accountSet) merge(other *accountSet) { as.cache = nil } -// txLookup is used internally by TxPool to track transactions while allowing lookup without -// mutex contention. +// txLookup is used internally by TxPool to track transactions while allowing +// lookup without mutex contention. // // Note, although this type is properly protected against concurrent access, it // is **not** a type that should ever be mutated or even exposed outside of the @@ -1613,26 +1720,43 @@ func (as *accountSet) merge(other *accountSet) { // internal mechanisms. The sole purpose of the type is to permit out-of-bound // peeking into the pool in TxPool.Get without having to acquire the widely scoped // TxPool.mu mutex. +// +// This lookup set combines the notion of "local transactions", which is useful +// to build upper-level structure. type txLookup struct { - all map[common.Hash]*types.Transaction - lock sync.RWMutex + slots int + lock sync.RWMutex + locals map[common.Hash]*types.Transaction + remotes map[common.Hash]*types.Transaction } // newTxLookup returns a new txLookup structure. func newTxLookup() *txLookup { return &txLookup{ - all: make(map[common.Hash]*types.Transaction), + locals: make(map[common.Hash]*types.Transaction), + remotes: make(map[common.Hash]*types.Transaction), } } -// Range calls f on each key and value present in the map. -func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction) bool) { +// Range calls f on each key and value present in the map. The callback passed +// should return the indicator whether the iteration needs to be continued. +// Callers need to specify which set (or both) to be iterated. +func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction, local bool) bool, local bool, remote bool) { t.lock.RLock() defer t.lock.RUnlock() - for key, value := range t.all { - if !f(key, value) { - break + if local { + for key, value := range t.locals { + if !f(key, value, true) { + return + } + } + } + if remote { + for key, value := range t.remotes { + if !f(key, value, false) { + return + } } } } @@ -1642,23 +1766,73 @@ func (t *txLookup) Get(hash common.Hash) *types.Transaction { t.lock.RLock() defer t.lock.RUnlock() - return t.all[hash] + if tx := t.locals[hash]; tx != nil { + return tx + } + return t.remotes[hash] } -// Count returns the current number of items in the lookup. +// GetLocal returns a transaction if it exists in the lookup, or nil if not found. +func (t *txLookup) GetLocal(hash common.Hash) *types.Transaction { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.locals[hash] +} + +// GetRemote returns a transaction if it exists in the lookup, or nil if not found. +func (t *txLookup) GetRemote(hash common.Hash) *types.Transaction { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.remotes[hash] +} + +// Count returns the current number of transactions in the lookup. func (t *txLookup) Count() int { t.lock.RLock() defer t.lock.RUnlock() - return len(t.all) + return len(t.locals) + len(t.remotes) +} + +// LocalCount returns the current number of local transactions in the lookup. +func (t *txLookup) LocalCount() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return len(t.locals) +} + +// RemoteCount returns the current number of remote transactions in the lookup. +func (t *txLookup) RemoteCount() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return len(t.remotes) +} + +// Slots returns the current number of slots used in the lookup. +func (t *txLookup) Slots() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.slots } // Add adds a transaction to the lookup. -func (t *txLookup) Add(tx *types.Transaction) { +func (t *txLookup) Add(tx *types.Transaction, local bool) { t.lock.Lock() defer t.lock.Unlock() - t.all[tx.Hash()] = tx + t.slots += numSlots(tx) + slotsGauge.Update(int64(t.slots)) + + if local { + t.locals[tx.Hash()] = tx + } else { + t.remotes[tx.Hash()] = tx + } } // Remove removes a transaction from the lookup. @@ -1666,5 +1840,39 @@ func (t *txLookup) Remove(hash common.Hash) { t.lock.Lock() defer t.lock.Unlock() - delete(t.all, hash) + tx, ok := t.locals[hash] + if !ok { + tx, ok = t.remotes[hash] + } + if !ok { + log.Error("No transaction found to be deleted", "hash", hash) + return + } + t.slots -= numSlots(tx) + slotsGauge.Update(int64(t.slots)) + + delete(t.locals, hash) + delete(t.remotes, hash) +} + +// RemoteToLocals migrates the transactions belongs to the given locals to locals +// set. The assumption is held the locals set is thread-safe to be used. +func (t *txLookup) RemoteToLocals(locals *accountSet) int { + t.lock.Lock() + defer t.lock.Unlock() + + var migrated int + for hash, tx := range t.remotes { + if locals.containsTx(tx) { + t.locals[hash] = tx + delete(t.remotes, hash) + migrated += 1 + } + } + return migrated +} + +// numSlots calculates the number of slots needed for a single transaction. +func numSlots(tx *types.Transaction) int { + return int((tx.Size() + txSlotSize - 1) / txSlotSize) } diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index b494d93ce5..dbbdd24baa 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -94,10 +94,18 @@ func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ec return tx } +func pricedDataTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ecdsa.PrivateKey, bytes uint64) *types.Transaction { + data := make([]byte, bytes) + rand.Read(data) + + tx, _ := types.SignTx(types.NewTransaction(nonce, common.Address{}, big.NewInt(0), gaslimit, gasprice, data), types.HomesteadSigner{}, key) + return tx +} + func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { diskdb := rawdb.NewMemoryDatabase() statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb)) - blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} + blockchain := &testBlockChain{statedb, 10000000, new(event.Feed)} key, _ := crypto.GenerateKey() pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) @@ -115,8 +123,10 @@ func validateTxPoolInternals(pool *TxPool) error { if total := pool.all.Count(); total != pending+queued { return fmt.Errorf("total transaction count %d != %d pending + %d queued", total, pending, queued) } - if priced := pool.priced.items.Len() - pool.priced.stales; priced != pending+queued { - return fmt.Errorf("total priced transaction count %d != %d pending + %d queued", priced, pending, queued) + pool.priced.Reheap() + priced, remote := pool.priced.remotes.Len(), pool.all.RemoteCount() + if priced != remote { + return fmt.Errorf("total priced transaction count %d != %d", priced, remote) } // Ensure the next nonce to assign is the correct one for addr, txs := range pool.pending { @@ -289,7 +299,7 @@ func TestTransactionQueue(t *testing.T) { pool.currentState.AddBalance(from, big.NewInt(1000)) <-pool.requestReset(nil, nil) - pool.enqueueTx(tx.Hash(), tx) + pool.enqueueTx(tx.Hash(), tx, false, true) <-pool.requestPromoteExecutables(newAccountSet(pool.signer, from)) if len(pool.pending) != 1 { t.Error("expected valid txs to be 1 is", len(pool.pending)) @@ -298,7 +308,7 @@ func TestTransactionQueue(t *testing.T) { tx = transaction(1, 100, key) from, _ = deriveSender(tx) pool.currentState.SetNonce(from, 2) - pool.enqueueTx(tx.Hash(), tx) + pool.enqueueTx(tx.Hash(), tx, false, true) <-pool.requestPromoteExecutables(newAccountSet(pool.signer, from)) if _, ok := pool.pending[from].txs.items[tx.Nonce()]; ok { @@ -323,9 +333,9 @@ func TestTransactionQueue2(t *testing.T) { pool.currentState.AddBalance(from, big.NewInt(1000)) pool.reset(nil, nil) - pool.enqueueTx(tx1.Hash(), tx1) - pool.enqueueTx(tx2.Hash(), tx2) - pool.enqueueTx(tx3.Hash(), tx3) + pool.enqueueTx(tx1.Hash(), tx1, false, true) + pool.enqueueTx(tx2.Hash(), tx2, false, true) + pool.enqueueTx(tx3.Hash(), tx3, false, true) pool.promoteExecutables([]common.Address{from}) if len(pool.pending) != 1 { @@ -488,7 +498,7 @@ func TestTransactionDropping(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000)) // Add some pending and some queued transactions @@ -500,12 +510,21 @@ func TestTransactionDropping(t *testing.T) { tx11 = transaction(11, 200, key) tx12 = transaction(12, 300, key) ) + pool.all.Add(tx0, false) + pool.priced.Put(tx0, false) pool.promoteTx(account, tx0.Hash(), tx0) + + pool.all.Add(tx1, false) + pool.priced.Put(tx1, false) pool.promoteTx(account, tx1.Hash(), tx1) + + pool.all.Add(tx2, false) + pool.priced.Put(tx2, false) pool.promoteTx(account, tx2.Hash(), tx2) - pool.enqueueTx(tx10.Hash(), tx10) - pool.enqueueTx(tx11.Hash(), tx11) - pool.enqueueTx(tx12.Hash(), tx12) + + pool.enqueueTx(tx10.Hash(), tx10, false, true) + pool.enqueueTx(tx11.Hash(), tx11, false, true) + pool.enqueueTx(tx12.Hash(), tx12, false, true) // Check that pre and post validations leave the pool as is if pool.pending[account].Len() != 3 { @@ -698,7 +717,7 @@ func TestTransactionGapFilling(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) // Keep track of transaction events to ensure all executables get announced @@ -725,7 +744,7 @@ func TestTransactionGapFilling(t *testing.T) { t.Fatalf("pool internal state corrupted: %v", err) } // Fill the nonce gap and ensure all transactions become pending - if err := pool.AddRemoteSync(transaction(1, 100000, key)); err != nil { + if err := pool.addRemoteSync(transaction(1, 100000, key)); err != nil { t.Fatalf("failed to add gapped transaction: %v", err) } pending, queued = pool.Stats() @@ -752,12 +771,12 @@ func TestTransactionQueueAccountLimiting(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) testTxPoolConfig.AccountQueue = 10 // Keep queuing up transactions and make sure all above a limit are dropped for i := uint64(1); i <= testTxPoolConfig.AccountQueue; i++ { - if err := pool.AddRemoteSync(transaction(i, 100000, key)); err != nil { + if err := pool.addRemoteSync(transaction(i, 100000, key)); err != nil { t.Fatalf("tx %d: failed to add transaction: %v", i, err) } if len(pool.pending) != 0 { @@ -884,7 +903,7 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { common.MinGasPrice = big.NewInt(0) // Reduce the eviction interval to a testable amount defer func(old time.Duration) { evictionInterval = old }(evictionInterval) - evictionInterval = time.Second + evictionInterval = time.Millisecond * 100 // Create the pool to test the non-expiration enforcement db := rawdb.NewMemoryDatabase() @@ -922,6 +941,22 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { if err := validateTxPoolInternals(pool); err != nil { t.Fatalf("pool internal state corrupted: %v", err) } + + // Allow the eviction interval to run + time.Sleep(2 * evictionInterval) + + // Transactions should not be evicted from the queue yet since lifetime duration has not passed + pending, queued = pool.Stats() + if pending != 0 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 0) + } + if queued != 2 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 2) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } + // Wait a bit for eviction to run and clean up any leftovers, and ensure only the local remains time.Sleep(2 * config.Lifetime) @@ -941,6 +976,72 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { if err := validateTxPoolInternals(pool); err != nil { t.Fatalf("pool internal state corrupted: %v", err) } + + // remove current transactions and increase nonce to prepare for a reset and cleanup + statedb.SetNonce(crypto.PubkeyToAddress(remote.PublicKey), 2) + statedb.SetNonce(crypto.PubkeyToAddress(local.PublicKey), 2) + <-pool.requestReset(nil, nil) + + // make sure queue, pending are cleared + pending, queued = pool.Stats() + if pending != 0 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 0) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } + + // Queue gapped transactions + if err := pool.AddLocal(pricedTransaction(4, 100000, big.NewInt(1), local)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + if err := pool.addRemoteSync(pricedTransaction(4, 100000, big.NewInt(1), remote)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + time.Sleep(5 * evictionInterval) // A half lifetime pass + + // Queue executable transactions, the life cycle should be restarted. + if err := pool.AddLocal(pricedTransaction(2, 100000, big.NewInt(1), local)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + if err := pool.addRemoteSync(pricedTransaction(2, 100000, big.NewInt(1), remote)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + time.Sleep(6 * evictionInterval) + + // All gapped transactions shouldn't be kicked out + pending, queued = pool.Stats() + if pending != 2 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 2) + } + if queued != 2 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 3) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } + + // The whole life time pass after last promotion, kick out stale transactions + time.Sleep(2 * config.Lifetime) + pending, queued = pool.Stats() + if pending != 2 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 2) + } + if nolocals { + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + } else { + if queued != 1 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 1) + } + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } } // Tests that even if the transaction count belonging to a single account goes @@ -953,7 +1054,7 @@ func TestTransactionPendingLimiting(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) testTxPoolConfig.AccountQueue = 10 // Keep track of transaction events to ensure all executables get announced @@ -963,7 +1064,7 @@ func TestTransactionPendingLimiting(t *testing.T) { // Keep queuing up transactions and make sure all above a limit are dropped for i := uint64(0); i < testTxPoolConfig.AccountQueue; i++ { - if err := pool.AddRemoteSync(transaction(i, 100000, key)); err != nil { + if err := pool.addRemoteSync(transaction(i, 100000, key)); err != nil { t.Fatalf("tx %d: failed to add transaction: %v", i, err) } if pool.pending[account].Len() != int(i)+1 { @@ -1033,6 +1134,62 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) { } } +// Test the limit on transaction size is enforced correctly. +// This test verifies every transaction having allowed size +// is added to the pool, and longer transactions are rejected. +func TestTransactionAllowedTxSize(t *testing.T) { + t.Parallel() + + // Create a test account and fund it + pool, key := setupTxPool() + defer pool.Stop() + + account := crypto.PubkeyToAddress(key.PublicKey) + pool.currentState.AddBalance(account, big.NewInt(1000000000)) + + // Compute maximal data size for transactions (lower bound). + // + // It is assumed the fields in the transaction (except of the data) are: + // - nonce <= 32 bytes + // - gasPrice <= 32 bytes + // - gasLimit <= 32 bytes + // - recipient == 20 bytes + // - value <= 32 bytes + // - signature == 65 bytes + // All those fields are summed up to at most 213 bytes. + baseSize := uint64(213) + dataSize := txMaxSize - baseSize + + // Try adding a transaction with maximal allowed size + tx := pricedDataTransaction(0, pool.currentMaxGas, big.NewInt(1), key, dataSize) + if err := pool.addRemoteSync(tx); err != nil { + t.Fatalf("failed to add transaction of size %d, close to maximal: %v", int(tx.Size()), err) + } + // Try adding a transaction with random allowed size + if err := pool.addRemoteSync(pricedDataTransaction(1, pool.currentMaxGas, big.NewInt(1), key, uint64(rand.Intn(int(dataSize))))); err != nil { + t.Fatalf("failed to add transaction of random allowed size: %v", err) + } + // Try adding a transaction of minimal not allowed size + if err := pool.addRemoteSync(pricedDataTransaction(2, pool.currentMaxGas, big.NewInt(1), key, txMaxSize)); err == nil { + t.Fatalf("expected rejection on slightly oversize transaction") + } + // Try adding a transaction of random not allowed size + if err := pool.addRemoteSync(pricedDataTransaction(2, pool.currentMaxGas, big.NewInt(1), key, dataSize+1+uint64(rand.Intn(int(10*txMaxSize))))); err == nil { + t.Fatalf("expected rejection on oversize transaction") + } + // Run some sanity checks on the pool internals + pending, queued := pool.Stats() + if pending != 2 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 2) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } +} + // Tests that if transactions start being capped, transactions are also removed from 'all' func TestTransactionCapClearsFromAll(t *testing.T) { t.Parallel() @@ -1458,7 +1615,7 @@ func TestTransactionPoolStableUnderpricing(t *testing.T) { t.Fatalf("pool internal state corrupted: %v", err) } // Ensure that adding high priced transactions drops a cheap, but doesn't produce a gap - if err := pool.AddRemoteSync(pricedTransaction(0, 100000, big.NewInt(3), keys[1])); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(3), keys[1])); err != nil { t.Fatalf("failed to add well priced transaction: %v", err) } pending, queued = pool.Stats() @@ -1476,6 +1633,71 @@ func TestTransactionPoolStableUnderpricing(t *testing.T) { } } +// Tests that the pool rejects duplicate transactions. +func TestTransactionDeduplication(t *testing.T) { + t.Parallel() + + // Create the pool to test the pricing enforcement with + statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) + blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} + + pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) + defer pool.Stop() + + // Create a test account to add transactions with + key, _ := crypto.GenerateKey() + pool.currentState.AddBalance(crypto.PubkeyToAddress(key.PublicKey), big.NewInt(1000000000)) + + // Create a batch of transactions and add a few of them + txs := make([]*types.Transaction, common.LimitThresholdNonceInQueue) + for i := 0; i < len(txs); i++ { + txs[i] = pricedTransaction(uint64(i), 100000, big.NewInt(1), key) + } + var firsts []*types.Transaction + for i := 0; i < len(txs); i += 2 { + firsts = append(firsts, txs[i]) + } + errs := pool.AddRemotesSync(firsts) + if len(errs) != len(firsts) { + t.Fatalf("first add mismatching result count: have %d, want %d", len(errs), len(firsts)) + } + for i, err := range errs { + if err != nil { + t.Errorf("add %d failed: %v", i, err) + } + } + pending, queued := pool.Stats() + if pending != 1 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 1) + } + if queued != len(txs)/2-1 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, len(txs)/2-1) + } + // Try to add all of them now and ensure previous ones error out as knowns + errs = pool.AddRemotesSync(txs) + if len(errs) != len(txs) { + t.Fatalf("all add mismatching result count: have %d, want %d", len(errs), len(txs)) + } + for i, err := range errs { + if i%2 == 0 && err == nil { + t.Errorf("add %d succeeded, should have failed as known", i) + } + if i%2 == 1 && err != nil { + t.Errorf("add %d failed: %v", i, err) + } + } + pending, queued = pool.Stats() + if pending != len(txs) { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, len(txs)) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } +} + // Tests that the pool rejects replacement transactions that don't meet the minimum // price bump required. func TestTransactionReplacement(t *testing.T) { @@ -1502,7 +1724,7 @@ func TestTransactionReplacement(t *testing.T) { price := int64(100) threshold := (price * (100 + int64(testTxPoolConfig.PriceBump))) / 100 - if err := pool.AddRemoteSync(pricedTransaction(0, 100000, big.NewInt(1), key)); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(1), key)); err != nil { t.Fatalf("failed to add original cheap pending transaction: %v", err) } if err := pool.AddRemote(pricedTransaction(0, 100001, big.NewInt(1), key)); err != ErrReplaceUnderpriced { @@ -1515,7 +1737,7 @@ func TestTransactionReplacement(t *testing.T) { t.Fatalf("cheap replacement event firing failed: %v", err) } - if err := pool.AddRemoteSync(pricedTransaction(0, 100000, big.NewInt(price), key)); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(price), key)); err != nil { t.Fatalf("failed to add original proper pending transaction: %v", err) } if err := pool.AddRemote(pricedTransaction(0, 100001, big.NewInt(threshold-1), key)); err != ErrReplaceUnderpriced { @@ -1606,7 +1828,7 @@ func testTransactionJournaling(t *testing.T, nolocals bool) { if err := pool.AddLocal(pricedTransaction(2, 100000, big.NewInt(1), local)); err != nil { t.Fatalf("failed to add local transaction: %v", err) } - if err := pool.AddRemoteSync(pricedTransaction(0, 100000, big.NewInt(1), remote)); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(1), remote)); err != nil { t.Fatalf("failed to add remote transaction: %v", err) } pending, queued := pool.Stats() @@ -1728,6 +1950,24 @@ func TestTransactionStatusCheck(t *testing.T) { } } +// Test the transaction slots consumption is computed correctly +func TestTransactionSlotCount(t *testing.T) { + t.Parallel() + + key, _ := crypto.GenerateKey() + + // Check that an empty transaction consumes a single slot + smallTx := pricedDataTransaction(0, 0, big.NewInt(0), key, 0) + if slots := numSlots(smallTx); slots != 1 { + t.Fatalf("small transactions slot count mismatch: have %d want %d", slots, 1) + } + // Check that a large transaction consumes the correct number of slots + bigTx := pricedDataTransaction(0, 0, big.NewInt(0), key, uint64(10*txSlotSize)) + if slots := numSlots(bigTx); slots != 11 { + t.Fatalf("big transactions slot count mismatch: have %d want %d", slots, 11) + } +} + // Benchmarks the speed of validating the contents of the pending queue of the // transaction pool. func BenchmarkPendingDemotion100(b *testing.B) { benchmarkPendingDemotion(b, 100) } @@ -1739,7 +1979,7 @@ func benchmarkPendingDemotion(b *testing.B, size int) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) for i := 0; i < size; i++ { @@ -1764,12 +2004,12 @@ func benchmarkFuturePromotion(b *testing.B, size int) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) for i := 0; i < size; i++ { tx := transaction(uint64(1+i), 100000, key) - pool.enqueueTx(tx.Hash(), tx) + pool.enqueueTx(tx.Hash(), tx, false, true) } // Benchmark the speed of pool validation b.ResetTimer() @@ -1779,16 +2019,20 @@ func benchmarkFuturePromotion(b *testing.B, size int) { } // Benchmarks the speed of batched transaction insertion. -func BenchmarkPoolBatchInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100) } -func BenchmarkPoolBatchInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000) } -func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000) } +func BenchmarkPoolBatchInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100, false) } +func BenchmarkPoolBatchInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000, false) } +func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000, false) } -func benchmarkPoolBatchInsert(b *testing.B, size int) { +func BenchmarkPoolBatchLocalInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100, true) } +func BenchmarkPoolBatchLocalInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000, true) } +func BenchmarkPoolBatchLocalInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000, true) } + +func benchmarkPoolBatchInsert(b *testing.B, size int, local bool) { // Generate a batch of transactions to enqueue into the pool pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) batches := make([]types.Transactions, b.N) @@ -1801,6 +2045,45 @@ func benchmarkPoolBatchInsert(b *testing.B, size int) { // Benchmark importing the transactions into the queue b.ResetTimer() for _, batch := range batches { - pool.AddRemotes(batch) + if local { + pool.AddLocals(batch) + } else { + pool.AddRemotes(batch) + } + } +} + +func BenchmarkInsertRemoteWithAllLocals(b *testing.B) { + // Allocate keys for testing + key, _ := crypto.GenerateKey() + account := crypto.PubkeyToAddress(key.PublicKey) + + remoteKey, _ := crypto.GenerateKey() + remoteAddr := crypto.PubkeyToAddress(remoteKey.PublicKey) + + locals := make([]*types.Transaction, 4096+1024) // Occupy all slots + for i := 0; i < len(locals); i++ { + locals[i] = transaction(uint64(i), 100000, key) + } + remotes := make([]*types.Transaction, 1000) + for i := 0; i < len(remotes); i++ { + remotes[i] = pricedTransaction(uint64(i), 100000, big.NewInt(2), remoteKey) // Higher gasprice + } + // Benchmark importing the transactions into the queue + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + pool, _ := setupTxPool() + pool.currentState.AddBalance(account, big.NewInt(100000000)) + for _, local := range locals { + pool.AddLocal(local) + } + b.StartTimer() + // Assign a high enough balance for testing + pool.currentState.AddBalance(remoteAddr, big.NewInt(100000000)) + for i := 0; i < len(remotes); i++ { + pool.AddRemotes([]*types.Transaction{remotes[i]}) + } + pool.Stop() } } diff --git a/core/types/access_list_tx.go b/core/types/access_list_tx.go new file mode 100644 index 0000000000..f80044e108 --- /dev/null +++ b/core/types/access_list_tx.go @@ -0,0 +1,115 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "math/big" + + "github.com/XinFinOrg/XDPoSChain/common" +) + +//go:generate gencodec -type AccessTuple -out gen_access_tuple.go + +// AccessList is an EIP-2930 access list. +type AccessList []AccessTuple + +// AccessTuple is the element type of an access list. +type AccessTuple struct { + Address common.Address `json:"address" gencodec:"required"` + StorageKeys []common.Hash `json:"storageKeys" gencodec:"required"` +} + +// StorageKeys returns the total number of storage keys in the access list. +func (al AccessList) StorageKeys() int { + sum := 0 + for _, tuple := range al { + sum += len(tuple.StorageKeys) + } + return sum +} + +// AccessListTx is the data of EIP-2930 access list transactions. +type AccessListTx struct { + ChainID *big.Int // destination chain ID + Nonce uint64 // nonce of sender account + GasPrice *big.Int // wei per gas + Gas uint64 // gas limit + To *common.Address `rlp:"nil"` // nil means contract creation + Value *big.Int // wei amount + Data []byte // contract invocation input data + AccessList AccessList // EIP-2930 access list + V, R, S *big.Int // signature values +} + +// copy creates a deep copy of the transaction data and initializes all fields. +func (tx *AccessListTx) copy() TxData { + cpy := &AccessListTx{ + Nonce: tx.Nonce, + To: tx.To, // TODO: copy pointed-to address + Data: common.CopyBytes(tx.Data), + Gas: tx.Gas, + // These are copied below. + AccessList: make(AccessList, len(tx.AccessList)), + Value: new(big.Int), + ChainID: new(big.Int), + GasPrice: new(big.Int), + V: new(big.Int), + R: new(big.Int), + S: new(big.Int), + } + copy(cpy.AccessList, tx.AccessList) + if tx.Value != nil { + cpy.Value.Set(tx.Value) + } + if tx.ChainID != nil { + cpy.ChainID.Set(tx.ChainID) + } + if tx.GasPrice != nil { + cpy.GasPrice.Set(tx.GasPrice) + } + if tx.V != nil { + cpy.V.Set(tx.V) + } + if tx.R != nil { + cpy.R.Set(tx.R) + } + if tx.S != nil { + cpy.S.Set(tx.S) + } + return cpy +} + +// accessors for innerTx. + +func (tx *AccessListTx) txType() byte { return AccessListTxType } +func (tx *AccessListTx) chainID() *big.Int { return tx.ChainID } +func (tx *AccessListTx) protected() bool { return true } +func (tx *AccessListTx) accessList() AccessList { return tx.AccessList } +func (tx *AccessListTx) data() []byte { return tx.Data } +func (tx *AccessListTx) gas() uint64 { return tx.Gas } +func (tx *AccessListTx) gasPrice() *big.Int { return tx.GasPrice } +func (tx *AccessListTx) value() *big.Int { return tx.Value } +func (tx *AccessListTx) nonce() uint64 { return tx.Nonce } +func (tx *AccessListTx) to() *common.Address { return tx.To } + +func (tx *AccessListTx) rawSignatureValues() (v, r, s *big.Int) { + return tx.V, tx.R, tx.S +} + +func (tx *AccessListTx) setSignatureValues(chainID, v, r, s *big.Int) { + tx.ChainID, tx.V, tx.R, tx.S = chainID, v, r, s +} diff --git a/core/types/block.go b/core/types/block.go index cbeb565373..071666a801 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -29,7 +29,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" - "github.com/XinFinOrg/XDPoSChain/crypto/sha3" "github.com/XinFinOrg/XDPoSChain/rlp" ) @@ -155,13 +154,6 @@ func (h *Header) Size() common.StorageSize { return common.StorageSize(unsafe.Sizeof(*h)) + common.StorageSize(len(h.Extra)+(h.Difficulty.BitLen()+h.Number.BitLen()+h.Time.BitLen())/8) } -func rlpHash(x interface{}) (h common.Hash) { - hw := sha3.NewKeccak256() - rlp.Encode(hw, x) - hw.Sum(h[:0]) - return h -} - // Body is a simple (mutable, non-safe) data container for storing and moving // a block's data contents (transactions and uncles) together. type Body struct { diff --git a/core/types/block_test.go b/core/types/block_test.go index c95aaae717..3cb50180b9 100644 --- a/core/types/block_test.go +++ b/core/types/block_test.go @@ -17,13 +17,18 @@ package types import ( + "bytes" + "hash" "math/big" + "reflect" "testing" - "bytes" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/math" + "github.com/XinFinOrg/XDPoSChain/crypto" + "github.com/XinFinOrg/XDPoSChain/params" "github.com/XinFinOrg/XDPoSChain/rlp" - "reflect" + "golang.org/x/crypto/sha3" ) // from bcValidBlockTest.json, "SimpleTx" @@ -59,3 +64,152 @@ func TestBlockEncoding(t *testing.T) { t.Errorf("encoded block mismatch:\ngot: %x\nwant: %x", ourBlockEnc, blockEnc) } } + +func TestEIP2718BlockEncoding(t *testing.T) { + blockEnc := common.FromHex("f9031cf90214a00000000000000000000000000000000000000000000000000000000000000000a01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347948888f1f195afa192cfee860698584c030f4c9db1a0ef1552a40b7165c3cd773806b9e0c165b75356e0314bf0706f279c729f51e017a0e6e49996c7ec59f7a23d22b83239a60151512c65613bf84a0d7da336399ebc4aa0cafe75574d59780665a97fbfd11365c7545aa8f1abf4e5e12e8243334ef7286bb901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000820200832fefd882a410845506eb0796636f6f6c65737420626c6f636b206f6e20636861696ea0bd4472abb6659ebe3ee06ee4d7b72a00a9f4d001caca51342001075469aff49888a13a5a8c8f2bb1c4808080f90101f85f800a82c35094095e7baea6a6c7c4c2dfeb977efac326af552d870a801ba09bea4c4daac7c7c52e093e6a4c35dbbcf8856f1af7b059ba20253e70848d094fa08a8fae537ce25ed8cb5af9adac3f141af69bd515bd2ba031522df09b97dd72b1b89e01f89b01800a8301e24194095e7baea6a6c7c4c2dfeb977efac326af552d878080f838f7940000000000000000000000000000000000000001e1a0000000000000000000000000000000000000000000000000000000000000000001a03dbacc8d0259f2508625e97fdfc57cd85fdd16e5821bc2c10bdd1a52649e8335a0476e10695b183a87b0aa292a7f4b78ef0c3fbe62aa2c42c84e1d9c3da159ef14c0") + var block Block + if err := rlp.DecodeBytes(blockEnc, &block); err != nil { + t.Fatal("decode error: ", err) + } + + check := func(f string, got, want interface{}) { + if !reflect.DeepEqual(got, want) { + t.Errorf("%s mismatch: got %v, want %v", f, got, want) + } + } + check("Difficulty", block.Difficulty(), big.NewInt(131072)) + check("GasLimit", block.GasLimit(), uint64(3141592)) + check("GasUsed", block.GasUsed(), uint64(42000)) + check("Coinbase", block.Coinbase(), common.HexToAddress("8888f1f195afa192cfee860698584c030f4c9db1")) + check("MixDigest", block.MixDigest(), common.HexToHash("bd4472abb6659ebe3ee06ee4d7b72a00a9f4d001caca51342001075469aff498")) + check("Root", block.Root(), common.HexToHash("ef1552a40b7165c3cd773806b9e0c165b75356e0314bf0706f279c729f51e017")) + check("Nonce", block.Nonce(), uint64(0xa13a5a8c8f2bb1c4)) + check("Time", block.Time().Uint64(), uint64(1426516743)) + check("Size", block.Size(), common.StorageSize(len(blockEnc))) + + // Create legacy tx. + to := common.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87") + tx1 := NewTx(&LegacyTx{ + Nonce: 0, + To: &to, + Value: big.NewInt(10), + Gas: 50000, + GasPrice: big.NewInt(10), + }) + sig := common.Hex2Bytes("9bea4c4daac7c7c52e093e6a4c35dbbcf8856f1af7b059ba20253e70848d094f8a8fae537ce25ed8cb5af9adac3f141af69bd515bd2ba031522df09b97dd72b100") + tx1, _ = tx1.WithSignature(HomesteadSigner{}, sig) + + // Create ACL tx. + addr := common.HexToAddress("0x0000000000000000000000000000000000000001") + tx2 := NewTx(&AccessListTx{ + ChainID: big.NewInt(1), + Nonce: 0, + To: &to, + Gas: 123457, + GasPrice: big.NewInt(10), + AccessList: AccessList{{Address: addr, StorageKeys: []common.Hash{{0}}}}, + }) + sig2 := common.Hex2Bytes("3dbacc8d0259f2508625e97fdfc57cd85fdd16e5821bc2c10bdd1a52649e8335476e10695b183a87b0aa292a7f4b78ef0c3fbe62aa2c42c84e1d9c3da159ef1401") + tx2, _ = tx2.WithSignature(NewEIP2930Signer(big.NewInt(1)), sig2) + + check("len(Transactions)", len(block.Transactions()), 2) + check("Transactions[0].Hash", block.Transactions()[0].Hash(), tx1.Hash()) + check("Transactions[1].Hash", block.Transactions()[1].Hash(), tx2.Hash()) + check("Transactions[1].Type()", block.Transactions()[1].Type(), uint8(AccessListTxType)) + + ourBlockEnc, err := rlp.EncodeToBytes(&block) + if err != nil { + t.Fatal("encode error: ", err) + } + if !bytes.Equal(ourBlockEnc, blockEnc) { + t.Errorf("encoded block mismatch:\ngot: %x\nwant: %x", ourBlockEnc, blockEnc) + } +} + +func TestUncleHash(t *testing.T) { + uncles := make([]*Header, 0) + h := CalcUncleHash(uncles) + exp := common.HexToHash("1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347") + if h != exp { + t.Fatalf("empty uncle hash is wrong, got %x != %x", h, exp) + } +} + +var benchBuffer = bytes.NewBuffer(make([]byte, 0, 32000)) + +func BenchmarkEncodeBlock(b *testing.B) { + block := makeBenchBlock() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchBuffer.Reset() + if err := rlp.Encode(benchBuffer, block); err != nil { + b.Fatal(err) + } + } +} + +// testHasher is the helper tool for transaction/receipt list hashing. +// The original hasher is trie, in order to get rid of import cycle, +// use the testing hasher instead. +type testHasher struct { + hasher hash.Hash +} + +func newHasher() *testHasher { + return &testHasher{hasher: sha3.NewLegacyKeccak256()} +} + +func (h *testHasher) Reset() { + h.hasher.Reset() +} + +func (h *testHasher) Update(key, val []byte) { + h.hasher.Write(key) + h.hasher.Write(val) +} + +func (h *testHasher) Hash() common.Hash { + return common.BytesToHash(h.hasher.Sum(nil)) +} + +func makeBenchBlock() *Block { + var ( + key, _ = crypto.GenerateKey() + txs = make([]*Transaction, 70) + receipts = make([]*Receipt, len(txs)) + signer = LatestSigner(params.TestChainConfig) + uncles = make([]*Header, 3) + ) + header := &Header{ + Difficulty: math.BigPow(11, 11), + Number: math.BigPow(2, 9), + GasLimit: 12345678, + GasUsed: 1476322, + Time: big.NewInt(9876543), + Extra: []byte("coolest block on chain"), + } + for i := range txs { + amount := math.BigPow(2, int64(i)) + price := big.NewInt(300000) + data := make([]byte, 100) + tx := NewTransaction(uint64(i), common.Address{}, amount, 123457, price, data) + signedTx, err := SignTx(tx, signer, key) + if err != nil { + panic(err) + } + txs[i] = signedTx + receipts[i] = NewReceipt(make([]byte, 32), false, tx.Gas()) + } + for i := range uncles { + uncles[i] = &Header{ + Difficulty: math.BigPow(11, 11), + Number: math.BigPow(2, 9), + GasLimit: 12345678, + GasUsed: 1476322, + Time: big.NewInt(9876543), + Extra: []byte("benchmark uncle"), + } + } + return NewBlock(header, txs, uncles, receipts) +} diff --git a/core/types/consensus_v2_test.go b/core/types/consensus_v2_test.go index 55687f150b..890dcba628 100644 --- a/core/types/consensus_v2_test.go +++ b/core/types/consensus_v2_test.go @@ -1,6 +1,7 @@ package types import ( + "errors" "fmt" "math/big" "reflect" @@ -15,11 +16,11 @@ import ( // Decode extra fields for consensus version >= 2 (XDPoS 2.0 and future versions) func DecodeBytesExtraFields(b []byte, val interface{}) error { if len(b) == 0 { - return fmt.Errorf("extra field is 0 length") + return errors.New("extra field is 0 length") } switch b[0] { case 1: - return fmt.Errorf("consensus version 1 is not applicable for decoding extra fields") + return errors.New("consensus version 1 is not applicable for decoding extra fields") case 2: return rlp.DecodeBytes(b[1:], val) default: diff --git a/core/types/gen_access_tuple.go b/core/types/gen_access_tuple.go new file mode 100644 index 0000000000..d23b32c009 --- /dev/null +++ b/core/types/gen_access_tuple.go @@ -0,0 +1,43 @@ +// Code generated by github.com/fjl/gencodec. DO NOT EDIT. + +package types + +import ( + "encoding/json" + "errors" + + "github.com/XinFinOrg/XDPoSChain/common" +) + +// MarshalJSON marshals as JSON. +func (a AccessTuple) MarshalJSON() ([]byte, error) { + type AccessTuple struct { + Address common.Address `json:"address" gencodec:"required"` + StorageKeys []common.Hash `json:"storageKeys" gencodec:"required"` + } + var enc AccessTuple + enc.Address = a.Address + enc.StorageKeys = a.StorageKeys + return json.Marshal(&enc) +} + +// UnmarshalJSON unmarshals from JSON. +func (a *AccessTuple) UnmarshalJSON(input []byte) error { + type AccessTuple struct { + Address *common.Address `json:"address" gencodec:"required"` + StorageKeys []common.Hash `json:"storageKeys" gencodec:"required"` + } + var dec AccessTuple + if err := json.Unmarshal(input, &dec); err != nil { + return err + } + if dec.Address == nil { + return errors.New("missing required field 'address' for AccessTuple") + } + a.Address = *dec.Address + if dec.StorageKeys == nil { + return errors.New("missing required field 'storageKeys' for AccessTuple") + } + a.StorageKeys = dec.StorageKeys + return nil +} diff --git a/core/types/gen_receipt_json.go b/core/types/gen_receipt_json.go index ee3e8c7c2a..b72ef0270d 100644 --- a/core/types/gen_receipt_json.go +++ b/core/types/gen_receipt_json.go @@ -13,8 +13,10 @@ import ( var _ = (*receiptMarshaling)(nil) +// MarshalJSON marshals as JSON. func (r Receipt) MarshalJSON() ([]byte, error) { type Receipt struct { + Type hexutil.Uint64 `json:"type,omitempty"` PostState hexutil.Bytes `json:"root"` Status hexutil.Uint `json:"status"` CumulativeGasUsed hexutil.Uint64 `json:"cumulativeGasUsed" gencodec:"required"` @@ -28,6 +30,7 @@ func (r Receipt) MarshalJSON() ([]byte, error) { TransactionIndex hexutil.Uint `json:"transactionIndex"` } var enc Receipt + enc.Type = hexutil.Uint64(r.Type) enc.PostState = r.PostState enc.Status = hexutil.Uint(r.Status) enc.CumulativeGasUsed = hexutil.Uint64(r.CumulativeGasUsed) @@ -42,8 +45,10 @@ func (r Receipt) MarshalJSON() ([]byte, error) { return json.Marshal(&enc) } +// UnmarshalJSON unmarshals from JSON. func (r *Receipt) UnmarshalJSON(input []byte) error { type Receipt struct { + Type *hexutil.Uint64 `json:"type,omitempty"` PostState *hexutil.Bytes `json:"root"` Status *hexutil.Uint `json:"status"` CumulativeGasUsed *hexutil.Uint64 `json:"cumulativeGasUsed" gencodec:"required"` @@ -60,6 +65,9 @@ func (r *Receipt) UnmarshalJSON(input []byte) error { if err := json.Unmarshal(input, &dec); err != nil { return err } + if dec.Type != nil { + r.Type = uint8(*dec.Type) + } if dec.PostState != nil { r.PostState = *dec.PostState } diff --git a/core/types/gen_tx_json.go b/core/types/gen_tx_json.go deleted file mode 100644 index 71f4c7d373..0000000000 --- a/core/types/gen_tx_json.go +++ /dev/null @@ -1,101 +0,0 @@ -// Code generated by github.com/fjl/gencodec. DO NOT EDIT. - -package types - -import ( - "encoding/json" - "errors" - "math/big" - - "github.com/XinFinOrg/XDPoSChain/common" - "github.com/XinFinOrg/XDPoSChain/common/hexutil" -) - -var _ = (*txdataMarshaling)(nil) - -// MarshalJSON marshals as JSON. -func (t txdata) MarshalJSON() ([]byte, error) { - type txdata struct { - AccountNonce hexutil.Uint64 `json:"nonce" gencodec:"required"` - Price *hexutil.Big `json:"gasPrice" gencodec:"required"` - GasLimit hexutil.Uint64 `json:"gas" gencodec:"required"` - Recipient *common.Address `json:"to" rlp:"nil"` - Amount *hexutil.Big `json:"value" gencodec:"required"` - Payload hexutil.Bytes `json:"input" gencodec:"required"` - V *hexutil.Big `json:"v" gencodec:"required"` - R *hexutil.Big `json:"r" gencodec:"required"` - S *hexutil.Big `json:"s" gencodec:"required"` - Hash *common.Hash `json:"hash" rlp:"-"` - } - var enc txdata - enc.AccountNonce = hexutil.Uint64(t.AccountNonce) - enc.Price = (*hexutil.Big)(t.Price) - enc.GasLimit = hexutil.Uint64(t.GasLimit) - enc.Recipient = t.Recipient - enc.Amount = (*hexutil.Big)(t.Amount) - enc.Payload = t.Payload - enc.V = (*hexutil.Big)(t.V) - enc.R = (*hexutil.Big)(t.R) - enc.S = (*hexutil.Big)(t.S) - enc.Hash = t.Hash - return json.Marshal(&enc) -} - -// UnmarshalJSON unmarshals from JSON. -func (t *txdata) UnmarshalJSON(input []byte) error { - type txdata struct { - AccountNonce *hexutil.Uint64 `json:"nonce" gencodec:"required"` - Price *hexutil.Big `json:"gasPrice" gencodec:"required"` - GasLimit *hexutil.Uint64 `json:"gas" gencodec:"required"` - Recipient *common.Address `json:"to" rlp:"nil"` - Amount *hexutil.Big `json:"value" gencodec:"required"` - Payload *hexutil.Bytes `json:"input" gencodec:"required"` - V *hexutil.Big `json:"v" gencodec:"required"` - R *hexutil.Big `json:"r" gencodec:"required"` - S *hexutil.Big `json:"s" gencodec:"required"` - Hash *common.Hash `json:"hash" rlp:"-"` - } - var dec txdata - if err := json.Unmarshal(input, &dec); err != nil { - return err - } - if dec.AccountNonce == nil { - return errors.New("missing required field 'nonce' for txdata") - } - t.AccountNonce = uint64(*dec.AccountNonce) - if dec.Price == nil { - return errors.New("missing required field 'gasPrice' for txdata") - } - t.Price = (*big.Int)(dec.Price) - if dec.GasLimit == nil { - return errors.New("missing required field 'gas' for txdata") - } - t.GasLimit = uint64(*dec.GasLimit) - if dec.Recipient != nil { - t.Recipient = dec.Recipient - } - if dec.Amount == nil { - return errors.New("missing required field 'value' for txdata") - } - t.Amount = (*big.Int)(dec.Amount) - if dec.Payload == nil { - return errors.New("missing required field 'input' for txdata") - } - t.Payload = *dec.Payload - if dec.V == nil { - return errors.New("missing required field 'v' for txdata") - } - t.V = (*big.Int)(dec.V) - if dec.R == nil { - return errors.New("missing required field 'r' for txdata") - } - t.R = (*big.Int)(dec.R) - if dec.S == nil { - return errors.New("missing required field 's' for txdata") - } - t.S = (*big.Int)(dec.S) - if dec.Hash != nil { - t.Hash = dec.Hash - } - return nil -} diff --git a/core/types/hashing.go b/core/types/hashing.go new file mode 100644 index 0000000000..73a5cfe53e --- /dev/null +++ b/core/types/hashing.go @@ -0,0 +1,59 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "bytes" + "sync" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/crypto" + "github.com/XinFinOrg/XDPoSChain/rlp" + "golang.org/x/crypto/sha3" +) + +// hasherPool holds LegacyKeccak256 hashers for rlpHash. +var hasherPool = sync.Pool{ + New: func() interface{} { return sha3.NewLegacyKeccak256() }, +} + +// deriveBufferPool holds temporary encoder buffers for DeriveSha and TX encoding. +var encodeBufferPool = sync.Pool{ + New: func() interface{} { return new(bytes.Buffer) }, +} + +// rlpHash encodes x and hashes the encoded bytes. +func rlpHash(x interface{}) (h common.Hash) { + sha := hasherPool.Get().(crypto.KeccakState) + defer hasherPool.Put(sha) + sha.Reset() + rlp.Encode(sha, x) + sha.Read(h[:]) + return h +} + +// prefixedRlpHash writes the prefix into the hasher before rlp-encoding x. +// It's used for typed transactions. +func prefixedRlpHash(prefix byte, x interface{}) (h common.Hash) { + sha := hasherPool.Get().(crypto.KeccakState) + defer hasherPool.Put(sha) + sha.Reset() + sha.Write([]byte{prefix}) + rlp.Encode(sha, x) + sha.Read(h[:]) + return h +} diff --git a/core/types/legacy_tx.go b/core/types/legacy_tx.go new file mode 100644 index 0000000000..146a1e2877 --- /dev/null +++ b/core/types/legacy_tx.go @@ -0,0 +1,111 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "math/big" + + "github.com/XinFinOrg/XDPoSChain/common" +) + +// LegacyTx is the transaction data of regular Ethereum transactions. +type LegacyTx struct { + Nonce uint64 // nonce of sender account + GasPrice *big.Int // wei per gas + Gas uint64 // gas limit + To *common.Address `rlp:"nil"` // nil means contract creation + Value *big.Int // wei amount + Data []byte // contract invocation input data + V, R, S *big.Int // signature values +} + +// NewTransaction creates an unsigned legacy transaction. +// Deprecated: use NewTx instead. +func NewTransaction(nonce uint64, to common.Address, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { + return NewTx(&LegacyTx{ + Nonce: nonce, + To: &to, + Value: amount, + Gas: gasLimit, + GasPrice: gasPrice, + Data: data, + }) +} + +// NewContractCreation creates an unsigned legacy transaction. +// Deprecated: use NewTx instead. +func NewContractCreation(nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { + return NewTx(&LegacyTx{ + Nonce: nonce, + Value: amount, + Gas: gasLimit, + GasPrice: gasPrice, + Data: data, + }) +} + +// copy creates a deep copy of the transaction data and initializes all fields. +func (tx *LegacyTx) copy() TxData { + cpy := &LegacyTx{ + Nonce: tx.Nonce, + To: tx.To, // TODO: copy pointed-to address + Data: common.CopyBytes(tx.Data), + Gas: tx.Gas, + // These are initialized below. + Value: new(big.Int), + GasPrice: new(big.Int), + V: new(big.Int), + R: new(big.Int), + S: new(big.Int), + } + if tx.Value != nil { + cpy.Value.Set(tx.Value) + } + if tx.GasPrice != nil { + cpy.GasPrice.Set(tx.GasPrice) + } + if tx.V != nil { + cpy.V.Set(tx.V) + } + if tx.R != nil { + cpy.R.Set(tx.R) + } + if tx.S != nil { + cpy.S.Set(tx.S) + } + return cpy +} + +// accessors for innerTx. + +func (tx *LegacyTx) txType() byte { return LegacyTxType } +func (tx *LegacyTx) chainID() *big.Int { return deriveChainId(tx.V) } +func (tx *LegacyTx) accessList() AccessList { return nil } +func (tx *LegacyTx) data() []byte { return tx.Data } +func (tx *LegacyTx) gas() uint64 { return tx.Gas } +func (tx *LegacyTx) gasPrice() *big.Int { return tx.GasPrice } +func (tx *LegacyTx) value() *big.Int { return tx.Value } +func (tx *LegacyTx) nonce() uint64 { return tx.Nonce } +func (tx *LegacyTx) to() *common.Address { return tx.To } + +func (tx *LegacyTx) rawSignatureValues() (v, r, s *big.Int) { + return tx.V, tx.R, tx.S +} + +func (tx *LegacyTx) setSignatureValues(chainID, v, r, s *big.Int) { + tx.V, tx.R, tx.S = v, r, s +} diff --git a/core/types/log_test.go b/core/types/log_test.go index b6ab7810b7..fbd26f72ff 100644 --- a/core/types/log_test.go +++ b/core/types/log_test.go @@ -18,7 +18,7 @@ package types import ( "encoding/json" - "fmt" + "errors" "reflect" "testing" @@ -97,7 +97,7 @@ var unmarshalLogTests = map[string]struct { }, "missing data": { input: `{"address":"0xecf8f87f810ecf450940c9f60066b4a7a501d6a7","blockHash":"0x656c34545f90a730a19008c0e7a7cd4fb3895064b48d6d69761bd5abad681056","blockNumber":"0x1ecfa4","logIndex":"0x2","topics":["0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef","0x00000000000000000000000080b2c9d7cbbf30a1b0fc8983c647d754c6525615","0x000000000000000000000000f9dff387dcb5cc4cca5b91adb07a95f54e9f1bb6"],"transactionHash":"0x3b198bfd5d2907285af009e9ae84a0ecd63677110d89d7e030251acb87f6487e","transactionIndex":"0x3"}`, - wantError: fmt.Errorf("missing required field 'data' for Log"), + wantError: errors.New("missing required field 'data' for Log"), }, } diff --git a/core/types/receipt.go b/core/types/receipt.go index cec362f5a6..d1b3cc46bf 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -18,6 +18,7 @@ package types import ( "bytes" + "errors" "fmt" "io" "math/big" @@ -35,6 +36,9 @@ var ( receiptStatusSuccessfulRLP = []byte{0x01} ) +// This error is returned when a typed receipt is decoded, but the string is empty. +var errEmptyTypedReceipt = errors.New("empty typed receipt bytes") + const ( // ReceiptStatusFailed is the status code of a transaction if execution failed. ReceiptStatusFailed = uint(0) @@ -46,6 +50,7 @@ const ( // Receipt represents the results of a transaction. type Receipt struct { // Consensus fields: These fields are defined by the Yellow Paper + Type uint8 `json:"type,omitempty"` PostState []byte `json:"root"` Status uint `json:"status"` CumulativeGasUsed uint64 `json:"cumulativeGasUsed" gencodec:"required"` @@ -66,6 +71,7 @@ type Receipt struct { } type receiptMarshaling struct { + Type hexutil.Uint64 PostState hexutil.Bytes Status hexutil.Uint CumulativeGasUsed hexutil.Uint64 @@ -82,7 +88,18 @@ type receiptRLP struct { Logs []*Log } -type receiptStorageRLP struct { +// v4StoredReceiptRLP is the storage encoding of a receipt used in database version 4. +type v4StoredReceiptRLP struct { + PostStateOrStatus []byte + CumulativeGasUsed uint64 + TxHash common.Hash + ContractAddress common.Address + Logs []*LogForStorage + GasUsed uint64 +} + +// v3StoredReceiptRLP is the original storage encoding of a receipt including some unnecessary fields. +type v3StoredReceiptRLP struct { PostStateOrStatus []byte CumulativeGasUsed uint64 Bloom Bloom @@ -93,8 +110,13 @@ type receiptStorageRLP struct { } // NewReceipt creates a barebone transaction receipt, copying the init fields. +// Deprecated: create receipts using a struct literal instead. func NewReceipt(root []byte, failed bool, cumulativeGasUsed uint64) *Receipt { - r := &Receipt{PostState: common.CopyBytes(root), CumulativeGasUsed: cumulativeGasUsed} + r := &Receipt{ + Type: LegacyTxType, + PostState: common.CopyBytes(root), + CumulativeGasUsed: cumulativeGasUsed, + } if failed { r.Status = ReceiptStatusFailed } else { @@ -106,21 +128,65 @@ func NewReceipt(root []byte, failed bool, cumulativeGasUsed uint64) *Receipt { // EncodeRLP implements rlp.Encoder, and flattens the consensus fields of a receipt // into an RLP stream. If no post state is present, byzantium fork is assumed. func (r *Receipt) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs}) + data := &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs} + if r.Type == LegacyTxType { + return rlp.Encode(w, data) + } + // It's an EIP-2718 typed TX receipt. + if r.Type != AccessListTxType { + return ErrTxTypeNotSupported + } + buf := encodeBufferPool.Get().(*bytes.Buffer) + defer encodeBufferPool.Put(buf) + buf.Reset() + buf.WriteByte(r.Type) + if err := rlp.Encode(buf, data); err != nil { + return err + } + return rlp.Encode(w, buf.Bytes()) } // DecodeRLP implements rlp.Decoder, and loads the consensus fields of a receipt // from an RLP stream. func (r *Receipt) DecodeRLP(s *rlp.Stream) error { - var dec receiptRLP - if err := s.Decode(&dec); err != nil { + kind, _, err := s.Kind() + switch { + case err != nil: return err + case kind == rlp.List: + // It's a legacy receipt. + var dec receiptRLP + if err := s.Decode(&dec); err != nil { + return err + } + r.Type = LegacyTxType + return r.setFromRLP(dec) + case kind == rlp.String: + // It's an EIP-2718 typed tx receipt. + b, err := s.Bytes() + if err != nil { + return err + } + if len(b) == 0 { + return errEmptyTypedReceipt + } + r.Type = b[0] + if r.Type == AccessListTxType { + var dec receiptRLP + if err := rlp.DecodeBytes(b[1:], &dec); err != nil { + return err + } + return r.setFromRLP(dec) + } + return ErrTxTypeNotSupported + default: + return rlp.ErrExpectedList } - if err := r.setStatus(dec.PostStateOrStatus); err != nil { - return err - } - r.CumulativeGasUsed, r.Bloom, r.Logs = dec.CumulativeGasUsed, dec.Bloom, dec.Logs - return nil +} + +func (r *Receipt) setFromRLP(data receiptRLP) error { + r.CumulativeGasUsed, r.Bloom, r.Logs = data.CumulativeGasUsed, data.Bloom, data.Logs + return r.setStatus(data.PostStateOrStatus) } func (r *Receipt) setStatus(postStateOrStatus []byte) error { @@ -151,7 +217,6 @@ func (r *Receipt) statusEncoding() []byte { // to approximate and limit the memory consumption of various caches. func (r *Receipt) Size() common.StorageSize { size := common.StorageSize(unsafe.Sizeof(*r)) + common.StorageSize(len(r.PostState)) - size += common.StorageSize(len(r.Logs)) * common.StorageSize(unsafe.Sizeof(Log{})) for _, log := range r.Logs { size += common.StorageSize(len(log.Topics)*common.HashLength + len(log.Data)) @@ -174,7 +239,7 @@ type ReceiptForStorage Receipt // EncodeRLP implements rlp.Encoder, and flattens all content fields of a receipt // into an RLP stream. func (r *ReceiptForStorage) EncodeRLP(w io.Writer) error { - enc := &receiptStorageRLP{ + enc := &v3StoredReceiptRLP{ PostStateOrStatus: (*Receipt)(r).statusEncoding(), CumulativeGasUsed: r.CumulativeGasUsed, Bloom: r.Bloom, @@ -192,25 +257,64 @@ func (r *ReceiptForStorage) EncodeRLP(w io.Writer) error { // DecodeRLP implements rlp.Decoder, and loads both consensus and implementation // fields of a receipt from an RLP stream. func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error { - var dec receiptStorageRLP - if err := s.Decode(&dec); err != nil { + // Retrieve the entire receipt blob as we need to try multiple decoders + blob, err := s.Raw() + if err != nil { return err } - if err := (*Receipt)(r).setStatus(dec.PostStateOrStatus); err != nil { + // Try decoding from the newest format for future proofness, then the older one + // for old nodes that just upgraded. V4 was an intermediate unreleased format so + // we do need to decode it, but it's not common (try last). + if err := decodeV3StoredReceiptRLP(r, blob); err == nil { + return nil + } + return decodeV4StoredReceiptRLP(r, blob) +} + +func decodeV3StoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { + var stored v3StoredReceiptRLP + if err := rlp.DecodeBytes(blob, &stored); err != nil { + return err + } + if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { return err } // Assign the consensus fields - r.CumulativeGasUsed, r.Bloom = dec.CumulativeGasUsed, dec.Bloom - r.Logs = make([]*Log, len(dec.Logs)) - for i, log := range dec.Logs { + r.CumulativeGasUsed = stored.CumulativeGasUsed + r.Bloom = stored.Bloom + r.Logs = make([]*Log, len(stored.Logs)) + for i, log := range stored.Logs { r.Logs[i] = (*Log)(log) } // Assign the implementation fields - r.TxHash, r.ContractAddress, r.GasUsed = dec.TxHash, dec.ContractAddress, dec.GasUsed + r.TxHash = stored.TxHash + r.ContractAddress = stored.ContractAddress + r.GasUsed = stored.GasUsed return nil } -// Receipts is a wrapper around a Receipt array to implement DerivableList. +func decodeV4StoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { + var stored v4StoredReceiptRLP + if err := rlp.DecodeBytes(blob, &stored); err != nil { + return err + } + if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { + return err + } + r.CumulativeGasUsed = stored.CumulativeGasUsed + r.TxHash = stored.TxHash + r.ContractAddress = stored.ContractAddress + r.GasUsed = stored.GasUsed + r.Logs = make([]*Log, len(stored.Logs)) + for i, log := range stored.Logs { + r.Logs[i] = (*Log)(log) + } + r.Bloom = CreateBloom(Receipts{(*Receipt)(r)}) + + return nil +} + +// Receipts implements DerivableList for receipts. type Receipts []*Receipt // Len returns the number of receipts in this list. diff --git a/core/types/receipt_test.go b/core/types/receipt_test.go new file mode 100644 index 0000000000..82fec06c96 --- /dev/null +++ b/core/types/receipt_test.go @@ -0,0 +1,170 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "bytes" + "math/big" + "reflect" + "testing" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/rlp" +) + +func TestDecodeEmptyTypedReceipt(t *testing.T) { + input := []byte{0x80} + var r Receipt + err := rlp.DecodeBytes(input, &r) + if err != errEmptyTypedReceipt { + t.Fatal("wrong error:", err) + } +} + +func TestLegacyReceiptDecoding(t *testing.T) { + tests := []struct { + name string + encode func(*Receipt) ([]byte, error) + }{ + { + "V4StoredReceiptRLP", + encodeAsV4StoredReceiptRLP, + }, + { + "V3StoredReceiptRLP", + encodeAsV3StoredReceiptRLP, + }, + } + + tx := NewTransaction(1, common.HexToAddress("0x1"), big.NewInt(1), 1, big.NewInt(1), nil) + receipt := &Receipt{ + Status: ReceiptStatusFailed, + CumulativeGasUsed: 1, + Logs: []*Log{ + { + Address: common.BytesToAddress([]byte{0x11}), + Topics: []common.Hash{common.HexToHash("dead"), common.HexToHash("beef")}, + Data: []byte{0x01, 0x00, 0xff}, + }, + { + Address: common.BytesToAddress([]byte{0x01, 0x11}), + Topics: []common.Hash{common.HexToHash("dead"), common.HexToHash("beef")}, + Data: []byte{0x01, 0x00, 0xff}, + }, + }, + TxHash: tx.Hash(), + ContractAddress: common.BytesToAddress([]byte{0x01, 0x11, 0x11}), + GasUsed: 111111, + } + receipt.Bloom = CreateBloom(Receipts{receipt}) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + enc, err := tc.encode(receipt) + if err != nil { + t.Fatalf("Error encoding receipt: %v", err) + } + var dec ReceiptForStorage + if err := rlp.DecodeBytes(enc, &dec); err != nil { + t.Fatalf("Error decoding RLP receipt: %v", err) + } + // Check whether all consensus fields are correct. + if dec.Status != receipt.Status { + t.Fatalf("Receipt status mismatch, want %v, have %v", receipt.Status, dec.Status) + } + if dec.CumulativeGasUsed != receipt.CumulativeGasUsed { + t.Fatalf("Receipt CumulativeGasUsed mismatch, want %v, have %v", receipt.CumulativeGasUsed, dec.CumulativeGasUsed) + } + if dec.Bloom != receipt.Bloom { + t.Fatalf("Bloom data mismatch, want %v, have %v", receipt.Bloom, dec.Bloom) + } + if len(dec.Logs) != len(receipt.Logs) { + t.Fatalf("Receipt log number mismatch, want %v, have %v", len(receipt.Logs), len(dec.Logs)) + } + for i := 0; i < len(dec.Logs); i++ { + if dec.Logs[i].Address != receipt.Logs[i].Address { + t.Fatalf("Receipt log %d address mismatch, want %v, have %v", i, receipt.Logs[i].Address, dec.Logs[i].Address) + } + if !reflect.DeepEqual(dec.Logs[i].Topics, receipt.Logs[i].Topics) { + t.Fatalf("Receipt log %d topics mismatch, want %v, have %v", i, receipt.Logs[i].Topics, dec.Logs[i].Topics) + } + if !bytes.Equal(dec.Logs[i].Data, receipt.Logs[i].Data) { + t.Fatalf("Receipt log %d data mismatch, want %v, have %v", i, receipt.Logs[i].Data, dec.Logs[i].Data) + } + } + }) + } +} + +func encodeAsV4StoredReceiptRLP(want *Receipt) ([]byte, error) { + stored := &v4StoredReceiptRLP{ + PostStateOrStatus: want.statusEncoding(), + CumulativeGasUsed: want.CumulativeGasUsed, + TxHash: want.TxHash, + ContractAddress: want.ContractAddress, + Logs: make([]*LogForStorage, len(want.Logs)), + GasUsed: want.GasUsed, + } + for i, log := range want.Logs { + stored.Logs[i] = (*LogForStorage)(log) + } + return rlp.EncodeToBytes(stored) +} + +func encodeAsV3StoredReceiptRLP(want *Receipt) ([]byte, error) { + stored := &v3StoredReceiptRLP{ + PostStateOrStatus: want.statusEncoding(), + CumulativeGasUsed: want.CumulativeGasUsed, + Bloom: want.Bloom, + TxHash: want.TxHash, + ContractAddress: want.ContractAddress, + Logs: make([]*LogForStorage, len(want.Logs)), + GasUsed: want.GasUsed, + } + for i, log := range want.Logs { + stored.Logs[i] = (*LogForStorage)(log) + } + return rlp.EncodeToBytes(stored) +} + +// TestTypedReceiptEncodingDecoding reproduces a flaw that existed in the receipt +// rlp decoder, which failed due to a shadowing error. +func TestTypedReceiptEncodingDecoding(t *testing.T) { + var payload = common.FromHex("f9043eb9010c01f90108018262d4b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0b9010c01f901080182cd14b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0b9010d01f901090183013754b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0b9010d01f90109018301a194b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0") + check := func(bundle []*Receipt) { + t.Helper() + for i, receipt := range bundle { + if got, want := receipt.Type, uint8(1); got != want { + t.Fatalf("bundle %d: got %x, want %x", i, got, want) + } + } + } + { + var bundle []*Receipt + rlp.DecodeBytes(payload, &bundle) + check(bundle) + } + { + var bundle []*Receipt + r := bytes.NewReader(payload) + s := rlp.NewStream(r, uint64(len(payload))) + if err := s.Decode(&bundle); err != nil { + t.Fatal(err) + } + check(bundle) + } +} diff --git a/core/types/transaction.go b/core/types/transaction.go index 03f1bbe39a..b26c0fc206 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -24,241 +24,346 @@ import ( "io" "math/big" "sync/atomic" + "time" "github.com/XinFinOrg/XDPoSChain/common" - "github.com/XinFinOrg/XDPoSChain/common/hexutil" "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/rlp" ) //go:generate gencodec -type txdata -field-override txdataMarshaling -out gen_tx_json.go - var ( ErrInvalidSig = errors.New("invalid transaction v, r, s values") + ErrUnexpectedProtection = errors.New("transaction type does not supported EIP-155 protected signatures") + ErrInvalidTxType = errors.New("transaction type not valid in this context") + ErrTxTypeNotSupported = errors.New("transaction type not supported") + ErrGasFeeCapTooLow = errors.New("fee cap less than base fee") + errShortTypedTx = errors.New("typed transaction too short") + errInvalidYParity = errors.New("'yParity' field must be 0 or 1") + errVYParityMismatch = errors.New("'v' and 'yParity' fields do not match") + errVYParityMissing = errors.New("missing 'yParity' or 'v' field in transaction") + errEmptyTypedTx = errors.New("empty typed transaction bytes") errNoSigner = errors.New("missing signing methods") - skipNonceDestinationAddress = map[string]bool{ - common.XDCXAddr: true, - common.TradingStateAddr: true, - common.XDCXLendingAddress: true, - common.XDCXLendingFinalizedTradeAddress: true, + skipNonceDestinationAddress = map[common.Address]bool{ + common.XDCXAddrBinary: true, + common.TradingStateAddrBinary: true, + common.XDCXLendingAddressBinary: true, + common.XDCXLendingFinalizedTradeAddressBinary: true, } ) -// deriveSigner makes a *best* guess about which signer to use. -func deriveSigner(V *big.Int) Signer { - if V.Sign() != 0 && isProtectedV(V) { - return NewEIP155Signer(deriveChainId(V)) - } else { - return HomesteadSigner{} - } -} +// Transaction types. +const ( + LegacyTxType = iota + AccessListTxType +) +// Transaction is an Ethereum transaction. type Transaction struct { - data txdata + inner TxData // Consensus contents of a transaction + time time.Time // Time first seen locally (spam avoidance) + // caches hash atomic.Value size atomic.Value from atomic.Value } -type txdata struct { - AccountNonce uint64 `json:"nonce" gencodec:"required"` - Price *big.Int `json:"gasPrice" gencodec:"required"` - GasLimit uint64 `json:"gas" gencodec:"required"` - Recipient *common.Address `json:"to" rlp:"nil"` // nil means contract creation - Amount *big.Int `json:"value" gencodec:"required"` - Payload []byte `json:"input" gencodec:"required"` - - // Signature values - V *big.Int `json:"v" gencodec:"required"` - R *big.Int `json:"r" gencodec:"required"` - S *big.Int `json:"s" gencodec:"required"` - - // This is only used when marshaling to JSON. - Hash *common.Hash `json:"hash" rlp:"-"` +// NewTx creates a new transaction. +func NewTx(inner TxData) *Transaction { + tx := new(Transaction) + tx.setDecoded(inner.copy(), 0) + return tx } -type txdataMarshaling struct { - AccountNonce hexutil.Uint64 - Price *hexutil.Big - GasLimit hexutil.Uint64 - Amount *hexutil.Big - Payload hexutil.Bytes - V *hexutil.Big - R *hexutil.Big - S *hexutil.Big +// TxData is the underlying data of a transaction. +// +// This is implemented by LegacyTx and AccessListTx. +type TxData interface { + txType() byte // returns the type ID + copy() TxData // creates a deep copy and initializes all fields + + chainID() *big.Int + accessList() AccessList + data() []byte + gas() uint64 + gasPrice() *big.Int + value() *big.Int + nonce() uint64 + to() *common.Address + + rawSignatureValues() (v, r, s *big.Int) + setSignatureValues(chainID, v, r, s *big.Int) } -func NewTransaction(nonce uint64, to common.Address, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { - return newTransaction(nonce, &to, amount, gasLimit, gasPrice, data) -} - -func NewContractCreation(nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { - return newTransaction(nonce, nil, amount, gasLimit, gasPrice, data) -} - -func newTransaction(nonce uint64, to *common.Address, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { - if len(data) > 0 { - data = common.CopyBytes(data) +// EncodeRLP implements rlp.Encoder +func (tx *Transaction) EncodeRLP(w io.Writer) error { + if tx.Type() == LegacyTxType { + return rlp.Encode(w, tx.inner) } - d := txdata{ - AccountNonce: nonce, - Recipient: to, - Payload: data, - Amount: new(big.Int), - GasLimit: gasLimit, - Price: new(big.Int), - V: new(big.Int), - R: new(big.Int), - S: new(big.Int), + // It's an EIP-2718 typed TX envelope. + buf := encodeBufferPool.Get().(*bytes.Buffer) + defer encodeBufferPool.Put(buf) + buf.Reset() + if err := tx.encodeTyped(buf); err != nil { + return err } - if amount != nil { - d.Amount.Set(amount) + return rlp.Encode(w, buf.Bytes()) +} + +// encodeTyped writes the canonical encoding of a typed transaction to w. +func (tx *Transaction) encodeTyped(w *bytes.Buffer) error { + w.WriteByte(tx.Type()) + return rlp.Encode(w, tx.inner) +} + +// MarshalBinary returns the canonical encoding of the transaction. +// For legacy transactions, it returns the RLP encoding. For EIP-2718 typed +// transactions, it returns the type and payload. +func (tx *Transaction) MarshalBinary() ([]byte, error) { + if tx.Type() == LegacyTxType { + return rlp.EncodeToBytes(tx.inner) } - if gasPrice != nil { - d.Price.Set(gasPrice) + var buf bytes.Buffer + err := tx.encodeTyped(&buf) + return buf.Bytes(), err +} + +// DecodeRLP implements rlp.Decoder +func (tx *Transaction) DecodeRLP(s *rlp.Stream) error { + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == rlp.List: + // It's a legacy transaction. + var inner LegacyTx + err := s.Decode(&inner) + if err == nil { + tx.setDecoded(&inner, int(rlp.ListSize(size))) + } + return err + case kind == rlp.String: + // It's an EIP-2718 typed TX envelope. + var b []byte + if b, err = s.Bytes(); err != nil { + return err + } + inner, err := tx.decodeTyped(b) + if err == nil { + tx.setDecoded(inner, len(b)) + } + return err + default: + return rlp.ErrExpectedList + } +} + +// UnmarshalBinary decodes the canonical encoding of transactions. +// It supports legacy RLP transactions and EIP2718 typed transactions. +func (tx *Transaction) UnmarshalBinary(b []byte) error { + if len(b) > 0 && b[0] > 0x7f { + // It's a legacy transaction. + var data LegacyTx + err := rlp.DecodeBytes(b, &data) + if err != nil { + return err + } + tx.setDecoded(&data, len(b)) + return nil + } + // It's an EIP2718 typed transaction envelope. + inner, err := tx.decodeTyped(b) + if err != nil { + return err + } + tx.setDecoded(inner, len(b)) + return nil +} + +// decodeTyped decodes a typed transaction from the canonical format. +func (tx *Transaction) decodeTyped(b []byte) (TxData, error) { + if len(b) == 0 { + return nil, errEmptyTypedTx + } + switch b[0] { + case AccessListTxType: + var inner AccessListTx + err := rlp.DecodeBytes(b[1:], &inner) + return &inner, err + default: + return nil, ErrTxTypeNotSupported + } +} + +// setDecoded sets the inner transaction and size after decoding. +func (tx *Transaction) setDecoded(inner TxData, size int) { + tx.inner = inner + tx.time = time.Now() + if size > 0 { + tx.size.Store(common.StorageSize(size)) + } +} + +func sanityCheckSignature(v *big.Int, r *big.Int, s *big.Int, maybeProtected bool) error { + if isProtectedV(v) && !maybeProtected { + return ErrUnexpectedProtection } - return &Transaction{data: d} -} + var plainV byte + if isProtectedV(v) { + chainID := deriveChainId(v).Uint64() + plainV = byte(v.Uint64() - 35 - 2*chainID) + } else if maybeProtected { + // Only EIP-155 signatures can be optionally protected. Since + // we determined this v value is not protected, it must be a + // raw 27 or 28. + plainV = byte(v.Uint64() - 27) + } else { + // If the signature is not optionally protected, we assume it + // must already be equal to the recovery id. + plainV = byte(v.Uint64()) + } + if !crypto.ValidateSignatureValues(plainV, r, s, false) { + return ErrInvalidSig + } -// ChainId returns which chain id this transaction was signed for (if at all) -func (tx *Transaction) ChainId() *big.Int { - return deriveChainId(tx.data.V) -} - -// Protected returns whether the transaction is protected from replay protection. -func (tx *Transaction) Protected() bool { - return isProtectedV(tx.data.V) + return nil } func isProtectedV(V *big.Int) bool { if V.BitLen() <= 8 { v := V.Uint64() - return v != 27 && v != 28 + return v != 27 && v != 28 && v != 1 && v != 0 } - // anything not 27 or 28 are considered unprotected + // anything not 27 or 28 is considered protected return true } -// EncodeRLP implements rlp.Encoder -func (tx *Transaction) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, &tx.data) +// Protected says whether the transaction is replay-protected. +func (tx *Transaction) Protected() bool { + switch tx := tx.inner.(type) { + case *LegacyTx: + return tx.V != nil && isProtectedV(tx.V) + default: + return true + } } -// DecodeRLP implements rlp.Decoder -func (tx *Transaction) DecodeRLP(s *rlp.Stream) error { - _, size, _ := s.Kind() - err := s.Decode(&tx.data) - if err == nil { - tx.size.Store(common.StorageSize(rlp.ListSize(size))) - } - - return err +// Type returns the transaction type. +func (tx *Transaction) Type() uint8 { + return tx.inner.txType() } -// MarshalJSON encodes the web3 RPC transaction format. -func (tx *Transaction) MarshalJSON() ([]byte, error) { - hash := tx.Hash() - data := tx.data - data.Hash = &hash - return data.MarshalJSON() +// ChainId returns the EIP155 chain ID of the transaction. The return value will always be +// non-nil. For legacy transactions which are not replay-protected, the return value is +// zero. +func (tx *Transaction) ChainId() *big.Int { + return tx.inner.chainID() } -// UnmarshalJSON decodes the web3 RPC transaction format. -func (tx *Transaction) UnmarshalJSON(input []byte) error { - var dec txdata - if err := dec.UnmarshalJSON(input); err != nil { - return err - } - var V byte - if isProtectedV(dec.V) { - chainID := deriveChainId(dec.V).Uint64() - V = byte(dec.V.Uint64() - 35 - 2*chainID) - } else { - V = byte(dec.V.Uint64() - 27) - } - if !crypto.ValidateSignatureValues(V, dec.R, dec.S, false) { - return ErrInvalidSig - } - *tx = Transaction{data: dec} - return nil -} +// Data returns the input data of the transaction. +func (tx *Transaction) Data() []byte { return tx.inner.data() } -func (tx *Transaction) Data() []byte { return common.CopyBytes(tx.data.Payload) } -func (tx *Transaction) Gas() uint64 { return tx.data.GasLimit } -func (tx *Transaction) GasPrice() *big.Int { return new(big.Int).Set(tx.data.Price) } -func (tx *Transaction) Value() *big.Int { return new(big.Int).Set(tx.data.Amount) } -func (tx *Transaction) Nonce() uint64 { return tx.data.AccountNonce } -func (tx *Transaction) CheckNonce() bool { return true } +// AccessList returns the access list of the transaction. +func (tx *Transaction) AccessList() AccessList { return tx.inner.accessList() } + +// Gas returns the gas limit of the transaction. +func (tx *Transaction) Gas() uint64 { return tx.inner.gas() } + +// GasPrice returns the gas price of the transaction. +func (tx *Transaction) GasPrice() *big.Int { return new(big.Int).Set(tx.inner.gasPrice()) } + +// Value returns the ether amount of the transaction. +func (tx *Transaction) Value() *big.Int { return new(big.Int).Set(tx.inner.value()) } + +// Nonce returns the sender account nonce of the transaction. +func (tx *Transaction) Nonce() uint64 { return tx.inner.nonce() } // To returns the recipient address of the transaction. -// It returns nil if the transaction is a contract creation. +// For contract-creation transactions, To returns nil. func (tx *Transaction) To() *common.Address { - if tx.data.Recipient == nil { + // Copy the pointed-to address. + ito := tx.inner.to() + if ito == nil { return nil } - to := *tx.data.Recipient - return &to + cpy := *ito + return &cpy } func (tx *Transaction) From() *common.Address { - if tx.data.V != nil { - signer := deriveSigner(tx.data.V) - if f, err := Sender(signer, tx); err != nil { - return nil - } else { - return &f - } + var signer Signer + if tx.Protected() { + signer = LatestSignerForChainID(tx.ChainId()) } else { + signer = HomesteadSigner{} + } + from, err := Sender(signer, tx) + if err != nil { return nil } + return &from } -// Hash hashes the RLP encoding of tx. -// It uniquely identifies the transaction. +// RawSignatureValues returns the V, R, S signature values of the transaction. +// The return values should not be modified by the caller. +func (tx *Transaction) RawSignatureValues() (v, r, s *big.Int) { + return tx.inner.rawSignatureValues() +} + +// GasPriceCmp compares the gas prices of two transactions. +func (tx *Transaction) GasPriceCmp(other *Transaction) int { + return tx.inner.gasPrice().Cmp(other.inner.gasPrice()) +} + +// GasPriceIntCmp compares the gas price of the transaction against the given price. +func (tx *Transaction) GasPriceIntCmp(other *big.Int) int { + return tx.inner.gasPrice().Cmp(other) +} + +// Hash returns the transaction hash. func (tx *Transaction) Hash() common.Hash { if hash := tx.hash.Load(); hash != nil { return hash.(common.Hash) } - v := rlpHash(tx) - tx.hash.Store(v) - return v -} -func (tx *Transaction) CacheHash() { - v := rlpHash(tx) - tx.hash.Store(v) + var h common.Hash + if tx.Type() == LegacyTxType { + h = rlpHash(tx.inner) + } else { + h = prefixedRlpHash(tx.Type(), tx.inner) + } + tx.hash.Store(h) + return h } // Size returns the true RLP encoded storage size of the transaction, either by -// encoding and returning it, or returning a previsouly cached value. +// encoding and returning it, or returning a previously cached value. func (tx *Transaction) Size() common.StorageSize { if size := tx.size.Load(); size != nil { return size.(common.StorageSize) } c := writeCounter(0) - rlp.Encode(&c, &tx.data) + rlp.Encode(&c, &tx.inner) tx.size.Store(common.StorageSize(c)) return common.StorageSize(c) } // AsMessage returns the transaction as a core.Message. -// -// AsMessage requires a signer to derive the sender. -// -// XXX Rename message to something less arbitrary? func (tx *Transaction) AsMessage(s Signer, balanceFee *big.Int, number *big.Int) (Message, error) { msg := Message{ - nonce: tx.data.AccountNonce, - gasLimit: tx.data.GasLimit, - gasPrice: new(big.Int).Set(tx.data.Price), - to: tx.data.Recipient, - amount: tx.data.Amount, - data: tx.data.Payload, + nonce: tx.Nonce(), + gasLimit: tx.Gas(), + gasPrice: new(big.Int).Set(tx.GasPrice()), + to: tx.To(), + amount: tx.Value(), + data: tx.Data(), + accessList: tx.AccessList(), checkNonce: true, balanceTokenFee: balanceFee, } + var err error msg.from, err = Sender(s, tx) if balanceFee != nil { @@ -274,151 +379,94 @@ func (tx *Transaction) AsMessage(s Signer, balanceFee *big.Int, number *big.Int) } // WithSignature returns a new transaction with the given signature. -// This signature needs to be formatted as described in the yellow paper (v+27). +// This signature needs to be in the [R || S || V] format where V is 0 or 1. func (tx *Transaction) WithSignature(signer Signer, sig []byte) (*Transaction, error) { r, s, v, err := signer.SignatureValues(tx, sig) if err != nil { return nil, err } - cpy := &Transaction{data: tx.data} - cpy.data.R, cpy.data.S, cpy.data.V = r, s, v - return cpy, nil + cpy := tx.inner.copy() + cpy.setSignatureValues(signer.ChainID(), v, r, s) + return &Transaction{inner: cpy, time: tx.time}, nil } -// Cost returns amount + gasprice * gaslimit. +// Cost returns gas * gasPrice + value. func (tx *Transaction) Cost() *big.Int { - total := new(big.Int).Mul(tx.data.Price, new(big.Int).SetUint64(tx.data.GasLimit)) - total.Add(total, tx.data.Amount) + total := new(big.Int).Mul(tx.GasPrice(), new(big.Int).SetUint64(tx.Gas())) + total.Add(total, tx.Value()) return total } -// Cost returns amount + gasprice * gaslimit. +// TxCost returns gas * gasPrice + value. func (tx *Transaction) TxCost(number *big.Int) *big.Int { - total := new(big.Int).Mul(common.GetGasPrice(number), new(big.Int).SetUint64(tx.data.GasLimit)) - total.Add(total, tx.data.Amount) + total := new(big.Int).Mul(common.GetGasPrice(number), new(big.Int).SetUint64(tx.Gas())) + total.Add(total, tx.Value()) return total } -func (tx *Transaction) RawSignatureValues() (*big.Int, *big.Int, *big.Int) { - return tx.data.V, tx.data.R, tx.data.S -} - func (tx *Transaction) IsSpecialTransaction() bool { - if tx.To() == nil { - return false - } - toBytes := tx.To().Bytes() - randomizeSMCBytes := common.HexToAddress(common.RandomizeSMC).Bytes() - blockSignersBytes := common.HexToAddress(common.BlockSigners).Bytes() - return bytes.Equal(toBytes, randomizeSMCBytes) || bytes.Equal(toBytes, blockSignersBytes) + to := tx.To() + return to != nil && (*to == common.RandomizeSMCBinary || *to == common.BlockSignersBinary) } func (tx *Transaction) IsTradingTransaction() bool { - if tx.To() == nil { - return false - } - - if tx.To().String() != common.XDCXAddr { - return false - } - - return true + to := tx.To() + return to != nil && *to == common.XDCXAddrBinary } func (tx *Transaction) IsLendingTransaction() bool { - if tx.To() == nil { - return false - } - - if tx.To().String() != common.XDCXLendingAddress { - return false - } - return true + to := tx.To() + return to != nil && *to == common.XDCXLendingAddressBinary } func (tx *Transaction) IsLendingFinalizedTradeTransaction() bool { - if tx.To() == nil { - return false - } - - if tx.To().String() != common.XDCXLendingFinalizedTradeAddress { - return false - } - return true + to := tx.To() + return to != nil && *to == common.XDCXLendingFinalizedTradeAddressBinary } func (tx *Transaction) IsSkipNonceTransaction() bool { - if tx.To() == nil { - return false - } - if skip := skipNonceDestinationAddress[tx.To().String()]; skip { - return true - } - return false + to := tx.To() + return to != nil && skipNonceDestinationAddress[*to] } func (tx *Transaction) IsSigningTransaction() bool { - if tx.To() == nil { + to := tx.To() + if to == nil || *to != common.BlockSignersBinary { return false } - - if tx.To().String() != common.BlockSigners { + data := tx.Data() + if len(data) != (32*2 + 4) { return false } - - method := common.ToHex(tx.Data()[0:4]) - - if method != common.SignMethod { - return false - } - - if len(tx.Data()) != (32*2 + 4) { - return false - } - - return true + method := common.ToHex(data[0:4]) + return method == common.SignMethod } func (tx *Transaction) IsVotingTransaction() (bool, *common.Address) { - if tx.To() == nil { + to := tx.To() + if to == nil || *to != common.MasternodeVotingSMCBinary { return false, nil } - b := (tx.To().String() == common.MasternodeVotingSMC) - - if !b { - return b, nil + var end int + data := tx.Data() + method := common.ToHex(data[0:4]) + if method == common.VoteMethod || method == common.ProposeMethod || method == common.ResignMethod { + end = len(data) + } else if method == common.UnvoteMethod { + end = len(data) - 32 + } else { + return false, nil } - method := common.ToHex(tx.Data()[0:4]) - if b = (method == common.VoteMethod); b { - addr := tx.Data()[len(tx.Data())-20:] - m := common.BytesToAddress(addr) - return b, &m - } + addr := data[end-20 : end] + m := common.BytesToAddress(addr) + return true, &m - if b = (method == common.UnvoteMethod); b { - addr := tx.Data()[len(tx.Data())-32-20 : len(tx.Data())-32] - m := common.BytesToAddress(addr) - return b, &m - } - - if b = (method == common.ProposeMethod); b { - addr := tx.Data()[len(tx.Data())-20:] - m := common.BytesToAddress(addr) - return b, &m - } - - if b = (method == common.ResignMethod); b { - addr := tx.Data()[len(tx.Data())-20:] - m := common.BytesToAddress(addr) - return b, &m - } - - return b, nil } func (tx *Transaction) IsXDCXApplyTransaction() bool { - if tx.To() == nil { + to := tx.To() + if to == nil { return false } @@ -426,26 +474,22 @@ func (tx *Transaction) IsXDCXApplyTransaction() bool { if common.IsTestnet { addr = common.XDCXListingSMCTestNet } - if tx.To().String() != addr.String() { + if *to != addr { return false } - - method := common.ToHex(tx.Data()[0:4]) - - if method != common.XDCXApplyMethod { - return false - } - + data := tx.Data() // 4 bytes for function name // 32 bytes for 1 parameter - if len(tx.Data()) != (32 + 4) { + if len(data) != (32 + 4) { return false } - return true + method := common.ToHex(data[0:4]) + return method == common.XDCXApplyMethod } func (tx *Transaction) IsXDCZApplyTransaction() bool { - if tx.To() == nil { + to := tx.To() + if to == nil { return false } @@ -453,45 +497,39 @@ func (tx *Transaction) IsXDCZApplyTransaction() bool { if common.IsTestnet { addr = common.TRC21IssuerSMCTestNet } - if tx.To().String() != addr.String() { + if *to != addr { return false } - - method := common.ToHex(tx.Data()[0:4]) - if method != common.XDCZApplyMethod { - return false - } - + data := tx.Data() // 4 bytes for function name // 32 bytes for 1 parameter - if len(tx.Data()) != (32 + 4) { + if len(data) != (32 + 4) { return false } - - return true + method := common.ToHex(data[0:4]) + return method == common.XDCZApplyMethod } func (tx *Transaction) String() string { var from, to string - if tx.data.V != nil { - // make a best guess about the signer and use that to derive - // the sender. - signer := deriveSigner(tx.data.V) - if f, err := Sender(signer, tx); err != nil { // derive but don't cache - from = "[invalid sender: invalid sig]" - } else { - from = fmt.Sprintf("%x", f[:]) - } + + sender := tx.From() + if sender != nil { + from = fmt.Sprintf("%x", sender[:]) } else { - from = "[invalid sender: nil V field]" + from = "[invalid sender]" } - if tx.data.Recipient == nil { + receiver := tx.To() + if receiver == nil { to = "[contract creation]" } else { - to = fmt.Sprintf("%x", tx.data.Recipient[:]) + to = fmt.Sprintf("%x", receiver[:]) } - enc, _ := rlp.EncodeToBytes(&tx.data) + + enc, _ := rlp.EncodeToBytes(tx.Data()) + v, r, s := tx.RawSignatureValues() + return fmt.Sprintf(` TX(%x) Contract: %v @@ -508,17 +546,17 @@ func (tx *Transaction) String() string { Hex: %x `, tx.Hash(), - tx.data.Recipient == nil, + receiver == nil, from, to, - tx.data.AccountNonce, - tx.data.Price, - tx.data.GasLimit, - tx.data.Amount, - tx.data.Payload, - tx.data.V, - tx.data.R, - tx.data.S, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.Value(), + tx.Data(), + v, + r, + s, enc, ) } @@ -562,40 +600,47 @@ func TxDifference(a, b Transactions) (keep Transactions) { type TxByNonce Transactions func (s TxByNonce) Len() int { return len(s) } -func (s TxByNonce) Less(i, j int) bool { return s[i].data.AccountNonce < s[j].data.AccountNonce } +func (s TxByNonce) Less(i, j int) bool { return s[i].Nonce() < s[j].Nonce() } func (s TxByNonce) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// TxByPrice implements both the sort and the heap interface, making it useful +// TxByPriceAndTime implements both the sort and the heap interface, making it useful // for all at once sorting as well as individually adding and removing elements. -type TxByPrice struct { +type TxByPriceAndTime struct { txs Transactions payersSwap map[common.Address]*big.Int } -func (s TxByPrice) Len() int { return len(s.txs) } -func (s TxByPrice) Less(i, j int) bool { - i_price := s.txs[i].data.Price +func (s TxByPriceAndTime) Len() int { return len(s.txs) } +func (s TxByPriceAndTime) Less(i, j int) bool { + i_price := s.txs[i].GasPrice() if s.txs[i].To() != nil { if _, ok := s.payersSwap[*s.txs[i].To()]; ok { i_price = common.TRC21GasPrice } } - j_price := s.txs[j].data.Price + j_price := s.txs[j].GasPrice() if s.txs[j].To() != nil { if _, ok := s.payersSwap[*s.txs[j].To()]; ok { j_price = common.TRC21GasPrice } } - return i_price.Cmp(j_price) > 0 -} -func (s TxByPrice) Swap(i, j int) { s.txs[i], s.txs[j] = s.txs[j], s.txs[i] } -func (s *TxByPrice) Push(x interface{}) { + // If the prices are equal, use the time the transaction was first seen for + // deterministic sorting + cmp := i_price.Cmp(j_price) + if cmp == 0 { + return s.txs[i].time.Before(s.txs[j].time) + } + return cmp > 0 +} +func (s TxByPriceAndTime) Swap(i, j int) { s.txs[i], s.txs[j] = s.txs[j], s.txs[i] } + +func (s *TxByPriceAndTime) Push(x interface{}) { s.txs = append(s.txs, x.(*Transaction)) } -func (s *TxByPrice) Pop() interface{} { +func (s *TxByPriceAndTime) Pop() interface{} { old := s.txs n := len(old) x := old[n-1] @@ -608,7 +653,7 @@ func (s *TxByPrice) Pop() interface{} { // entire batches of transactions for non-executable accounts. type TransactionsByPriceAndNonce struct { txs map[common.Address]Transactions // Per account nonce-sorted list of transactions - heads TxByPrice // Next transaction for each unique account (price heap) + heads TxByPriceAndTime // Next transaction for each unique account (price heap) signer Signer // Signer for the set of transactions } @@ -617,11 +662,11 @@ type TransactionsByPriceAndNonce struct { // // Note, the input map is reowned so the caller should not interact any more with // if after providing it to the constructor. - +// // It also classifies special txs and normal txs func NewTransactionsByPriceAndNonce(signer Signer, txs map[common.Address]Transactions, signers map[common.Address]struct{}, payersSwap map[common.Address]*big.Int) (*TransactionsByPriceAndNonce, Transactions) { - // Initialize a price based heap with the head transactions - heads := TxByPrice{} + // Initialize a price and received time based heap with the head transactions + heads := TxByPriceAndTime{} heads.payersSwap = payersSwap specialTxs := Transactions{} for _, accTxs := range txs { @@ -698,11 +743,12 @@ type Message struct { gasLimit uint64 gasPrice *big.Int data []byte + accessList AccessList checkNonce bool balanceTokenFee *big.Int } -func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte, checkNonce bool, balanceTokenFee *big.Int, number *big.Int) Message { +func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte, accessList AccessList, checkNonce bool, balanceTokenFee *big.Int, number *big.Int) Message { if balanceTokenFee != nil { gasPrice = common.GetGasPrice(number) } @@ -714,6 +760,7 @@ func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *b gasLimit: gasLimit, gasPrice: gasPrice, data: data, + accessList: accessList, checkNonce: checkNonce, balanceTokenFee: balanceTokenFee, } @@ -728,5 +775,6 @@ func (m Message) Gas() uint64 { return m.gasLimit } func (m Message) Nonce() uint64 { return m.nonce } func (m Message) Data() []byte { return m.data } func (m Message) CheckNonce() bool { return m.checkNonce } +func (m Message) AccessList() AccessList { return m.accessList } func (m *Message) SetNonce(nonce uint64) { m.nonce = nonce } diff --git a/core/types/transaction_marshalling.go b/core/types/transaction_marshalling.go new file mode 100644 index 0000000000..91403994bf --- /dev/null +++ b/core/types/transaction_marshalling.go @@ -0,0 +1,187 @@ +package types + +import ( + "encoding/json" + "errors" + "math/big" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/hexutil" +) + +// txJSON is the JSON representation of transactions. +type txJSON struct { + Type hexutil.Uint64 `json:"type"` + + // Common transaction fields: + Nonce *hexutil.Uint64 `json:"nonce"` + GasPrice *hexutil.Big `json:"gasPrice"` + Gas *hexutil.Uint64 `json:"gas"` + Value *hexutil.Big `json:"value"` + Data *hexutil.Bytes `json:"input"` + V *hexutil.Big `json:"v"` + R *hexutil.Big `json:"r"` + S *hexutil.Big `json:"s"` + To *common.Address `json:"to"` + + // Access list transaction fields: + ChainID *hexutil.Big `json:"chainId,omitempty"` + AccessList *AccessList `json:"accessList,omitempty"` + + // Only used for encoding: + Hash common.Hash `json:"hash"` +} + +// MarshalJSON marshals as JSON with a hash. +func (t *Transaction) MarshalJSON() ([]byte, error) { + var enc txJSON + // These are set for all tx types. + enc.Hash = t.Hash() + enc.Type = hexutil.Uint64(t.Type()) + + // Other fields are set conditionally depending on tx type. + switch tx := t.inner.(type) { + case *LegacyTx: + enc.Nonce = (*hexutil.Uint64)(&tx.Nonce) + enc.Gas = (*hexutil.Uint64)(&tx.Gas) + enc.GasPrice = (*hexutil.Big)(tx.GasPrice) + enc.Value = (*hexutil.Big)(tx.Value) + enc.Data = (*hexutil.Bytes)(&tx.Data) + enc.To = t.To() + enc.V = (*hexutil.Big)(tx.V) + enc.R = (*hexutil.Big)(tx.R) + enc.S = (*hexutil.Big)(tx.S) + case *AccessListTx: + enc.ChainID = (*hexutil.Big)(tx.ChainID) + enc.AccessList = &tx.AccessList + enc.Nonce = (*hexutil.Uint64)(&tx.Nonce) + enc.Gas = (*hexutil.Uint64)(&tx.Gas) + enc.GasPrice = (*hexutil.Big)(tx.GasPrice) + enc.Value = (*hexutil.Big)(tx.Value) + enc.Data = (*hexutil.Bytes)(&tx.Data) + enc.To = t.To() + enc.V = (*hexutil.Big)(tx.V) + enc.R = (*hexutil.Big)(tx.R) + enc.S = (*hexutil.Big)(tx.S) + } + return json.Marshal(&enc) +} + +// UnmarshalJSON unmarshals from JSON. +func (t *Transaction) UnmarshalJSON(input []byte) error { + var dec txJSON + if err := json.Unmarshal(input, &dec); err != nil { + return err + } + + // Decode / verify fields according to transaction type. + var inner TxData + switch dec.Type { + case LegacyTxType: + var itx LegacyTx + inner = &itx + if dec.To != nil { + itx.To = dec.To + } + if dec.Nonce == nil { + return errors.New("missing required field 'nonce' in transaction") + } + itx.Nonce = uint64(*dec.Nonce) + if dec.GasPrice == nil { + return errors.New("missing required field 'gasPrice' in transaction") + } + itx.GasPrice = (*big.Int)(dec.GasPrice) + if dec.Gas == nil { + return errors.New("missing required field 'gas' in transaction") + } + itx.Gas = uint64(*dec.Gas) + if dec.Value == nil { + return errors.New("missing required field 'value' in transaction") + } + itx.Value = (*big.Int)(dec.Value) + if dec.Data == nil { + return errors.New("missing required field 'input' in transaction") + } + itx.Data = *dec.Data + if dec.V == nil { + return errors.New("missing required field 'v' in transaction") + } + itx.V = (*big.Int)(dec.V) + if dec.R == nil { + return errors.New("missing required field 'r' in transaction") + } + itx.R = (*big.Int)(dec.R) + if dec.S == nil { + return errors.New("missing required field 's' in transaction") + } + itx.S = (*big.Int)(dec.S) + withSignature := itx.V.Sign() != 0 || itx.R.Sign() != 0 || itx.S.Sign() != 0 + if withSignature { + if err := sanityCheckSignature(itx.V, itx.R, itx.S, true); err != nil { + return err + } + } + + case AccessListTxType: + var itx AccessListTx + inner = &itx + // Access list is optional for now. + if dec.AccessList != nil { + itx.AccessList = *dec.AccessList + } + if dec.ChainID == nil { + return errors.New("missing required field 'chainId' in transaction") + } + itx.ChainID = (*big.Int)(dec.ChainID) + if dec.To != nil { + itx.To = dec.To + } + if dec.Nonce == nil { + return errors.New("missing required field 'nonce' in transaction") + } + itx.Nonce = uint64(*dec.Nonce) + if dec.GasPrice == nil { + return errors.New("missing required field 'gasPrice' in transaction") + } + itx.GasPrice = (*big.Int)(dec.GasPrice) + if dec.Gas == nil { + return errors.New("missing required field 'gas' in transaction") + } + itx.Gas = uint64(*dec.Gas) + if dec.Value == nil { + return errors.New("missing required field 'value' in transaction") + } + itx.Value = (*big.Int)(dec.Value) + if dec.Data == nil { + return errors.New("missing required field 'input' in transaction") + } + itx.Data = *dec.Data + if dec.V == nil { + return errors.New("missing required field 'v' in transaction") + } + itx.V = (*big.Int)(dec.V) + if dec.R == nil { + return errors.New("missing required field 'r' in transaction") + } + itx.R = (*big.Int)(dec.R) + if dec.S == nil { + return errors.New("missing required field 's' in transaction") + } + itx.S = (*big.Int)(dec.S) + withSignature := itx.V.Sign() != 0 || itx.R.Sign() != 0 || itx.S.Sign() != 0 + if withSignature { + if err := sanityCheckSignature(itx.V, itx.R, itx.S, false); err != nil { + return err + } + } + + default: + return ErrTxTypeNotSupported + } + + // Now set the inner transaction. + t.setDecoded(inner, 0) + + // TODO: check hash here? + return nil +} diff --git a/core/types/transaction_signing.go b/core/types/transaction_signing.go index 3658ed43b9..f4174dae48 100644 --- a/core/types/transaction_signing.go +++ b/core/types/transaction_signing.go @@ -27,9 +27,8 @@ import ( "github.com/XinFinOrg/XDPoSChain/params" ) -var ( - ErrInvalidChainId = errors.New("invalid chain id for signer") -) +var ErrInvalidChainId = errors.New("invalid chain id for signer") +var ErrInvalidNilTx = errors.New("invalid nil tx") // sigCache is used to cache the derived sender and contains // the signer used to derive it. @@ -42,6 +41,8 @@ type sigCache struct { func MakeSigner(config *params.ChainConfig, blockNumber *big.Int) Signer { var signer Signer switch { + case config.IsEIP1559(blockNumber): + signer = NewEIP2930Signer(config.ChainId) case config.IsEIP155(blockNumber): signer = NewEIP155Signer(config.ChainId) case config.IsHomestead(blockNumber): @@ -52,7 +53,40 @@ func MakeSigner(config *params.ChainConfig, blockNumber *big.Int) Signer { return signer } -// SignTx signs the transaction using the given signer and private key +// LatestSigner returns the 'most permissive' Signer available for the given chain +// configuration. Specifically, this enables support of EIP-155 replay protection and +// EIP-2930 access list transactions when their respective forks are scheduled to occur at +// any block number in the chain config. +// +// Use this in transaction-handling code where the current block number is unknown. If you +// have the current block number available, use MakeSigner instead. +func LatestSigner(config *params.ChainConfig) Signer { + if config.ChainId != nil { + if common.Eip1559Block.Uint64() != 9999999999 || config.Eip1559Block != nil { + return NewEIP2930Signer(config.ChainId) + } + if config.EIP155Block != nil { + return NewEIP155Signer(config.ChainId) + } + } + return HomesteadSigner{} +} + +// LatestSignerForChainID returns the 'most permissive' Signer available. Specifically, +// this enables support for EIP-155 replay protection and all implemented EIP-2718 +// transaction types if chainID is non-nil. +// +// Use this in transaction-handling code where the current block number and fork +// configuration are unknown. If you have a ChainConfig, use LatestSigner instead. +// If you have a ChainConfig and know the current block number, use MakeSigner instead. +func LatestSignerForChainID(chainID *big.Int) Signer { + if chainID == nil { + return HomesteadSigner{} + } + return NewEIP2930Signer(chainID) +} + +// SignTx signs the transaction using the given signer and private key. func SignTx(tx *Transaction, s Signer, prv *ecdsa.PrivateKey) (*Transaction, error) { h := s.Hash(tx) sig, err := crypto.Sign(h[:], prv) @@ -62,6 +96,27 @@ func SignTx(tx *Transaction, s Signer, prv *ecdsa.PrivateKey) (*Transaction, err return tx.WithSignature(s, sig) } +// SignNewTx creates a transaction and signs it. +func SignNewTx(prv *ecdsa.PrivateKey, s Signer, txdata TxData) (*Transaction, error) { + tx := NewTx(txdata) + h := s.Hash(tx) + sig, err := crypto.Sign(h[:], prv) + if err != nil { + return nil, err + } + return tx.WithSignature(s, sig) +} + +// MustSignNewTx creates a transaction and signs it. +// This panics if the transaction cannot be signed. +func MustSignNewTx(prv *ecdsa.PrivateKey, s Signer, txdata TxData) *Transaction { + tx, err := SignNewTx(prv, s, txdata) + if err != nil { + panic(err) + } + return tx +} + // Sender returns the address derived from the signature (V, R, S) using secp256k1 // elliptic curve and an error if it failed deriving or upon an incorrect // signature. @@ -70,6 +125,10 @@ func SignTx(tx *Transaction, s Signer, prv *ecdsa.PrivateKey) (*Transaction, err // signing method. The cache is invalidated if the cached signer does // not match the signer used in the current call. func Sender(signer Signer, tx *Transaction) (common.Address, error) { + if tx == nil { + return common.Address{}, ErrInvalidNilTx + } + if sc := tx.from.Load(); sc != nil { sigCache := sc.(sigCache) // If the signer used to derive from in a previous @@ -88,21 +147,112 @@ func Sender(signer Signer, tx *Transaction) (common.Address, error) { return addr, nil } -// Signer encapsulates transaction signature handling. Note that this interface is not a -// stable API and may change at any time to accommodate new protocol rules. +// Signer encapsulates transaction signature handling. The name of this type is slightly +// misleading because Signers don't actually sign, they're just for validating and +// processing of signatures. +// +// Note that this interface is not a stable API and may change at any time to accommodate +// new protocol rules. type Signer interface { // Sender returns the sender address of the transaction. Sender(tx *Transaction) (common.Address, error) + // SignatureValues returns the raw R, S, V values corresponding to the // given signature. SignatureValues(tx *Transaction, sig []byte) (r, s, v *big.Int, err error) - // Hash returns the hash to be signed. + ChainID() *big.Int + + // Hash returns 'signature hash', i.e. the transaction hash that is signed by the + // private key. This hash does not uniquely identify the transaction. Hash(tx *Transaction) common.Hash + // Equal returns true if the given signer is the same as the receiver. Equal(Signer) bool } -// EIP155Transaction implements Signer using the EIP155 rules. +type eip2930Signer struct{ EIP155Signer } + +// NewEIP2930Signer returns a signer that accepts EIP-2930 access list transactions, +// EIP-155 replay protected transactions, and legacy Homestead transactions. +func NewEIP2930Signer(chainId *big.Int) Signer { + return eip2930Signer{NewEIP155Signer(chainId)} +} + +func (s eip2930Signer) ChainID() *big.Int { + return s.chainId +} + +func (s eip2930Signer) Equal(s2 Signer) bool { + x, ok := s2.(eip2930Signer) + return ok && x.chainId.Cmp(s.chainId) == 0 +} + +func (s eip2930Signer) Sender(tx *Transaction) (common.Address, error) { + V, R, S := tx.RawSignatureValues() + switch tx.Type() { + case LegacyTxType: + return s.EIP155Signer.Sender(tx) + case AccessListTxType: + // ACL txs are defined to use 0 and 1 as their recovery id, add + // 27 to become equivalent to unprotected Homestead signatures. + V = new(big.Int).Add(V, big.NewInt(27)) + default: + return common.Address{}, ErrTxTypeNotSupported + } + if tx.ChainId().Cmp(s.chainId) != 0 { + return common.Address{}, ErrInvalidChainId + } + return recoverPlain(s.Hash(tx), R, S, V, true) +} + +func (s eip2930Signer) SignatureValues(tx *Transaction, sig []byte) (R, S, V *big.Int, err error) { + switch txdata := tx.inner.(type) { + case *LegacyTx: + return s.EIP155Signer.SignatureValues(tx, sig) + case *AccessListTx: + // Check that chain ID of tx matches the signer. We also accept ID zero here, + // because it indicates that the chain ID was not specified in the tx. + if txdata.ChainID.Sign() != 0 && txdata.ChainID.Cmp(s.chainId) != 0 { + return nil, nil, nil, ErrInvalidChainId + } + R, S, _ = decodeSignature(sig) + V = big.NewInt(int64(sig[64])) + default: + return nil, nil, nil, ErrTxTypeNotSupported + } + return R, S, V, nil +} + +// Hash returns the hash to be signed by the sender. +// It does not uniquely identify the transaction. +func (s eip2930Signer) Hash(tx *Transaction) common.Hash { + switch tx.Type() { + case LegacyTxType: + return s.EIP155Signer.Hash(tx) + case AccessListTxType: + return prefixedRlpHash( + tx.Type(), + []interface{}{ + s.chainId, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), + tx.AccessList(), + }) + default: + // This _should_ not happen, but in case someone sends in a bad + // json struct via RPC, it's probably more prudent to return an + // empty hash instead of killing the node with a panic + //panic("Unsupported transaction type: %d", tx.typ) + return common.Hash{} + } +} + +// EIP155Signer implements Signer using the EIP-155 rules. This accepts transactions which +// are replay-protected as well as unprotected homestead transactions. type EIP155Signer struct { chainId, chainIdMul *big.Int } @@ -117,6 +267,10 @@ func NewEIP155Signer(chainId *big.Int) EIP155Signer { } } +func (s EIP155Signer) ChainID() *big.Int { + return s.chainId +} + func (s EIP155Signer) Equal(s2 Signer) bool { eip155, ok := s2.(EIP155Signer) return ok && eip155.chainId.Cmp(s.chainId) == 0 @@ -125,24 +279,28 @@ func (s EIP155Signer) Equal(s2 Signer) bool { var big8 = big.NewInt(8) func (s EIP155Signer) Sender(tx *Transaction) (common.Address, error) { + if tx.Type() != LegacyTxType { + return common.Address{}, ErrTxTypeNotSupported + } if !tx.Protected() { return HomesteadSigner{}.Sender(tx) } if tx.ChainId().Cmp(s.chainId) != 0 { return common.Address{}, ErrInvalidChainId } - V := new(big.Int).Sub(tx.data.V, s.chainIdMul) + V, R, S := tx.RawSignatureValues() + V = new(big.Int).Sub(V, s.chainIdMul) V.Sub(V, big8) - return recoverPlain(s.Hash(tx), tx.data.R, tx.data.S, V, true) + return recoverPlain(s.Hash(tx), R, S, V, true) } -// WithSignature returns a new transaction with the given signature. This signature +// SignatureValues returns signature values. This signature // needs to be in the [R || S || V] format where V is 0 or 1. func (s EIP155Signer) SignatureValues(tx *Transaction, sig []byte) (R, S, V *big.Int, err error) { - R, S, V, err = HomesteadSigner{}.SignatureValues(tx, sig) - if err != nil { - return nil, nil, nil, err + if tx.Type() != LegacyTxType { + return nil, nil, nil, ErrTxTypeNotSupported } + R, S, V = decodeSignature(sig) if s.chainId.Sign() != 0 { V = big.NewInt(int64(sig[64] + 35)) V.Add(V, s.chainIdMul) @@ -154,12 +312,12 @@ func (s EIP155Signer) SignatureValues(tx *Transaction, sig []byte) (R, S, V *big // It does not uniquely identify the transaction. func (s EIP155Signer) Hash(tx *Transaction) common.Hash { return rlpHash([]interface{}{ - tx.data.AccountNonce, - tx.data.Price, - tx.data.GasLimit, - tx.data.Recipient, - tx.data.Amount, - tx.data.Payload, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), s.chainId, uint(0), uint(0), }) } @@ -168,6 +326,10 @@ func (s EIP155Signer) Hash(tx *Transaction) common.Hash { // homestead rules. type HomesteadSigner struct{ FrontierSigner } +func (s HomesteadSigner) ChainID() *big.Int { + return nil +} + func (s HomesteadSigner) Equal(s2 Signer) bool { _, ok := s2.(HomesteadSigner) return ok @@ -180,25 +342,39 @@ func (hs HomesteadSigner) SignatureValues(tx *Transaction, sig []byte) (r, s, v } func (hs HomesteadSigner) Sender(tx *Transaction) (common.Address, error) { - return recoverPlain(hs.Hash(tx), tx.data.R, tx.data.S, tx.data.V, true) + if tx.Type() != LegacyTxType { + return common.Address{}, ErrTxTypeNotSupported + } + v, r, s := tx.RawSignatureValues() + return recoverPlain(hs.Hash(tx), r, s, v, true) } type FrontierSigner struct{} +func (s FrontierSigner) ChainID() *big.Int { + return nil +} + func (s FrontierSigner) Equal(s2 Signer) bool { _, ok := s2.(FrontierSigner) return ok } +func (fs FrontierSigner) Sender(tx *Transaction) (common.Address, error) { + if tx.Type() != LegacyTxType { + return common.Address{}, ErrTxTypeNotSupported + } + v, r, s := tx.RawSignatureValues() + return recoverPlain(fs.Hash(tx), r, s, v, false) +} + // SignatureValues returns signature values. This signature // needs to be in the [R || S || V] format where V is 0 or 1. func (fs FrontierSigner) SignatureValues(tx *Transaction, sig []byte) (r, s, v *big.Int, err error) { - if len(sig) != 65 { - panic(fmt.Sprintf("wrong size for signature: got %d, want 65", len(sig))) + if tx.Type() != LegacyTxType { + return nil, nil, nil, ErrTxTypeNotSupported } - r = new(big.Int).SetBytes(sig[:32]) - s = new(big.Int).SetBytes(sig[32:64]) - v = new(big.Int).SetBytes([]byte{sig[64] + 27}) + r, s, v = decodeSignature(sig) return r, s, v, nil } @@ -206,17 +382,23 @@ func (fs FrontierSigner) SignatureValues(tx *Transaction, sig []byte) (r, s, v * // It does not uniquely identify the transaction. func (fs FrontierSigner) Hash(tx *Transaction) common.Hash { return rlpHash([]interface{}{ - tx.data.AccountNonce, - tx.data.Price, - tx.data.GasLimit, - tx.data.Recipient, - tx.data.Amount, - tx.data.Payload, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), }) } -func (fs FrontierSigner) Sender(tx *Transaction) (common.Address, error) { - return recoverPlain(fs.Hash(tx), tx.data.R, tx.data.S, tx.data.V, false) +func decodeSignature(sig []byte) (r, s, v *big.Int) { + if len(sig) != crypto.SignatureLength { + panic(fmt.Sprintf("wrong size for signature: got %d, want %d", len(sig), crypto.SignatureLength)) + } + r = new(big.Int).SetBytes(sig[:32]) + s = new(big.Int).SetBytes(sig[32:64]) + v = new(big.Int).SetBytes([]byte{sig[64] + 27}) + return r, s, v } func recoverPlain(sighash common.Hash, R, S, Vb *big.Int, homestead bool) (common.Address, error) { diff --git a/core/types/transaction_test.go b/core/types/transaction_test.go index 234ef322c5..b5b84be08c 100644 --- a/core/types/transaction_test.go +++ b/core/types/transaction_test.go @@ -20,8 +20,12 @@ import ( "bytes" "crypto/ecdsa" "encoding/json" + "errors" + "fmt" "math/big" + "reflect" "testing" + "time" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/crypto" @@ -31,6 +35,8 @@ import ( // The values in those tests are from the Transaction Tests // at github.com/ethereum/tests. var ( + testAddr = common.HexToAddress("b94f5374fce5edbc8e2a8697c15331677e6ebf0b") + emptyTx = NewTransaction( 0, common.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87"), @@ -40,7 +46,7 @@ var ( rightvrsTx, _ = NewTransaction( 3, - common.HexToAddress("b94f5374fce5edbc8e2a8697c15331677e6ebf0b"), + testAddr, big.NewInt(10), 2000, big.NewInt(1), @@ -49,8 +55,32 @@ var ( HomesteadSigner{}, common.Hex2Bytes("98ff921201554726367d2be8c804a7ff89ccf285ebc57dff8ae4c44b9c19ac4a8887321be575c8095f789dd4c743dfe42c1820f9231f98a962b210e3ac2452a301"), ) + + emptyEip2718Tx = NewTx(&AccessListTx{ + ChainID: big.NewInt(1), + Nonce: 3, + To: &testAddr, + Value: big.NewInt(10), + Gas: 25000, + GasPrice: big.NewInt(1), + Data: common.FromHex("5544"), + }) + + signedEip2718Tx, _ = emptyEip2718Tx.WithSignature( + NewEIP2930Signer(big.NewInt(1)), + common.Hex2Bytes("c9519f4f2b30335884581971573fadf60c6204f59a911df35ee8a540456b266032f1e8e2c5dd761f9e4f88f41c8310aeaba26a8bfcdacfedfa12ec3862d3752101"), + ) ) +func TestDecodeEmptyTypedTx(t *testing.T) { + input := []byte{0x80} + var tx Transaction + err := rlp.DecodeBytes(input, &tx) + if err != errEmptyTypedTx { + t.Fatal("wrong error:", err) + } +} + func TestTransactionSigHash(t *testing.T) { var homestead HomesteadSigner if homestead.Hash(emptyTx) != common.HexToHash("c775b99e7ad12f50d819fcd602390467e28141316969f4b57f0626f74fe3b386") { @@ -72,6 +102,117 @@ func TestTransactionEncode(t *testing.T) { } } +func TestEIP2718TransactionSigHash(t *testing.T) { + s := NewEIP2930Signer(big.NewInt(1)) + if s.Hash(emptyEip2718Tx) != common.HexToHash("49b486f0ec0a60dfbbca2d30cb07c9e8ffb2a2ff41f29a1ab6737475f6ff69f3") { + t.Errorf("empty EIP-2718 transaction hash mismatch, got %x", s.Hash(emptyEip2718Tx)) + } + if s.Hash(signedEip2718Tx) != common.HexToHash("49b486f0ec0a60dfbbca2d30cb07c9e8ffb2a2ff41f29a1ab6737475f6ff69f3") { + t.Errorf("signed EIP-2718 transaction hash mismatch, got %x", s.Hash(signedEip2718Tx)) + } +} + +// This test checks signature operations on access list transactions. +func TestEIP2930Signer(t *testing.T) { + var ( + key, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + keyAddr = crypto.PubkeyToAddress(key.PublicKey) + signer1 = NewEIP2930Signer(big.NewInt(1)) + signer2 = NewEIP2930Signer(big.NewInt(2)) + tx0 = NewTx(&AccessListTx{Nonce: 1}) + tx1 = NewTx(&AccessListTx{ChainID: big.NewInt(1), Nonce: 1}) + tx2, _ = SignNewTx(key, signer2, &AccessListTx{ChainID: big.NewInt(2), Nonce: 1}) + ) + + tests := []struct { + tx *Transaction + signer Signer + wantSignerHash common.Hash + wantSenderErr error + wantSignErr error + wantHash common.Hash // after signing + }{ + { + tx: tx0, + signer: signer1, + wantSignerHash: common.HexToHash("846ad7672f2a3a40c1f959cd4a8ad21786d620077084d84c8d7c077714caa139"), + wantSenderErr: ErrInvalidChainId, + wantHash: common.HexToHash("1ccd12d8bbdb96ea391af49a35ab641e219b2dd638dea375f2bc94dd290f2549"), + }, + { + tx: tx1, + signer: signer1, + wantSenderErr: ErrInvalidSig, + wantSignerHash: common.HexToHash("846ad7672f2a3a40c1f959cd4a8ad21786d620077084d84c8d7c077714caa139"), + wantHash: common.HexToHash("1ccd12d8bbdb96ea391af49a35ab641e219b2dd638dea375f2bc94dd290f2549"), + }, + { + // This checks what happens when trying to sign an unsigned tx for the wrong chain. + tx: tx1, + signer: signer2, + wantSenderErr: ErrInvalidChainId, + wantSignerHash: common.HexToHash("367967247499343401261d718ed5aa4c9486583e4d89251afce47f4a33c33362"), + wantSignErr: ErrInvalidChainId, + }, + { + // This checks what happens when trying to re-sign a signed tx for the wrong chain. + tx: tx2, + signer: signer1, + wantSenderErr: ErrInvalidChainId, + wantSignerHash: common.HexToHash("846ad7672f2a3a40c1f959cd4a8ad21786d620077084d84c8d7c077714caa139"), + wantSignErr: ErrInvalidChainId, + }, + } + + for i, test := range tests { + sigHash := test.signer.Hash(test.tx) + if sigHash != test.wantSignerHash { + t.Errorf("test %d: wrong sig hash: got %x, want %x", i, sigHash, test.wantSignerHash) + } + sender, err := Sender(test.signer, test.tx) + if !errors.Is(err, test.wantSenderErr) { + t.Errorf("test %d: wrong Sender error %q", i, err) + } + if err == nil && sender != keyAddr { + t.Errorf("test %d: wrong sender address %x", i, sender) + } + signedTx, err := SignTx(test.tx, test.signer, key) + if !errors.Is(err, test.wantSignErr) { + t.Fatalf("test %d: wrong SignTx error %q", i, err) + } + if signedTx != nil { + if signedTx.Hash() != test.wantHash { + t.Errorf("test %d: wrong tx hash after signing: got %x, want %x", i, signedTx.Hash(), test.wantHash) + } + } + } +} + +func TestEIP2718TransactionEncode(t *testing.T) { + // RLP representation + { + have, err := rlp.EncodeToBytes(signedEip2718Tx) + if err != nil { + t.Fatalf("encode error: %v", err) + } + want := common.FromHex("b86601f8630103018261a894b94f5374fce5edbc8e2a8697c15331677e6ebf0b0a825544c001a0c9519f4f2b30335884581971573fadf60c6204f59a911df35ee8a540456b2660a032f1e8e2c5dd761f9e4f88f41c8310aeaba26a8bfcdacfedfa12ec3862d37521") + if !bytes.Equal(have, want) { + t.Errorf("encoded RLP mismatch, got %x", have) + } + } + // Binary representation + { + have, err := signedEip2718Tx.MarshalBinary() + if err != nil { + t.Fatalf("encode error: %v", err) + } + want := common.FromHex("01f8630103018261a894b94f5374fce5edbc8e2a8697c15331677e6ebf0b0a825544c001a0c9519f4f2b30335884581971573fadf60c6204f59a911df35ee8a540456b2660a032f1e8e2c5dd761f9e4f88f41c8310aeaba26a8bfcdacfedfa12ec3862d37521") + if !bytes.Equal(have, want) { + t.Errorf("encoded RLP mismatch, got %x", have) + } + } +} + func decodeTx(data []byte) (*Transaction, error) { var tx Transaction t, err := &tx, rlp.Decode(bytes.NewReader(data), &tx) @@ -233,3 +374,174 @@ func TestTransactionJSON(t *testing.T) { } } } + +// Tests that if multiple transactions have the same price, the ones seen earlier +// are prioritized to avoid network spam attacks aiming for a specific ordering. +func TestTransactionTimeSort(t *testing.T) { + // Generate a batch of accounts to start with + keys := make([]*ecdsa.PrivateKey, 5) + for i := 0; i < len(keys); i++ { + keys[i], _ = crypto.GenerateKey() + } + signer := HomesteadSigner{} + + // Generate a batch of transactions with overlapping prices, but different creation times + groups := map[common.Address]Transactions{} + for start, key := range keys { + addr := crypto.PubkeyToAddress(key.PublicKey) + + tx, _ := SignTx(NewTransaction(0, common.Address{}, big.NewInt(100), 100, big.NewInt(1), nil), signer, key) + tx.time = time.Unix(0, int64(len(keys)-start)) + + groups[addr] = append(groups[addr], tx) + } + // Sort the transactions and cross check the nonce ordering + txset, _ := NewTransactionsByPriceAndNonce(signer, groups, nil, map[common.Address]*big.Int{}) + + txs := Transactions{} + for tx := txset.Peek(); tx != nil; tx = txset.Peek() { + txs = append(txs, tx) + txset.Shift() + } + if len(txs) != len(keys) { + t.Errorf("expected %d transactions, found %d", len(keys), len(txs)) + } + for i, txi := range txs { + fromi, _ := Sender(signer, txi) + if i+1 < len(txs) { + next := txs[i+1] + fromNext, _ := Sender(signer, next) + + if txi.GasPrice().Cmp(next.GasPrice()) < 0 { + t.Errorf("invalid gasprice ordering: tx #%d (A=%x P=%v) < tx #%d (A=%x P=%v)", i, fromi[:4], txi.GasPrice(), i+1, fromNext[:4], next.GasPrice()) + } + // Make sure time order is ascending if the txs have the same gas price + if txi.GasPrice().Cmp(next.GasPrice()) == 0 && txi.time.After(next.time) { + t.Errorf("invalid received time ordering: tx #%d (A=%x T=%v) > tx #%d (A=%x T=%v)", i, fromi[:4], txi.time, i+1, fromNext[:4], next.time) + } + } + } +} + +// TestTransactionCoding tests serializing/de-serializing to/from rlp and JSON. +func TestTransactionCoding(t *testing.T) { + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("could not generate key: %v", err) + } + var ( + signer = NewEIP2930Signer(common.Big1) + addr = common.HexToAddress("0x0000000000000000000000000000000000000001") + recipient = common.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87") + accesses = AccessList{{Address: addr, StorageKeys: []common.Hash{{0}}}} + ) + for i := uint64(0); i < 500; i++ { + var txdata TxData + switch i % 5 { + case 0: + // Legacy tx. + txdata = &LegacyTx{ + Nonce: i, + To: &recipient, + Gas: 1, + GasPrice: big.NewInt(2), + Data: []byte("abcdef"), + } + case 1: + // Legacy tx contract creation. + txdata = &LegacyTx{ + Nonce: i, + Gas: 1, + GasPrice: big.NewInt(2), + Data: []byte("abcdef"), + } + case 2: + // Tx with non-zero access list. + txdata = &AccessListTx{ + ChainID: big.NewInt(1), + Nonce: i, + To: &recipient, + Gas: 123457, + GasPrice: big.NewInt(10), + AccessList: accesses, + Data: []byte("abcdef"), + } + case 3: + // Tx with empty access list. + txdata = &AccessListTx{ + ChainID: big.NewInt(1), + Nonce: i, + To: &recipient, + Gas: 123457, + GasPrice: big.NewInt(10), + Data: []byte("abcdef"), + } + case 4: + // Contract creation with access list. + txdata = &AccessListTx{ + ChainID: big.NewInt(1), + Nonce: i, + Gas: 123457, + GasPrice: big.NewInt(10), + AccessList: accesses, + } + } + tx, err := SignNewTx(key, signer, txdata) + if err != nil { + t.Fatalf("could not sign transaction: %v", err) + } + // RLP + parsedTx, err := encodeDecodeBinary(tx) + if err != nil { + t.Fatal(err) + } + assertEqual(parsedTx, tx) + + // JSON + parsedTx, err = encodeDecodeJSON(tx) + if err != nil { + t.Fatal(err) + } + assertEqual(parsedTx, tx) + } +} + +func encodeDecodeJSON(tx *Transaction) (*Transaction, error) { + data, err := json.Marshal(tx) + if err != nil { + return nil, fmt.Errorf("json encoding failed: %v", err) + } + var parsedTx = &Transaction{} + if err := json.Unmarshal(data, &parsedTx); err != nil { + return nil, fmt.Errorf("json decoding failed: %v", err) + } + return parsedTx, nil +} + +func encodeDecodeBinary(tx *Transaction) (*Transaction, error) { + data, err := tx.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("rlp encoding failed: %v", err) + } + var parsedTx = &Transaction{} + if err := parsedTx.UnmarshalBinary(data); err != nil { + return nil, fmt.Errorf("rlp decoding failed: %v", err) + } + return parsedTx, nil +} + +func assertEqual(orig *Transaction, cpy *Transaction) error { + // compare nonce, price, gaslimit, recipient, amount, payload, V, R, S + if want, got := orig.Hash(), cpy.Hash(); want != got { + return fmt.Errorf("parsed tx differs from original tx, want %v, got %v", want, got) + } + if want, got := orig.ChainId(), cpy.ChainId(); want.Cmp(got) != 0 { + return fmt.Errorf("invalid chain id, want %d, got %d", want, got) + } + if orig.AccessList() != nil { + if !reflect.DeepEqual(orig.AccessList(), cpy.AccessList()) { + return errors.New("access list wrong!") + } + } + return nil +} diff --git a/core/vm/access_list_tracer.go b/core/vm/access_list_tracer.go new file mode 100644 index 0000000000..1093af7c4d --- /dev/null +++ b/core/vm/access_list_tracer.go @@ -0,0 +1,177 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package vm + +import ( + "math/big" + "time" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/types" +) + +// accessList is an accumulator for the set of accounts and storage slots an EVM +// contract execution touches. +type accessList map[common.Address]accessListSlots + +// accessListSlots is an accumulator for the set of storage slots within a single +// contract that an EVM contract execution touches. +type accessListSlots map[common.Hash]struct{} + +// newAccessList creates a new accessList. +func newAccessList() accessList { + return make(map[common.Address]accessListSlots) +} + +// addAddress adds an address to the accesslist. +func (al accessList) addAddress(address common.Address) { + // Set address if not previously present + if _, present := al[address]; !present { + al[address] = make(map[common.Hash]struct{}) + } +} + +// addSlot adds a storage slot to the accesslist. +func (al accessList) addSlot(address common.Address, slot common.Hash) { + // Set address if not previously present + al.addAddress(address) + + // Set the slot on the surely existent storage set + al[address][slot] = struct{}{} +} + +// equal checks if the content of the current access list is the same as the +// content of the other one. +func (al accessList) equal(other accessList) bool { + // Cross reference the accounts first + if len(al) != len(other) { + return false + } + for addr := range al { + if _, ok := other[addr]; !ok { + return false + } + } + for addr := range other { + if _, ok := al[addr]; !ok { + return false + } + } + // Accounts match, cross reference the storage slots too + for addr, slots := range al { + otherslots := other[addr] + + if len(slots) != len(otherslots) { + return false + } + for hash := range slots { + if _, ok := otherslots[hash]; !ok { + return false + } + } + for hash := range otherslots { + if _, ok := slots[hash]; !ok { + return false + } + } + } + return true +} + +// accesslist converts the accesslist to a types.AccessList. +func (al accessList) accessList() types.AccessList { + acl := make(types.AccessList, 0, len(al)) + for addr, slots := range al { + tuple := types.AccessTuple{Address: addr} + for slot := range slots { + tuple.StorageKeys = append(tuple.StorageKeys, slot) + } + acl = append(acl, tuple) + } + return acl +} + +// AccessListTracer is a tracer that accumulates touched accounts and storage +// slots into an internal set. +type AccessListTracer struct { + excl map[common.Address]struct{} // Set of account to exclude from the list + list accessList // Set of accounts and storage slots touched +} + +// NewAccessListTracer creates a new tracer that can generate AccessLists. +// An optional AccessList can be specified to occupy slots and addresses in +// the resulting accesslist. +func NewAccessListTracer(acl types.AccessList, from, to common.Address, precompiles []common.Address) *AccessListTracer { + excl := map[common.Address]struct{}{ + from: {}, to: {}, + } + for _, addr := range precompiles { + excl[addr] = struct{}{} + } + list := newAccessList() + for _, al := range acl { + if _, ok := excl[al.Address]; !ok { + list.addAddress(al.Address) + } + for _, slot := range al.StorageKeys { + list.addSlot(al.Address, slot) + } + } + return &AccessListTracer{ + excl: excl, + list: list, + } +} + +func (a *AccessListTracer) CaptureStart(env *EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { +} + +// CaptureState captures all opcodes that touch storage or addresses and adds them to the accesslist. +func (a *AccessListTracer) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) { + stack := scope.Stack + if (op == SLOAD || op == SSTORE) && stack.len() >= 1 { + slot := common.Hash(stack.data[stack.len()-1].Bytes32()) + a.list.addSlot(scope.Contract.Address(), slot) + } + if (op == EXTCODECOPY || op == EXTCODEHASH || op == EXTCODESIZE || op == BALANCE || op == SELFDESTRUCT) && stack.len() >= 1 { + addr := common.Address(stack.data[stack.len()-1].Bytes20()) + if _, ok := a.excl[addr]; !ok { + a.list.addAddress(addr) + } + } + if (op == DELEGATECALL || op == CALL || op == STATICCALL || op == CALLCODE) && stack.len() >= 5 { + addr := common.Address(stack.data[stack.len()-2].Bytes20()) + if _, ok := a.excl[addr]; !ok { + a.list.addAddress(addr) + } + } +} + +func (*AccessListTracer) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, depth int, err error) { +} + +func (*AccessListTracer) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) {} + +// AccessList returns the current accesslist maintained by the tracer. +func (a *AccessListTracer) AccessList() types.AccessList { + return a.list.accessList() +} + +// Equal returns if the content of two access list traces are equal. +func (a *AccessListTracer) Equal(other *AccessListTracer) bool { + return a.list.equal(other.list) +} diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 79cd67875a..175db51a87 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -120,6 +120,20 @@ func init() { } } +// ActivePrecompiles returns the precompiles enabled with the current configuration. +func ActivePrecompiles(rules params.Rules) []common.Address { + switch { + case rules.IsXDCxDisable: + return PrecompiledAddressesXDCv2 + case rules.IsIstanbul: + return PrecompiledAddressesIstanbul + case rules.IsByzantium: + return PrecompiledAddressesByzantium + default: + return PrecompiledAddressesHomestead + } +} + // RunPrecompiledContract runs and evaluates the output of a precompiled contract. func RunPrecompiledContract(p PrecompiledContract, input []byte, contract *Contract) (ret []byte, err error) { gas := p.RequiredGas(input) diff --git a/core/vm/eips.go b/core/vm/eips.go index e24c22683b..69cca508b3 100644 --- a/core/vm/eips.go +++ b/core/vm/eips.go @@ -67,9 +67,9 @@ func enable1884(jt *JumpTable) { } } -func opSelfBalance(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - balance, _ := uint256.FromBig(interpreter.evm.StateDB.GetBalance(callContext.contract.Address())) - callContext.stack.push(balance) +func opSelfBalance(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + balance, _ := uint256.FromBig(interpreter.evm.StateDB.GetBalance(scope.Contract.Address())) + scope.Stack.push(balance) return nil, nil } @@ -86,9 +86,9 @@ func enable1344(jt *JumpTable) { } // opChainID implements CHAINID opcode -func opChainID(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opChainID(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { chainId, _ := uint256.FromBig(interpreter.evm.chainConfig.ChainId) - callContext.stack.push(chainId) + scope.Stack.push(chainId) return nil, nil } @@ -149,9 +149,9 @@ func enable3198(jt *JumpTable) { } // opBaseFee implements BASEFEE opcode -func opBaseFee(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opBaseFee(pc *uint64, interpreter *EVMInterpreter, callContext *ScopeContext) ([]byte, error) { baseFee, _ := uint256.FromBig(common.MinGasPrice50x) - callContext.stack.push(baseFee) + callContext.Stack.push(baseFee) return nil, nil } @@ -167,7 +167,7 @@ func enable3855(jt *JumpTable) { } // opPush0 implements the PUSH0 opcode -func opPush0(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int)) +func opPush0(pc *uint64, interpreter *EVMInterpreter, callContext *ScopeContext) ([]byte, error) { + callContext.Stack.push(new(uint256.Int)) return nil, nil } diff --git a/core/vm/evm.go b/core/vm/evm.go index 626ffc3ba4..cdc0f9a48a 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -42,21 +42,6 @@ type ( GetHashFunc func(uint64) common.Hash ) -// ActivePrecompiles returns the addresses of the precompiles enabled with the current -// configuration -func (evm *EVM) ActivePrecompiles() []common.Address { - switch { - case evm.chainRules.IsXDCxDisable: - return PrecompiledAddressesXDCv2 - case evm.chainRules.IsIstanbul: - return PrecompiledAddressesIstanbul - case evm.chainRules.IsByzantium: - return PrecompiledAddressesByzantium - default: - return PrecompiledAddressesHomestead - } -} - func (evm *EVM) precompile(addr common.Address) (PrecompiledContract, bool) { var precompiles map[common.Address]PrecompiledContract switch { @@ -248,7 +233,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas if !isPrecompile && evm.chainRules.IsEIP158 && value.Sign() == 0 { // Calling a non existing account, don't do anything, but ping the tracer if evm.vmConfig.Debug && evm.depth == 0 { - evm.vmConfig.Tracer.CaptureStart(caller.Address(), addr, false, input, gas, value) + evm.vmConfig.Tracer.CaptureStart(evm, caller.Address(), addr, false, input, gas, value) evm.vmConfig.Tracer.CaptureEnd(ret, 0, 0, nil) } return nil, gas, nil @@ -266,7 +251,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas // Capture the tracer start/end events in debug mode if evm.vmConfig.Debug && evm.depth == 0 { - evm.vmConfig.Tracer.CaptureStart(caller.Address(), addr, false, input, gas, value) + evm.vmConfig.Tracer.CaptureStart(evm, caller.Address(), addr, false, input, gas, value) defer func() { // Lazy evaluation of the parameters evm.vmConfig.Tracer.CaptureEnd(ret, gas-contract.Gas, time.Since(start), err) @@ -439,7 +424,7 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, contract.SetCodeOptionalHash(&address, codeAndHash) if evm.vmConfig.Debug && evm.depth == 0 { - evm.vmConfig.Tracer.CaptureStart(caller.Address(), address, true, codeAndHash.code, gas, value) + evm.vmConfig.Tracer.CaptureStart(evm, caller.Address(), address, true, codeAndHash.code, gas, value) } start := time.Now() diff --git a/core/vm/instructions.go b/core/vm/instructions.go index a020654440..39f0032abe 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -26,68 +26,68 @@ import ( "golang.org/x/crypto/sha3" ) -func opAdd(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opAdd(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Add(&x, y) return nil, nil } -func opSub(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opSub(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Sub(&x, y) return nil, nil } -func opMul(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opMul(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Mul(&x, y) return nil, nil } -func opDiv(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opDiv(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Div(&x, y) return nil, nil } -func opSdiv(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opSdiv(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.SDiv(&x, y) return nil, nil } -func opMod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opMod(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Mod(&x, y) return nil, nil } -func opSmod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opSmod(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.SMod(&x, y) return nil, nil } -func opExp(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - base, exponent := callContext.stack.pop(), callContext.stack.peek() +func opExp(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + base, exponent := scope.Stack.pop(), scope.Stack.peek() exponent.Exp(&base, exponent) return nil, nil } -func opSignExtend(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - back, num := callContext.stack.pop(), callContext.stack.peek() +func opSignExtend(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + back, num := scope.Stack.pop(), scope.Stack.peek() num.ExtendSign(num, &back) return nil, nil } -func opNot(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x := callContext.stack.peek() +func opNot(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x := scope.Stack.peek() x.Not(x) return nil, nil } -func opLt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opLt(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() if x.Lt(y) { y.SetOne() } else { @@ -96,8 +96,8 @@ func opLt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte return nil, nil } -func opGt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opGt(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() if x.Gt(y) { y.SetOne() } else { @@ -106,8 +106,8 @@ func opGt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte return nil, nil } -func opSlt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opSlt(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() if x.Slt(y) { y.SetOne() } else { @@ -116,8 +116,8 @@ func opSlt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byt return nil, nil } -func opSgt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opSgt(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() if x.Sgt(y) { y.SetOne() } else { @@ -126,8 +126,8 @@ func opSgt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byt return nil, nil } -func opEq(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opEq(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() if x.Eq(y) { y.SetOne() } else { @@ -136,8 +136,8 @@ func opEq(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte return nil, nil } -func opIszero(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x := callContext.stack.peek() +func opIszero(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x := scope.Stack.peek() if x.IsZero() { x.SetOne() } else { @@ -146,32 +146,32 @@ func opIszero(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] return nil, nil } -func opAnd(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opAnd(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.And(&x, y) return nil, nil } -func opOr(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opOr(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Or(&x, y) return nil, nil } -func opXor(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opXor(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Xor(&x, y) return nil, nil } -func opByte(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - th, val := callContext.stack.pop(), callContext.stack.peek() +func opByte(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + th, val := scope.Stack.pop(), scope.Stack.peek() val.Byte(&th) return nil, nil } -func opAddmod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y, z := callContext.stack.pop(), callContext.stack.pop(), callContext.stack.peek() +func opAddmod(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y, z := scope.Stack.pop(), scope.Stack.pop(), scope.Stack.peek() if z.IsZero() { z.Clear() } else { @@ -180,8 +180,8 @@ func opAddmod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] return nil, nil } -func opMulmod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y, z := callContext.stack.pop(), callContext.stack.pop(), callContext.stack.peek() +func opMulmod(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y, z := scope.Stack.pop(), scope.Stack.pop(), scope.Stack.peek() z.MulMod(&x, &y, z) return nil, nil } @@ -189,9 +189,9 @@ func opMulmod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] // opSHL implements Shift Left // The SHL instruction (shift left) pops 2 values from the stack, first arg1 and then arg2, // and pushes on the stack arg2 shifted to the left by arg1 number of bits. -func opSHL(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opSHL(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { // Note, second operand is left in the stack; accumulate result into it, and no need to push it afterwards - shift, value := callContext.stack.pop(), callContext.stack.peek() + shift, value := scope.Stack.pop(), scope.Stack.peek() if shift.LtUint64(256) { value.Lsh(value, uint(shift.Uint64())) } else { @@ -203,9 +203,9 @@ func opSHL(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byt // opSHR implements Logical Shift Right // The SHR instruction (logical shift right) pops 2 values from the stack, first arg1 and then arg2, // and pushes on the stack arg2 shifted to the right by arg1 number of bits with zero fill. -func opSHR(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opSHR(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { // Note, second operand is left in the stack; accumulate result into it, and no need to push it afterwards - shift, value := callContext.stack.pop(), callContext.stack.peek() + shift, value := scope.Stack.pop(), scope.Stack.peek() if shift.LtUint64(256) { value.Rsh(value, uint(shift.Uint64())) } else { @@ -217,8 +217,8 @@ func opSHR(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byt // opSAR implements Arithmetic Shift Right // The SAR instruction (arithmetic shift right) pops 2 values from the stack, first arg1 and then arg2, // and pushes on the stack arg2 shifted to the right by arg1 number of bits with sign extension. -func opSAR(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - shift, value := callContext.stack.pop(), callContext.stack.peek() +func opSAR(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + shift, value := scope.Stack.pop(), scope.Stack.peek() if shift.GtUint64(256) { if value.Sign() >= 0 { value.Clear() @@ -233,9 +233,9 @@ func opSAR(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byt return nil, nil } -func opKeccak256(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - offset, size := callContext.stack.pop(), callContext.stack.peek() - data := callContext.memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) +func opKeccak256(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + offset, size := scope.Stack.pop(), scope.Stack.peek() + data := scope.Memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) if interpreter.hasher == nil { interpreter.hasher = sha3.NewLegacyKeccak256().(keccakState) @@ -253,37 +253,37 @@ func opKeccak256(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) size.SetBytes(interpreter.hasherBuf[:]) return nil, nil } -func opAddress(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetBytes(callContext.contract.Address().Bytes())) +func opAddress(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetBytes(scope.Contract.Address().Bytes())) return nil, nil } -func opBalance(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - slot := callContext.stack.peek() +func opBalance(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + slot := scope.Stack.peek() address := common.Address(slot.Bytes20()) slot.SetFromBig(interpreter.evm.StateDB.GetBalance(address)) return nil, nil } -func opOrigin(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetBytes(interpreter.evm.Origin.Bytes())) +func opOrigin(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetBytes(interpreter.evm.Origin.Bytes())) return nil, nil } -func opCaller(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetBytes(callContext.contract.Caller().Bytes())) +func opCaller(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetBytes(scope.Contract.Caller().Bytes())) return nil, nil } -func opCallValue(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - v, _ := uint256.FromBig(callContext.contract.value) - callContext.stack.push(v) +func opCallValue(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + v, _ := uint256.FromBig(scope.Contract.value) + scope.Stack.push(v) return nil, nil } -func opCallDataLoad(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x := callContext.stack.peek() +func opCallDataLoad(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x := scope.Stack.peek() if offset, overflow := x.Uint64WithOverflow(); !overflow { - data := getData(callContext.contract.Input, offset, 32) + data := getData(scope.Contract.Input, offset, 32) x.SetBytes(data) } else { x.Clear() @@ -291,16 +291,16 @@ func opCallDataLoad(pc *uint64, interpreter *EVMInterpreter, callContext *callCt return nil, nil } -func opCallDataSize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(uint64(len(callContext.contract.Input)))) +func opCallDataSize(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(uint64(len(scope.Contract.Input)))) return nil, nil } -func opCallDataCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCallDataCopy(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var ( - memOffset = callContext.stack.pop() - dataOffset = callContext.stack.pop() - length = callContext.stack.pop() + memOffset = scope.Stack.pop() + dataOffset = scope.Stack.pop() + length = scope.Stack.pop() ) dataOffset64, overflow := dataOffset.Uint64WithOverflow() if overflow { @@ -309,21 +309,21 @@ func opCallDataCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCt // These values are checked for overflow during gas cost calculation memOffset64 := memOffset.Uint64() length64 := length.Uint64() - callContext.memory.Set(memOffset64, length64, getData(callContext.contract.Input, dataOffset64, length64)) + scope.Memory.Set(memOffset64, length64, getData(scope.Contract.Input, dataOffset64, length64)) return nil, nil } -func opReturnDataSize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(uint64(len(interpreter.returnData)))) +func opReturnDataSize(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(uint64(len(interpreter.returnData)))) return nil, nil } -func opReturnDataCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opReturnDataCopy(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var ( - memOffset = callContext.stack.pop() - dataOffset = callContext.stack.pop() - length = callContext.stack.pop() + memOffset = scope.Stack.pop() + dataOffset = scope.Stack.pop() + length = scope.Stack.pop() ) offset64, overflow := dataOffset.Uint64WithOverflow() @@ -337,42 +337,42 @@ func opReturnDataCopy(pc *uint64, interpreter *EVMInterpreter, callContext *call if overflow || uint64(len(interpreter.returnData)) < end64 { return nil, ErrReturnDataOutOfBounds } - callContext.memory.Set(memOffset.Uint64(), length.Uint64(), interpreter.returnData[offset64:end64]) + scope.Memory.Set(memOffset.Uint64(), length.Uint64(), interpreter.returnData[offset64:end64]) return nil, nil } -func opExtCodeSize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - slot := callContext.stack.peek() +func opExtCodeSize(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + slot := scope.Stack.peek() slot.SetUint64(uint64(interpreter.evm.StateDB.GetCodeSize(common.Address(slot.Bytes20())))) return nil, nil } -func opCodeSize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCodeSize(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { l := new(uint256.Int) - l.SetUint64(uint64(len(callContext.contract.Code))) - callContext.stack.push(l) + l.SetUint64(uint64(len(scope.Contract.Code))) + scope.Stack.push(l) return nil, nil } -func opCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCodeCopy(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var ( - memOffset = callContext.stack.pop() - codeOffset = callContext.stack.pop() - length = callContext.stack.pop() + memOffset = scope.Stack.pop() + codeOffset = scope.Stack.pop() + length = scope.Stack.pop() ) uint64CodeOffset, overflow := codeOffset.Uint64WithOverflow() if overflow { uint64CodeOffset = 0xffffffffffffffff } - codeCopy := getData(callContext.contract.Code, uint64CodeOffset, length.Uint64()) - callContext.memory.Set(memOffset.Uint64(), length.Uint64(), codeCopy) + codeCopy := getData(scope.Contract.Code, uint64CodeOffset, length.Uint64()) + scope.Memory.Set(memOffset.Uint64(), length.Uint64(), codeCopy) return nil, nil } -func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var ( - stack = callContext.stack + stack = scope.Stack a = stack.pop() memOffset = stack.pop() codeOffset = stack.pop() @@ -384,7 +384,7 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx } addr := common.Address(a.Bytes20()) codeCopy := getData(interpreter.evm.StateDB.GetCode(addr), uint64CodeOffset, length.Uint64()) - callContext.memory.Set(memOffset.Uint64(), length.Uint64(), codeCopy) + scope.Memory.Set(memOffset.Uint64(), length.Uint64(), codeCopy) return nil, nil } @@ -392,16 +392,21 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx // opExtCodeHash returns the code hash of a specified account. // There are several cases when the function is called, while we can relay everything // to `state.GetCodeHash` function to ensure the correctness. -// (1) Caller tries to get the code hash of a normal contract account, state +// +// (1) Caller tries to get the code hash of a normal contract account, state +// // should return the relative code hash and set it as the result. // -// (2) Caller tries to get the code hash of a non-existent account, state should +// (2) Caller tries to get the code hash of a non-existent account, state should +// // return common.Hash{} and zero will be set as the result. // -// (3) Caller tries to get the code hash for an account without contract code, +// (3) Caller tries to get the code hash for an account without contract code, +// // state should return emptyCodeHash(0xc5d246...) as the result. // -// (4) Caller tries to get the code hash of a precompiled account, the result +// (4) Caller tries to get the code hash of a precompiled account, the result +// // should be zero or emptyCodeHash. // // It is worth noting that in order to avoid unnecessary create and clean, @@ -410,13 +415,15 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx // If the precompile account is not transferred any amount on a private or // customized chain, the return value will be zero. // -// (5) Caller tries to get the code hash for an account which is marked as suicided +// (5) Caller tries to get the code hash for an account which is marked as suicided +// // in the current transaction, the code hash of this account should be returned. // -// (6) Caller tries to get the code hash for an account which is marked as deleted, +// (6) Caller tries to get the code hash for an account which is marked as deleted, +// // this account should be regarded as a non-existent account and zero should be returned. -func opExtCodeHash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - slot := callContext.stack.peek() +func opExtCodeHash(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + slot := scope.Stack.peek() address := common.Address(slot.Bytes20()) if interpreter.evm.StateDB.Empty(address) { slot.Clear() @@ -426,14 +433,14 @@ func opExtCodeHash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx return nil, nil } -func opGasprice(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opGasprice(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { v, _ := uint256.FromBig(interpreter.evm.GasPrice) - callContext.stack.push(v) + scope.Stack.push(v) return nil, nil } -func opBlockhash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - num := callContext.stack.peek() +func opBlockhash(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + num := scope.Stack.peek() num64, overflow := num.Uint64WithOverflow() if overflow { num.Clear() @@ -454,108 +461,108 @@ func opBlockhash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) return nil, nil } -func opCoinbase(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetBytes(interpreter.evm.Coinbase.Bytes())) +func opCoinbase(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetBytes(interpreter.evm.Coinbase.Bytes())) return nil, nil } -func opTimestamp(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opTimestamp(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { v, _ := uint256.FromBig(interpreter.evm.Time) - callContext.stack.push(v) + scope.Stack.push(v) return nil, nil } -func opNumber(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opNumber(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { v, _ := uint256.FromBig(interpreter.evm.BlockNumber) - callContext.stack.push(v) + scope.Stack.push(v) return nil, nil } -func opDifficulty(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opDifficulty(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { v, _ := uint256.FromBig(interpreter.evm.Difficulty) - callContext.stack.push(v) + scope.Stack.push(v) return nil, nil } -func opRandom(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opRandom(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var v *uint256.Int if interpreter.evm.Context.Random != nil { v = new(uint256.Int).SetBytes((interpreter.evm.Context.Random.Bytes())) } else { // if context random is not set, use emptyCodeHash as default v = new(uint256.Int).SetBytes(emptyCodeHash.Bytes()) } - callContext.stack.push(v) + scope.Stack.push(v) return nil, nil } -func opGasLimit(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(interpreter.evm.GasLimit)) +func opGasLimit(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(interpreter.evm.GasLimit)) return nil, nil } -func opPop(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.pop() +func opPop(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.pop() return nil, nil } -func opMload(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - v := callContext.stack.peek() +func opMload(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + v := scope.Stack.peek() offset := int64(v.Uint64()) - v.SetBytes(callContext.memory.GetPtr(offset, 32)) + v.SetBytes(scope.Memory.GetPtr(offset, 32)) return nil, nil } -func opMstore(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opMstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { // pop value of the stack - mStart, val := callContext.stack.pop(), callContext.stack.pop() - callContext.memory.Set32(mStart.Uint64(), &val) + mStart, val := scope.Stack.pop(), scope.Stack.pop() + scope.Memory.Set32(mStart.Uint64(), &val) return nil, nil } -func opMstore8(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - off, val := callContext.stack.pop(), callContext.stack.pop() - callContext.memory.store[off.Uint64()] = byte(val.Uint64()) +func opMstore8(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + off, val := scope.Stack.pop(), scope.Stack.pop() + scope.Memory.store[off.Uint64()] = byte(val.Uint64()) return nil, nil } -func opSload(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - loc := callContext.stack.peek() +func opSload(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + loc := scope.Stack.peek() hash := common.Hash(loc.Bytes32()) - val := interpreter.evm.StateDB.GetState(callContext.contract.Address(), hash) + val := interpreter.evm.StateDB.GetState(scope.Contract.Address(), hash) loc.SetBytes(val.Bytes()) return nil, nil } -func opSstore(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if interpreter.readOnly { return nil, ErrWriteProtection } - loc := callContext.stack.pop() - val := callContext.stack.pop() - interpreter.evm.StateDB.SetState(callContext.contract.Address(), + loc := scope.Stack.pop() + val := scope.Stack.pop() + interpreter.evm.StateDB.SetState(scope.Contract.Address(), common.Hash(loc.Bytes32()), common.Hash(val.Bytes32())) return nil, nil } -func opJump(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opJump(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if atomic.LoadInt32(&interpreter.evm.abort) != 0 { return nil, errStopToken } - pos := callContext.stack.pop() - if !callContext.contract.validJumpdest(&pos) { + pos := scope.Stack.pop() + if !scope.Contract.validJumpdest(&pos) { return nil, ErrInvalidJump } *pc = pos.Uint64() - 1 // pc will be increased by the interpreter loop return nil, nil } -func opJumpi(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opJumpi(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if atomic.LoadInt32(&interpreter.evm.abort) != 0 { return nil, errStopToken } - pos, cond := callContext.stack.pop(), callContext.stack.pop() + pos, cond := scope.Stack.pop(), scope.Stack.pop() if !cond.IsZero() { - if !callContext.contract.validJumpdest(&pos) { + if !scope.Contract.validJumpdest(&pos) { return nil, ErrInvalidJump } *pc = pos.Uint64() - 1 // pc will be increased by the interpreter loop @@ -563,34 +570,34 @@ func opJumpi(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]b return nil, nil } -func opJumpdest(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opJumpdest(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { return nil, nil } -func opPc(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(*pc)) +func opPc(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(*pc)) return nil, nil } -func opMsize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(uint64(callContext.memory.Len()))) +func opMsize(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(uint64(scope.Memory.Len()))) return nil, nil } -func opGas(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(callContext.contract.Gas)) +func opGas(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(scope.Contract.Gas)) return nil, nil } -func opCreate(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCreate(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if interpreter.readOnly { return nil, ErrWriteProtection } var ( - value = callContext.stack.pop() - offset, size = callContext.stack.pop(), callContext.stack.pop() - input = callContext.memory.GetCopy(int64(offset.Uint64()), int64(size.Uint64())) - gas = callContext.contract.Gas + value = scope.Stack.pop() + offset, size = scope.Stack.pop(), scope.Stack.pop() + input = scope.Memory.GetCopy(int64(offset.Uint64()), int64(size.Uint64())) + gas = scope.Contract.Gas ) if interpreter.evm.chainRules.IsEIP150 { gas -= gas / 64 @@ -598,8 +605,8 @@ func opCreate(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] // reuse size int for stackvalue stackvalue := size - callContext.contract.UseGas(gas) - res, addr, returnGas, suberr := interpreter.evm.Create(callContext.contract, input, gas, value.ToBig()) + scope.Contract.UseGas(gas) + res, addr, returnGas, suberr := interpreter.evm.Create(scope.Contract, input, gas, value.ToBig()) // Push item on the stack based on the returned error. If the ruleset is // homestead we must check for CodeStoreOutOfGasError (homestead only // rule) and treat as an error, if the ruleset is frontier we must @@ -611,8 +618,8 @@ func opCreate(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] } else { stackvalue.SetBytes(addr.Bytes()) } - callContext.stack.push(&stackvalue) - callContext.contract.Gas += returnGas + scope.Stack.push(&stackvalue) + scope.Contract.Gas += returnGas if suberr == ErrExecutionReverted { interpreter.returnData = res // set REVERT data to return data buffer @@ -622,24 +629,24 @@ func opCreate(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] return nil, nil } -func opCreate2(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCreate2(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if interpreter.readOnly { return nil, ErrWriteProtection } var ( - endowment = callContext.stack.pop() - offset, size = callContext.stack.pop(), callContext.stack.pop() - salt = callContext.stack.pop() - input = callContext.memory.GetCopy(int64(offset.Uint64()), int64(size.Uint64())) - gas = callContext.contract.Gas + endowment = scope.Stack.pop() + offset, size = scope.Stack.pop(), scope.Stack.pop() + salt = scope.Stack.pop() + input = scope.Memory.GetCopy(int64(offset.Uint64()), int64(size.Uint64())) + gas = scope.Contract.Gas ) // Apply EIP150 gas -= gas / 64 - callContext.contract.UseGas(gas) + scope.Contract.UseGas(gas) // reuse size int for stackvalue stackvalue := size - res, addr, returnGas, suberr := interpreter.evm.Create2(callContext.contract, input, gas, + res, addr, returnGas, suberr := interpreter.evm.Create2(scope.Contract, input, gas, endowment.ToBig(), salt.ToBig()) // Push item on the stack based on the returned error. if suberr != nil { @@ -647,8 +654,8 @@ func opCreate2(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([ } else { stackvalue.SetBytes(addr.Bytes()) } - callContext.stack.push(&stackvalue) - callContext.contract.Gas += returnGas + scope.Stack.push(&stackvalue) + scope.Contract.Gas += returnGas if suberr == ErrExecutionReverted { interpreter.returnData = res // set REVERT data to return data buffer @@ -658,8 +665,8 @@ func opCreate2(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([ return nil, nil } -func opCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - stack := callContext.stack +func opCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + stack := scope.Stack // Pop gas. The actual gas in interpreter.evm.callGasTemp. // We can use this as a temporary value temp := stack.pop() @@ -668,7 +675,7 @@ func opCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]by addr, value, inOffset, inSize, retOffset, retSize := stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop() toAddr := common.Address(addr.Bytes20()) // Get the arguments from the memory. - args := callContext.memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) + args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) if interpreter.readOnly && !value.IsZero() { return nil, ErrWriteProtection @@ -676,7 +683,7 @@ func opCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]by if !value.IsZero() { gas += params.CallStipend } - ret, returnGas, err := interpreter.evm.Call(callContext.contract, toAddr, args, gas, value.ToBig()) + ret, returnGas, err := interpreter.evm.Call(scope.Contract, toAddr, args, gas, value.ToBig()) if err != nil { temp.Clear() } else { @@ -685,17 +692,17 @@ func opCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]by stack.push(&temp) if err == nil || err == ErrExecutionReverted { ret = common.CopyBytes(ret) - callContext.memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) + scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - callContext.contract.Gas += returnGas + scope.Contract.Gas += returnGas interpreter.returnData = ret return ret, nil } -func opCallCode(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCallCode(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { // Pop gas. The actual gas is in interpreter.evm.callGasTemp. - stack := callContext.stack + stack := scope.Stack // We use it as a temporary value temp := stack.pop() gas := interpreter.evm.callGasTemp @@ -703,12 +710,12 @@ func opCallCode(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ( addr, value, inOffset, inSize, retOffset, retSize := stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop() toAddr := common.Address(addr.Bytes20()) // Get arguments from the memory. - args := callContext.memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) + args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) if !value.IsZero() { gas += params.CallStipend } - ret, returnGas, err := interpreter.evm.CallCode(callContext.contract, toAddr, args, gas, value.ToBig()) + ret, returnGas, err := interpreter.evm.CallCode(scope.Contract, toAddr, args, gas, value.ToBig()) if err != nil { temp.Clear() } else { @@ -717,16 +724,16 @@ func opCallCode(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ( stack.push(&temp) if err == nil || err == ErrExecutionReverted { ret = common.CopyBytes(ret) - callContext.memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) + scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - callContext.contract.Gas += returnGas + scope.Contract.Gas += returnGas interpreter.returnData = ret return ret, nil } -func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - stack := callContext.stack +func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + stack := scope.Stack // Pop gas. The actual gas is in interpreter.evm.callGasTemp. // We use it as a temporary value temp := stack.pop() @@ -735,9 +742,9 @@ func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCt addr, inOffset, inSize, retOffset, retSize := stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop() toAddr := common.Address(addr.Bytes20()) // Get arguments from the memory. - args := callContext.memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) + args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) - ret, returnGas, err := interpreter.evm.DelegateCall(callContext.contract, toAddr, args, gas) + ret, returnGas, err := interpreter.evm.DelegateCall(scope.Contract, toAddr, args, gas) if err != nil { temp.Clear() } else { @@ -746,17 +753,17 @@ func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCt stack.push(&temp) if err == nil || err == ErrExecutionReverted { ret = common.CopyBytes(ret) - callContext.memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) + scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - callContext.contract.Gas += returnGas + scope.Contract.Gas += returnGas interpreter.returnData = ret return ret, nil } -func opStaticCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opStaticCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { // Pop gas. The actual gas is in interpreter.evm.callGasTemp. - stack := callContext.stack + stack := scope.Stack // We use it as a temporary value temp := stack.pop() gas := interpreter.evm.callGasTemp @@ -764,9 +771,9 @@ func opStaticCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) addr, inOffset, inSize, retOffset, retSize := stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop() toAddr := common.Address(addr.Bytes20()) // Get arguments from the memory. - args := callContext.memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) + args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) - ret, returnGas, err := interpreter.evm.StaticCall(callContext.contract, toAddr, args, gas) + ret, returnGas, err := interpreter.evm.StaticCall(scope.Contract, toAddr, args, gas) if err != nil { temp.Clear() } else { @@ -775,45 +782,45 @@ func opStaticCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) stack.push(&temp) if err == nil || err == ErrExecutionReverted { ret = common.CopyBytes(ret) - callContext.memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) + scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - callContext.contract.Gas += returnGas + scope.Contract.Gas += returnGas interpreter.returnData = ret return ret, nil } -func opReturn(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - offset, size := callContext.stack.pop(), callContext.stack.pop() - ret := callContext.memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) +func opReturn(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + offset, size := scope.Stack.pop(), scope.Stack.pop() + ret := scope.Memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) return ret, errStopToken } -func opRevert(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - offset, size := callContext.stack.pop(), callContext.stack.pop() - ret := callContext.memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) +func opRevert(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + offset, size := scope.Stack.pop(), scope.Stack.pop() + ret := scope.Memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) interpreter.returnData = ret return ret, ErrExecutionReverted } -func opUndefined(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - return nil, &ErrInvalidOpCode{opcode: OpCode(callContext.contract.Code[*pc])} +func opUndefined(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + return nil, &ErrInvalidOpCode{opcode: OpCode(scope.Contract.Code[*pc])} } -func opStop(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opStop(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { return nil, errStopToken } -func opSelfdestruct(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opSelfdestruct(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if interpreter.readOnly { return nil, ErrWriteProtection } - beneficiary := callContext.stack.pop() - balance := interpreter.evm.StateDB.GetBalance(callContext.contract.Address()) + beneficiary := scope.Stack.pop() + balance := interpreter.evm.StateDB.GetBalance(scope.Contract.Address()) interpreter.evm.StateDB.AddBalance(common.Address(beneficiary.Bytes20()), balance) - interpreter.evm.StateDB.Suicide(callContext.contract.Address()) + interpreter.evm.StateDB.Suicide(scope.Contract.Address()) return nil, errStopToken } @@ -821,21 +828,21 @@ func opSelfdestruct(pc *uint64, interpreter *EVMInterpreter, callContext *callCt // make log instruction function func makeLog(size int) executionFunc { - return func(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { + return func(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if interpreter.readOnly { return nil, ErrWriteProtection } topics := make([]common.Hash, size) - stack := callContext.stack + stack := scope.Stack mStart, mSize := stack.pop(), stack.pop() for i := 0; i < size; i++ { addr := stack.pop() topics[i] = common.Hash(addr.Bytes32()) } - d := callContext.memory.GetCopy(int64(mStart.Uint64()), int64(mSize.Uint64())) + d := scope.Memory.GetCopy(int64(mStart.Uint64()), int64(mSize.Uint64())) interpreter.evm.StateDB.AddLog(&types.Log{ - Address: callContext.contract.Address(), + Address: scope.Contract.Address(), Topics: topics, Data: d, // This is a non-consensus field, but assigned here because @@ -848,24 +855,24 @@ func makeLog(size int) executionFunc { } // opPush1 is a specialized version of pushN -func opPush1(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opPush1(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var ( - codeLen = uint64(len(callContext.contract.Code)) + codeLen = uint64(len(scope.Contract.Code)) integer = new(uint256.Int) ) *pc += 1 if *pc < codeLen { - callContext.stack.push(integer.SetUint64(uint64(callContext.contract.Code[*pc]))) + scope.Stack.push(integer.SetUint64(uint64(scope.Contract.Code[*pc]))) } else { - callContext.stack.push(integer.Clear()) + scope.Stack.push(integer.Clear()) } return nil, nil } // make push instruction function func makePush(size uint64, pushByteSize int) executionFunc { - return func(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - codeLen := len(callContext.contract.Code) + return func(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + codeLen := len(scope.Contract.Code) startMin := codeLen if int(*pc+1) < startMin { @@ -878,8 +885,8 @@ func makePush(size uint64, pushByteSize int) executionFunc { } integer := new(uint256.Int) - callContext.stack.push(integer.SetBytes(common.RightPadBytes( - callContext.contract.Code[startMin:endMin], pushByteSize))) + scope.Stack.push(integer.SetBytes(common.RightPadBytes( + scope.Contract.Code[startMin:endMin], pushByteSize))) *pc += size return nil, nil @@ -888,8 +895,8 @@ func makePush(size uint64, pushByteSize int) executionFunc { // make dup instruction function func makeDup(size int64) executionFunc { - return func(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.dup(int(size)) + return func(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.dup(int(size)) return nil, nil } } @@ -898,8 +905,8 @@ func makeDup(size int64) executionFunc { func makeSwap(size int64) executionFunc { // switch n + 1 otherwise n would be swapped with n size++ - return func(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.swap(int(size)) + return func(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.swap(int(size)) return nil, nil } } diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go index 200aa14772..580a008b4e 100644 --- a/core/vm/instructions_test.go +++ b/core/vm/instructions_test.go @@ -107,7 +107,7 @@ func testTwoOperandOp(t *testing.T, tests []TwoOperandTestcase, opFn executionFu expected := new(uint256.Int).SetBytes(common.Hex2Bytes(test.Expected)) stack.push(x) stack.push(y) - opFn(&pc, evmInterpreter, &callCtx{nil, stack, nil}) + opFn(&pc, evmInterpreter, &ScopeContext{nil, stack, nil}) if len(stack.data) != 1 { t.Errorf("Expected one item on stack after %v, got %d: ", name, len(stack.data)) } @@ -222,7 +222,7 @@ func TestAddMod(t *testing.T) { stack.push(z) stack.push(y) stack.push(x) - opAddmod(&pc, evmInterpreter, &callCtx{nil, stack, nil}) + opAddmod(&pc, evmInterpreter, &ScopeContext{nil, stack, nil}) actual := stack.pop() if actual.Cmp(expected) != 0 { t.Errorf("Testcase %d, expected %x, got %x", i, expected, actual) @@ -244,7 +244,7 @@ func getResult(args []*twoOperandParams, opFn executionFunc) []TwoOperandTestcas y := new(uint256.Int).SetBytes(common.Hex2Bytes(param.y)) stack.push(x) stack.push(y) - _, err := opFn(&pc, interpreter, &callCtx{nil, stack, nil}) + _, err := opFn(&pc, interpreter, &ScopeContext{nil, stack, nil}) if err != nil { log.Fatalln(err) } @@ -308,7 +308,7 @@ func opBenchmark(bench *testing.B, op executionFunc, args ...string) { a.SetBytes(arg) stack.push(a) } - op(&pc, evmInterpreter, &callCtx{nil, stack, nil}) + op(&pc, evmInterpreter, &ScopeContext{nil, stack, nil}) stack.pop() } } @@ -535,13 +535,13 @@ func TestOpMstore(t *testing.T) { v := "abcdef00000000000000abba000000000deaf000000c0de00100000000133700" stack.push(new(uint256.Int).SetBytes(common.Hex2Bytes(v))) stack.push(new(uint256.Int)) - opMstore(&pc, evmInterpreter, &callCtx{mem, stack, nil}) + opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil}) if got := common.Bytes2Hex(mem.GetCopy(0, 32)); got != v { t.Fatalf("Mstore fail, got %v, expected %v", got, v) } stack.push(new(uint256.Int).SetUint64(0x1)) stack.push(new(uint256.Int)) - opMstore(&pc, evmInterpreter, &callCtx{mem, stack, nil}) + opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil}) if common.Bytes2Hex(mem.GetCopy(0, 32)) != "0000000000000000000000000000000000000000000000000000000000000001" { t.Fatalf("Mstore failed to overwrite previous value") } @@ -565,7 +565,7 @@ func BenchmarkOpMstore(bench *testing.B) { for i := 0; i < bench.N; i++ { stack.push(value) stack.push(memStart) - opMstore(&pc, evmInterpreter, &callCtx{mem, stack, nil}) + opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil}) } } @@ -585,7 +585,7 @@ func BenchmarkOpKeccak256(bench *testing.B) { for i := 0; i < bench.N; i++ { stack.push(uint256.NewInt(32)) stack.push(start) - opKeccak256(&pc, evmInterpreter, &callCtx{mem, stack, nil}) + opKeccak256(&pc, evmInterpreter, &ScopeContext{mem, stack, nil}) } } @@ -681,7 +681,7 @@ func TestRandom(t *testing.T) { pc = uint64(0) evmInterpreter = NewEVMInterpreter(env, env.vmConfig) ) - opRandom(&pc, evmInterpreter, &callCtx{nil, stack, nil}) + opRandom(&pc, evmInterpreter, &ScopeContext{nil, stack, nil}) if len(stack.data) != 1 { t.Errorf("Expected one item on stack after %v, got %d: ", tt.name, len(stack.data)) } diff --git a/core/vm/interface.go b/core/vm/interface.go index 2b209c0f09..903a957a20 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -57,6 +57,7 @@ type StateDB interface { // is defined according to EIP161 (balance = nonce = code = 0). Empty(common.Address) bool + PrepareAccessList(sender common.Address, dest *common.Address, precompiles []common.Address, txAccesses types.AccessList) AddressInAccessList(addr common.Address) bool SlotInAccessList(addr common.Address, slot common.Hash) (addressOk bool, slotOk bool) // AddAddressToAccessList adds the given address to the access list. This operation is safe to perform diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index d65ba6b021..9d16be1a66 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -60,12 +60,12 @@ type Interpreter interface { CanRun([]byte) bool } -// callCtx contains the things that are per-call, such as stack and memory, +// ScopeContext contains the things that are per-call, such as stack and memory, // but not transients like pc and gas -type callCtx struct { - memory *Memory - stack *Stack - contract *Contract +type ScopeContext struct { + Memory *Memory + Stack *Stack + Contract *Contract } // keccakState wraps sha3.state. In addition to the usual hash methods, it also supports @@ -170,10 +170,10 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( op OpCode // current opcode mem = NewMemory() // bound memory stack = newstack() // local stack - callContext = &callCtx{ - memory: mem, - stack: stack, - contract: contract, + callContext = &ScopeContext{ + Memory: mem, + Stack: stack, + Contract: contract, } // For optimisation reason we're using uint64 as the program counter. // It's theoretically possible to go above 2^64. The YP defines the PC @@ -192,9 +192,9 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( defer func() { if err != nil { if !logged { - in.cfg.Tracer.CaptureState(in.evm, pcCopy, op, gasCopy, cost, mem, stack, contract, in.evm.depth, err) + in.cfg.Tracer.CaptureState(in.evm, pcCopy, op, gasCopy, cost, callContext, in.returnData, in.evm.depth, err) } else { - in.cfg.Tracer.CaptureFault(in.evm, pcCopy, op, gasCopy, cost, mem, stack, contract, in.evm.depth, err) + in.cfg.Tracer.CaptureFault(in.evm, pcCopy, op, gasCopy, cost, callContext, in.evm.depth, err) } } }() @@ -257,7 +257,7 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( } if in.cfg.Debug { - in.cfg.Tracer.CaptureState(in.evm, pc, op, gasCopy, cost, mem, stack, contract, in.evm.depth, err) + in.cfg.Tracer.CaptureState(in.evm, pc, op, gasCopy, cost, callContext, in.returnData, in.evm.depth, err) logged = true } diff --git a/core/vm/jump_table.go b/core/vm/jump_table.go index d20871600d..4ae7c1e8ce 100644 --- a/core/vm/jump_table.go +++ b/core/vm/jump_table.go @@ -21,7 +21,7 @@ import ( ) type ( - executionFunc func(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) + executionFunc func(pc *uint64, interpreter *EVMInterpreter, callContext *ScopeContext) ([]byte, error) gasFunc func(*EVM, *Contract, *Stack, *Memory, uint64) (uint64, error) // last parameter is the requested memory size as a uint64 // memorySizeFunc returns the required size, and whether the operation overflowed a uint64 memorySizeFunc func(*Stack) (size uint64, overflow bool) diff --git a/core/vm/logger.go b/core/vm/logger.go index dd7d5c2a1a..a6a995eef9 100644 --- a/core/vm/logger.go +++ b/core/vm/logger.go @@ -18,10 +18,10 @@ package vm import ( "encoding/hex" - "errors" "fmt" "io" "math/big" + "strings" "time" "github.com/XinFinOrg/XDPoSChain/common" @@ -31,8 +31,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/params" ) -var errTraceLimitReached = errors.New("the number of logs reached the specified limit") - // Storage represents a contract's storage. type Storage map[common.Hash]common.Hash @@ -105,10 +103,10 @@ func (s *StructLog) ErrorString() string { // Note that reference types are actual VM data structures; make copies // if you need to retain them beyond the current call. type Tracer interface { - CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error - CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error - CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error - CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error + CaptureStart(env *EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) + CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) + CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, depth int, err error) + CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) } // StructLogger is an EVM state logger and implements Tracer. @@ -137,17 +135,19 @@ func NewStructLogger(cfg *LogConfig) *StructLogger { } // CaptureStart implements the Tracer interface to initialize the tracing operation. -func (l *StructLogger) CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error { - return nil +func (l *StructLogger) CaptureStart(env *EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { } // CaptureState logs a new structured log message and pushes it out to the environment // -// CaptureState also tracks SSTORE ops to track dirty values. -func (l *StructLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error { +// CaptureState also tracks SLOAD/SSTORE ops to track storage change. +func (l *StructLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) { + memory := scope.Memory + stack := scope.Stack + contract := scope.Contract // check if already accumulated the specified number of logs if l.cfg.Limit != 0 && l.cfg.Limit <= len(l.logs) { - return errTraceLimitReached + return } // initialise new changed values storage container for this contract @@ -188,17 +188,15 @@ func (l *StructLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost ui log := StructLog{pc, op, gas, cost, mem, memory.Len(), stck, storage, depth, env.StateDB.GetRefund(), err} l.logs = append(l.logs, log) - return nil } // CaptureFault implements the Tracer interface to trace an execution fault // while running an opcode. -func (l *StructLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error { - return nil +func (l *StructLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, depth int, err error) { } // CaptureEnd is called after the call finishes to finalize the tracing. -func (l *StructLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error { +func (l *StructLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) { l.output = output l.err = err if l.cfg.Debug { @@ -207,7 +205,6 @@ func (l *StructLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration fmt.Printf(" error: %v\n", err) } } - return nil } // StructLogs returns the captured log entries. @@ -261,3 +258,65 @@ func WriteLogs(writer io.Writer, logs []*types.Log) { fmt.Fprintln(writer) } } + +type mdLogger struct { + out io.Writer + cfg *LogConfig +} + +// NewMarkdownLogger creates a logger which outputs information in a format adapted +// for human readability, and is also a valid markdown table +func NewMarkdownLogger(cfg *LogConfig, writer io.Writer) *mdLogger { + l := &mdLogger{writer, cfg} + if l.cfg == nil { + l.cfg = &LogConfig{} + } + return l +} + +func (t *mdLogger) CaptureStart(env *EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { + if !create { + fmt.Fprintf(t.out, "From: `%v`\nTo: `%v`\nData: `0x%x`\nGas: `%d`\nValue `%v` wei\n", + from.String(), to.String(), + input, gas, value) + } else { + fmt.Fprintf(t.out, "From: `%v`\nCreate at: `%v`\nData: `0x%x`\nGas: `%d`\nValue `%v` wei\n", + from.String(), to.String(), + input, gas, value) + } + + fmt.Fprintf(t.out, ` +| Pc | Op | Cost | Stack | RStack | Refund | +|-------|-------------|------|-----------|-----------|---------| +`) +} + +// CaptureState also tracks SLOAD/SSTORE ops to track storage change. +func (t *mdLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) { + stack := scope.Stack + fmt.Fprintf(t.out, "| %4d | %10v | %3d |", pc, op, cost) + + if !t.cfg.DisableStack { + // format stack + var a []string + for _, elem := range stack.data { + a = append(a, fmt.Sprintf("%v", elem.String())) + } + b := fmt.Sprintf("[%v]", strings.Join(a, ",")) + fmt.Fprintf(t.out, "%10v |", b) + } + fmt.Fprintf(t.out, "%10v |", env.StateDB.GetRefund()) + fmt.Fprintln(t.out, "") + if err != nil { + fmt.Fprintf(t.out, "Error: %v\n", err) + } +} + +func (t *mdLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, depth int, err error) { + fmt.Fprintf(t.out, "\nError: at pc=%d, op=%v: %v\n", pc, op, err) +} + +func (t *mdLogger) CaptureEnd(output []byte, gasUsed uint64, tm time.Duration, err error) { + fmt.Fprintf(t.out, "\nOutput: `0x%x`\nConsumed gas: `%d`\nError: `%v`\n", + output, gasUsed, err) +} diff --git a/core/vm/logger_json.go b/core/vm/logger_json.go index 4a63120a3d..64bb3a9fe9 100644 --- a/core/vm/logger_json.go +++ b/core/vm/logger_json.go @@ -41,12 +41,16 @@ func NewJSONLogger(cfg *LogConfig, writer io.Writer) *JSONLogger { return l } -func (l *JSONLogger) CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error { - return nil +func (l *JSONLogger) CaptureStart(env *EVM, from, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { } +func (l *JSONLogger) CaptureFault(*EVM, uint64, OpCode, uint64, uint64, *ScopeContext, int, error) {} + // CaptureState outputs state information on the logger. -func (l *JSONLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error { +func (l *JSONLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) { + memory := scope.Memory + stack := scope.Stack + log := StructLog{ Pc: pc, Op: op, @@ -69,16 +73,11 @@ func (l *JSONLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint } log.Stack = logstack } - return l.encoder.Encode(log) -} - -// CaptureFault outputs state information on the logger. -func (l *JSONLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error { - return l.CaptureState(env, pc, op, gas, cost, memory, stack, contract, depth, err) + l.encoder.Encode(log) } // CaptureEnd is triggered at end of execution. -func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error { +func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) { type endLog struct { Output string `json:"output"` GasUsed math.HexOrDecimal64 `json:"gasUsed"` @@ -89,5 +88,5 @@ func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, if err != nil { errMsg = err.Error() } - return l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, errMsg}) + l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, errMsg}) } diff --git a/core/vm/logger_test.go b/core/vm/logger_test.go index f6c15d9cbb..ef5fb12510 100644 --- a/core/vm/logger_test.go +++ b/core/vm/logger_test.go @@ -52,14 +52,17 @@ func TestStoreCapture(t *testing.T) { var ( env = NewEVM(Context{}, &dummyStatedb{}, nil, params.TestChainConfig, Config{}) logger = NewStructLogger(nil) - mem = NewMemory() - stack = newstack() contract = NewContract(&dummyContractRef{}, &dummyContractRef{}, new(big.Int), 0) + scope = &ScopeContext{ + Memory: NewMemory(), + Stack: newstack(), + Contract: contract, + } ) - stack.push(uint256.NewInt(1)) - stack.push(new(uint256.Int)) + scope.Stack.push(uint256.NewInt(1)) + scope.Stack.push(new(uint256.Int)) var index common.Hash - logger.CaptureState(env, 0, SSTORE, 0, 0, mem, stack, contract, 0, nil) + logger.CaptureState(env, 0, SSTORE, 0, 0, scope, nil, 0, nil) if len(logger.changedValues[contract.Address()]) == 0 { t.Fatalf("expected exactly 1 changed value on address %x, got %d", contract.Address(), len(logger.changedValues[contract.Address()])) } diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go index 5c643ee320..a8741fa12d 100644 --- a/core/vm/runtime/runtime.go +++ b/core/vm/runtime/runtime.go @@ -107,13 +107,8 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { vmenv = NewEnv(cfg) sender = vm.AccountRef(cfg.Origin) ) - if cfg.ChainConfig.IsEIP1559(vmenv.BlockNumber) { - cfg.State.AddAddressToAccessList(cfg.Origin) - cfg.State.AddAddressToAccessList(address) - for _, addr := range vmenv.ActivePrecompiles() { - cfg.State.AddAddressToAccessList(addr) - cfg.State.AddAddressToAccessList(addr) - } + if rules := cfg.ChainConfig.Rules(vmenv.Context.BlockNumber); rules.IsEIP1559 { + cfg.State.PrepareAccessList(cfg.Origin, &address, vm.ActivePrecompiles(rules), nil) } cfg.State.CreateAccount(address) // set the receiver's (the executing contract) code for execution. @@ -145,13 +140,9 @@ func Create(input []byte, cfg *Config) ([]byte, common.Address, uint64, error) { vmenv = NewEnv(cfg) sender = vm.AccountRef(cfg.Origin) ) - if cfg.ChainConfig.IsEIP1559(vmenv.BlockNumber) { - cfg.State.AddAddressToAccessList(cfg.Origin) - for _, addr := range vmenv.ActivePrecompiles() { - cfg.State.AddAddressToAccessList(addr) - } + if rules := cfg.ChainConfig.Rules(vmenv.Context.BlockNumber); rules.IsEIP1559 { + cfg.State.PrepareAccessList(cfg.Origin, nil, vm.ActivePrecompiles(rules), nil) } - // Call the code with the given configuration. code, address, leftOverGas, err := vmenv.Create( sender, @@ -173,12 +164,9 @@ func Call(address common.Address, input []byte, cfg *Config) ([]byte, uint64, er vmenv := NewEnv(cfg) sender := cfg.State.GetOrNewStateObject(cfg.Origin) - if cfg.ChainConfig.IsEIP1559(vmenv.BlockNumber) { - cfg.State.AddAddressToAccessList(cfg.Origin) - cfg.State.AddAddressToAccessList(address) - for _, addr := range vmenv.ActivePrecompiles() { - cfg.State.AddAddressToAccessList(addr) - } + statedb := cfg.State + if rules := cfg.ChainConfig.Rules(vmenv.Context.BlockNumber); rules.IsEIP1559 { + statedb.PrepareAccessList(cfg.Origin, &address, vm.ActivePrecompiles(rules), nil) } // Call the code with the given configuration. @@ -189,6 +177,5 @@ func Call(address common.Address, input []byte, cfg *Config) ([]byte, uint64, er cfg.GasLimit, cfg.Value, ) - return ret, leftOverGas, err } diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go index 4c65513a4c..826b431f48 100644 --- a/core/vm/runtime/runtime_test.go +++ b/core/vm/runtime/runtime_test.go @@ -19,6 +19,7 @@ package runtime import ( "fmt" "math/big" + "os" "strings" "testing" @@ -329,32 +330,161 @@ func TestBlockhash(t *testing.T) { } } -// BenchmarkSimpleLoop test a pretty simple loop which loops -// 1M (1 048 575) times. -// Takes about 200 ms -func BenchmarkSimpleLoop(b *testing.B) { - // 0xfffff = 1048575 loops - code := []byte{ - byte(vm.PUSH3), 0x0f, 0xff, 0xff, - byte(vm.JUMPDEST), // [ count ] - byte(vm.PUSH1), 1, // [count, 1] - byte(vm.SWAP1), // [1, count] - byte(vm.SUB), // [ count -1 ] - byte(vm.DUP1), // [ count -1 , count-1] - byte(vm.PUSH1), 4, // [count-1, count -1, label] - byte(vm.JUMPI), // [ 0 ] - byte(vm.STOP), +// benchmarkNonModifyingCode benchmarks code, but if the code modifies the +// state, this should not be used, since it does not reset the state between runs. +func benchmarkNonModifyingCode(gas uint64, code []byte, name string, b *testing.B) { + cfg := new(Config) + setDefaults(cfg) + cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) + cfg.GasLimit = gas + var ( + destination = common.BytesToAddress([]byte("contract")) + vmenv = NewEnv(cfg) + sender = vm.AccountRef(cfg.Origin) + ) + cfg.State.CreateAccount(destination) + eoa := common.HexToAddress("E0") + { + cfg.State.CreateAccount(eoa) + cfg.State.SetNonce(eoa, 100) } + reverting := common.HexToAddress("EE") + { + cfg.State.CreateAccount(reverting) + cfg.State.SetCode(reverting, []byte{ + byte(vm.PUSH1), 0x00, + byte(vm.PUSH1), 0x00, + byte(vm.REVERT), + }) + } + + //cfg.State.CreateAccount(cfg.Origin) + // set the receiver's (the executing contract) code for execution. + cfg.State.SetCode(destination, code) + vmenv.Call(sender, destination, nil, gas, cfg.Value) + + b.Run(name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + vmenv.Call(sender, destination, nil, gas, cfg.Value) + } + }) +} + +// BenchmarkSimpleLoop test a pretty simple loop which loops until OOG +// 55 ms +func BenchmarkSimpleLoop(b *testing.B) { + + staticCallIdentity := []byte{ + byte(vm.JUMPDEST), // [ count ] + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.DUP1), // out insize + byte(vm.DUP1), // in offset + byte(vm.PUSH1), 0x4, // address of identity + byte(vm.GAS), // gas + byte(vm.STATICCALL), + byte(vm.POP), // pop return value + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + + callIdentity := []byte{ + byte(vm.JUMPDEST), // [ count ] + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.DUP1), // out insize + byte(vm.DUP1), // in offset + byte(vm.DUP1), // value + byte(vm.PUSH1), 0x4, // address of identity + byte(vm.GAS), // gas + byte(vm.CALL), + byte(vm.POP), // pop return value + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + + callInexistant := []byte{ + byte(vm.JUMPDEST), // [ count ] + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.DUP1), // out insize + byte(vm.DUP1), // in offset + byte(vm.DUP1), // value + byte(vm.PUSH1), 0xff, // address of existing contract + byte(vm.GAS), // gas + byte(vm.CALL), + byte(vm.POP), // pop return value + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + + callEOA := []byte{ + byte(vm.JUMPDEST), // [ count ] + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.DUP1), // out insize + byte(vm.DUP1), // in offset + byte(vm.DUP1), // value + byte(vm.PUSH1), 0xE0, // address of EOA + byte(vm.GAS), // gas + byte(vm.CALL), + byte(vm.POP), // pop return value + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + + loopingCode := []byte{ + byte(vm.JUMPDEST), // [ count ] + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.DUP1), // out insize + byte(vm.DUP1), // in offset + byte(vm.PUSH1), 0x4, // address of identity + byte(vm.GAS), // gas + + byte(vm.POP), byte(vm.POP), byte(vm.POP), byte(vm.POP), byte(vm.POP), byte(vm.POP), + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + + calllRevertingContractWithInput := []byte{ + byte(vm.JUMPDEST), // + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.PUSH1), 0x20, // in size + byte(vm.PUSH1), 0x00, // in offset + byte(vm.PUSH1), 0x00, // value + byte(vm.PUSH1), 0xEE, // address of reverting contract + byte(vm.GAS), // gas + byte(vm.CALL), + byte(vm.POP), // pop return value + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + //tracer := vm.NewJSONLogger(nil, os.Stdout) - //Execute(code, nil, &Config{ + //Execute(loopingCode, nil, &Config{ // EVMConfig: vm.Config{ // Debug: true, // Tracer: tracer, // }}) + // 100M gas + benchmarkNonModifyingCode(100000000, staticCallIdentity, "staticcall-identity-100M", b) + benchmarkNonModifyingCode(100000000, callIdentity, "call-identity-100M", b) + benchmarkNonModifyingCode(100000000, loopingCode, "loop-100M", b) + benchmarkNonModifyingCode(100000000, callInexistant, "call-nonexist-100M", b) + benchmarkNonModifyingCode(100000000, callEOA, "call-EOA-100M", b) + benchmarkNonModifyingCode(100000000, calllRevertingContractWithInput, "call-reverting-100M", b) - for i := 0; i < b.N; i++ { - Execute(code, nil, nil) - } + //benchmarkNonModifyingCode(10000000, staticCallIdentity, "staticcall-identity-10M", b) + //benchmarkNonModifyingCode(10000000, loopingCode, "loop-10M", b) } // TestEip2929Cases contains various testcases that are used for @@ -381,7 +511,8 @@ func TestEip2929Cases(t *testing.T) { code, ops) Execute(code, nil, &Config{ EVMConfig: vm.Config{ - Debug: false, + Debug: true, + Tracer: vm.NewMarkdownLogger(nil, os.Stdout), ExtraEips: []int{2929}, }, }) diff --git a/crypto/crypto.go b/crypto/crypto.go index 2213bf0c14..f8387cb733 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -17,57 +17,94 @@ package crypto import ( + "bufio" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "encoding/hex" "errors" "fmt" + "hash" "io" + "io/ioutil" "math/big" "os" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/math" - "github.com/XinFinOrg/XDPoSChain/crypto/sha3" "github.com/XinFinOrg/XDPoSChain/rlp" + "golang.org/x/crypto/sha3" ) +// SignatureLength indicates the byte length required to carry a signature with recovery id. +const SignatureLength = 64 + 1 // 64 bytes ECDSA signature + 1 byte recovery id + +// RecoveryIDOffset points to the byte offset within the signature that contains the recovery id. +const RecoveryIDOffset = 64 + +// DigestLength sets the signature digest exact length +const DigestLength = 32 + var ( - secp256k1_N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) - secp256k1_halfN = new(big.Int).Div(secp256k1_N, big.NewInt(2)) + secp256k1N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) + secp256k1halfN = new(big.Int).Div(secp256k1N, big.NewInt(2)) ) +var errInvalidPubkey = errors.New("invalid secp256k1 public key") + +// KeccakState wraps sha3.state. In addition to the usual hash methods, it also supports +// Read to get a variable amount of data from the hash state. Read is faster than Sum +// because it doesn't copy the internal state, but also modifies the internal state. +type KeccakState interface { + hash.Hash + Read([]byte) (int, error) +} + +// NewKeccakState creates a new KeccakState +func NewKeccakState() KeccakState { + return sha3.NewLegacyKeccak256().(KeccakState) +} + +// HashData hashes the provided data using the KeccakState and returns a 32 byte hash +func HashData(kh KeccakState, data []byte) (h common.Hash) { + kh.Reset() + kh.Write(data) + kh.Read(h[:]) + return h +} + // Keccak256 calculates and returns the Keccak256 hash of the input data. func Keccak256(data ...[]byte) []byte { - d := sha3.NewKeccak256() + b := make([]byte, 32) + d := NewKeccakState() for _, b := range data { d.Write(b) } - return d.Sum(nil) + d.Read(b) + return b } // Keccak256Hash calculates and returns the Keccak256 hash of the input data, // converting it to an internal Hash data structure. func Keccak256Hash(data ...[]byte) (h common.Hash) { - d := sha3.NewKeccak256() + d := NewKeccakState() for _, b := range data { d.Write(b) } - d.Sum(h[:0]) + d.Read(h[:]) return h } // Keccak512 calculates and returns the Keccak512 hash of the input data. func Keccak512(data ...[]byte) []byte { - d := sha3.NewKeccak512() + d := sha3.NewLegacyKeccak512() for _, b := range data { d.Write(b) } return d.Sum(nil) } -// Creates an ethereum address given the bytes and the nonce +// CreateAddress creates an ethereum address given the bytes and the nonce func CreateAddress(b common.Address, nonce uint64) common.Address { data, _ := rlp.EncodeToBytes([]interface{}{b, nonce}) return common.BytesToAddress(Keccak256(data)[12:]) @@ -104,12 +141,12 @@ func toECDSA(d []byte, strict bool) (*ecdsa.PrivateKey, error) { priv.D = new(big.Int).SetBytes(d) // The priv.D must < N - if priv.D.Cmp(secp256k1_N) >= 0 { - return nil, fmt.Errorf("invalid private key, >=N") + if priv.D.Cmp(secp256k1N) >= 0 { + return nil, errors.New("invalid private key, >=N") } // The priv.D must not be zero or negative. if priv.D.Sign() <= 0 { - return nil, fmt.Errorf("invalid private key, zero or negative") + return nil, errors.New("invalid private key, zero or negative") } priv.PublicKey.X, priv.PublicKey.Y = priv.PublicKey.Curve.ScalarBaseMult(d) @@ -127,12 +164,13 @@ func FromECDSA(priv *ecdsa.PrivateKey) []byte { return math.PaddedBigBytes(priv.D, priv.Params().BitSize/8) } -func ToECDSAPub(pub []byte) *ecdsa.PublicKey { - if len(pub) == 0 { - return nil - } +// UnmarshalPubkey converts bytes to a secp256k1 public key. +func UnmarshalPubkey(pub []byte) (*ecdsa.PublicKey, error) { x, y := elliptic.Unmarshal(S256(), pub) - return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y} + if x == nil { + return nil, errInvalidPubkey + } + return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y}, nil } func FromECDSAPub(pub *ecdsa.PublicKey) []byte { @@ -145,38 +183,77 @@ func FromECDSAPub(pub *ecdsa.PublicKey) []byte { // HexToECDSA parses a secp256k1 private key. func HexToECDSA(hexkey string) (*ecdsa.PrivateKey, error) { b, err := hex.DecodeString(hexkey) - if err != nil { - return nil, errors.New("invalid hex string") + if byteErr, ok := err.(hex.InvalidByteError); ok { + return nil, fmt.Errorf("invalid hex character %q in private key", byte(byteErr)) + } else if err != nil { + return nil, errors.New("invalid hex data for private key") } return ToECDSA(b) } // LoadECDSA loads a secp256k1 private key from the given file. func LoadECDSA(file string) (*ecdsa.PrivateKey, error) { - buf := make([]byte, 64) fd, err := os.Open(file) if err != nil { return nil, err } defer fd.Close() - if _, err := io.ReadFull(fd, buf); err != nil { + + r := bufio.NewReader(fd) + buf := make([]byte, 64) + n, err := readASCII(buf, r) + if err != nil { + return nil, err + } else if n != len(buf) { + return nil, errors.New("key file too short, want 64 hex characters") + } + if err := checkKeyFileEnd(r); err != nil { return nil, err } - key, err := hex.DecodeString(string(buf)) - if err != nil { - return nil, err + return HexToECDSA(string(buf)) +} + +// readASCII reads into 'buf', stopping when the buffer is full or +// when a non-printable control character is encountered. +func readASCII(buf []byte, r *bufio.Reader) (n int, err error) { + for ; n < len(buf); n++ { + buf[n], err = r.ReadByte() + switch { + case err == io.EOF || buf[n] < '!': + return n, nil + case err != nil: + return n, err + } + } + return n, nil +} + +// checkKeyFileEnd skips over additional newlines at the end of a key file. +func checkKeyFileEnd(r *bufio.Reader) error { + for i := 0; ; i++ { + b, err := r.ReadByte() + switch { + case err == io.EOF: + return nil + case err != nil: + return err + case b != '\n' && b != '\r': + return fmt.Errorf("invalid character %q at end of key file", b) + case i >= 2: + return errors.New("key file too long, want 64 hex characters") + } } - return ToECDSA(key) } // SaveECDSA saves a secp256k1 private key to the given file with // restrictive permissions. The key data is saved hex-encoded. func SaveECDSA(file string, key *ecdsa.PrivateKey) error { k := hex.EncodeToString(FromECDSA(key)) - return os.WriteFile(file, []byte(k), 0600) + return ioutil.WriteFile(file, []byte(k), 0600) } +// GenerateKey generates a new private key. func GenerateKey() (*ecdsa.PrivateKey, error) { return ecdsa.GenerateKey(S256(), rand.Reader) } @@ -189,11 +266,11 @@ func ValidateSignatureValues(v byte, r, s *big.Int, homestead bool) bool { } // reject upper range of s values (ECDSA malleability) // see discussion in secp256k1/libsecp256k1/include/secp256k1.h - if homestead && s.Cmp(secp256k1_halfN) > 0 { + if homestead && s.Cmp(secp256k1halfN) > 0 { return false } // Frontier: allow s to be in full N range - return r.Cmp(secp256k1_N) < 0 && s.Cmp(secp256k1_N) < 0 && (v == 0 || v == 1) + return r.Cmp(secp256k1N) < 0 && s.Cmp(secp256k1N) < 0 && (v == 0 || v == 1) } func PubkeyToAddress(p ecdsa.PublicKey) common.Address { diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index f1910e5c26..9e1bb2639b 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -20,11 +20,14 @@ import ( "bytes" "crypto/ecdsa" "encoding/hex" + "io/ioutil" "math/big" "os" + "reflect" "testing" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/hexutil" ) var testAddrHex = "970e8128ab834e8eac17ab8e3812f010678cf791" @@ -55,6 +58,33 @@ func BenchmarkSha3(b *testing.B) { } } +func TestUnmarshalPubkey(t *testing.T) { + key, err := UnmarshalPubkey(nil) + if err != errInvalidPubkey || key != nil { + t.Fatalf("expected error, got %v, %v", err, key) + } + key, err = UnmarshalPubkey([]byte{1, 2, 3}) + if err != errInvalidPubkey || key != nil { + t.Fatalf("expected error, got %v, %v", err, key) + } + + var ( + enc, _ = hex.DecodeString("04760c4460e5336ac9bbd87952a3c7ec4363fc0a97bd31c86430806e287b437fd1b01abc6e1db640cf3106b520344af1d58b00b57823db3e1407cbc433e1b6d04d") + dec = &ecdsa.PublicKey{ + Curve: S256(), + X: hexutil.MustDecodeBig("0x760c4460e5336ac9bbd87952a3c7ec4363fc0a97bd31c86430806e287b437fd1"), + Y: hexutil.MustDecodeBig("0xb01abc6e1db640cf3106b520344af1d58b00b57823db3e1407cbc433e1b6d04d"), + } + ) + key, err = UnmarshalPubkey(enc) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !reflect.DeepEqual(key, dec) { + t.Fatal("wrong result") + } +} + func TestSign(t *testing.T) { key, _ := HexToECDSA(testPrivHex) addr := common.HexToAddress(testAddrHex) @@ -68,7 +98,7 @@ func TestSign(t *testing.T) { if err != nil { t.Errorf("ECRecover error: %s", err) } - pubKey := ToECDSAPub(recoveredPub) + pubKey, _ := UnmarshalPubkey(recoveredPub) recoveredAddr := PubkeyToAddress(*pubKey) if addr != recoveredAddr { t.Errorf("Address mismatch: want: %x have: %x", addr, recoveredAddr) @@ -109,39 +139,82 @@ func TestNewContractAddress(t *testing.T) { checkAddr(t, common.HexToAddress("c9ddedf451bc62ce88bf9292afb13df35b670699"), caddr2) } -func TestLoadECDSAFile(t *testing.T) { - keyBytes := common.FromHex(testPrivHex) - fileName0 := "test_key0" - fileName1 := "test_key1" - checkKey := func(k *ecdsa.PrivateKey) { - checkAddr(t, PubkeyToAddress(k.PublicKey), common.HexToAddress(testAddrHex)) - loadedKeyBytes := FromECDSA(k) - if !bytes.Equal(loadedKeyBytes, keyBytes) { - t.Fatalf("private key mismatch: want: %x have: %x", keyBytes, loadedKeyBytes) +func TestLoadECDSA(t *testing.T) { + tests := []struct { + input string + err string + }{ + // good + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\r"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\r\n"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\n"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\r"}, + // bad + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", + err: "key file too short, want 64 hex characters", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde\n", + err: "key file too short, want 64 hex characters", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeX", + err: "invalid hex character 'X' in private key", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdefX", + err: "invalid character 'X' at end of key file", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\n\n", + err: "key file too long, want 64 hex characters", + }, + } + + for _, test := range tests { + f, err := ioutil.TempFile("", "loadecdsa_test.*.txt") + if err != nil { + t.Fatal(err) + } + filename := f.Name() + f.WriteString(test.input) + f.Close() + + _, err = LoadECDSA(filename) + switch { + case err != nil && test.err == "": + t.Fatalf("unexpected error for input %q:\n %v", test.input, err) + case err != nil && err.Error() != test.err: + t.Fatalf("wrong error for input %q:\n %v", test.input, err) + case err == nil && test.err != "": + t.Fatalf("LoadECDSA did not return error for input %q", test.input) } } +} - os.WriteFile(fileName0, []byte(testPrivHex), 0600) - defer os.Remove(fileName0) - - key0, err := LoadECDSA(fileName0) +func TestSaveECDSA(t *testing.T) { + f, err := ioutil.TempFile("", "saveecdsa_test.*.txt") if err != nil { t.Fatal(err) } - checkKey(key0) + file := f.Name() + f.Close() + defer os.Remove(file) - // again, this time with SaveECDSA instead of manual save: - err = SaveECDSA(fileName1, key0) + key, _ := HexToECDSA(testPrivHex) + if err := SaveECDSA(file, key); err != nil { + t.Fatal(err) + } + loaded, err := LoadECDSA(file) if err != nil { t.Fatal(err) } - defer os.Remove(fileName1) - - key1, err := LoadECDSA(fileName1) - if err != nil { - t.Fatal(err) + if !reflect.DeepEqual(key, loaded) { + t.Fatal("loaded key not equal to saved key") } - checkKey(key1) } func TestValidateSignatureValues(t *testing.T) { @@ -153,7 +226,7 @@ func TestValidateSignatureValues(t *testing.T) { minusOne := big.NewInt(-1) one := common.Big1 zero := common.Big0 - secp256k1nMinus1 := new(big.Int).Sub(secp256k1_N, common.Big1) + secp256k1nMinus1 := new(big.Int).Sub(secp256k1N, common.Big1) // correct v,r,s check(true, 0, one, one) @@ -180,9 +253,9 @@ func TestValidateSignatureValues(t *testing.T) { // correct sig with max r,s check(true, 0, secp256k1nMinus1, secp256k1nMinus1) // correct v, combinations of incorrect r,s at upper limit - check(false, 0, secp256k1_N, secp256k1nMinus1) - check(false, 0, secp256k1nMinus1, secp256k1_N) - check(false, 0, secp256k1_N, secp256k1_N) + check(false, 0, secp256k1N, secp256k1nMinus1) + check(false, 0, secp256k1nMinus1, secp256k1N) + check(false, 0, secp256k1N, secp256k1N) // current callers ensures r,s cannot be negative, but let's test for that too // as crypto package could be used stand-alone diff --git a/crypto/ecies/ecies.go b/crypto/ecies/ecies.go index 1474181482..bb1c8d2ff4 100644 --- a/crypto/ecies/ecies.go +++ b/crypto/ecies/ecies.go @@ -35,6 +35,7 @@ import ( "crypto/elliptic" "crypto/hmac" "crypto/subtle" + "errors" "fmt" "hash" "io" @@ -42,12 +43,12 @@ import ( ) var ( - ErrImport = fmt.Errorf("ecies: failed to import key") - ErrInvalidCurve = fmt.Errorf("ecies: invalid elliptic curve") - ErrInvalidParams = fmt.Errorf("ecies: invalid ECIES parameters") - ErrInvalidPublicKey = fmt.Errorf("ecies: invalid public key") - ErrSharedKeyIsPointAtInfinity = fmt.Errorf("ecies: shared key is point at infinity") - ErrSharedKeyTooBig = fmt.Errorf("ecies: shared key params are too big") + ErrImport = errors.New("ecies: failed to import key") + ErrInvalidCurve = errors.New("ecies: invalid elliptic curve") + ErrInvalidParams = errors.New("ecies: invalid ECIES parameters") + ErrInvalidPublicKey = errors.New("ecies: invalid public key") + ErrSharedKeyIsPointAtInfinity = errors.New("ecies: shared key is point at infinity") + ErrSharedKeyTooBig = errors.New("ecies: shared key params are too big") ) // PublicKey is a representation of an elliptic curve public key. @@ -138,9 +139,9 @@ func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []b } var ( - ErrKeyDataTooLong = fmt.Errorf("ecies: can't supply requested key data") - ErrSharedTooLong = fmt.Errorf("ecies: shared secret is too long") - ErrInvalidMessage = fmt.Errorf("ecies: invalid message") + ErrKeyDataTooLong = errors.New("ecies: can't supply requested key data") + ErrSharedTooLong = errors.New("ecies: shared secret is too long") + ErrInvalidMessage = errors.New("ecies: invalid message") ) var ( diff --git a/crypto/ecies/ecies_test.go b/crypto/ecies/ecies_test.go index 15f5b392e5..0d3fb3bbac 100644 --- a/crypto/ecies/ecies_test.go +++ b/crypto/ecies/ecies_test.go @@ -35,6 +35,7 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" + "errors" "fmt" "math/big" "testing" @@ -64,7 +65,7 @@ func TestKDF(t *testing.T) { } } -var ErrBadSharedKeys = fmt.Errorf("ecies: shared keys don't match") +var ErrBadSharedKeys = errors.New("ecies: shared keys don't match") // cmpParams compares a set of ECIES parameters. We assume, as per the // docs, that AES is the only supported symmetric encryption algorithm. diff --git a/crypto/ecies/params.go b/crypto/ecies/params.go index 969cc4a3bd..bd5969f1a9 100644 --- a/crypto/ecies/params.go +++ b/crypto/ecies/params.go @@ -39,7 +39,7 @@ import ( "crypto/elliptic" "crypto/sha256" "crypto/sha512" - "fmt" + "errors" "hash" ethcrypto "github.com/XinFinOrg/XDPoSChain/crypto" @@ -47,8 +47,8 @@ import ( var ( DefaultCurve = ethcrypto.S256() - ErrUnsupportedECDHAlgorithm = fmt.Errorf("ecies: unsupported ECDH algorithm") - ErrUnsupportedECIESParameters = fmt.Errorf("ecies: unsupported ECIES parameters") + ErrUnsupportedECDHAlgorithm = errors.New("ecies: unsupported ECDH algorithm") + ErrUnsupportedECIESParameters = errors.New("ecies: unsupported ECIES parameters") ) type ECIESParams struct { diff --git a/eth/api_backend.go b/eth/api_backend.go index 041e088bb2..c087ec2d38 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -245,12 +245,14 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { - state.SetBalance(msg.From(), math.MaxBig256) +func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmConfig *vm.Config) (*vm.EVM, func() error, error) { vmError := func() error { return nil } - + if vmConfig == nil { + vmConfig = b.eth.blockchain.GetVMConfig() + } + state.SetBalance(msg.From(), math.MaxBig256) context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil) - return vm.NewEVM(context, state, XDCxState, b.eth.chainConfig, vmCfg), vmError, nil + return vm.NewEVM(context, state, XDCxState, b.eth.chainConfig, *vmConfig), vmError, nil } func (b *EthApiBackend) SubscribeRemovedLogsEvent(ch chan<- core.RemovedLogsEvent) event.Subscription { @@ -346,6 +348,10 @@ func (b *EthApiBackend) EventMux() *event.TypeMux { return b.eth.EventMux() } +func (b *EthApiBackend) RPCGasCap() uint64 { + return b.eth.config.RPCGasCap +} + func (b *EthApiBackend) AccountManager() *accounts.Manager { return b.eth.AccountManager() } @@ -375,6 +381,10 @@ func (b *EthApiBackend) GetEngine() consensus.Engine { return b.eth.engine } +func (b *EthApiBackend) StateAtBlock(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, checkLive bool) (*state.StateDB, error) { + return b.eth.stateAtBlock(block, reexec, base, checkLive) +} + func (s *EthApiBackend) GetRewardByHash(hash common.Hash) map[string]map[string]map[string]*big.Int { header := s.eth.blockchain.GetHeaderByHash(hash) if header != nil { diff --git a/eth/api_tracer.go b/eth/api_tracer.go index 2a0f916b5f..01abc14b34 100644 --- a/eth/api_tracer.go +++ b/eth/api_tracer.go @@ -62,6 +62,20 @@ type TraceConfig struct { Reexec *uint64 } +// txTraceContext is the contextual infos about a transaction before it gets run. +type txTraceContext struct { + index int // Index of the transaction within the block + hash common.Hash // Hash of the transaction + block common.Hash // Hash of the block containing the transaction +} + +// TraceCallConfig is the config for traceCall API. It holds one more +// field to override the state for tracing. +type TraceCallConfig struct { + TraceConfig + StateOverrides *ethapi.StateOverride +} + // txTraceResult is the result of a single transaction trace. type txTraceResult struct { Result interface{} `json:"result,omitempty"` // Trace results produced by the tracer @@ -209,9 +223,14 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl } } msg, _ := tx.AsMessage(signer, balacne, task.block.Number()) + txctx := &txTraceContext{ + index: i, + hash: tx.Hash(), + block: task.block.Hash(), + } vmctx := core.NewEVMContext(msg, task.block.Header(), api.eth.blockchain, nil) - res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) + res, err := api.traceTx(ctx, msg, txctx, vmctx, task.statedb, config) if err != nil { task.results[i] = &txTraceResult{Error: err.Error()} log.Warn("Tracing failed", "hash", tx.Hash(), "block", task.block.NumberU64(), "err", err) @@ -434,6 +453,7 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, if threads > len(txs) { threads = len(txs) } + blockHash := block.Hash() for th := 0; th < threads; th++ { pend.Add(1) go func() { @@ -449,9 +469,14 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, } } msg, _ := txs[task.index].AsMessage(signer, balacne, block.Number()) + txctx := &txTraceContext{ + index: task.index, + hash: txs[task.index].Hash(), + block: blockHash, + } vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) - res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) + res, err := api.traceTx(ctx, msg, txctx, vmctx, task.statedb, config) if err != nil { results[task.index] = &txTraceResult{Error: err.Error()} continue @@ -478,6 +503,7 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, } // Generate the next state snapshot fast without tracing msg, _ := tx.AsMessage(signer, balacne, block.Number()) + statedb.Prepare(tx.Hash(), block.Hash(), i) vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) vmenv := vm.NewEVM(vmctx, statedb, XDCxState, api.config, vm.Config{}) @@ -594,14 +620,72 @@ func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, hash common.Ha if err != nil { return nil, err } - // Trace the transaction and return - return api.traceTx(ctx, msg, vmctx, statedb, config) + + txctx := &txTraceContext{ + index: int(index), + hash: hash, + block: blockHash, + } + return api.traceTx(ctx, msg, txctx, vmctx, statedb, config) +} + +// TraceCall lets you trace a given eth_call. It collects the structured logs +// created during the execution of EVM if the given transaction was added on +// top of the provided block and returns them as a JSON object. +// You can provide -2 as a block number to trace on top of the pending block. +func (api *PrivateDebugAPI) TraceCall(ctx context.Context, args ethapi.CallArgs, blockNrOrHash rpc.BlockNumberOrHash, config *TraceCallConfig) (interface{}, error) { + // Try to retrieve the specified block + var ( + err error + block *types.Block + ) + if hash, ok := blockNrOrHash.Hash(); ok { + block, err = api.eth.ApiBackend.BlockByHash(ctx, hash) + } else if number, ok := blockNrOrHash.Number(); ok { + if number == rpc.PendingBlockNumber { + // We don't have access to the miner here. For tracing 'future' transactions, + // it can be done with block- and state-overrides instead, which offers + // more flexibility and stability than trying to trace on 'pending', since + // the contents of 'pending' is unstable and probably not a true representation + // of what the next actual block is likely to contain. + return nil, errors.New("tracing on top of pending is not supported") + } + block, err = api.eth.ApiBackend.BlockByNumber(ctx, number) + } else { + return nil, errors.New("invalid arguments; neither block nor hash specified") + } + if err != nil { + return nil, err + } + // try to recompute the state + reexec := defaultTraceReexec + if config != nil && config.Reexec != nil { + reexec = *config.Reexec + } + statedb, err := api.eth.ApiBackend.StateAtBlock(ctx, block, reexec, nil, true) + if err != nil { + return nil, err + } + // Apply the customized state rules if required. + if config != nil { + if err := config.StateOverrides.Apply(statedb); err != nil { + return nil, err + } + } + // Execute the trace + msg := args.ToMessage(api.eth.ApiBackend, block.Number(), api.eth.ApiBackend.RPCGasCap()) + vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) + var traceConfig *TraceConfig + if config != nil { + traceConfig = &config.TraceConfig + } + return api.traceTx(ctx, msg, new(txTraceContext), vmctx, statedb, traceConfig) } // traceTx configures a new tracer according to the provided configuration, and // executes the given message in the provided environment. The return value will // be tracer dependent. -func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) { +func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, txctx *txTraceContext, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) { // Assemble the structured logger or the JavaScript tracer var ( tracer vm.Tracer @@ -637,6 +721,9 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v // Run the transaction with tracing enabled. vmenv := vm.NewEVM(vmctx, statedb, nil, api.config, vm.Config{Debug: true, Tracer: tracer}) + // Call Prepare to clear out the statedb access list + statedb.Prepare(txctx.hash, txctx.block, txctx.index) + owner := common.Address{} ret, gas, failed, err, _ := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.Gas()), owner) if err != nil { @@ -678,7 +765,7 @@ func (api *PrivateDebugAPI) computeTxEnv(blockHash common.Hash, txIndex int, ree // Recompute transactions up to the target index. feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) if common.TIPSigning.Cmp(block.Header().Number) == 0 { - statedb.DeleteAddress(common.HexToAddress(common.BlockSigners)) + statedb.DeleteAddress(common.BlockSignersBinary) } core.InitSignerInTransactions(api.config, block.Header(), block.Transactions()) balanceUpdated := map[common.Address]*big.Int{} diff --git a/eth/backend.go b/eth/backend.go index 91c60d40d9..a572015b77 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -447,7 +447,7 @@ func (s *Ethereum) Etherbase() (eb common.Address, err error) { return etherbase, nil } } - return common.Address{}, fmt.Errorf("etherbase must be explicitly specified") + return common.Address{}, errors.New("etherbase must be explicitly specified") } // set in js console via admin interface or wrapper from cli flags @@ -475,7 +475,7 @@ func (s *Ethereum) ValidateMasternode() (bool, error) { return false, nil } } else { - return false, fmt.Errorf("Only verify masternode permission in XDPoS protocol") + return false, errors.New("Only verify masternode permission in XDPoS protocol") } return true, nil } @@ -487,7 +487,7 @@ func (s *Ethereum) ValidateMasternodeTestnet() (bool, error) { return false, err } if s.chainConfig.XDPoS == nil { - return false, fmt.Errorf("Only verify masternode permission in XDPoS protocol") + return false, errors.New("Only verify masternode permission in XDPoS protocol") } masternodes := []common.Address{ common.HexToAddress("0x3Ea0A3555f9B1dE983572BfF6444aeb1899eC58C"), diff --git a/eth/bft/bft_handler_test.go b/eth/bft/bft_handler_test.go index 5426fc5699..11d052ba3f 100644 --- a/eth/bft/bft_handler_test.go +++ b/eth/bft/bft_handler_test.go @@ -1,7 +1,7 @@ package bft import ( - "fmt" + "errors" "math/big" "sync/atomic" "testing" @@ -102,7 +102,7 @@ func TestNotBoardcastInvalidVote(t *testing.T) { targetVotes := 0 tester.bfter.consensus.verifyVote = func(chain consensus.ChainReader, vote *types.Vote) (bool, error) { - return false, fmt.Errorf("This is invalid vote") + return false, errors.New("This is invalid vote") } tester.bfter.consensus.voteHandler = func(chain consensus.ChainReader, vote *types.Vote) error { diff --git a/eth/config.go b/eth/config.go index c932a7ae44..73a95aa675 100644 --- a/eth/config.go +++ b/eth/config.go @@ -50,7 +50,8 @@ var DefaultConfig = Config{ TrieTimeout: 5 * time.Minute, GasPrice: big.NewInt(0.25 * params.Shannon), - TxPool: core.DefaultTxPoolConfig, + TxPool: core.DefaultTxPoolConfig, + RPCGasCap: 25000000, GPO: gasprice.Config{ Blocks: 20, Percentile: 60, @@ -114,6 +115,9 @@ type Config struct { // Miscellaneous options DocRoot string `toml:"-"` + + // RPCGasCap is the global gas cap for eth-call variants. + RPCGasCap uint64 } type configMarshaling struct { diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 5e32e6b6d8..1a500142c8 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -980,22 +980,22 @@ func (d *Downloader) fetchReceipts(from uint64) error { // various callbacks to handle the slight differences between processing them. // // The instrumentation parameters: -// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer) -// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers) -// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`) -// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed) -// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping) -// - pending: task callback for the number of requests still needing download (detect completion/non-completability) -// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish) -// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use) -// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions) -// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic) -// - fetch: network callback to actually send a particular download request to a physical remote peer -// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer) -// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping) -// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks -// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) -// - kind: textual label of the type being downloaded to display in log mesages +// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer) +// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers) +// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`) +// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed) +// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping) +// - pending: task callback for the number of requests still needing download (detect completion/non-completability) +// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish) +// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use) +// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions) +// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic) +// - fetch: network callback to actually send a particular download request to a physical remote peer +// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer) +// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping) +// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks +// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) +// - kind: textual label of the type being downloaded to display in log mesages func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool, expire func() map[string]int, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peerConnection, int) (*fetchRequest, bool, error), fetchHook func([]*types.Header), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int, @@ -1036,7 +1036,7 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv case err == nil: peer.log.Trace("Delivered new batch of data", "type", kind, "count", packet.Stats()) default: - peer.log.Trace("Failed to deliver retrieved data", "type", kind, "err", err) + peer.log.Debug("Failed to deliver retrieved data", "type", kind, "err", err) } } // Blocks assembled, try to update the progress diff --git a/eth/filters/api.go b/eth/filters/api.go index 952598046d..04256a1096 100644 --- a/eth/filters/api.go +++ b/eth/filters/api.go @@ -387,7 +387,7 @@ func (api *PublicFilterAPI) GetFilterLogs(ctx context.Context, id rpc.ID) ([]*ty api.filtersMu.Unlock() if !found || f.typ != LogsSubscription { - return nil, fmt.Errorf("filter not found") + return nil, errors.New("filter not found") } var filter *Filter @@ -446,7 +446,7 @@ func (api *PublicFilterAPI) GetFilterChanges(id rpc.ID) (interface{}, error) { } } - return []interface{}{}, fmt.Errorf("filter not found") + return []interface{}{}, errors.New("filter not found") } // returnHashes is a helper that will return an empty hash array case the given hash array is nil, @@ -485,7 +485,7 @@ func (args *FilterCriteria) UnmarshalJSON(data []byte) error { if raw.BlockHash != nil { if raw.FromBlock != nil || raw.ToBlock != nil { // BlockHash is mutually exclusive with FromBlock/ToBlock criteria - return fmt.Errorf("cannot specify both BlockHash and FromBlock/ToBlock, choose one or the other") + return errors.New("cannot specify both BlockHash and FromBlock/ToBlock, choose one or the other") } args.BlockHash = raw.BlockHash } else { @@ -558,11 +558,11 @@ func (args *FilterCriteria) UnmarshalJSON(data []byte) error { } args.Topics[i] = append(args.Topics[i], parsed) } else { - return fmt.Errorf("invalid topic(s)") + return errors.New("invalid topic(s)") } } default: - return fmt.Errorf("invalid topic(s)") + return errors.New("invalid topic(s)") } } } diff --git a/eth/filters/filter_system.go b/eth/filters/filter_system.go index 2d91b771ef..7bf7db6297 100644 --- a/eth/filters/filter_system.go +++ b/eth/filters/filter_system.go @@ -21,7 +21,6 @@ package filters import ( "context" "errors" - "fmt" "sync" "time" @@ -227,7 +226,7 @@ func (es *EventSystem) SubscribeLogs(crit ethereum.FilterQuery, logs chan []*typ if from >= 0 && to == rpc.LatestBlockNumber { return es.subscribeLogs(crit, logs), nil } - return nil, fmt.Errorf("invalid from and to block combination: from > to") + return nil, errors.New("invalid from and to block combination: from > to") } // subscribeMinedPendingLogs creates a subscription that returned mined and diff --git a/eth/gasprice/gasprice.go b/eth/gasprice/gasprice.go index 8dfaeec43a..5449279037 100644 --- a/eth/gasprice/gasprice.go +++ b/eth/gasprice/gasprice.go @@ -162,7 +162,7 @@ type transactionsByGasPrice []*types.Transaction func (t transactionsByGasPrice) Len() int { return len(t) } func (t transactionsByGasPrice) Swap(i, j int) { t[i], t[j] = t[j], t[i] } -func (t transactionsByGasPrice) Less(i, j int) bool { return t[i].GasPrice().Cmp(t[j].GasPrice()) < 0 } +func (t transactionsByGasPrice) Less(i, j int) bool { return t[i].GasPriceCmp(t[j]) < 0 } // getBlockPrices calculates the lowest transaction gas price in a given block // and sends it to the result channel. If the block is empty, price is nil. diff --git a/eth/gen_config.go b/eth/gen_config.go index e58954b7b2..1959ee559d 100644 --- a/eth/gen_config.go +++ b/eth/gen_config.go @@ -4,46 +4,52 @@ package eth import ( "math/big" + "time" "github.com/XinFinOrg/XDPoSChain/common" - "github.com/XinFinOrg/XDPoSChain/common/hexutil" "github.com/XinFinOrg/XDPoSChain/consensus/ethash" "github.com/XinFinOrg/XDPoSChain/core" "github.com/XinFinOrg/XDPoSChain/eth/downloader" "github.com/XinFinOrg/XDPoSChain/eth/gasprice" ) -var _ = (*configMarshaling)(nil) - +// MarshalTOML marshals as TOML. func (c Config) MarshalTOML() (interface{}, error) { type Config struct { Genesis *core.Genesis `toml:",omitempty"` NetworkId uint64 SyncMode downloader.SyncMode + NoPruning bool LightServ int `toml:",omitempty"` LightPeers int `toml:",omitempty"` SkipBcVersionCheck bool `toml:"-"` DatabaseHandles int `toml:"-"` DatabaseCache int + TrieCache int + TrieTimeout time.Duration Etherbase common.Address `toml:",omitempty"` MinerThreads int `toml:",omitempty"` - ExtraData hexutil.Bytes `toml:",omitempty"` + ExtraData []byte `toml:",omitempty"` GasPrice *big.Int Ethash ethash.Config TxPool core.TxPoolConfig GPO gasprice.Config EnablePreimageRecording bool DocRoot string `toml:"-"` + RPCGasCap uint64 } var enc Config enc.Genesis = c.Genesis enc.NetworkId = c.NetworkId enc.SyncMode = c.SyncMode + enc.NoPruning = c.NoPruning enc.LightServ = c.LightServ enc.LightPeers = c.LightPeers enc.SkipBcVersionCheck = c.SkipBcVersionCheck enc.DatabaseHandles = c.DatabaseHandles enc.DatabaseCache = c.DatabaseCache + enc.TrieCache = c.TrieCache + enc.TrieTimeout = c.TrieTimeout enc.Etherbase = c.Etherbase enc.MinerThreads = c.MinerThreads enc.ExtraData = c.ExtraData @@ -53,28 +59,34 @@ func (c Config) MarshalTOML() (interface{}, error) { enc.GPO = c.GPO enc.EnablePreimageRecording = c.EnablePreimageRecording enc.DocRoot = c.DocRoot + enc.RPCGasCap = c.RPCGasCap return &enc, nil } +// UnmarshalTOML unmarshals from TOML. func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { type Config struct { Genesis *core.Genesis `toml:",omitempty"` NetworkId *uint64 SyncMode *downloader.SyncMode + NoPruning *bool LightServ *int `toml:",omitempty"` LightPeers *int `toml:",omitempty"` SkipBcVersionCheck *bool `toml:"-"` DatabaseHandles *int `toml:"-"` DatabaseCache *int + TrieCache *int + TrieTimeout *time.Duration Etherbase *common.Address `toml:",omitempty"` MinerThreads *int `toml:",omitempty"` - ExtraData *hexutil.Bytes `toml:",omitempty"` + ExtraData []byte `toml:",omitempty"` GasPrice *big.Int Ethash *ethash.Config TxPool *core.TxPoolConfig GPO *gasprice.Config EnablePreimageRecording *bool DocRoot *string `toml:"-"` + RPCGasCap *uint64 } var dec Config if err := unmarshal(&dec); err != nil { @@ -89,6 +101,9 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { if dec.SyncMode != nil { c.SyncMode = *dec.SyncMode } + if dec.NoPruning != nil { + c.NoPruning = *dec.NoPruning + } if dec.LightServ != nil { c.LightServ = *dec.LightServ } @@ -104,6 +119,12 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { if dec.DatabaseCache != nil { c.DatabaseCache = *dec.DatabaseCache } + if dec.TrieCache != nil { + c.TrieCache = *dec.TrieCache + } + if dec.TrieTimeout != nil { + c.TrieTimeout = *dec.TrieTimeout + } if dec.Etherbase != nil { c.Etherbase = *dec.Etherbase } @@ -111,7 +132,7 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { c.MinerThreads = *dec.MinerThreads } if dec.ExtraData != nil { - c.ExtraData = *dec.ExtraData + c.ExtraData = dec.ExtraData } if dec.GasPrice != nil { c.GasPrice = dec.GasPrice @@ -131,5 +152,8 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { if dec.DocRoot != nil { c.DocRoot = *dec.DocRoot } + if dec.RPCGasCap != nil { + c.RPCGasCap = *dec.RPCGasCap + } return nil } diff --git a/eth/hooks/engine_v1_hooks.go b/eth/hooks/engine_v1_hooks.go index adb131b2ea..950e231ba0 100644 --- a/eth/hooks/engine_v1_hooks.go +++ b/eth/hooks/engine_v1_hooks.go @@ -216,7 +216,7 @@ func AttachConsensusV1Hooks(adaptor *XDPoS.XDPoS, bc *core.BlockChain, chainConf if err != nil { return nil, err } - addr := common.HexToAddress(common.MasternodeVotingSMC) + addr := common.MasternodeVotingSMCBinary validator, err := contractValidator.NewXDCValidator(addr, client) if err != nil { return nil, err @@ -318,9 +318,6 @@ func getValidators(bc *core.BlockChain, masternodes []common.Address) ([]byte, e // Get secrets and opening at epoc block checkpoint. var candidates []int64 - if err != nil { - return nil, err - } lenSigners := int64(len(masternodes)) if lenSigners > 0 { for _, addr := range masternodes { diff --git a/eth/hooks/engine_v2_hooks.go b/eth/hooks/engine_v2_hooks.go index 4047014a4d..f356af27a7 100644 --- a/eth/hooks/engine_v2_hooks.go +++ b/eth/hooks/engine_v2_hooks.go @@ -2,7 +2,6 @@ package hooks import ( "errors" - "fmt" "math/big" "time" @@ -41,7 +40,7 @@ func AttachConsensusV2Hooks(adaptor *XDPoS.XDPoS, bc *core.BlockChain, chainConf if timeout > 30 { // wait over 30s log.Error("[V2 Hook Penalty] parentHeader is nil, wait too long not writen in to disk", "parentNumber", parentNumber) - return []common.Address{}, fmt.Errorf("parentHeader is nil") + return []common.Address{}, errors.New("parentHeader is nil") } } diff --git a/eth/state_accessor.go b/eth/state_accessor.go new file mode 100644 index 0000000000..140830c5ea --- /dev/null +++ b/eth/state_accessor.go @@ -0,0 +1,137 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "errors" + "fmt" + "time" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/state" + "github.com/XinFinOrg/XDPoSChain/core/types" + "github.com/XinFinOrg/XDPoSChain/core/vm" + "github.com/XinFinOrg/XDPoSChain/log" + "github.com/XinFinOrg/XDPoSChain/trie" +) + +// stateAtBlock retrieves the state database associated with a certain block. +// If no state is locally available for the given block, a number of blocks +// are attempted to be reexecuted to generate the desired state. The optional +// base layer statedb can be passed then it's regarded as the statedb of the +// parent block. +func (eth *Ethereum) stateAtBlock(block *types.Block, reexec uint64, base *state.StateDB, checkLive bool) (statedb *state.StateDB, err error) { + var ( + current *types.Block + database state.Database + report = true + origin = block.NumberU64() + ) + // Check the live database first if we have the state fully available, use that. + if checkLive { + statedb, err = eth.blockchain.StateAt(block.Root()) + if err == nil { + return statedb, nil + } + } + if base != nil { + // The optional base statedb is given, mark the start point as parent block + statedb, database, report = base, base.Database(), false + current = eth.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1) + } else { + // Otherwise try to reexec blocks until we find a state or reach our limit + current = block + + // Create an ephemeral trie.Database for isolating the live one. Otherwise + // the internal junks created by tracing will be persisted into the disk. + database = state.NewDatabaseWithCache(eth.chainDb, 16) + // If we didn't check the dirty database, do check the clean one, otherwise + // we would rewind past a persisted block (specific corner case is chain + // tracing from the genesis). + if !checkLive { + statedb, err = state.New(current.Root(), database) + if err == nil { + return statedb, nil + } + } + // Database does not have the state for the given block, try to regenerate + for i := uint64(0); i < reexec; i++ { + if current.NumberU64() == 0 { + return nil, errors.New("genesis state is missing") + } + parent := eth.blockchain.GetBlock(current.ParentHash(), current.NumberU64()-1) + if parent == nil { + return nil, fmt.Errorf("missing block %v %d", current.ParentHash(), current.NumberU64()-1) + } + current = parent + + statedb, err = state.New(current.Root(), database) + if err == nil { + break + } + } + if err != nil { + switch err.(type) { + case *trie.MissingNodeError: + return nil, fmt.Errorf("required historical state unavailable (reexec=%d)", reexec) + default: + return nil, err + } + } + } + // State was available at historical point, regenerate + var ( + start = time.Now() + logged time.Time + parent common.Hash + ) + for current.NumberU64() < origin { + // Print progress logs if long enough time elapsed + if time.Since(logged) > 8*time.Second && report { + log.Info("Regenerating historical state", "block", current.NumberU64()+1, "target", origin, "remaining", origin-current.NumberU64()-1, "elapsed", time.Since(start)) + logged = time.Now() + } + // Retrieve the next block to regenerate and process it + next := current.NumberU64() + 1 + if current = eth.blockchain.GetBlockByNumber(next); current == nil { + return nil, fmt.Errorf("block #%d not found", next) + } + _, _, _, err := eth.blockchain.Processor().Process(current, statedb, nil, vm.Config{}, nil) + if err != nil { + return nil, fmt.Errorf("processing block %d failed: %v", current.NumberU64(), err) + } + // Finalize the state so any modifications are written to the trie + root, err := statedb.Commit(eth.blockchain.Config().IsEIP158(current.Number())) + if err != nil { + return nil, err + } + statedb, err = state.New(root, database) + if err != nil { + return nil, fmt.Errorf("state reset after block %d failed: %v", current.NumberU64(), err) + } + database.TrieDB().Reference(root, common.Hash{}) + if parent != (common.Hash{}) { + database.TrieDB().Dereference(parent) + } + parent = root + } + if report { + nodes, imgs := database.TrieDB().Size() + log.Info("Historical state regenerated", "block", current.NumberU64(), "elapsed", time.Since(start), "nodes", nodes, "preimages", imgs) + } + return statedb, nil +} diff --git a/eth/tracers/tracer.go b/eth/tracers/tracer.go index 235f752bac..cc4bac5147 100644 --- a/eth/tracers/tracer.go +++ b/eth/tracers/tracer.go @@ -27,6 +27,7 @@ import ( "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" + "github.com/XinFinOrg/XDPoSChain/core" "github.com/XinFinOrg/XDPoSChain/core/vm" "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/log" @@ -286,8 +287,6 @@ func (cw *contractWrapper) pushObject(vm *duktape.Context) { // Tracer provides an implementation of Tracer that evaluates a Javascript // function for each VM execution step. type Tracer struct { - inited bool // Flag whether the context was already inited from the EVM - vm *duktape.Context // Javascript VM instance tracerObject int // Stack index of the tracer JavaScript object @@ -428,17 +427,17 @@ func New(code string) (*Tracer, error) { tracer.tracerObject = 0 // yeah, nice, eval can't return the index itself if !tracer.vm.GetPropString(tracer.tracerObject, "step") { - return nil, fmt.Errorf("trace object must expose a function step()") + return nil, errors.New("trace object must expose a function step()") } tracer.vm.Pop() if !tracer.vm.GetPropString(tracer.tracerObject, "fault") { - return nil, fmt.Errorf("trace object must expose a function fault()") + return nil, errors.New("trace object must expose a function fault()") } tracer.vm.Pop() if !tracer.vm.GetPropString(tracer.tracerObject, "result") { - return nil, fmt.Errorf("trace object must expose a function result()") + return nil, errors.New("trace object must expose a function result()") } tracer.vm.Pop() @@ -526,7 +525,7 @@ func wrapError(context string, err error) error { } // CaptureStart implements the Tracer interface to initialize the tracing operation. -func (jst *Tracer) CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error { +func (jst *Tracer) CaptureStart(env *vm.EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { jst.ctx["type"] = "CALL" if create { jst.ctx["type"] = "CREATE" @@ -537,73 +536,73 @@ func (jst *Tracer) CaptureStart(from common.Address, to common.Address, create b jst.ctx["gas"] = gas jst.ctx["value"] = value - return nil + // Initialize the context + jst.ctx["block"] = env.Context.BlockNumber.Uint64() + jst.dbWrapper.db = env.StateDB + // Compute intrinsic gas + isHomestead := env.ChainConfig().IsHomestead(env.Context.BlockNumber) + intrinsicGas, err := core.IntrinsicGas(input, nil, jst.ctx["type"] == "CREATE", isHomestead) + if err != nil { + return + } + jst.ctx["intrinsicGas"] = intrinsicGas } // CaptureState implements the Tracer interface to trace a single step of VM execution. -func (jst *Tracer) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, contract *vm.Contract, depth int, err error) error { - if jst.err == nil { - // Initialize the context if it wasn't done yet - if !jst.inited { - jst.ctx["block"] = env.BlockNumber.Uint64() - jst.inited = true - } - // If tracing was interrupted, set the error and stop - if atomic.LoadUint32(&jst.interrupt) > 0 { - jst.err = jst.reason - return nil - } - jst.opWrapper.op = op - jst.stackWrapper.stack = stack - jst.memoryWrapper.memory = memory - jst.contractWrapper.contract = contract - jst.dbWrapper.db = env.StateDB - - *jst.pcValue = uint(pc) - *jst.gasValue = uint(gas) - *jst.costValue = uint(cost) - *jst.depthValue = uint(depth) - *jst.refundValue = uint(env.StateDB.GetRefund()) - - jst.errorValue = nil - if err != nil { - jst.errorValue = new(string) - *jst.errorValue = err.Error() - } - _, err := jst.call("step", "log", "db") - if err != nil { - jst.err = wrapError("step", err) - } +func (jst *Tracer) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, rData []byte, depth int, err error) { + if jst.err != nil { + return + } + // If tracing was interrupted, set the error and stop + if atomic.LoadUint32(&jst.interrupt) > 0 { + jst.err = jst.reason + return + } + jst.opWrapper.op = op + jst.stackWrapper.stack = scope.Stack + jst.memoryWrapper.memory = scope.Memory + jst.contractWrapper.contract = scope.Contract + + *jst.pcValue = uint(pc) + *jst.gasValue = uint(gas) + *jst.costValue = uint(cost) + *jst.depthValue = uint(depth) + *jst.refundValue = uint(env.StateDB.GetRefund()) + + jst.errorValue = nil + if err != nil { + jst.errorValue = new(string) + *jst.errorValue = err.Error() + } + + if _, err := jst.call("step", "log", "db"); err != nil { + jst.err = wrapError("step", err) } - return nil } // CaptureFault implements the Tracer interface to trace an execution fault -// while running an opcode. -func (jst *Tracer) CaptureFault(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, contract *vm.Contract, depth int, err error) error { - if jst.err == nil { - // Apart from the error, everything matches the previous invocation - jst.errorValue = new(string) - *jst.errorValue = err.Error() - - _, err := jst.call("fault", "log", "db") - if err != nil { - jst.err = wrapError("fault", err) - } +func (jst *Tracer) CaptureFault(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, depth int, err error) { + if jst.err != nil { + return + } + // Apart from the error, everything matches the previous invocation + jst.errorValue = new(string) + *jst.errorValue = err.Error() + + if _, err := jst.call("fault", "log", "db"); err != nil { + jst.err = wrapError("fault", err) } - return nil } // CaptureEnd is called after the call finishes to finalize the tracing. -func (jst *Tracer) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error { +func (jst *Tracer) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) { jst.ctx["output"] = output - jst.ctx["gasUsed"] = gasUsed jst.ctx["time"] = t.String() + jst.ctx["gasUsed"] = gasUsed if err != nil { jst.ctx["error"] = err.Error() } - return nil } // GetResult calls the Javascript 'result' function and returns its value, or any accumulated error diff --git a/eth/tracers/tracer_test.go b/eth/tracers/tracer_test.go index 577fd1e576..e648973481 100644 --- a/eth/tracers/tracer_test.go +++ b/eth/tracers/tracer_test.go @@ -47,21 +47,150 @@ type dummyStatedb struct { state.StateDB } -func (*dummyStatedb) GetRefund() uint64 { return 1337 } +func (*dummyStatedb) GetRefund() uint64 { return 1337 } +func (*dummyStatedb) GetBalance(addr common.Address) *big.Int { return new(big.Int) } func runTrace(tracer *Tracer) (json.RawMessage, error) { env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) - + var ( + startGas uint64 = 10000 + value = big.NewInt(0) + ) contract := vm.NewContract(account{}, account{}, big.NewInt(0), 10000) contract.Code = []byte{byte(vm.PUSH1), 0x1, byte(vm.PUSH1), 0x1, 0x0} - _, err := env.Interpreter().Run(contract, []byte{}, false) + tracer.CaptureStart(env, contract.Caller(), contract.Address(), false, []byte{}, startGas, value) + ret, err := env.Interpreter().Run(contract, []byte{}, false) + tracer.CaptureEnd(ret, startGas-contract.Gas, 1, err) if err != nil { return nil, err } return tracer.GetResult() } +func TestTracer(t *testing.T) { + execTracer := func(code string) []byte { + t.Helper() + tracer, err := New(code) + if err != nil { + t.Fatal(err) + } + ret, err := runTrace(tracer) + if err != nil { + t.Fatal(err) + } + return ret + } + for i, tt := range []struct { + code string + want string + }{ + { // tests that we don't panic on bad arguments to memory access + code: "{depths: [], step: function(log) { this.depths.push(log.memory.slice(-1,-2)); }, fault: function() {}, result: function() { return this.depths; }}", + want: `[{},{},{}]`, + }, { // tests that we don't panic on bad arguments to stack peeks + code: "{depths: [], step: function(log) { this.depths.push(log.stack.peek(-1)); }, fault: function() {}, result: function() { return this.depths; }}", + want: `["0","0","0"]`, + }, { // tests that we don't panic on bad arguments to memory getUint + code: "{ depths: [], step: function(log, db) { this.depths.push(log.memory.getUint(-64));}, fault: function() {}, result: function() { return this.depths; }}", + want: `["0","0","0"]`, + }, { // tests some general counting + code: "{count: 0, step: function() { this.count += 1; }, fault: function() {}, result: function() { return this.count; }}", + want: `3`, + }, { // tests that depth is reported correctly + code: "{depths: [], step: function(log) { this.depths.push(log.stack.length()); }, fault: function() {}, result: function() { return this.depths; }}", + want: `[0,1,2]`, + }, { // tests to-string of opcodes + code: "{opcodes: [], step: function(log) { this.opcodes.push(log.op.toString()); }, fault: function() {}, result: function() { return this.opcodes; }}", + want: `["PUSH1","PUSH1","STOP"]`, + }, { // tests intrinsic gas + code: "{depths: [], step: function() {}, fault: function() {}, result: function(ctx) { return ctx.gasUsed+'.'+ctx.intrinsicGas; }}", + want: `"6.21000"`, + }, + } { + if have := execTracer(tt.code); tt.want != string(have) { + t.Errorf("testcase %d: expected return value to be %s got %s\n\tcode: %v", i, tt.want, string(have), tt.code) + } + } +} + +func TestHalt(t *testing.T) { + t.Skip("duktape doesn't support abortion") + + timeout := errors.New("stahp") + tracer, err := New("{step: function() { while(1); }, result: function() { return null; }}") + if err != nil { + t.Fatal(err) + } + + go func() { + time.Sleep(1 * time.Second) + tracer.Stop(timeout) + }() + + if _, err = runTrace(tracer); err.Error() != "stahp in server-side tracer function 'step'" { + t.Errorf("Expected timeout error, got %v", err) + } +} + +func TestHaltBetweenSteps(t *testing.T) { + tracer, err := New("{step: function() {}, fault: function() {}, result: function() { return null; }}") + if err != nil { + t.Fatal(err) + } + env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) + scope := &vm.ScopeContext{ + Contract: vm.NewContract(&account{}, &account{}, big.NewInt(0), 0), + } + + tracer.CaptureState(env, 0, 0, 0, 0, scope, nil, 0, nil) + timeout := errors.New("stahp") + tracer.Stop(timeout) + tracer.CaptureState(env, 0, 0, 0, 0, scope, nil, 0, nil) + + if _, err := tracer.GetResult(); err.Error() != timeout.Error() { + t.Errorf("Expected timeout error, got %v", err) + } +} + +// TestNoStepExec tests a regular value transfer (no exec), and accessing the statedb +// in 'result' +func TestNoStepExec(t *testing.T) { + runEmptyTrace := func(tracer *Tracer) (json.RawMessage, error) { + env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) + startGas := uint64(10000) + contract := vm.NewContract(account{}, account{}, big.NewInt(0), startGas) + tracer.CaptureStart(env, contract.Caller(), contract.Address(), false, []byte{}, startGas, big.NewInt(0)) + tracer.CaptureEnd(nil, startGas-contract.Gas, 1, nil) + return tracer.GetResult() + } + execTracer := func(code string) []byte { + t.Helper() + tracer, err := New(code) + if err != nil { + t.Fatal(err) + } + ret, err := runEmptyTrace(tracer) + if err != nil { + t.Fatal(err) + } + return ret + } + for i, tt := range []struct { + code string + want string + }{ + { // tests that we don't panic on accessing the db methods + code: "{depths: [], step: function() {}, fault: function() {}, result: function(ctx, db){ return db.getBalance(ctx.to)} }", + want: `"0"`, + }, + } { + if have := execTracer(tt.code); tt.want != string(have) { + t.Errorf("testcase %d: expected return value to be %s got %s\n\tcode: %v", i, tt.want, string(have), tt.code) + } + } +} + // TestRegressionPanicSlice tests that we don't panic on bad arguments to memory access func TestRegressionPanicSlice(t *testing.T) { tracer, err := New("{depths: [], step: function(log) { this.depths.push(log.memory.slice(-1,-2)); }, fault: function() {}, result: function() { return this.depths; }}") @@ -139,40 +268,3 @@ func TestOpcodes(t *testing.T) { t.Errorf("Expected return value to be [\"PUSH1\",\"PUSH1\",\"STOP\"], got %s", string(ret)) } } - -func TestHalt(t *testing.T) { - t.Skip("duktape doesn't support abortion") - - timeout := errors.New("stahp") - tracer, err := New("{step: function() { while(1); }, result: function() { return null; }}") - if err != nil { - t.Fatal(err) - } - - go func() { - time.Sleep(1 * time.Second) - tracer.Stop(timeout) - }() - - if _, err = runTrace(tracer); err.Error() != "stahp in server-side tracer function 'step'" { - t.Errorf("Expected timeout error, got %v", err) - } -} - -func TestHaltBetweenSteps(t *testing.T) { - tracer, err := New("{step: function() {}, fault: function() {}, result: function() { return null; }}") - if err != nil { - t.Fatal(err) - } - env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) - contract := vm.NewContract(&account{}, &account{}, big.NewInt(0), 0) - - tracer.CaptureState(env, 0, 0, 0, 0, nil, nil, contract, 0, nil) - timeout := errors.New("stahp") - tracer.Stop(timeout) - tracer.CaptureState(env, 0, 0, 0, 0, nil, nil, contract, 0, nil) - - if _, err := tracer.GetResult(); err.Error() != timeout.Error() { - t.Errorf("Expected timeout error, got %v", err) - } -} diff --git a/ethclient/ethclient.go b/ethclient/ethclient.go index 745f5656a6..592b7db356 100644 --- a/ethclient/ethclient.go +++ b/ethclient/ethclient.go @@ -70,6 +70,16 @@ func (ec *Client) BlockByNumber(ctx context.Context, number *big.Int) (*types.Bl return ec.getBlock(ctx, "eth_getBlockByNumber", toBlockNumArg(number), true) } +// BlockReceipts returns the receipts of a given block number or hash +func (ec *Client) BlockReceipts(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) ([]*types.Receipt, error) { + var r []*types.Receipt + err := ec.c.CallContext(ctx, &r, "eth_getBlockReceipts", blockNrOrHash) + if err == nil && r == nil { + return nil, ethereum.NotFound + } + return r, err +} + type rpcBlock struct { Hash common.Hash `json:"hash"` Transactions []rpcTransaction `json:"transactions"` @@ -95,16 +105,16 @@ func (ec *Client) getBlock(ctx context.Context, method string, args ...interface } // Quick-verify transaction and uncle lists. This mostly helps with debugging the server. if head.UncleHash == types.EmptyUncleHash && len(body.UncleHashes) > 0 { - return nil, fmt.Errorf("server returned non-empty uncle list but block header indicates no uncles") + return nil, errors.New("server returned non-empty uncle list but block header indicates no uncles") } if head.UncleHash != types.EmptyUncleHash && len(body.UncleHashes) == 0 { - return nil, fmt.Errorf("server returned empty uncle list but block header indicates uncles") + return nil, errors.New("server returned empty uncle list but block header indicates uncles") } if head.TxHash == types.EmptyRootHash && len(body.Transactions) > 0 { - return nil, fmt.Errorf("server returned non-empty transaction list but block header indicates no transactions") + return nil, errors.New("server returned non-empty transaction list but block header indicates no transactions") } if head.TxHash != types.EmptyRootHash && len(body.Transactions) == 0 { - return nil, fmt.Errorf("server returned empty transaction list but block header indicates transactions") + return nil, errors.New("server returned empty transaction list but block header indicates transactions") } // Load uncles because they are not included in the block response. var uncles []*types.Header @@ -187,7 +197,7 @@ func (ec *Client) TransactionByHash(ctx context.Context, hash common.Hash) (tx * } else if json == nil { return nil, false, ethereum.NotFound } else if _, r, _ := json.tx.RawSignatureValues(); r == nil { - return nil, false, fmt.Errorf("server returned transaction without signature") + return nil, false, errors.New("server returned transaction without signature") } setSenderFromServer(json.tx, json.From, json.BlockHash) return json.tx, json.BlockNumber == nil, nil @@ -233,7 +243,7 @@ func (ec *Client) TransactionInBlock(ctx context.Context, blockHash common.Hash, if json == nil { return nil, ethereum.NotFound } else if _, r, _ := json.tx.RawSignatureValues(); r == nil { - return nil, fmt.Errorf("server returned transaction without signature") + return nil, errors.New("server returned transaction without signature") } } setSenderFromServer(json.tx, json.From, json.BlockHash) @@ -479,7 +489,7 @@ func (ec *Client) EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64 // If the transaction was a contract creation use the TransactionReceipt method to get the // contract address after the transaction has been mined. func (ec *Client) SendTransaction(ctx context.Context, tx *types.Transaction) error { - data, err := rlp.EncodeToBytes(tx) + data, err := tx.MarshalBinary() if err != nil { return err } diff --git a/ethclient/signer.go b/ethclient/signer.go index d41bd379f4..1db03a9c95 100644 --- a/ethclient/signer.go +++ b/ethclient/signer.go @@ -51,9 +51,14 @@ func (s *senderFromServer) Sender(tx *types.Transaction) (common.Address, error) return s.addr, nil } +func (s *senderFromServer) ChainID() *big.Int { + panic("can't sign with senderFromServer") +} + func (s *senderFromServer) Hash(tx *types.Transaction) common.Hash { panic("can't sign with senderFromServer") } + func (s *senderFromServer) SignatureValues(tx *types.Transaction, sig []byte) (R, S, V *big.Int, err error) { panic("can't sign with senderFromServer") } diff --git a/event/feed_test.go b/event/feed_test.go index a82c103033..6c4c91b5b4 100644 --- a/event/feed_test.go +++ b/event/feed_test.go @@ -17,6 +17,7 @@ package event import ( + "errors" "fmt" "reflect" "sync" @@ -68,7 +69,7 @@ func checkPanic(want error, fn func()) (err error) { defer func() { panic := recover() if panic == nil { - err = fmt.Errorf("didn't panic") + err = errors.New("didn't panic") } else if !reflect.DeepEqual(panic, want) { err = fmt.Errorf("panicked with wrong error: got %q, want %q", panic, want) } diff --git a/interfaces.go b/interfaces.go index 88cf2fcb0c..dfd96b217a 100644 --- a/interfaces.go +++ b/interfaces.go @@ -120,6 +120,7 @@ type CallMsg struct { Value *big.Int // amount of wei sent along with the call Data []byte // input data, usually an ABI-encoded contract method invocation BalanceTokenFee *big.Int + AccessList types.AccessList // EIP-2930 access list. } // A ContractCaller provides contract calls, essentially transactions that are executed by diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index d234802a1b..650135e067 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -393,7 +393,7 @@ func (s *PrivateAccountAPI) SendTransaction(ctx context.Context, args SendTxArgs if err != nil { return common.Hash{}, err } - return submitTransaction(ctx, s.b, signed) + return SubmitTransaction(ctx, s.b, signed) } // SignTransaction will create a transaction from the given arguments and @@ -404,19 +404,20 @@ func (s *PrivateAccountAPI) SignTransaction(ctx context.Context, args SendTxArgs // No need to obtain the noncelock mutex, since we won't be sending this // tx into the transaction pool, but right back to the user if args.Gas == nil { - return nil, fmt.Errorf("gas not specified") + return nil, errors.New("gas not specified") } if args.GasPrice == nil { - return nil, fmt.Errorf("gasPrice not specified") + return nil, errors.New("gasPrice not specified") } if args.Nonce == nil { - return nil, fmt.Errorf("nonce not specified") + return nil, errors.New("nonce not specified") } signed, err := s.signTransaction(ctx, args, passwd) if err != nil { + log.Warn("Failed transaction sign attempt", "from", args.From, "to", args.To, "value", args.Value.ToInt(), "err", err) return nil, err } - data, err := rlp.EncodeToBytes(signed) + data, err := signed.MarshalBinary() if err != nil { return nil, err } @@ -473,21 +474,19 @@ func (s *PrivateAccountAPI) Sign(ctx context.Context, data hexutil.Bytes, addr c // // https://github.com/XinFinOrg/XDPoSChain/wiki/Management-APIs#personal_ecRecover func (s *PrivateAccountAPI) EcRecover(ctx context.Context, data, sig hexutil.Bytes) (common.Address, error) { - if len(sig) != 65 { - return common.Address{}, fmt.Errorf("signature must be 65 bytes long") + if len(sig) != crypto.SignatureLength { + return common.Address{}, fmt.Errorf("signature must be %d bytes long", crypto.SignatureLength) } - if sig[64] != 27 && sig[64] != 28 { - return common.Address{}, fmt.Errorf("invalid Ethereum signature (V is not 27 or 28)") + if sig[crypto.RecoveryIDOffset] != 27 && sig[crypto.RecoveryIDOffset] != 28 { + return common.Address{}, errors.New("invalid Ethereum signature (V is not 27 or 28)") } - sig[64] -= 27 // Transform yellow paper V from 27/28 to 0/1 + sig[crypto.RecoveryIDOffset] -= 27 // Transform yellow paper V from 27/28 to 0/1 - rpk, err := crypto.Ecrecover(signHash(data), sig) + rpk, err := crypto.SigToPub(accounts.TextHash(data), sig) if err != nil { return common.Address{}, err } - pubKey := crypto.ToECDSAPub(rpk) - recoveredAddr := crypto.PubkeyToAddress(*pubKey) - return recoveredAddr, nil + return crypto.PubkeyToAddress(*rpk), nil } // SignAndSendTransaction was renamed to SendTransaction. This method is deprecated @@ -533,6 +532,49 @@ func (s *PublicBlockChainAPI) GetBalance(ctx context.Context, address common.Add return (*hexutil.Big)(state.GetBalance(address)), state.Error() } +// GetTransactionAndReceiptProof returns the Trie transaction and receipt proof of the given transaction hash. +func (s *PublicBlockChainAPI) GetTransactionAndReceiptProof(ctx context.Context, hash common.Hash) (map[string]interface{}, error) { + tx, blockHash, _, index := core.GetTransaction(s.b.ChainDb(), hash) + if tx == nil { + return nil, nil + } + block, err := s.b.GetBlock(ctx, blockHash) + if err != nil { + return nil, err + } + tx_tr := deriveTrie(block.Transactions()) + + keybuf := new(bytes.Buffer) + rlp.Encode(keybuf, uint(index)) + var tx_proof proofPairList + if err := tx_tr.Prove(keybuf.Bytes(), 0, &tx_proof); err != nil { + return nil, err + } + receipts, err := s.b.GetReceipts(ctx, blockHash) + if err != nil { + return nil, err + } + if len(receipts) <= int(index) { + return nil, nil + } + receipt_tr := deriveTrie(receipts) + var receipt_proof proofPairList + if err := receipt_tr.Prove(keybuf.Bytes(), 0, &receipt_proof); err != nil { + return nil, err + } + fields := map[string]interface{}{ + "blockHash": blockHash, + "txRoot": tx_tr.Hash(), + "receiptRoot": receipt_tr.Hash(), + "key": hexutil.Encode(keybuf.Bytes()), + "txProofKeys": tx_proof.keys, + "txProofValues": tx_proof.values, + "receiptProofKeys": receipt_proof.keys, + "receiptProofValues": receipt_proof.values, + } + return fields, nil +} + // GetBlockByNumber returns the requested block. When blockNr is -1 the chain head is returned. When fullTx is true all // transactions in the block are returned in full detail, otherwise only the transaction hash is returned. func (s *PublicBlockChainAPI) GetBlockByNumber(ctx context.Context, blockNr rpc.BlockNumber, fullTx bool) (map[string]interface{}, error) { @@ -653,6 +695,34 @@ func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.A return res[:], state.Error() } +// GetBlockReceipts returns the block receipts for the given block hash or number or tag. +func (s *PublicBlockChainAPI) GetBlockReceipts(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) ([]map[string]interface{}, error) { + block, err := s.b.BlockByNumberOrHash(ctx, blockNrOrHash) + if block == nil || err != nil { + // When the block doesn't exist, the RPC method should return JSON null + // as per specification. + return nil, nil + } + receipts, err := s.b.GetReceipts(ctx, block.Hash()) + if err != nil { + return nil, err + } + txs := block.Transactions() + if len(txs) != len(receipts) { + return nil, fmt.Errorf("receipts length mismatch: %d vs %d", len(txs), len(receipts)) + } + + // Derive the sender. + signer := types.MakeSigner(s.b.ChainConfig(), block.Number()) + + result := make([]map[string]interface{}, len(receipts)) + for i, receipt := range receipts { + result[i] = marshalReceipt(receipt, block.Hash(), block.NumberU64(), signer, txs[i], i) + } + + return result, nil +} + // OverrideAccount indicates the overriding fields of account during the execution // of a message call. // Note, state and stateDiff can't be specified at the same time. If state is @@ -845,7 +915,7 @@ func (s *PublicBlockChainAPI) GetCandidateStatus(ctx context.Context, coinbaseAd } maxMasternodes = s.b.ChainConfig().XDPoS.V2.Config(uint64(round)).MaxMasternodes } else { - return result, fmt.Errorf("undefined XDPoS consensus engine") + return result, errors.New("undefined XDPoS consensus engine") } } else if s.b.ChainConfig().IsTIPIncreaseMasternodes(block.Number()) { maxMasternodes = common.MaxMasternodesV2 @@ -1036,7 +1106,7 @@ func (s *PublicBlockChainAPI) GetCandidates(ctx context.Context, epoch rpc.Epoch } maxMasternodes = s.b.ChainConfig().XDPoS.V2.Config(uint64(round)).MaxMasternodes } else { - return result, fmt.Errorf("undefined XDPoS consensus engine") + return result, errors.New("undefined XDPoS consensus engine") } } else if s.b.ChainConfig().IsTIPIncreaseMasternodes(block.Number()) { maxMasternodes = common.MaxMasternodesV2 @@ -1127,7 +1197,7 @@ func (s *PublicBlockChainAPI) getCandidatesFromSmartContract() ([]utils.Masterno return []utils.Masternode{}, err } - addr := common.HexToAddress(common.MasternodeVotingSMC) + addr := common.MasternodeVotingSMCBinary validator, err := contractValidator.NewXDCValidator(addr, client) if err != nil { return []utils.Masternode{}, err @@ -1157,15 +1227,74 @@ func (s *PublicBlockChainAPI) getCandidatesFromSmartContract() ([]utils.Masterno // CallArgs represents the arguments for a call. type CallArgs struct { - From common.Address `json:"from"` - To *common.Address `json:"to"` - Gas hexutil.Uint64 `json:"gas"` - GasPrice hexutil.Big `json:"gasPrice"` - Value hexutil.Big `json:"value"` - Data hexutil.Bytes `json:"data"` + From *common.Address `json:"from"` + To *common.Address `json:"to"` + Gas *hexutil.Uint64 `json:"gas"` + GasPrice *hexutil.Big `json:"gasPrice"` + Value *hexutil.Big `json:"value"` + Data *hexutil.Bytes `json:"data"` + AccessList *types.AccessList `json:"accessList"` } -func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, vmCfg vm.Config, timeout time.Duration) ([]byte, uint64, bool, error, error) { +// ToMessage converts CallArgs to the Message type used by the core evm +// TODO: set balanceTokenFee +func (args *CallArgs) ToMessage(b Backend, number *big.Int, globalGasCap uint64) types.Message { + // Set sender address or use a default if none specified + var addr common.Address + if args.From == nil || *args.From == (common.Address{}) { + if wallets := b.AccountManager().Wallets(); len(wallets) > 0 { + if accounts := wallets[0].Accounts(); len(accounts) > 0 { + addr = accounts[0].Address + } + } + } else { + addr = *args.From + } + + // Set default gas & gas price if none were set + gas := globalGasCap + if gas == 0 { + gas = uint64(math.MaxUint64 / 2) + } + if args.Gas != nil { + gas = uint64(*args.Gas) + } + if globalGasCap != 0 && globalGasCap < gas { + log.Warn("Caller gas above allowance, capping", "requested", gas, "cap", globalGasCap) + gas = globalGasCap + } + gasPrice := new(big.Int) + if args.GasPrice != nil { + gasPrice = args.GasPrice.ToInt() + } + if gasPrice.Sign() <= 0 { + gasPrice = new(big.Int).SetUint64(defaultGasPrice) + } + + value := new(big.Int) + if args.Value != nil { + value = args.Value.ToInt() + } + + var data []byte + if args.Data != nil { + data = *args.Data + } + + var accessList types.AccessList + if args.AccessList != nil { + accessList = *args.AccessList + } + + balanceTokenFee := big.NewInt(0).SetUint64(gas) + balanceTokenFee = balanceTokenFee.Mul(balanceTokenFee, gasPrice) + + // Create new call message + msg := types.NewMessage(addr, args.To, 0, value, gas, gasPrice, data, accessList, false, balanceTokenFee, number) + return msg +} + +func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, vmCfg vm.Config, timeout time.Duration, globalGasCap uint64) ([]byte, uint64, bool, error, error) { defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) statedb, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) @@ -1175,27 +1304,8 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.Blo if err := overrides.Apply(statedb); err != nil { return nil, 0, false, err, nil } - // Set sender address or use a default if none specified - addr := args.From - if addr == (common.Address{}) { - if wallets := b.AccountManager().Wallets(); len(wallets) > 0 { - if accounts := wallets[0].Accounts(); len(accounts) > 0 { - addr = accounts[0].Address - } - } - } - // Set default gas & gas price if none were set - gas, gasPrice := uint64(args.Gas), args.GasPrice.ToInt() - if gas == 0 { - gas = math.MaxUint64 / 2 - } - if gasPrice.Sign() == 0 { - gasPrice = new(big.Int).SetUint64(defaultGasPrice) - } - balanceTokenFee := big.NewInt(0).SetUint64(gas) - balanceTokenFee = balanceTokenFee.Mul(balanceTokenFee, gasPrice) - // Create new call message - msg := types.NewMessage(addr, args.To, 0, args.Value.ToInt(), gas, gasPrice, args.Data, false, balanceTokenFee, header.Number) + + msg := args.ToMessage(b, header.Number, globalGasCap) // Setup context so it may be cancelled the call has completed // or, in case of unmetered gas, setup a context with a timeout. @@ -1209,7 +1319,7 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.Blo // this makes sure resources are cleaned up. defer cancel() - block, err := b.BlockByNumberOrHash(ctx, blockNrOrHash) + block, err := b.BlockByHash(ctx, header.Hash()) if err != nil { return nil, 0, false, err, nil } @@ -1221,8 +1331,9 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.Blo if err != nil { return nil, 0, false, err, nil } + // Get a new instance of the EVM. - evm, vmError, err := b.GetEVM(ctx, msg, statedb, XDCxState, header, vmCfg) + evm, vmError, err := b.GetEVM(ctx, msg, statedb, XDCxState, header, &vmCfg) if err != nil { return nil, 0, false, err, nil } @@ -1233,8 +1344,7 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.Blo evm.Cancel() }() - // Setup the gas pool (also for unmetered requests) - // and apply the message. + // Execute the message. gp := new(core.GasPool).AddGas(math.MaxUint64) owner := common.Address{} res, gas, failed, err, vmErr := core.ApplyMessage(evm, msg, gp, owner) @@ -1289,7 +1399,7 @@ func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNrOr latest := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) blockNrOrHash = &latest } - result, _, failed, err, vmErr := DoCall(ctx, s.b, args, *blockNrOrHash, overrides, vm.Config{}, 5*time.Second) + result, _, failed, err, vmErr := DoCall(ctx, s.b, args, *blockNrOrHash, overrides, vm.Config{}, 5*time.Second, s.b.RPCGasCap()) if err != nil { return nil, err } @@ -1301,7 +1411,7 @@ func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNrOr return (hexutil.Bytes)(result), vmErr } -func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride) (hexutil.Uint64, error) { +func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, gasCap uint64) (hexutil.Uint64, error) { // Retrieve the base state and mutate it with any overrides state, _, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { @@ -1316,23 +1426,36 @@ func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash hi uint64 cap uint64 ) - if uint64(args.Gas) >= params.TxGas { - hi = uint64(args.Gas) + // Use zero address if sender unspecified. + if args.From == nil { + args.From = new(common.Address) + } + // Determine the highest gas limit can be used during the estimation. + if args.Gas != nil && uint64(*args.Gas) >= params.TxGas { + hi = uint64(*args.Gas) } else { // Retrieve the current pending block to act as the gas ceiling block, err := b.BlockByNumberOrHash(ctx, blockNrOrHash) if err != nil { return 0, err } + if block == nil { + return 0, errors.New("block not found") + } hi = block.GasLimit() } + // Recap the highest gas allowance with specified gascap. + if gasCap != 0 && hi > gasCap { + log.Warn("Caller gas above allowance, capping", "requested", hi, "cap", gasCap) + hi = gasCap + } cap = hi // Create a helper to check if a gas allowance results in an executable transaction executable := func(gas uint64) (bool, []byte, error, error) { - args.Gas = hexutil.Uint64(gas) + args.Gas = (*hexutil.Uint64)(&gas) - res, _, failed, err, vmErr := DoCall(ctx, b, args, blockNrOrHash, nil, vm.Config{}, 0) + res, _, failed, err, vmErr := DoCall(ctx, b, args, blockNrOrHash, nil, vm.Config{}, 0, gasCap) if err != nil { if errors.Is(err, vm.ErrOutOfGas) || errors.Is(err, core.ErrIntrinsicGas) { return false, nil, nil, nil // Special case, raise gas limit @@ -1350,8 +1473,8 @@ func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash // directly try 21000. Returning 21000 without any execution is dangerous as // some tx field combos might bump the price up even for plain transfers (e.g. // unused access list items). Ever so slightly wasteful, but safer overall. - if len(args.Data) == 0 && args.To != nil { - if state.GetCodeSize(*args.To) == 0 { + if args.Data == nil || len(*args.Data) == 0 { + if args.To != nil && state.GetCodeSize(*args.To) == 0 { ok, _, err, _ := executable(params.TxGas) if ok && err == nil { return hexutil.Uint64(params.TxGas), nil @@ -1407,7 +1530,7 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs, bl if blockNrOrHash != nil { bNrOrHash = *blockNrOrHash } - return DoEstimateGas(ctx, s.b, args, bNrOrHash, overrides) + return DoEstimateGas(ctx, s.b, args, bNrOrHash, overrides, s.b.RPCGasCap()) } // ExecutionResult groups all structured logs emitted by the EVM @@ -1681,33 +1804,43 @@ func (s *PublicBlockChainAPI) rpcOutputBlockSigners(b *types.Block, ctx context. // RPCTransaction represents a transaction that will serialize to the RPC representation of a transaction type RPCTransaction struct { - BlockHash common.Hash `json:"blockHash"` - BlockNumber *hexutil.Big `json:"blockNumber"` - From common.Address `json:"from"` - Gas hexutil.Uint64 `json:"gas"` - GasPrice *hexutil.Big `json:"gasPrice"` - Hash common.Hash `json:"hash"` - Input hexutil.Bytes `json:"input"` - Nonce hexutil.Uint64 `json:"nonce"` - To *common.Address `json:"to"` - TransactionIndex hexutil.Uint `json:"transactionIndex"` - Value *hexutil.Big `json:"value"` - V *hexutil.Big `json:"v"` - R *hexutil.Big `json:"r"` - S *hexutil.Big `json:"s"` + BlockHash *common.Hash `json:"blockHash"` + BlockNumber *hexutil.Big `json:"blockNumber"` + From common.Address `json:"from"` + Gas hexutil.Uint64 `json:"gas"` + GasPrice *hexutil.Big `json:"gasPrice"` + Hash common.Hash `json:"hash"` + Input hexutil.Bytes `json:"input"` + Nonce hexutil.Uint64 `json:"nonce"` + To *common.Address `json:"to"` + TransactionIndex *hexutil.Uint64 `json:"transactionIndex"` + Value *hexutil.Big `json:"value"` + Type hexutil.Uint64 `json:"type"` + Accesses *types.AccessList `json:"accessList,omitempty"` + ChainID *hexutil.Big `json:"chainId,omitempty"` + V *hexutil.Big `json:"v"` + R *hexutil.Big `json:"r"` + S *hexutil.Big `json:"s"` } // newRPCTransaction returns a transaction that will serialize to the RPC // representation, with the given location metadata set (if available). func newRPCTransaction(tx *types.Transaction, blockHash common.Hash, blockNumber uint64, index uint64) *RPCTransaction { - var signer types.Signer = types.FrontierSigner{} + // Determine the signer. For replay-protected transactions, use the most permissive + // signer, because we assume that signers are backwards-compatible with old + // transactions. For non-protected transactions, the homestead signer signer is used + // because the return value of ChainId is zero for those transactions. + var signer types.Signer if tx.Protected() { - signer = types.NewEIP155Signer(tx.ChainId()) + signer = types.LatestSignerForChainID(tx.ChainId()) + } else { + signer = types.HomesteadSigner{} } + from, _ := types.Sender(signer, tx) v, r, s := tx.RawSignatureValues() - result := &RPCTransaction{ + Type: hexutil.Uint64(tx.Type()), From: from, Gas: hexutil.Uint64(tx.Gas()), GasPrice: (*hexutil.Big)(tx.GasPrice()), @@ -1721,9 +1854,14 @@ func newRPCTransaction(tx *types.Transaction, blockHash common.Hash, blockNumber S: (*hexutil.Big)(s), } if blockHash != (common.Hash{}) { - result.BlockHash = blockHash + result.BlockHash = &blockHash result.BlockNumber = (*hexutil.Big)(new(big.Int).SetUint64(blockNumber)) - result.TransactionIndex = hexutil.Uint(index) + result.TransactionIndex = (*hexutil.Uint64)(&index) + } + if tx.Type() == types.AccessListTxType { + al := tx.AccessList() + result.Accesses = &al + result.ChainID = (*hexutil.Big)(tx.ChainId()) } return result } @@ -1748,7 +1886,7 @@ func newRPCRawTransactionFromBlockIndex(b *types.Block, index uint64) hexutil.By if index >= uint64(len(txs)) { return nil } - blob, _ := rlp.EncodeToBytes(txs[index]) + blob, _ := txs[index].MarshalBinary() return blob } @@ -1762,10 +1900,131 @@ func newRPCTransactionFromBlockHash(b *types.Block, hash common.Hash) *RPCTransa return nil } +// accessListResult returns an optional accesslist +// Its the result of the `debug_createAccessList` RPC call. +// It contains an error if the transaction itself failed. +type accessListResult struct { + Accesslist *types.AccessList `json:"accessList"` + Error string `json:"error,omitempty"` + GasUsed hexutil.Uint64 `json:"gasUsed"` +} + +// CreateAccessList creates a EIP-2930 type AccessList for the given transaction. +// Reexec and BlockNrOrHash can be specified to create the accessList on top of a certain state. +func (s *PublicBlockChainAPI) CreateAccessList(ctx context.Context, args SendTxArgs, blockNrOrHash *rpc.BlockNumberOrHash) (*accessListResult, error) { + bNrOrHash := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + if blockNrOrHash != nil { + bNrOrHash = *blockNrOrHash + } + acl, gasUsed, vmerr, err := AccessList(ctx, s.b, bNrOrHash, args) + if err != nil { + return nil, err + } + result := &accessListResult{Accesslist: &acl, GasUsed: hexutil.Uint64(gasUsed)} + if vmerr != nil { + result.Error = vmerr.Error() + } + return result, nil +} + +// AccessList creates an access list for the given transaction. +// If the accesslist creation fails an error is returned. +// If the transaction itself fails, an vmErr is returned. +func AccessList(ctx context.Context, b Backend, blockNrOrHash rpc.BlockNumberOrHash, args SendTxArgs) (acl types.AccessList, gasUsed uint64, vmErr error, err error) { + // Retrieve the execution context + db, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) + if db == nil || err != nil { + return nil, 0, nil, err + } + block, err := b.BlockByHash(ctx, header.Hash()) + if err != nil { + return nil, 0, nil, err + } + author, err := b.GetEngine().Author(block.Header()) + if err != nil { + return nil, 0, nil, err + } + XDCxState, err := b.XDCxService().GetTradingState(block, author) + if err != nil { + return nil, 0, nil, err + } + owner := common.Address{} + + // If the gas amount is not set, extract this as it will depend on access + // lists and we'll need to reestimate every time + nogas := args.Gas == nil + + // Ensure any missing fields are filled, extract the recipient and input data + if err := args.setDefaults(ctx, b); err != nil { + return nil, 0, nil, err + } + var to common.Address + if args.To != nil { + to = *args.To + } else { + to = crypto.CreateAddress(args.From, uint64(*args.Nonce)) + } + var input []byte + if args.Input != nil { + input = *args.Input + } else if args.Data != nil { + input = *args.Data + } + // Retrieve the precompiles since they don't need to be added to the access list + precompiles := vm.ActivePrecompiles(b.ChainConfig().Rules(header.Number)) + + // Create an initial tracer + prevTracer := vm.NewAccessListTracer(nil, args.From, to, precompiles) + if args.AccessList != nil { + prevTracer = vm.NewAccessListTracer(*args.AccessList, args.From, to, precompiles) + } + for { + // Retrieve the current access list to expand + accessList := prevTracer.AccessList() + log.Trace("Creating access list", "input", accessList) + + // If no gas amount was specified, each unique access list needs it's own + // gas calculation. This is quite expensive, but we need to be accurate + // and it's convered by the sender only anyway. + if nogas { + args.Gas = nil + if err := args.setDefaults(ctx, b); err != nil { + return nil, 0, nil, err // shouldn't happen, just in case + } + } + // Copy the original db so we don't modify it + statedb := db.Copy() + feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) + var balanceTokenFee *big.Int + if value, ok := feeCapacity[to]; ok { + balanceTokenFee = value + } + msg := types.NewMessage(args.From, args.To, uint64(*args.Nonce), args.Value.ToInt(), uint64(*args.Gas), args.GasPrice.ToInt(), input, accessList, false, balanceTokenFee, header.Number) + + // Apply the transaction with the access list tracer + tracer := vm.NewAccessListTracer(accessList, args.From, to, precompiles) + config := vm.Config{Tracer: tracer, Debug: true} + vmenv, _, err := b.GetEVM(ctx, msg, statedb, XDCxState, header, &config) + if err != nil { + return nil, 0, nil, err + } + // TODO: determine the value of owner + _, UsedGas, _, err, vmErr := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.Gas()), owner) + if err != nil { + return nil, 0, nil, fmt.Errorf("failed to apply transaction: %v err: %v", args.toTransaction().Hash(), err) + } + if tracer.Equal(prevTracer) { + return accessList, UsedGas, vmErr, nil + } + prevTracer = tracer + } +} + // PublicTransactionPoolAPI exposes methods for the RPC interface type PublicTransactionPoolAPI struct { b Backend nonceLock *AddrLocker + signer types.Signer } // PublicTransactionPoolAPI exposes methods for the RPC interface @@ -1776,7 +2035,10 @@ type PublicXDCXTransactionPoolAPI struct { // NewPublicTransactionPoolAPI creates a new RPC service with methods specific for the transaction pool. func NewPublicTransactionPoolAPI(b Backend, nonceLock *AddrLocker) *PublicTransactionPoolAPI { - return &PublicTransactionPoolAPI{b, nonceLock} + // The signer used by the API should always be the 'latest' known one because we expect + // signers to be backwards-compatible with old transactions. + signer := types.LatestSigner(b.ChainConfig()) + return &PublicTransactionPoolAPI{b, nonceLock, signer} } // NewPublicTransactionPoolAPI creates a new RPC service with methods specific for the transaction pool. @@ -1869,17 +2131,16 @@ func (s *PublicTransactionPoolAPI) GetTransactionByHash(ctx context.Context, has // GetRawTransactionByHash returns the bytes of the transaction for the given hash. func (s *PublicTransactionPoolAPI) GetRawTransactionByHash(ctx context.Context, hash common.Hash) (hexutil.Bytes, error) { - var tx *types.Transaction - // Retrieve a finalized transaction, or a pooled otherwise - if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil { + tx, _, _, _ := core.GetTransaction(s.b.ChainDb(), hash) + if tx == nil { if tx = s.b.GetPoolTransaction(hash); tx == nil { // Transaction not found anywhere, abort return nil, nil } } // Serialize to RLP and return - return rlp.EncodeToBytes(tx) + return tx.MarshalBinary() } // GetTransactionReceipt returns the transaction receipt for the given transaction hash. @@ -1897,10 +2158,9 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, ha } receipt := receipts[index] - var signer types.Signer = types.FrontierSigner{} - if tx.Protected() { - signer = types.NewEIP155Signer(tx.ChainId()) - } + // Derive the sender. + bigblock := new(big.Int).SetUint64(blockNumber) + signer := types.MakeSigner(s.b.ChainConfig(), bigblock) from, _ := types.Sender(signer, tx) fields := map[string]interface{}{ @@ -1915,6 +2175,7 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, ha "contractAddress": nil, "logs": receipt.Logs, "logsBloom": receipt.Bloom, + "type": hexutil.Uint(tx.Type()), } // Assign receipt status or post state. @@ -1933,6 +2194,44 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, ha return fields, nil } +// marshalReceipt marshals a transaction receipt into a JSON object. +func marshalReceipt(receipt *types.Receipt, blockHash common.Hash, blockNumber uint64, signer types.Signer, tx *types.Transaction, txIndex int) map[string]interface{} { + from, _ := types.Sender(signer, tx) + + fields := map[string]interface{}{ + "blockHash": blockHash, + "blockNumber": hexutil.Uint64(blockNumber), + "transactionHash": tx.Hash(), + "transactionIndex": hexutil.Uint64(txIndex), + "from": from, + "to": tx.To(), + "gasUsed": hexutil.Uint64(receipt.GasUsed), + "cumulativeGasUsed": hexutil.Uint64(receipt.CumulativeGasUsed), + "contractAddress": nil, + "logs": receipt.Logs, + "logsBloom": receipt.Bloom, + "type": hexutil.Uint(tx.Type()), + // uncomment below line after EIP-1559 + // TODO: "effectiveGasPrice": (*hexutil.Big)(receipt.EffectiveGasPrice), + } + + // Assign receipt status or post state. + if len(receipt.PostState) > 0 { + fields["root"] = hexutil.Bytes(receipt.PostState) + } else { + fields["status"] = hexutil.Uint(receipt.Status) + } + if receipt.Logs == nil { + fields["logs"] = []*types.Log{} + } + + // If the ContractAddress is 20 0x0 bytes, assume it is not a contract creation + if receipt.ContractAddress != (common.Address{}) { + fields["contractAddress"] = receipt.ContractAddress + } + return fields +} + // sign is a helper function that signs a transaction with the private key of the given address. func (s *PublicTransactionPoolAPI) sign(addr common.Address, tx *types.Transaction) (*types.Transaction, error) { // Look up the wallet containing the requested signer @@ -1962,14 +2261,14 @@ type SendTxArgs struct { // newer name and should be preferred by clients. Data *hexutil.Bytes `json:"data"` Input *hexutil.Bytes `json:"input"` + + // For non-legacy transactions + AccessList *types.AccessList `json:"accessList,omitempty"` + ChainID *hexutil.Big `json:"chainId,omitempty"` } -// setDefaults is a helper function that fills in default values for unspecified tx fields. +// setDefaults fills in default values for unspecified tx fields. func (args *SendTxArgs) setDefaults(ctx context.Context, b Backend) error { - if args.Gas == nil { - args.Gas = new(hexutil.Uint64) - *(*uint64)(args.Gas) = 90000 - } if args.GasPrice == nil { price, err := b.SuggestPrice(ctx) if err != nil { @@ -2002,45 +2301,98 @@ func (args *SendTxArgs) setDefaults(ctx context.Context, b Backend) error { return errors.New(`contract creation without any data provided`) } } + // Estimate the gas usage if necessary. + if args.Gas == nil { + // For backwards-compatibility reason, we try both input and data + // but input is preferred. + input := args.Input + if input == nil { + input = args.Data + } + callArgs := CallArgs{ + From: &args.From, // From shouldn't be nil + To: args.To, + GasPrice: args.GasPrice, + Value: args.Value, + Data: input, + AccessList: args.AccessList, + } + pendingBlockNr := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + estimated, err := DoEstimateGas(ctx, b, callArgs, pendingBlockNr, nil, b.RPCGasCap()) + if err != nil { + return err + } + args.Gas = &estimated + log.Trace("Estimate gas usage automatically", "gas", args.Gas) + + } + if args.ChainID == nil { + id := (*hexutil.Big)(b.ChainConfig().ChainId) + args.ChainID = id + } return nil } +// toTransaction converts the arguments to a transaction. +// This assumes that setDefaults has been called. func (args *SendTxArgs) toTransaction() *types.Transaction { var input []byte - if args.Data != nil { - input = *args.Data - } else if args.Input != nil { + if args.Input != nil { input = *args.Input + } else if args.Data != nil { + input = *args.Data } - if args.To == nil { - return types.NewContractCreation(uint64(*args.Nonce), (*big.Int)(args.Value), uint64(*args.Gas), (*big.Int)(args.GasPrice), input) + var data types.TxData + if args.AccessList == nil { + data = &types.LegacyTx{ + To: args.To, + Nonce: uint64(*args.Nonce), + Gas: uint64(*args.Gas), + GasPrice: (*big.Int)(args.GasPrice), + Value: (*big.Int)(args.Value), + Data: input, + } + } else { + data = &types.AccessListTx{ + To: args.To, + ChainID: (*big.Int)(args.ChainID), + Nonce: uint64(*args.Nonce), + Gas: uint64(*args.Gas), + GasPrice: (*big.Int)(args.GasPrice), + Value: (*big.Int)(args.Value), + Data: input, + AccessList: *args.AccessList, + } } - return types.NewTransaction(uint64(*args.Nonce), *args.To, (*big.Int)(args.Value), uint64(*args.Gas), (*big.Int)(args.GasPrice), input) + return types.NewTx(data) } -// submitTransaction is a helper function that submits tx to txPool and logs a message. -func submitTransaction(ctx context.Context, b Backend, tx *types.Transaction) (common.Hash, error) { +// SubmitTransaction is a helper function that submits tx to txPool and logs a message. +func SubmitTransaction(ctx context.Context, b Backend, tx *types.Transaction) (common.Hash, error) { if tx.To() != nil && tx.IsSpecialTransaction() { return common.Hash{}, errors.New("Dont allow transaction sent to BlockSigners & RandomizeSMC smart contract via API") } if err := b.SendTx(ctx, tx); err != nil { return common.Hash{}, err } + + // Print a log with full tx details for manual investigations and interventions + signer := types.MakeSigner(b.ChainConfig(), b.CurrentBlock().Number()) + from, err := types.Sender(signer, tx) + if err != nil { + return common.Hash{}, err + } + if tx.To() == nil { - signer := types.MakeSigner(b.ChainConfig(), b.CurrentBlock().Number()) - from, err := types.Sender(signer, tx) - if err != nil { - return common.Hash{}, err - } addr := crypto.CreateAddress(from, tx.Nonce()) - log.Trace("Submitted contract creation", "fullhash", tx.Hash().Hex(), "contract", addr.Hex()) + log.Info("Submitted contract creation", "hash", tx.Hash().Hex(), "from", from, "nonce", tx.Nonce(), "contract", addr.Hex(), "value", tx.Value()) } else { - log.Trace("Submitted transaction", "fullhash", tx.Hash().Hex(), "recipient", tx.To()) + log.Info("Submitted transaction", "hash", tx.Hash().Hex(), "from", from, "nonce", tx.Nonce(), "recipient", tx.To(), "value", tx.Value()) } return tx.Hash(), nil } -// submitTransaction is a helper function that submits tx to txPool and logs a message. +// SubmitTransaction is a helper function that submits tx to txPool and logs a message. func submitOrderTransaction(ctx context.Context, b Backend, tx *types.OrderTransaction) (common.Hash, error) { if err := b.SendOrderTx(ctx, tx); err != nil { @@ -2092,17 +2444,33 @@ func (s *PublicTransactionPoolAPI) SendTransaction(ctx context.Context, args Sen if err != nil { return common.Hash{}, err } - return submitTransaction(ctx, s.b, signed) + return SubmitTransaction(ctx, s.b, signed) +} + +// FillTransaction fills the defaults (nonce, gas, gasPrice) on a given unsigned transaction, +// and returns it to the caller for further processing (signing + broadcast) +func (s *PublicTransactionPoolAPI) FillTransaction(ctx context.Context, args SendTxArgs) (*SignTransactionResult, error) { + // Set some sanity defaults and terminate on failure + if err := args.setDefaults(ctx, s.b); err != nil { + return nil, err + } + // Assemble the transaction and obtain rlp + tx := args.toTransaction() + data, err := tx.MarshalBinary() + if err != nil { + return nil, err + } + return &SignTransactionResult{data, tx}, nil } // SendRawTransaction will add the signed transaction to the transaction pool. // The sender is responsible for signing the transaction and using the correct nonce. -func (s *PublicTransactionPoolAPI) SendRawTransaction(ctx context.Context, encodedTx hexutil.Bytes) (common.Hash, error) { +func (s *PublicTransactionPoolAPI) SendRawTransaction(ctx context.Context, input hexutil.Bytes) (common.Hash, error) { tx := new(types.Transaction) - if err := rlp.DecodeBytes(encodedTx, tx); err != nil { + if err := tx.UnmarshalBinary(input); err != nil { return common.Hash{}, err } - return submitTransaction(ctx, s.b, tx) + return SubmitTransaction(ctx, s.b, tx) } // SendOrderRawTransaction will add the signed transaction to the transaction pool. @@ -2980,13 +3348,13 @@ type SignTransactionResult struct { // the given from address and it needs to be unlocked. func (s *PublicTransactionPoolAPI) SignTransaction(ctx context.Context, args SendTxArgs) (*SignTransactionResult, error) { if args.Gas == nil { - return nil, fmt.Errorf("gas not specified") + return nil, errors.New("gas not specified") } if args.GasPrice == nil { - return nil, fmt.Errorf("gasPrice not specified") + return nil, errors.New("gasPrice not specified") } if args.Nonce == nil { - return nil, fmt.Errorf("nonce not specified") + return nil, errors.New("nonce not specified") } if err := args.setDefaults(ctx, s.b); err != nil { return nil, err @@ -2995,29 +3363,30 @@ func (s *PublicTransactionPoolAPI) SignTransaction(ctx context.Context, args Sen if err != nil { return nil, err } - data, err := rlp.EncodeToBytes(tx) + data, err := tx.MarshalBinary() if err != nil { return nil, err } return &SignTransactionResult{data, tx}, nil } -// PendingTransactions returns the transactions that are in the transaction pool and have a from address that is one of -// the accounts this node manages. +// PendingTransactions returns the transactions that are in the transaction pool +// and have a from address that is one of the accounts this node manages. func (s *PublicTransactionPoolAPI) PendingTransactions() ([]*RPCTransaction, error) { pending, err := s.b.GetPoolTransactions() if err != nil { return nil, err } - + accounts := make(map[common.Address]struct{}) + for _, wallet := range s.b.AccountManager().Wallets() { + for _, account := range wallet.Accounts() { + accounts[account.Address] = struct{}{} + } + } transactions := make([]*RPCTransaction, 0, len(pending)) for _, tx := range pending { - var signer types.Signer = types.HomesteadSigner{} - if tx.Protected() { - signer = types.NewEIP155Signer(tx.ChainId()) - } - from, _ := types.Sender(signer, tx) - if _, err := s.b.AccountManager().Find(accounts.Account{Address: from}); err == nil { + from, _ := types.Sender(s.signer, tx) + if _, exists := accounts[from]; exists { transactions = append(transactions, newRPCPendingTransaction(tx)) } } @@ -3028,25 +3397,22 @@ func (s *PublicTransactionPoolAPI) PendingTransactions() ([]*RPCTransaction, err // the given transaction from the pool and reinsert it with the new gas price and limit. func (s *PublicTransactionPoolAPI) Resend(ctx context.Context, sendArgs SendTxArgs, gasPrice *hexutil.Big, gasLimit *hexutil.Uint64) (common.Hash, error) { if sendArgs.Nonce == nil { - return common.Hash{}, fmt.Errorf("missing transaction nonce in transaction spec") + return common.Hash{}, errors.New("missing transaction nonce in transaction spec") } if err := sendArgs.setDefaults(ctx, s.b); err != nil { return common.Hash{}, err } matchTx := sendArgs.toTransaction() + + // Iterate the pending list for replacement pending, err := s.b.GetPoolTransactions() if err != nil { return common.Hash{}, err } - for _, p := range pending { - var signer types.Signer = types.HomesteadSigner{} - if p.Protected() { - signer = types.NewEIP155Signer(p.ChainId()) - } - wantSigHash := signer.Hash(matchTx) - - if pFrom, err := types.Sender(signer, p); err == nil && pFrom == sendArgs.From && signer.Hash(p) == wantSigHash { + wantSigHash := s.signer.Hash(matchTx) + pFrom, err := types.Sender(s.signer, p) + if err == nil && pFrom == sendArgs.From && s.signer.Hash(p) == wantSigHash { // Match. Re-sign and send the transaction. if gasPrice != nil && (*big.Int)(gasPrice).Sign() != 0 { sendArgs.GasPrice = gasPrice @@ -3064,8 +3430,7 @@ func (s *PublicTransactionPoolAPI) Resend(ctx context.Context, sendArgs SendTxAr return signedTx.Hash(), nil } } - - return common.Hash{}, fmt.Errorf("Transaction %#x not found", matchTx.Hash()) + return common.Hash{}, fmt.Errorf("transaction %#x not found", matchTx.Hash()) } // PublicDebugAPI is the collection of Ethereum APIs exposed over the public @@ -3129,7 +3494,7 @@ func (api *PrivateDebugAPI) ChaindbProperty(property string) (string, error) { LDB() *leveldb.DB }) if !ok { - return "", fmt.Errorf("chaindbProperty does not work for memory databases") + return "", errors.New("chaindbProperty does not work for memory databases") } if property == "" { property = "leveldb.stats" @@ -3144,7 +3509,7 @@ func (api *PrivateDebugAPI) ChaindbCompact() error { LDB() *leveldb.DB }) if !ok { - return fmt.Errorf("chaindbCompact does not work for memory databases") + return errors.New("chaindbCompact does not work for memory databases") } for b := byte(0); b < 255; b++ { log.Info("Compacting chain database", "range", fmt.Sprintf("0x%0.2X-0x%0.2X", b, b+1)) @@ -3288,3 +3653,18 @@ func (s *PublicBlockChainAPI) GetStakerROIMasternode(masternode common.Address) return 100.0 / float64(totalCap.Div(totalCap, voterRewardAYear).Uint64()) } + +// checkTxFee is an internal function used to check whether the fee of +// the given transaction is _reasonable_(under the cap). +func checkTxFee(gasPrice *big.Int, gas uint64, cap float64) error { + // Short circuit if there is no cap for transaction fee at all. + if cap == 0 { + return nil + } + feeEth := new(big.Float).Quo(new(big.Float).SetInt(new(big.Int).Mul(gasPrice, new(big.Int).SetUint64(gas))), new(big.Float).SetInt(big.NewInt(params.Ether))) + feeFloat, _ := feeEth.Float64() + if feeFloat > cap { + return fmt.Errorf("tx fee (%.2f ether) exceeds the configured cap (%.2f ether)", feeFloat, cap) + } + return nil +} diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index cde446db7f..00a9aea8ac 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -51,6 +51,7 @@ type Backend interface { ChainDb() ethdb.Database EventMux() *event.TypeMux AccountManager() *accounts.Manager + RPCGasCap() uint64 // global gas cap for eth_call over rpc: DoS protection XDCxService() *XDCx.XDCX LendingService() *XDCxlending.Lending @@ -67,7 +68,7 @@ type Backend interface { GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) GetTd(blockHash common.Hash) *big.Int - GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) + GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmConfig *vm.Config) (*vm.EVM, func() error, error) SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription SubscribeChainSideEvent(ch chan<- core.ChainSideEvent) event.Subscription diff --git a/internal/ethapi/trie_proof.go b/internal/ethapi/trie_proof.go new file mode 100644 index 0000000000..4a7747ab8f --- /dev/null +++ b/internal/ethapi/trie_proof.go @@ -0,0 +1,39 @@ +package ethapi + +import ( + "bytes" + + "github.com/XinFinOrg/XDPoSChain/common/hexutil" + "github.com/XinFinOrg/XDPoSChain/core/types" + "github.com/XinFinOrg/XDPoSChain/rlp" + "github.com/XinFinOrg/XDPoSChain/trie" +) + +// proofPairList implements ethdb.KeyValueWriter and collects the proofs as +// hex-strings of key and value for delivery to rpc-caller. +type proofPairList struct { + keys []string + values []string +} + +func (n *proofPairList) Put(key []byte, value []byte) error { + n.keys = append(n.keys, hexutil.Encode(key)) + n.values = append(n.values, hexutil.Encode(value)) + return nil +} + +func (n *proofPairList) Delete(key []byte) error { + panic("not supported") +} + +// modified from core/types/derive_sha.go +func deriveTrie(list types.DerivableList) *trie.Trie { + keybuf := new(bytes.Buffer) + trie := new(trie.Trie) + for i := 0; i < list.Len(); i++ { + keybuf.Reset() + rlp.Encode(keybuf, uint(i)) + trie.Update(keybuf.Bytes(), list.GetRlp(i)) + } + return trie +} diff --git a/internal/ethapi/trie_proof_test.go b/internal/ethapi/trie_proof_test.go new file mode 100644 index 0000000000..34922b9c60 --- /dev/null +++ b/internal/ethapi/trie_proof_test.go @@ -0,0 +1,106 @@ +package ethapi + +import ( + "bytes" + "errors" + "math/big" + "reflect" + "testing" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/hexutil" + "github.com/XinFinOrg/XDPoSChain/core/types" + "github.com/XinFinOrg/XDPoSChain/rlp" + "github.com/XinFinOrg/XDPoSChain/trie" +) + +// implement interface only for testing verifyProof +func (n *proofPairList) Has(key []byte) (bool, error) { + key_hex := hexutil.Encode(key) + for _, k := range n.keys { + if k == key_hex { + return true, nil + } + } + return false, nil +} + +func (n *proofPairList) Get(key []byte) ([]byte, error) { + key_hex := hexutil.Encode(key) + for i, k := range n.keys { + if k == key_hex { + b, err := hexutil.Decode(n.values[i]) + if err != nil { + return nil, err + } + return b, nil + } + } + return nil, errors.New("key not found") +} + +func TestTransactionProof(t *testing.T) { + to1 := common.HexToAddress("0x01") + to2 := common.HexToAddress("0x02") + t1 := types.NewTransaction(1, to1, big.NewInt(1), 1, big.NewInt(1), []byte{}) + t2 := types.NewTransaction(2, to2, big.NewInt(2), 2, big.NewInt(2), []byte{}) + t3 := types.NewTransaction(3, to2, big.NewInt(3), 3, big.NewInt(3), []byte{}) + t4 := types.NewTransaction(4, to1, big.NewInt(4), 4, big.NewInt(4), []byte{}) + transactions := types.Transactions([]*types.Transaction{t1, t2, t3, t4}) + tr := deriveTrie(transactions) + // for verifying the proof + root := types.DeriveSha(transactions) + for i := 0; i < transactions.Len(); i++ { + var proof proofPairList + keybuf := new(bytes.Buffer) + rlp.Encode(keybuf, uint(i)) + if err := tr.Prove(keybuf.Bytes(), 0, &proof); err != nil { + t.Fatal("Prove err:", err) + } + // verify the proof + value, err := trie.VerifyProof(root, keybuf.Bytes(), &proof) + if err != nil { + t.Fatal("verify proof error") + } + encodedTransaction, err := rlp.EncodeToBytes(transactions[i]) + if err != nil { + t.Fatal("encode receipt error") + } + if !reflect.DeepEqual(encodedTransaction, value) { + t.Fatal("verify does not return the receipt we want") + } + } +} + +func TestReceiptProof(t *testing.T) { + root1 := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + root2 := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + r1 := types.NewReceipt(root1, false, 1) + r2 := types.NewReceipt(root2, true, 2) + r3 := types.NewReceipt(root2, false, 3) + r4 := types.NewReceipt(root1, true, 4) + receipts := types.Receipts([]*types.Receipt{r1, r2, r3, r4}) + tr := deriveTrie(receipts) + // for verifying the proof + root := types.DeriveSha(receipts) + for i := 0; i < receipts.Len(); i++ { + var proof proofPairList + keybuf := new(bytes.Buffer) + rlp.Encode(keybuf, uint(i)) + if err := tr.Prove(keybuf.Bytes(), 0, &proof); err != nil { + t.Fatal("Prove err:", err) + } + // verify the proof + value, err := trie.VerifyProof(root, keybuf.Bytes(), &proof) + if err != nil { + t.Fatal("verify proof error") + } + encodedReceipt, err := rlp.EncodeToBytes(receipts[i]) + if err != nil { + t.Fatal("encode receipt error") + } + if !reflect.DeepEqual(encodedReceipt, value) { + t.Fatal("verify does not return the receipt we want") + } + } +} diff --git a/internal/guide/guide_test.go b/internal/guide/guide_test.go index 9077ae52a4..64b9e162fa 100644 --- a/internal/guide/guide_test.go +++ b/internal/guide/guide_test.go @@ -30,6 +30,7 @@ import ( "time" "github.com/XinFinOrg/XDPoSChain/accounts/keystore" + "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/types" ) @@ -74,7 +75,8 @@ func TestAccountManagement(t *testing.T) { if err != nil { t.Fatalf("Failed to create signer account: %v", err) } - tx, chain := new(types.Transaction), big.NewInt(1) + tx := types.NewTransaction(0, common.Address{}, big.NewInt(0), 0, big.NewInt(0), nil) + chain := big.NewInt(1) // Sign a transaction with a single authorization if _, err := ks.SignTxWithPassphrase(signer, "Signer password", tx, chain); err != nil { diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index c07cf978c8..dbb3f41df4 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -420,6 +420,12 @@ web3._extend({ params: 2, inputFormatter: [null, null] }), + new web3._extend.Method({ + name: 'traceCall', + call: 'debug_traceCall', + params: 3, + inputFormatter: [null, null, null] + }), new web3._extend.Method({ name: 'preimage', call: 'debug_preimage', @@ -496,6 +502,11 @@ web3._extend({ call: 'eth_getRewardByHash', params: 1 }), + new web3._extend.Method({ + name: 'getTransactionAndReceiptProof', + call: 'eth_getTransactionAndReceiptProof', + params: 1 + }), new web3._extend.Method({ name: 'getRawTransactionFromBlock', call: function(args) { @@ -510,6 +521,17 @@ web3._extend({ params: 2, inputFormatter: [web3._extend.formatters.inputAddressFormatter, web3._extend.formatters.inputBlockNumberFormatter] }), + new web3._extend.Method({ + name: 'createAccessList', + call: 'eth_createAccessList', + params: 2, + inputFormatter: [null, web3._extend.formatters.inputBlockNumberFormatter], + }), + new web3._extend.Method({ + name: 'getBlockReceipts', + call: 'eth_getBlockReceipts', + params: 1, + }), ], properties: [ new web3._extend.Property({ diff --git a/les/api_backend.go b/les/api_backend.go index b82c8ca91d..abdd952a6e 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -171,10 +171,13 @@ func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { +func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmConfig *vm.Config) (*vm.EVM, func() error, error) { + if vmConfig == nil { + vmConfig = new(vm.Config) + } state.SetBalance(msg.From(), math.MaxBig256) context := core.NewEVMContext(msg, header, b.eth.blockchain, nil) - return vm.NewEVM(context, state, XDCxState, b.eth.chainConfig, vmCfg), state.Error, nil + return vm.NewEVM(context, state, XDCxState, b.eth.chainConfig, *vmConfig), state.Error, nil } func (b *LesApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { @@ -266,6 +269,10 @@ func (b *LesApiBackend) AccountManager() *accounts.Manager { return b.eth.accountManager } +func (b *LesApiBackend) RPCGasCap() uint64 { + return b.eth.config.RPCGasCap +} + func (b *LesApiBackend) BloomStatus() (uint64, uint64) { if b.eth.bloomIndexer == nil { return 0, 0 diff --git a/les/backend.go b/les/backend.go index b172e8fc00..a4b2074744 100644 --- a/les/backend.go +++ b/les/backend.go @@ -18,7 +18,7 @@ package les import ( - "fmt" + "errors" "sync" "time" @@ -155,12 +155,12 @@ type LightDummyAPI struct{} // Etherbase is the address that mining rewards will be send to func (s *LightDummyAPI) Etherbase() (common.Address, error) { - return common.Address{}, fmt.Errorf("not supported") + return common.Address{}, errors.New("not supported") } // Coinbase is the address that mining rewards will be send to (alias for Etherbase) func (s *LightDummyAPI) Coinbase() (common.Address, error) { - return common.Address{}, fmt.Errorf("not supported") + return common.Address{}, errors.New("not supported") } // Hashrate returns the POW hashrate diff --git a/les/odr_test.go b/les/odr_test.go index bbaa7c0dfa..1234bd2853 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -19,7 +19,6 @@ package les import ( "bytes" "context" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math/big" "testing" "time" @@ -27,6 +26,7 @@ import ( "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/math" "github.com/XinFinOrg/XDPoSChain/core" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" "github.com/XinFinOrg/XDPoSChain/core/state" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/core/vm" @@ -133,7 +133,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false, balanceTokenFee, header.Number)} + msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, nil, false, balanceTokenFee, header.Number)} context := core.NewEVMContext(msg, header, bc, nil) vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{}) @@ -153,7 +153,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false, balanceTokenFee, header.Number)} + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, nil, false, balanceTokenFee, header.Number)} context := core.NewEVMContext(msg, header, lc, nil) vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) diff --git a/les/peer.go b/les/peer.go index ef635795bf..cbc7b99570 100644 --- a/les/peer.go +++ b/les/peer.go @@ -291,7 +291,7 @@ func (p *peer) RequestHelperTrieProofs(reqID, cost uint64, reqs []HelperTrieReq) reqsV1 := make([]ChtReq, len(reqs)) for i, req := range reqs { if req.Type != htCanonical || req.AuxReq != auxHeader || len(req.Key) != 8 { - return fmt.Errorf("Request invalid in LES/1 mode") + return errors.New("Request invalid in LES/1 mode") } blockNum := binary.BigEndian.Uint64(req.Key) // convert HelperTrie request to old CHT request diff --git a/les/retrieve.go b/les/retrieve.go index 91014f1964..509b8a3e35 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -22,7 +22,7 @@ import ( "context" "crypto/rand" "encoding/binary" - "fmt" + "errors" "sync" "time" @@ -119,7 +119,7 @@ func (rm *retrieveManager) retrieve(ctx context.Context, reqID uint64, req *dist case <-ctx.Done(): sentReq.stop(ctx.Err()) case <-shutdown: - sentReq.stop(fmt.Errorf("Client is shutting down")) + sentReq.stop(errors.New("Client is shutting down")) } return sentReq.getError() } diff --git a/light/odr_test.go b/light/odr_test.go index 4905260ade..83e1c807ca 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -20,12 +20,13 @@ import ( "bytes" "context" "errors" - "github.com/XinFinOrg/XDPoSChain/consensus" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math/big" "testing" "time" + "github.com/XinFinOrg/XDPoSChain/consensus" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" + "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/math" "github.com/XinFinOrg/XDPoSChain/consensus/ethash" @@ -183,7 +184,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, new(big.Int), data, false, balanceTokenFee, header.Number)} + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, new(big.Int), data, nil, false, balanceTokenFee, header.Number)} context := core.NewEVMContext(msg, header, chain, nil) vmenv := vm.NewEVM(context, st, nil, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) diff --git a/light/trie_test.go b/light/trie_test.go index fb030af871..2332043d26 100644 --- a/light/trie_test.go +++ b/light/trie_test.go @@ -19,10 +19,12 @@ package light import ( "bytes" "context" + "errors" "fmt" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "testing" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" + "github.com/XinFinOrg/XDPoSChain/consensus/ethash" "github.com/XinFinOrg/XDPoSChain/core" "github.com/XinFinOrg/XDPoSChain/core/state" @@ -74,9 +76,9 @@ func diffTries(t1, t2 state.Trie) error { case i2.Err != nil: return fmt.Errorf("light trie iterator error: %v", i1.Err) case i1.Next(): - return fmt.Errorf("full trie iterator has more k/v pairs") + return errors.New("full trie iterator has more k/v pairs") case i2.Next(): - return fmt.Errorf("light trie iterator has more k/v pairs") + return errors.New("light trie iterator has more k/v pairs") } return nil } diff --git a/light/txpool.go b/light/txpool.go index 281af18b2a..292e91b92d 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -19,6 +19,7 @@ package light import ( "context" "fmt" + "math/big" "sync" "time" @@ -30,7 +31,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/event" "github.com/XinFinOrg/XDPoSChain/log" "github.com/XinFinOrg/XDPoSChain/params" - "github.com/XinFinOrg/XDPoSChain/rlp" ) const ( @@ -67,6 +67,7 @@ type TxPool struct { clearIdx uint64 // earliest block nr that can contain mined tx info homestead bool + eip2718 bool // Fork indicator whether we are in the eip2718 stage. } // TxRelayBackend provides an interface to the mechanism that forwards transacions @@ -91,7 +92,7 @@ type TxRelayBackend interface { func NewTxPool(config *params.ChainConfig, chain *LightChain, relay TxRelayBackend) *TxPool { pool := &TxPool{ config: config, - signer: types.NewEIP155Signer(config.ChainId), + signer: types.LatestSigner(config), nonce: make(map[common.Address]uint64), pending: make(map[common.Hash]*types.Transaction), mined: make(map[common.Hash][]*types.Transaction), @@ -310,8 +311,11 @@ func (pool *TxPool) setNewHead(head *types.Header) { txc, _ := pool.reorgOnNewHead(ctx, head) m, r := txc.getLists() pool.relay.NewHead(pool.head, m, r) + + // Update fork indicator by next pending block number + next := new(big.Int).Add(head.Number, big.NewInt(1)) pool.homestead = pool.config.IsHomestead(head.Number) - pool.signer = types.MakeSigner(pool.config, head.Number) + pool.eip2718 = pool.config.IsEIP1559(next) } // Stop stops the light transaction pool @@ -403,7 +407,7 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error } // Should supply enough intrinsic gas - gas, err := core.IntrinsicGas(tx.Data(), tx.To() == nil, pool.homestead) + gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, pool.homestead) if err != nil { return err } @@ -452,8 +456,7 @@ func (self *TxPool) add(ctx context.Context, tx *types.Transaction) error { func (self *TxPool) Add(ctx context.Context, tx *types.Transaction) error { self.mu.Lock() defer self.mu.Unlock() - - data, err := rlp.EncodeToBytes(tx) + data, err := tx.MarshalBinary() if err != nil { return err } diff --git a/miner/worker.go b/miner/worker.go index 823cd74258..a97f55ceab 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -19,17 +19,15 @@ package miner import ( "bytes" "encoding/binary" - - "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" - "github.com/XinFinOrg/XDPoSChain/accounts" - + "errors" "math/big" "sync" "sync/atomic" "time" "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" - + "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" + "github.com/XinFinOrg/XDPoSChain/accounts" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/consensus/XDPoS" @@ -466,10 +464,13 @@ func (self *worker) push(work *Work) { // makeCurrent creates a new environment for the current cycle. func (self *worker) makeCurrent(parent *types.Block, header *types.Header) error { + // Retrieve the parent state to execute on top and start a prefetcher for + // the miner to speed block sealing up a bit state, err := self.chain.StateAt(parent.Root()) if err != nil { return err } + author, _ := self.chain.Engine().Author(parent.Header()) var XDCxState *tradingstate.TradingStateDB var lendingState *lendingstate.LendingStateDB @@ -490,7 +491,7 @@ func (self *worker) makeCurrent(parent *types.Block, header *types.Header) error work := &Work{ config: self.config, - signer: types.NewEIP155Signer(self.config.ChainId), + signer: types.MakeSigner(self.config, header.Number), state: state, parentState: state.Copy(), tradingState: XDCxState, @@ -613,7 +614,7 @@ func (self *worker) commitNewWork() { misc.ApplyDAOHardFork(work.state) } if common.TIPSigning.Cmp(header.Number) == 0 { - work.state.DeleteAddress(common.HexToAddress(common.BlockSigners)) + work.state.DeleteAddress(common.BlockSignersBinary) } // won't grasp txs at checkpoint var ( @@ -698,7 +699,7 @@ func (self *worker) commitNewWork() { return } nonce := work.state.GetNonce(self.coinbase) - tx := types.NewTransaction(nonce, common.HexToAddress(common.XDCXAddr), big.NewInt(0), txMatchGasLimit, big.NewInt(0), txMatchBytes) + tx := types.NewTransaction(nonce, common.XDCXAddrBinary, big.NewInt(0), txMatchGasLimit, big.NewInt(0), txMatchBytes) txM, err := wallet.SignTx(accounts.Account{Address: self.coinbase}, tx, self.config.ChainId) if err != nil { log.Error("Fail to create tx matches", "error", err) @@ -728,7 +729,7 @@ func (self *worker) commitNewWork() { return } nonce := work.state.GetNonce(self.coinbase) - lendingTx := types.NewTransaction(nonce, common.HexToAddress(common.XDCXLendingAddress), big.NewInt(0), txMatchGasLimit, big.NewInt(0), lendingDataBytes) + lendingTx := types.NewTransaction(nonce, common.XDCXLendingAddressBinary, big.NewInt(0), txMatchGasLimit, big.NewInt(0), lendingDataBytes) signedLendingTx, err := wallet.SignTx(accounts.Account{Address: self.coinbase}, lendingTx, self.config.ChainId) if err != nil { log.Error("Fail to create lending tx", "error", err) @@ -752,7 +753,7 @@ func (self *worker) commitNewWork() { return } nonce := work.state.GetNonce(self.coinbase) - finalizedTx := types.NewTransaction(nonce, common.HexToAddress(common.XDCXLendingFinalizedTradeAddress), big.NewInt(0), txMatchGasLimit, big.NewInt(0), finalizedTradeData) + finalizedTx := types.NewTransaction(nonce, common.XDCXLendingFinalizedTradeAddressBinary, big.NewInt(0), txMatchGasLimit, big.NewInt(0), finalizedTradeData) signedFinalizedTx, err := wallet.SignTx(accounts.Account{Address: self.coinbase}, finalizedTx, self.config.ChainId) if err != nil { log.Error("Fail to create lending tx", "error", err) @@ -771,7 +772,7 @@ func (self *worker) commitNewWork() { XDCxStateRoot := work.tradingState.IntermediateRoot() LendingStateRoot := work.lendingState.IntermediateRoot() txData := append(XDCxStateRoot.Bytes(), LendingStateRoot.Bytes()...) - tx := types.NewTransaction(work.state.GetNonce(self.coinbase), common.HexToAddress(common.TradingStateAddr), big.NewInt(0), txMatchGasLimit, big.NewInt(0), txData) + tx := types.NewTransaction(work.state.GetNonce(self.coinbase), common.TradingStateAddrBinary, big.NewInt(0), txMatchGasLimit, big.NewInt(0), txData) txStateRoot, err := wallet.SignTx(accounts.Account{Address: self.coinbase}, tx, self.config.ChainId) if err != nil { log.Error("Fail to create tx state root", "error", err) @@ -807,26 +808,27 @@ func (env *Work) commitTransactions(mux *event.TypeMux, balanceFee map[common.Ad var coalescedLogs []*types.Log // first priority for special Txs for _, tx := range specialTxs { - + to := tx.To() //HF number for black-list if (env.header.Number.Uint64() >= common.BlackListHFNumber) && !common.IsTestnet { + from := tx.From() // check if sender is in black list - if tx.From() != nil && common.Blacklist[*tx.From()] { - log.Debug("Skipping transaction with sender in black-list", "sender", tx.From().Hex()) + if from != nil && common.Blacklist[*from] { + log.Debug("Skipping transaction with sender in black-list", "sender", from.Hex()) continue } // check if receiver is in black list - if tx.To() != nil && common.Blacklist[*tx.To()] { - log.Debug("Skipping transaction with receiver in black-list", "receiver", tx.To().Hex()) + if to != nil && common.Blacklist[*to] { + log.Debug("Skipping transaction with receiver in black-list", "receiver", to.Hex()) continue } } - + data := tx.Data() // validate minFee slot for XDCZ if tx.IsXDCZApplyTransaction() { copyState, _ := bc.State() - if err := core.ValidateXDCZApplyTransaction(bc, nil, copyState, common.BytesToAddress(tx.Data()[4:])); err != nil { - log.Debug("XDCZApply: invalid token", "token", common.BytesToAddress(tx.Data()[4:]).Hex()) + if err := core.ValidateXDCZApplyTransaction(bc, nil, copyState, common.BytesToAddress(data[4:])); err != nil { + log.Debug("XDCZApply: invalid token", "token", common.BytesToAddress(data[4:]).Hex()) txs.Pop() continue } @@ -834,8 +836,8 @@ func (env *Work) commitTransactions(mux *event.TypeMux, balanceFee map[common.Ad // validate balance slot, token decimal for XDCX if tx.IsXDCXApplyTransaction() { copyState, _ := bc.State() - if err := core.ValidateXDCXApplyTransaction(bc, nil, copyState, common.BytesToAddress(tx.Data()[4:])); err != nil { - log.Debug("XDCXApply: invalid token", "token", common.BytesToAddress(tx.Data()[4:]).Hex()) + if err := core.ValidateXDCXApplyTransaction(bc, nil, copyState, common.BytesToAddress(data[4:])); err != nil { + log.Debug("XDCXApply: invalid token", "token", common.BytesToAddress(data[4:]).Hex()) txs.Pop() continue } @@ -852,38 +854,39 @@ func (env *Work) commitTransactions(mux *event.TypeMux, balanceFee map[common.Ad from, _ := types.Sender(env.signer, tx) // Check whether the tx is replay protected. If we're not in the EIP155 hf // phase, start ignoring the sender until we do. + hash := tx.Hash() if tx.Protected() && !env.config.IsEIP155(env.header.Number) { - log.Trace("Ignoring reply protected special transaction", "hash", tx.Hash(), "eip155", env.config.EIP155Block) + log.Trace("Ignoring reply protected special transaction", "hash", hash, "eip155", env.config.EIP155Block) continue } - if tx.To().Hex() == common.BlockSigners { - if len(tx.Data()) < 68 { - log.Trace("Data special transaction invalid length", "hash", tx.Hash(), "data", len(tx.Data())) + if *to == common.BlockSignersBinary { + if len(data) < 68 { + log.Trace("Data special transaction invalid length", "hash", hash, "data", len(data)) continue } - blkNumber := binary.BigEndian.Uint64(tx.Data()[8:40]) + blkNumber := binary.BigEndian.Uint64(data[8:40]) if blkNumber >= env.header.Number.Uint64() || blkNumber <= env.header.Number.Uint64()-env.config.XDPoS.Epoch*2 { - log.Trace("Data special transaction invalid number", "hash", tx.Hash(), "blkNumber", blkNumber, "miner", env.header.Number) + log.Trace("Data special transaction invalid number", "hash", hash, "blkNumber", blkNumber, "miner", env.header.Number) continue } } // Start executing the transaction - env.state.Prepare(tx.Hash(), common.Hash{}, env.tcount) + env.state.Prepare(hash, common.Hash{}, env.tcount) nonce := env.state.GetNonce(from) if nonce != tx.Nonce() && !tx.IsSkipNonceTransaction() { - log.Trace("Skipping account with special transaction invalid nonce", "sender", from, "nonce", nonce, "tx nonce ", tx.Nonce(), "to", tx.To()) + log.Trace("Skipping account with special transaction invalid nonce", "sender", from, "nonce", nonce, "tx nonce ", tx.Nonce(), "to", to) continue } err, logs, tokenFeeUsed, gas := env.commitTransaction(balanceFee, tx, bc, coinbase, gp) switch err { case core.ErrNonceTooLow: // New head notification data race between the transaction pool and miner, shift - log.Trace("Skipping special transaction with low nonce", "sender", from, "nonce", tx.Nonce(), "to", tx.To()) + log.Trace("Skipping special transaction with low nonce", "sender", from, "nonce", tx.Nonce(), "to", to) case core.ErrNonceTooHigh: // Reorg notification data race between the transaction pool and miner, skip account = - log.Trace("Skipping account with special transaction hight nonce", "sender", from, "nonce", tx.Nonce(), "to", tx.To()) + log.Trace("Skipping account with special transaction hight nonce", "sender", from, "nonce", tx.Nonce(), "to", to) case nil: // Everything ok, collect the logs and shift in the next transaction from the same account coalescedLogs = append(coalescedLogs, logs...) @@ -892,12 +895,12 @@ func (env *Work) commitTransactions(mux *event.TypeMux, balanceFee map[common.Ad default: // Strange error, discard the transaction and get the next in line (note, the // nonce-too-high clause will prevent us from executing in vain). - log.Debug("Add Special Transaction failed, account skipped", "hash", tx.Hash(), "sender", from, "nonce", tx.Nonce(), "to", tx.To(), "err", err) + log.Debug("Add Special Transaction failed, account skipped", "hash", hash, "sender", from, "nonce", tx.Nonce(), "to", to, "err", err) } if tokenFeeUsed { fee := common.GetGasFee(env.header.Number.Uint64(), gas) - balanceFee[*tx.To()] = new(big.Int).Sub(balanceFee[*tx.To()], fee) - balanceUpdated[*tx.To()] = balanceFee[*tx.To()] + balanceFee[*to] = new(big.Int).Sub(balanceFee[*to], fee) + balanceUpdated[*to] = balanceFee[*to] totalFeeUsed = totalFeeUsed.Add(totalFeeUsed, fee) } } @@ -919,26 +922,28 @@ func (env *Work) commitTransactions(mux *event.TypeMux, balanceFee map[common.Ad } //HF number for black-list + to := tx.To() if (env.header.Number.Uint64() >= common.BlackListHFNumber) && !common.IsTestnet { + from := tx.From() // check if sender is in black list - if tx.From() != nil && common.Blacklist[*tx.From()] { - log.Debug("Skipping transaction with sender in black-list", "sender", tx.From().Hex()) + if from != nil && common.Blacklist[*from] { + log.Debug("Skipping transaction with sender in black-list", "sender", from.Hex()) txs.Pop() continue } // check if receiver is in black list - if tx.To() != nil && common.Blacklist[*tx.To()] { - log.Debug("Skipping transaction with receiver in black-list", "receiver", tx.To().Hex()) + if to != nil && common.Blacklist[*to] { + log.Debug("Skipping transaction with receiver in black-list", "receiver", to.Hex()) txs.Shift() continue } } - + data := tx.Data() // validate minFee slot for XDCZ if tx.IsXDCZApplyTransaction() { copyState, _ := bc.State() - if err := core.ValidateXDCZApplyTransaction(bc, nil, copyState, common.BytesToAddress(tx.Data()[4:])); err != nil { - log.Debug("XDCZApply: invalid token", "token", common.BytesToAddress(tx.Data()[4:]).Hex()) + if err := core.ValidateXDCZApplyTransaction(bc, nil, copyState, common.BytesToAddress(data[4:])); err != nil { + log.Debug("XDCZApply: invalid token", "token", common.BytesToAddress(data[4:]).Hex()) txs.Pop() continue } @@ -946,8 +951,8 @@ func (env *Work) commitTransactions(mux *event.TypeMux, balanceFee map[common.Ad // validate balance slot, token decimal for XDCX if tx.IsXDCXApplyTransaction() { copyState, _ := bc.State() - if err := core.ValidateXDCXApplyTransaction(bc, nil, copyState, common.BytesToAddress(tx.Data()[4:])); err != nil { - log.Debug("XDCXApply: invalid token", "token", common.BytesToAddress(tx.Data()[4:]).Hex()) + if err := core.ValidateXDCXApplyTransaction(bc, nil, copyState, common.BytesToAddress(data[4:])); err != nil { + log.Debug("XDCXApply: invalid token", "token", common.BytesToAddress(data[4:]).Hex()) txs.Pop() continue } @@ -958,15 +963,16 @@ func (env *Work) commitTransactions(mux *event.TypeMux, balanceFee map[common.Ad // // We use the eip155 signer regardless of the current hf. from, _ := types.Sender(env.signer, tx) + hash := tx.Hash() // Check whether the tx is replay protected. If we're not in the EIP155 hf // phase, start ignoring the sender until we do. if tx.Protected() && !env.config.IsEIP155(env.header.Number) { - log.Trace("Ignoring reply protected transaction", "hash", tx.Hash(), "eip155", env.config.EIP155Block) + log.Trace("Ignoring reply protected transaction", "hash", hash, "eip155", env.config.EIP155Block) txs.Pop() continue } // Start executing the transaction - env.state.Prepare(tx.Hash(), common.Hash{}, env.tcount) + env.state.Prepare(hash, common.Hash{}, env.tcount) nonce := env.state.GetNonce(from) if nonce > tx.Nonce() { // New head notification data race between the transaction pool and miner, shift @@ -981,38 +987,43 @@ func (env *Work) commitTransactions(mux *event.TypeMux, balanceFee map[common.Ad continue } err, logs, tokenFeeUsed, gas := env.commitTransaction(balanceFee, tx, bc, coinbase, gp) - switch err { - case core.ErrGasLimitReached: + switch { + case errors.Is(err, core.ErrGasLimitReached): // Pop the current out-of-gas transaction without shifting in the next from the account log.Trace("Gas limit exceeded for current block", "sender", from) txs.Pop() - case core.ErrNonceTooLow: + case errors.Is(err, core.ErrNonceTooLow): // New head notification data race between the transaction pool and miner, shift log.Trace("Skipping transaction with low nonce", "sender", from, "nonce", tx.Nonce()) txs.Shift() - case core.ErrNonceTooHigh: + case errors.Is(err, core.ErrNonceTooHigh): // Reorg notification data race between the transaction pool and miner, skip account = log.Trace("Skipping account with high nonce", "sender", from, "nonce", tx.Nonce()) txs.Pop() - case nil: + case errors.Is(err, nil): // Everything ok, collect the logs and shift in the next transaction from the same account coalescedLogs = append(coalescedLogs, logs...) env.tcount++ txs.Shift() + case errors.Is(err, core.ErrTxTypeNotSupported): + // Pop the unsupported transaction without shifting in the next from the account + log.Trace("Skipping unsupported transaction type", "sender", from, "type", tx.Type()) + txs.Pop() + default: // Strange error, discard the transaction and get the next in line (note, the // nonce-too-high clause will prevent us from executing in vain). - log.Debug("Transaction failed, account skipped", "hash", tx.Hash(), "err", err) + log.Debug("Transaction failed, account skipped", "hash", hash, "err", err) txs.Shift() } if tokenFeeUsed { fee := common.GetGasFee(env.header.Number.Uint64(), gas) - balanceFee[*tx.To()] = new(big.Int).Sub(balanceFee[*tx.To()], fee) - balanceUpdated[*tx.To()] = balanceFee[*tx.To()] + balanceFee[*to] = new(big.Int).Sub(balanceFee[*to], fee) + balanceUpdated[*to] = balanceFee[*to] totalFeeUsed = totalFeeUsed.Add(totalFeeUsed, fee) } } diff --git a/node/api.go b/node/api.go index e03f668b17..8a36115244 100644 --- a/node/api.go +++ b/node/api.go @@ -18,6 +18,7 @@ package node import ( "context" + "errors" "fmt" "strings" "time" @@ -169,7 +170,7 @@ func (api *PrivateAdminAPI) StopRPC() (bool, error) { defer api.node.lock.Unlock() if api.node.httpHandler == nil { - return false, fmt.Errorf("HTTP RPC not running") + return false, errors.New("HTTP RPC not running") } api.node.stopHTTP() return true, nil @@ -223,7 +224,7 @@ func (api *PrivateAdminAPI) StopWS() (bool, error) { defer api.node.lock.Unlock() if api.node.wsHandler == nil { - return false, fmt.Errorf("WebSocket RPC not running") + return false, errors.New("WebSocket RPC not running") } api.node.stopWS() return true, nil diff --git a/node/service_test.go b/node/service_test.go index 86fb1a6fbd..23dc807084 100644 --- a/node/service_test.go +++ b/node/service_test.go @@ -17,6 +17,7 @@ package node import ( + "errors" "fmt" "os" "path/filepath" @@ -70,7 +71,7 @@ func TestContextServices(t *testing.T) { verifier := func(ctx *ServiceContext) (Service, error) { var objA *NoopServiceA if ctx.Service(&objA) != nil { - return nil, fmt.Errorf("former service not found") + return nil, errors.New("former service not found") } var objB *NoopServiceB if err := ctx.Service(&objB); err != ErrServiceUnknown { diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go index c24823967c..097771553d 100644 --- a/p2p/discv5/net.go +++ b/p2p/discv5/net.go @@ -1063,7 +1063,7 @@ func (net *Network) checkPacket(n *Node, ev nodeEvent, pkt *ingressPacket) error case pongPacket: if !bytes.Equal(pkt.data.(*pong).ReplyTok, n.pingEcho) { // fmt.Println("pong reply token mismatch") - return fmt.Errorf("pong reply token mismatch") + return errors.New("pong reply token mismatch") } n.pingEcho = nil } diff --git a/p2p/discv5/ticket.go b/p2p/discv5/ticket.go index 7c82d2db97..2a6d40e244 100644 --- a/p2p/discv5/ticket.go +++ b/p2p/discv5/ticket.go @@ -19,6 +19,7 @@ package discv5 import ( "bytes" "encoding/binary" + "errors" "fmt" "math" "math/rand" @@ -95,7 +96,7 @@ func pongToTicket(localTime mclock.AbsTime, topics []Topic, node *Node, p *ingre return nil, fmt.Errorf("bad wait period list: got %d values, want %d", len(topics), len(wps)) } if rlpHash(topics) != p.data.(*pong).TopicHash { - return nil, fmt.Errorf("bad topic hash") + return nil, errors.New("bad topic hash") } t := &ticket{ issueTime: localTime, diff --git a/p2p/enr/enr.go b/p2p/enr/enr.go index 28c8e26da5..5aca3ab25a 100644 --- a/p2p/enr/enr.go +++ b/p2p/enr/enr.go @@ -275,7 +275,7 @@ func (r *Record) verifySignature() error { if err := r.Load(&entry); err != nil { return err } else if len(entry) != 33 { - return fmt.Errorf("invalid public key") + return errors.New("invalid public key") } // Verify the signature. diff --git a/p2p/nat/natpmp.go b/p2p/nat/natpmp.go index 577a424fbe..df4764cc10 100644 --- a/p2p/nat/natpmp.go +++ b/p2p/nat/natpmp.go @@ -17,12 +17,13 @@ package nat import ( + "errors" "fmt" "net" "strings" "time" - "github.com/jackpal/go-nat-pmp" + natpmp "github.com/jackpal/go-nat-pmp" ) // natPMPClient adapts the NAT-PMP protocol implementation so it conforms to @@ -46,7 +47,7 @@ func (n *pmp) ExternalIP() (net.IP, error) { func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { if lifetime <= 0 { - return fmt.Errorf("lifetime must not be <= 0") + return errors.New("lifetime must not be <= 0") } // Note order of port arguments is switched between our // AddMapping and the client's AddPortMapping. diff --git a/p2p/peer.go b/p2p/peer.go index a7eb05621f..1d1cfc8906 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -17,6 +17,7 @@ package p2p import ( + "errors" "fmt" "io" "net" @@ -409,7 +410,7 @@ func (rw *protoRW) WriteMsg(msg Msg) (err error) { // as well but we don't want to rely on that. rw.werr <- err case <-rw.closed: - err = fmt.Errorf("shutting down") + err = errors.New("shutting down") } return err } diff --git a/p2p/protocols/protocol_test.go b/p2p/protocols/protocol_test.go index a4247917c7..bcd3186f6c 100644 --- a/p2p/protocols/protocol_test.go +++ b/p2p/protocols/protocol_test.go @@ -43,7 +43,7 @@ type kill struct { type drop struct { } -/// protoHandshake represents module-independent aspects of the protocol and is +// / protoHandshake represents module-independent aspects of the protocol and is // the first message peers send and receive as part the initial exchange type protoHandshake struct { Version uint // local and remote peer should have identical version @@ -241,7 +241,7 @@ func runModuleHandshake(t *testing.T, resp uint, errs ...error) { } func TestModuleHandshakeError(t *testing.T) { - runModuleHandshake(t, 43, fmt.Errorf("handshake mismatch remote 43 > local 42")) + runModuleHandshake(t, 43, errors.New("handshake mismatch remote 43 > local 42")) } func TestModuleHandshakeSuccess(t *testing.T) { @@ -376,14 +376,14 @@ WAIT: func TestMultiplePeersDropSelf(t *testing.T) { runMultiplePeers(t, 0, - fmt.Errorf("subprotocol error"), - fmt.Errorf("Message handler error: (msg code 3): dropped"), + errors.New("subprotocol error"), + errors.New("Message handler error: (msg code 3): dropped"), ) } func TestMultiplePeersDropOther(t *testing.T) { runMultiplePeers(t, 1, - fmt.Errorf("Message handler error: (msg code 3): dropped"), - fmt.Errorf("subprotocol error"), + errors.New("Message handler error: (msg code 3): dropped"), + errors.New("subprotocol error"), ) } diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 66efd41997..5ceb897eae 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -147,7 +147,7 @@ func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, return nil, err } if msg.Size > baseProtocolMaxMsgSize { - return nil, fmt.Errorf("message too big") + return nil, errors.New("message too big") } if msg.Code == discMsg { // Disconnect before protocol handshake is valid according to the @@ -528,9 +528,9 @@ func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) { return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey)) } // TODO: fewer pointless conversions - pub := crypto.ToECDSAPub(pubKey65) - if pub.X == nil { - return nil, fmt.Errorf("invalid public key") + pub, err := crypto.UnmarshalPubkey(pubKey65) + if err != nil { + return nil, err } return ecies.ImportECDSAPublic(pub), nil } diff --git a/p2p/server.go b/p2p/server.go index f20cba6dde..2ccb4cf17a 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -20,7 +20,6 @@ package p2p import ( "crypto/ecdsa" "errors" - "fmt" "net" "sync" "time" @@ -365,7 +364,7 @@ type sharedUDPConn struct { func (s *sharedUDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { packet, ok := <-s.unhandled if !ok { - return 0, nil, fmt.Errorf("Connection was closed") + return 0, nil, errors.New("Connection was closed") } l := len(packet.Data) if l > len(b) { @@ -397,7 +396,7 @@ func (srv *Server) Start() (err error) { // static fields if srv.PrivateKey == nil { - return fmt.Errorf("Server.PrivateKey must be set to a non-nil key") + return errors.New("Server.PrivateKey must be set to a non-nil key") } if srv.newTransport == nil { srv.newTransport = newRLPX diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go index da221189b9..2c0133b111 100644 --- a/p2p/testing/protocolsession.go +++ b/p2p/testing/protocolsession.go @@ -273,7 +273,7 @@ func (self *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error } delete(expects, event.Peer) case <-timeout: - return fmt.Errorf("timed out waiting for peers to disconnect") + return errors.New("timed out waiting for peers to disconnect") } } return nil diff --git a/params/config.go b/params/config.go index ed488836d8..ff2de6a7eb 100644 --- a/params/config.go +++ b/params/config.go @@ -363,6 +363,14 @@ type ChainConfig struct { ByzantiumBlock *big.Int `json:"byzantiumBlock,omitempty"` // Byzantium switch block (nil = no fork, 0 = already on byzantium) ConstantinopleBlock *big.Int `json:"constantinopleBlock,omitempty"` // Constantinople switch block (nil = no fork, 0 = already activated) + PetersburgBlock *big.Int `json:"petersburgBlock,omitempty"` + IstanbulBlock *big.Int `json:"istanbulBlock,omitempty"` + BerlinBlock *big.Int `json:"berlinBlock,omitempty"` + LondonBlock *big.Int `json:"londonBlock,omitempty"` + MergeBlock *big.Int `json:"mergeBlock,omitempty"` + ShanghaiBlock *big.Int `json:"shanghaiBlock,omitempty"` + Eip1559Block *big.Int `json:"eip1559Block,omitempty"` + // Various consensus engines Ethash *EthashConfig `json:"ethash,omitempty"` Clique *CliqueConfig `json:"clique,omitempty"` @@ -552,37 +560,37 @@ func (c *ChainConfig) IsConstantinople(num *big.Int) bool { // - equal to or greater than the PetersburgBlock fork block, // - OR is nil, and Constantinople is active func (c *ChainConfig) IsPetersburg(num *big.Int) bool { - return isForked(common.TIPXDCXCancellationFee, num) + return isForked(common.TIPXDCXCancellationFee, num) || isForked(c.PetersburgBlock, num) } // IsIstanbul returns whether num is either equal to the Istanbul fork block or greater. func (c *ChainConfig) IsIstanbul(num *big.Int) bool { - return isForked(common.TIPXDCXCancellationFee, num) + return isForked(common.TIPXDCXCancellationFee, num) || isForked(c.IstanbulBlock, num) } // IsBerlin returns whether num is either equal to the Berlin fork block or greater. func (c *ChainConfig) IsBerlin(num *big.Int) bool { - return isForked(common.BerlinBlock, num) + return isForked(common.BerlinBlock, num) || isForked(c.BerlinBlock, num) } // IsLondon returns whether num is either equal to the London fork block or greater. func (c *ChainConfig) IsLondon(num *big.Int) bool { - return isForked(common.LondonBlock, num) + return isForked(common.LondonBlock, num) || isForked(c.LondonBlock, num) } // IsMerge returns whether num is either equal to the Merge fork block or greater. // Different from Geth which uses `block.difficulty != nil` func (c *ChainConfig) IsMerge(num *big.Int) bool { - return isForked(common.MergeBlock, num) + return isForked(common.MergeBlock, num) || isForked(c.MergeBlock, num) } // IsShanghai returns whether num is either equal to the Shanghai fork block or greater. func (c *ChainConfig) IsShanghai(num *big.Int) bool { - return isForked(common.ShanghaiBlock, num) + return isForked(common.ShanghaiBlock, num) || isForked(c.ShanghaiBlock, num) } func (c *ChainConfig) IsEIP1559(num *big.Int) bool { - return isForked(common.Eip1559Block, num) + return isForked(common.Eip1559Block, num) || isForked(c.Eip1559Block, num) } func (c *ChainConfig) IsTIP2019(num *big.Int) bool { diff --git a/params/protocol_params.go b/params/protocol_params.go index 93419fc367..b8cd64f723 100644 --- a/params/protocol_params.go +++ b/params/protocol_params.go @@ -61,6 +61,10 @@ const ( CreateGas uint64 = 32000 // Once per CREATE operation & contract-creation transaction. SuicideRefundGas uint64 = 24000 // Refunded following a suicide operation. MemoryGas uint64 = 3 // Times the address of the (highest referenced byte in memory + 1). NOTE: referencing happens on read, write and in instructions such as RETURN and CALL. + + TxAccessListAddressGas uint64 = 2400 // Per address specified in EIP 2930 access list + TxAccessListStorageKeyGas uint64 = 1900 // Per storage key specified in EIP 2930 access list + TxDataNonZeroGas uint64 = 68 // Per byte of data attached to a transaction that is not equal to zero. NOTE: Not payable on data of calls between transactions. MaxCodeSize = 24576 // Maximum bytecode to permit for a contract diff --git a/params/version.go b/params/version.go index 41efb96e34..6e9d48d38d 100644 --- a/params/version.go +++ b/params/version.go @@ -22,7 +22,7 @@ import ( const ( VersionMajor = 2 // Major version component of the current release - VersionMinor = 2 // Minor version component of the current release + VersionMinor = 3 // Minor version component of the current release VersionPatch = 0 // Patch version component of the current release VersionMeta = "beta1" // Version metadata to append to the version string ) diff --git a/rlp/decode.go b/rlp/decode.go index 60d9dab2b5..1d32ff8523 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -26,103 +26,80 @@ import ( "math/big" "reflect" "strings" + "sync" + + "github.com/XinFinOrg/XDPoSChain/rlp/internal/rlpstruct" + "github.com/holiman/uint256" ) +//lint:ignore ST1012 EOL is not an error. + +// EOL is returned when the end of the current list +// has been reached during streaming. +var EOL = errors.New("rlp: end of list") + var ( + ErrExpectedString = errors.New("rlp: expected String or Byte") + ErrExpectedList = errors.New("rlp: expected List") + ErrCanonInt = errors.New("rlp: non-canonical integer format") + ErrCanonSize = errors.New("rlp: non-canonical size information") + ErrElemTooLarge = errors.New("rlp: element is larger than containing list") + ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") + ErrMoreThanOneValue = errors.New("rlp: input contains more than one value") + + // internal errors + errNotInList = errors.New("rlp: call of ListEnd outside of any list") + errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") + errUintOverflow = errors.New("rlp: uint overflow") errNoPointer = errors.New("rlp: interface given to Decode must be a pointer") errDecodeIntoNil = errors.New("rlp: pointer given to Decode must not be nil") + errUint256Large = errors.New("rlp: value too large for uint256") + + streamPool = sync.Pool{ + New: func() interface{} { return new(Stream) }, + } ) -// Decoder is implemented by types that require custom RLP -// decoding rules or need to decode into private fields. +// Decoder is implemented by types that require custom RLP decoding rules or need to decode +// into private fields. // -// The DecodeRLP method should read one value from the given -// Stream. It is not forbidden to read less or more, but it might -// be confusing. +// The DecodeRLP method should read one value from the given Stream. It is not forbidden to +// read less or more, but it might be confusing. type Decoder interface { DecodeRLP(*Stream) error } -// Decode parses RLP-encoded data from r and stores the result in the -// value pointed to by val. Val must be a non-nil pointer. If r does -// not implement ByteReader, Decode will do its own buffering. +// Decode parses RLP-encoded data from r and stores the result in the value pointed to by +// val. Please see package-level documentation for the decoding rules. Val must be a +// non-nil pointer. // -// Decode uses the following type-dependent decoding rules: +// If r does not implement ByteReader, Decode will do its own buffering. // -// If the type implements the Decoder interface, decode calls -// DecodeRLP. +// Note that Decode does not set an input limit for all readers and may be vulnerable to +// panics cause by huge value sizes. If you need an input limit, use // -// To decode into a pointer, Decode will decode into the value pointed -// to. If the pointer is nil, a new value of the pointer's element -// type is allocated. If the pointer is non-nil, the existing value -// will be reused. -// -// To decode into a struct, Decode expects the input to be an RLP -// list. The decoded elements of the list are assigned to each public -// field in the order given by the struct's definition. The input list -// must contain an element for each decoded field. Decode returns an -// error if there are too few or too many elements. -// -// The decoding of struct fields honours certain struct tags, "tail", -// "nil" and "-". -// -// The "-" tag ignores fields. -// -// For an explanation of "tail", see the example. -// -// The "nil" tag applies to pointer-typed fields and changes the decoding -// rules for the field such that input values of size zero decode as a nil -// pointer. This tag can be useful when decoding recursive types. -// -// type StructWithEmptyOK struct { -// Foo *[20]byte `rlp:"nil"` -// } -// -// To decode into a slice, the input must be a list and the resulting -// slice will contain the input elements in order. For byte slices, -// the input must be an RLP string. Array types decode similarly, with -// the additional restriction that the number of input elements (or -// bytes) must match the array's length. -// -// To decode into a Go string, the input must be an RLP string. The -// input bytes are taken as-is and will not necessarily be valid UTF-8. -// -// To decode into an unsigned integer type, the input must also be an RLP -// string. The bytes are interpreted as a big endian representation of -// the integer. If the RLP string is larger than the bit size of the -// type, Decode will return an error. Decode also supports *big.Int. -// There is no size limit for big integers. -// -// To decode into an interface value, Decode stores one of these -// in the value: -// -// []interface{}, for RLP lists -// []byte, for RLP strings -// -// Non-empty interface types are not supported, nor are booleans, -// signed integers, floating point numbers, maps, channels and -// functions. -// -// Note that Decode does not set an input limit for all readers -// and may be vulnerable to panics cause by huge value sizes. If -// you need an input limit, use -// -// NewStream(r, limit).Decode(val) +// NewStream(r, limit).Decode(val) func Decode(r io.Reader, val interface{}) error { - // TODO: this could use a Stream from a pool. - return NewStream(r, 0).Decode(val) + stream := streamPool.Get().(*Stream) + defer streamPool.Put(stream) + + stream.Reset(r, 0) + return stream.Decode(val) } -// DecodeBytes parses RLP data from b into val. -// Please see the documentation of Decode for the decoding rules. -// The input must contain exactly one value and no trailing data. +// DecodeBytes parses RLP data from b into val. Please see package-level documentation for +// the decoding rules. The input must contain exactly one value and no trailing data. func DecodeBytes(b []byte, val interface{}) error { - // TODO: this could use a Stream from a pool. - r := bytes.NewReader(b) - if err := NewStream(r, uint64(len(b))).Decode(val); err != nil { + r := (*sliceReader)(&b) + + stream := streamPool.Get().(*Stream) + defer streamPool.Put(stream) + + stream.Reset(r, uint64(len(b))) + if err := stream.Decode(val); err != nil { return err } - if r.Len() > 0 { + if len(b) > 0 { return ErrMoreThanOneValue } return nil @@ -173,21 +150,26 @@ func addErrorContext(err error, ctx string) error { var ( decoderInterface = reflect.TypeOf(new(Decoder)).Elem() bigInt = reflect.TypeOf(big.Int{}) + u256Int = reflect.TypeOf(uint256.Int{}) ) -func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { +func makeDecoder(typ reflect.Type, tags rlpstruct.Tags) (dec decoder, err error) { kind := typ.Kind() switch { case typ == rawValueType: return decodeRawValue, nil - case typ.Implements(decoderInterface): - return decodeDecoder, nil - case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(decoderInterface): - return decodeDecoderNoPtr, nil - case typ.AssignableTo(reflect.PtrTo(bigInt)): + case typ.AssignableTo(reflect.PointerTo(bigInt)): return decodeBigInt, nil case typ.AssignableTo(bigInt): return decodeBigIntNoPtr, nil + case typ == reflect.PointerTo(u256Int): + return decodeU256, nil + case typ == u256Int: + return decodeU256NoPtr, nil + case kind == reflect.Ptr: + return makePtrDecoder(typ, tags) + case reflect.PointerTo(typ).Implements(decoderInterface): + return decodeDecoder, nil case isUint(kind): return decodeUint, nil case kind == reflect.Bool: @@ -198,11 +180,6 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { return makeListDecoder(typ, tags) case kind == reflect.Struct: return makeStructDecoder(typ) - case kind == reflect.Ptr: - if tags.nilOK { - return makeOptionalPtrDecoder(typ) - } - return makePtrDecoder(typ) case kind == reflect.Interface: return decodeInterface, nil default: @@ -252,35 +229,48 @@ func decodeBigIntNoPtr(s *Stream, val reflect.Value) error { } func decodeBigInt(s *Stream, val reflect.Value) error { - b, err := s.Bytes() - if err != nil { - return wrapStreamError(err, val.Type()) - } i := val.Interface().(*big.Int) if i == nil { i = new(big.Int) val.Set(reflect.ValueOf(i)) } - // Reject leading zero bytes - if len(b) > 0 && b[0] == 0 { - return wrapStreamError(ErrCanonInt, val.Type()) + + err := s.decodeBigInt(i) + if err != nil { + return wrapStreamError(err, val.Type()) } - i.SetBytes(b) return nil } -func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) { +func decodeU256NoPtr(s *Stream, val reflect.Value) error { + return decodeU256(s, val.Addr()) +} + +func decodeU256(s *Stream, val reflect.Value) error { + i := val.Interface().(*uint256.Int) + if i == nil { + i = new(uint256.Int) + val.Set(reflect.ValueOf(i)) + } + + err := s.ReadUint256(i) + if err != nil { + return wrapStreamError(err, val.Type()) + } + return nil +} + +func makeListDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) { etype := typ.Elem() - if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) { + if etype.Kind() == reflect.Uint8 && !reflect.PointerTo(etype).Implements(decoderInterface) { if typ.Kind() == reflect.Array { return decodeByteArray, nil - } else { - return decodeByteSlice, nil } + return decodeByteSlice, nil } - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{}) + if etypeinfo.decoderErr != nil { + return nil, etypeinfo.decoderErr } var dec decoder switch { @@ -288,7 +278,7 @@ func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) { dec = func(s *Stream, val reflect.Value) error { return decodeListArray(s, val, etypeinfo.decoder) } - case tag.tail: + case tag.Tail: // A slice with "tail" tag can occur as the last field // of a struct and is supposed to swallow all remaining // list elements. The struct decoder already called s.List, @@ -381,25 +371,23 @@ func decodeByteArray(s *Stream, val reflect.Value) error { if err != nil { return err } - vlen := val.Len() + slice := byteArrayBytes(val, val.Len()) switch kind { case Byte: - if vlen == 0 { + if len(slice) == 0 { return &decodeError{msg: "input string too long", typ: val.Type()} - } - if vlen > 1 { + } else if len(slice) > 1 { return &decodeError{msg: "input string too short", typ: val.Type()} } - bv, _ := s.Uint() - val.Index(0).SetUint(bv) + slice[0] = s.byteval + s.kind = -1 case String: - if uint64(vlen) < size { + if uint64(len(slice)) < size { return &decodeError{msg: "input string too long", typ: val.Type()} } - if uint64(vlen) > size { + if uint64(len(slice)) > size { return &decodeError{msg: "input string too short", typ: val.Type()} } - slice := val.Slice(0, vlen).Interface().([]byte) if err := s.readFull(slice); err != nil { return err } @@ -418,13 +406,25 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { if err != nil { return nil, err } + for _, f := range fields { + if f.info.decoderErr != nil { + return nil, structFieldError{typ, f.index, f.info.decoderErr} + } + } dec := func(s *Stream, val reflect.Value) (err error) { if _, err := s.List(); err != nil { return wrapStreamError(err, typ) } - for _, f := range fields { + for i, f := range fields { err := f.info.decoder(s, val.Field(f.index)) if err == EOL { + if f.optional { + // The field is optional, so reaching the end of the list before + // reaching the last field is acceptable. All remaining undecoded + // fields are zeroed. + zeroFields(val, fields[i:]) + break + } return &decodeError{msg: "too few elements", typ: typ} } else if err != nil { return addErrorContext(err, "."+typ.Field(f.index).Name) @@ -435,15 +435,29 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { return dec, nil } -// makePtrDecoder creates a decoder that decodes into -// the pointer's element type. -func makePtrDecoder(typ reflect.Type) (decoder, error) { - etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err +func zeroFields(structval reflect.Value, fields []field) { + for _, f := range fields { + fv := structval.Field(f.index) + fv.Set(reflect.Zero(fv.Type())) } - dec := func(s *Stream, val reflect.Value) (err error) { +} + +// makePtrDecoder creates a decoder that decodes into the pointer's element type. +func makePtrDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) { + etype := typ.Elem() + etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{}) + switch { + case etypeinfo.decoderErr != nil: + return nil, etypeinfo.decoderErr + case !tag.NilOK: + return makeSimplePtrDecoder(etype, etypeinfo), nil + default: + return makeNilPtrDecoder(etype, etypeinfo, tag), nil + } +} + +func makeSimplePtrDecoder(etype reflect.Type, etypeinfo *typeinfo) decoder { + return func(s *Stream, val reflect.Value) (err error) { newval := val if val.IsNil() { newval = reflect.New(etype) @@ -453,30 +467,39 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) { } return err } - return dec, nil } -// makeOptionalPtrDecoder creates a decoder that decodes empty values -// as nil. Non-empty values are decoded into a value of the element type, -// just like makePtrDecoder does. +// makeNilPtrDecoder creates a decoder that decodes empty values as nil. Non-empty +// values are decoded into a value of the element type, just like makePtrDecoder does. // // This decoder is used for pointer-typed struct fields with struct tag "nil". -func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { - etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err - } - dec := func(s *Stream, val reflect.Value) (err error) { +func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, ts rlpstruct.Tags) decoder { + typ := reflect.PointerTo(etype) + nilPtr := reflect.Zero(typ) + + // Determine the value kind that results in nil pointer. + nilKind := typeNilKind(etype, ts) + + return func(s *Stream, val reflect.Value) (err error) { kind, size, err := s.Kind() - if err != nil || size == 0 && kind != Byte { + if err != nil { + val.Set(nilPtr) + return wrapStreamError(err, typ) + } + // Handle empty values as a nil pointer. + if kind != Byte && size == 0 { + if kind != nilKind { + return &decodeError{ + msg: fmt.Sprintf("wrong kind of empty value (got %v, want %v)", kind, nilKind), + typ: typ, + } + } // rearm s.Kind. This is important because the input // position must advance to the next value even though // we don't read anything. s.kind = -1 - // set the pointer to nil. - val.Set(reflect.Zero(typ)) - return err + val.Set(nilPtr) + return nil } newval := val if val.IsNil() { @@ -487,7 +510,6 @@ func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { } return err } - return dec, nil } var ifsliceType = reflect.TypeOf([]interface{}{}) @@ -516,25 +538,12 @@ func decodeInterface(s *Stream, val reflect.Value) error { return nil } -// This decoder is used for non-pointer values of types -// that implement the Decoder interface using a pointer receiver. -func decodeDecoderNoPtr(s *Stream, val reflect.Value) error { +func decodeDecoder(s *Stream, val reflect.Value) error { return val.Addr().Interface().(Decoder).DecodeRLP(s) } -func decodeDecoder(s *Stream, val reflect.Value) error { - // Decoder instances are not handled using the pointer rule if the type - // implements Decoder with pointer receiver (i.e. always) - // because it might handle empty values specially. - // We need to allocate one here in this case, like makePtrDecoder does. - if val.Kind() == reflect.Ptr && val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) - } - return val.Interface().(Decoder).DecodeRLP(s) -} - // Kind represents the kind of value contained in an RLP stream. -type Kind int +type Kind int8 const ( Byte Kind = iota @@ -555,29 +564,6 @@ func (k Kind) String() string { } } -var ( - // EOL is returned when the end of the current list - // has been reached during streaming. - EOL = errors.New("rlp: end of list") - - // Actual Errors - ErrExpectedString = errors.New("rlp: expected String or Byte") - ErrExpectedList = errors.New("rlp: expected List") - ErrCanonInt = errors.New("rlp: non-canonical integer format") - ErrCanonSize = errors.New("rlp: non-canonical size information") - ErrElemTooLarge = errors.New("rlp: element is larger than containing list") - ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") - - // This error is reported by DecodeBytes if the slice contains - // additional data after the first RLP value. - ErrMoreThanOneValue = errors.New("rlp: input contains more than one value") - - // internal errors - errNotInList = errors.New("rlp: call of ListEnd outside of any list") - errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") - errUintOverflow = errors.New("rlp: uint overflow") -) - // ByteReader must be implemented by any input reader for a Stream. It // is implemented by e.g. bufio.Reader and bytes.Reader. type ByteReader interface { @@ -600,22 +586,16 @@ type ByteReader interface { type Stream struct { r ByteReader - // number of bytes remaining to be read from r. - remaining uint64 - limited bool - - // auxiliary buffer for integer decoding - uintbuf []byte - - kind Kind // kind of value ahead - size uint64 // size of value ahead - byteval byte // value of single byte in type tag - kinderr error // error from last readKind - stack []listpos + remaining uint64 // number of bytes remaining to be read from r + size uint64 // size of value ahead + kinderr error // error from last readKind + stack []uint64 // list sizes + uintbuf [32]byte // auxiliary buffer for integer decoding + kind Kind // kind of value ahead + byteval byte // value of single byte in type tag + limited bool // true if input limit is in effect } -type listpos struct{ pos, size uint64 } - // NewStream creates a new decoding stream reading from r. // // If r implements the ByteReader interface, Stream will @@ -675,6 +655,37 @@ func (s *Stream) Bytes() ([]byte, error) { } } +// ReadBytes decodes the next RLP value and stores the result in b. +// The value size must match len(b) exactly. +func (s *Stream) ReadBytes(b []byte) error { + kind, size, err := s.Kind() + if err != nil { + return err + } + switch kind { + case Byte: + if len(b) != 1 { + return fmt.Errorf("input value has wrong size 1, want %d", len(b)) + } + b[0] = s.byteval + s.kind = -1 // rearm Kind + return nil + case String: + if uint64(len(b)) != size { + return fmt.Errorf("input value has wrong size %d, want %d", size, len(b)) + } + if err = s.readFull(b); err != nil { + return err + } + if size == 1 && b[0] < 128 { + return ErrCanonSize + } + return nil + default: + return ErrExpectedString + } +} + // Raw reads a raw encoded value including RLP type information. func (s *Stream) Raw() ([]byte, error) { kind, size, err := s.Kind() @@ -685,8 +696,8 @@ func (s *Stream) Raw() ([]byte, error) { s.kind = -1 // rearm Kind return []byte{s.byteval}, nil } - // the original header has already been read and is no longer - // available. read content and put a new header in front of it. + // The original header has already been read and is no longer + // available. Read content and put a new header in front of it. start := headsize(size) buf := make([]byte, uint64(start)+size) if err := s.readFull(buf[start:]); err != nil { @@ -703,10 +714,31 @@ func (s *Stream) Raw() ([]byte, error) { // Uint reads an RLP string of up to 8 bytes and returns its contents // as an unsigned integer. If the input does not contain an RLP string, the // returned error will be ErrExpectedString. +// +// Deprecated: use s.Uint64 instead. func (s *Stream) Uint() (uint64, error) { return s.uint(64) } +func (s *Stream) Uint64() (uint64, error) { + return s.uint(64) +} + +func (s *Stream) Uint32() (uint32, error) { + i, err := s.uint(32) + return uint32(i), err +} + +func (s *Stream) Uint16() (uint16, error) { + i, err := s.uint(16) + return uint16(i), err +} + +func (s *Stream) Uint8() (uint8, error) { + i, err := s.uint(8) + return uint8(i), err +} + func (s *Stream) uint(maxbits int) (uint64, error) { kind, size, err := s.Kind() if err != nil { @@ -769,7 +801,14 @@ func (s *Stream) List() (size uint64, err error) { if kind != List { return 0, ErrExpectedList } - s.stack = append(s.stack, listpos{0, size}) + + // Remove size of inner list from outer list before pushing the new size + // onto the stack. This ensures that the remaining outer list size will + // be correct after the matching call to ListEnd. + if inList, limit := s.listLimit(); inList { + s.stack[len(s.stack)-1] = limit - size + } + s.stack = append(s.stack, size) s.kind = -1 s.size = 0 return size, nil @@ -778,22 +817,116 @@ func (s *Stream) List() (size uint64, err error) { // ListEnd returns to the enclosing list. // The input reader must be positioned at the end of a list. func (s *Stream) ListEnd() error { - if len(s.stack) == 0 { + // Ensure that no more data is remaining in the current list. + if inList, listLimit := s.listLimit(); !inList { return errNotInList - } - tos := s.stack[len(s.stack)-1] - if tos.pos != tos.size { + } else if listLimit > 0 { return errNotAtEOL } s.stack = s.stack[:len(s.stack)-1] // pop - if len(s.stack) > 0 { - s.stack[len(s.stack)-1].pos += tos.size - } s.kind = -1 s.size = 0 return nil } +// MoreDataInList reports whether the current list context contains +// more data to be read. +func (s *Stream) MoreDataInList() bool { + _, listLimit := s.listLimit() + return listLimit > 0 +} + +// BigInt decodes an arbitrary-size integer value. +func (s *Stream) BigInt() (*big.Int, error) { + i := new(big.Int) + if err := s.decodeBigInt(i); err != nil { + return nil, err + } + return i, nil +} + +func (s *Stream) decodeBigInt(dst *big.Int) error { + var buffer []byte + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == List: + return ErrExpectedString + case kind == Byte: + buffer = s.uintbuf[:1] + buffer[0] = s.byteval + s.kind = -1 // re-arm Kind + case size == 0: + // Avoid zero-length read. + s.kind = -1 + case size <= uint64(len(s.uintbuf)): + // For integers smaller than s.uintbuf, allocating a buffer + // can be avoided. + buffer = s.uintbuf[:size] + if err := s.readFull(buffer); err != nil { + return err + } + // Reject inputs where single byte encoding should have been used. + if size == 1 && buffer[0] < 128 { + return ErrCanonSize + } + default: + // For large integers, a temporary buffer is needed. + buffer = make([]byte, size) + if err := s.readFull(buffer); err != nil { + return err + } + } + + // Reject leading zero bytes. + if len(buffer) > 0 && buffer[0] == 0 { + return ErrCanonInt + } + // Set the integer bytes. + dst.SetBytes(buffer) + return nil +} + +// ReadUint256 decodes the next value as a uint256. +func (s *Stream) ReadUint256(dst *uint256.Int) error { + var buffer []byte + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == List: + return ErrExpectedString + case kind == Byte: + buffer = s.uintbuf[:1] + buffer[0] = s.byteval + s.kind = -1 // re-arm Kind + case size == 0: + // Avoid zero-length read. + s.kind = -1 + case size <= uint64(len(s.uintbuf)): + // All possible uint256 values fit into s.uintbuf. + buffer = s.uintbuf[:size] + if err := s.readFull(buffer); err != nil { + return err + } + // Reject inputs where single byte encoding should have been used. + if size == 1 && buffer[0] < 128 { + return ErrCanonSize + } + default: + return errUint256Large + } + + // Reject leading zero bytes. + if len(buffer) > 0 && buffer[0] == 0 { + return ErrCanonInt + } + // Set the integer bytes. + dst.SetBytes(buffer) + return nil +} + // Decode decodes a value and stores the result in the value pointed // to by val. Please see the documentation for the Decode function // to learn about the decoding rules. @@ -809,14 +942,14 @@ func (s *Stream) Decode(val interface{}) error { if rval.IsNil() { return errDecodeIntoNil } - info, err := cachedTypeInfo(rtyp.Elem(), tags{}) + decoder, err := cachedDecoder(rtyp.Elem()) if err != nil { return err } - err = info.decoder(s, rval.Elem()) + err = decoder(s, rval.Elem()) if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 { - // add decode target type to error so context has more meaning + // Add decode target type to error so context has more meaning. decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")")) } return err @@ -839,6 +972,9 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { case *bytes.Reader: s.remaining = uint64(br.Len()) s.limited = true + case *bytes.Buffer: + s.remaining = uint64(br.Len()) + s.limited = true case *strings.Reader: s.remaining = uint64(br.Len()) s.limited = true @@ -857,9 +993,8 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { s.size = 0 s.kind = -1 s.kinderr = nil - if s.uintbuf == nil { - s.uintbuf = make([]byte, 8) - } + s.byteval = 0 + s.uintbuf = [32]byte{} } // Kind returns the kind and size of the next value in the @@ -874,35 +1009,29 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { // the value. Subsequent calls to Kind (until the value is decoded) // will not advance the input reader and return cached information. func (s *Stream) Kind() (kind Kind, size uint64, err error) { - var tos *listpos - if len(s.stack) > 0 { - tos = &s.stack[len(s.stack)-1] + if s.kind >= 0 { + return s.kind, s.size, s.kinderr } - if s.kind < 0 { - s.kinderr = nil - // Don't read further if we're at the end of the - // innermost list. - if tos != nil && tos.pos == tos.size { - return 0, 0, EOL - } - s.kind, s.size, s.kinderr = s.readKind() - if s.kinderr == nil { - if tos == nil { - // At toplevel, check that the value is smaller - // than the remaining input length. - if s.limited && s.size > s.remaining { - s.kinderr = ErrValueTooLarge - } - } else { - // Inside a list, check that the value doesn't overflow the list. - if s.size > tos.size-tos.pos { - s.kinderr = ErrElemTooLarge - } - } + + // Check for end of list. This needs to be done here because readKind + // checks against the list size, and would return the wrong error. + inList, listLimit := s.listLimit() + if inList && listLimit == 0 { + return 0, 0, EOL + } + // Read the actual size tag. + s.kind, s.size, s.kinderr = s.readKind() + if s.kinderr == nil { + // Check the data size of the value ahead against input limits. This + // is done here because many decoders require allocating an input + // buffer matching the value size. Checking it here protects those + // decoders from inputs declaring very large value size. + if inList && s.size > listLimit { + s.kinderr = ErrElemTooLarge + } else if s.limited && s.size > s.remaining { + s.kinderr = ErrValueTooLarge } } - // Note: this might return a sticky error generated - // by an earlier call to readKind. return s.kind, s.size, s.kinderr } @@ -929,37 +1058,35 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) { s.byteval = b return Byte, 0, nil case b < 0xB8: - // Otherwise, if a string is 0-55 bytes long, - // the RLP encoding consists of a single byte with value 0x80 plus the - // length of the string followed by the string. The range of the first - // byte is thus [0x80, 0xB7]. + // Otherwise, if a string is 0-55 bytes long, the RLP encoding consists + // of a single byte with value 0x80 plus the length of the string + // followed by the string. The range of the first byte is thus [0x80, 0xB7]. return String, uint64(b - 0x80), nil case b < 0xC0: - // If a string is more than 55 bytes long, the - // RLP encoding consists of a single byte with value 0xB7 plus the length - // of the length of the string in binary form, followed by the length of - // the string, followed by the string. For example, a length-1024 string - // would be encoded as 0xB90400 followed by the string. The range of - // the first byte is thus [0xB8, 0xBF]. + // If a string is more than 55 bytes long, the RLP encoding consists of a + // single byte with value 0xB7 plus the length of the length of the + // string in binary form, followed by the length of the string, followed + // by the string. For example, a length-1024 string would be encoded as + // 0xB90400 followed by the string. The range of the first byte is thus + // [0xB8, 0xBF]. size, err = s.readUint(b - 0xB7) if err == nil && size < 56 { err = ErrCanonSize } return String, size, err case b < 0xF8: - // If the total payload of a list - // (i.e. the combined length of all its items) is 0-55 bytes long, the - // RLP encoding consists of a single byte with value 0xC0 plus the length - // of the list followed by the concatenation of the RLP encodings of the - // items. The range of the first byte is thus [0xC0, 0xF7]. + // If the total payload of a list (i.e. the combined length of all its + // items) is 0-55 bytes long, the RLP encoding consists of a single byte + // with value 0xC0 plus the length of the list followed by the + // concatenation of the RLP encodings of the items. The range of the + // first byte is thus [0xC0, 0xF7]. return List, uint64(b - 0xC0), nil default: - // If the total payload of a list is more than 55 bytes long, - // the RLP encoding consists of a single byte with value 0xF7 - // plus the length of the length of the payload in binary - // form, followed by the length of the payload, followed by - // the concatenation of the RLP encodings of the items. The - // range of the first byte is thus [0xF8, 0xFF]. + // If the total payload of a list is more than 55 bytes long, the RLP + // encoding consists of a single byte with value 0xF7 plus the length of + // the length of the payload in binary form, followed by the length of + // the payload, followed by the concatenation of the RLP encodings of + // the items. The range of the first byte is thus [0xF8, 0xFF]. size, err = s.readUint(b - 0xF7) if err == nil && size < 56 { err = ErrCanonSize @@ -977,23 +1104,22 @@ func (s *Stream) readUint(size byte) (uint64, error) { b, err := s.readByte() return uint64(b), err default: + buffer := s.uintbuf[:8] + clear(buffer) start := int(8 - size) - for i := 0; i < start; i++ { - s.uintbuf[i] = 0 - } - if err := s.readFull(s.uintbuf[start:]); err != nil { + if err := s.readFull(buffer[start:]); err != nil { return 0, err } - if s.uintbuf[start] == 0 { - // Note: readUint is also used to decode integer - // values. The error needs to be adjusted to become - // ErrCanonInt in this case. + if buffer[start] == 0 { + // Note: readUint is also used to decode integer values. + // The error needs to be adjusted to become ErrCanonInt in this case. return 0, ErrCanonSize } - return binary.BigEndian.Uint64(s.uintbuf), nil + return binary.BigEndian.Uint64(buffer[:]), nil } } +// readFull reads into buf from the underlying stream. func (s *Stream) readFull(buf []byte) (err error) { if err := s.willRead(uint64(len(buf))); err != nil { return err @@ -1004,11 +1130,18 @@ func (s *Stream) readFull(buf []byte) (err error) { n += nn } if err == io.EOF { - err = io.ErrUnexpectedEOF + if n < len(buf) { + err = io.ErrUnexpectedEOF + } else { + // Readers are allowed to give EOF even though the read succeeded. + // In such cases, we discard the EOF, like io.ReadFull() does. + err = nil + } } return err } +// readByte reads a single byte from the underlying stream. func (s *Stream) readByte() (byte, error) { if err := s.willRead(1); err != nil { return 0, err @@ -1020,16 +1153,16 @@ func (s *Stream) readByte() (byte, error) { return b, err } +// willRead is called before any read from the underlying stream. It checks +// n against size limits, and updates the limits if n doesn't overflow them. func (s *Stream) willRead(n uint64) error { s.kind = -1 // rearm Kind - if len(s.stack) > 0 { - // check list overflow - tos := s.stack[len(s.stack)-1] - if n > tos.size-tos.pos { + if inList, limit := s.listLimit(); inList { + if n > limit { return ErrElemTooLarge } - s.stack[len(s.stack)-1].pos += n + s.stack[len(s.stack)-1] = limit - n } if s.limited { if n > s.remaining { @@ -1039,3 +1172,31 @@ func (s *Stream) willRead(n uint64) error { } return nil } + +// listLimit returns the amount of data remaining in the innermost list. +func (s *Stream) listLimit() (inList bool, limit uint64) { + if len(s.stack) == 0 { + return false, 0 + } + return true, s.stack[len(s.stack)-1] +} + +type sliceReader []byte + +func (sr *sliceReader) Read(b []byte) (int, error) { + if len(*sr) == 0 { + return 0, io.EOF + } + n := copy(b, *sr) + *sr = (*sr)[n:] + return n, nil +} + +func (sr *sliceReader) ReadByte() (byte, error) { + if len(*sr) == 0 { + return 0, io.EOF + } + b := (*sr)[0] + *sr = (*sr)[1:] + return b, nil +} diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 4d8abd0012..38cca38aa5 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -26,6 +26,9 @@ import ( "reflect" "strings" "testing" + + "github.com/XinFinOrg/XDPoSChain/common/math" + "github.com/holiman/uint256" ) func TestStreamKind(t *testing.T) { @@ -284,6 +287,47 @@ func TestStreamRaw(t *testing.T) { } } +func TestStreamReadBytes(t *testing.T) { + tests := []struct { + input string + size int + err string + }{ + // kind List + {input: "C0", size: 1, err: "rlp: expected String or Byte"}, + // kind Byte + {input: "04", size: 0, err: "input value has wrong size 1, want 0"}, + {input: "04", size: 1}, + {input: "04", size: 2, err: "input value has wrong size 1, want 2"}, + // kind String + {input: "820102", size: 0, err: "input value has wrong size 2, want 0"}, + {input: "820102", size: 1, err: "input value has wrong size 2, want 1"}, + {input: "820102", size: 2}, + {input: "820102", size: 3, err: "input value has wrong size 2, want 3"}, + } + + for _, test := range tests { + test := test + name := fmt.Sprintf("input_%s/size_%d", test.input, test.size) + t.Run(name, func(t *testing.T) { + s := NewStream(bytes.NewReader(unhex(test.input)), 0) + b := make([]byte, test.size) + err := s.ReadBytes(b) + if test.err == "" { + if err != nil { + t.Errorf("unexpected error %q", err) + } + } else { + if err == nil { + t.Errorf("expected error, got nil") + } else if err.Error() != test.err { + t.Errorf("wrong error %q", err) + } + } + }) + } +} + func TestDecodeErrors(t *testing.T) { r := bytes.NewReader(nil) @@ -327,6 +371,15 @@ type recstruct struct { Child *recstruct `rlp:"nil"` } +type bigIntStruct struct { + I *big.Int + B string +} + +type invalidNilTag struct { + X []byte `rlp:"nil"` +} + type invalidTail1 struct { A uint `rlp:"tail"` B string @@ -347,19 +400,79 @@ type tailUint struct { Tail []uint `rlp:"tail"` } -var ( - veryBigInt = big.NewInt(0).Add( - big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), - big.NewInt(0xFFFF), - ) -) +type tailPrivateFields struct { + A uint + Tail []uint `rlp:"tail"` + x, y bool //lint:ignore U1000 unused fields required for testing purposes. +} -type hasIgnoredField struct { +type nilListUint struct { + X *uint `rlp:"nilList"` +} + +type nilStringSlice struct { + X *[]uint `rlp:"nilString"` +} + +type intField struct { + X int +} + +type optionalFields struct { + A uint + B uint `rlp:"optional"` + C uint `rlp:"optional"` +} + +type optionalAndTailField struct { + A uint + B uint `rlp:"optional"` + Tail []uint `rlp:"tail"` +} + +type optionalBigIntField struct { + A uint + B *big.Int `rlp:"optional"` +} + +type optionalPtrField struct { + A uint + B *[3]byte `rlp:"optional"` +} + +type nonOptionalPtrField struct { + A uint + B *[3]byte +} + +type multipleOptionalFields struct { + A *[3]byte `rlp:"optional"` + B *[3]byte `rlp:"optional"` +} + +type optionalPtrFieldNil struct { + A uint + B *[3]byte `rlp:"optional,nil"` +} + +type ignoredField struct { A uint B uint `rlp:"-"` C uint } +var ( + veryBigInt = new(big.Int).Add( + new(big.Int).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), + big.NewInt(0xFFFF), + ) + veryVeryBigInt = new(big.Int).Exp(veryBigInt, big.NewInt(8), nil) +) + +var ( + veryBigInt256, _ = uint256.FromBig(veryBigInt) +) + var decodeTests = []decodeTest{ // booleans {input: "01", ptr: new(bool), value: true}, @@ -428,12 +541,31 @@ var decodeTests = []decodeTest{ {input: "C0", ptr: new(string), error: "rlp: expected input string or byte for string"}, // big ints + {input: "80", ptr: new(*big.Int), value: big.NewInt(0)}, {input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, + {input: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", ptr: new(*big.Int), value: veryVeryBigInt}, {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works + + // big int errors {input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"}, - {input: "820001", ptr: new(big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, - {input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"}, + {input: "00", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, + {input: "820001", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, + {input: "8105", ptr: new(*big.Int), error: "rlp: non-canonical size information for *big.Int"}, + + // uint256 + {input: "80", ptr: new(*uint256.Int), value: uint256.NewInt(0)}, + {input: "01", ptr: new(*uint256.Int), value: uint256.NewInt(1)}, + {input: "88FFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: uint256.NewInt(math.MaxUint64)}, + {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: veryBigInt256}, + {input: "10", ptr: new(uint256.Int), value: *uint256.NewInt(16)}, // non-pointer also works + + // uint256 errors + {input: "C0", ptr: new(*uint256.Int), error: "rlp: expected input string or byte for *uint256.Int"}, + {input: "00", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"}, + {input: "820001", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"}, + {input: "8105", ptr: new(*uint256.Int), error: "rlp: non-canonical size information for *uint256.Int"}, + {input: "A1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00", ptr: new(*uint256.Int), error: "rlp: value too large for uint256"}, // structs { @@ -446,6 +578,13 @@ var decodeTests = []decodeTest{ ptr: new(recstruct), value: recstruct{1, &recstruct{2, &recstruct{3, nil}}}, }, + { + // This checks that empty big.Int works correctly in struct context. It's easy to + // miss the update of s.kind for this case, so it needs its own test. + input: "C58083343434", + ptr: new(bigIntStruct), + value: bigIntStruct{new(big.Int), "444"}, + }, // struct errors { @@ -479,20 +618,20 @@ var decodeTests = []decodeTest{ error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I", }, { - input: "C0", - ptr: new(invalidTail1), - error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail1.A (must be on last field)", - }, - { - input: "C0", - ptr: new(invalidTail2), - error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail2.B (field type is not slice)", + input: "C103", + ptr: new(intField), + error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)", }, { input: "C50102C20102", ptr: new(tailUint), error: "rlp: expected input string or byte for uint, decoding into (rlp.tailUint).Tail[1]", }, + { + input: "C0", + ptr: new(invalidNilTag), + error: `rlp: invalid struct tag "nil" for rlp.invalidNilTag.X (field is not a pointer)`, + }, // struct tag "tail" { @@ -510,12 +649,192 @@ var decodeTests = []decodeTest{ ptr: new(tailRaw), value: tailRaw{A: 1, Tail: []RawValue{}}, }, + { + input: "C3010203", + ptr: new(tailPrivateFields), + value: tailPrivateFields{A: 1, Tail: []uint{2, 3}}, + }, + { + input: "C0", + ptr: new(invalidTail1), + error: `rlp: invalid struct tag "tail" for rlp.invalidTail1.A (must be on last field)`, + }, + { + input: "C0", + ptr: new(invalidTail2), + error: `rlp: invalid struct tag "tail" for rlp.invalidTail2.B (field type is not slice)`, + }, // struct tag "-" { input: "C20102", - ptr: new(hasIgnoredField), - value: hasIgnoredField{A: 1, C: 2}, + ptr: new(ignoredField), + value: ignoredField{A: 1, C: 2}, + }, + + // struct tag "nilList" + { + input: "C180", + ptr: new(nilListUint), + error: "rlp: wrong kind of empty value (got String, want List) for *uint, decoding into (rlp.nilListUint).X", + }, + { + input: "C1C0", + ptr: new(nilListUint), + value: nilListUint{}, + }, + { + input: "C103", + ptr: new(nilListUint), + value: func() interface{} { + v := uint(3) + return nilListUint{X: &v} + }(), + }, + + // struct tag "nilString" + { + input: "C1C0", + ptr: new(nilStringSlice), + error: "rlp: wrong kind of empty value (got List, want String) for *[]uint, decoding into (rlp.nilStringSlice).X", + }, + { + input: "C180", + ptr: new(nilStringSlice), + value: nilStringSlice{}, + }, + { + input: "C2C103", + ptr: new(nilStringSlice), + value: nilStringSlice{X: &[]uint{3}}, + }, + + // struct tag "optional" + { + input: "C101", + ptr: new(optionalFields), + value: optionalFields{1, 0, 0}, + }, + { + input: "C20102", + ptr: new(optionalFields), + value: optionalFields{1, 2, 0}, + }, + { + input: "C3010203", + ptr: new(optionalFields), + value: optionalFields{1, 2, 3}, + }, + { + input: "C401020304", + ptr: new(optionalFields), + error: "rlp: input list has too many elements for rlp.optionalFields", + }, + { + input: "C101", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1}, + }, + { + input: "C20102", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}}, + }, + { + input: "C401020304", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{3, 4}}, + }, + { + input: "C101", + ptr: new(optionalBigIntField), + value: optionalBigIntField{A: 1, B: nil}, + }, + { + input: "C20102", + ptr: new(optionalBigIntField), + value: optionalBigIntField{A: 1, B: big.NewInt(2)}, + }, + { + input: "C101", + ptr: new(optionalPtrField), + value: optionalPtrField{A: 1}, + }, + { + input: "C20180", // not accepted because "optional" doesn't enable "nil" + ptr: new(optionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B", + }, + { + input: "C20102", + ptr: new(optionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B", + }, + { + input: "C50183010203", + ptr: new(optionalPtrField), + value: optionalPtrField{A: 1, B: &[3]byte{1, 2, 3}}, + }, + { + // all optional fields nil + input: "C0", + ptr: new(multipleOptionalFields), + value: multipleOptionalFields{A: nil, B: nil}, + }, + { + // all optional fields set + input: "C88301020383010203", + ptr: new(multipleOptionalFields), + value: multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}}, + }, + { + // nil optional field appears before a non-nil one + input: "C58083010203", + ptr: new(multipleOptionalFields), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.multipleOptionalFields).A", + }, + { + // decode a nil ptr into a ptr that is not nil or not optional + input: "C20180", + ptr: new(nonOptionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.nonOptionalPtrField).B", + }, + { + input: "C101", + ptr: new(optionalPtrFieldNil), + value: optionalPtrFieldNil{A: 1}, + }, + { + input: "C20180", // accepted because "nil" tag allows empty input + ptr: new(optionalPtrFieldNil), + value: optionalPtrFieldNil{A: 1}, + }, + { + input: "C20102", + ptr: new(optionalPtrFieldNil), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrFieldNil).B", + }, + + // struct tag "optional" field clearing + { + input: "C101", + ptr: &optionalFields{A: 9, B: 8, C: 7}, + value: optionalFields{A: 1, B: 0, C: 0}, + }, + { + input: "C20102", + ptr: &optionalFields{A: 9, B: 8, C: 7}, + value: optionalFields{A: 1, B: 2, C: 0}, + }, + { + input: "C20102", + ptr: &optionalAndTailField{A: 9, B: 8, Tail: []uint{7, 6, 5}}, + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}}, + }, + { + input: "C101", + ptr: &optionalPtrField{A: 9, B: &[3]byte{8, 7, 6}}, + value: optionalPtrField{A: 1}, }, // RawValue @@ -591,6 +910,26 @@ func TestDecodeWithByteReader(t *testing.T) { }) } +func testDecodeWithEncReader(t *testing.T, n int) { + s := strings.Repeat("0", n) + _, r, _ := EncodeToReader(s) + var decoded string + err := Decode(r, &decoded) + if err != nil { + t.Errorf("Unexpected decode error with n=%v: %v", n, err) + } + if decoded != s { + t.Errorf("Decode mismatch with n=%v", n) + } +} + +// This is a regression test checking that decoding from encReader +// works for RLP values of size 8192 bytes or more. +func TestDecodeWithEncReader(t *testing.T) { + testDecodeWithEncReader(t, 8188) // length with header is 8191 + testDecodeWithEncReader(t, 8189) // length with header is 8192 +} + // plainReader reads from a byte slice but does not // implement ReadByte. It is also not recognized by the // size validation. This is useful to test how the decoder @@ -661,6 +1000,22 @@ func TestDecodeDecoder(t *testing.T) { } } +func TestDecodeDecoderNilPointer(t *testing.T) { + var s struct { + T1 *testDecoder `rlp:"nil"` + T2 *testDecoder + } + if err := Decode(bytes.NewReader(unhex("C2C002")), &s); err != nil { + t.Fatalf("Decode error: %v", err) + } + if s.T1 != nil { + t.Errorf("decoder T1 allocated for empty input (called: %v)", s.T1.called) + } + if s.T2 == nil || !s.T2.called { + t.Errorf("decoder T2 not allocated/called") + } +} + type byteDecoder byte func (bd *byteDecoder) DecodeRLP(s *Stream) error { @@ -691,13 +1046,66 @@ func TestDecoderInByteSlice(t *testing.T) { } } +type unencodableDecoder func() + +func (f *unencodableDecoder) DecodeRLP(s *Stream) error { + if _, err := s.List(); err != nil { + return err + } + if err := s.ListEnd(); err != nil { + return err + } + *f = func() {} + return nil +} + +func TestDecoderFunc(t *testing.T) { + var x func() + if err := DecodeBytes([]byte{0xC0}, (*unencodableDecoder)(&x)); err != nil { + t.Fatal(err) + } + x() +} + +// This tests the validity checks for fields with struct tag "optional". +func TestInvalidOptionalField(t *testing.T) { + type ( + invalid1 struct { + A uint `rlp:"optional"` + B uint + } + invalid2 struct { + T []uint `rlp:"tail,optional"` + } + invalid3 struct { + T []uint `rlp:"optional,tail"` + } + ) + + tests := []struct { + v interface{} + err string + }{ + {v: new(invalid1), err: `rlp: invalid struct tag "" for rlp.invalid1.B (must be optional because preceding field "A" is optional)`}, + {v: new(invalid2), err: `rlp: invalid struct tag "optional" for rlp.invalid2.T (also has "tail" tag)`}, + {v: new(invalid3), err: `rlp: invalid struct tag "tail" for rlp.invalid3.T (also has "optional" tag)`}, + } + for _, test := range tests { + err := DecodeBytes(unhex("C20102"), test.v) + if err == nil { + t.Errorf("no error for %T", test.v) + } else if err.Error() != test.err { + t.Errorf("wrong error for %T: %v", test.v, err.Error()) + } + } +} + func ExampleDecode() { input, _ := hex.DecodeString("C90A1486666F6F626172") type example struct { - A, B uint - private uint // private fields are ignored - String string + A, B uint + String string } var s example @@ -708,7 +1116,7 @@ func ExampleDecode() { fmt.Printf("Decoded value: %#v\n", s) } // Output: - // Decoded value: rlp.example{A:0xa, B:0x14, private:0x0, String:"foobar"} + // Decoded value: rlp.example{A:0xa, B:0x14, String:"foobar"} } func ExampleDecode_structTagNil() { @@ -768,7 +1176,7 @@ func ExampleStream() { // [102 111 111 98 97 114] } -func BenchmarkDecode(b *testing.B) { +func BenchmarkDecodeUints(b *testing.B) { enc := encodeTestSlice(90000) b.SetBytes(int64(len(enc))) b.ReportAllocs() @@ -783,7 +1191,7 @@ func BenchmarkDecode(b *testing.B) { } } -func BenchmarkDecodeIntSliceReuse(b *testing.B) { +func BenchmarkDecodeUintsReused(b *testing.B) { enc := encodeTestSlice(100000) b.SetBytes(int64(len(enc))) b.ReportAllocs() @@ -798,6 +1206,65 @@ func BenchmarkDecodeIntSliceReuse(b *testing.B) { } } +func BenchmarkDecodeByteArrayStruct(b *testing.B) { + enc, err := EncodeToBytes(&byteArrayStruct{}) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out byteArrayStruct + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecodeBigInts(b *testing.B) { + ints := make([]*big.Int, 200) + for i := range ints { + ints[i] = math.BigPow(2, int64(i)) + } + enc, err := EncodeToBytes(ints) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out []*big.Int + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecodeU256Ints(b *testing.B) { + ints := make([]*uint256.Int, 200) + for i := range ints { + ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i))) + } + enc, err := EncodeToBytes(ints) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out []*uint256.Int + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + func encodeTestSlice(n uint) []byte { s := make([]uint, n) for i := uint(0); i < n; i++ { @@ -811,7 +1278,7 @@ func encodeTestSlice(n uint) []byte { } func unhex(str string) []byte { - b, err := hex.DecodeString(strings.Replace(str, " ", "", -1)) + b, err := hex.DecodeString(strings.ReplaceAll(str, " ", "")) if err != nil { panic(fmt.Sprintf("invalid hex string: %q", str)) } diff --git a/rlp/doc.go b/rlp/doc.go index b3a81fe232..eeeee9a43a 100644 --- a/rlp/doc.go +++ b/rlp/doc.go @@ -17,17 +17,142 @@ /* Package rlp implements the RLP serialization format. -The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily -nested arrays of binary data, and RLP is the main encoding method used -to serialize objects in Ethereum. The only purpose of RLP is to encode -structure; encoding specific atomic data types (eg. strings, ints, -floats) is left up to higher-order protocols; in Ethereum integers -must be represented in big endian binary form with no leading zeroes -(thus making the integer value zero equivalent to the empty byte -array). +The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily nested arrays of +binary data, and RLP is the main encoding method used to serialize objects in Ethereum. +The only purpose of RLP is to encode structure; encoding specific atomic data types (eg. +strings, ints, floats) is left up to higher-order protocols. In Ethereum integers must be +represented in big endian binary form with no leading zeroes (thus making the integer +value zero equivalent to the empty string). -RLP values are distinguished by a type tag. The type tag precedes the -value in the input stream and defines the size and kind of the bytes -that follow. +RLP values are distinguished by a type tag. The type tag precedes the value in the input +stream and defines the size and kind of the bytes that follow. + +# Encoding Rules + +Package rlp uses reflection and encodes RLP based on the Go type of the value. + +If the type implements the Encoder interface, Encode calls EncodeRLP. It does not +call EncodeRLP on nil pointer values. + +To encode a pointer, the value being pointed to is encoded. A nil pointer to a struct +type, slice or array always encodes as an empty RLP list unless the slice or array has +element type byte. A nil pointer to any other value encodes as the empty string. + +Struct values are encoded as an RLP list of all their encoded public fields. Recursive +struct types are supported. + +To encode slices and arrays, the elements are encoded as an RLP list of the value's +elements. Note that arrays and slices with element type uint8 or byte are always encoded +as an RLP string. + +A Go string is encoded as an RLP string. + +An unsigned integer value is encoded as an RLP string. Zero always encodes as an empty RLP +string. big.Int values are treated as integers. Signed integers (int, int8, int16, ...) +are not supported and will return an error when encoding. + +Boolean values are encoded as the unsigned integers zero (false) and one (true). + +An interface value encodes as the value contained in the interface. + +Floating point numbers, maps, channels and functions are not supported. + +# Decoding Rules + +Decoding uses the following type-dependent rules: + +If the type implements the Decoder interface, DecodeRLP is called. + +To decode into a pointer, the value will be decoded as the element type of the pointer. If +the pointer is nil, a new value of the pointer's element type is allocated. If the pointer +is non-nil, the existing value will be reused. Note that package rlp never leaves a +pointer-type struct field as nil unless one of the "nil" struct tags is present. + +To decode into a struct, decoding expects the input to be an RLP list. The decoded +elements of the list are assigned to each public field in the order given by the struct's +definition. The input list must contain an element for each decoded field. Decoding +returns an error if there are too few or too many elements for the struct. + +To decode into a slice, the input must be a list and the resulting slice will contain the +input elements in order. For byte slices, the input must be an RLP string. Array types +decode similarly, with the additional restriction that the number of input elements (or +bytes) must match the array's defined length. + +To decode into a Go string, the input must be an RLP string. The input bytes are taken +as-is and will not necessarily be valid UTF-8. + +To decode into an unsigned integer type, the input must also be an RLP string. The bytes +are interpreted as a big endian representation of the integer. If the RLP string is larger +than the bit size of the type, decoding will return an error. Decode also supports +*big.Int. There is no size limit for big integers. + +To decode into a boolean, the input must contain an unsigned integer of value zero (false) +or one (true). + +To decode into an interface value, one of these types is stored in the value: + + []interface{}, for RLP lists + []byte, for RLP strings + +Non-empty interface types are not supported when decoding. +Signed integers, floating point numbers, maps, channels and functions cannot be decoded into. + +# Struct Tags + +As with other encoding packages, the "-" tag ignores fields. + + type StructWithIgnoredField struct{ + Ignored uint `rlp:"-"` + Field uint + } + +Go struct values encode/decode as RLP lists. There are two ways of influencing the mapping +of fields to list elements. The "tail" tag, which may only be used on the last exported +struct field, allows slurping up any excess list elements into a slice. + + type StructWithTail struct{ + Field uint + Tail []string `rlp:"tail"` + } + +The "optional" tag says that the field may be omitted if it is zero-valued. If this tag is +used on a struct field, all subsequent public fields must also be declared optional. + +When encoding a struct with optional fields, the output RLP list contains all values up to +the last non-zero optional field. + +When decoding into a struct, optional fields may be omitted from the end of the input +list. For the example below, this means input lists of one, two, or three elements are +accepted. + + type StructWithOptionalFields struct{ + Required uint + Optional1 uint `rlp:"optional"` + Optional2 uint `rlp:"optional"` + } + +The "nil", "nilList" and "nilString" tags apply to pointer-typed fields only, and change +the decoding rules for the field type. For regular pointer fields without the "nil" tag, +input values must always match the required input length exactly and the decoder does not +produce nil values. When the "nil" tag is set, input values of size zero decode as a nil +pointer. This is especially useful for recursive types. + + type StructWithNilField struct { + Field *[3]byte `rlp:"nil"` + } + +In the example above, Field allows two possible input sizes. For input 0xC180 (a list +containing an empty string) Field is set to nil after decoding. For input 0xC483000000 (a +list containing a 3-byte string), Field is set to a non-nil array pointer. + +RLP supports two kinds of empty values: empty lists and empty strings. When using the +"nil" tag, the kind of empty value allowed for a type is chosen automatically. A field +whose Go type is a pointer to an unsigned integer, string, boolean or byte array/slice +expects an empty RLP string. Any other pointer field type encodes/decodes as an empty RLP +list. + +The choice of null value can be made explicit with the "nilList" and "nilString" struct +tags. Using these tags encodes/decodes a Go nil pointer value as the empty RLP value kind +defined by the tag. */ package rlp diff --git a/rlp/encbuffer.go b/rlp/encbuffer.go new file mode 100644 index 0000000000..8d3a3b2293 --- /dev/null +++ b/rlp/encbuffer.go @@ -0,0 +1,423 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +import ( + "encoding/binary" + "io" + "math/big" + "reflect" + "sync" + + "github.com/holiman/uint256" +) + +type encBuffer struct { + str []byte // string data, contains everything except list headers + lheads []listhead // all list headers + lhsize int // sum of sizes of all encoded list headers + sizebuf [9]byte // auxiliary buffer for uint encoding +} + +// The global encBuffer pool. +var encBufferPool = sync.Pool{ + New: func() interface{} { return new(encBuffer) }, +} + +func getEncBuffer() *encBuffer { + buf := encBufferPool.Get().(*encBuffer) + buf.reset() + return buf +} + +func (buf *encBuffer) reset() { + buf.lhsize = 0 + buf.str = buf.str[:0] + buf.lheads = buf.lheads[:0] +} + +// size returns the length of the encoded data. +func (buf *encBuffer) size() int { + return len(buf.str) + buf.lhsize +} + +// makeBytes creates the encoder output. +func (buf *encBuffer) makeBytes() []byte { + out := make([]byte, buf.size()) + buf.copyTo(out) + return out +} + +func (buf *encBuffer) copyTo(dst []byte) { + strpos := 0 + pos := 0 + for _, head := range buf.lheads { + // write string data before header + n := copy(dst[pos:], buf.str[strpos:head.offset]) + pos += n + strpos += n + // write the header + enc := head.encode(dst[pos:]) + pos += len(enc) + } + // copy string data after the last list header + copy(dst[pos:], buf.str[strpos:]) +} + +// writeTo writes the encoder output to w. +func (buf *encBuffer) writeTo(w io.Writer) (err error) { + strpos := 0 + for _, head := range buf.lheads { + // write string data before header + if head.offset-strpos > 0 { + n, err := w.Write(buf.str[strpos:head.offset]) + strpos += n + if err != nil { + return err + } + } + // write the header + enc := head.encode(buf.sizebuf[:]) + if _, err = w.Write(enc); err != nil { + return err + } + } + if strpos < len(buf.str) { + // write string data after the last list header + _, err = w.Write(buf.str[strpos:]) + } + return err +} + +// Write implements io.Writer and appends b directly to the output. +func (buf *encBuffer) Write(b []byte) (int, error) { + buf.str = append(buf.str, b...) + return len(b), nil +} + +// writeBool writes b as the integer 0 (false) or 1 (true). +func (buf *encBuffer) writeBool(b bool) { + if b { + buf.str = append(buf.str, 0x01) + } else { + buf.str = append(buf.str, 0x80) + } +} + +func (buf *encBuffer) writeUint64(i uint64) { + if i == 0 { + buf.str = append(buf.str, 0x80) + } else if i < 128 { + // fits single byte + buf.str = append(buf.str, byte(i)) + } else { + s := putint(buf.sizebuf[1:], i) + buf.sizebuf[0] = 0x80 + byte(s) + buf.str = append(buf.str, buf.sizebuf[:s+1]...) + } +} + +func (buf *encBuffer) writeBytes(b []byte) { + if len(b) == 1 && b[0] <= 0x7F { + // fits single byte, no string header + buf.str = append(buf.str, b[0]) + } else { + buf.encodeStringHeader(len(b)) + buf.str = append(buf.str, b...) + } +} + +func (buf *encBuffer) writeString(s string) { + buf.writeBytes([]byte(s)) +} + +// wordBytes is the number of bytes in a big.Word +const wordBytes = (32 << (uint64(^big.Word(0)) >> 63)) / 8 + +// writeBigInt writes i as an integer. +func (buf *encBuffer) writeBigInt(i *big.Int) { + bitlen := i.BitLen() + if bitlen <= 64 { + buf.writeUint64(i.Uint64()) + return + } + // Integer is larger than 64 bits, encode from i.Bits(). + // The minimal byte length is bitlen rounded up to the next + // multiple of 8, divided by 8. + length := ((bitlen + 7) & -8) >> 3 + buf.encodeStringHeader(length) + buf.str = append(buf.str, make([]byte, length)...) + index := length + bytesBuf := buf.str[len(buf.str)-length:] + for _, d := range i.Bits() { + for j := 0; j < wordBytes && index > 0; j++ { + index-- + bytesBuf[index] = byte(d) + d >>= 8 + } + } +} + +// writeUint256 writes z as an integer. +func (buf *encBuffer) writeUint256(z *uint256.Int) { + bitlen := z.BitLen() + if bitlen <= 64 { + buf.writeUint64(z.Uint64()) + return + } + nBytes := byte((bitlen + 7) / 8) + var b [33]byte + binary.BigEndian.PutUint64(b[1:9], z[3]) + binary.BigEndian.PutUint64(b[9:17], z[2]) + binary.BigEndian.PutUint64(b[17:25], z[1]) + binary.BigEndian.PutUint64(b[25:33], z[0]) + b[32-nBytes] = 0x80 + nBytes + buf.str = append(buf.str, b[32-nBytes:]...) +} + +// list adds a new list header to the header stack. It returns the index of the header. +// Call listEnd with this index after encoding the content of the list. +func (buf *encBuffer) list() int { + buf.lheads = append(buf.lheads, listhead{offset: len(buf.str), size: buf.lhsize}) + return len(buf.lheads) - 1 +} + +func (buf *encBuffer) listEnd(index int) { + lh := &buf.lheads[index] + lh.size = buf.size() - lh.offset - lh.size + if lh.size < 56 { + buf.lhsize++ // length encoded into kind tag + } else { + buf.lhsize += 1 + intsize(uint64(lh.size)) + } +} + +func (buf *encBuffer) encode(val interface{}) error { + rval := reflect.ValueOf(val) + writer, err := cachedWriter(rval.Type()) + if err != nil { + return err + } + return writer(rval, buf) +} + +func (buf *encBuffer) encodeStringHeader(size int) { + if size < 56 { + buf.str = append(buf.str, 0x80+byte(size)) + } else { + sizesize := putint(buf.sizebuf[1:], uint64(size)) + buf.sizebuf[0] = 0xB7 + byte(sizesize) + buf.str = append(buf.str, buf.sizebuf[:sizesize+1]...) + } +} + +// encReader is the io.Reader returned by EncodeToReader. +// It releases its encbuf at EOF. +type encReader struct { + buf *encBuffer // the buffer we're reading from. this is nil when we're at EOF. + lhpos int // index of list header that we're reading + strpos int // current position in string buffer + piece []byte // next piece to be read +} + +func (r *encReader) Read(b []byte) (n int, err error) { + for { + if r.piece = r.next(); r.piece == nil { + // Put the encode buffer back into the pool at EOF when it + // is first encountered. Subsequent calls still return EOF + // as the error but the buffer is no longer valid. + if r.buf != nil { + encBufferPool.Put(r.buf) + r.buf = nil + } + return n, io.EOF + } + nn := copy(b[n:], r.piece) + n += nn + if nn < len(r.piece) { + // piece didn't fit, see you next time. + r.piece = r.piece[nn:] + return n, nil + } + r.piece = nil + } +} + +// next returns the next piece of data to be read. +// it returns nil at EOF. +func (r *encReader) next() []byte { + switch { + case r.buf == nil: + return nil + + case r.piece != nil: + // There is still data available for reading. + return r.piece + + case r.lhpos < len(r.buf.lheads): + // We're before the last list header. + head := r.buf.lheads[r.lhpos] + sizebefore := head.offset - r.strpos + if sizebefore > 0 { + // String data before header. + p := r.buf.str[r.strpos:head.offset] + r.strpos += sizebefore + return p + } + r.lhpos++ + return head.encode(r.buf.sizebuf[:]) + + case r.strpos < len(r.buf.str): + // String data at the end, after all list headers. + p := r.buf.str[r.strpos:] + r.strpos = len(r.buf.str) + return p + + default: + return nil + } +} + +func encBufferFromWriter(w io.Writer) *encBuffer { + switch w := w.(type) { + case EncoderBuffer: + return w.buf + case *EncoderBuffer: + return w.buf + case *encBuffer: + return w + default: + return nil + } +} + +// EncoderBuffer is a buffer for incremental encoding. +// +// The zero value is NOT ready for use. To get a usable buffer, +// create it using NewEncoderBuffer or call Reset. +type EncoderBuffer struct { + buf *encBuffer + dst io.Writer + + ownBuffer bool +} + +// NewEncoderBuffer creates an encoder buffer. +func NewEncoderBuffer(dst io.Writer) EncoderBuffer { + var w EncoderBuffer + w.Reset(dst) + return w +} + +// Reset truncates the buffer and sets the output destination. +func (w *EncoderBuffer) Reset(dst io.Writer) { + if w.buf != nil && !w.ownBuffer { + panic("can't Reset derived EncoderBuffer") + } + + // If the destination writer has an *encBuffer, use it. + // Note that w.ownBuffer is left false here. + if dst != nil { + if outer := encBufferFromWriter(dst); outer != nil { + *w = EncoderBuffer{outer, nil, false} + return + } + } + + // Get a fresh buffer. + if w.buf == nil { + w.buf = encBufferPool.Get().(*encBuffer) + w.ownBuffer = true + } + w.buf.reset() + w.dst = dst +} + +// Flush writes encoded RLP data to the output writer. This can only be called once. +// If you want to re-use the buffer after Flush, you must call Reset. +func (w *EncoderBuffer) Flush() error { + var err error + if w.dst != nil { + err = w.buf.writeTo(w.dst) + } + // Release the internal buffer. + if w.ownBuffer { + encBufferPool.Put(w.buf) + } + *w = EncoderBuffer{} + return err +} + +// ToBytes returns the encoded bytes. +func (w *EncoderBuffer) ToBytes() []byte { + return w.buf.makeBytes() +} + +// AppendToBytes appends the encoded bytes to dst. +func (w *EncoderBuffer) AppendToBytes(dst []byte) []byte { + size := w.buf.size() + out := append(dst, make([]byte, size)...) + w.buf.copyTo(out[len(dst):]) + return out +} + +// Write appends b directly to the encoder output. +func (w EncoderBuffer) Write(b []byte) (int, error) { + return w.buf.Write(b) +} + +// WriteBool writes b as the integer 0 (false) or 1 (true). +func (w EncoderBuffer) WriteBool(b bool) { + w.buf.writeBool(b) +} + +// WriteUint64 encodes an unsigned integer. +func (w EncoderBuffer) WriteUint64(i uint64) { + w.buf.writeUint64(i) +} + +// WriteBigInt encodes a big.Int as an RLP string. +// Note: Unlike with Encode, the sign of i is ignored. +func (w EncoderBuffer) WriteBigInt(i *big.Int) { + w.buf.writeBigInt(i) +} + +// WriteUint256 encodes uint256.Int as an RLP string. +func (w EncoderBuffer) WriteUint256(i *uint256.Int) { + w.buf.writeUint256(i) +} + +// WriteBytes encodes b as an RLP string. +func (w EncoderBuffer) WriteBytes(b []byte) { + w.buf.writeBytes(b) +} + +// WriteString encodes s as an RLP string. +func (w EncoderBuffer) WriteString(s string) { + w.buf.writeString(s) +} + +// List starts a list. It returns an internal index. Call EndList with +// this index after encoding the content to finish the list. +func (w EncoderBuffer) List() int { + return w.buf.list() +} + +// ListEnd finishes the given list. +func (w EncoderBuffer) ListEnd(index int) { + w.buf.listEnd(index) +} diff --git a/rlp/encbuffer_example_test.go b/rlp/encbuffer_example_test.go new file mode 100644 index 0000000000..f737f3e40c --- /dev/null +++ b/rlp/encbuffer_example_test.go @@ -0,0 +1,45 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp_test + +import ( + "bytes" + "fmt" + + "github.com/XinFinOrg/XDPoSChain/rlp" +) + +func ExampleEncoderBuffer() { + var w bytes.Buffer + + // Encode [4, [5, 6]] to w. + buf := rlp.NewEncoderBuffer(&w) + l1 := buf.List() + buf.WriteUint64(4) + l2 := buf.List() + buf.WriteUint64(5) + buf.WriteUint64(6) + buf.ListEnd(l2) + buf.ListEnd(l1) + + if err := buf.Flush(); err != nil { + panic(err) + } + fmt.Printf("%X\n", w.Bytes()) + // Output: + // C404C20506 +} diff --git a/rlp/encode.go b/rlp/encode.go index 44592c2f53..9435cfc22c 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -17,20 +17,28 @@ package rlp import ( + "errors" "fmt" "io" "math/big" "reflect" - "sync" + + "github.com/XinFinOrg/XDPoSChain/rlp/internal/rlpstruct" + "github.com/holiman/uint256" ) var ( // Common encoded values. // These are useful when implementing EncodeRLP. + + // EmptyString is the encoding of an empty string. EmptyString = []byte{0x80} - EmptyList = []byte{0xC0} + // EmptyList is the encoding of an empty list. + EmptyList = []byte{0xC0} ) +var ErrNegativeBigInt = errors.New("rlp: cannot encode negative big.Int") + // Encoder is implemented by types that require custom // encoding rules or want to encode private fields. type Encoder interface { @@ -49,80 +57,48 @@ type Encoder interface { // perform many small writes in some cases. Consider making w // buffered. // -// Encode uses the following type-dependent encoding rules: -// -// If the type implements the Encoder interface, Encode calls -// EncodeRLP. This is true even for nil pointers, please see the -// documentation for Encoder. -// -// To encode a pointer, the value being pointed to is encoded. For nil -// pointers, Encode will encode the zero value of the type. A nil -// pointer to a struct type always encodes as an empty RLP list. -// A nil pointer to an array encodes as an empty list (or empty string -// if the array has element type byte). -// -// Struct values are encoded as an RLP list of all their encoded -// public fields. Recursive struct types are supported. -// -// To encode slices and arrays, the elements are encoded as an RLP -// list of the value's elements. Note that arrays and slices with -// element type uint8 or byte are always encoded as an RLP string. -// -// A Go string is encoded as an RLP string. -// -// An unsigned integer value is encoded as an RLP string. Zero always -// encodes as an empty RLP string. Encode also supports *big.Int. -// -// An interface value encodes as the value contained in the interface. -// -// Boolean values are not supported, nor are signed integers, floating -// point numbers, maps, channels and functions. +// Please see package-level documentation of encoding rules. func Encode(w io.Writer, val interface{}) error { - if outer, ok := w.(*encbuf); ok { - // Encode was called by some type's EncodeRLP. - // Avoid copying by writing to the outer encbuf directly. - return outer.encode(val) + // Optimization: reuse *encBuffer when called by EncodeRLP. + if buf := encBufferFromWriter(w); buf != nil { + return buf.encode(val) } - eb := encbufPool.Get().(*encbuf) - defer encbufPool.Put(eb) - eb.reset() - if err := eb.encode(val); err != nil { + + buf := getEncBuffer() + defer encBufferPool.Put(buf) + if err := buf.encode(val); err != nil { return err } - return eb.toWriter(w) + return buf.writeTo(w) } -// EncodeBytes returns the RLP encoding of val. -// Please see the documentation of Encode for the encoding rules. +// EncodeToBytes returns the RLP encoding of val. +// Please see package-level documentation for the encoding rules. func EncodeToBytes(val interface{}) ([]byte, error) { - eb := encbufPool.Get().(*encbuf) - defer encbufPool.Put(eb) - eb.reset() - if err := eb.encode(val); err != nil { + buf := getEncBuffer() + defer encBufferPool.Put(buf) + + if err := buf.encode(val); err != nil { return nil, err } - return eb.toBytes(), nil + return buf.makeBytes(), nil } -// EncodeReader returns a reader from which the RLP encoding of val +// EncodeToReader returns a reader from which the RLP encoding of val // can be read. The returned size is the total size of the encoded // data. // // Please see the documentation of Encode for the encoding rules. func EncodeToReader(val interface{}) (size int, r io.Reader, err error) { - eb := encbufPool.Get().(*encbuf) - eb.reset() - if err := eb.encode(val); err != nil { + buf := getEncBuffer() + if err := buf.encode(val); err != nil { + encBufferPool.Put(buf) return 0, nil, err } - return eb.size(), &encReader{buf: eb}, nil -} - -type encbuf struct { - str []byte // string data, contains everything except list headers - lheads []*listhead // all list headers - lhsize int // sum of sizes of all encoded list headers - sizebuf []byte // 9-byte auxiliary buffer for uint encoding + // Note: can't put the reader back into the pool here + // because it is held by encReader. The reader puts it + // back when it has been fully consumed. + return buf.size(), &encReader{buf: buf}, nil } type listhead struct { @@ -151,214 +127,32 @@ func puthead(buf []byte, smalltag, largetag byte, size uint64) int { if size < 56 { buf[0] = smalltag + byte(size) return 1 - } else { - sizesize := putint(buf[1:], size) - buf[0] = largetag + byte(sizesize) - return sizesize + 1 } + sizesize := putint(buf[1:], size) + buf[0] = largetag + byte(sizesize) + return sizesize + 1 } -// encbufs are pooled. -var encbufPool = sync.Pool{ - New: func() interface{} { return &encbuf{sizebuf: make([]byte, 9)} }, -} - -func (w *encbuf) reset() { - w.lhsize = 0 - if w.str != nil { - w.str = w.str[:0] - } - if w.lheads != nil { - w.lheads = w.lheads[:0] - } -} - -// encbuf implements io.Writer so it can be passed it into EncodeRLP. -func (w *encbuf) Write(b []byte) (int, error) { - w.str = append(w.str, b...) - return len(b), nil -} - -func (w *encbuf) encode(val interface{}) error { - rval := reflect.ValueOf(val) - ti, err := cachedTypeInfo(rval.Type(), tags{}) - if err != nil { - return err - } - return ti.writer(rval, w) -} - -func (w *encbuf) encodeStringHeader(size int) { - if size < 56 { - w.str = append(w.str, 0x80+byte(size)) - } else { - // TODO: encode to w.str directly - sizesize := putint(w.sizebuf[1:], uint64(size)) - w.sizebuf[0] = 0xB7 + byte(sizesize) - w.str = append(w.str, w.sizebuf[:sizesize+1]...) - } -} - -func (w *encbuf) encodeString(b []byte) { - if len(b) == 1 && b[0] <= 0x7F { - // fits single byte, no string header - w.str = append(w.str, b[0]) - } else { - w.encodeStringHeader(len(b)) - w.str = append(w.str, b...) - } -} - -func (w *encbuf) list() *listhead { - lh := &listhead{offset: len(w.str), size: w.lhsize} - w.lheads = append(w.lheads, lh) - return lh -} - -func (w *encbuf) listEnd(lh *listhead) { - lh.size = w.size() - lh.offset - lh.size - if lh.size < 56 { - w.lhsize += 1 // length encoded into kind tag - } else { - w.lhsize += 1 + intsize(uint64(lh.size)) - } -} - -func (w *encbuf) size() int { - return len(w.str) + w.lhsize -} - -func (w *encbuf) toBytes() []byte { - out := make([]byte, w.size()) - strpos := 0 - pos := 0 - for _, head := range w.lheads { - // write string data before header - n := copy(out[pos:], w.str[strpos:head.offset]) - pos += n - strpos += n - // write the header - enc := head.encode(out[pos:]) - pos += len(enc) - } - // copy string data after the last list header - copy(out[pos:], w.str[strpos:]) - return out -} - -func (w *encbuf) toWriter(out io.Writer) (err error) { - strpos := 0 - for _, head := range w.lheads { - // write string data before header - if head.offset-strpos > 0 { - n, err := out.Write(w.str[strpos:head.offset]) - strpos += n - if err != nil { - return err - } - } - // write the header - enc := head.encode(w.sizebuf) - if _, err = out.Write(enc); err != nil { - return err - } - } - if strpos < len(w.str) { - // write string data after the last list header - _, err = out.Write(w.str[strpos:]) - } - return err -} - -// encReader is the io.Reader returned by EncodeToReader. -// It releases its encbuf at EOF. -type encReader struct { - buf *encbuf // the buffer we're reading from. this is nil when we're at EOF. - lhpos int // index of list header that we're reading - strpos int // current position in string buffer - piece []byte // next piece to be read -} - -func (r *encReader) Read(b []byte) (n int, err error) { - for { - if r.piece = r.next(); r.piece == nil { - // Put the encode buffer back into the pool at EOF when it - // is first encountered. Subsequent calls still return EOF - // as the error but the buffer is no longer valid. - if r.buf != nil { - encbufPool.Put(r.buf) - r.buf = nil - } - return n, io.EOF - } - nn := copy(b[n:], r.piece) - n += nn - if nn < len(r.piece) { - // piece didn't fit, see you next time. - r.piece = r.piece[nn:] - return n, nil - } - r.piece = nil - } -} - -// next returns the next piece of data to be read. -// it returns nil at EOF. -func (r *encReader) next() []byte { - switch { - case r.buf == nil: - return nil - - case r.piece != nil: - // There is still data available for reading. - return r.piece - - case r.lhpos < len(r.buf.lheads): - // We're before the last list header. - head := r.buf.lheads[r.lhpos] - sizebefore := head.offset - r.strpos - if sizebefore > 0 { - // String data before header. - p := r.buf.str[r.strpos:head.offset] - r.strpos += sizebefore - return p - } else { - r.lhpos++ - return head.encode(r.buf.sizebuf) - } - - case r.strpos < len(r.buf.str): - // String data at the end, after all list headers. - p := r.buf.str[r.strpos:] - r.strpos = len(r.buf.str) - return p - - default: - return nil - } -} - -var ( - encoderInterface = reflect.TypeOf(new(Encoder)).Elem() - big0 = big.NewInt(0) -) +var encoderInterface = reflect.TypeOf(new(Encoder)).Elem() // makeWriter creates a writer function for the given type. -func makeWriter(typ reflect.Type, ts tags) (writer, error) { +func makeWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { kind := typ.Kind() switch { case typ == rawValueType: return writeRawValue, nil - case typ.Implements(encoderInterface): - return writeEncoder, nil - case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(encoderInterface): - return writeEncoderNoPtr, nil - case kind == reflect.Interface: - return writeInterface, nil - case typ.AssignableTo(reflect.PtrTo(bigInt)): + case typ.AssignableTo(reflect.PointerTo(bigInt)): return writeBigIntPtr, nil case typ.AssignableTo(bigInt): return writeBigIntNoPtr, nil + case typ == reflect.PointerTo(u256Int): + return writeU256IntPtr, nil + case typ == u256Int: + return writeU256IntNoPtr, nil + case kind == reflect.Ptr: + return makePtrWriter(typ, ts) + case reflect.PointerTo(typ).Implements(encoderInterface): + return makeEncoderWriter(typ), nil case isUint(kind): return writeUint, nil case kind == reflect.Bool: @@ -368,97 +162,116 @@ func makeWriter(typ reflect.Type, ts tags) (writer, error) { case kind == reflect.Slice && isByte(typ.Elem()): return writeBytes, nil case kind == reflect.Array && isByte(typ.Elem()): - return writeByteArray, nil + return makeByteArrayWriter(typ), nil case kind == reflect.Slice || kind == reflect.Array: return makeSliceWriter(typ, ts) case kind == reflect.Struct: return makeStructWriter(typ) - case kind == reflect.Ptr: - return makePtrWriter(typ) + case kind == reflect.Interface: + return writeInterface, nil default: return nil, fmt.Errorf("rlp: type %v is not RLP-serializable", typ) } } -func isByte(typ reflect.Type) bool { - return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface) -} - -func writeRawValue(val reflect.Value, w *encbuf) error { +func writeRawValue(val reflect.Value, w *encBuffer) error { w.str = append(w.str, val.Bytes()...) return nil } -func writeUint(val reflect.Value, w *encbuf) error { - i := val.Uint() - if i == 0 { - w.str = append(w.str, 0x80) - } else if i < 128 { - // fits single byte - w.str = append(w.str, byte(i)) - } else { - // TODO: encode int to w.str directly - s := putint(w.sizebuf[1:], i) - w.sizebuf[0] = 0x80 + byte(s) - w.str = append(w.str, w.sizebuf[:s+1]...) - } +func writeUint(val reflect.Value, w *encBuffer) error { + w.writeUint64(val.Uint()) return nil } -func writeBool(val reflect.Value, w *encbuf) error { - if val.Bool() { - w.str = append(w.str, 0x01) - } else { - w.str = append(w.str, 0x80) - } +func writeBool(val reflect.Value, w *encBuffer) error { + w.writeBool(val.Bool()) return nil } -func writeBigIntPtr(val reflect.Value, w *encbuf) error { +func writeBigIntPtr(val reflect.Value, w *encBuffer) error { ptr := val.Interface().(*big.Int) if ptr == nil { w.str = append(w.str, 0x80) return nil } - return writeBigInt(ptr, w) + if ptr.Sign() == -1 { + return ErrNegativeBigInt + } + w.writeBigInt(ptr) + return nil } -func writeBigIntNoPtr(val reflect.Value, w *encbuf) error { +func writeBigIntNoPtr(val reflect.Value, w *encBuffer) error { i := val.Interface().(big.Int) - return writeBigInt(&i, w) + if i.Sign() == -1 { + return ErrNegativeBigInt + } + w.writeBigInt(&i) + return nil } -func writeBigInt(i *big.Int, w *encbuf) error { - if cmp := i.Cmp(big0); cmp == -1 { - return fmt.Errorf("rlp: cannot encode negative *big.Int") - } else if cmp == 0 { +func writeU256IntPtr(val reflect.Value, w *encBuffer) error { + ptr := val.Interface().(*uint256.Int) + if ptr == nil { w.str = append(w.str, 0x80) + return nil + } + w.writeUint256(ptr) + return nil +} + +func writeU256IntNoPtr(val reflect.Value, w *encBuffer) error { + i := val.Interface().(uint256.Int) + w.writeUint256(&i) + return nil +} + +func writeBytes(val reflect.Value, w *encBuffer) error { + w.writeBytes(val.Bytes()) + return nil +} + +func makeByteArrayWriter(typ reflect.Type) writer { + switch typ.Len() { + case 0: + return writeLengthZeroByteArray + case 1: + return writeLengthOneByteArray + default: + length := typ.Len() + return func(val reflect.Value, w *encBuffer) error { + if !val.CanAddr() { + // Getting the byte slice of val requires it to be addressable. Make it + // addressable by copying. + copy := reflect.New(val.Type()).Elem() + copy.Set(val) + val = copy + } + slice := byteArrayBytes(val, length) + w.encodeStringHeader(len(slice)) + w.str = append(w.str, slice...) + return nil + } + } +} + +func writeLengthZeroByteArray(val reflect.Value, w *encBuffer) error { + w.str = append(w.str, 0x80) + return nil +} + +func writeLengthOneByteArray(val reflect.Value, w *encBuffer) error { + b := byte(val.Index(0).Uint()) + if b <= 0x7f { + w.str = append(w.str, b) } else { - w.encodeString(i.Bytes()) + w.str = append(w.str, 0x81, b) } return nil } -func writeBytes(val reflect.Value, w *encbuf) error { - w.encodeString(val.Bytes()) - return nil -} - -func writeByteArray(val reflect.Value, w *encbuf) error { - if !val.CanAddr() { - // Slice requires the value to be addressable. - // Make it addressable by copying. - copy := reflect.New(val.Type()).Elem() - copy.Set(val) - val = copy - } - size := val.Len() - slice := val.Slice(0, size).Bytes() - w.encodeString(slice) - return nil -} - -func writeString(val reflect.Value, w *encbuf) error { +func writeString(val reflect.Value, w *encBuffer) error { s := val.String() if len(s) == 1 && s[0] <= 0x7f { // fits single byte, no string header @@ -470,27 +283,7 @@ func writeString(val reflect.Value, w *encbuf) error { return nil } -func writeEncoder(val reflect.Value, w *encbuf) error { - return val.Interface().(Encoder).EncodeRLP(w) -} - -// writeEncoderNoPtr handles non-pointer values that implement Encoder -// with a pointer receiver. -func writeEncoderNoPtr(val reflect.Value, w *encbuf) error { - if !val.CanAddr() { - // We can't get the address. It would be possible to make the - // value addressable by creating a shallow copy, but this - // creates other problems so we're not doing it (yet). - // - // package json simply doesn't call MarshalJSON for cases like - // this, but encodes the value as if it didn't implement the - // interface. We don't want to handle it that way. - return fmt.Errorf("rlp: game over: unadressable value of type %v, EncodeRLP is pointer method", val.Type()) - } - return val.Addr().Interface().(Encoder).EncodeRLP(w) -} - -func writeInterface(val reflect.Value, w *encbuf) error { +func writeInterface(val reflect.Value, w *encBuffer) error { if val.IsNil() { // Write empty list. This is consistent with the previous RLP // encoder that we had and should therefore avoid any @@ -499,31 +292,51 @@ func writeInterface(val reflect.Value, w *encbuf) error { return nil } eval := val.Elem() - ti, err := cachedTypeInfo(eval.Type(), tags{}) + writer, err := cachedWriter(eval.Type()) if err != nil { return err } - return ti.writer(eval, w) + return writer(eval, w) } -func makeSliceWriter(typ reflect.Type, ts tags) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err +func makeSliceWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { + etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr } - writer := func(val reflect.Value, w *encbuf) error { - if !ts.tail { - defer w.listEnd(w.list()) - } - vlen := val.Len() - for i := 0; i < vlen; i++ { - if err := etypeinfo.writer(val.Index(i), w); err != nil { - return err + + var wfn writer + if ts.Tail { + // This is for struct tail slices. + // w.list is not called for them. + wfn = func(val reflect.Value, w *encBuffer) error { + vlen := val.Len() + for i := 0; i < vlen; i++ { + if err := etypeinfo.writer(val.Index(i), w); err != nil { + return err + } } + return nil + } + } else { + // This is for regular slices and arrays. + wfn = func(val reflect.Value, w *encBuffer) error { + vlen := val.Len() + if vlen == 0 { + w.str = append(w.str, 0xC0) + return nil + } + listOffset := w.list() + for i := 0; i < vlen; i++ { + if err := etypeinfo.writer(val.Index(i), w); err != nil { + return err + } + } + w.listEnd(listOffset) + return nil } - return nil } - return writer, nil + return wfn, nil } func makeStructWriter(typ reflect.Type) (writer, error) { @@ -531,56 +344,86 @@ func makeStructWriter(typ reflect.Type) (writer, error) { if err != nil { return nil, err } - writer := func(val reflect.Value, w *encbuf) error { - lh := w.list() - for _, f := range fields { - if err := f.info.writer(val.Field(f.index), w); err != nil { - return err - } + for _, f := range fields { + if f.info.writerErr != nil { + return nil, structFieldError{typ, f.index, f.info.writerErr} } - w.listEnd(lh) + } + + var writer writer + firstOptionalField := firstOptionalField(fields) + if firstOptionalField == len(fields) { + // This is the writer function for structs without any optional fields. + writer = func(val reflect.Value, w *encBuffer) error { + lh := w.list() + for _, f := range fields { + if err := f.info.writer(val.Field(f.index), w); err != nil { + return err + } + } + w.listEnd(lh) + return nil + } + } else { + // If there are any "optional" fields, the writer needs to perform additional + // checks to determine the output list length. + writer = func(val reflect.Value, w *encBuffer) error { + lastField := len(fields) - 1 + for ; lastField >= firstOptionalField; lastField-- { + if !val.Field(fields[lastField].index).IsZero() { + break + } + } + lh := w.list() + for i := 0; i <= lastField; i++ { + if err := fields[i].info.writer(val.Field(fields[i].index), w); err != nil { + return err + } + } + w.listEnd(lh) + return nil + } + } + return writer, nil +} + +func makePtrWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { + nilEncoding := byte(0xC0) + if typeNilKind(typ.Elem(), ts) == String { + nilEncoding = 0x80 + } + + etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr + } + + writer := func(val reflect.Value, w *encBuffer) error { + if ev := val.Elem(); ev.IsValid() { + return etypeinfo.writer(ev, w) + } + w.str = append(w.str, nilEncoding) return nil } return writer, nil } -func makePtrWriter(typ reflect.Type) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err - } - - // determine nil pointer handler - var nilfunc func(*encbuf) error - kind := typ.Elem().Kind() - switch { - case kind == reflect.Array && isByte(typ.Elem().Elem()): - nilfunc = func(w *encbuf) error { - w.str = append(w.str, 0x80) - return nil - } - case kind == reflect.Struct || kind == reflect.Array: - nilfunc = func(w *encbuf) error { - // encoding the zero value of a struct/array could trigger - // infinite recursion, avoid that. - w.listEnd(w.list()) - return nil - } - default: - zero := reflect.Zero(typ.Elem()) - nilfunc = func(w *encbuf) error { - return etypeinfo.writer(zero, w) +func makeEncoderWriter(typ reflect.Type) writer { + if typ.Implements(encoderInterface) { + return func(val reflect.Value, w *encBuffer) error { + return val.Interface().(Encoder).EncodeRLP(w) } } - - writer := func(val reflect.Value, w *encbuf) error { - if val.IsNil() { - return nilfunc(w) - } else { - return etypeinfo.writer(val.Elem(), w) + w := func(val reflect.Value, w *encBuffer) error { + if !val.CanAddr() { + // package json simply doesn't call MarshalJSON for this case, but encodes the + // value as if it didn't implement the interface. We don't want to handle it that + // way. + return fmt.Errorf("rlp: unaddressable value of type %v, EncodeRLP is pointer method", val.Type()) } + return val.Addr().Interface().(Encoder).EncodeRLP(w) } - return writer, err + return w } // putint writes i to the beginning of b in big endian byte diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 12e8f27555..5fc2d116ef 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -22,8 +22,12 @@ import ( "fmt" "io" "math/big" + "runtime" "sync" "testing" + + "github.com/XinFinOrg/XDPoSChain/common/math" + "github.com/holiman/uint256" ) type testEncoder struct { @@ -32,12 +36,19 @@ type testEncoder struct { func (e *testEncoder) EncodeRLP(w io.Writer) error { if e == nil { - w.Write([]byte{0, 0, 0, 0}) - } else if e.err != nil { - return e.err - } else { - w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) + panic("EncodeRLP called on nil value") } + if e.err != nil { + return e.err + } + w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) + return nil +} + +type testEncoderValueMethod struct{} + +func (e testEncoderValueMethod) EncodeRLP(w io.Writer) error { + w.Write([]byte{0xFA, 0xFE, 0xF0}) return nil } @@ -48,6 +59,13 @@ func (e byteEncoder) EncodeRLP(w io.Writer) error { return nil } +type undecodableEncoder func() + +func (f undecodableEncoder) EncodeRLP(w io.Writer) error { + w.Write([]byte{0xF5, 0xF5, 0xF5}) + return nil +} + type encodableReader struct { A, B uint } @@ -102,35 +120,95 @@ var encTests = []encTest{ {val: big.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"}, {val: big.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"}, { - val: big.NewInt(0).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), + val: new(big.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), output: "8F102030405060708090A0B0C0D0E0F2", }, { - val: big.NewInt(0).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), + val: new(big.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), output: "9C0100020003000400050006000700080009000A000B000C000D000E01", }, { - val: big.NewInt(0).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")), + val: new(big.Int).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")), output: "A1010000000000000000000000000000000000000000000000000000000000000000", }, + { + val: veryBigInt, + output: "89FFFFFFFFFFFFFFFFFF", + }, + { + val: veryVeryBigInt, + output: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", + }, // non-pointer big.Int {val: *big.NewInt(0), output: "80"}, {val: *big.NewInt(0xFFFFFF), output: "83FFFFFF"}, // negative ints are not supported - {val: big.NewInt(-1), error: "rlp: cannot encode negative *big.Int"}, + {val: big.NewInt(-1), error: "rlp: cannot encode negative big.Int"}, + {val: *big.NewInt(-1), error: "rlp: cannot encode negative big.Int"}, - // byte slices, strings + // uint256 + {val: uint256.NewInt(0), output: "80"}, + {val: uint256.NewInt(1), output: "01"}, + {val: uint256.NewInt(127), output: "7F"}, + {val: uint256.NewInt(128), output: "8180"}, + {val: uint256.NewInt(256), output: "820100"}, + {val: uint256.NewInt(1024), output: "820400"}, + {val: uint256.NewInt(0xFFFFFF), output: "83FFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFF), output: "84FFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFF), output: "85FFFFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"}, + { + val: new(uint256.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), + output: "8F102030405060708090A0B0C0D0E0F2", + }, + { + val: new(uint256.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), + output: "9C0100020003000400050006000700080009000A000B000C000D000E01", + }, + // non-pointer uint256.Int + {val: *uint256.NewInt(0), output: "80"}, + {val: *uint256.NewInt(0xFFFFFF), output: "83FFFFFF"}, + + // byte arrays + {val: [0]byte{}, output: "80"}, + {val: [1]byte{0}, output: "00"}, + {val: [1]byte{1}, output: "01"}, + {val: [1]byte{0x7F}, output: "7F"}, + {val: [1]byte{0x80}, output: "8180"}, + {val: [1]byte{0xFF}, output: "81FF"}, + {val: [3]byte{1, 2, 3}, output: "83010203"}, + {val: [57]byte{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + + // named byte type arrays + {val: [0]namedByteType{}, output: "80"}, + {val: [1]namedByteType{0}, output: "00"}, + {val: [1]namedByteType{1}, output: "01"}, + {val: [1]namedByteType{0x7F}, output: "7F"}, + {val: [1]namedByteType{0x80}, output: "8180"}, + {val: [1]namedByteType{0xFF}, output: "81FF"}, + {val: [3]namedByteType{1, 2, 3}, output: "83010203"}, + {val: [57]namedByteType{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + + // byte slices {val: []byte{}, output: "80"}, + {val: []byte{0}, output: "00"}, {val: []byte{0x7E}, output: "7E"}, {val: []byte{0x7F}, output: "7F"}, {val: []byte{0x80}, output: "8180"}, {val: []byte{1, 2, 3}, output: "83010203"}, + // named byte type slices + {val: []namedByteType{}, output: "80"}, + {val: []namedByteType{0}, output: "00"}, + {val: []namedByteType{0x7E}, output: "7E"}, + {val: []namedByteType{0x7F}, output: "7F"}, + {val: []namedByteType{0x80}, output: "8180"}, {val: []namedByteType{1, 2, 3}, output: "83010203"}, - {val: [...]namedByteType{1, 2, 3}, output: "83010203"}, + // strings {val: "", output: "80"}, {val: "\x7E", output: "7E"}, {val: "\x7F", output: "7F"}, @@ -203,6 +281,12 @@ var encTests = []encTest{ output: "F90200CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376", }, + // Non-byte arrays are encoded as lists. + // Note that it is important to test [4]uint64 specifically, + // because that's the underlying type of uint256.Int. + {val: [4]uint32{1, 2, 3, 4}, output: "C401020304"}, + {val: [4]uint64{1, 2, 3, 4}, output: "C401020304"}, + // RawValue {val: RawValue(unhex("01")), output: "01"}, {val: RawValue(unhex("82FFFF")), output: "82FFFF"}, @@ -213,11 +297,34 @@ var encTests = []encTest{ {val: simplestruct{A: 3, B: "foo"}, output: "C50383666F6F"}, {val: &recstruct{5, nil}, output: "C205C0"}, {val: &recstruct{5, &recstruct{4, &recstruct{3, nil}}}, output: "C605C404C203C0"}, + {val: &intField{X: 3}, error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)"}, + + // struct tag "-" + {val: &ignoredField{A: 1, B: 2, C: 3}, output: "C20103"}, + + // struct tag "tail" {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02"), unhex("03")}}, output: "C3010203"}, {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02")}}, output: "C20102"}, {val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"}, {val: &tailRaw{A: 1, Tail: nil}, output: "C101"}, - {val: &hasIgnoredField{A: 1, B: 2, C: 3}, output: "C20103"}, + + // struct tag "optional" + {val: &optionalFields{}, output: "C180"}, + {val: &optionalFields{A: 1}, output: "C101"}, + {val: &optionalFields{A: 1, B: 2}, output: "C20102"}, + {val: &optionalFields{A: 1, B: 2, C: 3}, output: "C3010203"}, + {val: &optionalFields{A: 1, B: 0, C: 3}, output: "C3018003"}, + {val: &optionalAndTailField{A: 1}, output: "C101"}, + {val: &optionalAndTailField{A: 1, B: 2}, output: "C20102"}, + {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"}, + {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"}, + {val: &optionalBigIntField{A: 1}, output: "C101"}, + {val: &optionalPtrField{A: 1}, output: "C101"}, + {val: &optionalPtrFieldNil{A: 1}, output: "C101"}, + {val: &multipleOptionalFields{A: nil, B: nil}, output: "C0"}, + {val: &multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}}, output: "C88301020383010203"}, + {val: &multipleOptionalFields{A: nil, B: &[3]byte{1, 2, 3}}, output: "C58083010203"}, // encodes without error but decode will fail + {val: &nonOptionalPtrField{A: 1}, output: "C20180"}, // encodes without error but decode will fail // nil {val: (*uint)(nil), output: "80"}, @@ -225,26 +332,73 @@ var encTests = []encTest{ {val: (*[]byte)(nil), output: "80"}, {val: (*[10]byte)(nil), output: "80"}, {val: (*big.Int)(nil), output: "80"}, + {val: (*uint256.Int)(nil), output: "80"}, {val: (*[]string)(nil), output: "C0"}, {val: (*[10]string)(nil), output: "C0"}, {val: (*[]interface{})(nil), output: "C0"}, {val: (*[]struct{ uint })(nil), output: "C0"}, {val: (*interface{})(nil), output: "C0"}, + // nil struct fields + { + val: struct { + X *[]byte + }{}, + output: "C180", + }, + { + val: struct { + X *[2]byte + }{}, + output: "C180", + }, + { + val: struct { + X *uint64 + }{}, + output: "C180", + }, + { + val: struct { + X *uint64 `rlp:"nilList"` + }{}, + output: "C1C0", + }, + { + val: struct { + X *[]uint64 + }{}, + output: "C1C0", + }, + { + val: struct { + X *[]uint64 `rlp:"nilString"` + }{}, + output: "C180", + }, + // interfaces {val: []io.Reader{reader}, output: "C3C20102"}, // the contained value is a struct // Encoder - {val: (*testEncoder)(nil), output: "00000000"}, + {val: (*testEncoder)(nil), output: "C0"}, {val: &testEncoder{}, output: "00010001000100010001"}, {val: &testEncoder{errors.New("test error")}, error: "test error"}, - // verify that pointer method testEncoder.EncodeRLP is called for + {val: struct{ E testEncoderValueMethod }{}, output: "C3FAFEF0"}, + {val: struct{ E *testEncoderValueMethod }{}, output: "C1C0"}, + + // Verify that the Encoder interface works for unsupported types like func(). + {val: undecodableEncoder(func() {}), output: "F5F5F5"}, + + // Verify that pointer method testEncoder.EncodeRLP is called for // addressable non-pointer values. {val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"}, {val: &struct{ TE testEncoder }{testEncoder{errors.New("test error")}}, error: "test error"}, - // verify the error for non-addressable non-pointer Encoder - {val: testEncoder{}, error: "rlp: game over: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"}, - // verify the special case for []byte + + // Verify the error for non-addressable non-pointer Encoder. + {val: testEncoder{}, error: "rlp: unaddressable value of type rlp.testEncoder, EncodeRLP is pointer method"}, + + // Verify Encoder takes precedence over []byte. {val: []byteEncoder{0, 1, 2, 3, 4}, output: "C5C0C0C0C0C0"}, } @@ -280,6 +434,21 @@ func TestEncodeToBytes(t *testing.T) { runEncTests(t, EncodeToBytes) } +func TestEncodeAppendToBytes(t *testing.T) { + buffer := make([]byte, 20) + runEncTests(t, func(val interface{}) ([]byte, error) { + w := NewEncoderBuffer(nil) + defer w.Flush() + + err := Encode(w, val) + if err != nil { + return nil, err + } + output := w.AppendToBytes(buffer[:0]) + return output, nil + }) +} + func TestEncodeToReader(t *testing.T) { runEncTests(t, func(val interface{}) ([]byte, error) { _, r, err := EncodeToReader(val) @@ -338,3 +507,132 @@ func TestEncodeToReaderReturnToPool(t *testing.T) { } wg.Wait() } + +var sink interface{} + +func BenchmarkIntsize(b *testing.B) { + for i := 0; i < b.N; i++ { + sink = intsize(0x12345678) + } +} + +func BenchmarkPutint(b *testing.B) { + buf := make([]byte, 8) + for i := 0; i < b.N; i++ { + putint(buf, 0x12345678) + sink = buf + } +} + +func BenchmarkEncodeBigInts(b *testing.B) { + ints := make([]*big.Int, 200) + for i := range ints { + ints[i] = math.BigPow(2, int64(i)) + } + out := bytes.NewBuffer(make([]byte, 0, 4096)) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(out, ints); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeU256Ints(b *testing.B) { + ints := make([]*uint256.Int, 200) + for i := range ints { + ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i))) + } + out := bytes.NewBuffer(make([]byte, 0, 4096)) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(out, ints); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeConcurrentInterface(b *testing.B) { + type struct1 struct { + A string + B *big.Int + C [20]byte + } + value := []interface{}{ + uint(999), + &struct1{A: "hello", B: big.NewInt(0xFFFFFFFF)}, + [10]byte{1, 2, 3, 4, 5, 6}, + []string{"yeah", "yeah", "yeah"}, + } + + var wg sync.WaitGroup + for cpu := 0; cpu < runtime.NumCPU(); cpu++ { + wg.Add(1) + go func() { + defer wg.Done() + + var buffer bytes.Buffer + for i := 0; i < b.N; i++ { + buffer.Reset() + err := Encode(&buffer, value) + if err != nil { + panic(err) + } + } + }() + } + wg.Wait() +} + +type byteArrayStruct struct { + A [20]byte + B [32]byte + C [32]byte +} + +func BenchmarkEncodeByteArrayStruct(b *testing.B) { + var out bytes.Buffer + var value byteArrayStruct + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(&out, &value); err != nil { + b.Fatal(err) + } + } +} + +type structSliceElem struct { + X uint64 + Y uint64 + Z uint64 +} + +type structPtrSlice []*structSliceElem + +func BenchmarkEncodeStructPtrSlice(b *testing.B) { + var out bytes.Buffer + var value = structPtrSlice{ + &structSliceElem{1, 1, 1}, + &structSliceElem{2, 2, 2}, + &structSliceElem{3, 3, 3}, + &structSliceElem{5, 5, 5}, + &structSliceElem{6, 6, 6}, + &structSliceElem{7, 7, 7}, + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(&out, &value); err != nil { + b.Fatal(err) + } + } +} diff --git a/rlp/encoder_example_test.go b/rlp/encoder_example_test.go index 1cffa241c2..da0465a533 100644 --- a/rlp/encoder_example_test.go +++ b/rlp/encoder_example_test.go @@ -14,11 +14,13 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package rlp +package rlp_test import ( "fmt" "io" + + "github.com/XinFinOrg/XDPoSChain/rlp" ) type MyCoolType struct { @@ -28,27 +30,19 @@ type MyCoolType struct { // EncodeRLP writes x as RLP list [a, b] that omits the Name field. func (x *MyCoolType) EncodeRLP(w io.Writer) (err error) { - // Note: the receiver can be a nil pointer. This allows you to - // control the encoding of nil, but it also means that you have to - // check for a nil receiver. - if x == nil { - err = Encode(w, []uint{0, 0}) - } else { - err = Encode(w, []uint{x.a, x.b}) - } - return err + return rlp.Encode(w, []uint{x.a, x.b}) } func ExampleEncoder() { var t *MyCoolType // t is nil pointer to MyCoolType - bytes, _ := EncodeToBytes(t) + bytes, _ := rlp.EncodeToBytes(t) fmt.Printf("%v → %X\n", t, bytes) t = &MyCoolType{Name: "foobar", a: 5, b: 6} - bytes, _ = EncodeToBytes(t) + bytes, _ = rlp.EncodeToBytes(t) fmt.Printf("%v → %X\n", t, bytes) // Output: - // → C28080 + // → C0 // &{foobar 5 6} → C20506 } diff --git a/rlp/internal/rlpstruct/rlpstruct.go b/rlp/internal/rlpstruct/rlpstruct.go new file mode 100644 index 0000000000..2e3eeb6881 --- /dev/null +++ b/rlp/internal/rlpstruct/rlpstruct.go @@ -0,0 +1,213 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package rlpstruct implements struct processing for RLP encoding/decoding. +// +// In particular, this package handles all rules around field filtering, +// struct tags and nil value determination. +package rlpstruct + +import ( + "fmt" + "reflect" + "strings" +) + +// Field represents a struct field. +type Field struct { + Name string + Index int + Exported bool + Type Type + Tag string +} + +// Type represents the attributes of a Go type. +type Type struct { + Name string + Kind reflect.Kind + IsEncoder bool // whether type implements rlp.Encoder + IsDecoder bool // whether type implements rlp.Decoder + Elem *Type // non-nil for Kind values of Ptr, Slice, Array +} + +// DefaultNilValue determines whether a nil pointer to t encodes/decodes +// as an empty string or empty list. +func (t Type) DefaultNilValue() NilKind { + k := t.Kind + if isUint(k) || k == reflect.String || k == reflect.Bool || isByteArray(t) { + return NilKindString + } + return NilKindList +} + +// NilKind is the RLP value encoded in place of nil pointers. +type NilKind uint8 + +const ( + NilKindString NilKind = 0x80 + NilKindList NilKind = 0xC0 +) + +// Tags represents struct tags. +type Tags struct { + // rlp:"nil" controls whether empty input results in a nil pointer. + // nilKind is the kind of empty value allowed for the field. + NilKind NilKind + NilOK bool + + // rlp:"optional" allows for a field to be missing in the input list. + // If this is set, all subsequent fields must also be optional. + Optional bool + + // rlp:"tail" controls whether this field swallows additional list elements. It can + // only be set for the last field, which must be of slice type. + Tail bool + + // rlp:"-" ignores fields. + Ignored bool +} + +// TagError is raised for invalid struct tags. +type TagError struct { + StructType string + + // These are set by this package. + Field string + Tag string + Err string +} + +func (e TagError) Error() string { + field := "field " + e.Field + if e.StructType != "" { + field = e.StructType + "." + e.Field + } + return fmt.Sprintf("rlp: invalid struct tag %q for %s (%s)", e.Tag, field, e.Err) +} + +// ProcessFields filters the given struct fields, returning only fields +// that should be considered for encoding/decoding. +func ProcessFields(allFields []Field) ([]Field, []Tags, error) { + lastPublic := lastPublicField(allFields) + + // Gather all exported fields and their tags. + var fields []Field + var tags []Tags + for _, field := range allFields { + if !field.Exported { + continue + } + ts, err := parseTag(field, lastPublic) + if err != nil { + return nil, nil, err + } + if ts.Ignored { + continue + } + fields = append(fields, field) + tags = append(tags, ts) + } + + // Verify optional field consistency. If any optional field exists, + // all fields after it must also be optional. Note: optional + tail + // is supported. + var anyOptional bool + var firstOptionalName string + for i, ts := range tags { + name := fields[i].Name + if ts.Optional || ts.Tail { + if !anyOptional { + firstOptionalName = name + } + anyOptional = true + } else { + if anyOptional { + msg := fmt.Sprintf("must be optional because preceding field %q is optional", firstOptionalName) + return nil, nil, TagError{Field: name, Err: msg} + } + } + } + return fields, tags, nil +} + +func parseTag(field Field, lastPublic int) (Tags, error) { + name := field.Name + tag := reflect.StructTag(field.Tag) + var ts Tags + for _, t := range strings.Split(tag.Get("rlp"), ",") { + switch t = strings.TrimSpace(t); t { + case "": + // empty tag is allowed for some reason + case "-": + ts.Ignored = true + case "nil", "nilString", "nilList": + ts.NilOK = true + if field.Type.Kind != reflect.Ptr { + return ts, TagError{Field: name, Tag: t, Err: "field is not a pointer"} + } + switch t { + case "nil": + ts.NilKind = field.Type.Elem.DefaultNilValue() + case "nilString": + ts.NilKind = NilKindString + case "nilList": + ts.NilKind = NilKindList + } + case "optional": + ts.Optional = true + if ts.Tail { + return ts, TagError{Field: name, Tag: t, Err: `also has "tail" tag`} + } + case "tail": + ts.Tail = true + if field.Index != lastPublic { + return ts, TagError{Field: name, Tag: t, Err: "must be on last field"} + } + if ts.Optional { + return ts, TagError{Field: name, Tag: t, Err: `also has "optional" tag`} + } + if field.Type.Kind != reflect.Slice { + return ts, TagError{Field: name, Tag: t, Err: "field type is not slice"} + } + default: + return ts, TagError{Field: name, Tag: t, Err: "unknown tag"} + } + } + return ts, nil +} + +func lastPublicField(fields []Field) int { + last := 0 + for _, f := range fields { + if f.Exported { + last = f.Index + } + } + return last +} + +func isUint(k reflect.Kind) bool { + return k >= reflect.Uint && k <= reflect.Uintptr +} + +func isByte(typ Type) bool { + return typ.Kind == reflect.Uint8 && !typ.IsEncoder +} + +func isByteArray(typ Type) bool { + return (typ.Kind == reflect.Slice || typ.Kind == reflect.Array) && isByte(*typ.Elem) +} diff --git a/rlp/iterator.go b/rlp/iterator.go new file mode 100644 index 0000000000..95bd3f2582 --- /dev/null +++ b/rlp/iterator.go @@ -0,0 +1,59 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +type listIterator struct { + data []byte + next []byte + err error +} + +// NewListIterator creates an iterator for the (list) represented by data +func NewListIterator(data RawValue) (*listIterator, error) { + k, t, c, err := readKind(data) + if err != nil { + return nil, err + } + if k != List { + return nil, ErrExpectedList + } + it := &listIterator{ + data: data[t : t+c], + } + return it, nil +} + +// Next forwards the iterator one step, returns true if it was not at end yet +func (it *listIterator) Next() bool { + if len(it.data) == 0 { + return false + } + _, t, c, err := readKind(it.data) + it.next = it.data[:t+c] + it.data = it.data[t+c:] + it.err = err + return true +} + +// Value returns the current value +func (it *listIterator) Value() []byte { + return it.next +} + +func (it *listIterator) Err() error { + return it.err +} diff --git a/rlp/iterator_test.go b/rlp/iterator_test.go new file mode 100644 index 0000000000..82ac7bfa6e --- /dev/null +++ b/rlp/iterator_test.go @@ -0,0 +1,59 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +import ( + "testing" + + "github.com/XinFinOrg/XDPoSChain/common/hexutil" +) + +// TestIterator tests some basic things about the ListIterator. A more +// comprehensive test can be found in core/rlp_test.go, where we can +// use both types and rlp without dependency cycles +func TestIterator(t *testing.T) { + bodyRlpHex := "0xf902cbf8d6f869800182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ba01025c66fad28b4ce3370222624d952c35529e602af7cbe04f667371f61b0e3b3a00ab8813514d1217059748fd903288ace1b4001a4bc5fbde2790debdc8167de2ff869010182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ca05ac4cf1d19be06f3742c21df6c49a7e929ceb3dbaf6a09f3cfb56ff6828bd9a7a06875970133a35e63ac06d360aa166d228cc013e9b96e0a2cae7f55b22e1ee2e8f901f0f901eda0c75448377c0e426b8017b23c5f77379ecf69abc1d5c224284ad3ba1c46c59adaa00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000808080808080a00000000000000000000000000000000000000000000000000000000000000000880000000000000000" + bodyRlp := hexutil.MustDecode(bodyRlpHex) + + it, err := NewListIterator(bodyRlp) + if err != nil { + t.Fatal(err) + } + // Check that txs exist + if !it.Next() { + t.Fatal("expected two elems, got zero") + } + txs := it.Value() + // Check that uncles exist + if !it.Next() { + t.Fatal("expected two elems, got one") + } + txit, err := NewListIterator(txs) + if err != nil { + t.Fatal(err) + } + var i = 0 + for txit.Next() { + if txit.err != nil { + t.Fatal(txit.err) + } + i++ + } + if exp := 2; i != exp { + t.Errorf("count wrong, expected %d got %d", i, exp) + } +} diff --git a/rlp/raw.go b/rlp/raw.go index 2b3f328f66..773aa7e614 100644 --- a/rlp/raw.go +++ b/rlp/raw.go @@ -28,12 +28,53 @@ type RawValue []byte var rawValueType = reflect.TypeOf(RawValue{}) +// StringSize returns the encoded size of a string. +func StringSize(s string) uint64 { + switch { + case len(s) == 0: + return 1 + case len(s) == 1: + if s[0] <= 0x7f { + return 1 + } else { + return 2 + } + default: + return uint64(headsize(uint64(len(s))) + len(s)) + } +} + +// BytesSize returns the encoded size of a byte slice. +func BytesSize(b []byte) uint64 { + switch { + case len(b) == 0: + return 1 + case len(b) == 1: + if b[0] <= 0x7f { + return 1 + } else { + return 2 + } + default: + return uint64(headsize(uint64(len(b))) + len(b)) + } +} + // ListSize returns the encoded size of an RLP list with the given // content size. func ListSize(contentSize uint64) uint64 { return uint64(headsize(contentSize)) + contentSize } +// IntSize returns the encoded size of the integer x. Note: The return type of this +// function is 'int' for backwards-compatibility reasons. The result is always positive. +func IntSize(x uint64) int { + if x < 0x80 { + return 1 + } + return 1 + intsize(x) +} + // Split returns the content of first RLP value and any // bytes after the value as subslices of b. func Split(b []byte) (k Kind, content, rest []byte, err error) { @@ -57,6 +98,32 @@ func SplitString(b []byte) (content, rest []byte, err error) { return content, rest, nil } +// SplitUint64 decodes an integer at the beginning of b. +// It also returns the remaining data after the integer in 'rest'. +func SplitUint64(b []byte) (x uint64, rest []byte, err error) { + content, rest, err := SplitString(b) + if err != nil { + return 0, b, err + } + switch { + case len(content) == 0: + return 0, rest, nil + case len(content) == 1: + if content[0] == 0 { + return 0, b, ErrCanonInt + } + return uint64(content[0]), rest, nil + case len(content) > 8: + return 0, b, errUintOverflow + default: + x, err = readSize(content, byte(len(content))) + if err != nil { + return 0, b, ErrCanonInt + } + return x, rest, nil + } +} + // SplitList splits b into the content of a list and any remaining // bytes after the list. func SplitList(b []byte) (content, rest []byte, err error) { @@ -154,3 +221,74 @@ func readSize(b []byte, slen byte) (uint64, error) { } return s, nil } + +// AppendUint64 appends the RLP encoding of i to b, and returns the resulting slice. +func AppendUint64(b []byte, i uint64) []byte { + if i == 0 { + return append(b, 0x80) + } else if i < 128 { + return append(b, byte(i)) + } + switch { + case i < (1 << 8): + return append(b, 0x81, byte(i)) + case i < (1 << 16): + return append(b, 0x82, + byte(i>>8), + byte(i), + ) + case i < (1 << 24): + return append(b, 0x83, + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 32): + return append(b, 0x84, + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 40): + return append(b, 0x85, + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + case i < (1 << 48): + return append(b, 0x86, + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 56): + return append(b, 0x87, + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + default: + return append(b, 0x88, + byte(i>>56), + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + } +} diff --git a/rlp/raw_test.go b/rlp/raw_test.go index 2aad042100..7b3255eca3 100644 --- a/rlp/raw_test.go +++ b/rlp/raw_test.go @@ -18,9 +18,10 @@ package rlp import ( "bytes" + "errors" "io" - "reflect" "testing" + "testing/quick" ) func TestCountValues(t *testing.T) { @@ -53,21 +54,84 @@ func TestCountValues(t *testing.T) { if count != test.count { t.Errorf("test %d: count mismatch, got %d want %d\ninput: %s", i, count, test.count, test.input) } - if !reflect.DeepEqual(err, test.err) { + if !errors.Is(err, test.err) { t.Errorf("test %d: err mismatch, got %q want %q\ninput: %s", i, err, test.err, test.input) } } } -func TestSplitTypes(t *testing.T) { - if _, _, err := SplitString(unhex("C100")); err != ErrExpectedString { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedString) +func TestSplitString(t *testing.T) { + for i, test := range []string{ + "C0", + "C100", + "C3010203", + "C88363617483646F67", + "F8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974", + } { + if _, _, err := SplitString(unhex(test)); !errors.Is(err, ErrExpectedString) { + t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedString) + } } - if _, _, err := SplitList(unhex("01")); err != ErrExpectedList { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedList) +} + +func TestSplitList(t *testing.T) { + for i, test := range []string{ + "80", + "00", + "01", + "8180", + "81FF", + "820400", + "83636174", + "83646F67", + "B8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974", + } { + if _, _, err := SplitList(unhex(test)); !errors.Is(err, ErrExpectedList) { + t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedList) + } } - if _, _, err := SplitList(unhex("81FF")); err != ErrExpectedList { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedList) +} + +func TestSplitUint64(t *testing.T) { + tests := []struct { + input string + val uint64 + rest string + err error + }{ + {"01", 1, "", nil}, + {"7FFF", 0x7F, "FF", nil}, + {"80FF", 0, "FF", nil}, + {"81FAFF", 0xFA, "FF", nil}, + {"82FAFAFF", 0xFAFA, "FF", nil}, + {"83FAFAFAFF", 0xFAFAFA, "FF", nil}, + {"84FAFAFAFAFF", 0xFAFAFAFA, "FF", nil}, + {"85FAFAFAFAFAFF", 0xFAFAFAFAFA, "FF", nil}, + {"86FAFAFAFAFAFAFF", 0xFAFAFAFAFAFA, "FF", nil}, + {"87FAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFA, "FF", nil}, + {"88FAFAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFAFA, "FF", nil}, + + // errors + {"", 0, "", io.ErrUnexpectedEOF}, + {"00", 0, "00", ErrCanonInt}, + {"81", 0, "81", ErrValueTooLarge}, + {"8100", 0, "8100", ErrCanonSize}, + {"8200FF", 0, "8200FF", ErrCanonInt}, + {"8103FF", 0, "8103FF", ErrCanonSize}, + {"89FAFAFAFAFAFAFAFAFAFF", 0, "89FAFAFAFAFAFAFAFAFAFF", errUintOverflow}, + } + + for i, test := range tests { + val, rest, err := SplitUint64(unhex(test.input)) + if val != test.val { + t.Errorf("test %d: val mismatch: got %x, want %x (input %q)", i, val, test.val, test.input) + } + if !bytes.Equal(rest, unhex(test.rest)) { + t.Errorf("test %d: rest mismatch: got %x, want %s (input %q)", i, rest, test.rest, test.input) + } + if err != test.err { + t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) + } } } @@ -78,7 +142,9 @@ func TestSplit(t *testing.T) { val, rest string err error }{ + {input: "00FFFF", kind: Byte, val: "00", rest: "FFFF"}, {input: "01FFFF", kind: Byte, val: "01", rest: "FFFF"}, + {input: "7FFFFF", kind: Byte, val: "7F", rest: "FFFF"}, {input: "80FFFF", kind: String, val: "", rest: "FFFF"}, {input: "C3010203", kind: List, val: "010203"}, @@ -194,3 +260,79 @@ func TestReadSize(t *testing.T) { } } } + +func TestAppendUint64(t *testing.T) { + tests := []struct { + input uint64 + slice []byte + output string + }{ + {0, nil, "80"}, + {1, nil, "01"}, + {2, nil, "02"}, + {127, nil, "7F"}, + {128, nil, "8180"}, + {129, nil, "8181"}, + {0xFFFFFF, nil, "83FFFFFF"}, + {127, []byte{1, 2, 3}, "0102037F"}, + {0xFFFFFF, []byte{1, 2, 3}, "01020383FFFFFF"}, + } + + for _, test := range tests { + x := AppendUint64(test.slice, test.input) + if !bytes.Equal(x, unhex(test.output)) { + t.Errorf("AppendUint64(%v, %d): got %x, want %s", test.slice, test.input, x, test.output) + } + + // Check that IntSize returns the appended size. + length := len(x) - len(test.slice) + if s := IntSize(test.input); s != length { + t.Errorf("IntSize(%d): got %d, want %d", test.input, s, length) + } + } +} + +func TestAppendUint64Random(t *testing.T) { + fn := func(i uint64) bool { + enc, _ := EncodeToBytes(i) + encAppend := AppendUint64(nil, i) + return bytes.Equal(enc, encAppend) + } + config := quick.Config{MaxCountScale: 50} + if err := quick.Check(fn, &config); err != nil { + t.Fatal(err) + } +} + +func TestBytesSize(t *testing.T) { + tests := []struct { + v []byte + size uint64 + }{ + {v: []byte{}, size: 1}, + {v: []byte{0x1}, size: 1}, + {v: []byte{0x7E}, size: 1}, + {v: []byte{0x7F}, size: 1}, + {v: []byte{0x80}, size: 2}, + {v: []byte{0xFF}, size: 2}, + {v: []byte{0xFF, 0xF0}, size: 3}, + {v: make([]byte, 55), size: 56}, + {v: make([]byte, 56), size: 58}, + } + + for _, test := range tests { + s := BytesSize(test.v) + if s != test.size { + t.Errorf("BytesSize(%#x) -> %d, want %d", test.v, s, test.size) + } + s = StringSize(string(test.v)) + if s != test.size { + t.Errorf("StringSize(%#x) -> %d, want %d", test.v, s, test.size) + } + // Sanity check: + enc, _ := EncodeToBytes(test.v) + if uint64(len(enc)) != test.size { + t.Errorf("len(EncodeToBytes(%#x)) -> %d, test says %d", test.v, len(enc), test.size) + } + } +} diff --git a/rlp/rlpgen/gen.go b/rlp/rlpgen/gen.go new file mode 100644 index 0000000000..ed502c09a7 --- /dev/null +++ b/rlp/rlpgen/gen.go @@ -0,0 +1,800 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "fmt" + "go/format" + "go/types" + "sort" + + "github.com/XinFinOrg/XDPoSChain/rlp/internal/rlpstruct" +) + +// buildContext keeps the data needed for make*Op. +type buildContext struct { + topType *types.Named // the type we're creating methods for + + encoderIface *types.Interface + decoderIface *types.Interface + rawValueType *types.Named + + typeToStructCache map[types.Type]*rlpstruct.Type +} + +func newBuildContext(packageRLP *types.Package) *buildContext { + enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying() + dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying() + rawv := packageRLP.Scope().Lookup("RawValue").Type() + return &buildContext{ + typeToStructCache: make(map[types.Type]*rlpstruct.Type), + encoderIface: enc.(*types.Interface), + decoderIface: dec.(*types.Interface), + rawValueType: rawv.(*types.Named), + } +} + +func (bctx *buildContext) isEncoder(typ types.Type) bool { + return types.Implements(typ, bctx.encoderIface) +} + +func (bctx *buildContext) isDecoder(typ types.Type) bool { + return types.Implements(typ, bctx.decoderIface) +} + +// typeToStructType converts typ to rlpstruct.Type. +func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type { + if prev := bctx.typeToStructCache[typ]; prev != nil { + return prev // short-circuit for recursive types. + } + + // Resolve named types to their underlying type, but keep the name. + name := types.TypeString(typ, nil) + for { + utype := typ.Underlying() + if utype == typ { + break + } + typ = utype + } + + // Create the type and store it in cache. + t := &rlpstruct.Type{ + Name: name, + Kind: typeReflectKind(typ), + IsEncoder: bctx.isEncoder(typ), + IsDecoder: bctx.isDecoder(typ), + } + bctx.typeToStructCache[typ] = t + + // Assign element type. + switch typ.(type) { + case *types.Array, *types.Slice, *types.Pointer: + etype := typ.(interface{ Elem() types.Type }).Elem() + t.Elem = bctx.typeToStructType(etype) + } + return t +} + +// genContext is passed to the gen* methods of op when generating +// the output code. It tracks packages to be imported by the output +// file and assigns unique names of temporary variables. +type genContext struct { + inPackage *types.Package + imports map[string]struct{} + tempCounter int +} + +func newGenContext(inPackage *types.Package) *genContext { + return &genContext{ + inPackage: inPackage, + imports: make(map[string]struct{}), + } +} + +func (ctx *genContext) temp() string { + v := fmt.Sprintf("_tmp%d", ctx.tempCounter) + ctx.tempCounter++ + return v +} + +func (ctx *genContext) resetTemp() { + ctx.tempCounter = 0 +} + +func (ctx *genContext) addImport(path string) { + if path == ctx.inPackage.Path() { + return // avoid importing the package that we're generating in. + } + // TODO: renaming? + ctx.imports[path] = struct{}{} +} + +// importsList returns all packages that need to be imported. +func (ctx *genContext) importsList() []string { + imp := make([]string, 0, len(ctx.imports)) + for k := range ctx.imports { + imp = append(imp, k) + } + sort.Strings(imp) + return imp +} + +// qualify is the types.Qualifier used for printing types. +func (ctx *genContext) qualify(pkg *types.Package) string { + if pkg.Path() == ctx.inPackage.Path() { + return "" + } + ctx.addImport(pkg.Path()) + // TODO: renaming? + return pkg.Name() +} + +type op interface { + // genWrite creates the encoder. The generated code should write v, + // which is any Go expression, to the rlp.EncoderBuffer 'w'. + genWrite(ctx *genContext, v string) string + + // genDecode creates the decoder. The generated code should read + // a value from the rlp.Stream 'dec' and store it to dst. + genDecode(ctx *genContext) (string, string) +} + +// basicOp handles basic types bool, uint*, string. +type basicOp struct { + typ types.Type + writeMethod string // EncoderBuffer writer method name + writeArgType types.Type // parameter type of writeMethod + decMethod string + decResultType types.Type // return type of decMethod + decUseBitSize bool // if true, result bit size is appended to decMethod +} + +func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) { + op := basicOp{typ: typ} + kind := typ.Kind() + switch { + case kind == types.Bool: + op.writeMethod = "WriteBool" + op.writeArgType = types.Typ[types.Bool] + op.decMethod = "Bool" + op.decResultType = types.Typ[types.Bool] + case kind >= types.Uint8 && kind <= types.Uint64: + op.writeMethod = "WriteUint64" + op.writeArgType = types.Typ[types.Uint64] + op.decMethod = "Uint" + op.decResultType = typ + op.decUseBitSize = true + case kind == types.String: + op.writeMethod = "WriteString" + op.writeArgType = types.Typ[types.String] + op.decMethod = "String" + op.decResultType = types.Typ[types.String] + default: + return nil, fmt.Errorf("unhandled basic type: %v", typ) + } + return op, nil +} + +func (*buildContext) makeByteSliceOp(typ *types.Slice) op { + if !isByte(typ.Elem()) { + panic("non-byte slice type in makeByteSliceOp") + } + bslice := types.NewSlice(types.Typ[types.Uint8]) + return basicOp{ + typ: typ, + writeMethod: "WriteBytes", + writeArgType: bslice, + decMethod: "Bytes", + decResultType: bslice, + } +} + +func (bctx *buildContext) makeRawValueOp() op { + bslice := types.NewSlice(types.Typ[types.Uint8]) + return basicOp{ + typ: bctx.rawValueType, + writeMethod: "Write", + writeArgType: bslice, + decMethod: "Raw", + decResultType: bslice, + } +} + +func (op basicOp) writeNeedsConversion() bool { + return !types.AssignableTo(op.typ, op.writeArgType) +} + +func (op basicOp) decodeNeedsConversion() bool { + return !types.AssignableTo(op.decResultType, op.typ) +} + +func (op basicOp) genWrite(ctx *genContext, v string) string { + if op.writeNeedsConversion() { + v = fmt.Sprintf("%s(%s)", op.writeArgType, v) + } + return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v) +} + +func (op basicOp) genDecode(ctx *genContext) (string, string) { + var ( + resultV = ctx.temp() + result = resultV + method = op.decMethod + ) + if op.decUseBitSize { + // Note: For now, this only works for platform-independent integer + // sizes. makeBasicOp forbids the platform-dependent types. + var sizes types.StdSizes + method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8) + } + + // Call the decoder method. + var b bytes.Buffer + fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method) + fmt.Fprintf(&b, "if err != nil { return err }\n") + if op.decodeNeedsConversion() { + conv := ctx.temp() + fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV) + result = conv + } + return result, b.String() +} + +// byteArrayOp handles [...]byte. +type byteArrayOp struct { + typ types.Type + name types.Type // name != typ for named byte array types (e.g. common.Address) +} + +func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp { + nt := types.Type(name) + if name == nil { + nt = typ + } + return byteArrayOp{typ, nt} +} + +func (op byteArrayOp) genWrite(ctx *genContext, v string) string { + return fmt.Sprintf("w.WriteBytes(%s[:])\n", v) +} + +func (op byteArrayOp) genDecode(ctx *genContext) (string, string) { + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify)) + fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV) + return resultV, b.String() +} + +// bigIntOp handles big.Int. +// This exists because big.Int has it's own decoder operation on rlp.Stream, +// but the decode method returns *big.Int, so it needs to be dereferenced. +type bigIntOp struct { + pointer bool +} + +func (op bigIntOp) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + + fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v) + fmt.Fprintf(&b, " return rlp.ErrNegativeBigInt\n") + fmt.Fprintf(&b, "}\n") + dst := v + if !op.pointer { + dst = "&" + v + } + fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst) + + // Wrap with nil check. + if op.pointer { + code := b.String() + b.Reset() + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") + fmt.Fprintf(&b, "} else {\n") + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "}\n") + } + + return b.String() +} + +func (op bigIntOp) genDecode(ctx *genContext) (string, string) { + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV) + fmt.Fprintf(&b, "if err != nil { return err }\n") + + result := resultV + if !op.pointer { + result = "(*" + resultV + ")" + } + return result, b.String() +} + +// uint256Op handles "github.com/holiman/uint256".Int +type uint256Op struct { + pointer bool +} + +func (op uint256Op) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + + dst := v + if !op.pointer { + dst = "&" + v + } + fmt.Fprintf(&b, "w.WriteUint256(%s)\n", dst) + + // Wrap with nil check. + if op.pointer { + code := b.String() + b.Reset() + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") + fmt.Fprintf(&b, "} else {\n") + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "}\n") + } + + return b.String() +} + +func (op uint256Op) genDecode(ctx *genContext) (string, string) { + ctx.addImport("github.com/holiman/uint256") + + var b bytes.Buffer + resultV := ctx.temp() + fmt.Fprintf(&b, "var %s uint256.Int\n", resultV) + fmt.Fprintf(&b, "if err := dec.ReadUint256(&%s); err != nil { return err }\n", resultV) + + result := resultV + if op.pointer { + result = "&" + resultV + } + return result, b.String() +} + +// encoderDecoderOp handles rlp.Encoder and rlp.Decoder. +// In order to be used with this, the type must implement both interfaces. +// This restriction may be lifted in the future by creating separate ops for +// encoding and decoding. +type encoderDecoderOp struct { + typ types.Type +} + +func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string { + return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v) +} + +func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) { + // DecodeRLP must have pointer receiver, and this is verified in makeOp. + etyp := op.typ.(*types.Pointer).Elem() + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify)) + fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV) + return resultV, b.String() +} + +// ptrOp handles pointer types. +type ptrOp struct { + elemTyp types.Type + elem op + nilOK bool + nilValue rlpstruct.NilKind +} + +func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) { + elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{}) + if err != nil { + return nil, err + } + op := ptrOp{elemTyp: elemTyp, elem: elemOp} + + // Determine nil value. + if tags.NilOK { + op.nilOK = true + op.nilValue = tags.NilKind + } else { + styp := bctx.typeToStructType(elemTyp) + op.nilValue = styp.DefaultNilValue() + } + return op, nil +} + +func (op ptrOp) genWrite(ctx *genContext, v string) string { + // Note: in writer functions, accesses to v are read-only, i.e. v is any Go + // expression. To make all accesses work through the pointer, we substitute + // v with (*v). This is required for most accesses including `v`, `call(v)`, + // and `v[index]` on slices. + // + // For `v.field` and `v[:]` on arrays, the dereference operation is not required. + var vv string + _, isStruct := op.elem.(structOp) + _, isByteArray := op.elem.(byteArrayOp) + if isStruct || isByteArray { + vv = v + } else { + vv = fmt.Sprintf("(*%s)", v) + } + + var b bytes.Buffer + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write([]byte{0x%X})\n", op.nilValue) + fmt.Fprintf(&b, "} else {\n") + fmt.Fprintf(&b, " %s", op.elem.genWrite(ctx, vv)) + fmt.Fprintf(&b, "}\n") + return b.String() +} + +func (op ptrOp) genDecode(ctx *genContext) (string, string) { + result, code := op.elem.genDecode(ctx) + if !op.nilOK { + // If nil pointers are not allowed, we can just decode the element. + return "&" + result, code + } + + // nil is allowed, so check the kind and size first. + // If size is zero and kind matches the nilKind of the type, + // the value decodes as a nil pointer. + var ( + resultV = ctx.temp() + kindV = ctx.temp() + sizeV = ctx.temp() + wantKind string + ) + if op.nilValue == rlpstruct.NilKindList { + wantKind = "rlp.List" + } else { + wantKind = "rlp.String" + } + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify)) + fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV) + fmt.Fprintf(&b, " return err\n") + fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, " %s = &%s\n", resultV, result) + fmt.Fprintf(&b, "}\n") + return resultV, b.String() +} + +// structOp handles struct types. +type structOp struct { + named *types.Named + typ *types.Struct + fields []*structField + optionalFields []*structField +} + +type structField struct { + name string + typ types.Type + elem op +} + +func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) { + // Convert fields to []rlpstruct.Field. + var allStructFields []rlpstruct.Field + for i := 0; i < typ.NumFields(); i++ { + f := typ.Field(i) + allStructFields = append(allStructFields, rlpstruct.Field{ + Name: f.Name(), + Exported: f.Exported(), + Index: i, + Tag: typ.Tag(i), + Type: *bctx.typeToStructType(f.Type()), + }) + } + + // Filter/validate fields. + fields, tags, err := rlpstruct.ProcessFields(allStructFields) + if err != nil { + return nil, err + } + + // Create field ops. + var op = structOp{named: named, typ: typ} + for i, field := range fields { + // Advanced struct tags are not supported yet. + tag := tags[i] + if err := checkUnsupportedTags(field.Name, tag); err != nil { + return nil, err + } + typ := typ.Field(field.Index).Type() + elem, err := bctx.makeOp(nil, typ, tags[i]) + if err != nil { + return nil, fmt.Errorf("field %s: %v", field.Name, err) + } + f := &structField{name: field.Name, typ: typ, elem: elem} + if tag.Optional { + op.optionalFields = append(op.optionalFields, f) + } else { + op.fields = append(op.fields, f) + } + } + return op, nil +} + +func checkUnsupportedTags(field string, tag rlpstruct.Tags) error { + if tag.Tail { + return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field) + } + return nil +} + +func (op structOp) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + var listMarker = ctx.temp() + fmt.Fprintf(&b, "%s := w.List()\n", listMarker) + for _, field := range op.fields { + selector := v + "." + field.name + fmt.Fprint(&b, field.elem.genWrite(ctx, selector)) + } + op.writeOptionalFields(&b, ctx, v) + fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) + return b.String() +} + +func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) { + if len(op.optionalFields) == 0 { + return + } + // First check zero-ness of all optional fields. + var zeroV = make([]string, len(op.optionalFields)) + for i, field := range op.optionalFields { + selector := v + "." + field.name + zeroV[i] = ctx.temp() + fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify)) + } + // Now write the fields. + for i, field := range op.optionalFields { + selector := v + "." + field.name + cond := "" + for j := i; j < len(op.optionalFields); j++ { + if j > i { + cond += " || " + } + cond += zeroV[j] + } + fmt.Fprintf(b, "if %s {\n", cond) + fmt.Fprint(b, field.elem.genWrite(ctx, selector)) + fmt.Fprintf(b, "}\n") + } +} + +func (op structOp) genDecode(ctx *genContext) (string, string) { + // Get the string representation of the type. + // Here, named types are handled separately because the output + // would contain a copy of the struct definition otherwise. + var typeName string + if op.named != nil { + typeName = types.TypeString(op.named, ctx.qualify) + } else { + typeName = types.TypeString(op.typ, ctx.qualify) + } + + // Create struct object. + var resultV = ctx.temp() + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, typeName) + + // Decode fields. + fmt.Fprintf(&b, "{\n") + fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") + for _, field := range op.fields { + result, code := field.elem.genDecode(ctx) + fmt.Fprintf(&b, "// %s:\n", field.name) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result) + } + op.decodeOptionalFields(&b, ctx, resultV) + fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") + fmt.Fprintf(&b, "}\n") + return resultV, b.String() +} + +func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) { + var suffix bytes.Buffer + for _, field := range op.optionalFields { + result, code := field.elem.genDecode(ctx) + fmt.Fprintf(b, "// %s:\n", field.name) + fmt.Fprintf(b, "if dec.MoreDataInList() {\n") + fmt.Fprint(b, code) + fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result) + fmt.Fprintf(&suffix, "}\n") + } + suffix.WriteTo(b) +} + +// sliceOp handles slice types. +type sliceOp struct { + typ *types.Slice + elemOp op +} + +func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) { + elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{}) + if err != nil { + return nil, err + } + return sliceOp{typ: typ, elemOp: elemOp}, nil +} + +func (op sliceOp) genWrite(ctx *genContext, v string) string { + var ( + listMarker = ctx.temp() // holds return value of w.List() + iterElemV = ctx.temp() // iteration variable + elemCode = op.elemOp.genWrite(ctx, iterElemV) + ) + + var b bytes.Buffer + fmt.Fprintf(&b, "%s := w.List()\n", listMarker) + fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v) + fmt.Fprint(&b, elemCode) + fmt.Fprintf(&b, "}\n") + fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) + return b.String() +} + +func (op sliceOp) genDecode(ctx *genContext) (string, string) { + var sliceV = ctx.temp() // holds the output slice + elemResult, elemCode := op.elemOp.genDecode(ctx) + + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify)) + fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") + fmt.Fprintf(&b, "for dec.MoreDataInList() {\n") + fmt.Fprintf(&b, " %s", elemCode) + fmt.Fprintf(&b, " %s = append(%s, %s)\n", sliceV, sliceV, elemResult) + fmt.Fprintf(&b, "}\n") + fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") + return sliceV, b.String() +} + +func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) { + switch typ := typ.(type) { + case *types.Named: + if isBigInt(typ) { + return bigIntOp{}, nil + } + if isUint256(typ) { + return uint256Op{}, nil + } + if typ == bctx.rawValueType { + return bctx.makeRawValueOp(), nil + } + if bctx.isDecoder(typ) { + return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ) + } + // TODO: same check for encoder? + return bctx.makeOp(typ, typ.Underlying(), tags) + case *types.Pointer: + if isBigInt(typ.Elem()) { + return bigIntOp{pointer: true}, nil + } + if isUint256(typ.Elem()) { + return uint256Op{pointer: true}, nil + } + // Encoder/Decoder interfaces. + if bctx.isEncoder(typ) { + if bctx.isDecoder(typ) { + return encoderDecoderOp{typ}, nil + } + return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ) + } + if bctx.isDecoder(typ) { + return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ) + } + // Default pointer handling. + return bctx.makePtrOp(typ.Elem(), tags) + case *types.Basic: + return bctx.makeBasicOp(typ) + case *types.Struct: + return bctx.makeStructOp(name, typ) + case *types.Slice: + etyp := typ.Elem() + if isByte(etyp) && !bctx.isEncoder(etyp) { + return bctx.makeByteSliceOp(typ), nil + } + return bctx.makeSliceOp(typ) + case *types.Array: + etyp := typ.Elem() + if isByte(etyp) && !bctx.isEncoder(etyp) { + return bctx.makeByteArrayOp(name, typ), nil + } + return nil, fmt.Errorf("unhandled array type: %v", typ) + default: + return nil, fmt.Errorf("unhandled type: %v", typ) + } +} + +// generateDecoder generates the DecodeRLP method on 'typ'. +func generateDecoder(ctx *genContext, typ string, op op) []byte { + ctx.resetTemp() + ctx.addImport(pathOfPackageRLP) + + result, code := op.genDecode(ctx) + var b bytes.Buffer + fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, " *obj = %s\n", result) + fmt.Fprintf(&b, " return nil\n") + fmt.Fprintf(&b, "}\n") + return b.Bytes() +} + +// generateEncoder generates the EncodeRLP method on 'typ'. +func generateEncoder(ctx *genContext, typ string, op op) []byte { + ctx.resetTemp() + ctx.addImport("io") + ctx.addImport(pathOfPackageRLP) + + var b bytes.Buffer + fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ) + fmt.Fprintf(&b, " w := rlp.NewEncoderBuffer(_w)\n") + fmt.Fprint(&b, op.genWrite(ctx, "obj")) + fmt.Fprintf(&b, " return w.Flush()\n") + fmt.Fprintf(&b, "}\n") + return b.Bytes() +} + +func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) { + bctx.topType = typ + + pkg := typ.Obj().Pkg() + op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{}) + if err != nil { + return nil, err + } + + var ( + ctx = newGenContext(pkg) + encSource []byte + decSource []byte + ) + if encoder { + encSource = generateEncoder(ctx, typ.Obj().Name(), op) + } + if decoder { + decSource = generateDecoder(ctx, typ.Obj().Name(), op) + } + + var b bytes.Buffer + fmt.Fprintf(&b, "package %s\n\n", pkg.Name()) + for _, imp := range ctx.importsList() { + fmt.Fprintf(&b, "import %q\n", imp) + } + if encoder { + fmt.Fprintln(&b) + b.Write(encSource) + } + if decoder { + fmt.Fprintln(&b) + b.Write(decSource) + } + + source := b.Bytes() + // fmt.Println(string(source)) + return format.Source(source) +} diff --git a/rlp/rlpgen/gen_test.go b/rlp/rlpgen/gen_test.go new file mode 100644 index 0000000000..3b4f5df287 --- /dev/null +++ b/rlp/rlpgen/gen_test.go @@ -0,0 +1,107 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" + "os" + "path/filepath" + "testing" +) + +// Package RLP is loaded only once and reused for all tests. +var ( + testFset = token.NewFileSet() + testImporter = importer.ForCompiler(testFset, "source", nil).(types.ImporterFrom) + testPackageRLP *types.Package +) + +func init() { + cwd, err := os.Getwd() + if err != nil { + panic(err) + } + testPackageRLP, err = testImporter.ImportFrom(pathOfPackageRLP, cwd, 0) + if err != nil { + panic(fmt.Errorf("can't load package RLP: %v", err)) + } +} + +var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"} + +func TestOutput(t *testing.T) { + for _, test := range tests { + test := test + t.Run(test, func(t *testing.T) { + inputFile := filepath.Join("testdata", test+".in.txt") + outputFile := filepath.Join("testdata", test+".out.txt") + bctx, typ, err := loadTestSource(inputFile, "Test") + if err != nil { + t.Fatal("error loading test source:", err) + } + output, err := bctx.generate(typ, true, true) + if err != nil { + t.Fatal("error in generate:", err) + } + + // Set this environment variable to regenerate the test outputs. + if os.Getenv("WRITE_TEST_FILES") != "" { + os.WriteFile(outputFile, output, 0644) + } + + // Check if output matches. + wantOutput, err := os.ReadFile(outputFile) + if err != nil { + t.Fatal("error loading expected test output:", err) + } + if !bytes.Equal(output, wantOutput) { + t.Fatalf("output mismatch, want: %v got %v", string(wantOutput), string(output)) + } + }) + } +} + +func loadTestSource(file string, typeName string) (*buildContext, *types.Named, error) { + // Load the test input. + content, err := os.ReadFile(file) + if err != nil { + return nil, nil, err + } + f, err := parser.ParseFile(testFset, file, content, 0) + if err != nil { + return nil, nil, err + } + conf := types.Config{Importer: testImporter} + pkg, err := conf.Check("test", testFset, []*ast.File{f}, nil) + if err != nil { + return nil, nil, err + } + + // Find the test struct. + bctx := newBuildContext(testPackageRLP) + typ, err := lookupStructType(pkg.Scope(), typeName) + if err != nil { + return nil, nil, fmt.Errorf("can't find type %s: %v", typeName, err) + } + return bctx, typ, nil +} diff --git a/rlp/rlpgen/main.go b/rlp/rlpgen/main.go new file mode 100644 index 0000000000..7271882306 --- /dev/null +++ b/rlp/rlpgen/main.go @@ -0,0 +1,144 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "go/types" + "os" + + "golang.org/x/tools/go/packages" +) + +const pathOfPackageRLP = "github.com/XinFinOrg/XDPoSChain/rlp" + +func main() { + var ( + pkgdir = flag.String("dir", ".", "input package") + output = flag.String("out", "-", "output file (default is stdout)") + genEncoder = flag.Bool("encoder", true, "generate EncodeRLP?") + genDecoder = flag.Bool("decoder", false, "generate DecodeRLP?") + typename = flag.String("type", "", "type to generate methods for") + ) + flag.Parse() + + cfg := Config{ + Dir: *pkgdir, + Type: *typename, + GenerateEncoder: *genEncoder, + GenerateDecoder: *genDecoder, + } + code, err := cfg.process() + if err != nil { + fatal(err) + } + if *output == "-" { + os.Stdout.Write(code) + } else if err := os.WriteFile(*output, code, 0600); err != nil { + fatal(err) + } +} + +func fatal(args ...interface{}) { + fmt.Fprintln(os.Stderr, args...) + os.Exit(1) +} + +type Config struct { + Dir string // input package directory + Type string + + GenerateEncoder bool + GenerateDecoder bool +} + +// process generates the Go code. +func (cfg *Config) process() (code []byte, err error) { + // Load packages. + pcfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes, + Dir: cfg.Dir, + } + ps, err := packages.Load(pcfg, pathOfPackageRLP, ".") + if err != nil { + return nil, err + } + if len(ps) == 0 { + return nil, fmt.Errorf("no Go package found in %s", cfg.Dir) + } + packages.PrintErrors(ps) + + // Find the packages that were loaded. + var ( + pkg *types.Package + packageRLP *types.Package + ) + for _, p := range ps { + if len(p.Errors) > 0 { + return nil, fmt.Errorf("package %s has errors", p.PkgPath) + } + if p.PkgPath == pathOfPackageRLP { + packageRLP = p.Types + } else { + pkg = p.Types + } + } + bctx := newBuildContext(packageRLP) + + // Find the type and generate. + typ, err := lookupStructType(pkg.Scope(), cfg.Type) + if err != nil { + return nil, fmt.Errorf("can't find %s in %s: %v", cfg.Type, pkg, err) + } + code, err = bctx.generate(typ, cfg.GenerateEncoder, cfg.GenerateDecoder) + if err != nil { + return nil, err + } + + // Add build comments. + // This is done here to avoid processing these lines with gofmt. + var header bytes.Buffer + fmt.Fprint(&header, "// Code generated by rlpgen. DO NOT EDIT.\n\n") + return append(header.Bytes(), code...), nil +} + +func lookupStructType(scope *types.Scope, name string) (*types.Named, error) { + typ, err := lookupType(scope, name) + if err != nil { + return nil, err + } + _, ok := typ.Underlying().(*types.Struct) + if !ok { + return nil, errors.New("not a struct type") + } + return typ, nil +} + +func lookupType(scope *types.Scope, name string) (*types.Named, error) { + obj := scope.Lookup(name) + if obj == nil { + return nil, errors.New("no such identifier") + } + typ, ok := obj.(*types.TypeName) + if !ok { + return nil, errors.New("not a type") + } + return typ.Type().(*types.Named), nil +} diff --git a/rlp/rlpgen/testdata/bigint.in.txt b/rlp/rlpgen/testdata/bigint.in.txt new file mode 100644 index 0000000000..d23d84a287 --- /dev/null +++ b/rlp/rlpgen/testdata/bigint.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +import "math/big" + +type Test struct { + Int *big.Int + IntNoPtr big.Int +} diff --git a/rlp/rlpgen/testdata/bigint.out.txt b/rlp/rlpgen/testdata/bigint.out.txt new file mode 100644 index 0000000000..faab1bed46 --- /dev/null +++ b/rlp/rlpgen/testdata/bigint.out.txt @@ -0,0 +1,49 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Int == nil { + w.Write(rlp.EmptyString) + } else { + if obj.Int.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(obj.Int) + } + if obj.IntNoPtr.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(&obj.IntNoPtr) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Int: + _tmp1, err := dec.BigInt() + if err != nil { + return err + } + _tmp0.Int = _tmp1 + // IntNoPtr: + _tmp2, err := dec.BigInt() + if err != nil { + return err + } + _tmp0.IntNoPtr = (*_tmp2) + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/nil.in.txt b/rlp/rlpgen/testdata/nil.in.txt new file mode 100644 index 0000000000..a28ff34487 --- /dev/null +++ b/rlp/rlpgen/testdata/nil.in.txt @@ -0,0 +1,30 @@ +// -*- mode: go -*- + +package test + +type Aux struct{ + A uint32 +} + +type Test struct{ + Uint8 *byte `rlp:"nil"` + Uint8List *byte `rlp:"nilList"` + + Uint32 *uint32 `rlp:"nil"` + Uint32List *uint32 `rlp:"nilList"` + + Uint64 *uint64 `rlp:"nil"` + Uint64List *uint64 `rlp:"nilList"` + + String *string `rlp:"nil"` + StringList *string `rlp:"nilList"` + + ByteArray *[3]byte `rlp:"nil"` + ByteArrayList *[3]byte `rlp:"nilList"` + + ByteSlice *[]byte `rlp:"nil"` + ByteSliceList *[]byte `rlp:"nilList"` + + Struct *Aux `rlp:"nil"` + StructString *Aux `rlp:"nilString"` +} diff --git a/rlp/rlpgen/testdata/nil.out.txt b/rlp/rlpgen/testdata/nil.out.txt new file mode 100644 index 0000000000..7f3459682b --- /dev/null +++ b/rlp/rlpgen/testdata/nil.out.txt @@ -0,0 +1,289 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Uint8 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64(uint64((*obj.Uint8))) + } + if obj.Uint8List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64(uint64((*obj.Uint8List))) + } + if obj.Uint32 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64(uint64((*obj.Uint32))) + } + if obj.Uint32List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64(uint64((*obj.Uint32List))) + } + if obj.Uint64 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64((*obj.Uint64)) + } + if obj.Uint64List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64((*obj.Uint64List)) + } + if obj.String == nil { + w.Write([]byte{0x80}) + } else { + w.WriteString((*obj.String)) + } + if obj.StringList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteString((*obj.StringList)) + } + if obj.ByteArray == nil { + w.Write([]byte{0x80}) + } else { + w.WriteBytes(obj.ByteArray[:]) + } + if obj.ByteArrayList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteBytes(obj.ByteArrayList[:]) + } + if obj.ByteSlice == nil { + w.Write([]byte{0x80}) + } else { + w.WriteBytes((*obj.ByteSlice)) + } + if obj.ByteSliceList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteBytes((*obj.ByteSliceList)) + } + if obj.Struct == nil { + w.Write([]byte{0xC0}) + } else { + _tmp1 := w.List() + w.WriteUint64(uint64(obj.Struct.A)) + w.ListEnd(_tmp1) + } + if obj.StructString == nil { + w.Write([]byte{0x80}) + } else { + _tmp2 := w.List() + w.WriteUint64(uint64(obj.StructString.A)) + w.ListEnd(_tmp2) + } + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Uint8: + var _tmp2 *byte + if _tmp3, _tmp4, err := dec.Kind(); err != nil { + return err + } else if _tmp4 != 0 || _tmp3 != rlp.String { + _tmp1, err := dec.Uint8() + if err != nil { + return err + } + _tmp2 = &_tmp1 + } + _tmp0.Uint8 = _tmp2 + // Uint8List: + var _tmp6 *byte + if _tmp7, _tmp8, err := dec.Kind(); err != nil { + return err + } else if _tmp8 != 0 || _tmp7 != rlp.List { + _tmp5, err := dec.Uint8() + if err != nil { + return err + } + _tmp6 = &_tmp5 + } + _tmp0.Uint8List = _tmp6 + // Uint32: + var _tmp10 *uint32 + if _tmp11, _tmp12, err := dec.Kind(); err != nil { + return err + } else if _tmp12 != 0 || _tmp11 != rlp.String { + _tmp9, err := dec.Uint32() + if err != nil { + return err + } + _tmp10 = &_tmp9 + } + _tmp0.Uint32 = _tmp10 + // Uint32List: + var _tmp14 *uint32 + if _tmp15, _tmp16, err := dec.Kind(); err != nil { + return err + } else if _tmp16 != 0 || _tmp15 != rlp.List { + _tmp13, err := dec.Uint32() + if err != nil { + return err + } + _tmp14 = &_tmp13 + } + _tmp0.Uint32List = _tmp14 + // Uint64: + var _tmp18 *uint64 + if _tmp19, _tmp20, err := dec.Kind(); err != nil { + return err + } else if _tmp20 != 0 || _tmp19 != rlp.String { + _tmp17, err := dec.Uint64() + if err != nil { + return err + } + _tmp18 = &_tmp17 + } + _tmp0.Uint64 = _tmp18 + // Uint64List: + var _tmp22 *uint64 + if _tmp23, _tmp24, err := dec.Kind(); err != nil { + return err + } else if _tmp24 != 0 || _tmp23 != rlp.List { + _tmp21, err := dec.Uint64() + if err != nil { + return err + } + _tmp22 = &_tmp21 + } + _tmp0.Uint64List = _tmp22 + // String: + var _tmp26 *string + if _tmp27, _tmp28, err := dec.Kind(); err != nil { + return err + } else if _tmp28 != 0 || _tmp27 != rlp.String { + _tmp25, err := dec.String() + if err != nil { + return err + } + _tmp26 = &_tmp25 + } + _tmp0.String = _tmp26 + // StringList: + var _tmp30 *string + if _tmp31, _tmp32, err := dec.Kind(); err != nil { + return err + } else if _tmp32 != 0 || _tmp31 != rlp.List { + _tmp29, err := dec.String() + if err != nil { + return err + } + _tmp30 = &_tmp29 + } + _tmp0.StringList = _tmp30 + // ByteArray: + var _tmp34 *[3]byte + if _tmp35, _tmp36, err := dec.Kind(); err != nil { + return err + } else if _tmp36 != 0 || _tmp35 != rlp.String { + var _tmp33 [3]byte + if err := dec.ReadBytes(_tmp33[:]); err != nil { + return err + } + _tmp34 = &_tmp33 + } + _tmp0.ByteArray = _tmp34 + // ByteArrayList: + var _tmp38 *[3]byte + if _tmp39, _tmp40, err := dec.Kind(); err != nil { + return err + } else if _tmp40 != 0 || _tmp39 != rlp.List { + var _tmp37 [3]byte + if err := dec.ReadBytes(_tmp37[:]); err != nil { + return err + } + _tmp38 = &_tmp37 + } + _tmp0.ByteArrayList = _tmp38 + // ByteSlice: + var _tmp42 *[]byte + if _tmp43, _tmp44, err := dec.Kind(); err != nil { + return err + } else if _tmp44 != 0 || _tmp43 != rlp.String { + _tmp41, err := dec.Bytes() + if err != nil { + return err + } + _tmp42 = &_tmp41 + } + _tmp0.ByteSlice = _tmp42 + // ByteSliceList: + var _tmp46 *[]byte + if _tmp47, _tmp48, err := dec.Kind(); err != nil { + return err + } else if _tmp48 != 0 || _tmp47 != rlp.List { + _tmp45, err := dec.Bytes() + if err != nil { + return err + } + _tmp46 = &_tmp45 + } + _tmp0.ByteSliceList = _tmp46 + // Struct: + var _tmp51 *Aux + if _tmp52, _tmp53, err := dec.Kind(); err != nil { + return err + } else if _tmp53 != 0 || _tmp52 != rlp.List { + var _tmp49 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp50, err := dec.Uint32() + if err != nil { + return err + } + _tmp49.A = _tmp50 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp51 = &_tmp49 + } + _tmp0.Struct = _tmp51 + // StructString: + var _tmp56 *Aux + if _tmp57, _tmp58, err := dec.Kind(); err != nil { + return err + } else if _tmp58 != 0 || _tmp57 != rlp.String { + var _tmp54 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp55, err := dec.Uint32() + if err != nil { + return err + } + _tmp54.A = _tmp55 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp56 = &_tmp54 + } + _tmp0.StructString = _tmp56 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/optional.in.txt b/rlp/rlpgen/testdata/optional.in.txt new file mode 100644 index 0000000000..f1ac9f7899 --- /dev/null +++ b/rlp/rlpgen/testdata/optional.in.txt @@ -0,0 +1,17 @@ +// -*- mode: go -*- + +package test + +type Aux struct { + A uint64 +} + +type Test struct { + Uint64 uint64 `rlp:"optional"` + Pointer *uint64 `rlp:"optional"` + String string `rlp:"optional"` + Slice []uint64 `rlp:"optional"` + Array [3]byte `rlp:"optional"` + NamedStruct Aux `rlp:"optional"` + AnonStruct struct{ A string } `rlp:"optional"` +} diff --git a/rlp/rlpgen/testdata/optional.out.txt b/rlp/rlpgen/testdata/optional.out.txt new file mode 100644 index 0000000000..8b4cfa1817 --- /dev/null +++ b/rlp/rlpgen/testdata/optional.out.txt @@ -0,0 +1,153 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + _tmp1 := obj.Uint64 != 0 + _tmp2 := obj.Pointer != nil + _tmp3 := obj.String != "" + _tmp4 := len(obj.Slice) > 0 + _tmp5 := obj.Array != ([3]byte{}) + _tmp6 := obj.NamedStruct != (Aux{}) + _tmp7 := obj.AnonStruct != (struct{ A string }{}) + if _tmp1 || _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + w.WriteUint64(obj.Uint64) + } + if _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + if obj.Pointer == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64((*obj.Pointer)) + } + } + if _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + w.WriteString(obj.String) + } + if _tmp4 || _tmp5 || _tmp6 || _tmp7 { + _tmp8 := w.List() + for _, _tmp9 := range obj.Slice { + w.WriteUint64(_tmp9) + } + w.ListEnd(_tmp8) + } + if _tmp5 || _tmp6 || _tmp7 { + w.WriteBytes(obj.Array[:]) + } + if _tmp6 || _tmp7 { + _tmp10 := w.List() + w.WriteUint64(obj.NamedStruct.A) + w.ListEnd(_tmp10) + } + if _tmp7 { + _tmp11 := w.List() + w.WriteString(obj.AnonStruct.A) + w.ListEnd(_tmp11) + } + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Uint64: + if dec.MoreDataInList() { + _tmp1, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.Uint64 = _tmp1 + // Pointer: + if dec.MoreDataInList() { + _tmp2, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.Pointer = &_tmp2 + // String: + if dec.MoreDataInList() { + _tmp3, err := dec.String() + if err != nil { + return err + } + _tmp0.String = _tmp3 + // Slice: + if dec.MoreDataInList() { + var _tmp4 []uint64 + if _, err := dec.List(); err != nil { + return err + } + for dec.MoreDataInList() { + _tmp5, err := dec.Uint64() + if err != nil { + return err + } + _tmp4 = append(_tmp4, _tmp5) + } + if err := dec.ListEnd(); err != nil { + return err + } + _tmp0.Slice = _tmp4 + // Array: + if dec.MoreDataInList() { + var _tmp6 [3]byte + if err := dec.ReadBytes(_tmp6[:]); err != nil { + return err + } + _tmp0.Array = _tmp6 + // NamedStruct: + if dec.MoreDataInList() { + var _tmp7 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp8, err := dec.Uint64() + if err != nil { + return err + } + _tmp7.A = _tmp8 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp0.NamedStruct = _tmp7 + // AnonStruct: + if dec.MoreDataInList() { + var _tmp9 struct{ A string } + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp10, err := dec.String() + if err != nil { + return err + } + _tmp9.A = _tmp10 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp0.AnonStruct = _tmp9 + } + } + } + } + } + } + } + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/rawvalue.in.txt b/rlp/rlpgen/testdata/rawvalue.in.txt new file mode 100644 index 0000000000..daa050c3f6 --- /dev/null +++ b/rlp/rlpgen/testdata/rawvalue.in.txt @@ -0,0 +1,11 @@ +// -*- mode: go -*- + +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" + +type Test struct { + RawValue rlp.RawValue + PointerToRawValue *rlp.RawValue + SliceOfRawValue []rlp.RawValue +} diff --git a/rlp/rlpgen/testdata/rawvalue.out.txt b/rlp/rlpgen/testdata/rawvalue.out.txt new file mode 100644 index 0000000000..35bf145dcc --- /dev/null +++ b/rlp/rlpgen/testdata/rawvalue.out.txt @@ -0,0 +1,64 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.Write(obj.RawValue) + if obj.PointerToRawValue == nil { + w.Write([]byte{0x80}) + } else { + w.Write((*obj.PointerToRawValue)) + } + _tmp1 := w.List() + for _, _tmp2 := range obj.SliceOfRawValue { + w.Write(_tmp2) + } + w.ListEnd(_tmp1) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // RawValue: + _tmp1, err := dec.Raw() + if err != nil { + return err + } + _tmp0.RawValue = _tmp1 + // PointerToRawValue: + _tmp2, err := dec.Raw() + if err != nil { + return err + } + _tmp0.PointerToRawValue = &_tmp2 + // SliceOfRawValue: + var _tmp3 []rlp.RawValue + if _, err := dec.List(); err != nil { + return err + } + for dec.MoreDataInList() { + _tmp4, err := dec.Raw() + if err != nil { + return err + } + _tmp3 = append(_tmp3, _tmp4) + } + if err := dec.ListEnd(); err != nil { + return err + } + _tmp0.SliceOfRawValue = _tmp3 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/uint256.in.txt b/rlp/rlpgen/testdata/uint256.in.txt new file mode 100644 index 0000000000..ed16e0a788 --- /dev/null +++ b/rlp/rlpgen/testdata/uint256.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +import "github.com/holiman/uint256" + +type Test struct { + Int *uint256.Int + IntNoPtr uint256.Int +} diff --git a/rlp/rlpgen/testdata/uint256.out.txt b/rlp/rlpgen/testdata/uint256.out.txt new file mode 100644 index 0000000000..b560aa0e42 --- /dev/null +++ b/rlp/rlpgen/testdata/uint256.out.txt @@ -0,0 +1,44 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "github.com/holiman/uint256" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Int == nil { + w.Write(rlp.EmptyString) + } else { + w.WriteUint256(obj.Int) + } + w.WriteUint256(&obj.IntNoPtr) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Int: + var _tmp1 uint256.Int + if err := dec.ReadUint256(&_tmp1); err != nil { + return err + } + _tmp0.Int = &_tmp1 + // IntNoPtr: + var _tmp2 uint256.Int + if err := dec.ReadUint256(&_tmp2); err != nil { + return err + } + _tmp0.IntNoPtr = _tmp2 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/uints.in.txt b/rlp/rlpgen/testdata/uints.in.txt new file mode 100644 index 0000000000..8095da997d --- /dev/null +++ b/rlp/rlpgen/testdata/uints.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +type Test struct{ + A uint8 + B uint16 + C uint32 + D uint64 +} diff --git a/rlp/rlpgen/testdata/uints.out.txt b/rlp/rlpgen/testdata/uints.out.txt new file mode 100644 index 0000000000..cf973ec9a4 --- /dev/null +++ b/rlp/rlpgen/testdata/uints.out.txt @@ -0,0 +1,53 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.WriteUint64(uint64(obj.A)) + w.WriteUint64(uint64(obj.B)) + w.WriteUint64(uint64(obj.C)) + w.WriteUint64(obj.D) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp1, err := dec.Uint8() + if err != nil { + return err + } + _tmp0.A = _tmp1 + // B: + _tmp2, err := dec.Uint16() + if err != nil { + return err + } + _tmp0.B = _tmp2 + // C: + _tmp3, err := dec.Uint32() + if err != nil { + return err + } + _tmp0.C = _tmp3 + // D: + _tmp4, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.D = _tmp4 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/types.go b/rlp/rlpgen/types.go new file mode 100644 index 0000000000..ea7dc96d88 --- /dev/null +++ b/rlp/rlpgen/types.go @@ -0,0 +1,124 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "fmt" + "go/types" + "reflect" +) + +// typeReflectKind gives the reflect.Kind that represents typ. +func typeReflectKind(typ types.Type) reflect.Kind { + switch typ := typ.(type) { + case *types.Basic: + k := typ.Kind() + if k >= types.Bool && k <= types.Complex128 { + // value order matches for Bool..Complex128 + return reflect.Bool + reflect.Kind(k-types.Bool) + } + if k == types.String { + return reflect.String + } + if k == types.UnsafePointer { + return reflect.UnsafePointer + } + panic(fmt.Errorf("unhandled BasicKind %v", k)) + case *types.Array: + return reflect.Array + case *types.Chan: + return reflect.Chan + case *types.Interface: + return reflect.Interface + case *types.Map: + return reflect.Map + case *types.Pointer: + return reflect.Ptr + case *types.Signature: + return reflect.Func + case *types.Slice: + return reflect.Slice + case *types.Struct: + return reflect.Struct + default: + panic(fmt.Errorf("unhandled type %T", typ)) + } +} + +// nonZeroCheck returns the expression that checks whether 'v' is a non-zero value of type 'vtyp'. +func nonZeroCheck(v string, vtyp types.Type, qualify types.Qualifier) string { + // Resolve type name. + typ := resolveUnderlying(vtyp) + switch typ := typ.(type) { + case *types.Basic: + k := typ.Kind() + switch { + case k == types.Bool: + return v + case k >= types.Uint && k <= types.Complex128: + return fmt.Sprintf("%s != 0", v) + case k == types.String: + return fmt.Sprintf(`%s != ""`, v) + default: + panic(fmt.Errorf("unhandled BasicKind %v", k)) + } + case *types.Array, *types.Struct: + return fmt.Sprintf("%s != (%s{})", v, types.TypeString(vtyp, qualify)) + case *types.Interface, *types.Pointer, *types.Signature: + return fmt.Sprintf("%s != nil", v) + case *types.Slice, *types.Map: + return fmt.Sprintf("len(%s) > 0", v) + default: + panic(fmt.Errorf("unhandled type %T", typ)) + } +} + +// isBigInt checks whether 'typ' is "math/big".Int. +func isBigInt(typ types.Type) bool { + named, ok := typ.(*types.Named) + if !ok { + return false + } + name := named.Obj() + return name.Pkg().Path() == "math/big" && name.Name() == "Int" +} + +// isUint256 checks whether 'typ' is "github.com/holiman/uint256".Int. +func isUint256(typ types.Type) bool { + named, ok := typ.(*types.Named) + if !ok { + return false + } + name := named.Obj() + return name.Pkg().Path() == "github.com/holiman/uint256" && name.Name() == "Int" +} + +// isByte checks whether the underlying type of 'typ' is uint8. +func isByte(typ types.Type) bool { + basic, ok := resolveUnderlying(typ).(*types.Basic) + return ok && basic.Kind() == types.Uint8 +} + +func resolveUnderlying(typ types.Type) types.Type { + for { + t := typ.Underlying() + if t == typ { + return t + } + typ = t + } +} diff --git a/rlp/safe.go b/rlp/safe.go new file mode 100644 index 0000000000..3c910337b6 --- /dev/null +++ b/rlp/safe.go @@ -0,0 +1,27 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +//go:build nacl || js || !cgo +// +build nacl js !cgo + +package rlp + +import "reflect" + +// byteArrayBytes returns a slice of the byte array v. +func byteArrayBytes(v reflect.Value, length int) []byte { + return v.Slice(0, length).Bytes() +} diff --git a/rlp/typecache.go b/rlp/typecache.go index 3df799e1ec..57e4f2e46a 100644 --- a/rlp/typecache.go +++ b/rlp/typecache.go @@ -18,139 +18,221 @@ package rlp import ( "fmt" + "maps" "reflect" - "strings" "sync" + "sync/atomic" + + "github.com/XinFinOrg/XDPoSChain/rlp/internal/rlpstruct" ) -var ( - typeCacheMutex sync.RWMutex - typeCache = make(map[typekey]*typeinfo) -) - +// typeinfo is an entry in the type cache. type typeinfo struct { - decoder - writer -} - -// represents struct tags -type tags struct { - // rlp:"nil" controls whether empty input results in a nil pointer. - nilOK bool - // rlp:"tail" controls whether this field swallows additional list - // elements. It can only be set for the last field, which must be - // of slice type. - tail bool - // rlp:"-" ignores fields. - ignored bool + decoder decoder + decoderErr error // error from makeDecoder + writer writer + writerErr error // error from makeWriter } +// typekey is the key of a type in typeCache. It includes the struct tags because +// they might generate a different decoder. type typekey struct { reflect.Type - // the key must include the struct tags because they - // might generate a different decoder. - tags + rlpstruct.Tags } type decoder func(*Stream, reflect.Value) error -type writer func(reflect.Value, *encbuf) error +type writer func(reflect.Value, *encBuffer) error -func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) { - typeCacheMutex.RLock() - info := typeCache[typekey{typ, tags}] - typeCacheMutex.RUnlock() - if info != nil { - return info, nil - } - // not in the cache, need to generate info for this type. - typeCacheMutex.Lock() - defer typeCacheMutex.Unlock() - return cachedTypeInfo1(typ, tags) +var theTC = newTypeCache() + +type typeCache struct { + cur atomic.Value + + // This lock synchronizes writers. + mu sync.Mutex + next map[typekey]*typeinfo } -func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) { +func newTypeCache() *typeCache { + c := new(typeCache) + c.cur.Store(make(map[typekey]*typeinfo)) + return c +} + +func cachedDecoder(typ reflect.Type) (decoder, error) { + info := theTC.info(typ) + return info.decoder, info.decoderErr +} + +func cachedWriter(typ reflect.Type) (writer, error) { + info := theTC.info(typ) + return info.writer, info.writerErr +} + +func (c *typeCache) info(typ reflect.Type) *typeinfo { + key := typekey{Type: typ} + if info := c.cur.Load().(map[typekey]*typeinfo)[key]; info != nil { + return info + } + + // Not in the cache, need to generate info for this type. + return c.generate(typ, rlpstruct.Tags{}) +} + +func (c *typeCache) generate(typ reflect.Type, tags rlpstruct.Tags) *typeinfo { + c.mu.Lock() + defer c.mu.Unlock() + + cur := c.cur.Load().(map[typekey]*typeinfo) + if info := cur[typekey{typ, tags}]; info != nil { + return info + } + + // Copy cur to next. + c.next = maps.Clone(cur) + + // Generate. + info := c.infoWhileGenerating(typ, tags) + + // next -> cur + c.cur.Store(c.next) + c.next = nil + return info +} + +func (c *typeCache) infoWhileGenerating(typ reflect.Type, tags rlpstruct.Tags) *typeinfo { key := typekey{typ, tags} - info := typeCache[key] - if info != nil { - // another goroutine got the write lock first - return info, nil + if info := c.next[key]; info != nil { + return info } - // put a dummmy value into the cache before generating. - // if the generator tries to lookup itself, it will get + // Put a dummy value into the cache before generating. + // If the generator tries to lookup itself, it will get // the dummy value and won't call itself recursively. - typeCache[key] = new(typeinfo) - info, err := genTypeInfo(typ, tags) - if err != nil { - // remove the dummy value if the generator fails - delete(typeCache, key) - return nil, err - } - *typeCache[key] = *info - return typeCache[key], err + info := new(typeinfo) + c.next[key] = info + info.generate(typ, tags) + return info } type field struct { - index int - info *typeinfo + index int + info *typeinfo + optional bool } +// structFields resolves the typeinfo of all public fields in a struct type. func structFields(typ reflect.Type) (fields []field, err error) { + // Convert fields to rlpstruct.Field. + var allStructFields []rlpstruct.Field for i := 0; i < typ.NumField(); i++ { - if f := typ.Field(i); f.PkgPath == "" { // exported - tags, err := parseStructTag(typ, i) - if err != nil { - return nil, err - } - if tags.ignored { - continue - } - info, err := cachedTypeInfo1(f.Type, tags) - if err != nil { - return nil, err - } - fields = append(fields, field{i, info}) + rf := typ.Field(i) + allStructFields = append(allStructFields, rlpstruct.Field{ + Name: rf.Name, + Index: i, + Exported: rf.PkgPath == "", + Tag: string(rf.Tag), + Type: *rtypeToStructType(rf.Type, nil), + }) + } + + // Filter/validate fields. + structFields, structTags, err := rlpstruct.ProcessFields(allStructFields) + if err != nil { + if tagErr, ok := err.(rlpstruct.TagError); ok { + tagErr.StructType = typ.String() + return nil, tagErr } + return nil, err + } + + // Resolve typeinfo. + for i, sf := range structFields { + typ := typ.Field(sf.Index).Type + tags := structTags[i] + info := theTC.infoWhileGenerating(typ, tags) + fields = append(fields, field{sf.Index, info, tags.Optional}) } return fields, nil } -func parseStructTag(typ reflect.Type, fi int) (tags, error) { - f := typ.Field(fi) - var ts tags - for _, t := range strings.Split(f.Tag.Get("rlp"), ",") { - switch t = strings.TrimSpace(t); t { - case "": - case "-": - ts.ignored = true - case "nil": - ts.nilOK = true - case "tail": - ts.tail = true - if fi != typ.NumField()-1 { - return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name) - } - if f.Type.Kind() != reflect.Slice { - return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (field type is not slice)`, typ, f.Name) - } - default: - return ts, fmt.Errorf("rlp: unknown struct tag %q on %v.%s", t, typ, f.Name) +// firstOptionalField returns the index of the first field with "optional" tag. +func firstOptionalField(fields []field) int { + for i, f := range fields { + if f.optional { + return i } } - return ts, nil + return len(fields) } -func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) { - info = new(typeinfo) - if info.decoder, err = makeDecoder(typ, tags); err != nil { - return nil, err +type structFieldError struct { + typ reflect.Type + field int + err error +} + +func (e structFieldError) Error() string { + return fmt.Sprintf("%v (struct field %v.%s)", e.err, e.typ, e.typ.Field(e.field).Name) +} + +func (i *typeinfo) generate(typ reflect.Type, tags rlpstruct.Tags) { + i.decoder, i.decoderErr = makeDecoder(typ, tags) + i.writer, i.writerErr = makeWriter(typ, tags) +} + +// rtypeToStructType converts typ to rlpstruct.Type. +func rtypeToStructType(typ reflect.Type, rec map[reflect.Type]*rlpstruct.Type) *rlpstruct.Type { + k := typ.Kind() + if k == reflect.Invalid { + panic("invalid kind") } - if info.writer, err = makeWriter(typ, tags); err != nil { - return nil, err + + if prev := rec[typ]; prev != nil { + return prev // short-circuit for recursive types + } + if rec == nil { + rec = make(map[reflect.Type]*rlpstruct.Type) + } + + t := &rlpstruct.Type{ + Name: typ.String(), + Kind: k, + IsEncoder: typ.Implements(encoderInterface), + IsDecoder: typ.Implements(decoderInterface), + } + rec[typ] = t + if k == reflect.Array || k == reflect.Slice || k == reflect.Ptr { + t.Elem = rtypeToStructType(typ.Elem(), rec) + } + return t +} + +// typeNilKind gives the RLP value kind for nil pointers to 'typ'. +func typeNilKind(typ reflect.Type, tags rlpstruct.Tags) Kind { + styp := rtypeToStructType(typ, nil) + + var nk rlpstruct.NilKind + if tags.NilOK { + nk = tags.NilKind + } else { + nk = styp.DefaultNilValue() + } + switch nk { + case rlpstruct.NilKindString: + return String + case rlpstruct.NilKindList: + return List + default: + panic("invalid nil kind value") } - return info, nil } func isUint(k reflect.Kind) bool { return k >= reflect.Uint && k <= reflect.Uintptr } + +func isByte(typ reflect.Type) bool { + return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface) +} diff --git a/rlp/unsafe.go b/rlp/unsafe.go new file mode 100644 index 0000000000..10868caaf2 --- /dev/null +++ b/rlp/unsafe.go @@ -0,0 +1,30 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +//go:build !nacl && !js && cgo +// +build !nacl,!js,cgo + +package rlp + +import ( + "reflect" + "unsafe" +) + +// byteArrayBytes returns a slice of the byte array v. +func byteArrayBytes(v reflect.Value, length int) []byte { + return unsafe.Slice((*byte)(unsafe.Pointer(v.UnsafeAddr())), length) +} diff --git a/rpc/types.go b/rpc/types.go index a48dd7a1f3..f3b50a259b 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -19,7 +19,7 @@ package rpc import ( "context" "encoding/json" - "fmt" + "errors" "math" "strings" @@ -109,7 +109,7 @@ func (bn *BlockNumber) UnmarshalJSON(data []byte) error { return err } if blckNum > math.MaxInt64 { - return fmt.Errorf("block number larger than int64") + return errors.New("block number larger than int64") } *bn = BlockNumber(blckNum) return nil @@ -131,7 +131,7 @@ func (e *EpochNumber) UnmarshalJSON(data []byte) error { return err } if eNum > math.MaxInt64 { - return fmt.Errorf("EpochNumber too high") + return errors.New("EpochNumber too high") } *e = EpochNumber(eNum) @@ -154,7 +154,7 @@ func (bnh *BlockNumberOrHash) UnmarshalJSON(data []byte) error { err := json.Unmarshal(data, &e) if err == nil { if e.BlockNumber != nil && e.BlockHash != nil { - return fmt.Errorf("cannot specify both BlockHash and BlockNumber, choose one or the other") + return errors.New("cannot specify both BlockHash and BlockNumber, choose one or the other") } bnh.BlockNumber = e.BlockNumber bnh.BlockHash = e.BlockHash @@ -198,7 +198,7 @@ func (bnh *BlockNumberOrHash) UnmarshalJSON(data []byte) error { return err } if blckNum > math.MaxInt64 { - return fmt.Errorf("blocknumber too high") + return errors.New("blocknumber too high") } bn := BlockNumber(blckNum) bnh.BlockNumber = &bn diff --git a/tests/block_test_util.go b/tests/block_test_util.go index 684fbf1d66..72e9f21080 100644 --- a/tests/block_test_util.go +++ b/tests/block_test_util.go @@ -21,10 +21,12 @@ import ( "bytes" "encoding/hex" "encoding/json" + "errors" "fmt" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math/big" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" + "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" "github.com/XinFinOrg/XDPoSChain/common/math" @@ -150,17 +152,18 @@ func (t *BlockTest) genesis(config *params.ChainConfig) *core.Genesis { } } -/* See https://github.com/ethereum/tests/wiki/Blockchain-Tests-II +/* +See https://github.com/ethereum/tests/wiki/Blockchain-Tests-II - Whether a block is valid or not is a bit subtle, it's defined by presence of - blockHeader, transactions and uncleHeaders fields. If they are missing, the block is - invalid and we must verify that we do not accept it. + Whether a block is valid or not is a bit subtle, it's defined by presence of + blockHeader, transactions and uncleHeaders fields. If they are missing, the block is + invalid and we must verify that we do not accept it. - Since some tests mix valid and invalid blocks we need to check this for every block. + Since some tests mix valid and invalid blocks we need to check this for every block. - If a block is invalid it does not necessarily fail the test, if it's invalidness is - expected we are expected to ignore it and continue processing and then validate the - post state. + If a block is invalid it does not necessarily fail the test, if it's invalidness is + expected we are expected to ignore it and continue processing and then validate the + post state. */ func (t *BlockTest) insertBlocks(blockchain *core.BlockChain) ([]btBlock, error) { validBlocks := make([]btBlock, 0) @@ -185,7 +188,7 @@ func (t *BlockTest) insertBlocks(blockchain *core.BlockChain) ([]btBlock, error) } } if b.BlockHeader == nil { - return nil, fmt.Errorf("Block insertion should have failed") + return nil, errors.New("Block insertion should have failed") } // validate RLP decoding by checking all values against test file JSON diff --git a/tests/init_test.go b/tests/init_test.go index 77a084d95e..53fe8b4120 100644 --- a/tests/init_test.go +++ b/tests/init_test.go @@ -18,6 +18,7 @@ package tests import ( "encoding/json" + "errors" "fmt" "io" "os" @@ -169,7 +170,7 @@ func (tm *testMatcher) checkFailure(t *testing.T, name string, err error) error t.Logf("error: %v", err) return nil } else { - return fmt.Errorf("test succeeded unexpectedly") + return errors.New("test succeeded unexpectedly") } } return err diff --git a/tests/rlp_test_util.go b/tests/rlp_test_util.go index 9fb53097f3..9315d8d8ba 100644 --- a/tests/rlp_test_util.go +++ b/tests/rlp_test_util.go @@ -46,7 +46,7 @@ type RLPTest struct { func (t *RLPTest) Run() error { outb, err := hex.DecodeString(t.Out) if err != nil { - return fmt.Errorf("invalid hex in Out") + return errors.New("invalid hex in Out") } // Handle simple decoding tests with no actual In value. @@ -74,7 +74,7 @@ func checkDecodeInterface(b []byte, isValid bool) error { case isValid && err != nil: return fmt.Errorf("decoding failed: %v", err) case !isValid && err == nil: - return fmt.Errorf("decoding of invalid value succeeded") + return errors.New("decoding of invalid value succeeded") } return nil } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 18d09b0bc7..742d2c5955 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -135,23 +135,16 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD if err != nil { return nil, err } + + // Prepare the EVM. context := core.NewEVMContext(msg, block.Header(), nil, &t.json.Env.Coinbase) context.GetHash = vmTestBlockHash evm := vm.NewEVM(context, statedb, nil, config, vmconfig) - if config.IsEIP1559(context.BlockNumber) { - statedb.AddAddressToAccessList(msg.From()) - if dst := msg.To(); dst != nil { - statedb.AddAddressToAccessList(*dst) - // If it's a create-tx, the destination will be added inside evm.create - } - for _, addr := range evm.ActivePrecompiles() { - statedb.AddAddressToAccessList(addr) - } - } + // Execute the message. + snapshot := statedb.Snapshot() gaspool := new(core.GasPool) gaspool.AddGas(block.GasLimit()) - snapshot := statedb.Snapshot() coinbase := &t.json.Env.Coinbase if _, _, _, err, _ := core.ApplyMessage(evm, msg, gaspool, *coinbase); err != nil { @@ -160,6 +153,8 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { return statedb, fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs) } + + // Commit block root, _ := statedb.Commit(config.IsEIP158(block.Number())) if root != common.Hash(post.Root) { return statedb, fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root) @@ -245,7 +240,7 @@ func (tx *stTransaction) toMessage(ps stPostState, number *big.Int) (core.Messag if err != nil { return nil, fmt.Errorf("invalid tx data %q", dataHex) } - msg := types.NewMessage(from, to, tx.Nonce, value, gasLimit, tx.GasPrice, data, true, nil, number) + msg := types.NewMessage(from, to, tx.Nonce, value, gasLimit, tx.GasPrice, data, nil, true, nil, number) return msg, nil } diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go index bd516a1827..b9bdffe485 100644 --- a/tests/vm_test_util.go +++ b/tests/vm_test_util.go @@ -19,6 +19,7 @@ package tests import ( "bytes" "encoding/json" + "errors" "fmt" "math/big" @@ -85,10 +86,10 @@ func (t *VMTest) Run(vmconfig vm.Config) error { if t.json.GasRemaining == nil { if err == nil { - return fmt.Errorf("gas unspecified (indicating an error), but VM returned no error") + return errors.New("gas unspecified (indicating an error), but VM returned no error") } if gasRemaining > 0 { - return fmt.Errorf("gas unspecified (indicating an error), but VM returned gas remaining > 0") + return errors.New("gas unspecified (indicating an error), but VM returned gas remaining > 0") } return nil } diff --git a/trie/committer.go b/trie/committer.go index 9db314d980..435da51981 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -22,6 +22,7 @@ import ( "sync" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/rlp" "golang.org/x/crypto/sha3" ) @@ -46,7 +47,7 @@ type leaf struct { // processed sequentially - onleaf will never be called in parallel or out of order. type committer struct { tmp sliceBuffer - sha keccakState + sha crypto.KeccakState onleaf LeafCallback leafCh chan *leaf @@ -57,7 +58,7 @@ var committerPool = sync.Pool{ New: func() interface{} { return &committer{ tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + sha: sha3.NewLegacyKeccak256().(crypto.KeccakState), } }, } diff --git a/trie/hasher.go b/trie/hasher.go index a2b385ad5b..1dc9aae689 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -17,21 +17,13 @@ package trie import ( - "hash" "sync" + "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/rlp" "golang.org/x/crypto/sha3" ) -// keccakState wraps sha3.state. In addition to the usual hash methods, it also supports -// Read to get a variable amount of data from the hash state. Read is faster than Sum -// because it doesn't copy the internal state, but also modifies the internal state. -type keccakState interface { - hash.Hash - Read([]byte) (int, error) -} - type sliceBuffer []byte func (b *sliceBuffer) Write(data []byte) (n int, err error) { @@ -46,7 +38,7 @@ func (b *sliceBuffer) Reset() { // hasher is a type used for the trie Hash operation. A hasher has some // internal preallocated temp space type hasher struct { - sha keccakState + sha crypto.KeccakState tmp sliceBuffer parallel bool // Whether to use paralallel threads when hashing } @@ -56,7 +48,7 @@ var hasherPool = sync.Pool{ New: func() interface{} { return &hasher{ tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + sha: sha3.NewLegacyKeccak256().(crypto.KeccakState), } }, } diff --git a/trie/proof.go b/trie/proof.go index fff4142401..0db68c103c 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -395,11 +395,11 @@ func hasRightElement(node Node, key []byte) bool { // Expect the normal case, this function can also be used to verify the following // range proofs(note this function doesn't accept zero element proof): // -// - All elements proof. In this case the left and right proof can be nil, but the -// range should be all the leaves in the trie. +// - All elements proof. In this case the left and right proof can be nil, but the +// range should be all the leaves in the trie. // -// - One element proof. In this case no matter the left edge proof is a non-existent -// proof or not, we can always verify the correctness of the proof. +// - One element proof. In this case no matter the left edge proof is a non-existent +// proof or not, we can always verify the correctness of the proof. // // Except returning the error to indicate the proof is valid or not, the function will // also return a flag to indicate whether there exists more accounts/slots in the trie. @@ -439,7 +439,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, valu return err, false } if !bytes.Equal(val, values[0]) { - return fmt.Errorf("correct proof but invalid data"), false + return errors.New("correct proof but invalid data"), false } return nil, hasRightElement(root, keys[0]) } diff --git a/trie/trie_test.go b/trie/trie_test.go index 89d47bf7dc..202fc76d80 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -19,6 +19,7 @@ package trie import ( "bytes" "encoding/binary" + "errors" "fmt" "math/big" "math/rand" @@ -446,7 +447,7 @@ func runRandTest(rt randTest) bool { checktr.Update(it.Key, it.Value) } if tr.Hash() != checktr.Hash() { - rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash") + rt[i].err = errors.New("hash mismatch in opItercheckhash") } } // Abort the test on error. diff --git a/whisper/whisperv5/api.go b/whisper/whisperv5/api.go index 89e9c2860b..b28ea5075d 100644 --- a/whisper/whisperv5/api.go +++ b/whisper/whisperv5/api.go @@ -256,8 +256,7 @@ func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, er // Set asymmetric key that is used to encrypt the message if pubKeyGiven { - params.Dst = crypto.ToECDSAPub(req.PublicKey) - if !ValidatePublicKey(params.Dst) { + if params.Dst, err = crypto.UnmarshalPubkey(req.PublicKey); err != nil { return false, ErrInvalidPublicKey } } @@ -333,8 +332,7 @@ func (api *PublicWhisperAPI) Messages(ctx context.Context, crit Criteria) (*rpc. } if len(crit.Sig) > 0 { - filter.Src = crypto.ToECDSAPub(crit.Sig) - if !ValidatePublicKey(filter.Src) { + if filter.Src, err = crypto.UnmarshalPubkey(crit.Sig); err != nil { return nil, ErrInvalidSigningPubKey } } @@ -473,7 +471,7 @@ func (api *PublicWhisperAPI) GetFilterMessages(id string) ([]*Message, error) { f := api.w.GetFilter(id) if f == nil { api.mu.Unlock() - return nil, fmt.Errorf("filter not found") + return nil, errors.New("filter not found") } api.lastUsed[id] = time.Now() api.mu.Unlock() @@ -517,8 +515,7 @@ func (api *PublicWhisperAPI) NewMessageFilter(req Criteria) (string, error) { } if len(req.Sig) > 0 { - src = crypto.ToECDSAPub(req.Sig) - if !ValidatePublicKey(src) { + if src, err = crypto.UnmarshalPubkey(req.Sig); err != nil { return "", ErrInvalidSigningPubKey } } diff --git a/whisper/whisperv5/filter.go b/whisper/whisperv5/filter.go index d91e9747bf..5c64d46910 100644 --- a/whisper/whisperv5/filter.go +++ b/whisper/whisperv5/filter.go @@ -18,6 +18,7 @@ package whisperv5 import ( "crypto/ecdsa" + "errors" "fmt" "sync" @@ -66,7 +67,7 @@ func (fs *Filters) Install(watcher *Filter) (string, error) { defer fs.mutex.Unlock() if fs.watchers[id] != nil { - return "", fmt.Errorf("failed to generate unique ID") + return "", errors.New("failed to generate unique ID") } if watcher.expectsSymmetricEncryption() { diff --git a/whisper/whisperv5/whisper.go b/whisper/whisperv5/whisper.go index 66901034f6..9bbb1cae01 100644 --- a/whisper/whisperv5/whisper.go +++ b/whisper/whisperv5/whisper.go @@ -247,7 +247,7 @@ func (w *Whisper) NewKeyPair() (string, error) { return "", err } if !validatePrivateKey(key) { - return "", fmt.Errorf("failed to generate valid key") + return "", errors.New("failed to generate valid key") } id, err := GenerateRandomID() @@ -259,7 +259,7 @@ func (w *Whisper) NewKeyPair() (string, error) { defer w.keyMu.Unlock() if w.privateKeys[id] != nil { - return "", fmt.Errorf("failed to generate unique ID") + return "", errors.New("failed to generate unique ID") } w.privateKeys[id] = key return id, nil @@ -305,7 +305,7 @@ func (w *Whisper) GetPrivateKey(id string) (*ecdsa.PrivateKey, error) { defer w.keyMu.RUnlock() key := w.privateKeys[id] if key == nil { - return nil, fmt.Errorf("invalid id") + return nil, errors.New("invalid id") } return key, nil } @@ -318,7 +318,7 @@ func (w *Whisper) GenerateSymKey() (string, error) { if err != nil { return "", err } else if !validateSymmetricKey(key) { - return "", fmt.Errorf("error in GenerateSymKey: crypto/rand failed to generate random data") + return "", errors.New("error in GenerateSymKey: crypto/rand failed to generate random data") } id, err := GenerateRandomID() @@ -330,7 +330,7 @@ func (w *Whisper) GenerateSymKey() (string, error) { defer w.keyMu.Unlock() if w.symKeys[id] != nil { - return "", fmt.Errorf("failed to generate unique ID") + return "", errors.New("failed to generate unique ID") } w.symKeys[id] = key return id, nil @@ -351,7 +351,7 @@ func (w *Whisper) AddSymKeyDirect(key []byte) (string, error) { defer w.keyMu.Unlock() if w.symKeys[id] != nil { - return "", fmt.Errorf("failed to generate unique ID") + return "", errors.New("failed to generate unique ID") } w.symKeys[id] = key return id, nil @@ -364,7 +364,7 @@ func (w *Whisper) AddSymKeyFromPassword(password string) (string, error) { return "", fmt.Errorf("failed to generate ID: %s", err) } if w.HasSymKey(id) { - return "", fmt.Errorf("failed to generate unique ID") + return "", errors.New("failed to generate unique ID") } derived, err := deriveKeyMaterial([]byte(password), EnvelopeVersion) @@ -377,7 +377,7 @@ func (w *Whisper) AddSymKeyFromPassword(password string) (string, error) { // double check is necessary, because deriveKeyMaterial() is very slow if w.symKeys[id] != nil { - return "", fmt.Errorf("critical error: failed to generate unique ID") + return "", errors.New("critical error: failed to generate unique ID") } w.symKeys[id] = derived return id, nil @@ -409,7 +409,7 @@ func (w *Whisper) GetSymKey(id string) ([]byte, error) { if w.symKeys[id] != nil { return w.symKeys[id], nil } - return nil, fmt.Errorf("non-existent key ID") + return nil, errors.New("non-existent key ID") } // Subscribe installs a new message handler used for filtering, decrypting @@ -427,7 +427,7 @@ func (w *Whisper) GetFilter(id string) *Filter { func (w *Whisper) Unsubscribe(id string) error { ok := w.filters.Uninstall(id) if !ok { - return fmt.Errorf("Unsubscribe: Invalid ID") + return errors.New("Unsubscribe: Invalid ID") } return nil } @@ -440,7 +440,7 @@ func (w *Whisper) Send(envelope *Envelope) error { return err } if !ok { - return fmt.Errorf("failed to add envelope") + return errors.New("failed to add envelope") } return err } @@ -576,7 +576,7 @@ func (wh *Whisper) add(envelope *Envelope) (bool, error) { if envelope.Expiry < now { if envelope.Expiry+SynchAllowance*2 < now { - return false, fmt.Errorf("very old message") + return false, errors.New("very old message") } else { log.Debug("expired envelope dropped", "hash", envelope.Hash().Hex()) return false, nil // drop envelope without error @@ -851,7 +851,7 @@ func GenerateRandomID() (id string, err error) { return "", err } if !validateSymmetricKey(buf) { - return "", fmt.Errorf("error in generateRandomID: crypto/rand failed to generate random data") + return "", errors.New("error in generateRandomID: crypto/rand failed to generate random data") } id = common.Bytes2Hex(buf) return id, err diff --git a/whisper/whisperv6/api.go b/whisper/whisperv6/api.go index 95106ee167..8711d6b195 100644 --- a/whisper/whisperv6/api.go +++ b/whisper/whisperv6/api.go @@ -275,8 +275,7 @@ func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, er // Set asymmetric key that is used to encrypt the message if pubKeyGiven { - params.Dst = crypto.ToECDSAPub(req.PublicKey) - if !ValidatePublicKey(params.Dst) { + if params.Dst, err = crypto.UnmarshalPubkey(req.PublicKey); err != nil { return false, ErrInvalidPublicKey } } @@ -352,8 +351,7 @@ func (api *PublicWhisperAPI) Messages(ctx context.Context, crit Criteria) (*rpc. } if len(crit.Sig) > 0 { - filter.Src = crypto.ToECDSAPub(crit.Sig) - if !ValidatePublicKey(filter.Src) { + if filter.Src, err = crypto.UnmarshalPubkey(crit.Sig); err != nil { return nil, ErrInvalidSigningPubKey } } @@ -492,7 +490,7 @@ func (api *PublicWhisperAPI) GetFilterMessages(id string) ([]*Message, error) { f := api.w.GetFilter(id) if f == nil { api.mu.Unlock() - return nil, fmt.Errorf("filter not found") + return nil, errors.New("filter not found") } api.lastUsed[id] = time.Now() api.mu.Unlock() @@ -536,8 +534,7 @@ func (api *PublicWhisperAPI) NewMessageFilter(req Criteria) (string, error) { } if len(req.Sig) > 0 { - src = crypto.ToECDSAPub(req.Sig) - if !ValidatePublicKey(src) { + if src, err = crypto.UnmarshalPubkey(req.Sig); err != nil { return "", ErrInvalidSigningPubKey } } diff --git a/whisper/whisperv6/filter.go b/whisper/whisperv6/filter.go index aae1de8480..801954c7c3 100644 --- a/whisper/whisperv6/filter.go +++ b/whisper/whisperv6/filter.go @@ -18,6 +18,7 @@ package whisperv6 import ( "crypto/ecdsa" + "errors" "fmt" "sync" @@ -65,7 +66,7 @@ func NewFilters(w *Whisper) *Filters { // Install will add a new filter to the filter collection func (fs *Filters) Install(watcher *Filter) (string, error) { if watcher.KeySym != nil && watcher.KeyAsym != nil { - return "", fmt.Errorf("filters must choose between symmetric and asymmetric keys") + return "", errors.New("filters must choose between symmetric and asymmetric keys") } if watcher.Messages == nil { @@ -81,7 +82,7 @@ func (fs *Filters) Install(watcher *Filter) (string, error) { defer fs.mutex.Unlock() if fs.watchers[id] != nil { - return "", fmt.Errorf("failed to generate unique ID") + return "", errors.New("failed to generate unique ID") } if watcher.expectsSymmetricEncryption() { diff --git a/whisper/whisperv6/whisper.go b/whisper/whisperv6/whisper.go index 53cf11a909..0b83c89d13 100644 --- a/whisper/whisperv6/whisper.go +++ b/whisper/whisperv6/whisper.go @@ -379,7 +379,7 @@ func (whisper *Whisper) NewKeyPair() (string, error) { return "", err } if !validatePrivateKey(key) { - return "", fmt.Errorf("failed to generate valid key") + return "", errors.New("failed to generate valid key") } id, err := GenerateRandomID() @@ -391,7 +391,7 @@ func (whisper *Whisper) NewKeyPair() (string, error) { defer whisper.keyMu.Unlock() if whisper.privateKeys[id] != nil { - return "", fmt.Errorf("failed to generate unique ID") + return "", errors.New("failed to generate unique ID") } whisper.privateKeys[id] = key return id, nil @@ -437,7 +437,7 @@ func (whisper *Whisper) GetPrivateKey(id string) (*ecdsa.PrivateKey, error) { defer whisper.keyMu.RUnlock() key := whisper.privateKeys[id] if key == nil { - return nil, fmt.Errorf("invalid id") + return nil, errors.New("invalid id") } return key, nil } @@ -449,7 +449,7 @@ func (whisper *Whisper) GenerateSymKey() (string, error) { if err != nil { return "", err } else if !validateDataIntegrity(key, aesKeyLength) { - return "", fmt.Errorf("error in GenerateSymKey: crypto/rand failed to generate random data") + return "", errors.New("error in GenerateSymKey: crypto/rand failed to generate random data") } id, err := GenerateRandomID() @@ -461,7 +461,7 @@ func (whisper *Whisper) GenerateSymKey() (string, error) { defer whisper.keyMu.Unlock() if whisper.symKeys[id] != nil { - return "", fmt.Errorf("failed to generate unique ID") + return "", errors.New("failed to generate unique ID") } whisper.symKeys[id] = key return id, nil @@ -482,7 +482,7 @@ func (whisper *Whisper) AddSymKeyDirect(key []byte) (string, error) { defer whisper.keyMu.Unlock() if whisper.symKeys[id] != nil { - return "", fmt.Errorf("failed to generate unique ID") + return "", errors.New("failed to generate unique ID") } whisper.symKeys[id] = key return id, nil @@ -495,7 +495,7 @@ func (whisper *Whisper) AddSymKeyFromPassword(password string) (string, error) { return "", fmt.Errorf("failed to generate ID: %s", err) } if whisper.HasSymKey(id) { - return "", fmt.Errorf("failed to generate unique ID") + return "", errors.New("failed to generate unique ID") } // kdf should run no less than 0.1 seconds on an average computer, @@ -510,7 +510,7 @@ func (whisper *Whisper) AddSymKeyFromPassword(password string) (string, error) { // double check is necessary, because deriveKeyMaterial() is very slow if whisper.symKeys[id] != nil { - return "", fmt.Errorf("critical error: failed to generate unique ID") + return "", errors.New("critical error: failed to generate unique ID") } whisper.symKeys[id] = derived return id, nil @@ -542,7 +542,7 @@ func (whisper *Whisper) GetSymKey(id string) ([]byte, error) { if whisper.symKeys[id] != nil { return whisper.symKeys[id], nil } - return nil, fmt.Errorf("non-existent key ID") + return nil, errors.New("non-existent key ID") } // Subscribe installs a new message handler used for filtering, decrypting @@ -581,7 +581,7 @@ func (whisper *Whisper) GetFilter(id string) *Filter { func (whisper *Whisper) Unsubscribe(id string) error { ok := whisper.filters.Uninstall(id) if !ok { - return fmt.Errorf("Unsubscribe: Invalid ID") + return errors.New("Unsubscribe: Invalid ID") } return nil } @@ -591,7 +591,7 @@ func (whisper *Whisper) Unsubscribe(id string) error { func (whisper *Whisper) Send(envelope *Envelope) error { ok, err := whisper.add(envelope, false) if err == nil && !ok { - return fmt.Errorf("failed to add envelope") + return errors.New("failed to add envelope") } return err } @@ -762,7 +762,7 @@ func (whisper *Whisper) add(envelope *Envelope, isP2P bool) (bool, error) { if envelope.Expiry < now { if envelope.Expiry+DefaultSyncAllowance*2 < now { - return false, fmt.Errorf("very old message") + return false, errors.New("very old message") } log.Debug("expired envelope dropped", "hash", envelope.Hash().Hex()) return false, nil // drop envelope without error @@ -1009,7 +1009,7 @@ func GenerateRandomID() (id string, err error) { return "", err } if !validateDataIntegrity(buf, keyIDSize) { - return "", fmt.Errorf("error in generateRandomID: crypto/rand failed to generate random data") + return "", errors.New("error in generateRandomID: crypto/rand failed to generate random data") } id = common.Bytes2Hex(buf) return id, err