forked from forks/go-ethereum
rpc: add method name length limit (#31711)
This change adds a limit for RPC method names to prevent potential abuse where large method names could lead to large response sizes. The limit is enforced in: - handleCall for regular RPC method calls - handleSubscribe for subscription method calls Added tests in websocket_test.go to verify the length limit functionality for both regular method calls and subscriptions. --------- Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
parent
bca0646ede
commit
b135da2eac
3 changed files with 84 additions and 0 deletions
|
|
@ -501,6 +501,10 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
|
|||
if msg.isUnsubscribe() {
|
||||
callb = h.unsubscribeCb
|
||||
} else {
|
||||
// Check method name length
|
||||
if len(msg.Method) > maxMethodNameLength {
|
||||
return msg.errorResponse(&invalidRequestError{fmt.Sprintf("method name too long: %d > %d", len(msg.Method), maxMethodNameLength)})
|
||||
}
|
||||
callb = h.reg.callback(msg.Method)
|
||||
}
|
||||
if callb == nil {
|
||||
|
|
@ -536,6 +540,11 @@ func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMes
|
|||
return msg.errorResponse(ErrNotificationsUnsupported)
|
||||
}
|
||||
|
||||
// Check method name length
|
||||
if len(msg.Method) > maxMethodNameLength {
|
||||
return msg.errorResponse(&invalidRequestError{fmt.Sprintf("subscription name too long: %d > %d", len(msg.Method), maxMethodNameLength)})
|
||||
}
|
||||
|
||||
// Subscription method name is first argument.
|
||||
name, err := parseSubscriptionName(msg.Params)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ const (
|
|||
subscribeMethodSuffix = "_subscribe"
|
||||
unsubscribeMethodSuffix = "_unsubscribe"
|
||||
notificationMethodSuffix = "_subscription"
|
||||
maxMethodNameLength = 2048
|
||||
|
||||
defaultWriteTimeout = 10 * time.Second // used if context has no deadline
|
||||
)
|
||||
|
|
|
|||
|
|
@ -391,3 +391,77 @@ func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketMethodNameLengthLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
srv = newTestServer()
|
||||
httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
|
||||
wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
|
||||
)
|
||||
defer srv.Stop()
|
||||
defer httpsrv.Close()
|
||||
|
||||
client, err := DialWebsocket(context.Background(), wsURL, "")
|
||||
if err != nil {
|
||||
t.Fatalf("can't dial: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
params []interface{}
|
||||
expectedError string
|
||||
isSubscription bool
|
||||
}{
|
||||
{
|
||||
name: "valid method name",
|
||||
method: "test_echo",
|
||||
params: []interface{}{"test", 1},
|
||||
expectedError: "",
|
||||
isSubscription: false,
|
||||
},
|
||||
{
|
||||
name: "method name too long",
|
||||
method: "test_" + string(make([]byte, maxMethodNameLength+1)),
|
||||
params: []interface{}{"test", 1},
|
||||
expectedError: "method name too long",
|
||||
isSubscription: false,
|
||||
},
|
||||
{
|
||||
name: "valid subscription",
|
||||
method: "nftest_subscribe",
|
||||
params: []interface{}{"someSubscription", 1, 2},
|
||||
expectedError: "",
|
||||
isSubscription: true,
|
||||
},
|
||||
{
|
||||
name: "subscription name too long",
|
||||
method: string(make([]byte, maxMethodNameLength+1)) + "_subscribe",
|
||||
params: []interface{}{"newHeads"},
|
||||
expectedError: "subscription name too long",
|
||||
isSubscription: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var result interface{}
|
||||
err := client.Call(&result, tt.method, tt.params...)
|
||||
if tt.expectedError == "" {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), tt.expectedError) {
|
||||
t.Errorf("expected error containing %q, got %q", tt.expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue