rpc: add limit for batch request and response size #26681 (#998)

This commit is contained in:
Daniel Liu 2025-04-28 17:00:30 +08:00 committed by GitHub
parent 57c40154be
commit c75623ace7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 671 additions and 279 deletions

View file

@ -166,6 +166,8 @@ var (
utils.IPCPathFlag,
utils.RPCGlobalTxFeeCap,
utils.AllowUnprotectedTxs,
utils.BatchRequestLimit,
utils.BatchResponseMaxSize,
}
metricsFlags = []cli.Flag{

View file

@ -570,10 +570,22 @@ var (
Category: flags.APICategory,
}
AllowUnprotectedTxs = &cli.BoolFlag{
Name: "rpc.allow-unprotected-txs",
Name: "rpc-allow-unprotected-txs",
Usage: "Allow for unprotected (non EIP155 signed) transactions to be submitted via RPC",
Category: flags.APICategory,
}
BatchRequestLimit = &cli.IntFlag{
Name: "rpc-batch-request-limit",
Usage: "Maximum number of requests in a batch",
Value: node.DefaultConfig.BatchRequestLimit,
Category: flags.APICategory,
}
BatchResponseMaxSize = &cli.IntFlag{
Name: "rpc-batch-response-max-size",
Usage: "Maximum number of bytes returned from a batched call",
Value: node.DefaultConfig.BatchResponseMaxSize,
Category: flags.APICategory,
}
// Network Settings
MaxPeersFlag = &cli.IntFlag{
@ -1027,15 +1039,22 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) {
if ctx.IsSet(HTTPPortFlag.Name) {
cfg.HTTPPort = ctx.Int(HTTPPortFlag.Name)
}
if ctx.IsSet(AuthHostFlag.Name) {
cfg.AuthHost = ctx.String(AuthHostFlag.Name)
}
if ctx.IsSet(AuthPortFlag.Name) {
cfg.AuthPort = ctx.Int(AuthPortFlag.Name)
}
cfg.HTTPCors = SplitAndTrim(ctx.String(HTTPCORSDomainFlag.Name))
cfg.HTTPModules = SplitAndTrim(ctx.String(HTTPApiFlag.Name))
cfg.HTTPVirtualHosts = SplitAndTrim(ctx.String(HTTPVirtualHostsFlag.Name))
if ctx.IsSet(HTTPPathPrefixFlag.Name) {
cfg.HTTPPathPrefix = ctx.String(HTTPPathPrefixFlag.Name)
}
if ctx.IsSet(HTTPReadTimeoutFlag.Name) {
cfg.HTTPTimeouts.ReadTimeout = ctx.Duration(HTTPReadTimeoutFlag.Name)
}
@ -1048,12 +1067,17 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) {
if ctx.IsSet(HTTPIdleTimeoutFlag.Name) {
cfg.HTTPTimeouts.IdleTimeout = ctx.Duration(HTTPIdleTimeoutFlag.Name)
}
if ctx.IsSet(AllowUnprotectedTxs.Name) {
cfg.AllowUnprotectedTxs = ctx.Bool(AllowUnprotectedTxs.Name)
}
cfg.HTTPCors = SplitAndTrim(ctx.String(HTTPCORSDomainFlag.Name))
cfg.HTTPModules = SplitAndTrim(ctx.String(HTTPApiFlag.Name))
cfg.HTTPVirtualHosts = SplitAndTrim(ctx.String(HTTPVirtualHostsFlag.Name))
if ctx.IsSet(BatchRequestLimit.Name) {
cfg.BatchRequestLimit = ctx.Int(BatchRequestLimit.Name)
}
if ctx.IsSet(BatchResponseMaxSize.Name) {
cfg.BatchResponseMaxSize = ctx.Int(BatchResponseMaxSize.Name)
}
}
// setWS creates the WebSocket RPC listener interface string from the set

View file

@ -178,6 +178,10 @@ func (api *privateAdminAPI) StartRPC(host *string, port *int, cors *string, apis
CorsAllowedOrigins: api.node.config.HTTPCors,
Vhosts: api.node.config.HTTPVirtualHosts,
Modules: api.node.config.HTTPModules,
rpcEndpointConfig: rpcEndpointConfig{
batchItemLimit: api.node.config.BatchRequestLimit,
batchResponseSizeLimit: api.node.config.BatchResponseMaxSize,
},
}
if cors != nil {
config.CorsAllowedOrigins = nil
@ -237,6 +241,10 @@ func (api *privateAdminAPI) StartWS(host *string, port *int, allowedOrigins *str
config := wsConfig{
Modules: api.node.config.WSModules,
Origins: api.node.config.WSOrigins,
rpcEndpointConfig: rpcEndpointConfig{
batchItemLimit: api.node.config.BatchRequestLimit,
batchResponseSizeLimit: api.node.config.BatchResponseMaxSize,
},
}
if apis != nil {

View file

@ -185,6 +185,12 @@ type Config struct {
// AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC.
AllowUnprotectedTxs bool `toml:",omitempty"`
// BatchRequestLimit is the maximum number of requests in a batch.
BatchRequestLimit int `toml:",omitempty"`
// BatchResponseMaxSize is the maximum number of bytes returned from a batched rpc call.
BatchResponseMaxSize int `toml:",omitempty"`
// JWTSecret is the path to the hex-encoded jwt secret.
JWTSecret string `toml:",omitempty"`
}

View file

@ -46,15 +46,17 @@ var (
// DefaultConfig contains reasonable default settings.
var DefaultConfig = Config{
DataDir: DefaultDataDir(),
HTTPPort: DefaultHTTPPort,
AuthHost: DefaultAuthHost,
AuthPort: DefaultAuthPort,
HTTPModules: []string{"net", "web3"},
HTTPVirtualHosts: []string{"localhost"},
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
WSPort: DefaultWSPort,
WSModules: []string{"net", "web3"},
DataDir: DefaultDataDir(),
HTTPPort: DefaultHTTPPort,
AuthAddr: DefaultAuthHost,
AuthPort: DefaultAuthPort,
HTTPModules: []string{"net", "web3"},
HTTPVirtualHosts: []string{"localhost"},
HTTPTimeouts: rpc.DefaultHTTPTimeouts,
WSPort: DefaultWSPort,
WSModules: []string{"net", "web3"},
BatchRequestLimit: 1000,
BatchResponseMaxSize: 25 * 1000 * 1000,
P2P: p2p.Config{
ListenAddr: ":30303",
MaxPeers: 50,

View file

@ -103,10 +103,11 @@ func New(conf *Config) (*Node, error) {
if strings.HasSuffix(conf.Name, ".ipc") {
return nil, errors.New(`Config.Name cannot end in ".ipc"`)
}
server := rpc.NewServer()
server.SetBatchLimits(conf.BatchRequestLimit, conf.BatchResponseMaxSize)
node := &Node{
config: conf,
inprocHandler: rpc.NewServer(),
inprocHandler: server,
eventmux: new(event.TypeMux),
log: conf.Logger,
stop: make(chan struct{}),
@ -412,6 +413,11 @@ func (n *Node) startRPC() error {
open, all = n.GetAPIs()
)
rpcConfig := rpcEndpointConfig{
batchItemLimit: n.config.BatchRequestLimit,
batchResponseSizeLimit: n.config.BatchResponseMaxSize,
}
initHttp := func(server *httpServer, apis []rpc.API, port int) error {
if err := server.setListenAddr(n.config.HTTPHost, port); err != nil {
return err
@ -421,21 +427,24 @@ func (n *Node) startRPC() error {
Vhosts: n.config.HTTPVirtualHosts,
Modules: n.config.HTTPModules,
prefix: n.config.HTTPPathPrefix,
rpcEndpointConfig: rpcConfig,
}); err != nil {
return err
}
servers = append(servers, server)
return nil
}
initWS := func(port int) error {
server := n.wsServerForPort(port, false)
if err := server.setListenAddr(n.config.WSHost, port); err != nil {
return err
}
if err := server.enableWS(n.rpcAPIs, wsConfig{
Modules: n.config.WSModules,
Origins: n.config.WSOrigins,
prefix: n.config.WSPathPrefix,
Modules: n.config.WSModules,
Origins: n.config.WSOrigins,
prefix: n.config.WSPathPrefix,
rpcEndpointConfig: rpcConfig,
}); err != nil {
return err
}
@ -449,26 +458,29 @@ func (n *Node) startRPC() error {
if err := server.setListenAddr(n.config.AuthHost, port); err != nil {
return err
}
sharedConfig := rpcConfig
sharedConfig.jwtSecret = secret
if err := server.enableRPC(apis, httpConfig{
CorsAllowedOrigins: DefaultAuthCors,
Vhosts: DefaultAuthVhosts,
Modules: DefaultAuthModules,
prefix: DefaultAuthPrefix,
jwtSecret: secret,
rpcEndpointConfig: sharedConfig,
}); err != nil {
return err
}
servers = append(servers, server)
// Enable auth via WS
server = n.wsServerForPort(port, true)
if err := server.setListenAddr(n.config.AuthHost, port); err != nil {
return err
}
if err := server.enableWS(apis, wsConfig{
Modules: DefaultAuthModules,
Origins: DefaultAuthOrigins,
prefix: DefaultAuthPrefix,
jwtSecret: secret,
Modules: DefaultAuthModules,
Origins: DefaultAuthOrigins,
prefix: DefaultAuthPrefix,
rpcEndpointConfig: sharedConfig,
}); err != nil {
return err
}

View file

@ -40,15 +40,21 @@ type httpConfig struct {
CorsAllowedOrigins []string
Vhosts []string
prefix string // path prefix on which to mount http handler
jwtSecret []byte // optional JWT secret
rpcEndpointConfig
}
// wsConfig is the JSON-RPC/Websocket configuration
type wsConfig struct {
Origins []string
Modules []string
prefix string // path prefix on which to mount ws handler
jwtSecret []byte // optional JWT secret
Origins []string
Modules []string
prefix string // path prefix on which to mount ws handler
rpcEndpointConfig
}
type rpcEndpointConfig struct {
jwtSecret []byte // optional JWT secret
batchItemLimit int
batchResponseSizeLimit int
}
type rpcHandler struct {
@ -282,6 +288,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
// Create RPC server and handler.
srv := rpc.NewServer()
srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit)
if err := RegisterApisFromWhitelist(apis, config.Modules, srv); err != nil {
return err
}
@ -314,6 +321,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
// Create RPC server and handler.
srv := rpc.NewServer()
srv.SetBatchLimits(config.batchItemLimit, config.batchResponseSizeLimit)
if err := RegisterApisFromWhitelist(apis, config.Modules, srv); err != nil {
return err
}

View file

@ -309,61 +309,110 @@ func TestJWT(t *testing.T) {
ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret)
return ss
}
expOk := []string{
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 4})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 4})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{
"iat": time.Now().Unix(),
"exp": time.Now().Unix() + 2,
})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{
"iat": time.Now().Unix(),
"bar": "baz",
})),
}
expFail := []string{
// future
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + int64(jwtExpiryTimeout.Seconds()) + 1})),
// stale
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - int64(jwtExpiryTimeout.Seconds()) - 1})),
// wrong algo
fmt.Sprintf("Bearer %v", issueToken(secret, jwt.SigningMethodHS512, testClaim{"iat": time.Now().Unix() + 4})),
// expired
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix(), "exp": time.Now().Unix()})),
// missing mandatory iat
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{})),
// wrong secret
fmt.Sprintf("Bearer %v", issueToken([]byte("wrong"), nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken([]byte{}, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken(nil, nil, testClaim{"iat": time.Now().Unix()})),
// Various malformed syntax
fmt.Sprintf("%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer: %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer:%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
fmt.Sprintf("Bearer \t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})),
}
srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")},
true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}, nil)
cfg := rpcEndpointConfig{jwtSecret: []byte("secret")}
httpcfg := &httpConfig{rpcEndpointConfig: cfg}
wscfg := &wsConfig{Origins: []string{"*"}, rpcEndpointConfig: cfg}
srv := createAndStartServer(t, httpcfg, true, wscfg, nil)
wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr())
htUrl := fmt.Sprintf("http://%v", srv.listenAddr())
for i, token := range expOk {
expOk := []func() string{
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()}))
},
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 4}))
},
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 4}))
},
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{
"iat": time.Now().Unix(),
"exp": time.Now().Unix() + 2,
}))
},
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{
"iat": time.Now().Unix(),
"bar": "baz",
}))
},
}
for i, tokenFn := range expOk {
token := tokenFn()
if err := wsRequest(t, wsUrl, "Authorization", token); err != nil {
t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err)
}
token = tokenFn()
if resp := rpcRequest(t, htUrl, testMethod, "Authorization", token); resp.StatusCode != 200 {
t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode)
}
}
for i, token := range expFail {
expFail := []func() string{
// future
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + int64(jwtExpiryTimeout.Seconds()) + 1}))
},
// stale
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - int64(jwtExpiryTimeout.Seconds()) - 1}))
},
// wrong algo
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, jwt.SigningMethodHS512, testClaim{"iat": time.Now().Unix() + 4}))
},
// expired
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix(), "exp": time.Now().Unix()}))
},
// missing mandatory iat
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{}))
},
// wrong secret
func() string {
return fmt.Sprintf("Bearer %v", issueToken([]byte("wrong"), nil, testClaim{"iat": time.Now().Unix()}))
},
func() string {
return fmt.Sprintf("Bearer %v", issueToken([]byte{}, nil, testClaim{"iat": time.Now().Unix()}))
},
func() string {
return fmt.Sprintf("Bearer %v", issueToken(nil, nil, testClaim{"iat": time.Now().Unix()}))
},
// Various malformed syntax
func() string {
return fmt.Sprintf("%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()}))
},
func() string {
return fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()}))
},
func() string {
return fmt.Sprintf("bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()}))
},
func() string {
return fmt.Sprintf("Bearer: %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()}))
},
func() string {
return fmt.Sprintf("Bearer:%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()}))
},
func() string {
return fmt.Sprintf("Bearer\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()}))
},
func() string {
return fmt.Sprintf("Bearer \t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()}))
},
}
for i, tokenFn := range expFail {
token := tokenFn()
if err := wsRequest(t, wsUrl, "Authorization", token); err == nil {
t.Errorf("tc %d-ws, token '%v': expected not to allow, got ok", i, token)
}
if resp := rpcRequest(t, htUrl, testMethod, "Authorization", token); resp.StatusCode != 403 {
token = tokenFn()
resp := rpcRequest(t, htUrl, testMethod, "Authorization", token)
if resp.StatusCode != http.StatusForbidden {
t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode)
}
}

View file

@ -34,14 +34,15 @@ import (
var (
ErrBadResult = errors.New("bad result in JSON-RPC response")
ErrClientQuit = errors.New("client is closed")
ErrNoResult = errors.New("no result in JSON-RPC response")
ErrNoResult = errors.New("JSON-RPC response has no result")
ErrMissingBatchResponse = errors.New("response batch did not contain a response to this call")
ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow")
errClientReconnected = errors.New("client reconnected")
errDead = errors.New("connection lost")
)
// Timeouts
const (
// Timeouts
defaultDialTimeout = 10 * time.Second // used if context has no deadline
subscribeTimeout = 10 * time.Second // overall timeout eth_subscribe, rpc_modules calls
)
@ -84,6 +85,10 @@ type Client struct {
// This function, if non-nil, is called when the connection is lost.
reconnectFunc reconnectFunc
// config fields
batchItemLimit int
batchResponseMaxSize int
// writeConn is used for writing to the connection on the caller's goroutine. It should
// only be accessed outside of dispatch, with the write lock held. The write lock is
// taken by sending on reqInit and released by sending on reqSent.
@ -113,7 +118,7 @@ type clientConn struct {
func (c *Client) newClientConn(conn ServerCodec) *clientConn {
ctx := context.WithValue(context.Background(), clientContextKey{}, c)
ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo())
handler := newHandler(ctx, conn, c.idgen, c.services)
handler := newHandler(ctx, conn, c.idgen, c.services, c.batchItemLimit, c.batchResponseMaxSize)
return &clientConn{conn, handler}
}
@ -127,14 +132,17 @@ type readOp struct {
batch bool
}
// requestOp represents a pending request. This is used for both batch and non-batch
// requests.
type requestOp struct {
ids []json.RawMessage
err error
resp chan *jsonrpcMessage // receives up to len(ids) responses
sub *ClientSubscription // only set for EthSubscribe requests
ids []json.RawMessage
err error
resp chan []*jsonrpcMessage // receives up to len(ids) responses
sub *ClientSubscription // only set for EthSubscribe requests
hadResponse bool // true when the request was responded to
}
func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) {
func (op *requestOp) wait(ctx context.Context, c *Client) ([]*jsonrpcMessage, error) {
select {
case <-ctx.Done():
// Send the timeout to dispatch so it can remove the request IDs.
@ -166,7 +174,8 @@ func Dial(rawurl string) (*Client, error) {
return DialOptions(context.Background(), rawurl)
}
// DialContext creates a new RPC client, just like Dial.
// DialOptions creates a new RPC client for the given URL. You can supply any of the
// pre-defined client options to configure the underlying transport.
//
// The context is used to cancel or time out the initial connection establishment. It does
// not affect subsequent interactions with the client.
@ -210,7 +219,7 @@ func DialOptions(ctx context.Context, rawurl string, options ...ClientOption) (*
return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
}
return newClient(ctx, reconnect)
return newClient(ctx, cfg, reconnect)
}
// ClientFromContext retrieves the client from the context, if any. This can be used to perform
@ -220,33 +229,42 @@ func ClientFromContext(ctx context.Context) (*Client, bool) {
return client, ok
}
func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) {
func newClient(initctx context.Context, cfg *clientConfig, connect reconnectFunc) (*Client, error) {
conn, err := connect(initctx)
if err != nil {
return nil, err
}
c := initClient(conn, randomIDGenerator(), new(serviceRegistry))
c := initClient(conn, new(serviceRegistry), cfg)
c.reconnectFunc = connect
return c, nil
}
func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client {
func initClient(conn ServerCodec, services *serviceRegistry, cfg *clientConfig) *Client {
_, isHTTP := conn.(*httpConn)
c := &Client{
idgen: idgen,
isHTTP: isHTTP,
services: services,
writeConn: conn,
close: make(chan struct{}),
closing: make(chan struct{}),
didClose: make(chan struct{}),
reconnected: make(chan ServerCodec),
readOp: make(chan readOp),
readErr: make(chan error),
reqInit: make(chan *requestOp),
reqSent: make(chan error, 1),
reqTimeout: make(chan *requestOp),
isHTTP: isHTTP,
services: services,
idgen: cfg.idgen,
batchItemLimit: cfg.batchItemLimit,
batchResponseMaxSize: cfg.batchResponseLimit,
writeConn: conn,
close: make(chan struct{}),
closing: make(chan struct{}),
didClose: make(chan struct{}),
reconnected: make(chan ServerCodec),
readOp: make(chan readOp),
readErr: make(chan error),
reqInit: make(chan *requestOp),
reqSent: make(chan error, 1),
reqTimeout: make(chan *requestOp),
}
// Set defaults.
if c.idgen == nil {
c.idgen = randomIDGenerator()
}
// Launch the main loop.
if !isHTTP {
go c.dispatch(conn)
}
@ -324,7 +342,10 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
if err != nil {
return err
}
op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)}
op := &requestOp{
ids: []json.RawMessage{msg.ID},
resp: make(chan []*jsonrpcMessage, 1),
}
if c.isHTTP {
err = c.sendHTTP(ctx, op, msg)
@ -336,9 +357,12 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
}
// dispatch has accepted the request and will close the channel when it quits.
switch resp, err := op.wait(ctx, c); {
case err != nil:
batchresp, err := op.wait(ctx, c)
if err != nil {
return err
}
resp := batchresp[0]
switch {
case resp.Error != nil:
return resp.Error
case len(resp.Result) == 0:
@ -357,11 +381,17 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
// The result must be a pointer so that package json can unmarshal into it. You
// can also pass nil, in which case the result is ignored.
func (c *Client) GetResultCallContext(ctx context.Context, result interface{}, method string, args ...interface{}) (json.RawMessage, error) {
if result != nil && reflect.TypeOf(result).Kind() != reflect.Ptr {
return nil, fmt.Errorf("call result parameter must be pointer or nil interface: %v", result)
}
msg, err := c.newMessage(method, args...)
if err != nil {
return nil, err
}
op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)}
op := &requestOp{
ids: []json.RawMessage{msg.ID},
resp: make(chan []*jsonrpcMessage, 1),
}
if c.isHTTP {
err = c.sendHTTP(ctx, op, msg)
@ -373,15 +403,21 @@ func (c *Client) GetResultCallContext(ctx context.Context, result interface{}, m
}
// dispatch has accepted the request and will close the channel it when it quits.
switch resp, err := op.wait(ctx, c); {
case err != nil:
batchresp, err := op.wait(ctx, c)
if err != nil {
return nil, err
}
resp := batchresp[0]
switch {
case resp.Error != nil:
return nil, resp.Error
case len(resp.Result) == 0:
return nil, ErrNoResult
default:
return resp.Result, json.Unmarshal(resp.Result, &result)
if result == nil {
return nil, nil
}
return resp.Result, json.Unmarshal(resp.Result, result)
}
}
@ -413,7 +449,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
)
op := &requestOp{
ids: make([]json.RawMessage, len(b)),
resp: make(chan *jsonrpcMessage, len(b)),
resp: make(chan []*jsonrpcMessage, 1),
}
for i, elem := range b {
msg, err := c.newMessage(elem.Method, elem.Args...)
@ -431,28 +467,48 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
} else {
err = c.send(ctx, op, msgs)
}
if err != nil {
return err
}
batchresp, err := op.wait(ctx, c)
if err != nil {
return err
}
// Wait for all responses to come back.
for n := 0; n < len(b) && err == nil; n++ {
var resp *jsonrpcMessage
resp, err = op.wait(ctx, c)
if err != nil {
break
for n := 0; n < len(batchresp); n++ {
resp := batchresp[n]
if resp == nil {
// Ignore null responses. These can happen for batches sent via HTTP.
continue
}
// Find the element corresponding to this response.
// The element is guaranteed to be present because dispatch
// only sends valid IDs to our channel.
elem := &b[byID[string(resp.ID)]]
if resp.Error != nil {
index, ok := byID[string(resp.ID)]
if !ok {
continue
}
delete(byID, string(resp.ID))
// Assign result and error.
elem := &b[index]
switch {
case resp.Error != nil:
elem.Error = resp.Error
continue
}
if len(resp.Result) == 0 {
case resp.Result == nil:
elem.Error = ErrNoResult
continue
default:
elem.Error = json.Unmarshal(resp.Result, elem.Result)
}
elem.Error = json.Unmarshal(resp.Result, elem.Result)
}
// Check that all expected responses have been received.
for _, index := range byID {
elem := &b[index]
elem.Error = ErrMissingBatchResponse
}
return err
}
@ -513,7 +569,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf
}
op := &requestOp{
ids: []json.RawMessage{msg.ID},
resp: make(chan *jsonrpcMessage),
resp: make(chan []*jsonrpcMessage, 1),
sub: newClientSubscription(c, namespace, chanVal),
}

View file

@ -28,11 +28,18 @@ type ClientOption interface {
}
type clientConfig struct {
// HTTP settings
httpClient *http.Client
httpHeaders http.Header
httpAuth HTTPAuth
// WebSocket options
wsDialer *websocket.Dialer
// RPC handler options
idgen func() ID
batchItemLimit int
batchResponseLimit int
}
func (cfg *clientConfig) initHeaders() {
@ -104,3 +111,25 @@ func WithHTTPAuth(a HTTPAuth) ClientOption {
// Usually, HTTPAuth functions will call h.Set("authorization", "...") to add
// auth information to the request.
type HTTPAuth func(h http.Header) error
// WithBatchItemLimit changes the maximum number of items allowed in batch requests.
//
// Note: this option applies when processing incoming batch requests. It does not affect
// batch requests sent by the client.
func WithBatchItemLimit(limit int) ClientOption {
return optionFunc(func(cfg *clientConfig) {
cfg.batchItemLimit = limit
})
}
// WithBatchResponseSizeLimit changes the maximum number of response bytes that can be
// generated for batch requests. When this limit is reached, further calls in the batch
// will not be processed.
//
// Note: this option applies when processing incoming batch requests. It does not affect
// batch requests sent by the client.
func WithBatchResponseSizeLimit(sizeLimit int) ClientOption {
return optionFunc(func(cfg *clientConfig) {
cfg.batchResponseLimit = sizeLimit
})
}

View file

@ -169,10 +169,12 @@ func TestClientBatchRequest(t *testing.T) {
}
}
// This checks that, for HTTP connections, the length of batch responses is validated to
// match the request exactly.
func TestClientBatchRequest_len(t *testing.T) {
b, err := json.Marshal([]jsonrpcMessage{
{Version: "2.0", ID: json.RawMessage("1"), Method: "foo", Result: json.RawMessage(`"0x1"`)},
{Version: "2.0", ID: json.RawMessage("2"), Method: "bar", Result: json.RawMessage(`"0x2"`)},
{Version: "2.0", ID: json.RawMessage("1"), Result: json.RawMessage(`"0x1"`)},
{Version: "2.0", ID: json.RawMessage("2"), Result: json.RawMessage(`"0x2"`)},
})
if err != nil {
t.Fatal("failed to encode jsonrpc message:", err)
@ -185,37 +187,102 @@ func TestClientBatchRequest_len(t *testing.T) {
}))
t.Cleanup(s.Close)
client, err := Dial(s.URL)
if err != nil {
t.Fatal("failed to dial test server:", err)
}
defer client.Close()
t.Run("too-few", func(t *testing.T) {
client, err := Dial(s.URL)
if err != nil {
t.Fatal("failed to dial test server:", err)
}
defer client.Close()
batch := []BatchElem{
{Method: "foo"},
{Method: "bar"},
{Method: "baz"},
{Method: "foo", Result: new(string)},
{Method: "bar", Result: new(string)},
{Method: "baz", Result: new(string)},
}
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
defer cancelFn()
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
t.Errorf("expected %q but got: %v", ErrBadResult, err)
if err := client.BatchCallContext(ctx, batch); err != nil {
t.Fatal("error:", err)
}
for i, elem := range batch[:2] {
if elem.Error != nil {
t.Errorf("expected no error for batch element %d, got %q", i, elem.Error)
}
}
for i, elem := range batch[2:] {
if elem.Error != ErrMissingBatchResponse {
t.Errorf("wrong error %q for batch element %d", elem.Error, i+2)
}
}
})
t.Run("too-many", func(t *testing.T) {
client, err := Dial(s.URL)
if err != nil {
t.Fatal("failed to dial test server:", err)
}
defer client.Close()
batch := []BatchElem{
{Method: "foo"},
{Method: "foo", Result: new(string)},
}
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
defer cancelFn()
if err := client.BatchCallContext(ctx, batch); !errors.Is(err, ErrBadResult) {
t.Errorf("expected %q but got: %v", ErrBadResult, err)
if err := client.BatchCallContext(ctx, batch); err != nil {
t.Fatal("error:", err)
}
for i, elem := range batch[:1] {
if elem.Error != nil {
t.Errorf("expected no error for batch element %d, got %q", i, elem.Error)
}
}
for i, elem := range batch[1:] {
if elem.Error != ErrMissingBatchResponse {
t.Errorf("wrong error %q for batch element %d", elem.Error, i+2)
}
}
})
}
// This checks that the client can handle the case where the server doesn't
// respond to all requests in a batch.
func TestClientBatchRequestLimit(t *testing.T) {
server := newTestServer()
defer server.Stop()
server.SetBatchLimits(2, 100000)
client := DialInProc(server)
batch := []BatchElem{
{Method: "foo"},
{Method: "bar"},
{Method: "baz"},
}
err := client.BatchCall(batch)
if err != nil {
t.Fatal("unexpected error:", err)
}
// Check that the first response indicates an error with batch size.
var err0 Error
if !errors.As(batch[0].Error, &err0) {
t.Log("error zero:", batch[0].Error)
t.Fatalf("batch elem 0 has wrong error type: %T", batch[0].Error)
} else {
if err0.ErrorCode() != -32600 || err0.Error() != errMsgBatchTooLarge {
t.Fatalf("wrong error on batch elem zero: %v", err0)
}
}
// Check that remaining response batch elements are reported as absent.
for i, elem := range batch[1:] {
if elem.Error != ErrMissingBatchResponse {
t.Fatalf("batch elem %d has unexpected error: %v", i+1, elem.Error)
}
}
}
func TestClientNotify(t *testing.T) {
server := newTestServer()
defer server.Stop()
@ -310,7 +377,7 @@ func testClientCancel(transport string, t *testing.T) {
_, hasDeadline := ctx.Deadline()
t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline)
// default:
// t.Logf("got expected error with %v wait time: %v", timeout, err)
// t.Logf("got expected error with %v wait time: %v", timeout, err)
}
cancel()
}
@ -487,7 +554,8 @@ func TestClientSubscriptionUnsubscribeServer(t *testing.T) {
defer srv.Stop()
// Create the client on the other end of the pipe.
client, _ := newClient(context.Background(), func(context.Context) (ServerCodec, error) {
cfg := new(clientConfig)
client, _ := newClient(context.Background(), cfg, func(context.Context) (ServerCodec, error) {
return NewCodec(p2), nil
})
defer client.Close()

View file

@ -61,12 +61,15 @@ const (
errcodeDefault = -32000
errcodeNotificationsUnsupported = -32001
errcodeTimeout = -32002
errcodeResponseTooLarge = -32003
errcodePanic = -32603
errcodeMarshalError = -32603
)
const (
errMsgTimeout = "request timed out"
errMsgTimeout = "request timed out"
errMsgResponseTooLarge = "response too large"
errMsgBatchTooLarge = "batch too large"
)
type methodNotFoundError struct{ method string }

View file

@ -49,17 +49,19 @@ import (
// h.removeRequestOp(op) // timeout, etc.
// }
type handler struct {
reg *serviceRegistry
unsubscribeCb *callback
idgen func() ID // subscription ID generator
respWait map[string]*requestOp // active client requests
clientSubs map[string]*ClientSubscription // active client subscriptions
callWG sync.WaitGroup // pending call goroutines
rootCtx context.Context // canceled by close()
cancelRoot func() // cancel function for rootCtx
conn jsonWriter // where responses will be sent
log log.Logger
allowSubscribe bool
reg *serviceRegistry
unsubscribeCb *callback
idgen func() ID // subscription ID generator
respWait map[string]*requestOp // active client requests
clientSubs map[string]*ClientSubscription // active client subscriptions
callWG sync.WaitGroup // pending call goroutines
rootCtx context.Context // canceled by close()
cancelRoot func() // cancel function for rootCtx
conn jsonWriter // where responses will be sent
log log.Logger
allowSubscribe bool
batchRequestLimit int
batchResponseMaxSize int
subLock sync.Mutex
serverSubs map[ID]*Subscription
@ -70,19 +72,21 @@ type callProc struct {
notifiers []*Notifier
}
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler {
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, batchRequestLimit, batchResponseMaxSize int) *handler {
rootCtx, cancelRoot := context.WithCancel(connCtx)
h := &handler{
reg: reg,
idgen: idgen,
conn: conn,
respWait: make(map[string]*requestOp),
clientSubs: make(map[string]*ClientSubscription),
rootCtx: rootCtx,
cancelRoot: cancelRoot,
allowSubscribe: true,
serverSubs: make(map[ID]*Subscription),
log: log.Root(),
reg: reg,
idgen: idgen,
conn: conn,
respWait: make(map[string]*requestOp),
clientSubs: make(map[string]*ClientSubscription),
rootCtx: rootCtx,
cancelRoot: cancelRoot,
allowSubscribe: true,
serverSubs: make(map[ID]*Subscription),
log: log.Root(),
batchRequestLimit: batchRequestLimit,
batchResponseMaxSize: batchResponseMaxSize,
}
if conn.remoteAddr() != "" {
h.log = h.log.New("conn", conn.remoteAddr())
@ -134,16 +138,15 @@ func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) {
b.doWrite(ctx, conn, false)
}
// timeout sends the responses added so far. For the remaining unanswered call
// messages, it sends a timeout error response.
func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) {
// respondWithError sends the responses added so far. For the remaining unanswered call
// messages, it responds with the given error.
func (b *batchCallBuffer) respondWithError(ctx context.Context, conn jsonWriter, err error) {
b.mutex.Lock()
defer b.mutex.Unlock()
for _, msg := range b.calls {
if !msg.isNotification() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
b.resp = append(b.resp, resp)
b.resp = append(b.resp, msg.errorResponse(err))
}
}
b.doWrite(ctx, conn, true)
@ -171,17 +174,24 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
})
return
}
// Handle non-call messages first:
calls := make([]*jsonrpcMessage, 0, len(msgs))
for _, msg := range msgs {
if handled := h.handleImmediate(msg); !handled {
calls = append(calls, msg)
}
// Apply limit on total number of requests.
if h.batchRequestLimit != 0 && len(msgs) > h.batchRequestLimit {
h.startCallProc(func(cp *callProc) {
h.respondWithBatchTooLarge(cp, msgs)
})
return
}
// Handle non-call messages first.
// Here we need to find the requestOp that sent the request batch.
calls := make([]*jsonrpcMessage, 0, len(msgs))
h.handleResponses(msgs, func(msg *jsonrpcMessage) {
calls = append(calls, msg)
})
if len(calls) == 0 {
return
}
// Process calls on a goroutine because they may block indefinitely:
h.startCallProc(func(cp *callProc) {
var (
@ -199,10 +209,12 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
callBuffer.timeout(cp.ctx, h.conn)
err := &internalServerError{errcodeTimeout, errMsgTimeout}
callBuffer.respondWithError(cp.ctx, h.conn, err)
})
}
responseBytes := 0
for {
// No need to handle rest of calls if timed out.
if cp.ctx.Err() != nil {
@ -214,61 +226,88 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
}
resp := h.handleCallMsg(cp, msg)
callBuffer.pushResponse(resp)
if resp != nil && h.batchResponseMaxSize != 0 {
responseBytes += len(resp.Result)
if responseBytes > h.batchResponseMaxSize {
err := &internalServerError{errcodeResponseTooLarge, errMsgResponseTooLarge}
callBuffer.respondWithError(cp.ctx, h.conn, err)
break
}
}
}
if timer != nil {
timer.Stop()
}
callBuffer.write(cp.ctx, h.conn)
h.addSubscriptions(cp.notifiers)
callBuffer.write(cp.ctx, h.conn)
for _, n := range cp.notifiers {
n.activate()
}
})
}
func (h *handler) respondWithBatchTooLarge(cp *callProc, batch []*jsonrpcMessage) {
resp := errorMessage(&invalidRequestError{errMsgBatchTooLarge})
// Find the first call and add its "id" field to the error.
// This is the best we can do, given that the protocol doesn't have a way
// of reporting an error for the entire batch.
for _, msg := range batch {
if msg.isCall() {
resp.ID = msg.ID
break
}
}
h.conn.writeJSON(cp.ctx, []*jsonrpcMessage{resp}, true)
}
// handleMsg handles a single message.
func (h *handler) handleMsg(msg *jsonrpcMessage) {
if ok := h.handleImmediate(msg); ok {
return
}
h.startCallProc(func(cp *callProc) {
var (
responded sync.Once
timer *time.Timer
cancel context.CancelFunc
)
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()
// Cancel the request context after timeout and send an error response. Since the
// running method might not return immediately on timeout, we must wait for the
// timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
responded.Do(func() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
h.conn.writeJSON(cp.ctx, resp, true)
})
})
}
answer := h.handleCallMsg(cp, msg)
if timer != nil {
timer.Stop()
}
h.addSubscriptions(cp.notifiers)
if answer != nil {
responded.Do(func() {
h.conn.writeJSON(cp.ctx, answer, false)
})
}
for _, n := range cp.notifiers {
n.activate()
}
msgs := []*jsonrpcMessage{msg}
h.handleResponses(msgs, func(msg *jsonrpcMessage) {
h.startCallProc(func(cp *callProc) {
h.handleNonBatchCall(cp, msg)
})
})
}
func (h *handler) handleNonBatchCall(cp *callProc, msg *jsonrpcMessage) {
var (
responded sync.Once
timer *time.Timer
cancel context.CancelFunc
)
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()
// Cancel the request context after timeout and send an error response. Since the
// running method might not return immediately on timeout, we must wait for the
// timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
responded.Do(func() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
h.conn.writeJSON(cp.ctx, resp, true)
})
})
}
answer := h.handleCallMsg(cp, msg)
if timer != nil {
timer.Stop()
}
h.addSubscriptions(cp.notifiers)
if answer != nil {
responded.Do(func() {
h.conn.writeJSON(cp.ctx, answer, false)
})
}
for _, n := range cp.notifiers {
n.activate()
}
}
// close cancels all requests except for inflightReq and waits for
// call goroutines to shut down.
func (h *handler) close(err error, inflightReq *requestOp) {
@ -349,23 +388,60 @@ func (h *handler) startCallProc(fn func(*callProc)) {
}()
}
// handleImmediate executes non-call messages. It returns false if the message is a
// call or requires a reply.
func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
start := time.Now()
switch {
case msg.isNotification():
if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
h.handleSubscriptionResult(msg)
return true
// handleResponse processes method call responses.
func (h *handler) handleResponses(batch []*jsonrpcMessage, handleCall func(*jsonrpcMessage)) {
var resolvedops []*requestOp
handleResp := func(msg *jsonrpcMessage) {
op := h.respWait[string(msg.ID)]
if op == nil {
h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
return
}
return false
case msg.isResponse():
h.handleResponse(msg)
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
return true
default:
return false
resolvedops = append(resolvedops, op)
delete(h.respWait, string(msg.ID))
// For subscription responses, start the subscription if the server
// indicates success. EthSubscribe gets unblocked in either case through
// the op.resp channel.
if op.sub != nil {
if msg.Error != nil {
op.err = msg.Error
} else {
op.err = json.Unmarshal(msg.Result, &op.sub.subid)
if op.err == nil {
go op.sub.run()
h.clientSubs[op.sub.subid] = op.sub
}
}
}
if !op.hadResponse {
op.hadResponse = true
op.resp <- batch
}
}
for _, msg := range batch {
start := time.Now()
switch {
case msg.isResponse():
handleResp(msg)
h.log.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
case msg.isNotification():
if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
h.handleSubscriptionResult(msg)
continue
}
handleCall(msg)
default:
handleCall(msg)
}
}
for _, op := range resolvedops {
h.removeRequestOp(op)
}
}
@ -381,33 +457,6 @@ func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
}
}
// handleResponse processes method call responses.
func (h *handler) handleResponse(msg *jsonrpcMessage) {
op := h.respWait[string(msg.ID)]
if op == nil {
h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
return
}
delete(h.respWait, string(msg.ID))
// For normal responses, just forward the reply to Call/BatchCall.
if op.sub == nil {
op.resp <- msg
return
}
// For subscription responses, start the subscription if the server
// indicates success. EthSubscribe gets unblocked in either case through
// the op.resp channel.
defer close(op.resp)
if msg.Error != nil {
op.err = msg.Error
return
}
if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil {
go op.sub.run()
h.clientSubs[op.sub.subid] = op.sub
}
}
// handleCallMsg executes a call message and returns the answer.
func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
start := time.Now()
@ -416,6 +465,7 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
h.handleCall(ctx, msg)
h.log.Debug("Served "+msg.Method, "duration", time.Since(start))
return nil
case msg.isCall():
resp := h.handleCall(ctx, msg)
var ctx []interface{}
@ -430,8 +480,10 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
h.log.Debug("Served "+msg.Method, ctx...)
}
return resp
case msg.hasValidID():
return msg.errorResponse(&invalidRequestError{"invalid request"})
default:
return errorMessage(&invalidRequestError{"invalid request"})
}
@ -451,12 +503,14 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
if callb == nil {
return msg.errorResponse(&methodNotFoundError{method: msg.Method})
}
args, err := parsePositionalArguments(msg.Params, callb.argTypes)
if err != nil {
return msg.errorResponse(&invalidParamsError{err.Error()})
}
start := time.Now()
answer := h.runMethod(cp.ctx, msg, callb, args)
// Collect the statistics for RPC calls if metrics is enabled.
// We only care about pure rpc call. Filter out subscription.
if callb != h.unsubscribeCb {
@ -469,6 +523,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
rpcServingTimer.UpdateSince(start)
updateServeTimeHistogram(msg.Method, answer.Error == nil, time.Since(start))
}
return answer
}

View file

@ -139,7 +139,7 @@ func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
var cfg clientConfig
cfg.httpClient = client
fn := newClientTransportHTTP(endpoint, &cfg)
return newClient(context.Background(), fn)
return newClient(context.Background(), &cfg, fn)
}
func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc {
@ -176,11 +176,12 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
}
defer respBody.Close()
var respmsg jsonrpcMessage
if err := json.NewDecoder(respBody).Decode(&respmsg); err != nil {
var resp jsonrpcMessage
batch := [1]*jsonrpcMessage{&resp}
if err := json.NewDecoder(respBody).Decode(&resp); err != nil {
return err
}
op.resp <- &respmsg
op.resp <- batch[:]
return nil
}
@ -191,16 +192,12 @@ func (c *Client) sendBatchHTTP(ctx context.Context, op *requestOp, msgs []*jsonr
return err
}
defer respBody.Close()
var respmsgs []jsonrpcMessage
var respmsgs []*jsonrpcMessage
if err := json.NewDecoder(respBody).Decode(&respmsgs); err != nil {
return err
}
if len(respmsgs) != len(msgs) {
return fmt.Errorf("batch has %d requests but response has %d: %w", len(msgs), len(respmsgs), ErrBadResult)
}
for i := 0; i < len(respmsgs); i++ {
op.resp <- &respmsgs[i]
}
op.resp <- respmsgs
return nil
}

View file

@ -24,7 +24,8 @@ import (
// DialInProc attaches an in-process connection to the given RPC server.
func DialInProc(handler *Server) *Client {
initctx := context.Background()
c, _ := newClient(initctx, func(context.Context) (ServerCodec, error) {
cfg := new(clientConfig)
c, _ := newClient(initctx, cfg, func(context.Context) (ServerCodec, error) {
p1, p2 := net.Pipe()
go handler.ServeCodec(NewCodec(p1), 0)
return NewCodec(p2), nil

View file

@ -24,17 +24,17 @@ import (
"github.com/XinFinOrg/XDPoSChain/p2p/netutil"
)
// ServeListener accepts connections on l, serving IPC-RPC on them.
// ServeListener accepts connections on l, serving JSON-RPC on them.
func (s *Server) ServeListener(l net.Listener) error {
for {
conn, err := l.Accept()
if netutil.IsTemporaryError(err) {
log.Warn("IPC accept error", "err", err)
log.Warn("RPC accept error", "err", err)
continue
} else if err != nil {
return err
}
log.Trace("IPC accepted connection")
log.Trace("Accepted RPC connection", "conn", conn.RemoteAddr())
go s.ServeCodec(NewCodec(conn), 0)
}
}
@ -46,7 +46,8 @@ func (s *Server) ServeListener(l net.Listener) error {
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialIPC(ctx context.Context, endpoint string) (*Client, error) {
return newClient(ctx, newClientTransportIPC(endpoint))
cfg := new(clientConfig)
return newClient(ctx, cfg, newClientTransportIPC(endpoint))
}
func newClientTransportIPC(endpoint string) reconnectFunc {

View file

@ -46,9 +46,11 @@ type Server struct {
services serviceRegistry
idgen func() ID
mutex sync.Mutex
codecs map[ServerCodec]struct{}
run atomic.Bool
mutex sync.Mutex
codecs map[ServerCodec]struct{}
run atomic.Bool
batchItemLimit int
batchResponseLimit int
}
// NewServer creates a new server instance with no registered handlers.
@ -65,6 +67,17 @@ func NewServer() *Server {
return server
}
// SetBatchLimits sets limits applied to batch requests. There are two limits: 'itemLimit'
// is the maximum number of items in a batch. 'maxResponseSize' is the maximum number of
// response bytes across all requests in a batch.
//
// This method should be called before processing any requests via ServeCodec, ServeHTTP,
// ServeListener etc.
func (s *Server) SetBatchLimits(itemLimit, maxResponseSize int) {
s.batchItemLimit = itemLimit
s.batchResponseLimit = maxResponseSize
}
// RegisterName creates a service for the given receiver type under the given name. When no
// methods on the given receiver match the criteria to be either a RPC method or a
// subscription an error is returned. Otherwise a new service is created and added to the
@ -86,7 +99,12 @@ func (s *Server) ServeCodec(codec ServerCodec, options CodecOption) {
}
defer s.untrackCodec(codec)
c := initClient(codec, s.idgen, &s.services)
cfg := &clientConfig{
idgen: s.idgen,
batchItemLimit: s.batchItemLimit,
batchResponseLimit: s.batchResponseLimit,
}
c := initClient(codec, &s.services, cfg)
<-codec.closed()
c.Close()
}
@ -118,7 +136,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
return
}
h := newHandler(ctx, codec, s.idgen, &s.services)
h := newHandler(ctx, codec, s.idgen, &s.services, s.batchItemLimit, s.batchResponseLimit)
h.allowSubscribe = false
defer h.close(io.EOF, nil)

View file

@ -70,6 +70,7 @@ func TestServer(t *testing.T) {
func runTestScript(t *testing.T, file string) {
server := newTestServer()
server.SetBatchLimits(4, 100000)
content, err := os.ReadFile(file)
if err != nil {
t.Fatal(err)
@ -152,3 +153,41 @@ func TestServerShortLivedConn(t *testing.T) {
}
}
}
func TestServerBatchResponseSizeLimit(t *testing.T) {
server := newTestServer()
defer server.Stop()
server.SetBatchLimits(100, 60)
var (
batch []BatchElem
client = DialInProc(server)
)
for i := 0; i < 5; i++ {
batch = append(batch, BatchElem{
Method: "test_echo",
Args: []any{"x", 1},
Result: new(echoResult),
})
}
if err := client.BatchCall(batch); err != nil {
t.Fatal("error sending batch:", err)
}
for i := range batch {
// We expect the first two queries to be ok, but after that the size limit takes effect.
if i < 2 {
if batch[i].Error != nil {
t.Fatalf("batch elem %d has unexpected error: %v", i, batch[i].Error)
}
continue
}
// After two, we expect an error.
re, ok := batch[i].Error.(Error)
if !ok {
t.Fatalf("batch elem %d has wrong error: %v", i, batch[i].Error)
}
wantedCode := errcodeResponseTooLarge
if re.ErrorCode() != wantedCode {
t.Errorf("batch elem %d wrong error code, have %d want %d", i, re.ErrorCode(), wantedCode)
}
}
}

View file

@ -32,7 +32,8 @@ func DialStdIO(ctx context.Context) (*Client, error) {
// DialIO creates a client which uses the given IO channels
func DialIO(ctx context.Context, in io.Reader, out io.Writer) (*Client, error) {
return newClient(ctx, newClientTransportIO(in, out))
cfg := new(clientConfig)
return newClient(ctx, cfg, newClientTransportIO(in, out))
}
func newClientTransportIO(in io.Reader, out io.Writer) reconnectFunc {

13
rpc/testdata/invalid-batch-toolarge.js vendored Normal file
View file

@ -0,0 +1,13 @@
// This file checks the behavior of the batch item limit code.
// In tests, the batch item limit is set to 4. So to trigger the error,
// all batches in this file have 5 elements.
// For batches that do not contain any calls, a response message with "id" == null
// is returned.
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
<-- [{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"batch too large"}}]
// For batches with at least one call, the call's "id" is used.
--> [{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","id":3,"method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]},{"jsonrpc":"2.0","method":"test_echo","params":["x",99]}]
<-- [{"jsonrpc":"2.0","id":3,"error":{"code":-32600,"message":"batch too large"}}]

View file

@ -197,7 +197,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
if err != nil {
return nil, err
}
return newClient(ctx, connect)
return newClient(ctx, cfg, connect)
}
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
@ -214,7 +214,7 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error
if err != nil {
return nil, err
}
return newClient(ctx, connect)
return newClient(ctx, cfg, connect)
}
func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) {