diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index a8d8624900..3b7d5a9da0 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -121,56 +121,49 @@ func TestWebsocketLargeRead(t *testing.T) { srv = newTestServer() httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + buffer = 64 ) defer srv.Stop() defer httpsrv.Close() - testLimit := func(limit *int64) { - opts := []ClientOption{} - expLimit := int64(wsDefaultReadLimit) - if limit != nil && *limit >= 0 { - opts = append(opts, WithWebsocketMessageSizeLimit(*limit)) - if *limit > 0 { - expLimit = *limit // 0 means infinite + for _, tt := range []struct { + size int + limit int + err bool + }{ + {200, 200, false}, // Small, successful request and limit + {2048, 1024, true}, // Normal, failed request + {wsDefaultReadLimit + buffer, 0, false}, // Large, successful request, infinite limit + } { + func() { + if tt.limit != 0 { + // Some buffer is added to the limit to account for JSON encoding. It's + // skipped when the limit is zero since the intention is for the limit + // to be infinite. + tt.limit += buffer } - } - client, err := DialOptions(context.Background(), wsURL, opts...) - if err != nil { - t.Fatalf("can't dial: %v", err) - } - defer client.Close() - // Remove some bytes for json encoding overhead. - underLimit := int(expLimit - 128) - overLimit := expLimit + 1 - if expLimit == wsDefaultReadLimit { - // No point trying the full 32MB in tests. Just sanity-check that - // it's not obviously limited. - underLimit = 1024 - overLimit = -1 - } - var res string - // Check under limit - if err = client.Call(&res, "test_repeat", "A", underLimit); err != nil { - t.Fatalf("unexpected error with limit %d: %v", expLimit, err) - } - if len(res) != underLimit || strings.Count(res, "A") != underLimit { - t.Fatal("incorrect data") - } - // Check over limit - if overLimit > 0 { - err = client.Call(&res, "test_repeat", "A", expLimit+1) - if err == nil || err != websocket.ErrReadLimit { - t.Fatalf("wrong error with limit %d: %v expecting %v", expLimit, err, websocket.ErrReadLimit) + opts := []ClientOption{WithWebsocketMessageSizeLimit(int64(tt.limit))} + client, err := DialOptions(context.Background(), wsURL, opts...) + if err != nil { + t.Fatalf("failed to dial test server: %v", err) } - } - } - ptr := func(v int64) *int64 { return &v } + defer client.Close() - testLimit(ptr(-1)) // Should be ignored (use default) - testLimit(ptr(0)) // Should be ignored (use default) - testLimit(nil) // Should be ignored (use default) - testLimit(ptr(200)) - testLimit(ptr(wsDefaultReadLimit * 2)) + var res string + err = client.Call(&res, "test_repeat", "A", tt.size) + if tt.err && err == nil { + t.Fatalf("expected error, got none") + } + if !tt.err { + if err != nil { + t.Fatalf("unexpected error with limit %d: %v", tt.limit, err) + } + if strings.Count(res, "A") != tt.size { + t.Fatal("incorrect data") + } + } + }() + } } func TestWebsocketPeerInfo(t *testing.T) {