From b135da2eac9bd3beb043f4b11418c09034fb9cc3 Mon Sep 17 00:00:00 2001 From: Matus Kysel Date: Mon, 5 May 2025 14:43:47 +0200 Subject: [PATCH] rpc: add method name length limit (#31711) This change adds a limit for RPC method names to prevent potential abuse where large method names could lead to large response sizes. The limit is enforced in: - handleCall for regular RPC method calls - handleSubscribe for subscription method calls Added tests in websocket_test.go to verify the length limit functionality for both regular method calls and subscriptions. --------- Co-authored-by: Felix Lange --- rpc/handler.go | 9 ++++++ rpc/json.go | 1 + rpc/websocket_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+) diff --git a/rpc/handler.go b/rpc/handler.go index f23b544b58..45558d5821 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -501,6 +501,10 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage if msg.isUnsubscribe() { callb = h.unsubscribeCb } else { + // Check method name length + if len(msg.Method) > maxMethodNameLength { + return msg.errorResponse(&invalidRequestError{fmt.Sprintf("method name too long: %d > %d", len(msg.Method), maxMethodNameLength)}) + } callb = h.reg.callback(msg.Method) } if callb == nil { @@ -536,6 +540,11 @@ func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMes return msg.errorResponse(ErrNotificationsUnsupported) } + // Check method name length + if len(msg.Method) > maxMethodNameLength { + return msg.errorResponse(&invalidRequestError{fmt.Sprintf("subscription name too long: %d > %d", len(msg.Method), maxMethodNameLength)}) + } + // Subscription method name is first argument. name, err := parseSubscriptionName(msg.Params) if err != nil { diff --git a/rpc/json.go b/rpc/json.go index e932389d17..fcd801fc95 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -35,6 +35,7 @@ const ( subscribeMethodSuffix = "_subscribe" unsubscribeMethodSuffix = "_unsubscribe" notificationMethodSuffix = "_subscription" + maxMethodNameLength = 2048 defaultWriteTimeout = 10 * time.Second // used if context has no deadline ) diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 10a998b351..a8d8624900 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -391,3 +391,77 @@ func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <- } } } + +func TestWebsocketMethodNameLengthLimit(t *testing.T) { + t.Parallel() + + var ( + srv = newTestServer() + httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) + wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") + ) + defer srv.Stop() + defer httpsrv.Close() + + client, err := DialWebsocket(context.Background(), wsURL, "") + if err != nil { + t.Fatalf("can't dial: %v", err) + } + defer client.Close() + + // Test cases + tests := []struct { + name string + method string + params []interface{} + expectedError string + isSubscription bool + }{ + { + name: "valid method name", + method: "test_echo", + params: []interface{}{"test", 1}, + expectedError: "", + isSubscription: false, + }, + { + name: "method name too long", + method: "test_" + string(make([]byte, maxMethodNameLength+1)), + params: []interface{}{"test", 1}, + expectedError: "method name too long", + isSubscription: false, + }, + { + name: "valid subscription", + method: "nftest_subscribe", + params: []interface{}{"someSubscription", 1, 2}, + expectedError: "", + isSubscription: true, + }, + { + name: "subscription name too long", + method: string(make([]byte, maxMethodNameLength+1)) + "_subscribe", + params: []interface{}{"newHeads"}, + expectedError: "subscription name too long", + isSubscription: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var result interface{} + err := client.Call(&result, tt.method, tt.params...) + if tt.expectedError == "" { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } else { + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.expectedError) { + t.Errorf("expected error containing %q, got %q", tt.expectedError, err.Error()) + } + } + }) + } +}