diff --git a/common/countdown/countdown.go b/common/countdown/countdown.go index 860381e110..5f738abb02 100644 --- a/common/countdown/countdown.go +++ b/common/countdown/countdown.go @@ -15,7 +15,7 @@ type CountdownTimer struct { initilised bool timeoutDuration time.Duration // Triggered when the countdown timer timeout for the `timeoutDuration` period, it will pass current timestamp to the callback function - OnTimeoutFn func(time time.Time) error + OnTimeoutFn func(time time.Time, i interface{}) error } func NewCountDown(duration time.Duration) *CountdownTimer { @@ -35,17 +35,17 @@ func (t *CountdownTimer) StopTimer() { } // Reset will start the countdown timer if it's already stopped, or simply reset the countdown time back to the defual `duration` -func (t *CountdownTimer) Reset() { +func (t *CountdownTimer) Reset(i interface{}) { if !t.isInitilised() { t.setInitilised(true) - go t.startTimer() + go t.startTimer(i) } else { t.resetc <- 0 } } // A long running process that -func (t *CountdownTimer) startTimer() { +func (t *CountdownTimer) startTimer(i interface{}) { // Make sure we mark Initilised to false when we quit the countdown defer t.setInitilised(false) timer := time.NewTimer(t.timeoutDuration) @@ -58,7 +58,7 @@ func (t *CountdownTimer) startTimer() { return case <-timer.C: log.Debug("Countdown time reached!") - err := t.OnTimeoutFn(time.Now()) + err := t.OnTimeoutFn(time.Now(), i) if err != nil { log.Error("OnTimeoutFn error", err) } diff --git a/common/countdown/countdown_test.go b/common/countdown/countdown_test.go index 758287e5e1..3697783f79 100644 --- a/common/countdown/countdown_test.go +++ b/common/countdown/countdown_test.go @@ -9,23 +9,24 @@ import ( ) func TestCountdownWillCallback(t *testing.T) { - + var fakeI interface{} called := make(chan int) - OnTimeoutFn := func(time time.Time) error { + OnTimeoutFn := func(time.Time, interface{}) error { called <- 1 return nil } countdown := NewCountDown(1000 * time.Millisecond) countdown.OnTimeoutFn = OnTimeoutFn - countdown.Reset() + countdown.Reset(fakeI) <-called t.Log("Times up, successfully called OnTimeoutFn") } func TestCountdownShouldReset(t *testing.T) { + var fakeI interface{} called := make(chan int) - OnTimeoutFn := func(time time.Time) error { + OnTimeoutFn := func(time.Time, interface{}) error { called <- 1 return nil } @@ -34,7 +35,7 @@ func TestCountdownShouldReset(t *testing.T) { countdown.OnTimeoutFn = OnTimeoutFn // Check countdown did not start assert.False(t, countdown.isInitilised()) - countdown.Reset() + countdown.Reset(fakeI) // Now the countdown should already started assert.True(t, countdown.isInitilised()) expectedCalledTime := time.Now().Add(9000 * time.Millisecond) @@ -53,7 +54,7 @@ firstReset: } break firstReset case <-resetTimer.C: - countdown.Reset() + countdown.Reset(fakeI) } } @@ -71,8 +72,9 @@ firstReset: } func TestCountdownShouldResetEvenIfErrored(t *testing.T) { + var fakeI interface{} called := make(chan int) - OnTimeoutFn := func(time time.Time) error { + OnTimeoutFn := func(time.Time, interface{}) error { called <- 1 return fmt.Errorf("ERROR!") } @@ -81,7 +83,7 @@ func TestCountdownShouldResetEvenIfErrored(t *testing.T) { countdown.OnTimeoutFn = OnTimeoutFn // Check countdown did not start assert.False(t, countdown.isInitilised()) - countdown.Reset() + countdown.Reset(fakeI) // Now the countdown should already started assert.True(t, countdown.isInitilised()) expectedCalledTime := time.Now().Add(9000 * time.Millisecond) @@ -100,7 +102,7 @@ firstReset: } break firstReset case <-resetTimer.C: - countdown.Reset() + countdown.Reset(fakeI) } } @@ -118,8 +120,9 @@ firstReset: } func TestCountdownShouldBeAbleToStop(t *testing.T) { + var fakeI interface{} called := make(chan int) - OnTimeoutFn := func(time time.Time) error { + OnTimeoutFn := func(time.Time, interface{}) error { called <- 1 return nil } @@ -128,7 +131,7 @@ func TestCountdownShouldBeAbleToStop(t *testing.T) { countdown.OnTimeoutFn = OnTimeoutFn // Check countdown did not start assert.False(t, countdown.isInitilised()) - countdown.Reset() + countdown.Reset(fakeI) // Now the countdown should already started assert.True(t, countdown.isInitilised()) // Try manually stop the timer before it triggers the callback diff --git a/consensus/XDPoS/engines/engine_v2/engine.go b/consensus/XDPoS/engines/engine_v2/engine.go index 4f170327d9..cff912706d 100644 --- a/consensus/XDPoS/engines/engine_v2/engine.go +++ b/consensus/XDPoS/engines/engine_v2/engine.go @@ -169,7 +169,7 @@ func (x *XDPoS_v2) Initial(chain consensus.ChainReader, masternodes []common.Add }() // Kick-off the countdown timer - x.timeoutWorker.Reset() + x.timeoutWorker.Reset(chain) log.Info("[Initial] finish initialisation") return nil @@ -682,7 +682,7 @@ func (x *XDPoS_v2) SyncInfoHandler(chain consensus.ChainReader, syncInfo *utils. if err != nil { return err } - return x.processTC(syncInfo.HighestTimeoutCert) + return x.processTC(chain, syncInfo.HighestTimeoutCert) } /* @@ -801,7 +801,7 @@ func (x *XDPoS_v2) onVotePoolThresholdReached(chain consensus.ChainReader, poole log.Error("Error while processing QC in the Vote handler after reaching pool threshold, ", err) return err } - log.Info("🗳 Successfully processed the vote and produced QC!") + log.Info("Successfully processed the vote and produced QC!", "QcRound", quorumCert.ProposedBlockInfo.Round, "QcNumOfSig", len(quorumCert.Signatures), "QcHash", quorumCert.ProposedBlockInfo.Hash, "QcNumber", quorumCert.ProposedBlockInfo.Number.Uint64()) // clean up vote at the same poolKey. and pookKey is proposed block hash x.votePool.ClearPoolKeyByObj(currentVoteMsg) return nil @@ -822,7 +822,10 @@ func (x *XDPoS_v2) onVotePoolThresholdReached(chain consensus.ChainReader, poole func (x *XDPoS_v2) VerifyTimeoutMessage(chain consensus.ChainReader, timeoutMsg *utils.Timeout) (bool, error) { masternodes := x.GetMasternodesAtRound(chain, timeoutMsg.Round, chain.CurrentHeader()) - return x.verifyMsgSignature(utils.TimeoutSigHash(&timeoutMsg.Round), timeoutMsg.Signature, masternodes) + return x.verifyMsgSignature(utils.TimeoutSigHash(&utils.TimeoutForSign{ + Round: timeoutMsg.Round, + GapNumber: timeoutMsg.GapNumber, + }), timeoutMsg.Signature, masternodes) } /* @@ -831,13 +834,13 @@ func (x *XDPoS_v2) VerifyTimeoutMessage(chain consensus.ChainReader, timeoutMsg 2. Collect timeout 3. Once timeout pool reached threshold, it will trigger the call to the function "onTimeoutPoolThresholdReached" */ -func (x *XDPoS_v2) TimeoutHandler(timeout *utils.Timeout) error { +func (x *XDPoS_v2) TimeoutHandler(blockChainReader consensus.ChainReader, timeout *utils.Timeout) error { x.lock.Lock() defer x.lock.Unlock() - return x.timeoutHandler(timeout) + return x.timeoutHandler(blockChainReader, timeout) } -func (x *XDPoS_v2) timeoutHandler(timeout *utils.Timeout) error { +func (x *XDPoS_v2) timeoutHandler(blockChainReader consensus.ChainReader, timeout *utils.Timeout) error { // 1. checkRoundNumber if timeout.Round != x.currentRound { return &utils.ErrIncomingMessageRoundNotEqualCurrentRound{ @@ -851,12 +854,12 @@ func (x *XDPoS_v2) timeoutHandler(timeout *utils.Timeout) error { // Threshold reached if isThresholdReached { log.Info(fmt.Sprintf("Timeout pool threashold reached: %v, number of items in the pool: %v", isThresholdReached, numberOfTimeoutsInPool)) - err := x.onTimeoutPoolThresholdReached(pooledTimeouts, timeout) + err := x.onTimeoutPoolThresholdReached(blockChainReader, pooledTimeouts, timeout, timeout.GapNumber) if err != nil { return err } - // clean up timeout message at the same poolKey. and pookKey is proposed block hash - x.timeoutPool.ClearPoolKeyByObj(timeout) + // clean up timeout message, regardless its GapNumber or round + x.timeoutPool.Clear() } return nil } @@ -868,7 +871,7 @@ func (x *XDPoS_v2) timeoutHandler(timeout *utils.Timeout) error { 2. processTC() 3. generateSyncInfo() */ -func (x *XDPoS_v2) onTimeoutPoolThresholdReached(pooledTimeouts map[common.Hash]utils.PoolObj, currentTimeoutMsg utils.PoolObj) error { +func (x *XDPoS_v2) onTimeoutPoolThresholdReached(blockChainReader consensus.ChainReader, pooledTimeouts map[common.Hash]utils.PoolObj, currentTimeoutMsg utils.PoolObj, gapNumber uint64) error { signatures := []utils.Signature{} for _, v := range pooledTimeouts { signatures = append(signatures, v.(*utils.Timeout).Signature) @@ -877,18 +880,19 @@ func (x *XDPoS_v2) onTimeoutPoolThresholdReached(pooledTimeouts map[common.Hash] timeoutCert := &utils.TimeoutCert{ Round: currentTimeoutMsg.(*utils.Timeout).Round, Signatures: signatures, + GapNumber: gapNumber, } // Process TC - err := x.processTC(timeoutCert) + err := x.processTC(blockChainReader, timeoutCert) if err != nil { - log.Error("Error while processing TC in the Timeout handler after reaching pool threshold, ", err.Error()) + log.Error("Error while processing TC in the Timeout handler after reaching pool threshold", "TcRound", timeoutCert.Round, "NumberOfTcSig", len(timeoutCert.Signatures), "GapNumber", gapNumber, "Error", err) return err } // Generate and broadcast syncInfo syncInfo := x.getSyncInfo() x.broadcastToBftChannel(syncInfo) - log.Info("⏰ Successfully processed the timeout message and produced TC & SyncInfo!") + log.Info("Successfully processed the timeout message and produced TC & SyncInfo!", "TcRound", timeoutCert.Round, "NumberOfTcSig", len(timeoutCert.Signatures)) return nil } @@ -1041,10 +1045,10 @@ func (x *XDPoS_v2) verifyQC(blockChainReader consensus.ChainReader, quorumCert * return x.VerifyBlockInfo(blockChainReader, quorumCert.ProposedBlockInfo) } -// TODO: Unhold, wait till proposal finalise func (x *XDPoS_v2) verifyTC(timeoutCert *utils.TimeoutCert) error { /* - 1. Get epoch master node list by round/number with chain's current header + 1. Get epoch master node list by gapNumber + 2. Check number of signatures > threshold, as well as it's format. (Same as verifyQC) 2. Verify signer signature: (List of signatures) - Use ecRecover to get the public key - Use the above public key to find out the xdc address @@ -1087,7 +1091,7 @@ func (x *XDPoS_v2) processQC(blockChainReader consensus.ChainReader, quorumCert } // 4. Set new round if quorumCert.ProposedBlockInfo.Round >= x.currentRound { - err := x.setNewRound(quorumCert.ProposedBlockInfo.Round + 1) + err := x.setNewRound(blockChainReader, quorumCert.ProposedBlockInfo.Round+1) if err != nil { log.Error("[processQC] Fail to setNewRound", "new round to set", quorumCert.ProposedBlockInfo.Round+1) return err @@ -1101,12 +1105,12 @@ func (x *XDPoS_v2) processQC(blockChainReader consensus.ChainReader, quorumCert 1. Update highestTC 2. Check TC round >= node's currentRound. If yes, call setNewRound */ -func (x *XDPoS_v2) processTC(timeoutCert *utils.TimeoutCert) error { +func (x *XDPoS_v2) processTC(blockChainReader consensus.ChainReader, timeoutCert *utils.TimeoutCert) error { if timeoutCert.Round > x.highestTimeoutCert.Round { x.highestTimeoutCert = timeoutCert } if timeoutCert.Round >= x.currentRound { - err := x.setNewRound(timeoutCert.Round + 1) + err := x.setNewRound(blockChainReader, timeoutCert.Round+1) if err != nil { return err } @@ -1119,10 +1123,10 @@ func (x *XDPoS_v2) processTC(timeoutCert *utils.TimeoutCert) error { 2. Reset timer 3. Reset vote and timeout Pools */ -func (x *XDPoS_v2) setNewRound(round utils.Round) error { +func (x *XDPoS_v2) setNewRound(blockChainReader consensus.ChainReader, round utils.Round) error { x.currentRound = round //TODO: tell miner now it's a new round and start mine if it's leader - x.timeoutWorker.Reset() + x.timeoutWorker.Reset(blockChainReader) //TODO: vote pools x.timeoutPool.Clear() return nil @@ -1197,18 +1201,44 @@ func (x *XDPoS_v2) sendVote(chainReader consensus.ChainReader, blockInfo *utils. 2. Sign the signature 3. send to broadcast channel */ -func (x *XDPoS_v2) sendTimeout() error { - signedHash, err := x.signSignature(utils.TimeoutSigHash(&x.currentRound)) +func (x *XDPoS_v2) sendTimeout(chain consensus.ChainReader) error { + // Construct the gapNumber + var gapNumber uint64 + currentBlockHeader := chain.CurrentHeader() + isEpochSwitch, epochNum, err := x.IsEpochSwitchAtRound(x.currentRound, currentBlockHeader) if err != nil { - log.Error("signSignature when sending out TC", "Error", err) + log.Error("[sendTimeout] Error while checking if the currentBlock is epoch switch", "currentRound", x.currentRound, "currentBlockNum", currentBlockHeader.Number, "currentBlockHash", currentBlockHeader.Hash(), "epochNum", epochNum) + return err + } + if isEpochSwitch { + // Notice this +1 is because we expect a block whos is the child of currentHeader + currentNumber := currentBlockHeader.Number.Uint64() + 1 + gapNumber := currentNumber - currentNumber%x.config.Epoch - x.config.Gap + log.Debug("[sendTimeout] is epoch switch when sending out timeout message", "currentNumber", currentNumber, "gapNumber", gapNumber) + } else { + epochSwitchInfo, err := x.getEpochSwitchInfo(chain, currentBlockHeader, currentBlockHeader.Hash()) + if err != nil { + log.Error("[sendTimeout] Error when trying to get current epoch switch info for a non-epoch block", "currentRound", x.currentRound, "currentBlockNum", currentBlockHeader.Number, "currentBlockHash", currentBlockHeader.Hash(), "epochNum", epochNum) + } + gapNumber := epochSwitchInfo.EpochSwitchBlockInfo.Number.Uint64() - epochSwitchInfo.EpochSwitchBlockInfo.Number.Uint64()%x.config.Epoch - x.config.Gap + log.Debug("[sendTimeout] non-epoch-switch block found its epoch block and calculated the gapNumber", "epochSwitchInfo.EpochSwitchBlockInfo.Number", epochSwitchInfo.EpochSwitchBlockInfo.Number.Uint64(), "gapNumber", gapNumber) + } + + signedHash, err := x.signSignature(utils.TimeoutSigHash(&utils.TimeoutForSign{ + Round: x.currentRound, + GapNumber: gapNumber, + })) + if err != nil { + log.Error("[sendTimeout] signSignature when sending out TC", "Error", err) return err } timeoutMsg := &utils.Timeout{ Round: x.currentRound, Signature: signedHash, + GapNumber: gapNumber, } - - err = x.timeoutHandler(timeoutMsg) + log.Info("[sendTimeout] Timeout message generated, ready to send!", "timeoutMsgRound", timeoutMsg.Round, "timeoutMsgGapNumber", timeoutMsg.GapNumber) + err = x.timeoutHandler(chain, timeoutMsg) if err != nil { log.Error("TimeoutHandler error", "TimeoutRound", timeoutMsg.Round, "Error", err) return err @@ -1254,11 +1284,11 @@ func (x *XDPoS_v2) verifyMsgSignature(signedHashToBeVerified common.Hash, signat Function that will be called by timer when countdown reaches its threshold. In the engine v2, we would need to broadcast timeout messages to other peers */ -func (x *XDPoS_v2) OnCountdownTimeout(time time.Time) error { +func (x *XDPoS_v2) OnCountdownTimeout(time time.Time, chain interface{}) error { x.lock.Lock() defer x.lock.Unlock() - err := x.sendTimeout() + err := x.sendTimeout(chain.(consensus.ChainReader)) if err != nil { log.Error("Error while sending out timeout message at time: ", time) return err @@ -1321,7 +1351,7 @@ func (x *XDPoS_v2) commitBlocks(blockChainReader consensus.ChainReader, proposed Hash: grandParentBlock.Hash(), Round: decodedExtraField.Round, } - log.Debug("👴 Successfully committed block", "Committed block Hash", x.highestCommitBlock.Hash, "Committed round", x.highestCommitBlock.Round) + log.Debug("Successfully committed block", "Committed block Hash", x.highestCommitBlock.Hash, "Committed round", x.highestCommitBlock.Round) return true, nil } // Everything else, fail to commit @@ -1352,12 +1382,12 @@ func (x *XDPoS_v2) isExtendingFromAncestor(blockChainReader consensus.ChainReade Testing tools */ -func (x *XDPoS_v2) SetNewRoundFaker(newRound utils.Round, resetTimer bool) { +func (x *XDPoS_v2) SetNewRoundFaker(blockChainReader consensus.ChainReader, newRound utils.Round, resetTimer bool) { x.lock.Lock() defer x.lock.Unlock() // Reset a bunch of things if resetTimer { - x.timeoutWorker.Reset() + x.timeoutWorker.Reset(blockChainReader) } x.currentRound = newRound } diff --git a/consensus/XDPoS/utils/pool.go b/consensus/XDPoS/utils/pool.go index 43f0a09d4a..949cff3cac 100644 --- a/consensus/XDPoS/utils/pool.go +++ b/consensus/XDPoS/utils/pool.go @@ -35,6 +35,7 @@ func (p *Pool) Add(obj PoolObj) (bool, int, map[common.Hash]PoolObj) { } return false, numOfItems, objListKeyed } + func (p *Pool) Size(obj PoolObj) int { poolKey := obj.PoolKey() objListKeyed, ok := p.objList[poolKey] diff --git a/consensus/XDPoS/utils/types.go b/consensus/XDPoS/utils/types.go index 8b56459f20..aec39132ee 100644 --- a/consensus/XDPoS/utils/types.go +++ b/consensus/XDPoS/utils/types.go @@ -1,6 +1,7 @@ package utils import ( + "fmt" "math/big" "time" @@ -81,6 +82,7 @@ type Vote struct { type Timeout struct { Round Round Signature Signature + GapNumber uint64 } // BFT Sync Info message in XDPoS 2.0 @@ -99,6 +101,7 @@ type QuorumCert struct { type TimeoutCert struct { Round Round Signatures []Signature + GapNumber uint64 } // The parsed extra fields in block header in XDPoS 2.0 (excluding the version byte) @@ -147,7 +150,12 @@ func VoteSigHash(m *BlockInfo) common.Hash { return rlpHash(m) } -func TimeoutSigHash(m *Round) common.Hash { +type TimeoutForSign struct { + Round Round + GapNumber uint64 +} + +func TimeoutSigHash(m *TimeoutForSign) common.Hash { return rlpHash(m) } @@ -157,6 +165,6 @@ func (m *Vote) PoolKey() string { } func (m *Timeout) PoolKey() string { - // return a default pool key string - return "0" + // timeout pool key is round:gapNumber + return fmt.Sprint(m.Round, ":", m.GapNumber) } diff --git a/consensus/XDPoS/utils/types_test.go b/consensus/XDPoS/utils/types_test.go index a50836bac8..764eb69587 100644 --- a/consensus/XDPoS/utils/types_test.go +++ b/consensus/XDPoS/utils/types_test.go @@ -62,7 +62,13 @@ func TestHashAndSigHash(t *testing.T) { t.Fatalf("SigHash of two block info shouldn't equal") } round2 := Round(999) - if TimeoutSigHash(&round) == TimeoutSigHash(&round2) { + if TimeoutSigHash(&TimeoutForSign{ + Round: round, + GapNumber: 450, + }) == TimeoutSigHash(&TimeoutForSign{ + Round: round2, + GapNumber: 450, + }) { t.Fatalf("SigHash of two round shouldn't equal") } } diff --git a/consensus/tests/authorised_masternode_test.go b/consensus/tests/authorised_masternode_test.go index fe86f3bca1..ae046013c7 100644 --- a/consensus/tests/authorised_masternode_test.go +++ b/consensus/tests/authorised_masternode_test.go @@ -127,7 +127,7 @@ func TestIsYourTurnConsensusV2(t *testing.T) { blockchain.InsertBlock(currentBlock) time.Sleep(time.Duration(minePeriod) * time.Second) - adaptor.EngineV2.SetNewRoundFaker(2, false) + adaptor.EngineV2.SetNewRoundFaker(blockchain, 2, false) isYourTurn, _ = adaptor.YourTurn(blockchain, currentBlock.Header(), common.HexToAddress("xdc0D3ab14BBaD3D99F4203bd7a11aCB94882050E7e")) assert.False(t, isYourTurn) diff --git a/consensus/tests/countdown_test.go b/consensus/tests/countdown_test.go index c7bd5f4ea7..4e687a1b53 100644 --- a/consensus/tests/countdown_test.go +++ b/consensus/tests/countdown_test.go @@ -10,15 +10,17 @@ import ( ) func TestCountdownTimeoutToSendTimeoutMessage(t *testing.T) { - blockchain, _, _, _, _, _ := PrepareXDCTestBlockChainForV2Engine(t, 11, params.TestXDPoSMockChainConfig, 0) + blockchain, _, _, _, _, _ := PrepareXDCTestBlockChainForV2Engine(t, 901, params.TestXDPoSMockChainConfig, 0) engineV2 := blockchain.Engine().(*XDPoS.XDPoS).EngineV2 - engineV2.SetNewRoundFaker(utils.Round(1), true) + engineV2.SetNewRoundFaker(blockchain, utils.Round(1), true) timeoutMsg := <-engineV2.BroadcastCh poolSize := engineV2.GetTimeoutPoolSize(timeoutMsg.(*utils.Timeout)) assert.Equal(t, poolSize, 1) assert.NotNil(t, timeoutMsg) + assert.Equal(t, uint64(0), timeoutMsg.(*utils.Timeout).GapNumber) + assert.Equal(t, utils.Round(1), timeoutMsg.(*utils.Timeout).Round) valid, err := engineV2.VerifyTimeoutMessage(blockchain, timeoutMsg.(*utils.Timeout)) // We can only test valid = false for now as the implementation for getCurrentRoundMasterNodes is not complete diff --git a/consensus/tests/penalty_test.go b/consensus/tests/penalty_test.go index cae4b82687..504dc4f1dd 100644 --- a/consensus/tests/penalty_test.go +++ b/consensus/tests/penalty_test.go @@ -94,7 +94,7 @@ func TestHookPenaltyV2Jump(t *testing.T) { masternodes := adaptor.GetMasternodesFromCheckpointHeader(header901) assert.Equal(t, 4, len(masternodes)) header6285 := blockchain.GetHeaderByNumber(uint64(end)) - adaptor.EngineV2.SetNewRoundFaker(utils.Round(config.XDPoS.Epoch*7), false) + adaptor.EngineV2.SetNewRoundFaker(blockchain, utils.Round(config.XDPoS.Epoch*7), false) // round 6285-6300 miss blocks, penalty should work as usual penalty, err := adaptor.EngineV2.HookPenalty(blockchain, header6285.Number, header6285.ParentHash, masternodes) assert.Nil(t, err) diff --git a/consensus/tests/proposed_block_test.go b/consensus/tests/proposed_block_test.go index 882f7a095d..d03bc28a43 100644 --- a/consensus/tests/proposed_block_test.go +++ b/consensus/tests/proposed_block_test.go @@ -175,7 +175,7 @@ func TestProposedBlockMessageHandlerSuccessfullyGenerateVote(t *testing.T) { engineV2 := blockchain.Engine().(*XDPoS.XDPoS).EngineV2 // Set current round to 5 - engineV2.SetNewRoundFaker(utils.Round(5), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(5), false) var extraField utils.ExtraFields_v2 err := utils.DecodeBytesExtraFields(currentBlock.Extra(), &extraField) @@ -205,7 +205,7 @@ func TestShouldNotSetNewRound(t *testing.T) { engineV2 := blockchain.Engine().(*XDPoS.XDPoS).EngineV2 // Set current round to 6 - engineV2.SetNewRoundFaker(utils.Round(6), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(6), false) var extraField utils.ExtraFields_v2 err := utils.DecodeBytesExtraFields(currentBlock.Extra(), &extraField) @@ -229,7 +229,7 @@ func TestShouldNotSendVoteMessageIfAlreadyVoteForThisRound(t *testing.T) { engineV2 := blockchain.Engine().(*XDPoS.XDPoS).EngineV2 // Set current round to 5 - engineV2.SetNewRoundFaker(utils.Round(5), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(5), false) err := engineV2.ProposedBlockHandler(blockchain, currentBlock.Header()) if err != nil { @@ -267,7 +267,7 @@ func TestShouldNotSendVoteMsgIfBlockInfoRoundNotEqualCurrentRound(t *testing.T) engineV2 := blockchain.Engine().(*XDPoS.XDPoS).EngineV2 // Set current round to 8 - engineV2.SetNewRoundFaker(utils.Round(8), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(8), false) var extraField utils.ExtraFields_v2 err := utils.DecodeBytesExtraFields(currentBlock.Extra(), &extraField) @@ -316,7 +316,7 @@ func TestShouldNotSendVoteMsgIfBlockNotExtendedFromAncestor(t *testing.T) { // Find the first forked block at block 14th firstForkedBlock := blockchain.GetBlockByHash(blockchain.GetBlockByHash(forkedBlock.ParentHash()).ParentHash()) - engineV2.SetNewRoundFaker(utils.Round(7), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(7), false) err = engineV2.ProposedBlockHandler(blockchain, firstForkedBlock.Header()) if err != nil { t.Fatal("Fail propose proposedBlock handler", err) diff --git a/consensus/tests/timeout_test.go b/consensus/tests/timeout_test.go index 8f332047a8..55a9f8719c 100644 --- a/consensus/tests/timeout_test.go +++ b/consensus/tests/timeout_test.go @@ -15,32 +15,47 @@ func TestTimeoutMessageHandlerSuccessfullyGenerateTCandSyncInfo(t *testing.T) { engineV2 := blockchain.Engine().(*XDPoS.XDPoS).EngineV2 // Set round to 1 - engineV2.SetNewRoundFaker(utils.Round(1), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(1), false) // Create two timeout message which will not reach timeout pool threshold timeoutMsg := &utils.Timeout{ Round: utils.Round(1), Signature: []byte{1}, + GapNumber: 450, } - err := engineV2.TimeoutHandler(timeoutMsg) + err := engineV2.TimeoutHandler(blockchain, timeoutMsg) assert.Nil(t, err) currentRound, _, _, _, _ := engineV2.GetProperties() assert.Equal(t, utils.Round(1), currentRound) timeoutMsg = &utils.Timeout{ Round: utils.Round(1), Signature: []byte{2}, + GapNumber: 450, } - err = engineV2.TimeoutHandler(timeoutMsg) + err = engineV2.TimeoutHandler(blockchain, timeoutMsg) assert.Nil(t, err) currentRound, _, _, _, _ = engineV2.GetProperties() assert.Equal(t, utils.Round(1), currentRound) - // Create a timeout message that should trigger timeout pool hook + + // Send a timeout with different gap number, it shall not trigger timeout pool hook timeoutMsg = &utils.Timeout{ Round: utils.Round(1), Signature: []byte{3}, + GapNumber: 1350, + } + err = engineV2.TimeoutHandler(blockchain, timeoutMsg) + assert.Nil(t, err) + currentRound, _, _, _, _ = engineV2.GetProperties() + assert.Equal(t, utils.Round(1), currentRound) + + // Create a timeout message that should trigger timeout pool hook + timeoutMsg = &utils.Timeout{ + Round: utils.Round(1), + Signature: []byte{4}, + GapNumber: 450, } - err = engineV2.TimeoutHandler(timeoutMsg) + err = engineV2.TimeoutHandler(blockchain, timeoutMsg) assert.Nil(t, err) syncInfoMsg := <-engineV2.BroadcastCh @@ -56,7 +71,9 @@ func TestTimeoutMessageHandlerSuccessfullyGenerateTCandSyncInfo(t *testing.T) { tc := syncInfoMsg.(*utils.SyncInfo).HighestTimeoutCert assert.NotNil(t, tc) assert.Equal(t, tc.Round, utils.Round(1)) - sigatures := []utils.Signature{[]byte{1}, []byte{2}, []byte{3}} + assert.Equal(t, uint64(450), tc.GapNumber) + // The signatures shall not include the byte{3} from a different gap number + sigatures := []utils.Signature{[]byte{1}, []byte{2}, []byte{4}} assert.ElementsMatch(t, tc.Signatures, sigatures) assert.Equal(t, utils.Round(2), currentRound) } @@ -66,20 +83,20 @@ func TestThrowErrorIfTimeoutMsgRoundNotEqualToCurrentRound(t *testing.T) { engineV2 := blockchain.Engine().(*XDPoS.XDPoS).EngineV2 // Set round to 3 - engineV2.SetNewRoundFaker(utils.Round(3), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(3), false) timeoutMsg := &utils.Timeout{ Round: utils.Round(2), Signature: []byte{1}, } - err := engineV2.TimeoutHandler(timeoutMsg) + err := engineV2.TimeoutHandler(blockchain, timeoutMsg) assert.NotNil(t, err) // Timeout msg round > currentRound assert.Equal(t, "timeout message round number: 2 does not match currentRound: 3", err.Error()) // Set round to 1 - engineV2.SetNewRoundFaker(utils.Round(1), false) - err = engineV2.TimeoutHandler(timeoutMsg) + engineV2.SetNewRoundFaker(blockchain, utils.Round(1), false) + err = engineV2.TimeoutHandler(blockchain, timeoutMsg) assert.NotNil(t, err) // Timeout msg round < currentRound assert.Equal(t, "timeout message round number: 2 does not match currentRound: 1", err.Error()) diff --git a/consensus/tests/vote_test.go b/consensus/tests/vote_test.go index bdd3e0a2bc..404fcfcfa9 100644 --- a/consensus/tests/vote_test.go +++ b/consensus/tests/vote_test.go @@ -27,7 +27,7 @@ func TestVoteMessageHandlerSuccessfullyGeneratedAndProcessQCForFistV2Round(t *te voteSigningHash := utils.VoteSigHash(blockInfo) // Set round to 5 - engineV2.SetNewRoundFaker(utils.Round(1), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(1), false) // Create two vote messages which will not reach vote pool threshold signedHash, err := signFn(accounts.Account{Address: signer}, voteSigningHash.Bytes()) assert.Nil(t, err) @@ -89,7 +89,7 @@ func TestVoteMessageHandlerSuccessfullyGeneratedAndProcessQC(t *testing.T) { voteSigningHash := utils.VoteSigHash(blockInfo) // Set round to 5 - engineV2.SetNewRoundFaker(utils.Round(5), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(5), false) // Create two vote messages which will not reach vote pool threshold signedHash, err := signFn(accounts.Account{Address: signer}, voteSigningHash.Bytes()) assert.Nil(t, err) @@ -168,7 +168,7 @@ func TestThrowErrorIfVoteMsgRoundIsMoreThanOneRoundAwayFromCurrentRound(t *testi } // Set round to 7 - engineV2.SetNewRoundFaker(utils.Round(7), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(7), false) voteMsg := &utils.Vote{ ProposedBlockInfo: blockInfo, Signature: []byte{1}, @@ -180,11 +180,11 @@ func TestThrowErrorIfVoteMsgRoundIsMoreThanOneRoundAwayFromCurrentRound(t *testi assert.Equal(t, "vote message round number: 6 is too far away from currentRound: 7", err.Error()) // Set round to 5, it's 1 round away, should not trigger failure - engineV2.SetNewRoundFaker(utils.Round(5), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(5), false) err = engineV2.VoteHandler(blockchain, voteMsg) assert.Nil(t, err) - engineV2.SetNewRoundFaker(utils.Round(4), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(4), false) err = engineV2.VoteHandler(blockchain, voteMsg) assert.NotNil(t, err) assert.Equal(t, "vote message round number: 6 is too far away from currentRound: 4", err.Error()) @@ -196,7 +196,7 @@ func TestProcessVoteMsgThenTimeoutMsg(t *testing.T) { engineV2 := blockchain.Engine().(*XDPoS.XDPoS).EngineV2 // Set round to 5 - engineV2.SetNewRoundFaker(utils.Round(5), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(5), false) // Start with vote messages blockInfo := &utils.BlockInfo{ @@ -254,7 +254,7 @@ func TestProcessVoteMsgThenTimeoutMsg(t *testing.T) { Signature: []byte{1}, } - err = engineV2.TimeoutHandler(timeoutMsg) + err = engineV2.TimeoutHandler(blockchain, timeoutMsg) assert.NotNil(t, err) assert.Equal(t, "timeout message round number: 5 does not match currentRound: 6", err.Error()) @@ -264,7 +264,7 @@ func TestProcessVoteMsgThenTimeoutMsg(t *testing.T) { Signature: []byte{1}, } - err = engineV2.TimeoutHandler(timeoutMsg) + err = engineV2.TimeoutHandler(blockchain, timeoutMsg) assert.Nil(t, err) currentRound, _, _, _, _ = engineV2.GetProperties() assert.Equal(t, utils.Round(6), currentRound) @@ -272,7 +272,7 @@ func TestProcessVoteMsgThenTimeoutMsg(t *testing.T) { Round: utils.Round(6), Signature: []byte{2}, } - err = engineV2.TimeoutHandler(timeoutMsg) + err = engineV2.TimeoutHandler(blockchain, timeoutMsg) assert.Nil(t, err) currentRound, _, _, _, _ = engineV2.GetProperties() assert.Equal(t, utils.Round(6), currentRound) @@ -283,7 +283,7 @@ func TestProcessVoteMsgThenTimeoutMsg(t *testing.T) { Signature: []byte{3}, } - err = engineV2.TimeoutHandler(timeoutMsg) + err = engineV2.TimeoutHandler(blockchain, timeoutMsg) assert.Nil(t, err) syncInfoMsg := <-engineV2.BroadcastCh @@ -321,7 +321,7 @@ func TestVoteMessageShallNotThrowErrorIfBlockNotYetExist(t *testing.T) { voteSigningHash := utils.VoteSigHash(blockInfo) // Set round to 6 - engineV2.SetNewRoundFaker(utils.Round(6), false) + engineV2.SetNewRoundFaker(blockchain, utils.Round(6), false) // Create two vote messages which will not reach vote pool threshold voteMsg := &utils.Vote{ ProposedBlockInfo: blockInfo, diff --git a/eth/bft/bft_hander_test.go b/eth/bft/bft_hander_test.go index 935b87b820..5de638a82c 100644 --- a/eth/bft/bft_hander_test.go +++ b/eth/bft/bft_hander_test.go @@ -160,7 +160,7 @@ func TestTimeoutHandler(t *testing.T) { return nil } - tester.bfter.consensus.timeoutHandler = func(timeout *utils.Timeout) error { + tester.bfter.consensus.timeoutHandler = func(chain consensus.ChainReader, timeout *utils.Timeout) error { atomic.AddUint32(&handlerCounter, 1) return nil } @@ -190,7 +190,7 @@ func TestTimeoutHandlerRoundNotEqual(t *testing.T) { return nil } - tester.bfter.consensus.timeoutHandler = func(timeout *utils.Timeout) error { + tester.bfter.consensus.timeoutHandler = func(chain consensus.ChainReader, timeout *utils.Timeout) error { return &utils.ErrIncomingMessageRoundNotEqualCurrentRound{ Type: "timeout", IncomingRound: utils.Round(1), diff --git a/eth/bft/bft_handler.go b/eth/bft/bft_handler.go index b911ff75de..a443460b98 100644 --- a/eth/bft/bft_handler.go +++ b/eth/bft/bft_handler.go @@ -34,11 +34,11 @@ type Bfter struct { } type ConsensusFns struct { - verifyVote func(chain consensus.ChainReader, vote *utils.Vote) (bool, error) + verifyVote func(consensus.ChainReader, *utils.Vote) (bool, error) voteHandler func(consensus.ChainReader, *utils.Vote) error verifyTimeout func(*utils.Timeout) error - timeoutHandler func(*utils.Timeout) error + timeoutHandler func(consensus.ChainReader, *utils.Timeout) error verifySyncInfo func(*utils.SyncInfo) error syncInfoHandler func(consensus.ChainReader, *utils.SyncInfo) error @@ -123,7 +123,7 @@ func (b *Bfter) Timeout(timeout *utils.Timeout) error { } b.broadcastCh <- timeout - err = b.consensus.timeoutHandler(timeout) + err = b.consensus.timeoutHandler(b.blockChainReader, timeout) if err != nil { if _, ok := err.(*utils.ErrIncomingMessageRoundNotEqualCurrentRound); ok { log.Warn("timeout round not equal", "error", err)