diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index 6602aebd4f..a39ae936b5 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -6,8 +6,6 @@ import ( "io" "net/http" "net/http/httptest" - "net/url" - "strconv" "strings" "testing" "time" @@ -134,14 +132,13 @@ func TestWebsocketOrigins(t *testing.T) { } for _, tc := range tests { srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}, nil) - url := fmt.Sprintf("ws://%v", srv.listenAddr()) for _, origin := range tc.expOk { - if err := wsRequest(t, url, "Origin", origin); err != nil { + if err := attemptWebsocketConnectionFromOrigin(t, srv, origin); err != nil { t.Errorf("spec '%v', origin '%v': expected ok, got %v", tc.spec, origin, err) } } for _, origin := range tc.expFail { - if err := wsRequest(t, url, "Origin", origin); err == nil { + if err := attemptWebsocketConnectionFromOrigin(t, srv, origin); err == nil { t.Errorf("spec '%v', origin '%v': expected not to allow, got ok", tc.spec, origin) } } @@ -149,61 +146,6 @@ func TestWebsocketOrigins(t *testing.T) { } } -func Test_checkPath(t *testing.T) { - tests := []struct { - req *http.Request - prefix string - expected bool - }{ - { - req: &http.Request{URL: &url.URL{Path: "/test"}}, - prefix: "/test", - expected: true, - }, - { - req: &http.Request{URL: &url.URL{Path: "/testing"}}, - prefix: "/test", - expected: true, - }, - { - req: &http.Request{URL: &url.URL{Path: "/"}}, - prefix: "/test", - expected: false, - }, - { - req: &http.Request{URL: &url.URL{Path: "/fail"}}, - prefix: "/test", - expected: false, - }, - { - req: &http.Request{URL: &url.URL{Path: "/"}}, - prefix: "", - expected: true, - }, - { - req: &http.Request{URL: &url.URL{Path: "/fail"}}, - prefix: "", - expected: false, - }, - { - req: &http.Request{URL: &url.URL{Path: "/"}}, - prefix: "/", - expected: true, - }, - { - req: &http.Request{URL: &url.URL{Path: "/testing"}}, - prefix: "/", - expected: true, - }, - } - - for i, tt := range tests { - t.Run(strconv.Itoa(i), func(t *testing.T) { - assert.Equal(t, tt.expected, checkPath(tt.req, tt.prefix)) - }) - } -} - func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsConfig, timeouts *rpc.HTTPTimeouts) *httpServer { t.Helper() @@ -221,6 +163,17 @@ func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsCon return srv } +func attemptWebsocketConnectionFromOrigin(t *testing.T, srv *httpServer, browserOrigin string) error { + t.Helper() + dialer := websocket.DefaultDialer + _, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{ + "Content-type": []string{"application/json"}, + "Sec-WebSocket-Version": []string{"13"}, + "Origin": []string{browserOrigin}, + }) + return err +} + // wsRequest attempts to open a WebSocket connection to the given URL. func wsRequest(t *testing.T, url string, extraHeaders ...string) error { t.Helper()