diff --git a/api/config.go b/api/config.go index f937750c7c..60aa28c67e 100644 --- a/api/config.go +++ b/api/config.go @@ -25,6 +25,8 @@ type Config struct { BatchRequestLimit int `yaml:"batchRequestLimit"` // WebsocketRateLimit is the maximum number of messages per second per client. WebsocketRateLimit int `yaml:"websocketRateLimit"` + // ListenerLimit is the maximum number of listeners. + ListenerLimit int `yaml:"listenerLimit"` } // DefaultConfig is the default config @@ -38,4 +40,5 @@ var DefaultConfig = Config{ RangeQueryLimit: 1000, BatchRequestLimit: _defaultBatchRequestLimit, WebsocketRateLimit: 5, + ListenerLimit: 5000, } diff --git a/api/context.go b/api/context.go new file mode 100644 index 0000000000..d772ad6644 --- /dev/null +++ b/api/context.go @@ -0,0 +1,48 @@ +package api + +import ( + "context" + "sync" +) + +type ( + streamContextKey struct{} + + StreamContext struct { + listenerIDs map[string]struct{} + mutex sync.Mutex + } +) + +func (sc *StreamContext) AddListener(id string) { + sc.mutex.Lock() + defer sc.mutex.Unlock() + sc.listenerIDs[id] = struct{}{} +} + +func (sc *StreamContext) RemoveListener(id string) { + sc.mutex.Lock() + defer sc.mutex.Unlock() + delete(sc.listenerIDs, id) +} + +func (sc *StreamContext) ListenerIDs() []string { + sc.mutex.Lock() + defer sc.mutex.Unlock() + ids := make([]string, 0, len(sc.listenerIDs)) + for id := range sc.listenerIDs { + ids = append(ids, id) + } + return ids +} + +func WithStreamContext(ctx context.Context) context.Context { + return context.WithValue(ctx, streamContextKey{}, &StreamContext{ + listenerIDs: make(map[string]struct{}), + }) +} + +func StreamFromContext(ctx context.Context) (*StreamContext, bool) { + sc, ok := ctx.Value(streamContextKey{}).(*StreamContext) + return sc, ok +} diff --git a/api/coreservice.go b/api/coreservice.go index efda2b4c6d..402c105c4b 100644 --- a/api/coreservice.go +++ b/api/coreservice.go @@ -279,7 +279,7 @@ func newCoreService( ap: actPool, cfg: cfg, registry: registry, - chainListener: NewChainListener(500), + chainListener: NewChainListener(cfg.ListenerLimit), gs: gasstation.NewGasStation(chain, dao, cfg.GasStation), readCache: NewReadCache(), getBlockTime: getBlockTime, diff --git a/api/grpcserver.go b/api/grpcserver.go index bbff0696ac..89321ccd9c 100644 --- a/api/grpcserver.go +++ b/api/grpcserver.go @@ -573,15 +573,17 @@ func (svr *gRPCHandler) StreamBlocks(_ *iotexapi.StreamBlocksRequest, stream iot errChan := make(chan error) defer close(errChan) chainListener := svr.coreService.ChainListener() - if _, err := chainListener.AddResponder(NewGRPCBlockListener( + id, err := chainListener.AddResponder(NewGRPCBlockListener( func(resp interface{}) (int, error) { return 0, stream.Send(resp.(*iotexapi.StreamBlocksResponse)) }, errChan, - )); err != nil { + )) + if err != nil { return status.Error(codes.Internal, err.Error()) } - err := <-errChan + err = <-errChan + chainListener.RemoveResponder(id) if err != nil { return status.Error(codes.Aborted, err.Error()) } @@ -596,16 +598,18 @@ func (svr *gRPCHandler) StreamLogs(in *iotexapi.StreamLogsRequest, stream iotexa errChan := make(chan error) defer close(errChan) chainListener := svr.coreService.ChainListener() - if _, err := chainListener.AddResponder(NewGRPCLogListener( + id, err := chainListener.AddResponder(NewGRPCLogListener( logfilter.NewLogFilter(in.GetFilter()), func(in interface{}) (int, error) { return 0, stream.Send(in.(*iotexapi.StreamLogsResponse)) }, errChan, - )); err != nil { + )) + if err != nil { return status.Error(codes.Internal, err.Error()) } - err := <-errChan + err = <-errChan + chainListener.RemoveResponder(id) if err != nil { return status.Error(codes.Aborted, err.Error()) } diff --git a/api/grpcserver_test.go b/api/grpcserver_test.go index 73d1ea85f8..4f9e10266d 100644 --- a/api/grpcserver_test.go +++ b/api/grpcserver_test.go @@ -358,6 +358,9 @@ func TestGrpcServer_StreamBlocks(t *testing.T) { }() return "", nil }) + listener.EXPECT().RemoveResponder(gomock.Any()).DoAndReturn(func(string) (bool, error) { + return true, nil + }) core.EXPECT().ChainListener().Return(listener) err := grpcSvr.StreamBlocks(&iotexapi.StreamBlocksRequest{}, nil) require.NoError(err) @@ -390,6 +393,9 @@ func TestGrpcServer_StreamLogs(t *testing.T) { }() return "", nil }) + listener.EXPECT().RemoveResponder(gomock.Any()).DoAndReturn(func(string) (bool, error) { + return true, nil + }) core.EXPECT().ChainListener().Return(listener) err := grpcSvr.StreamLogs(&iotexapi.StreamLogsRequest{Filter: &iotexapi.LogsFilter{}}, nil) require.NoError(err) diff --git a/api/listener.go b/api/listener.go index 8770163dd1..042bedcd45 100644 --- a/api/listener.go +++ b/api/listener.go @@ -63,6 +63,7 @@ func (cl *chainListener) Stop() error { return nil }) cl.streamMap.Reset() + apiLimitMtcs.WithLabelValues("listener").Set(float64(cl.streamMap.Count())) return nil } @@ -105,6 +106,7 @@ func (cl *chainListener) AddResponder(responder apitypes.Responder) (string, err } cl.streamMap.Set(listenerID, responder) + apiLimitMtcs.WithLabelValues("listener").Set(float64(cl.streamMap.Count())) return listenerID, nil } @@ -122,6 +124,7 @@ func (cl *chainListener) RemoveResponder(listenerID string) (bool, error) { return false, errListenerNotFound } r.Exit() + apiLimitMtcs.WithLabelValues("listener").Set(float64(cl.streamMap.Count() - 1)) return cl.streamMap.Delete(listenerID), nil } diff --git a/api/metrics.go b/api/metrics.go new file mode 100644 index 0000000000..98352870aa --- /dev/null +++ b/api/metrics.go @@ -0,0 +1,14 @@ +package api + +import "github.com/prometheus/client_golang/prometheus" + +var ( + apiLimitMtcs = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "iotex_api_limit_metrics", + Help: "api limit metrics.", + }, []string{"limit"}) +) + +func init() { + prometheus.MustRegister(apiLimitMtcs) +} diff --git a/api/serverV2.go b/api/serverV2.go index f40e65586f..728a437f2a 100644 --- a/api/serverV2.go +++ b/api/serverV2.go @@ -67,7 +67,7 @@ func NewServerV2( wrappedWeb3Handler := otelhttp.NewHandler(newHTTPHandler(web3Handler), "web3.jsonrpc") limiter := rate.NewLimiter(rate.Limit(cfg.WebsocketRateLimit), 1) - wrappedWebsocketHandler := otelhttp.NewHandler(NewWebsocketHandler(web3Handler, limiter), "web3.websocket") + wrappedWebsocketHandler := otelhttp.NewHandler(NewWebsocketHandler(coreAPI, web3Handler, limiter), "web3.websocket") return &ServerV2{ core: coreAPI, diff --git a/api/serverV2_test.go b/api/serverV2_test.go index aa6d84e7dd..c3df0a9e77 100644 --- a/api/serverV2_test.go +++ b/api/serverV2_test.go @@ -28,7 +28,7 @@ func TestServerV2(t *testing.T) { core: core, grpcServer: NewGRPCServer(core, testutil.RandomPort()), httpSvr: NewHTTPServer("", testutil.RandomPort(), newHTTPHandler(web3Handler)), - websocketSvr: NewHTTPServer("", testutil.RandomPort(), NewWebsocketHandler(web3Handler, nil)), + websocketSvr: NewHTTPServer("", testutil.RandomPort(), NewWebsocketHandler(core, web3Handler, nil)), } ctx := context.Background() diff --git a/api/web3server.go b/api/web3server.go index 18eeae6c26..0a746581bc 100644 --- a/api/web3server.go +++ b/api/web3server.go @@ -82,6 +82,7 @@ var ( errInvalidBlock = errors.New("invalid block") errUnsupportedAction = errors.New("the type of action is not supported") errMsgBatchTooLarge = errors.New("batch too large") + errHTTPNotSupported = errors.New("http not supported") _pendingBlockNumber = "pending" _latestBlockNumber = "latest" @@ -224,7 +225,11 @@ func (svr *web3Handler) handleWeb3Req(ctx context.Context, web3Req *gjson.Result case "eth_newBlockFilter": res, err = svr.newBlockFilter() case "eth_subscribe": - res, err = svr.subscribe(web3Req, writer) + sc, ok := StreamFromContext(ctx) + if !ok { + return errHTTPNotSupported + } + res, err = svr.subscribe(sc, web3Req, writer) case "eth_unsubscribe": res, err = svr.unsubscribe(web3Req) //TODO: enable debug api after archive mode is supported @@ -924,35 +929,36 @@ func (svr *web3Handler) getFilterLogs(in *gjson.Result) (interface{}, error) { return svr.getLogsWithFilter(from, to, filterObj.Address, filterObj.Topics) } -func (svr *web3Handler) subscribe(in *gjson.Result, writer apitypes.Web3ResponseWriter) (interface{}, error) { +func (svr *web3Handler) subscribe(ctx *StreamContext, in *gjson.Result, writer apitypes.Web3ResponseWriter) (interface{}, error) { subscription := in.Get("params.0") if !subscription.Exists() { return nil, errInvalidFormat } switch subscription.String() { case "newHeads": - return svr.streamBlocks(writer) + return svr.streamBlocks(ctx, writer) case "logs": filter, err := parseLogRequest(in.Get("params.1")) if err != nil { return nil, err } - return svr.streamLogs(filter, writer) + return svr.streamLogs(ctx, filter, writer) default: return nil, errInvalidFormat } } -func (svr *web3Handler) streamBlocks(writer apitypes.Web3ResponseWriter) (interface{}, error) { +func (svr *web3Handler) streamBlocks(ctx *StreamContext, writer apitypes.Web3ResponseWriter) (interface{}, error) { chainListener := svr.coreService.ChainListener() streamID, err := chainListener.AddResponder(NewWeb3BlockListener(writer.Write)) if err != nil { return nil, err } + ctx.AddListener(streamID) return streamID, nil } -func (svr *web3Handler) streamLogs(filterObj *filterObject, writer apitypes.Web3ResponseWriter) (interface{}, error) { +func (svr *web3Handler) streamLogs(ctx *StreamContext, filterObj *filterObject, writer apitypes.Web3ResponseWriter) (interface{}, error) { filter, err := newLogFilterFrom(filterObj.Address, filterObj.Topics) if err != nil { return nil, err @@ -962,6 +968,7 @@ func (svr *web3Handler) streamLogs(filterObj *filterObject, writer apitypes.Web3 if err != nil { return nil, err } + ctx.AddListener(streamID) return streamID, nil } diff --git a/api/web3server_test.go b/api/web3server_test.go index 4327f1a4e9..a94624a637 100644 --- a/api/web3server_test.go +++ b/api/web3server_test.go @@ -1125,34 +1125,39 @@ func TestSubscribe(t *testing.T) { t.Run("newHeads subscription", func(t *testing.T) { in := gjson.Parse(`{"params":["newHeads"]}`) - ret, err := web3svr.subscribe(&in, writer) + sc, _ := StreamFromContext(WithStreamContext(context.Background())) + ret, err := web3svr.subscribe(sc, &in, writer) require.NoError(err) require.Equal("streamid_1", ret.(string)) }) t.Run("logs subscription", func(t *testing.T) { in := gjson.Parse(`{"params":["logs",{"fromBlock":"1","fromBlock":"2","address":["0x0000000000000000000000000000000000000001"],"topics":[["0x5f746f70696331"]]}]}`) - ret, err := web3svr.subscribe(&in, writer) + sc, _ := StreamFromContext(WithStreamContext(context.Background())) + ret, err := web3svr.subscribe(sc, &in, writer) require.NoError(err) require.Equal("streamid_1", ret.(string)) }) t.Run("logs topic not array", func(t *testing.T) { in := gjson.Parse(`{"params":["logs",{"fromBlock":"1","fromBlock":"2","address":["0x0000000000000000000000000000000000000001"],"topics":["0x5f746f70696331"]}]}`) - ret, err := web3svr.subscribe(&in, writer) + sc, _ := StreamFromContext(WithStreamContext(context.Background())) + ret, err := web3svr.subscribe(sc, &in, writer) require.NoError(err) require.Equal("streamid_1", ret.(string)) }) t.Run("nil params", func(t *testing.T) { inNil := gjson.Parse(`{"params":[]}`) - _, err := web3svr.subscribe(&inNil, writer) + sc, _ := StreamFromContext(WithStreamContext(context.Background())) + _, err := web3svr.subscribe(sc, &inNil, writer) require.EqualError(err, errInvalidFormat.Error()) }) t.Run("nil logs", func(t *testing.T) { inNil := gjson.Parse(`{"params":["logs"]}`) - _, err := web3svr.subscribe(&inNil, writer) + sc, _ := StreamFromContext(WithStreamContext(context.Background())) + _, err := web3svr.subscribe(sc, &inNil, writer) require.EqualError(err, errInvalidFormat.Error()) }) } diff --git a/api/websocket.go b/api/websocket.go index 784f0982df..13476f58c2 100644 --- a/api/websocket.go +++ b/api/websocket.go @@ -31,8 +31,9 @@ const ( // WebsocketHandler handles requests from websocket protocol type WebsocketHandler struct { - msgHandler Web3Handler - limiter *rate.Limiter + coreService CoreService + msgHandler Web3Handler + limiter *rate.Limiter } var upgrader = websocket.Upgrader{ @@ -75,14 +76,15 @@ func (c *safeWebsocketConn) SetWriteDeadline(t time.Time) error { } // NewWebsocketHandler creates a new websocket handler -func NewWebsocketHandler(web3Handler Web3Handler, limiter *rate.Limiter) *WebsocketHandler { +func NewWebsocketHandler(coreService CoreService, web3Handler Web3Handler, limiter *rate.Limiter) *WebsocketHandler { if limiter == nil { // set the limiter to the maximum possible rate limiter = rate.NewLimiter(rate.Limit(math.MaxFloat64), 1) } return &WebsocketHandler{ - msgHandler: web3Handler, - limiter: limiter, + msgHandler: web3Handler, + limiter: limiter, + coreService: coreService, } } @@ -112,10 +114,18 @@ func (wsSvr *WebsocketHandler) handleConnection(ctx context.Context, ws *websock return nil }) - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(WithStreamContext(ctx)) safeWs := &safeWebsocketConn{ws: ws} go ping(ctx, safeWs, cancel) + defer func() { + // clean up the stream context + sc, _ := StreamFromContext(ctx) + for _, id := range sc.ListenerIDs() { + wsSvr.coreService.ChainListener().RemoveResponder(id) + } + }() + for { select { case <-ctx.Done():