diff --git a/consensus/XDPoS/utils/pool.go b/consensus/XDPoS/utils/pool.go index 9c09a8ba53..c2d177a068 100644 --- a/consensus/XDPoS/utils/pool.go +++ b/consensus/XDPoS/utils/pool.go @@ -10,6 +10,7 @@ type PoolObj interface { Hash() common.Hash PoolKey() string GetSigner() common.Address + DeepCopy() interface{} } type Pool struct { objList map[string]map[common.Hash]PoolObj @@ -22,7 +23,17 @@ func NewPool() *Pool { } } func (p *Pool) Get() map[string]map[common.Hash]PoolObj { - return p.objList + p.lock.RLock() + defer p.lock.RUnlock() + dataCopy := make(map[string]map[common.Hash]PoolObj, len(p.objList)) + for k1, v1 := range p.objList { + dataCopy[k1] = make(map[common.Hash]PoolObj, len(v1)) + for k2, v2 := range v1 { + dataCopy[k1][k2] = v2.DeepCopy().(PoolObj) + } + } + + return dataCopy } func (p *Pool) Add(obj PoolObj) (int, map[common.Hash]PoolObj) { @@ -36,10 +47,18 @@ func (p *Pool) Add(obj PoolObj) (int, map[common.Hash]PoolObj) { } objListKeyed[obj.Hash()] = obj numOfItems := len(objListKeyed) - return numOfItems, objListKeyed + + dataCopy := make(map[common.Hash]PoolObj, len(objListKeyed)) + for k, v := range objListKeyed { + dataCopy[k] = v.DeepCopy().(PoolObj) + } + + return numOfItems, dataCopy } func (p *Pool) Size(obj PoolObj) int { + p.lock.RLock() + defer p.lock.RUnlock() poolKey := obj.PoolKey() objListKeyed, ok := p.objList[poolKey] if !ok { @@ -84,8 +103,8 @@ func (p *Pool) Clear() { } func (p *Pool) GetObjsByKey(poolKey string) []PoolObj { - p.lock.Lock() - defer p.lock.Unlock() + p.lock.RLock() + defer p.lock.RUnlock() objListKeyed, ok := p.objList[poolKey] if !ok { @@ -94,8 +113,8 @@ func (p *Pool) GetObjsByKey(poolKey string) []PoolObj { objList := make([]PoolObj, len(objListKeyed)) cnt := 0 for _, obj := range objListKeyed { - objList[cnt] = obj - cnt += 1 + objList[cnt] = obj.DeepCopy().(PoolObj) + cnt++ } return objList } diff --git a/core/types/consensus_v2.go b/core/types/consensus_v2.go index 2f573ca354..12f35c9b17 100644 --- a/core/types/consensus_v2.go +++ b/core/types/consensus_v2.go @@ -12,6 +12,12 @@ import ( type Round uint64 type Signature []byte +func (s Signature) DeepCopy() interface{} { + cpy := make([]byte, len(s)) + copy(cpy, s) + return s +} + // Block Info struct in XDPoS 2.0, used for vote message, etc. type BlockInfo struct { Hash common.Hash `json:"hash"` @@ -27,6 +33,20 @@ type Vote struct { GapNumber uint64 `json:"gapNumber"` } +func (v *Vote) DeepCopy() interface{} { + proposedBlockInfoCopy := &BlockInfo{ + Hash: v.ProposedBlockInfo.Hash, + Round: v.ProposedBlockInfo.Round, + Number: new(big.Int).Set(v.ProposedBlockInfo.Number), + } + return &Vote{ + signer: v.signer, + ProposedBlockInfo: proposedBlockInfoCopy, + Signature: v.Signature.DeepCopy().(Signature), + GapNumber: v.GapNumber, + } +} + func (v *Vote) Hash() common.Hash { return rlpHash(v) } @@ -52,6 +72,15 @@ type Timeout struct { GapNumber uint64 } +func (t *Timeout) DeepCopy() interface{} { + return &Timeout{ + signer: t.signer, + Round: t.Round, + Signature: t.Signature.DeepCopy().(Signature), + GapNumber: t.GapNumber, + } +} + func (t *Timeout) Hash() common.Hash { return rlpHash(t) } @@ -75,6 +104,42 @@ type SyncInfo struct { HighestTimeoutCert *TimeoutCert } +func (s *SyncInfo) DeepCopy() interface{} { + var highestQCCopy *QuorumCert + if s.HighestQuorumCert != nil { + sigsCopy := make([]Signature, len(s.HighestQuorumCert.Signatures)) + for i, sig := range s.HighestQuorumCert.Signatures { + sigsCopy[i] = sig.DeepCopy().(Signature) + } + highestQCCopy = &QuorumCert{ + ProposedBlockInfo: &BlockInfo{ + Hash: s.HighestQuorumCert.ProposedBlockInfo.Hash, + Round: s.HighestQuorumCert.ProposedBlockInfo.Round, + Number: new(big.Int).Set(s.HighestQuorumCert.ProposedBlockInfo.Number), + }, + Signatures: sigsCopy, + GapNumber: s.HighestQuorumCert.GapNumber, + } + } + + var highestTimeoutCopy *TimeoutCert + if s.HighestTimeoutCert != nil { + sigsCopy := make([]Signature, len(s.HighestTimeoutCert.Signatures)) + for i, sig := range s.HighestTimeoutCert.Signatures { + sigsCopy[i] = sig.DeepCopy().(Signature) + } + highestTimeoutCopy = &TimeoutCert{ + Round: s.HighestTimeoutCert.Round, + Signatures: sigsCopy, + GapNumber: s.HighestTimeoutCert.GapNumber, + } + } + return &SyncInfo{ + HighestQuorumCert: highestQCCopy, + HighestTimeoutCert: highestTimeoutCopy, + } +} + func (s *SyncInfo) Hash() common.Hash { return rlpHash(s) }