rpc: add SetWebsocketReadLimit in Server (#32279)
Some checks are pending
/ Linux Build (push) Waiting to run
/ Linux Build (arm) (push) Waiting to run
/ Windows Build (push) Waiting to run
/ Docker Image (push) Waiting to run

Exposing the public method to setReadLimits for Websocket RPC to
prevent OOM.

Current, Geth Server is using a default 32MB max read limit (message
size) for websocket, which is prune to being attacked for OOM. Any one
can easily launch a client to send a bunch of concurrent large request
to cause the node to crash for OOM. One example of such script that can
easily crash a Geth node running websocket server is like this:

ec830979ac/poc.go

---------

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
Yiming Zang 2025-08-18 23:32:59 -07:00 committed by GitHub
parent 42bf4844d8
commit d93f820358
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 98 additions and 1 deletions

View file

@ -54,6 +54,7 @@ type Server struct {
batchItemLimit int batchItemLimit int
batchResponseLimit int batchResponseLimit int
httpBodyLimit int httpBodyLimit int
wsReadLimit int64
} }
// NewServer creates a new server instance with no registered handlers. // NewServer creates a new server instance with no registered handlers.
@ -62,6 +63,7 @@ func NewServer() *Server {
idgen: randomIDGenerator(), idgen: randomIDGenerator(),
codecs: make(map[ServerCodec]struct{}), codecs: make(map[ServerCodec]struct{}),
httpBodyLimit: defaultBodyLimit, httpBodyLimit: defaultBodyLimit,
wsReadLimit: wsDefaultReadLimit,
} }
server.run.Store(true) server.run.Store(true)
// Register the default service providing meta information about the RPC service such // Register the default service providing meta information about the RPC service such
@ -89,6 +91,13 @@ func (s *Server) SetHTTPBodyLimit(limit int) {
s.httpBodyLimit = limit s.httpBodyLimit = limit
} }
// SetWebsocketReadLimit sets the limit for max message size for Websocket requests.
//
// This method should be called before processing any requests via Websocket server.
func (s *Server) SetWebsocketReadLimit(limit int64) {
s.wsReadLimit = limit
}
// RegisterName creates a service for the given receiver type under the given name. When no // RegisterName creates a service for the given receiver type under the given name. When no
// methods on the given receiver match the criteria to be either an RPC method or a // methods on the given receiver match the criteria to be either an RPC method or a
// subscription an error is returned. Otherwise a new service is created and added to the // subscription an error is returned. Otherwise a new service is created and added to the

View file

@ -19,13 +19,18 @@ package rpc
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"errors"
"io" "io"
"net" "net"
"net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/gorilla/websocket"
) )
func TestServerRegisterName(t *testing.T) { func TestServerRegisterName(t *testing.T) {
@ -202,3 +207,86 @@ func TestServerBatchResponseSizeLimit(t *testing.T) {
} }
} }
} }
func TestServerWebsocketReadLimit(t *testing.T) {
t.Parallel()
// Test different read limits
testCases := []struct {
name string
readLimit int64
testSize int
shouldFail bool
}{
{
name: "limit with small request - should succeed",
readLimit: 4096, // generous limit to comfortably allow JSON overhead
testSize: 256, // reasonably small payload
shouldFail: false,
},
{
name: "limit with large request - should fail",
readLimit: 256, // tight limit to trigger server-side read limit
testSize: 1024, // payload that will exceed the limit including JSON overhead
shouldFail: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create server and set read limits
srv := newTestServer()
srv.SetWebsocketReadLimit(tc.readLimit)
defer srv.Stop()
// Start HTTP server with WebSocket handler
httpsrv := httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
defer httpsrv.Close()
wsURL := "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
// Connect WebSocket client
client, err := DialOptions(context.Background(), wsURL)
if err != nil {
t.Fatalf("can't dial: %v", err)
}
defer client.Close()
// Create large request data - this is what will be limited
largeString := strings.Repeat("A", tc.testSize)
// Send the large string as a parameter in the request
var result echoResult
err = client.Call(&result, "test_echo", largeString, 42, &echoArgs{S: "test"})
if tc.shouldFail {
// Expecting an error due to read limit exceeded
if err == nil {
t.Fatalf("expected error for request size %d with limit %d, but got none", tc.testSize, tc.readLimit)
}
// Be tolerant about the exact error surfaced by gorilla/websocket.
// Prefer a CloseError with code 1009, but accept ErrReadLimit or an error string containing 1009/message too big.
var cerr *websocket.CloseError
if errors.As(err, &cerr) {
if cerr.Code != websocket.CloseMessageTooBig {
t.Fatalf("unexpected websocket close code: have %d want %d (err=%v)", cerr.Code, websocket.CloseMessageTooBig, err)
}
} else if !errors.Is(err, websocket.ErrReadLimit) &&
!strings.Contains(strings.ToLower(err.Error()), "1009") &&
!strings.Contains(strings.ToLower(err.Error()), "message too big") {
// Not the error we expect from exceeding the message size limit.
t.Fatalf("unexpected error for read limit violation: %v", err)
}
} else {
// Expecting success
if err != nil {
t.Fatalf("unexpected error for request size %d with limit %d: %v", tc.testSize, tc.readLimit, err)
}
// Verify the response is correct - the echo should return our string
if result.String != largeString {
t.Fatalf("expected echo result to match input")
}
}
})
}
}

View file

@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
log.Debug("WebSocket upgrade failed", "err", err) log.Debug("WebSocket upgrade failed", "err", err)
return return
} }
codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit) codec := newWebsocketCodec(conn, r.Host, r.Header, s.wsReadLimit)
s.ServeCodec(codec, 0) s.ServeCodec(codec, 0)
}) })
} }