diff --git a/node/endpoints.go b/node/endpoints.go index 277a6644bf..234b49e73d 100644 --- a/node/endpoints.go +++ b/node/endpoints.go @@ -26,14 +26,14 @@ import ( ) // StartHTTPEndpoint starts the HTTP RPC endpoint. -func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.Handler) (net.Listener, error) { +func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.Handler) (*http.Server, net.Addr, error) { // start the HTTP listener var ( listener net.Listener err error ) if listener, err = net.Listen("tcp", endpoint); err != nil { - return nil, err + return nil, nil, err } // Bundle and start the HTTP server httpSrv := &http.Server{ @@ -44,22 +44,22 @@ func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http. } log.Info("StartHTTPEndpoint", "ReadTimeout", timeouts.ReadTimeout, "WriteTimeout", timeouts.WriteTimeout, "IdleTimeout", timeouts.IdleTimeout) go httpSrv.Serve(listener) - return listener, err + return httpSrv, listener.Addr(), err } // startWSEndpoint starts a websocket endpoint. -func startWSEndpoint(endpoint string, handler http.Handler) (net.Listener, error) { +func startWSEndpoint(endpoint string, handler http.Handler) (*http.Server, net.Addr, error) { // start the HTTP listener var ( listener net.Listener err error ) if listener, err = net.Listen("tcp", endpoint); err != nil { - return nil, err + return nil, nil, err } wsSrv := &http.Server{Handler: handler} go wsSrv.Serve(listener) - return listener, err + return wsSrv, listener.Addr(), err } // checkModuleAvailability checks that all names given in modules are actually diff --git a/node/node.go b/node/node.go index b53ab4a855..dfa261640c 100644 --- a/node/node.go +++ b/node/node.go @@ -17,9 +17,11 @@ package node import ( + "context" "errors" "fmt" "net" + "net/http" "os" "path/filepath" "reflect" @@ -60,14 +62,16 @@ type Node struct { ipcListener net.Listener // IPC RPC listener socket to serve API requests ipcHandler *rpc.Server // IPC RPC request handler to process the API requests - httpEndpoint string // HTTP endpoint (interface + port) to listen at (empty = HTTP disabled) - httpWhitelist []string // HTTP RPC modules to allow through this endpoint - httpListener net.Listener // HTTP RPC listener socket to server API requests - httpHandler *rpc.Server // HTTP RPC request handler to process the API requests + httpEndpoint string // HTTP endpoint (interface + port) to listen at (empty = HTTP disabled) + httpWhitelist []string // HTTP RPC modules to allow through this endpoint + httpListenerAddr net.Addr // Address of HTTP RPC listener socket serving API requests + httpServer *http.Server // HTTP RPC HTTP server + httpHandler *rpc.Server // HTTP RPC request handler to process the API requests - wsEndpoint string // Websocket endpoint (interface + port) to listen at (empty = websocket disabled) - wsListener net.Listener // Websocket RPC listener socket to server API requests - wsHandler *rpc.Server // Websocket RPC request handler to process the API requests + wsEndpoint string // WebSocket endpoint (interface + port) to listen at (empty = WebSocket disabled) + wsListenerAddr net.Addr // Address of WebSocket RPC listener socket serving API requests + wsHTTPServer *http.Server // WebSocket RPC HTTP server + wsHandler *rpc.Server // WebSocket RPC request handler to process the API requests stop chan struct{} // Channel to wait for termination notifications lock sync.RWMutex @@ -364,23 +368,24 @@ func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors return err } handler := NewHTTPHandlerStack(srv, cors, vhosts, &timeouts) - // wrap handler in websocket handler only if websocket port is the same as http rpc + // wrap handler in WebSocket handler only if WebSocket port is the same as http rpc if n.httpEndpoint == n.wsEndpoint { handler = NewWebsocketUpgradeHandler(handler, srv.WebsocketHandler(wsOrigins)) } - listener, err := StartHTTPEndpoint(endpoint, timeouts, handler) + httpServer, addr, err := StartHTTPEndpoint(endpoint, timeouts, handler) if err != nil { return err } - n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", listener.Addr()), + n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", addr), "cors", strings.Join(cors, ","), "vhosts", strings.Join(vhosts, ",")) if n.httpEndpoint == n.wsEndpoint { - n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", listener.Addr())) + n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", addr)) } // All listeners booted successfully n.httpEndpoint = endpoint - n.httpListener = listener + n.httpListenerAddr = addr + n.httpServer = httpServer n.httpHandler = srv return nil @@ -388,11 +393,10 @@ func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors // stopHTTP terminates the HTTP RPC endpoint. func (n *Node) stopHTTP() { - if n.httpListener != nil { - url := fmt.Sprintf("http://%v/", n.httpListener.Addr()) - n.httpListener.Close() - n.httpListener = nil - n.log.Info("HTTP endpoint closed", "url", url) + if n.httpServer != nil { + // Don't bother imposing a timeout here. + n.httpServer.Shutdown(context.Background()) + n.log.Info("HTTP endpoint closed", "url", fmt.Sprintf("http://%v/", n.httpListenerAddr)) } if n.httpHandler != nil { n.httpHandler.Stop() @@ -400,7 +404,7 @@ func (n *Node) stopHTTP() { } } -// startWS initializes and starts the websocket RPC endpoint. +// startWS initializes and starts the WebSocket RPC endpoint. func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrigins []string, exposeAll bool) error { // Short circuit if the WS endpoint isn't being exposed if endpoint == "" { @@ -413,26 +417,26 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig if err != nil { return err } - listener, err := startWSEndpoint(endpoint, handler) + httpServer, addr, err := startWSEndpoint(endpoint, handler) if err != nil { return err } - n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%s", listener.Addr())) + n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", addr)) // All listeners booted successfully n.wsEndpoint = endpoint - n.wsListener = listener + n.wsListenerAddr = addr + n.wsHTTPServer = httpServer n.wsHandler = srv return nil } -// stopWS terminates the websocket RPC endpoint. +// stopWS terminates the WebSocket RPC endpoint. func (n *Node) stopWS() { - if n.wsListener != nil { - n.wsListener.Close() - n.wsListener = nil - - n.log.Info("WebSocket endpoint closed", "url", fmt.Sprintf("ws://%s", n.wsEndpoint)) + if n.wsHTTPServer != nil { + // Don't bother imposing a timeout here. + n.wsHTTPServer.Shutdown(context.Background()) + n.log.Info("WebSocket endpoint closed", "url", fmt.Sprintf("ws://%v", n.wsListenerAddr)) } if n.wsHandler != nil { n.wsHandler.Stop() @@ -599,8 +603,8 @@ func (n *Node) HTTPEndpoint() string { n.lock.Lock() defer n.lock.Unlock() - if n.httpListener != nil { - return n.httpListener.Addr().String() + if n.httpListenerAddr != nil { + return n.httpListenerAddr.String() } return n.httpEndpoint } @@ -610,8 +614,8 @@ func (n *Node) WSEndpoint() string { n.lock.Lock() defer n.lock.Unlock() - if n.wsListener != nil { - return n.wsListener.Addr().String() + if n.wsListenerAddr != nil { + return n.wsListenerAddr.String() } return n.wsEndpoint }