diff --git a/rpc/server.go b/rpc/server.go index 5c1670e16f..da614b242e 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -19,10 +19,10 @@ package rpc import ( "context" "io" + "sync" "sync/atomic" "github.com/XinFinOrg/XDPoSChain/log" - mapset "github.com/deckarep/golang-set/v2" ) const MetadataApi = "rpc" @@ -44,13 +44,19 @@ const ( type Server struct { services serviceRegistry idgen func() ID - run int32 - codecs mapset.Set[*ServerCodec] + + mutex sync.Mutex + codecs map[ServerCodec]struct{} + run int32 } // NewServer creates a new server instance with no registered handlers. func NewServer() *Server { - server := &Server{idgen: randomIDGenerator(), codecs: mapset.NewSet[*ServerCodec](), run: 1} + server := &Server{ + idgen: randomIDGenerator(), + codecs: make(map[ServerCodec]struct{}), + run: 1, + } // Register the default service providing meta information about the RPC service such // as the services and methods it offers. rpcService := &RPCService{server} @@ -74,20 +80,34 @@ func (s *Server) RegisterName(name string, receiver interface{}) error { func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) { defer codec.close() - // Don't serve if server is stopped. - if atomic.LoadInt32(&s.run) == 0 { + if !s.trackCodec(codec) { return } - - // Add the codec to the set so it can be closed by Stop. - s.codecs.Add(&codec) - defer s.codecs.Remove(&codec) + defer s.untrackCodec(codec) c := initClient(codec, s.idgen, &s.services) <-codec.closed() c.Close() } +func (s *Server) trackCodec(codec ServerCodec) bool { + s.mutex.Lock() + defer s.mutex.Unlock() + + if atomic.LoadInt32(&s.run) == 0 { + return false // Don't serve if server is stopped. + } + s.codecs[codec] = struct{}{} + return true +} + +func (s *Server) untrackCodec(codec ServerCodec) { + s.mutex.Lock() + defer s.mutex.Unlock() + + delete(s.codecs, codec) +} + // serveSingleRequest reads and processes a single RPC request from the given codec. This // is used to serve HTTP connections. Subscriptions and reverse calls are not allowed in // this mode. @@ -119,12 +139,14 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) { // requests to finish, then closes all codecs which will cancel pending requests and // subscriptions. func (s *Server) Stop() { + s.mutex.Lock() + defer s.mutex.Unlock() + if atomic.CompareAndSwapInt32(&s.run, 1, 0) { log.Debug("RPC server shutting down") - s.codecs.Each(func(c *ServerCodec) bool { - (*c).close() - return true - }) + for codec := range s.codecs { + codec.close() + } } }