diff --git a/accounts/accounts.go b/accounts/accounts.go index ba57577925..157112c8c1 100644 --- a/accounts/accounts.go +++ b/accounts/accounts.go @@ -18,12 +18,14 @@ package accounts import ( + "fmt" "math/big" ethereum "github.com/XinFinOrg/XDPoSChain" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/event" + "golang.org/x/crypto/sha3" ) // Account represents an Ethereum account located at a specific location defined @@ -148,6 +150,34 @@ type Backend interface { Subscribe(sink chan<- WalletEvent) event.Subscription } +// TextHash is a helper function that calculates a hash for the given message that can be +// safely used to calculate a signature from. +// +// The hash is calulcated as +// +// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). +// +// This gives context to the signed message and prevents signing of transactions. +func TextHash(data []byte) []byte { + hash, _ := TextAndHash(data) + return hash +} + +// TextAndHash is a helper function that calculates a hash for the given message that can be +// safely used to calculate a signature from. +// +// The hash is calulcated as +// +// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). +// +// This gives context to the signed message and prevents signing of transactions. +func TextAndHash(data []byte) ([]byte, string) { + msg := fmt.Sprintf("\x19Ethereum Signed Message:\n%d%s", len(data), string(data)) + hasher := sha3.NewLegacyKeccak256() + hasher.Write([]byte(msg)) + return hasher.Sum(nil), msg +} + // WalletEventType represents the different event types that can be fired by // the wallet subscription subsystem. type WalletEventType int diff --git a/cmd/wnode/main.go b/cmd/wnode/main.go index 5441428e88..5fa29ab96c 100644 --- a/cmd/wnode/main.go +++ b/cmd/wnode/main.go @@ -139,8 +139,8 @@ func processArgs() { } if *asymmetricMode && len(*argPub) > 0 { - pub = crypto.ToECDSAPub(common.FromHex(*argPub)) - if !isKeyValid(pub) { + var err error + if pub, err = crypto.UnmarshalPubkey(common.FromHex(*argPub)); err != nil { utils.Fatalf("invalid public key") } } @@ -337,9 +337,8 @@ func configureNode() { if b == nil { utils.Fatalf("Error: can not convert hexadecimal string") } - pub = crypto.ToECDSAPub(b) - if !isKeyValid(pub) { - utils.Fatalf("Error: invalid public key") + if pub, err = crypto.UnmarshalPubkey(b); err != nil { + utils.Fatalf("Error: invalid peer public key") } } } diff --git a/core/blockchain.go b/core/blockchain.go index 090a52b4bf..d35800541e 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -58,6 +58,11 @@ var ( blockInsertTimer = metrics.NewRegisteredTimer("chain/inserts", nil) CheckpointCh = make(chan int) ErrNoGenesis = errors.New("Genesis not found in chain") + + blockReorgMeter = metrics.NewRegisteredMeter("chain/reorg/executes", nil) + blockReorgAddMeter = metrics.NewRegisteredMeter("chain/reorg/add", nil) + blockReorgDropMeter = metrics.NewRegisteredMeter("chain/reorg/drop", nil) + blockReorgInvalidatedTx = metrics.NewRegisteredMeter("chain/reorg/invalidTx", nil) ) const ( @@ -2245,6 +2250,9 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { } logFn("Chain split detected", "number", commonBlock.Number(), "hash", commonBlock.Hash(), "drop", len(oldChain), "dropfrom", oldChain[0].Hash(), "add", len(newChain), "addfrom", newChain[0].Hash()) + blockReorgAddMeter.Mark(int64(len(newChain))) + blockReorgDropMeter.Mark(int64(len(oldChain))) + blockReorgMeter.Mark(1) } else { log.Error("Impossible reorg, please file an issue", "oldnum", oldBlock.Number(), "oldhash", oldBlock.Hash(), "newnum", newBlock.Number(), "newhash", newBlock.Hash()) } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 9dc4b65117..1542510a8e 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -1167,7 +1167,7 @@ func TestEIP161AccountRemoval(t *testing.T) { t.Error("account should not exist") } - // account musn't be created post eip 161 + // account mustn't be created post eip 161 if _, err := blockchain.InsertChain(types.Blocks{blocks[2]}); err != nil { t.Fatal(err) } diff --git a/core/tx_list.go b/core/tx_list.go index 2c3c33eb59..5f623806d6 100644 --- a/core/tx_list.go +++ b/core/tx_list.go @@ -24,7 +24,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/types" - "github.com/XinFinOrg/XDPoSChain/log" ) // nonceHeap is a heap.Interface implementation over 64bit unsigned integers for @@ -99,7 +98,30 @@ func (m *txSortedMap) Forward(threshold uint64) types.Transactions { // Filter iterates over the list of transactions and removes all of them for which // the specified function evaluates to true. +// Filter, as opposed to 'filter', re-initialises the heap after the operation is done. +// If you want to do several consecutive filterings, it's therefore better to first +// do a .filter(func1) followed by .Filter(func2) or reheap() func (m *txSortedMap) Filter(filter func(*types.Transaction) bool) types.Transactions { + removed := m.filter(filter) + // If transactions were removed, the heap and cache are ruined + if len(removed) > 0 { + m.reheap() + } + return removed +} + +func (m *txSortedMap) reheap() { + *m.index = make([]uint64, 0, len(m.items)) + for nonce := range m.items { + *m.index = append(*m.index, nonce) + } + heap.Init(m.index) + m.cache = nil +} + +// filter is identical to Filter, but **does not** regenerate the heap. This method +// should only be used if followed immediately by a call to Filter or reheap() +func (m *txSortedMap) filter(filter func(*types.Transaction) bool) types.Transactions { var removed types.Transactions // Collect all the transactions to filter out @@ -109,14 +131,7 @@ func (m *txSortedMap) Filter(filter func(*types.Transaction) bool) types.Transac delete(m.items, nonce) } } - // If transactions were removed, the heap and cache are ruined if len(removed) > 0 { - *m.index = make([]uint64, 0, len(m.items)) - for nonce := range m.items { - *m.index = append(*m.index, nonce) - } - heap.Init(m.index) - m.cache = nil } return removed @@ -197,10 +212,7 @@ func (m *txSortedMap) Len() int { return len(m.items) } -// Flatten creates a nonce-sorted slice of transactions based on the loosely -// sorted internal representation. The result of the sorting is cached in case -// it's requested again before any modifications are made to the contents. -func (m *txSortedMap) Flatten() types.Transactions { +func (m *txSortedMap) flatten() types.Transactions { // If the sorting was not cached yet, create and cache it if m.cache == nil { m.cache = make(types.Transactions, 0, len(m.items)) @@ -209,12 +221,27 @@ func (m *txSortedMap) Flatten() types.Transactions { } sort.Sort(types.TxByNonce(m.cache)) } + return m.cache +} + +// Flatten creates a nonce-sorted slice of transactions based on the loosely +// sorted internal representation. The result of the sorting is cached in case +// it's requested again before any modifications are made to the contents. +func (m *txSortedMap) Flatten() types.Transactions { // Copy the cache to prevent accidental modifications - txs := make(types.Transactions, len(m.cache)) - copy(txs, m.cache) + cache := m.flatten() + txs := make(types.Transactions, len(cache)) + copy(txs, cache) return txs } +// LastElement returns the last element of a flattened list, thus, the +// transaction with the highest nonce +func (m *txSortedMap) LastElement() *types.Transaction { + cache := m.flatten() + return cache[len(cache)-1] +} + // txList is a "list" of transactions belonging to an account, sorted by account // nonce. The same type can be used both for storing contiguous transactions for // the executable/pending queue; and for storing gapped transactions for the non- @@ -255,11 +282,15 @@ func (l *txList) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Tran return false, nil } if old != nil { - threshold := new(big.Int).Div(new(big.Int).Mul(old.GasPrice(), big.NewInt(100+int64(priceBump))), big.NewInt(100)) + // threshold = oldGP * (100 + priceBump) / 100 + a := big.NewInt(100 + int64(priceBump)) + a = a.Mul(a, old.GasPrice()) + b := big.NewInt(100) + threshold := a.Div(a, b) // Have to ensure that the new gas price is higher than the old gas // price as well as checking the percentage threshold to ensure that // this is accurate for low (Wei-level) gas price replacements - if old.GasPrice().Cmp(tx.GasPrice()) >= 0 || threshold.Cmp(tx.GasPrice()) > 0 { + if old.GasPriceCmp(tx) >= 0 || tx.GasPriceIntCmp(threshold) < 0 { return false, nil } } @@ -303,24 +334,27 @@ func (l *txList) Filter(costLimit *big.Int, gasLimit uint64, trc21Issuers map[co maximum := costLimit if tx.To() != nil { if feeCapacity, ok := trc21Issuers[*tx.To()]; ok { - return new(big.Int).Add(costLimit, feeCapacity).Cmp(tx.TxCost(number)) < 0 || tx.Gas() > gasLimit + return tx.Gas() > gasLimit || new(big.Int).Add(costLimit, feeCapacity).Cmp(tx.TxCost(number)) < 0 } } - return tx.Cost().Cmp(maximum) > 0 || tx.Gas() > gasLimit + return tx.Gas() > gasLimit || tx.Cost().Cmp(maximum) > 0 }) - // If the list was strict, filter anything above the lowest nonce + if len(removed) == 0 { + return nil, nil + } var invalids types.Transactions - - if l.strict && len(removed) > 0 { + // If the list was strict, filter anything above the lowest nonce + if l.strict { lowest := uint64(math.MaxUint64) for _, tx := range removed { if nonce := tx.Nonce(); lowest > nonce { lowest = nonce } } - invalids = l.txs.Filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) + invalids = l.txs.filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) } + l.txs.reheap() return removed, invalids } @@ -374,6 +408,12 @@ func (l *txList) Flatten() types.Transactions { return l.txs.Flatten() } +// LastElement returns the last element of a flattened list, thus, the +// transaction with the highest nonce +func (l *txList) LastElement() *types.Transaction { + return l.txs.LastElement() +} + // priceHeap is a heap.Interface implementation over transactions for retrieving // price-sorted transactions to discard when the pool fills up. type priceHeap []*types.Transaction @@ -383,7 +423,7 @@ func (h priceHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h priceHeap) Less(i, j int) bool { // Sort primarily by price, returning the cheaper one - switch h[i].GasPrice().Cmp(h[j].GasPrice()) { + switch h[i].GasPriceCmp(h[j]) { case -1: return true case 1: @@ -406,24 +446,29 @@ func (h *priceHeap) Pop() interface{} { } // txPricedList is a price-sorted heap to allow operating on transactions pool -// contents in a price-incrementing way. +// contents in a price-incrementing way. It's built opon the all transactions +// in txpool but only interested in the remote part. It means only remote transactions +// will be considered for tracking, sorting, eviction, etc. type txPricedList struct { - all *txLookup // Pointer to the map of all transactions - items *priceHeap // Heap of prices of all the stored transactions - stales int // Number of stale price points to (re-heap trigger) + all *txLookup // Pointer to the map of all transactions + remotes *priceHeap // Heap of prices of all the stored **remote** transactions + stales int // Number of stale price points to (re-heap trigger) } // newTxPricedList creates a new price-sorted transaction heap. func newTxPricedList(all *txLookup) *txPricedList { return &txPricedList{ - all: all, - items: new(priceHeap), + all: all, + remotes: new(priceHeap), } } // Put inserts a new transaction into the heap. -func (l *txPricedList) Put(tx *types.Transaction) { - heap.Push(l.items, tx) +func (l *txPricedList) Put(tx *types.Transaction, local bool) { + if local { + return + } + heap.Push(l.remotes, tx) } // Removed notifies the prices transaction list that an old transaction dropped @@ -432,100 +477,95 @@ func (l *txPricedList) Put(tx *types.Transaction) { func (l *txPricedList) Removed(count int) { // Bump the stale counter, but exit if still too low (< 25%) l.stales += count - if l.stales <= len(*l.items)/4 { + if l.stales <= len(*l.remotes)/4 { return } // Seems we've reached a critical number of stale transactions, reheap - reheap := make(priceHeap, 0, l.all.Count()) - - l.stales, l.items = 0, &reheap - l.all.Range(func(hash common.Hash, tx *types.Transaction) bool { - *l.items = append(*l.items, tx) - return true - }) - heap.Init(l.items) + l.Reheap() } // Cap finds all the transactions below the given price threshold, drops them -// from the priced list and returs them for further removal from the entire pool. -func (l *txPricedList) Cap(threshold *big.Int, local *accountSet) types.Transactions { +// from the priced list and returns them for further removal from the entire pool. +// +// Note: only remote transactions will be considered for eviction. +func (l *txPricedList) Cap(threshold *big.Int) types.Transactions { drop := make(types.Transactions, 0, 128) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep - - for len(*l.items) > 0 { + for len(*l.remotes) > 0 { // Discard stale transactions if found during cleanup - tx := heap.Pop(l.items).(*types.Transaction) - if l.all.Get(tx.Hash()) == nil { + cheapest := (*l.remotes)[0] + if l.all.GetRemote(cheapest.Hash()) == nil { // Removed or migrated + heap.Pop(l.remotes) l.stales-- continue } // Stop the discards if we've reached the threshold - if tx.GasPrice().Cmp(threshold) >= 0 { - save = append(save, tx) + if cheapest.GasPriceIntCmp(threshold) >= 0 { break } - // Non stale transaction found, discard unless local - if local.containsTx(tx) { - save = append(save, tx) - } else { - drop = append(drop, tx) - } - } - for _, tx := range save { - heap.Push(l.items, tx) + heap.Pop(l.remotes) + drop = append(drop, cheapest) } return drop } // Underpriced checks whether a transaction is cheaper than (or as cheap as) the -// lowest priced transaction currently being tracked. -func (l *txPricedList) Underpriced(tx *types.Transaction, local *accountSet) bool { - // Local transactions cannot be underpriced - if local.containsTx(tx) { - return false - } +// lowest priced (remote) transaction currently being tracked. +func (l *txPricedList) Underpriced(tx *types.Transaction) bool { // Discard stale price points if found at the heap start - for len(*l.items) > 0 { - head := []*types.Transaction(*l.items)[0] - if l.all.Get(head.Hash()) == nil { + for len(*l.remotes) > 0 { + head := []*types.Transaction(*l.remotes)[0] + if l.all.GetRemote(head.Hash()) == nil { // Removed or migrated l.stales-- - heap.Pop(l.items) + heap.Pop(l.remotes) continue } break } // Check if the transaction is underpriced or not - if len(*l.items) == 0 { - log.Error("Pricing query for empty pool") // This cannot happen, print to catch programming errors - return false + if len(*l.remotes) == 0 { + return false // There is no remote transaction at all. } - cheapest := []*types.Transaction(*l.items)[0] - return cheapest.GasPrice().Cmp(tx.GasPrice()) >= 0 + // If the remote transaction is even cheaper than the + // cheapest one tracked locally, reject it. + cheapest := []*types.Transaction(*l.remotes)[0] + return cheapest.GasPriceCmp(tx) >= 0 } // Discard finds a number of most underpriced transactions, removes them from the // priced list and returns them for further removal from the entire pool. -func (l *txPricedList) Discard(count int, local *accountSet) types.Transactions { - drop := make(types.Transactions, 0, count) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep - - for len(*l.items) > 0 && count > 0 { +// +// Note local transaction won't be considered for eviction. +func (l *txPricedList) Discard(slots int, force bool) (types.Transactions, bool) { + drop := make(types.Transactions, 0, slots) // Remote underpriced transactions to drop + for len(*l.remotes) > 0 && slots > 0 { // Discard stale transactions if found during cleanup - tx := heap.Pop(l.items).(*types.Transaction) - if l.all.Get(tx.Hash()) == nil { + tx := heap.Pop(l.remotes).(*types.Transaction) + if l.all.GetRemote(tx.Hash()) == nil { // Removed or migrated l.stales-- continue } - // Non stale transaction found, discard unless local - if local.containsTx(tx) { - save = append(save, tx) - } else { - drop = append(drop, tx) - count-- + // Non stale transaction found, discard it + drop = append(drop, tx) + slots -= numSlots(tx) + } + // If we still can't make enough room for the new transaction + if slots > 0 && !force { + for _, tx := range drop { + heap.Push(l.remotes, tx) } + return nil, false } - for _, tx := range save { - heap.Push(l.items, tx) - } - return drop + return drop, true +} + +// Reheap forcibly rebuilds the heap based on the current remote transaction set. +func (l *txPricedList) Reheap() { + reheap := make(priceHeap, 0, l.all.RemoteCount()) + + l.stales, l.remotes = 0, &reheap + l.all.Range(func(hash common.Hash, tx *types.Transaction, local bool) bool { + *l.remotes = append(*l.remotes, tx) + return true + }, false, true) // Only iterate remotes + heap.Init(l.remotes) } diff --git a/core/tx_list_test.go b/core/tx_list_test.go index f0ec8eb8b4..36a0196f1e 100644 --- a/core/tx_list_test.go +++ b/core/tx_list_test.go @@ -17,6 +17,7 @@ package core import ( + "math/big" "math/rand" "testing" @@ -49,3 +50,21 @@ func TestStrictTxListAdd(t *testing.T) { } } } + +func BenchmarkTxListAdd(t *testing.B) { + // Generate a list of transactions to insert + key, _ := crypto.GenerateKey() + + txs := make(types.Transactions, 100000) + for i := 0; i < len(txs); i++ { + txs[i] = transaction(uint64(i), 0, key) + } + // Insert the transactions in a random order + list := newTxList(true) + priceLimit := big.NewInt(int64(DefaultTxPoolConfig.PriceLimit)) + t.ResetTimer() + for _, v := range rand.Perm(len(txs)) { + list.Add(txs[v], DefaultTxPoolConfig.PriceBump) + list.Filter(priceLimit, DefaultTxPoolConfig.PriceBump, nil, nil) + } +} diff --git a/core/tx_pool.go b/core/tx_pool.go index b2c271f016..f2d9fc99a2 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -39,9 +39,25 @@ import ( const ( // chainHeadChanSize is the size of channel listening to ChainHeadEvent. chainHeadChanSize = 10 + + // txSlotSize is used to calculate how many data slots a single transaction + // takes up based on its size. The slots are used as DoS protection, ensuring + // that validating a new transaction remains a constant operation (in reality + // O(maxslots), where max slots are 4 currently). + txSlotSize = 32 * 1024 + + // txMaxSize is the maximum size a single transaction can have. This field has + // non-trivial consequences: larger transactions are significantly harder and + // more expensive to propagate; larger transactions also take more resources + // to validate whether they fit into the pool or not. + txMaxSize = 2 * txSlotSize // 64KB, don't bump without EIP-2464 support ) var ( + // ErrAlreadyKnown is returned if the transactions is already contained + // within the pool. + ErrAlreadyKnown = errors.New("already known") + // ErrInvalidSender is returned if the transaction contains an invalid signature. ErrInvalidSender = errors.New("invalid sender") @@ -53,6 +69,10 @@ var ( // configured for the transaction pool. ErrUnderpriced = errors.New("transaction underpriced") + // ErrTxPoolOverflow is returned if the transaction pool is full and can't accpet + // another remote transaction. + ErrTxPoolOverflow = errors.New("txpool is full") + // ErrReplaceUnderpriced is returned if a transaction is attempted to be replaced // with a different one without the required price bump. ErrReplaceUnderpriced = errors.New("replacement transaction underpriced") @@ -69,7 +89,7 @@ var ( // maximum allowance of the current block. ErrGasLimit = errors.New("exceeds block gas limit") - // ErrNegativeValue is a sanity error to ensure noone is able to specify a + // ErrNegativeValue is a sanity error to ensure no one is able to specify a // transaction with a negative value. ErrNegativeValue = errors.New("negative value") @@ -104,15 +124,19 @@ var ( queuedReplaceMeter = metrics.NewRegisteredMeter("txpool/queued/replace", nil) queuedRateLimitMeter = metrics.NewRegisteredMeter("txpool/queued/ratelimit", nil) // Dropped due to rate limiting queuedNofundsMeter = metrics.NewRegisteredMeter("txpool/queued/nofunds", nil) // Dropped due to out-of-funds + queuedEvictionMeter = metrics.NewRegisteredMeter("txpool/queued/eviction", nil) // Dropped due to lifetime // General tx metrics - validMeter = metrics.NewRegisteredMeter("txpool/valid", nil) + knownTxMeter = metrics.NewRegisteredMeter("txpool/known", nil) + validTxMeter = metrics.NewRegisteredMeter("txpool/valid", nil) invalidTxMeter = metrics.NewRegisteredMeter("txpool/invalid", nil) underpricedTxMeter = metrics.NewRegisteredMeter("txpool/underpriced", nil) + overflowedTxMeter = metrics.NewRegisteredMeter("txpool/overflowed", nil) pendingGauge = metrics.NewRegisteredGauge("txpool/pending", nil) queuedGauge = metrics.NewRegisteredGauge("txpool/queued", nil) localGauge = metrics.NewRegisteredGauge("txpool/local", nil) + slotsGauge = metrics.NewRegisteredGauge("txpool/slots", nil) ) // TxStatus is the current status of a transaction as seen by the pool. @@ -369,7 +393,7 @@ func (pool *TxPool) loop() { prevPending, prevQueued, prevStales = pending, queued, stales } - // Handle inactive account transaction eviction + // Handle inactive account transaction eviction case <-evict.C: pool.mu.Lock() for addr := range pool.queue { @@ -379,14 +403,16 @@ func (pool *TxPool) loop() { } // Any non-locals old enough should be removed if time.Since(pool.beats[addr]) > pool.config.Lifetime { - for _, tx := range pool.queue[addr].Flatten() { + list := pool.queue[addr].Flatten() + for _, tx := range list { pool.removeTx(tx.Hash(), true) } + queuedEvictionMeter.Mark(int64(len(list))) } } pool.mu.Unlock() - // Handle local transaction journal rotation + // Handle local transaction journal rotation case <-journal.C: if pool.journal != nil { pool.mu.Lock() @@ -435,7 +461,7 @@ func (pool *TxPool) SetGasPrice(price *big.Int) { defer pool.mu.Unlock() pool.gasPrice = price - for _, tx := range pool.priced.Cap(price, pool.locals) { + for _, tx := range pool.priced.Cap(price) { pool.removeTx(tx.Hash(), false) } log.Info("Transaction pool price threshold updated", "price", price) @@ -539,6 +565,10 @@ func (pool *TxPool) GetSender(tx *types.Transaction) (common.Address, error) { // validateTx checks whether a transaction is valid according to the consensus // rules and adheres to some heuristic limits of the local node (price and size). func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { + // Reject transactions over defined size to prevent DOS attacks + if uint64(tx.Size()) > txMaxSize { + return ErrOversizedData + } // check if sender is in black list if tx.From() != nil && common.Blacklist[*tx.From()] { return fmt.Errorf("Reject transaction with sender in black-list: %v", tx.From().Hex()) @@ -547,11 +577,6 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { if tx.To() != nil && common.Blacklist[*tx.To()] { return fmt.Errorf("Reject transaction with receiver in black-list: %v", tx.To().Hex()) } - - // Heuristic limit, reject transactions over 32KB to prevent DOS attacks - if tx.Size() > 32*1024 { - return ErrOversizedData - } // Transactions can't be negative. This may never happen using RLP decoded // transactions but may occur if you create a transaction using the RPC. if tx.Value().Sign() < 0 { @@ -567,8 +592,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { return ErrInvalidSender } // Drop non-local transactions under our own minimal accepted gas price - local = local || pool.locals.contains(from) // account may be local even if the transaction arrived from the network - if !local && pool.gasPrice.Cmp(tx.GasPrice()) > 0 { + if !local && tx.GasPriceIntCmp(pool.gasPrice) < 0 { if !tx.IsSpecialTransaction() || (pool.IsSigner != nil && !pool.IsSigner(from)) { return ErrUnderpriced } @@ -659,39 +683,50 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e hash := tx.Hash() if pool.all.Get(hash) != nil { log.Trace("Discarding already known transaction", "hash", hash) - return false, fmt.Errorf("known transaction: %x", hash) + knownTxMeter.Mark(1) + return false, ErrAlreadyKnown } + // Make the local flag. If it's from local source or it's from the network but + // the sender is marked as local previously, treat it as the local transaction. + isLocal := local || pool.locals.containsTx(tx) // If the transaction fails basic validation, discard it - if err := pool.validateTx(tx, local); err != nil { + if err := pool.validateTx(tx, isLocal); err != nil { log.Trace("Discarding invalid transaction", "hash", hash, "err", err) invalidTxMeter.Mark(1) return false, err } - from, _ := types.Sender(pool.signer, tx) // already validated if tx.IsSpecialTransaction() && pool.IsSigner != nil && pool.IsSigner(from) && pool.pendingNonces.get(from) == tx.Nonce() { - return pool.promoteSpecialTx(from, tx) + return pool.promoteSpecialTx(from, tx, isLocal) } - // If the transaction pool is full, discard underpriced transactions if uint64(pool.all.Count()) >= pool.config.GlobalSlots+pool.config.GlobalQueue { log.Debug("Add transaction to pool full", "hash", hash, "nonce", tx.Nonce()) // If the new transaction is underpriced, don't accept it - if !local && pool.priced.Underpriced(tx, pool.locals) { + if !isLocal && pool.priced.Underpriced(tx) { log.Trace("Discarding underpriced transaction", "hash", hash, "price", tx.GasPrice()) underpricedTxMeter.Mark(1) return false, ErrUnderpriced } - // New transaction is better than our worse ones, make room for it - drop := pool.priced.Discard(pool.all.Count()-int(pool.config.GlobalSlots+pool.config.GlobalQueue-1), pool.locals) + // New transaction is better than our worse ones, make room for it. + // If it's a local transaction, forcibly discard all available transactions. + // Otherwise if we can't make enough room for new one, abort the operation. + drop, success := pool.priced.Discard(pool.all.Slots()-int(pool.config.GlobalSlots+pool.config.GlobalQueue)+numSlots(tx), isLocal) + + // Special case, we still can't make the room for the new remote one. + if !isLocal && !success { + log.Trace("Discarding overflown transaction", "hash", hash) + overflowedTxMeter.Mark(1) + return false, ErrTxPoolOverflow + } + // Kick out the underpriced remote transactions. for _, tx := range drop { log.Trace("Discarding freshly underpriced transaction", "hash", tx.Hash(), "price", tx.GasPrice()) underpricedTxMeter.Mark(1) pool.removeTx(tx.Hash(), false) } } - // Try to replace an existing transaction in the pending pool if list := pool.pending[from]; list != nil && list.Overlaps(tx) { // Nonce already pending, check if required price bump is met @@ -706,28 +741,28 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e pool.priced.Removed(1) pendingReplaceMeter.Mark(1) } - pool.all.Add(tx) - pool.priced.Put(tx) + pool.all.Add(tx, isLocal) + pool.priced.Put(tx, isLocal) pool.journalTx(from, tx) pool.queueTxEvent(tx) log.Trace("Pooled new executable transaction", "hash", hash, "from", from, "to", tx.To()) + + // Successful promotion, bump the heartbeat + pool.beats[from] = time.Now() return old != nil, nil } - // New transaction isn't replacing a pending one, push into queue - replaced, err = pool.enqueueTx(hash, tx) + replaced, err = pool.enqueueTx(hash, tx, isLocal, true) if err != nil { return false, err } - // Mark local addresses and journal local transactions - if local { - if !pool.locals.contains(from) { - log.Info("Setting new local account", "address", from) - pool.locals.add(from) - } + if local && !pool.locals.contains(from) { + log.Info("Setting new local account", "address", from) + pool.locals.add(from) + pool.priced.Removed(pool.all.RemoteToLocals(pool.locals)) // Migrate the remotes if it's marked as local first time. } - if local || pool.locals.contains(from) { + if isLocal { localGauge.Inc(1) } pool.journalTx(from, tx) @@ -739,7 +774,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e // enqueueTx inserts a new transaction into the non-executable transaction queue. // // Note, this method assumes the pool lock is held! -func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, error) { +func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction, local bool, addAll bool) (bool, error) { // Try to insert the transaction into the future queue from, _ := types.Sender(pool.signer, tx) // already validated if pool.queue[from] == nil { @@ -760,9 +795,18 @@ func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, er // Nothing was replaced, bump the queued counter queuedGauge.Inc(1) } - if pool.all.Get(hash) == nil { - pool.all.Add(tx) - pool.priced.Put(tx) + // If the transaction isn't in lookup set but it's expected to be there, + // show the error log. + if pool.all.Get(hash) == nil && !addAll { + log.Error("Missing transaction in lookup set, please report the issue", "hash", hash) + } + if addAll { + pool.all.Add(tx, local) + pool.priced.Put(tx, local) + } + // If we never record the heartbeat, do it right now. + if _, exist := pool.beats[from]; !exist { + pool.beats[from] = time.Now() } return old != nil, nil } @@ -803,30 +847,26 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T if old != nil { pool.all.Remove(old.Hash()) pool.priced.Removed(1) - pendingReplaceMeter.Mark(1) } else { // Nothing was replaced, bump the pending counter pendingGauge.Inc(1) } - // Failsafe to work around direct pending inserts (tests) - if pool.all.Get(hash) == nil { - pool.all.Add(tx) - pool.priced.Put(tx) - } // Set the potentially new pending nonce and notify any subsystems of the new tx - pool.beats[addr] = time.Now() pool.pendingNonces.set(addr, tx.Nonce()+1) + // Successful promotion, bump the heartbeat + pool.beats[addr] = time.Now() return true } -func (pool *TxPool) promoteSpecialTx(addr common.Address, tx *types.Transaction) (bool, error) { +func (pool *TxPool) promoteSpecialTx(addr common.Address, tx *types.Transaction, isLocal bool) (bool, error) { // Try to insert the transaction into the pending queue if pool.pending[addr] == nil { pool.pending[addr] = newTxList(true) } list := pool.pending[addr] + old := list.txs.Get(tx.Nonce()) if old != nil && old.IsSpecialTransaction() { return false, ErrDuplicateSpecialTransaction @@ -849,7 +889,7 @@ func (pool *TxPool) promoteSpecialTx(addr common.Address, tx *types.Transaction) } // Failsafe to work around direct pending inserts (tests) if pool.all.Get(tx.Hash()) == nil { - pool.all.Add(tx) + pool.all.Add(tx, isLocal) } // Set the potentially new pending nonce and notify any subsystems of the new tx pool.beats[addr] = time.Now() @@ -889,7 +929,7 @@ func (pool *TxPool) AddRemotesSync(txs []*types.Transaction) []error { } // This is like AddRemotes with a single transaction, but waits for pool reorganization. Tests use this method. -func (pool *TxPool) AddRemoteSync(tx *types.Transaction) error { +func (pool *TxPool) addRemoteSync(tx *types.Transaction) error { errs := pool.AddRemotesSync([]*types.Transaction{tx}) return errs[0] } @@ -905,15 +945,48 @@ func (pool *TxPool) AddRemote(tx *types.Transaction) error { // addTxs attempts to queue a batch of transactions if they are valid. func (pool *TxPool) addTxs(txs []*types.Transaction, local, sync bool) []error { - // Cache senders in transactions before obtaining lock (pool.signer is immutable) - for _, tx := range txs { - types.Sender(pool.signer, tx) + // Filter out known ones without obtaining the pool lock or recovering signatures + var ( + errs = make([]error, len(txs)) + news = make([]*types.Transaction, 0, len(txs)) + ) + for i, tx := range txs { + // If the transaction is known, pre-set the error slot + if pool.all.Get(tx.Hash()) != nil { + errs[i] = ErrAlreadyKnown + knownTxMeter.Mark(1) + continue + } + // Exclude transactions with invalid signatures as soon as + // possible and cache senders in transactions before + // obtaining lock + _, err := types.Sender(pool.signer, tx) + if err != nil { + errs[i] = ErrInvalidSender + invalidTxMeter.Mark(1) + continue + } + // Accumulate all unknown transactions for deeper processing + news = append(news, tx) + } + if len(news) == 0 { + return errs } + // Process all the new transaction and merge any errors into the original slice pool.mu.Lock() - errs, dirtyAddrs := pool.addTxsLocked(txs, local) + newErrs, dirtyAddrs := pool.addTxsLocked(news, local) pool.mu.Unlock() + var nilSlot = 0 + for _, err := range newErrs { + for errs[nilSlot] != nil { + nilSlot++ + } + errs[nilSlot] = err + nilSlot++ + } + // Reorg the pool internals if needed and return done := pool.requestPromoteExecutables(dirtyAddrs) if sync { <-done @@ -933,26 +1006,29 @@ func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) ([]error, dirty.addTx(tx) } } - validMeter.Mark(int64(len(dirty.accounts))) + validTxMeter.Mark(int64(len(dirty.accounts))) return errs, dirty } // Status returns the status (unknown/pending/queued) of a batch of transactions // identified by their hashes. func (pool *TxPool) Status(hashes []common.Hash) []TxStatus { - pool.mu.RLock() - defer pool.mu.RUnlock() - status := make([]TxStatus, len(hashes)) for i, hash := range hashes { - if tx := pool.all.Get(hash); tx != nil { - from, _ := types.Sender(pool.signer, tx) // already validated - if pool.pending[from] != nil && pool.pending[from].txs.items[tx.Nonce()] != nil { - status[i] = TxStatusPending - } else { - status[i] = TxStatusQueued - } + tx := pool.Get(hash) + if tx == nil { + continue } + from, _ := types.Sender(pool.signer, tx) // already validated + pool.mu.RLock() + if txList := pool.pending[from]; txList != nil && txList.txs.items[tx.Nonce()] != nil { + status[i] = TxStatusPending + } else if txList := pool.queue[from]; txList != nil && txList.txs.items[tx.Nonce()] != nil { + status[i] = TxStatusQueued + } + // implicit else: the tx may have been included into a block between + // checking pool.Get and obtaining the lock. In that case, TxStatusUnknown is correct + pool.mu.RUnlock() } return status } @@ -962,6 +1038,12 @@ func (pool *TxPool) Get(hash common.Hash) *types.Transaction { return pool.all.Get(hash) } +// Has returns an indicator whether txpool has a transaction cached with the +// given hash. +func (pool *TxPool) Has(hash common.Hash) bool { + return pool.all.Get(hash) != nil +} + // removeTx removes a single transaction from the queue, moving all subsequent // transactions back to the future queue. func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { @@ -986,11 +1068,11 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { // If no more pending transactions are left, remove the list if pending.Empty() { delete(pool.pending, addr) - delete(pool.beats, addr) } // Postpone any invalidated transactions for _, tx := range invalids { - pool.enqueueTx(tx.Hash(), tx) + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(tx.Hash(), tx, false, false) } // Update the account nonce if needed pool.pendingNonces.setIfLower(addr, tx.Nonce()) @@ -1007,6 +1089,7 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { } if future.Empty() { delete(pool.queue, addr) + delete(pool.beats, addr) } } } @@ -1118,7 +1201,10 @@ func (pool *TxPool) runReorg(done chan struct{}, reset *txpoolResetRequest, dirt defer close(done) var promoteAddrs []common.Address - if dirtyAccounts != nil { + if dirtyAccounts != nil && reset == nil { + // Only dirty accounts need to be promoted, unless we're resetting. + // For resets, all addresses in the tx queue will be promoted and + // the flatten operation can be avoided. promoteAddrs = dirtyAccounts.flatten() } pool.mu.Lock() @@ -1134,20 +1220,14 @@ func (pool *TxPool) runReorg(done chan struct{}, reset *txpoolResetRequest, dirt } } // Reset needs promote for all addresses - promoteAddrs = promoteAddrs[:0] + promoteAddrs = make([]common.Address, 0, len(pool.queue)) for addr := range pool.queue { promoteAddrs = append(promoteAddrs, addr) } } // Check for pending transactions for every account that sent new ones promoted := pool.promoteExecutables(promoteAddrs) - for _, tx := range promoted { - addr, _ := types.Sender(pool.signer, tx) - if _, ok := events[addr]; !ok { - events[addr] = newTxSortedMap() - } - events[addr].Put(tx) - } + // If a new block appeared, validate the pool of pending transactions. This will // remove any transaction that has been included in the block or was invalidated // because of another transaction (e.g. higher gas price). @@ -1160,12 +1240,19 @@ func (pool *TxPool) runReorg(done chan struct{}, reset *txpoolResetRequest, dirt // Update all accounts to the latest known pending nonce for addr, list := range pool.pending { - txs := list.Flatten() // Heavy but will be cached and is needed by the miner anyway - pool.pendingNonces.set(addr, txs[len(txs)-1].Nonce()+1) + highestPending := list.LastElement() + pool.pendingNonces.set(addr, highestPending.Nonce()+1) } pool.mu.Unlock() // Notify subsystems for newly added transactions + for _, tx := range promoted { + addr, _ := types.Sender(pool.signer, tx) + if _, ok := events[addr]; !ok { + events[addr] = newTxSortedMap() + } + events[addr].Put(tx) + } if len(events) > 0 { var txs []*types.Transaction for _, set := range events { @@ -1200,44 +1287,45 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) { // head from the chain. // If that is the case, we don't have the lost transactions any more, and // there's nothing to add - if newNum < oldNum { - // If the reorg ended up on a lower number, it's indicative of setHead being the cause - log.Debug("Skipping transaction reset caused by setHead", - "old", oldHead.Hash(), "oldnum", oldNum, "new", newHead.Hash(), "newnum", newNum) - } else { + if newNum >= oldNum { // If we reorged to a same or higher number, then it's not a case of setHead log.Warn("Transaction pool reset with missing oldhead", "old", oldHead.Hash(), "oldnum", oldNum, "new", newHead.Hash(), "newnum", newNum) - } - return - } - for rem.NumberU64() > add.NumberU64() { - discarded = append(discarded, rem.Transactions()...) - if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { - log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) return } - } - for add.NumberU64() > rem.NumberU64() { - included = append(included, add.Transactions()...) - if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { - log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) - return + // If the reorg ended up on a lower number, it's indicative of setHead being the cause + log.Debug("Skipping transaction reset caused by setHead", + "old", oldHead.Hash(), "oldnum", oldNum, "new", newHead.Hash(), "newnum", newNum) + // We still need to update the current state s.th. the lost transactions can be readded by the user + } else { + for rem.NumberU64() > add.NumberU64() { + discarded = append(discarded, rem.Transactions()...) + if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { + log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) + return + } } - } - for rem.Hash() != add.Hash() { - discarded = append(discarded, rem.Transactions()...) - if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { - log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) - return + for add.NumberU64() > rem.NumberU64() { + included = append(included, add.Transactions()...) + if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { + log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) + return + } } - included = append(included, add.Transactions()...) - if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { - log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) - return + for rem.Hash() != add.Hash() { + discarded = append(discarded, rem.Transactions()...) + if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { + log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) + return + } + included = append(included, add.Transactions()...) + if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { + log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) + return + } } + reinject = types.TxDifference(discarded, included) } - reinject = types.TxDifference(discarded, included) } } // Initialize the internal state to the current head @@ -1283,8 +1371,8 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans for _, tx := range forwards { hash := tx.Hash() pool.all.Remove(hash) - log.Trace("Removed old queued transaction", "hash", hash) } + log.Trace("Removed old queued transactions", "count", len(forwards)) // Drop all transactions that are too costly (low balance or out of gas) var number *big.Int = nil if pool.chain.CurrentHeader() != nil { @@ -1294,8 +1382,8 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans for _, tx := range drops { hash := tx.Hash() pool.all.Remove(hash) - log.Trace("Removed unpayable queued transaction", "hash", hash) } + log.Trace("Removed unpayable queued transactions", "count", len(drops)) queuedNofundsMeter.Mark(int64(len(drops))) // Gather all executable transactions and promote them @@ -1303,10 +1391,10 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans for _, tx := range readies { hash := tx.Hash() if pool.promoteTx(addr, hash, tx) { - log.Trace("Promoting queued transaction", "hash", hash) promoted = append(promoted, tx) } } + log.Trace("Promoted queued transactions", "count", len(promoted)) queuedGauge.Dec(int64(len(readies))) // Drop all transactions over the allowed limit @@ -1329,6 +1417,7 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Trans // Delete the entire queue entry if it became empty. if list.Empty() { delete(pool.queue, addr) + delete(pool.beats, addr) } } return promoted @@ -1498,7 +1587,9 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range invalids { hash := tx.Hash() log.Trace("Demoting pending transaction", "hash", hash) - pool.enqueueTx(hash, tx) + + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(hash, tx, false, false) } pendingGauge.Dec(int64(len(olds) + len(drops) + len(invalids))) if pool.locals.contains(addr) { @@ -1510,14 +1601,17 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range gapped { hash := tx.Hash() log.Warn("Demoting invalidated transaction", "hash", hash) - pool.enqueueTx(hash, tx) + + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(hash, tx, false, false) } pendingGauge.Dec(int64(len(gapped))) + // This might happen in a reorg, so log it to the metering + blockReorgInvalidatedTx.Mark(int64(len(gapped))) } - // Delete the entire queue entry if it became empty. + // Delete the entire pending entry if it became empty. if list.Empty() { delete(pool.pending, addr) - delete(pool.beats, addr) } } } @@ -1561,6 +1655,10 @@ func (as *accountSet) contains(addr common.Address) bool { return exist } +func (as *accountSet) empty() bool { + return len(as.accounts) == 0 +} + // containsTx checks if the sender of a given tx is within the set. If the sender // cannot be derived, this method returns false. func (as *accountSet) containsTx(tx *types.Transaction) bool { @@ -1604,8 +1702,8 @@ func (as *accountSet) merge(other *accountSet) { as.cache = nil } -// txLookup is used internally by TxPool to track transactions while allowing lookup without -// mutex contention. +// txLookup is used internally by TxPool to track transactions while allowing +// lookup without mutex contention. // // Note, although this type is properly protected against concurrent access, it // is **not** a type that should ever be mutated or even exposed outside of the @@ -1613,26 +1711,43 @@ func (as *accountSet) merge(other *accountSet) { // internal mechanisms. The sole purpose of the type is to permit out-of-bound // peeking into the pool in TxPool.Get without having to acquire the widely scoped // TxPool.mu mutex. +// +// This lookup set combines the notion of "local transactions", which is useful +// to build upper-level structure. type txLookup struct { - all map[common.Hash]*types.Transaction - lock sync.RWMutex + slots int + lock sync.RWMutex + locals map[common.Hash]*types.Transaction + remotes map[common.Hash]*types.Transaction } // newTxLookup returns a new txLookup structure. func newTxLookup() *txLookup { return &txLookup{ - all: make(map[common.Hash]*types.Transaction), + locals: make(map[common.Hash]*types.Transaction), + remotes: make(map[common.Hash]*types.Transaction), } } -// Range calls f on each key and value present in the map. -func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction) bool) { +// Range calls f on each key and value present in the map. The callback passed +// should return the indicator whether the iteration needs to be continued. +// Callers need to specify which set (or both) to be iterated. +func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction, local bool) bool, local bool, remote bool) { t.lock.RLock() defer t.lock.RUnlock() - for key, value := range t.all { - if !f(key, value) { - break + if local { + for key, value := range t.locals { + if !f(key, value, true) { + return + } + } + } + if remote { + for key, value := range t.remotes { + if !f(key, value, false) { + return + } } } } @@ -1642,23 +1757,73 @@ func (t *txLookup) Get(hash common.Hash) *types.Transaction { t.lock.RLock() defer t.lock.RUnlock() - return t.all[hash] + if tx := t.locals[hash]; tx != nil { + return tx + } + return t.remotes[hash] } -// Count returns the current number of items in the lookup. +// GetLocal returns a transaction if it exists in the lookup, or nil if not found. +func (t *txLookup) GetLocal(hash common.Hash) *types.Transaction { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.locals[hash] +} + +// GetRemote returns a transaction if it exists in the lookup, or nil if not found. +func (t *txLookup) GetRemote(hash common.Hash) *types.Transaction { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.remotes[hash] +} + +// Count returns the current number of transactions in the lookup. func (t *txLookup) Count() int { t.lock.RLock() defer t.lock.RUnlock() - return len(t.all) + return len(t.locals) + len(t.remotes) +} + +// LocalCount returns the current number of local transactions in the lookup. +func (t *txLookup) LocalCount() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return len(t.locals) +} + +// RemoteCount returns the current number of remote transactions in the lookup. +func (t *txLookup) RemoteCount() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return len(t.remotes) +} + +// Slots returns the current number of slots used in the lookup. +func (t *txLookup) Slots() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.slots } // Add adds a transaction to the lookup. -func (t *txLookup) Add(tx *types.Transaction) { +func (t *txLookup) Add(tx *types.Transaction, local bool) { t.lock.Lock() defer t.lock.Unlock() - t.all[tx.Hash()] = tx + t.slots += numSlots(tx) + slotsGauge.Update(int64(t.slots)) + + if local { + t.locals[tx.Hash()] = tx + } else { + t.remotes[tx.Hash()] = tx + } } // Remove removes a transaction from the lookup. @@ -1666,5 +1831,39 @@ func (t *txLookup) Remove(hash common.Hash) { t.lock.Lock() defer t.lock.Unlock() - delete(t.all, hash) + tx, ok := t.locals[hash] + if !ok { + tx, ok = t.remotes[hash] + } + if !ok { + log.Error("No transaction found to be deleted", "hash", hash) + return + } + t.slots -= numSlots(tx) + slotsGauge.Update(int64(t.slots)) + + delete(t.locals, hash) + delete(t.remotes, hash) +} + +// RemoteToLocals migrates the transactions belongs to the given locals to locals +// set. The assumption is held the locals set is thread-safe to be used. +func (t *txLookup) RemoteToLocals(locals *accountSet) int { + t.lock.Lock() + defer t.lock.Unlock() + + var migrated int + for hash, tx := range t.remotes { + if locals.containsTx(tx) { + t.locals[hash] = tx + delete(t.remotes, hash) + migrated += 1 + } + } + return migrated +} + +// numSlots calculates the number of slots needed for a single transaction. +func numSlots(tx *types.Transaction) int { + return int((tx.Size() + txSlotSize - 1) / txSlotSize) } diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index b494d93ce5..dbbdd24baa 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -94,10 +94,18 @@ func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ec return tx } +func pricedDataTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ecdsa.PrivateKey, bytes uint64) *types.Transaction { + data := make([]byte, bytes) + rand.Read(data) + + tx, _ := types.SignTx(types.NewTransaction(nonce, common.Address{}, big.NewInt(0), gaslimit, gasprice, data), types.HomesteadSigner{}, key) + return tx +} + func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { diskdb := rawdb.NewMemoryDatabase() statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb)) - blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} + blockchain := &testBlockChain{statedb, 10000000, new(event.Feed)} key, _ := crypto.GenerateKey() pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) @@ -115,8 +123,10 @@ func validateTxPoolInternals(pool *TxPool) error { if total := pool.all.Count(); total != pending+queued { return fmt.Errorf("total transaction count %d != %d pending + %d queued", total, pending, queued) } - if priced := pool.priced.items.Len() - pool.priced.stales; priced != pending+queued { - return fmt.Errorf("total priced transaction count %d != %d pending + %d queued", priced, pending, queued) + pool.priced.Reheap() + priced, remote := pool.priced.remotes.Len(), pool.all.RemoteCount() + if priced != remote { + return fmt.Errorf("total priced transaction count %d != %d", priced, remote) } // Ensure the next nonce to assign is the correct one for addr, txs := range pool.pending { @@ -289,7 +299,7 @@ func TestTransactionQueue(t *testing.T) { pool.currentState.AddBalance(from, big.NewInt(1000)) <-pool.requestReset(nil, nil) - pool.enqueueTx(tx.Hash(), tx) + pool.enqueueTx(tx.Hash(), tx, false, true) <-pool.requestPromoteExecutables(newAccountSet(pool.signer, from)) if len(pool.pending) != 1 { t.Error("expected valid txs to be 1 is", len(pool.pending)) @@ -298,7 +308,7 @@ func TestTransactionQueue(t *testing.T) { tx = transaction(1, 100, key) from, _ = deriveSender(tx) pool.currentState.SetNonce(from, 2) - pool.enqueueTx(tx.Hash(), tx) + pool.enqueueTx(tx.Hash(), tx, false, true) <-pool.requestPromoteExecutables(newAccountSet(pool.signer, from)) if _, ok := pool.pending[from].txs.items[tx.Nonce()]; ok { @@ -323,9 +333,9 @@ func TestTransactionQueue2(t *testing.T) { pool.currentState.AddBalance(from, big.NewInt(1000)) pool.reset(nil, nil) - pool.enqueueTx(tx1.Hash(), tx1) - pool.enqueueTx(tx2.Hash(), tx2) - pool.enqueueTx(tx3.Hash(), tx3) + pool.enqueueTx(tx1.Hash(), tx1, false, true) + pool.enqueueTx(tx2.Hash(), tx2, false, true) + pool.enqueueTx(tx3.Hash(), tx3, false, true) pool.promoteExecutables([]common.Address{from}) if len(pool.pending) != 1 { @@ -488,7 +498,7 @@ func TestTransactionDropping(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000)) // Add some pending and some queued transactions @@ -500,12 +510,21 @@ func TestTransactionDropping(t *testing.T) { tx11 = transaction(11, 200, key) tx12 = transaction(12, 300, key) ) + pool.all.Add(tx0, false) + pool.priced.Put(tx0, false) pool.promoteTx(account, tx0.Hash(), tx0) + + pool.all.Add(tx1, false) + pool.priced.Put(tx1, false) pool.promoteTx(account, tx1.Hash(), tx1) + + pool.all.Add(tx2, false) + pool.priced.Put(tx2, false) pool.promoteTx(account, tx2.Hash(), tx2) - pool.enqueueTx(tx10.Hash(), tx10) - pool.enqueueTx(tx11.Hash(), tx11) - pool.enqueueTx(tx12.Hash(), tx12) + + pool.enqueueTx(tx10.Hash(), tx10, false, true) + pool.enqueueTx(tx11.Hash(), tx11, false, true) + pool.enqueueTx(tx12.Hash(), tx12, false, true) // Check that pre and post validations leave the pool as is if pool.pending[account].Len() != 3 { @@ -698,7 +717,7 @@ func TestTransactionGapFilling(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) // Keep track of transaction events to ensure all executables get announced @@ -725,7 +744,7 @@ func TestTransactionGapFilling(t *testing.T) { t.Fatalf("pool internal state corrupted: %v", err) } // Fill the nonce gap and ensure all transactions become pending - if err := pool.AddRemoteSync(transaction(1, 100000, key)); err != nil { + if err := pool.addRemoteSync(transaction(1, 100000, key)); err != nil { t.Fatalf("failed to add gapped transaction: %v", err) } pending, queued = pool.Stats() @@ -752,12 +771,12 @@ func TestTransactionQueueAccountLimiting(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) testTxPoolConfig.AccountQueue = 10 // Keep queuing up transactions and make sure all above a limit are dropped for i := uint64(1); i <= testTxPoolConfig.AccountQueue; i++ { - if err := pool.AddRemoteSync(transaction(i, 100000, key)); err != nil { + if err := pool.addRemoteSync(transaction(i, 100000, key)); err != nil { t.Fatalf("tx %d: failed to add transaction: %v", i, err) } if len(pool.pending) != 0 { @@ -884,7 +903,7 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { common.MinGasPrice = big.NewInt(0) // Reduce the eviction interval to a testable amount defer func(old time.Duration) { evictionInterval = old }(evictionInterval) - evictionInterval = time.Second + evictionInterval = time.Millisecond * 100 // Create the pool to test the non-expiration enforcement db := rawdb.NewMemoryDatabase() @@ -922,6 +941,22 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { if err := validateTxPoolInternals(pool); err != nil { t.Fatalf("pool internal state corrupted: %v", err) } + + // Allow the eviction interval to run + time.Sleep(2 * evictionInterval) + + // Transactions should not be evicted from the queue yet since lifetime duration has not passed + pending, queued = pool.Stats() + if pending != 0 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 0) + } + if queued != 2 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 2) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } + // Wait a bit for eviction to run and clean up any leftovers, and ensure only the local remains time.Sleep(2 * config.Lifetime) @@ -941,6 +976,72 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { if err := validateTxPoolInternals(pool); err != nil { t.Fatalf("pool internal state corrupted: %v", err) } + + // remove current transactions and increase nonce to prepare for a reset and cleanup + statedb.SetNonce(crypto.PubkeyToAddress(remote.PublicKey), 2) + statedb.SetNonce(crypto.PubkeyToAddress(local.PublicKey), 2) + <-pool.requestReset(nil, nil) + + // make sure queue, pending are cleared + pending, queued = pool.Stats() + if pending != 0 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 0) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } + + // Queue gapped transactions + if err := pool.AddLocal(pricedTransaction(4, 100000, big.NewInt(1), local)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + if err := pool.addRemoteSync(pricedTransaction(4, 100000, big.NewInt(1), remote)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + time.Sleep(5 * evictionInterval) // A half lifetime pass + + // Queue executable transactions, the life cycle should be restarted. + if err := pool.AddLocal(pricedTransaction(2, 100000, big.NewInt(1), local)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + if err := pool.addRemoteSync(pricedTransaction(2, 100000, big.NewInt(1), remote)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + time.Sleep(6 * evictionInterval) + + // All gapped transactions shouldn't be kicked out + pending, queued = pool.Stats() + if pending != 2 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 2) + } + if queued != 2 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 3) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } + + // The whole life time pass after last promotion, kick out stale transactions + time.Sleep(2 * config.Lifetime) + pending, queued = pool.Stats() + if pending != 2 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 2) + } + if nolocals { + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + } else { + if queued != 1 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 1) + } + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } } // Tests that even if the transaction count belonging to a single account goes @@ -953,7 +1054,7 @@ func TestTransactionPendingLimiting(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) testTxPoolConfig.AccountQueue = 10 // Keep track of transaction events to ensure all executables get announced @@ -963,7 +1064,7 @@ func TestTransactionPendingLimiting(t *testing.T) { // Keep queuing up transactions and make sure all above a limit are dropped for i := uint64(0); i < testTxPoolConfig.AccountQueue; i++ { - if err := pool.AddRemoteSync(transaction(i, 100000, key)); err != nil { + if err := pool.addRemoteSync(transaction(i, 100000, key)); err != nil { t.Fatalf("tx %d: failed to add transaction: %v", i, err) } if pool.pending[account].Len() != int(i)+1 { @@ -1033,6 +1134,62 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) { } } +// Test the limit on transaction size is enforced correctly. +// This test verifies every transaction having allowed size +// is added to the pool, and longer transactions are rejected. +func TestTransactionAllowedTxSize(t *testing.T) { + t.Parallel() + + // Create a test account and fund it + pool, key := setupTxPool() + defer pool.Stop() + + account := crypto.PubkeyToAddress(key.PublicKey) + pool.currentState.AddBalance(account, big.NewInt(1000000000)) + + // Compute maximal data size for transactions (lower bound). + // + // It is assumed the fields in the transaction (except of the data) are: + // - nonce <= 32 bytes + // - gasPrice <= 32 bytes + // - gasLimit <= 32 bytes + // - recipient == 20 bytes + // - value <= 32 bytes + // - signature == 65 bytes + // All those fields are summed up to at most 213 bytes. + baseSize := uint64(213) + dataSize := txMaxSize - baseSize + + // Try adding a transaction with maximal allowed size + tx := pricedDataTransaction(0, pool.currentMaxGas, big.NewInt(1), key, dataSize) + if err := pool.addRemoteSync(tx); err != nil { + t.Fatalf("failed to add transaction of size %d, close to maximal: %v", int(tx.Size()), err) + } + // Try adding a transaction with random allowed size + if err := pool.addRemoteSync(pricedDataTransaction(1, pool.currentMaxGas, big.NewInt(1), key, uint64(rand.Intn(int(dataSize))))); err != nil { + t.Fatalf("failed to add transaction of random allowed size: %v", err) + } + // Try adding a transaction of minimal not allowed size + if err := pool.addRemoteSync(pricedDataTransaction(2, pool.currentMaxGas, big.NewInt(1), key, txMaxSize)); err == nil { + t.Fatalf("expected rejection on slightly oversize transaction") + } + // Try adding a transaction of random not allowed size + if err := pool.addRemoteSync(pricedDataTransaction(2, pool.currentMaxGas, big.NewInt(1), key, dataSize+1+uint64(rand.Intn(int(10*txMaxSize))))); err == nil { + t.Fatalf("expected rejection on oversize transaction") + } + // Run some sanity checks on the pool internals + pending, queued := pool.Stats() + if pending != 2 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 2) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } +} + // Tests that if transactions start being capped, transactions are also removed from 'all' func TestTransactionCapClearsFromAll(t *testing.T) { t.Parallel() @@ -1458,7 +1615,7 @@ func TestTransactionPoolStableUnderpricing(t *testing.T) { t.Fatalf("pool internal state corrupted: %v", err) } // Ensure that adding high priced transactions drops a cheap, but doesn't produce a gap - if err := pool.AddRemoteSync(pricedTransaction(0, 100000, big.NewInt(3), keys[1])); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(3), keys[1])); err != nil { t.Fatalf("failed to add well priced transaction: %v", err) } pending, queued = pool.Stats() @@ -1476,6 +1633,71 @@ func TestTransactionPoolStableUnderpricing(t *testing.T) { } } +// Tests that the pool rejects duplicate transactions. +func TestTransactionDeduplication(t *testing.T) { + t.Parallel() + + // Create the pool to test the pricing enforcement with + statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) + blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} + + pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) + defer pool.Stop() + + // Create a test account to add transactions with + key, _ := crypto.GenerateKey() + pool.currentState.AddBalance(crypto.PubkeyToAddress(key.PublicKey), big.NewInt(1000000000)) + + // Create a batch of transactions and add a few of them + txs := make([]*types.Transaction, common.LimitThresholdNonceInQueue) + for i := 0; i < len(txs); i++ { + txs[i] = pricedTransaction(uint64(i), 100000, big.NewInt(1), key) + } + var firsts []*types.Transaction + for i := 0; i < len(txs); i += 2 { + firsts = append(firsts, txs[i]) + } + errs := pool.AddRemotesSync(firsts) + if len(errs) != len(firsts) { + t.Fatalf("first add mismatching result count: have %d, want %d", len(errs), len(firsts)) + } + for i, err := range errs { + if err != nil { + t.Errorf("add %d failed: %v", i, err) + } + } + pending, queued := pool.Stats() + if pending != 1 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 1) + } + if queued != len(txs)/2-1 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, len(txs)/2-1) + } + // Try to add all of them now and ensure previous ones error out as knowns + errs = pool.AddRemotesSync(txs) + if len(errs) != len(txs) { + t.Fatalf("all add mismatching result count: have %d, want %d", len(errs), len(txs)) + } + for i, err := range errs { + if i%2 == 0 && err == nil { + t.Errorf("add %d succeeded, should have failed as known", i) + } + if i%2 == 1 && err != nil { + t.Errorf("add %d failed: %v", i, err) + } + } + pending, queued = pool.Stats() + if pending != len(txs) { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, len(txs)) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } +} + // Tests that the pool rejects replacement transactions that don't meet the minimum // price bump required. func TestTransactionReplacement(t *testing.T) { @@ -1502,7 +1724,7 @@ func TestTransactionReplacement(t *testing.T) { price := int64(100) threshold := (price * (100 + int64(testTxPoolConfig.PriceBump))) / 100 - if err := pool.AddRemoteSync(pricedTransaction(0, 100000, big.NewInt(1), key)); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(1), key)); err != nil { t.Fatalf("failed to add original cheap pending transaction: %v", err) } if err := pool.AddRemote(pricedTransaction(0, 100001, big.NewInt(1), key)); err != ErrReplaceUnderpriced { @@ -1515,7 +1737,7 @@ func TestTransactionReplacement(t *testing.T) { t.Fatalf("cheap replacement event firing failed: %v", err) } - if err := pool.AddRemoteSync(pricedTransaction(0, 100000, big.NewInt(price), key)); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(price), key)); err != nil { t.Fatalf("failed to add original proper pending transaction: %v", err) } if err := pool.AddRemote(pricedTransaction(0, 100001, big.NewInt(threshold-1), key)); err != ErrReplaceUnderpriced { @@ -1606,7 +1828,7 @@ func testTransactionJournaling(t *testing.T, nolocals bool) { if err := pool.AddLocal(pricedTransaction(2, 100000, big.NewInt(1), local)); err != nil { t.Fatalf("failed to add local transaction: %v", err) } - if err := pool.AddRemoteSync(pricedTransaction(0, 100000, big.NewInt(1), remote)); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(1), remote)); err != nil { t.Fatalf("failed to add remote transaction: %v", err) } pending, queued := pool.Stats() @@ -1728,6 +1950,24 @@ func TestTransactionStatusCheck(t *testing.T) { } } +// Test the transaction slots consumption is computed correctly +func TestTransactionSlotCount(t *testing.T) { + t.Parallel() + + key, _ := crypto.GenerateKey() + + // Check that an empty transaction consumes a single slot + smallTx := pricedDataTransaction(0, 0, big.NewInt(0), key, 0) + if slots := numSlots(smallTx); slots != 1 { + t.Fatalf("small transactions slot count mismatch: have %d want %d", slots, 1) + } + // Check that a large transaction consumes the correct number of slots + bigTx := pricedDataTransaction(0, 0, big.NewInt(0), key, uint64(10*txSlotSize)) + if slots := numSlots(bigTx); slots != 11 { + t.Fatalf("big transactions slot count mismatch: have %d want %d", slots, 11) + } +} + // Benchmarks the speed of validating the contents of the pending queue of the // transaction pool. func BenchmarkPendingDemotion100(b *testing.B) { benchmarkPendingDemotion(b, 100) } @@ -1739,7 +1979,7 @@ func benchmarkPendingDemotion(b *testing.B, size int) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) for i := 0; i < size; i++ { @@ -1764,12 +2004,12 @@ func benchmarkFuturePromotion(b *testing.B, size int) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) for i := 0; i < size; i++ { tx := transaction(uint64(1+i), 100000, key) - pool.enqueueTx(tx.Hash(), tx) + pool.enqueueTx(tx.Hash(), tx, false, true) } // Benchmark the speed of pool validation b.ResetTimer() @@ -1779,16 +2019,20 @@ func benchmarkFuturePromotion(b *testing.B, size int) { } // Benchmarks the speed of batched transaction insertion. -func BenchmarkPoolBatchInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100) } -func BenchmarkPoolBatchInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000) } -func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000) } +func BenchmarkPoolBatchInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100, false) } +func BenchmarkPoolBatchInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000, false) } +func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000, false) } -func benchmarkPoolBatchInsert(b *testing.B, size int) { +func BenchmarkPoolBatchLocalInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100, true) } +func BenchmarkPoolBatchLocalInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000, true) } +func BenchmarkPoolBatchLocalInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000, true) } + +func benchmarkPoolBatchInsert(b *testing.B, size int, local bool) { // Generate a batch of transactions to enqueue into the pool pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) batches := make([]types.Transactions, b.N) @@ -1801,6 +2045,45 @@ func benchmarkPoolBatchInsert(b *testing.B, size int) { // Benchmark importing the transactions into the queue b.ResetTimer() for _, batch := range batches { - pool.AddRemotes(batch) + if local { + pool.AddLocals(batch) + } else { + pool.AddRemotes(batch) + } + } +} + +func BenchmarkInsertRemoteWithAllLocals(b *testing.B) { + // Allocate keys for testing + key, _ := crypto.GenerateKey() + account := crypto.PubkeyToAddress(key.PublicKey) + + remoteKey, _ := crypto.GenerateKey() + remoteAddr := crypto.PubkeyToAddress(remoteKey.PublicKey) + + locals := make([]*types.Transaction, 4096+1024) // Occupy all slots + for i := 0; i < len(locals); i++ { + locals[i] = transaction(uint64(i), 100000, key) + } + remotes := make([]*types.Transaction, 1000) + for i := 0; i < len(remotes); i++ { + remotes[i] = pricedTransaction(uint64(i), 100000, big.NewInt(2), remoteKey) // Higher gasprice + } + // Benchmark importing the transactions into the queue + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + pool, _ := setupTxPool() + pool.currentState.AddBalance(account, big.NewInt(100000000)) + for _, local := range locals { + pool.AddLocal(local) + } + b.StartTimer() + // Assign a high enough balance for testing + pool.currentState.AddBalance(remoteAddr, big.NewInt(100000000)) + for i := 0; i < len(remotes); i++ { + pool.AddRemotes([]*types.Transaction{remotes[i]}) + } + pool.Stop() } } diff --git a/core/types/block.go b/core/types/block.go index cbeb565373..042db62323 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -23,14 +23,16 @@ import ( "io" "math/big" "sort" + "sync" "sync/atomic" "time" "unsafe" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" - "github.com/XinFinOrg/XDPoSChain/crypto/sha3" + "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/rlp" + "golang.org/x/crypto/sha3" ) var ( @@ -155,10 +157,19 @@ func (h *Header) Size() common.StorageSize { return common.StorageSize(unsafe.Sizeof(*h)) + common.StorageSize(len(h.Extra)+(h.Difficulty.BitLen()+h.Number.BitLen()+h.Time.BitLen())/8) } +// hasherPool holds LegacyKeccak hashers. +var hasherPool = sync.Pool{ + New: func() interface{} { + return sha3.NewLegacyKeccak256() + }, +} + func rlpHash(x interface{}) (h common.Hash) { - hw := sha3.NewKeccak256() - rlp.Encode(hw, x) - hw.Sum(h[:0]) + sha := hasherPool.Get().(crypto.KeccakState) + defer hasherPool.Put(sha) + sha.Reset() + rlp.Encode(sha, x) + sha.Read(h[:]) return h } diff --git a/core/types/transaction.go b/core/types/transaction.go index 03f1bbe39a..f5fbad376a 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -188,9 +188,15 @@ func (tx *Transaction) UnmarshalJSON(input []byte) error { func (tx *Transaction) Data() []byte { return common.CopyBytes(tx.data.Payload) } func (tx *Transaction) Gas() uint64 { return tx.data.GasLimit } func (tx *Transaction) GasPrice() *big.Int { return new(big.Int).Set(tx.data.Price) } -func (tx *Transaction) Value() *big.Int { return new(big.Int).Set(tx.data.Amount) } -func (tx *Transaction) Nonce() uint64 { return tx.data.AccountNonce } -func (tx *Transaction) CheckNonce() bool { return true } +func (tx *Transaction) GasPriceCmp(other *Transaction) int { + return tx.data.Price.Cmp(other.data.Price) +} +func (tx *Transaction) GasPriceIntCmp(other *big.Int) int { + return tx.data.Price.Cmp(other) +} +func (tx *Transaction) Value() *big.Int { return new(big.Int).Set(tx.data.Amount) } +func (tx *Transaction) Nonce() uint64 { return tx.data.AccountNonce } +func (tx *Transaction) CheckNonce() bool { return true } // To returns the recipient address of the transaction. // It returns nil if the transaction is a contract creation. diff --git a/crypto/crypto.go b/crypto/crypto.go index 2213bf0c14..745f9c909f 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -17,57 +17,81 @@ package crypto import ( + "bufio" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "encoding/hex" "errors" "fmt" + "hash" "io" + "io/ioutil" "math/big" "os" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/math" - "github.com/XinFinOrg/XDPoSChain/crypto/sha3" "github.com/XinFinOrg/XDPoSChain/rlp" + "golang.org/x/crypto/sha3" ) +//SignatureLength indicates the byte length required to carry a signature with recovery id. +const SignatureLength = 64 + 1 // 64 bytes ECDSA signature + 1 byte recovery id + +// RecoveryIDOffset points to the byte offset within the signature that contains the recovery id. +const RecoveryIDOffset = 64 + +// DigestLength sets the signature digest exact length +const DigestLength = 32 + var ( - secp256k1_N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) - secp256k1_halfN = new(big.Int).Div(secp256k1_N, big.NewInt(2)) + secp256k1N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) + secp256k1halfN = new(big.Int).Div(secp256k1N, big.NewInt(2)) ) +var errInvalidPubkey = errors.New("invalid secp256k1 public key") + +// KeccakState wraps sha3.state. In addition to the usual hash methods, it also supports +// Read to get a variable amount of data from the hash state. Read is faster than Sum +// because it doesn't copy the internal state, but also modifies the internal state. +type KeccakState interface { + hash.Hash + Read([]byte) (int, error) +} + // Keccak256 calculates and returns the Keccak256 hash of the input data. func Keccak256(data ...[]byte) []byte { - d := sha3.NewKeccak256() + b := make([]byte, 32) + d := sha3.NewLegacyKeccak256().(KeccakState) for _, b := range data { d.Write(b) } - return d.Sum(nil) + d.Read(b) + return b } // Keccak256Hash calculates and returns the Keccak256 hash of the input data, // converting it to an internal Hash data structure. func Keccak256Hash(data ...[]byte) (h common.Hash) { - d := sha3.NewKeccak256() + d := sha3.NewLegacyKeccak256().(KeccakState) for _, b := range data { d.Write(b) } - d.Sum(h[:0]) + d.Read(h[:]) return h } // Keccak512 calculates and returns the Keccak512 hash of the input data. func Keccak512(data ...[]byte) []byte { - d := sha3.NewKeccak512() + d := sha3.NewLegacyKeccak512() for _, b := range data { d.Write(b) } return d.Sum(nil) } -// Creates an ethereum address given the bytes and the nonce +// CreateAddress creates an ethereum address given the bytes and the nonce func CreateAddress(b common.Address, nonce uint64) common.Address { data, _ := rlp.EncodeToBytes([]interface{}{b, nonce}) return common.BytesToAddress(Keccak256(data)[12:]) @@ -104,7 +128,7 @@ func toECDSA(d []byte, strict bool) (*ecdsa.PrivateKey, error) { priv.D = new(big.Int).SetBytes(d) // The priv.D must < N - if priv.D.Cmp(secp256k1_N) >= 0 { + if priv.D.Cmp(secp256k1N) >= 0 { return nil, fmt.Errorf("invalid private key, >=N") } // The priv.D must not be zero or negative. @@ -127,12 +151,13 @@ func FromECDSA(priv *ecdsa.PrivateKey) []byte { return math.PaddedBigBytes(priv.D, priv.Params().BitSize/8) } -func ToECDSAPub(pub []byte) *ecdsa.PublicKey { - if len(pub) == 0 { - return nil - } +// UnmarshalPubkey converts bytes to a secp256k1 public key. +func UnmarshalPubkey(pub []byte) (*ecdsa.PublicKey, error) { x, y := elliptic.Unmarshal(S256(), pub) - return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y} + if x == nil { + return nil, errInvalidPubkey + } + return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y}, nil } func FromECDSAPub(pub *ecdsa.PublicKey) []byte { @@ -145,38 +170,77 @@ func FromECDSAPub(pub *ecdsa.PublicKey) []byte { // HexToECDSA parses a secp256k1 private key. func HexToECDSA(hexkey string) (*ecdsa.PrivateKey, error) { b, err := hex.DecodeString(hexkey) - if err != nil { - return nil, errors.New("invalid hex string") + if byteErr, ok := err.(hex.InvalidByteError); ok { + return nil, fmt.Errorf("invalid hex character %q in private key", byte(byteErr)) + } else if err != nil { + return nil, errors.New("invalid hex data for private key") } return ToECDSA(b) } // LoadECDSA loads a secp256k1 private key from the given file. func LoadECDSA(file string) (*ecdsa.PrivateKey, error) { - buf := make([]byte, 64) fd, err := os.Open(file) if err != nil { return nil, err } defer fd.Close() - if _, err := io.ReadFull(fd, buf); err != nil { + + r := bufio.NewReader(fd) + buf := make([]byte, 64) + n, err := readASCII(buf, r) + if err != nil { + return nil, err + } else if n != len(buf) { + return nil, fmt.Errorf("key file too short, want 64 hex characters") + } + if err := checkKeyFileEnd(r); err != nil { return nil, err } - key, err := hex.DecodeString(string(buf)) - if err != nil { - return nil, err + return HexToECDSA(string(buf)) +} + +// readASCII reads into 'buf', stopping when the buffer is full or +// when a non-printable control character is encountered. +func readASCII(buf []byte, r *bufio.Reader) (n int, err error) { + for ; n < len(buf); n++ { + buf[n], err = r.ReadByte() + switch { + case err == io.EOF || buf[n] < '!': + return n, nil + case err != nil: + return n, err + } + } + return n, nil +} + +// checkKeyFileEnd skips over additional newlines at the end of a key file. +func checkKeyFileEnd(r *bufio.Reader) error { + for i := 0; ; i++ { + b, err := r.ReadByte() + switch { + case err == io.EOF: + return nil + case err != nil: + return err + case b != '\n' && b != '\r': + return fmt.Errorf("invalid character %q at end of key file", b) + case i >= 2: + return errors.New("key file too long, want 64 hex characters") + } } - return ToECDSA(key) } // SaveECDSA saves a secp256k1 private key to the given file with // restrictive permissions. The key data is saved hex-encoded. func SaveECDSA(file string, key *ecdsa.PrivateKey) error { k := hex.EncodeToString(FromECDSA(key)) - return os.WriteFile(file, []byte(k), 0600) + return ioutil.WriteFile(file, []byte(k), 0600) } +// GenerateKey generates a new private key. func GenerateKey() (*ecdsa.PrivateKey, error) { return ecdsa.GenerateKey(S256(), rand.Reader) } @@ -189,11 +253,11 @@ func ValidateSignatureValues(v byte, r, s *big.Int, homestead bool) bool { } // reject upper range of s values (ECDSA malleability) // see discussion in secp256k1/libsecp256k1/include/secp256k1.h - if homestead && s.Cmp(secp256k1_halfN) > 0 { + if homestead && s.Cmp(secp256k1halfN) > 0 { return false } // Frontier: allow s to be in full N range - return r.Cmp(secp256k1_N) < 0 && s.Cmp(secp256k1_N) < 0 && (v == 0 || v == 1) + return r.Cmp(secp256k1N) < 0 && s.Cmp(secp256k1N) < 0 && (v == 0 || v == 1) } func PubkeyToAddress(p ecdsa.PublicKey) common.Address { diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index f1910e5c26..9e1bb2639b 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -20,11 +20,14 @@ import ( "bytes" "crypto/ecdsa" "encoding/hex" + "io/ioutil" "math/big" "os" + "reflect" "testing" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/hexutil" ) var testAddrHex = "970e8128ab834e8eac17ab8e3812f010678cf791" @@ -55,6 +58,33 @@ func BenchmarkSha3(b *testing.B) { } } +func TestUnmarshalPubkey(t *testing.T) { + key, err := UnmarshalPubkey(nil) + if err != errInvalidPubkey || key != nil { + t.Fatalf("expected error, got %v, %v", err, key) + } + key, err = UnmarshalPubkey([]byte{1, 2, 3}) + if err != errInvalidPubkey || key != nil { + t.Fatalf("expected error, got %v, %v", err, key) + } + + var ( + enc, _ = hex.DecodeString("04760c4460e5336ac9bbd87952a3c7ec4363fc0a97bd31c86430806e287b437fd1b01abc6e1db640cf3106b520344af1d58b00b57823db3e1407cbc433e1b6d04d") + dec = &ecdsa.PublicKey{ + Curve: S256(), + X: hexutil.MustDecodeBig("0x760c4460e5336ac9bbd87952a3c7ec4363fc0a97bd31c86430806e287b437fd1"), + Y: hexutil.MustDecodeBig("0xb01abc6e1db640cf3106b520344af1d58b00b57823db3e1407cbc433e1b6d04d"), + } + ) + key, err = UnmarshalPubkey(enc) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !reflect.DeepEqual(key, dec) { + t.Fatal("wrong result") + } +} + func TestSign(t *testing.T) { key, _ := HexToECDSA(testPrivHex) addr := common.HexToAddress(testAddrHex) @@ -68,7 +98,7 @@ func TestSign(t *testing.T) { if err != nil { t.Errorf("ECRecover error: %s", err) } - pubKey := ToECDSAPub(recoveredPub) + pubKey, _ := UnmarshalPubkey(recoveredPub) recoveredAddr := PubkeyToAddress(*pubKey) if addr != recoveredAddr { t.Errorf("Address mismatch: want: %x have: %x", addr, recoveredAddr) @@ -109,39 +139,82 @@ func TestNewContractAddress(t *testing.T) { checkAddr(t, common.HexToAddress("c9ddedf451bc62ce88bf9292afb13df35b670699"), caddr2) } -func TestLoadECDSAFile(t *testing.T) { - keyBytes := common.FromHex(testPrivHex) - fileName0 := "test_key0" - fileName1 := "test_key1" - checkKey := func(k *ecdsa.PrivateKey) { - checkAddr(t, PubkeyToAddress(k.PublicKey), common.HexToAddress(testAddrHex)) - loadedKeyBytes := FromECDSA(k) - if !bytes.Equal(loadedKeyBytes, keyBytes) { - t.Fatalf("private key mismatch: want: %x have: %x", keyBytes, loadedKeyBytes) +func TestLoadECDSA(t *testing.T) { + tests := []struct { + input string + err string + }{ + // good + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\r"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\r\n"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\n"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\r"}, + // bad + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", + err: "key file too short, want 64 hex characters", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde\n", + err: "key file too short, want 64 hex characters", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeX", + err: "invalid hex character 'X' in private key", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdefX", + err: "invalid character 'X' at end of key file", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\n\n", + err: "key file too long, want 64 hex characters", + }, + } + + for _, test := range tests { + f, err := ioutil.TempFile("", "loadecdsa_test.*.txt") + if err != nil { + t.Fatal(err) + } + filename := f.Name() + f.WriteString(test.input) + f.Close() + + _, err = LoadECDSA(filename) + switch { + case err != nil && test.err == "": + t.Fatalf("unexpected error for input %q:\n %v", test.input, err) + case err != nil && err.Error() != test.err: + t.Fatalf("wrong error for input %q:\n %v", test.input, err) + case err == nil && test.err != "": + t.Fatalf("LoadECDSA did not return error for input %q", test.input) } } +} - os.WriteFile(fileName0, []byte(testPrivHex), 0600) - defer os.Remove(fileName0) - - key0, err := LoadECDSA(fileName0) +func TestSaveECDSA(t *testing.T) { + f, err := ioutil.TempFile("", "saveecdsa_test.*.txt") if err != nil { t.Fatal(err) } - checkKey(key0) + file := f.Name() + f.Close() + defer os.Remove(file) - // again, this time with SaveECDSA instead of manual save: - err = SaveECDSA(fileName1, key0) + key, _ := HexToECDSA(testPrivHex) + if err := SaveECDSA(file, key); err != nil { + t.Fatal(err) + } + loaded, err := LoadECDSA(file) if err != nil { t.Fatal(err) } - defer os.Remove(fileName1) - - key1, err := LoadECDSA(fileName1) - if err != nil { - t.Fatal(err) + if !reflect.DeepEqual(key, loaded) { + t.Fatal("loaded key not equal to saved key") } - checkKey(key1) } func TestValidateSignatureValues(t *testing.T) { @@ -153,7 +226,7 @@ func TestValidateSignatureValues(t *testing.T) { minusOne := big.NewInt(-1) one := common.Big1 zero := common.Big0 - secp256k1nMinus1 := new(big.Int).Sub(secp256k1_N, common.Big1) + secp256k1nMinus1 := new(big.Int).Sub(secp256k1N, common.Big1) // correct v,r,s check(true, 0, one, one) @@ -180,9 +253,9 @@ func TestValidateSignatureValues(t *testing.T) { // correct sig with max r,s check(true, 0, secp256k1nMinus1, secp256k1nMinus1) // correct v, combinations of incorrect r,s at upper limit - check(false, 0, secp256k1_N, secp256k1nMinus1) - check(false, 0, secp256k1nMinus1, secp256k1_N) - check(false, 0, secp256k1_N, secp256k1_N) + check(false, 0, secp256k1N, secp256k1nMinus1) + check(false, 0, secp256k1nMinus1, secp256k1N) + check(false, 0, secp256k1N, secp256k1N) // current callers ensures r,s cannot be negative, but let's test for that too // as crypto package could be used stand-alone diff --git a/eth/gasprice/gasprice.go b/eth/gasprice/gasprice.go index 8dfaeec43a..5449279037 100644 --- a/eth/gasprice/gasprice.go +++ b/eth/gasprice/gasprice.go @@ -162,7 +162,7 @@ type transactionsByGasPrice []*types.Transaction func (t transactionsByGasPrice) Len() int { return len(t) } func (t transactionsByGasPrice) Swap(i, j int) { t[i], t[j] = t[j], t[i] } -func (t transactionsByGasPrice) Less(i, j int) bool { return t[i].GasPrice().Cmp(t[j].GasPrice()) < 0 } +func (t transactionsByGasPrice) Less(i, j int) bool { return t[i].GasPriceCmp(t[j]) < 0 } // getBlockPrices calculates the lowest transaction gas price in a given block // and sends it to the result channel. If the block is empty, price is nil. diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index d234802a1b..f2f7f01fbe 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -473,21 +473,19 @@ func (s *PrivateAccountAPI) Sign(ctx context.Context, data hexutil.Bytes, addr c // // https://github.com/XinFinOrg/XDPoSChain/wiki/Management-APIs#personal_ecRecover func (s *PrivateAccountAPI) EcRecover(ctx context.Context, data, sig hexutil.Bytes) (common.Address, error) { - if len(sig) != 65 { - return common.Address{}, fmt.Errorf("signature must be 65 bytes long") + if len(sig) != crypto.SignatureLength { + return common.Address{}, fmt.Errorf("signature must be %d bytes long", crypto.SignatureLength) } - if sig[64] != 27 && sig[64] != 28 { + if sig[crypto.RecoveryIDOffset] != 27 && sig[crypto.RecoveryIDOffset] != 28 { return common.Address{}, fmt.Errorf("invalid Ethereum signature (V is not 27 or 28)") } - sig[64] -= 27 // Transform yellow paper V from 27/28 to 0/1 + sig[crypto.RecoveryIDOffset] -= 27 // Transform yellow paper V from 27/28 to 0/1 - rpk, err := crypto.Ecrecover(signHash(data), sig) + rpk, err := crypto.SigToPub(accounts.TextHash(data), sig) if err != nil { return common.Address{}, err } - pubKey := crypto.ToECDSAPub(rpk) - recoveredAddr := crypto.PubkeyToAddress(*pubKey) - return recoveredAddr, nil + return crypto.PubkeyToAddress(*rpk), nil } // SignAndSendTransaction was renamed to SendTransaction. This method is deprecated diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 66efd41997..ea26b2f2ec 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -528,9 +528,9 @@ func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) { return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey)) } // TODO: fewer pointless conversions - pub := crypto.ToECDSAPub(pubKey65) - if pub.X == nil { - return nil, fmt.Errorf("invalid public key") + pub, err := crypto.UnmarshalPubkey(pubKey65) + if err != nil { + return nil, err } return ecies.ImportECDSAPublic(pub), nil } diff --git a/swarm/services/swap/swap.go b/swarm/services/swap/swap.go index be595b710b..153f058968 100644 --- a/swarm/services/swap/swap.go +++ b/swarm/services/swap/swap.go @@ -80,7 +80,7 @@ type PayProfile struct { lock sync.RWMutex } -//create params with default values +// create params with default values func NewDefaultSwapParams() *SwapParams { return &SwapParams{ PayProfile: &PayProfile{}, @@ -102,8 +102,8 @@ func NewDefaultSwapParams() *SwapParams { } } -//this can only finally be set after all config options (file, cmd line, env vars) -//have been evaluated +// this can only finally be set after all config options (file, cmd line, env vars) +// have been evaluated func (self *SwapParams) Init(contract common.Address, prvkey *ecdsa.PrivateKey) { pubkey := &prvkey.PublicKey @@ -141,8 +141,12 @@ func NewSwap(local *SwapParams, remote *SwapProfile, backend chequebook.Backend, if !ok { log.Info(fmt.Sprintf("invalid contract %v for peer %v: %v)", remote.Contract.Hex()[:8], proto, err)) } else { + pub, err := crypto.UnmarshalPubkey(common.FromHex(remote.PublicKey)) + if err != nil { + return nil, err + } // remote contract valid, create inbox - in, err = chequebook.NewInbox(local.privateKey, remote.Contract, local.Beneficiary, crypto.ToECDSAPub(common.FromHex(remote.PublicKey)), backend) + in, err = chequebook.NewInbox(local.privateKey, remote.Contract, local.Beneficiary, pub, backend) if err != nil { log.Warn(fmt.Sprintf("unable to set up inbox for chequebook contract %v for peer %v: %v)", remote.Contract.Hex()[:8], proto, err)) } diff --git a/trie/committer.go b/trie/committer.go index 9db314d980..435da51981 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -22,6 +22,7 @@ import ( "sync" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/rlp" "golang.org/x/crypto/sha3" ) @@ -46,7 +47,7 @@ type leaf struct { // processed sequentially - onleaf will never be called in parallel or out of order. type committer struct { tmp sliceBuffer - sha keccakState + sha crypto.KeccakState onleaf LeafCallback leafCh chan *leaf @@ -57,7 +58,7 @@ var committerPool = sync.Pool{ New: func() interface{} { return &committer{ tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + sha: sha3.NewLegacyKeccak256().(crypto.KeccakState), } }, } diff --git a/trie/hasher.go b/trie/hasher.go index a2b385ad5b..1dc9aae689 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -17,21 +17,13 @@ package trie import ( - "hash" "sync" + "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/rlp" "golang.org/x/crypto/sha3" ) -// keccakState wraps sha3.state. In addition to the usual hash methods, it also supports -// Read to get a variable amount of data from the hash state. Read is faster than Sum -// because it doesn't copy the internal state, but also modifies the internal state. -type keccakState interface { - hash.Hash - Read([]byte) (int, error) -} - type sliceBuffer []byte func (b *sliceBuffer) Write(data []byte) (n int, err error) { @@ -46,7 +38,7 @@ func (b *sliceBuffer) Reset() { // hasher is a type used for the trie Hash operation. A hasher has some // internal preallocated temp space type hasher struct { - sha keccakState + sha crypto.KeccakState tmp sliceBuffer parallel bool // Whether to use paralallel threads when hashing } @@ -56,7 +48,7 @@ var hasherPool = sync.Pool{ New: func() interface{} { return &hasher{ tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + sha: sha3.NewLegacyKeccak256().(crypto.KeccakState), } }, } diff --git a/whisper/whisperv5/api.go b/whisper/whisperv5/api.go index 89e9c2860b..37c04e70aa 100644 --- a/whisper/whisperv5/api.go +++ b/whisper/whisperv5/api.go @@ -256,8 +256,7 @@ func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, er // Set asymmetric key that is used to encrypt the message if pubKeyGiven { - params.Dst = crypto.ToECDSAPub(req.PublicKey) - if !ValidatePublicKey(params.Dst) { + if params.Dst, err = crypto.UnmarshalPubkey(req.PublicKey); err != nil { return false, ErrInvalidPublicKey } } @@ -333,8 +332,7 @@ func (api *PublicWhisperAPI) Messages(ctx context.Context, crit Criteria) (*rpc. } if len(crit.Sig) > 0 { - filter.Src = crypto.ToECDSAPub(crit.Sig) - if !ValidatePublicKey(filter.Src) { + if filter.Src, err = crypto.UnmarshalPubkey(crit.Sig); err != nil { return nil, ErrInvalidSigningPubKey } } @@ -517,8 +515,7 @@ func (api *PublicWhisperAPI) NewMessageFilter(req Criteria) (string, error) { } if len(req.Sig) > 0 { - src = crypto.ToECDSAPub(req.Sig) - if !ValidatePublicKey(src) { + if src, err = crypto.UnmarshalPubkey(req.Sig); err != nil { return "", ErrInvalidSigningPubKey } } diff --git a/whisper/whisperv6/api.go b/whisper/whisperv6/api.go index 95106ee167..0ea7e0fc52 100644 --- a/whisper/whisperv6/api.go +++ b/whisper/whisperv6/api.go @@ -275,8 +275,7 @@ func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, er // Set asymmetric key that is used to encrypt the message if pubKeyGiven { - params.Dst = crypto.ToECDSAPub(req.PublicKey) - if !ValidatePublicKey(params.Dst) { + if params.Dst, err = crypto.UnmarshalPubkey(req.PublicKey); err != nil { return false, ErrInvalidPublicKey } } @@ -352,8 +351,7 @@ func (api *PublicWhisperAPI) Messages(ctx context.Context, crit Criteria) (*rpc. } if len(crit.Sig) > 0 { - filter.Src = crypto.ToECDSAPub(crit.Sig) - if !ValidatePublicKey(filter.Src) { + if filter.Src, err = crypto.UnmarshalPubkey(crit.Sig); err != nil { return nil, ErrInvalidSigningPubKey } } @@ -536,8 +534,7 @@ func (api *PublicWhisperAPI) NewMessageFilter(req Criteria) (string, error) { } if len(req.Sig) > 0 { - src = crypto.ToECDSAPub(req.Sig) - if !ValidatePublicKey(src) { + if src, err = crypto.UnmarshalPubkey(req.Sig); err != nil { return "", ErrInvalidSigningPubKey } }