diff --git a/node/api.go b/node/api.go index 3c6528e4d5..362880ceaa 100644 --- a/node/api.go +++ b/node/api.go @@ -187,7 +187,7 @@ func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis } } - if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts); err != nil { + if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts, api.node.config.WSOrigins); err != nil { return false, err } return true, nil diff --git a/node/endpoints.go b/node/endpoints.go new file mode 100644 index 0000000000..277a6644bf --- /dev/null +++ b/node/endpoints.go @@ -0,0 +1,98 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "net" + "net/http" + "time" + + "github.com/XinFinOrg/XDPoSChain/log" + "github.com/XinFinOrg/XDPoSChain/rpc" +) + +// StartHTTPEndpoint starts the HTTP RPC endpoint. +func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.Handler) (net.Listener, error) { + // start the HTTP listener + var ( + listener net.Listener + err error + ) + if listener, err = net.Listen("tcp", endpoint); err != nil { + return nil, err + } + // Bundle and start the HTTP server + httpSrv := &http.Server{ + Handler: handler, + ReadTimeout: timeouts.ReadTimeout, + WriteTimeout: timeouts.WriteTimeout, + IdleTimeout: timeouts.IdleTimeout, + } + log.Info("StartHTTPEndpoint", "ReadTimeout", timeouts.ReadTimeout, "WriteTimeout", timeouts.WriteTimeout, "IdleTimeout", timeouts.IdleTimeout) + go httpSrv.Serve(listener) + return listener, err +} + +// startWSEndpoint starts a websocket endpoint. +func startWSEndpoint(endpoint string, handler http.Handler) (net.Listener, error) { + // start the HTTP listener + var ( + listener net.Listener + err error + ) + if listener, err = net.Listen("tcp", endpoint); err != nil { + return nil, err + } + wsSrv := &http.Server{Handler: handler} + go wsSrv.Serve(listener) + return listener, err +} + +// checkModuleAvailability checks that all names given in modules are actually +// available API services. It assumes that the MetadataApi module ("rpc") is always available; +// the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints. +func checkModuleAvailability(modules []string, apis []rpc.API) (bad, available []string) { + availableSet := make(map[string]struct{}) + for _, api := range apis { + if _, ok := availableSet[api.Namespace]; !ok { + availableSet[api.Namespace] = struct{}{} + available = append(available, api.Namespace) + } + } + for _, name := range modules { + if _, ok := availableSet[name]; !ok && name != rpc.MetadataApi { + bad = append(bad, name) + } + } + return bad, available +} + +// CheckTimeouts ensures that timeout values are meaningful +func CheckTimeouts(timeouts *rpc.HTTPTimeouts) { + if timeouts.ReadTimeout < time.Second { + log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", rpc.DefaultHTTPTimeouts.ReadTimeout) + timeouts.ReadTimeout = rpc.DefaultHTTPTimeouts.ReadTimeout + } + if timeouts.WriteTimeout < time.Second { + log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", rpc.DefaultHTTPTimeouts.WriteTimeout) + timeouts.WriteTimeout = rpc.DefaultHTTPTimeouts.WriteTimeout + } + if timeouts.IdleTimeout < time.Second { + log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", rpc.DefaultHTTPTimeouts.IdleTimeout) + timeouts.IdleTimeout = rpc.DefaultHTTPTimeouts.IdleTimeout + } +} diff --git a/node/node.go b/node/node.go index 11813d8865..b53ab4a855 100644 --- a/node/node.go +++ b/node/node.go @@ -269,17 +269,21 @@ func (n *Node) startRPC(services map[reflect.Type]Service) error { n.stopInProc() return err } - if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts); err != nil { + if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts, n.config.WSOrigins); err != nil { n.stopIPC() n.stopInProc() return err } - if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil { - n.stopHTTP() - n.stopIPC() - n.stopInProc() - return err + // if endpoints are not the same, start separate servers + if n.httpEndpoint != n.wsEndpoint { + if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil { + n.stopHTTP() + n.stopIPC() + n.stopInProc() + return err + } } + // All API endpoints started successfully n.rpcAPIs = apis return nil @@ -348,22 +352,36 @@ func (n *Node) stopIPC() { } // startHTTP initializes and starts the HTTP RPC endpoint. -func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts) error { +func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts, wsOrigins []string) error { // Short circuit if the HTTP endpoint isn't being exposed if endpoint == "" { return nil } - listener, handler, err := rpc.StartHTTPEndpoint(endpoint, apis, modules, cors, vhosts, timeouts) + // register apis and create handler stack + srv := rpc.NewServer() + err := RegisterApisFromWhitelist(apis, modules, srv, false) + if err != nil { + return err + } + handler := NewHTTPHandlerStack(srv, cors, vhosts, &timeouts) + // wrap handler in websocket handler only if websocket port is the same as http rpc + if n.httpEndpoint == n.wsEndpoint { + handler = NewWebsocketUpgradeHandler(handler, srv.WebsocketHandler(wsOrigins)) + } + listener, err := StartHTTPEndpoint(endpoint, timeouts, handler) if err != nil { return err } n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", listener.Addr()), "cors", strings.Join(cors, ","), "vhosts", strings.Join(vhosts, ",")) + if n.httpEndpoint == n.wsEndpoint { + n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", listener.Addr())) + } // All listeners booted successfully n.httpEndpoint = endpoint n.httpListener = listener - n.httpHandler = handler + n.httpHandler = srv return nil } @@ -388,7 +406,14 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig if endpoint == "" { return nil } - listener, handler, err := rpc.StartWSEndpoint(endpoint, apis, modules, wsOrigins, exposeAll) + + srv := rpc.NewServer() + handler := srv.WebsocketHandler(wsOrigins) + err := RegisterApisFromWhitelist(apis, modules, srv, exposeAll) + if err != nil { + return err + } + listener, err := startWSEndpoint(endpoint, handler) if err != nil { return err } @@ -396,7 +421,7 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig // All listeners booted successfully n.wsEndpoint = endpoint n.wsListener = listener - n.wsHandler = handler + n.wsHandler = srv return nil } @@ -636,3 +661,25 @@ func (n *Node) apis() []rpc.API { }, } } + +// RegisterApisFromWhitelist checks the given modules' availability, generates a whitelist based on the allowed modules, +// and then registers all of the APIs exposed by the services. +func RegisterApisFromWhitelist(apis []rpc.API, modules []string, srv *rpc.Server, exposeAll bool) error { + if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 { + log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available) + } + // Generate the whitelist based on the allowed modules + whitelist := make(map[string]bool) + for _, module := range modules { + whitelist[module] = true + } + // Register all the APIs exposed by the services + for _, api := range apis { + if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { + if err := srv.RegisterName(api.Namespace, api.Service); err != nil { + return err + } + } + } + return nil +} diff --git a/node/node_test.go b/node/node_test.go index 1dee63a72e..4b4eec464d 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -18,6 +18,7 @@ package node import ( "errors" + "net/http" "os" "reflect" "testing" @@ -26,6 +27,7 @@ import ( "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/p2p" "github.com/XinFinOrg/XDPoSChain/rpc" + "github.com/stretchr/testify/assert" ) var ( @@ -332,7 +334,7 @@ func TestServiceStartupAbortion(t *testing.T) { } // Tests that even if a registered service fails to shut down cleanly, it does -// not influece the rest of the shutdown invocations. +// not influence the rest of the shutdown invocations. func TestServiceTerminationGuarantee(t *testing.T) { stack, err := New(testNodeConfig()) if err != nil { @@ -572,3 +574,58 @@ func TestAPIGather(t *testing.T) { } } } + +func TestWebsocketHTTPOnSamePort_WebsocketRequest(t *testing.T) { + node := startHTTP(t) + defer node.stopHTTP() + + wsReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil) + if err != nil { + t.Error("could not issue new http request ", err) + } + wsReq.Header.Set("Connection", "upgrade") + wsReq.Header.Set("Upgrade", "websocket") + wsReq.Header.Set("Sec-WebSocket-Version", "13") + wsReq.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==") + + resp := doHTTPRequest(t, wsReq) + assert.Equal(t, "websocket", resp.Header.Get("Upgrade")) +} + +func TestWebsocketHTTPOnSamePort_HTTPRequest(t *testing.T) { + node := startHTTP(t) + defer node.stopHTTP() + + httpReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil) + if err != nil { + t.Error("could not issue new http request ", err) + } + httpReq.Header.Set("Accept-Encoding", "gzip") + + resp := doHTTPRequest(t, httpReq) + assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding")) +} + +func startHTTP(t *testing.T) *Node { + conf := &Config{HTTPPort: 7453, WSPort: 7453} + node, err := New(conf) + if err != nil { + t.Error("could not create a new node ", err) + } + + err = node.startHTTP("127.0.0.1:7453", []rpc.API{}, []string{}, []string{}, []string{}, rpc.HTTPTimeouts{}, []string{}) + if err != nil { + t.Error("could not start http service on node ", err) + } + + return node +} + +func doHTTPRequest(t *testing.T, req *http.Request) *http.Response { + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Error("could not issue a GET request to the given endpoint", err) + } + return resp +} diff --git a/node/rpcstack.go b/node/rpcstack.go new file mode 100644 index 0000000000..0c6ecc4485 --- /dev/null +++ b/node/rpcstack.go @@ -0,0 +1,170 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "compress/gzip" + "io" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/XinFinOrg/XDPoSChain/log" + "github.com/XinFinOrg/XDPoSChain/rpc" + "github.com/rs/cors" +) + +// NewHTTPHandlerStack returns wrapped http-related handlers +func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, timeouts *rpc.HTTPTimeouts) http.Handler { + // Wrap the CORS-handler within a host-handler + handler := newCorsHandler(srv, cors) + handler = newVHostHandler(vhosts, handler) + handler = newGzipHandler(handler) + + // make sure timeout values are meaningful + CheckTimeouts(timeouts) + + // PR #469: register timeout handler before WebSocket and HTTP + handler = http.TimeoutHandler(handler, timeouts.WriteTimeout, `{"error":"http server timeout"}`) + + // add 1 second to let TimeoutHandler works first + timeouts.WriteTimeout = timeouts.WriteTimeout + time.Second + return handler +} + +func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler { + // disable CORS support if user has not specified a custom CORS configuration + if len(allowedOrigins) == 0 { + return srv + } + c := cors.New(cors.Options{ + AllowedOrigins: allowedOrigins, + AllowedMethods: []string{http.MethodPost, http.MethodGet}, + MaxAge: 600, + AllowedHeaders: []string{"*"}, + }) + return c.Handler(srv) +} + +// virtualHostHandler is a handler which validates the Host-header of incoming requests. +// Using virtual hosts can help prevent DNS rebinding attacks, where a 'random' domain name points to +// the service ip address (but without CORS headers). By verifying the targeted virtual host, we can +// ensure that it's a destination that the node operator has defined. +type virtualHostHandler struct { + vhosts map[string]struct{} + next http.Handler +} + +func newVHostHandler(vhosts []string, next http.Handler) http.Handler { + vhostMap := make(map[string]struct{}) + for _, allowedHost := range vhosts { + vhostMap[strings.ToLower(allowedHost)] = struct{}{} + } + return &virtualHostHandler{vhostMap, next} +} + +// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler +func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // if r.Host is not set, we can continue serving since a browser would set the Host header + if r.Host == "" { + h.next.ServeHTTP(w, r) + return + } + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + // Either invalid (too many colons) or no port specified + host = r.Host + } + if ipAddr := net.ParseIP(host); ipAddr != nil { + // It's an IP address, we can serve that + h.next.ServeHTTP(w, r) + return + + } + // Not an IP address, but a hostname. Need to validate + if _, exist := h.vhosts["*"]; exist { + h.next.ServeHTTP(w, r) + return + } + if _, exist := h.vhosts[host]; exist { + h.next.ServeHTTP(w, r) + return + } + http.Error(w, "invalid host specified", http.StatusForbidden) +} + +var gzPool = sync.Pool{ + New: func() interface{} { + w := gzip.NewWriter(io.Discard) + return w + }, +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter +} + +func (w *gzipResponseWriter) WriteHeader(status int) { + w.Header().Del("Content-Length") + w.ResponseWriter.WriteHeader(status) +} + +func (w *gzipResponseWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) +} + +func newGzipHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + next.ServeHTTP(w, r) + return + } + + w.Header().Set("Content-Encoding", "gzip") + + gz := gzPool.Get().(*gzip.Writer) + defer gzPool.Put(gz) + + gz.Reset(w) + defer gz.Close() + + next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r) + }) +} + +// NewWebsocketUpgradeHandler returns a websocket handler that serves an incoming request only if it contains an upgrade +// request to the websocket protocol. If not, serves the the request with the http handler. +func NewWebsocketUpgradeHandler(h http.Handler, ws http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isWebsocket(r) { + ws.ServeHTTP(w, r) + log.Debug("serving websocket request") + return + } + + h.ServeHTTP(w, r) + }) +} + +// isWebsocket checks the header of an http request for a websocket upgrade request. +func isWebsocket(r *http.Request) bool { + return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" && + strings.ToLower(r.Header.Get("Connection")) == "upgrade" +} diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go new file mode 100644 index 0000000000..28bb786858 --- /dev/null +++ b/node/rpcstack_test.go @@ -0,0 +1,38 @@ +package node + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/XinFinOrg/XDPoSChain/rpc" + "github.com/stretchr/testify/assert" +) + +func TestNewWebsocketUpgradeHandler_websocket(t *testing.T) { + srv := rpc.NewServer() + + handler := NewWebsocketUpgradeHandler(nil, srv.WebsocketHandler([]string{})) + ts := httptest.NewServer(handler) + defer ts.Close() + + responses := make(chan *http.Response) + go func(responses chan *http.Response) { + client := &http.Client{} + + req, _ := http.NewRequest(http.MethodGet, ts.URL, nil) + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==") + + resp, err := client.Do(req) + if err != nil { + t.Error("could not issue a GET request to the test http server", err) + } + responses <- resp + }(responses) + + response := <-responses + assert.Equal(t, "websocket", response.Header.Get("Upgrade")) +} diff --git a/rpc/endpoints.go b/rpc/endpoints.go index eb14e99f6f..0b543102fc 100644 --- a/rpc/endpoints.go +++ b/rpc/endpoints.go @@ -23,89 +23,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/log" ) -// checkModuleAvailability checks that all names given in modules are actually -// available API services. It assumes that the MetadataApi module ("rpc") is always available; -// the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints. -func checkModuleAvailability(modules []string, apis []API) (bad, available []string) { - availableSet := make(map[string]struct{}) - for _, api := range apis { - if _, ok := availableSet[api.Namespace]; !ok { - availableSet[api.Namespace] = struct{}{} - available = append(available, api.Namespace) - } - } - for _, name := range modules { - if _, ok := availableSet[name]; !ok && name != MetadataApi { - bad = append(bad, name) - } - } - return bad, available -} - -// StartHTTPEndpoint starts the HTTP RPC endpoint, configured with cors/vhosts/modules. -func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []string, vhosts []string, timeouts HTTPTimeouts) (net.Listener, *Server, error) { - if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 { - log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available) - } - // Generate the whitelist based on the allowed modules - whitelist := make(map[string]bool) - for _, module := range modules { - whitelist[module] = true - } - // Register all the APIs exposed by the services - handler := NewServer() - for _, api := range apis { - if whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { - if err := handler.RegisterName(api.Namespace, api.Service); err != nil { - return nil, nil, err - } - log.Debug("HTTP registered", "namespace", api.Namespace) - } - } - // All APIs registered, start the HTTP listener - var ( - listener net.Listener - err error - ) - if listener, err = net.Listen("tcp", endpoint); err != nil { - return nil, nil, err - } - go NewHTTPServer(cors, vhosts, timeouts, handler).Serve(listener) - return listener, handler, err -} - -// StartWSEndpoint starts a websocket endpoint. -func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []string, exposeAll bool) (net.Listener, *Server, error) { - if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 { - log.Error("Unavailable modules in WS API list", "unavailable", bad, "available", available) - } - // Generate the whitelist based on the allowed modules - whitelist := make(map[string]bool) - for _, module := range modules { - whitelist[module] = true - } - // Register all the APIs exposed by the services - handler := NewServer() - for _, api := range apis { - if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { - if err := handler.RegisterName(api.Namespace, api.Service); err != nil { - return nil, nil, err - } - log.Debug("WebSocket registered", "service", api.Service, "namespace", api.Namespace) - } - } - // All APIs registered, start the HTTP listener - var ( - listener net.Listener - err error - ) - if listener, err = net.Listen("tcp", endpoint); err != nil { - return nil, nil, err - } - go NewWSServer(wsOrigins, handler).Serve(listener) - return listener, handler, err -} - // StartIPCEndpoint starts an IPC endpoint. func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, error) { // Register all the APIs exposed by the services. @@ -119,7 +36,7 @@ func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, er log.Info("IPC registration failed", "namespace", api.Namespace, "error", err) return nil, nil, err } - log.Debug("IPC registered", "service", api.Service, "namespace", api.Namespace) + log.Debug("IPC registered", "namespace", api.Namespace) if _, ok := regMap[api.Namespace]; !ok { registered = append(registered, api.Namespace) regMap[api.Namespace] = struct{}{} diff --git a/rpc/http.go b/rpc/http.go index 2416d66753..ee23c79a1a 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -24,15 +24,10 @@ import ( "fmt" "io" "mime" - "net" "net/http" "net/url" - "strings" "sync" "time" - - "github.com/XinFinOrg/XDPoSChain/log" - "github.com/rs/cors" ) const ( @@ -237,41 +232,6 @@ func (t *httpServerConn) RemoteAddr() string { // SetWriteDeadline does nothing and always returns nil. func (t *httpServerConn) SetWriteDeadline(time.Time) error { return nil } -// NewHTTPServer creates a new HTTP RPC server around an API provider. -// -// Deprecated: Server implements http.Handler -func NewHTTPServer(cors []string, vhosts []string, timeouts HTTPTimeouts, srv *Server) *http.Server { - // Wrap the CORS-handler within a host-handler - handler := newCorsHandler(srv, cors) - handler = newVHostHandler(vhosts, handler) - - // Make sure timeout values are meaningful - if timeouts.ReadTimeout < time.Second { - log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", DefaultHTTPTimeouts.ReadTimeout) - timeouts.ReadTimeout = DefaultHTTPTimeouts.ReadTimeout - } - if timeouts.WriteTimeout < time.Second { - log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", DefaultHTTPTimeouts.WriteTimeout) - timeouts.WriteTimeout = DefaultHTTPTimeouts.WriteTimeout - } - if timeouts.IdleTimeout < time.Second { - log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", DefaultHTTPTimeouts.IdleTimeout) - timeouts.IdleTimeout = DefaultHTTPTimeouts.IdleTimeout - } - - // PR #469: return http code 503 and error message to client when timeout - handler = http.TimeoutHandler(handler, timeouts.WriteTimeout, `{"error":"http server timeout"}`) - log.Info("NewHTTPServer", "writeTimeout", timeouts.WriteTimeout) - - // Bundle and start the HTTP server - return &http.Server{ - Handler: handler, - ReadTimeout: timeouts.ReadTimeout, - WriteTimeout: timeouts.WriteTimeout + time.Second, - IdleTimeout: timeouts.IdleTimeout, - } -} - // ServeHTTP serves JSON-RPC requests over HTTP. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Permit dumb empty requests for remote health-checks (AWS) @@ -328,64 +288,3 @@ func validateRequest(r *http.Request) (int, error) { err := fmt.Errorf("invalid content type, only %s is supported", contentType) return http.StatusUnsupportedMediaType, err } - -func newCorsHandler(srv *Server, allowedOrigins []string) http.Handler { - // disable CORS support if user has not specified a custom CORS configuration - if len(allowedOrigins) == 0 { - return srv - } - c := cors.New(cors.Options{ - AllowedOrigins: allowedOrigins, - AllowedMethods: []string{http.MethodPost, http.MethodGet}, - MaxAge: 600, - AllowedHeaders: []string{"*"}, - }) - return c.Handler(srv) -} - -// virtualHostHandler is a handler which validates the Host-header of incoming requests. -// The virtualHostHandler can prevent DNS rebinding attacks, which do not utilize CORS-headers, -// since they do in-domain requests against the RPC api. Instead, we can see on the Host-header -// which domain was used, and validate that against a whitelist. -type virtualHostHandler struct { - vhosts map[string]struct{} - next http.Handler -} - -func newVHostHandler(vhosts []string, next http.Handler) http.Handler { - vhostMap := make(map[string]struct{}) - for _, allowedHost := range vhosts { - vhostMap[strings.ToLower(allowedHost)] = struct{}{} - } - return &virtualHostHandler{vhostMap, next} -} - -// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler -func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // if r.Host is not set, we can continue serving since a browser would set the Host header - if r.Host == "" { - h.next.ServeHTTP(w, r) - return - } - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - // Either invalid (too many colons) or no port specified - host = r.Host - } - if ipAddr := net.ParseIP(host); ipAddr != nil { - // It's an IP address, we can serve that - h.next.ServeHTTP(w, r) - return - - } - // Not an ip address, but a hostname. Need to validate - if _, exist := h.vhosts["*"]; exist { - h.next.ServeHTTP(w, r) - return - } - if _, exist := h.vhosts[host]; exist { - h.next.ServeHTTP(w, r) - return - } - http.Error(w, "invalid host specified", http.StatusForbidden) -} diff --git a/rpc/websocket.go b/rpc/websocket.go index fdc1d05e9f..7c7cae4554 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -64,13 +64,6 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { }) } -// NewWSServer creates a new websocket RPC server around an API provider. -// -// Deprecated: use Server.WebsocketHandler -func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { - return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} -} - // wsHandshakeValidator returns a handler that verifies the origin during the // websocket upgrade process. When a '*' is specified as an allowed origins all // connections are accepted.