diff --git a/indexer/services/socks/__tests__/lib/invalid-message.test.ts b/indexer/services/socks/__tests__/lib/invalid-message.test.ts new file mode 100644 index 0000000000..f9eb8c2b05 --- /dev/null +++ b/indexer/services/socks/__tests__/lib/invalid-message.test.ts @@ -0,0 +1,57 @@ +import { InvalidMessageHandler } from '../../src/lib/invalid-message'; +import { RateLimiter } from '../../src/lib/rate-limit'; +import { sendMessage } from '../../src/helpers/wss'; +import WebSocket from 'ws'; +import { WS_CLOSE_CODE_POLICY_VIOLATION } from '../../src/lib/constants'; +import { Connection, WebsocketEvents } from '../../src/types'; + +jest.mock('../../src/lib/rate-limit'); +jest.mock('../../src/helpers/wss', () => ({ + sendMessage: jest.fn(), +})); + +describe('InvalidMessageHandler', () => { + let invalidMessageHandler: InvalidMessageHandler; + let mockConnection: Connection; + const connectionId = 'testConnectionId'; + const responseMessage = 'Test response message'; + + beforeEach(() => { + (RateLimiter as jest.Mock).mockImplementation(() => ({ + rateLimit: jest.fn().mockReturnValue(0), + removeConnection: jest.fn(), + })); + + mockConnection = { + ws: { + close: jest.fn(), + removeAllListeners: jest.fn(), + } as unknown as WebSocket, + messageId: 1, + }; + }); + + test('should send normal response message if not rate-limited', () => { + invalidMessageHandler = new InvalidMessageHandler(); + invalidMessageHandler.handleInvalidMessage(responseMessage, mockConnection, connectionId); + + expect(sendMessage).toHaveBeenCalled(); + expect(mockConnection.ws.close).not.toHaveBeenCalled(); + }); + + test('should rate limit, close connection, remove all event listeners for messages if over limit', () => { + (RateLimiter as jest.Mock).mockImplementation(() => ({ + rateLimit: jest.fn().mockReturnValue(1000), + removeConnection: jest.fn(), + })); + invalidMessageHandler = new InvalidMessageHandler(); + invalidMessageHandler.handleInvalidMessage(responseMessage, mockConnection, connectionId); + + expect(sendMessage).toHaveBeenCalled(); + expect(mockConnection.ws.close).toHaveBeenCalledWith( + WS_CLOSE_CODE_POLICY_VIOLATION, + JSON.stringify({ message: 'Rate limited' }), + ); + expect(mockConnection.ws.removeAllListeners).toHaveBeenCalledWith(WebsocketEvents.MESSAGE); + }); +}); diff --git a/indexer/services/socks/__tests__/websocket/index.test.ts b/indexer/services/socks/__tests__/websocket/index.test.ts index 092f8dd664..8491fd306f 100644 --- a/indexer/services/socks/__tests__/websocket/index.test.ts +++ b/indexer/services/socks/__tests__/websocket/index.test.ts @@ -13,14 +13,12 @@ import { WebsocketEvents, } from '../../src/types'; import { InvalidMessageHandler } from '../../src/lib/invalid-message'; -import { PingHandler } from '../../src/lib/ping'; import { COUNTRY_HEADER_KEY } from '@dydxprotocol-indexer/compliance'; jest.mock('uuid'); jest.mock('../../src/helpers/wss'); jest.mock('../../src/lib/subscription'); jest.mock('../../src/lib/invalid-message'); -jest.mock('../../src/lib/ping'); describe('Index', () => { let index: Index; @@ -30,8 +28,8 @@ describe('Index', () => { let mockConnect: (ws: WebSocket, req: IncomingMessage) => void; let wsOnSpy: jest.SpyInstance; let wsPingSpy: jest.SpyInstance; + let wsPongSpy: jest.SpyInstance; let invalidMsgHandlerSpy: jest.SpyInstance; - let pingHandlerSpy: jest.SpyInstance; const connectionId: string = 'conId'; const countryCode: string = 'AR'; @@ -54,6 +52,7 @@ describe('Index', () => { websocket = new WebSocket(null); wsOnSpy = jest.spyOn(websocket, 'on'); wsPingSpy = jest.spyOn(websocket, 'ping').mockImplementation(jest.fn()); + wsPongSpy = jest.spyOn(websocket, 'pong').mockImplementation(jest.fn()); mockWss.onConnection = jest.fn().mockImplementation( (cb: (ws: WebSocket, req: IncomingMessage) => void) => { mockConnect = cb; @@ -61,7 +60,6 @@ describe('Index', () => { ); mockSub = new Subscriptions(); invalidMsgHandlerSpy = jest.spyOn(InvalidMessageHandler.prototype, 'handleInvalidMessage'); - pingHandlerSpy = jest.spyOn(PingHandler.prototype, 'handlePing'); index = new Index(mockWss, mockSub); }); @@ -76,11 +74,12 @@ describe('Index', () => { expect(index.connections[connectionId].messageId).toEqual(0); // Test that handlers are attached. - expect(wsOnSpy).toHaveBeenCalledTimes(4); + expect(wsOnSpy).toHaveBeenCalledTimes(5); expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.MESSAGE, expect.anything()); expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.CLOSE, expect.anything()); expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.ERROR, expect.anything()); expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.PONG, expect.anything()); + expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.PING, expect.anything()); // Test that a connection messages is sent. expect(sendMessage).toHaveBeenCalledTimes(1); @@ -127,25 +126,6 @@ describe('Index', () => { ); }); - it('handles ping message', () => { - const pingMessage: IncomingMessage = createIncomingMessage( - { type: IncomingMessageType.PING }, - ); - websocket.emit(WebsocketEvents.MESSAGE, JSON.stringify(pingMessage)); - - expect(pingHandlerSpy).toHaveBeenCalledTimes(1); - expect(pingHandlerSpy).toHaveBeenCalledWith( - expect.objectContaining({ - type: IncomingMessageType.PING, - }), - expect.objectContaining({ - ws: websocket, - messageId: index.connections[connectionId].messageId, - }), - connectionId, - ); - }); - // Nested parameterized test of invalid subscribe and unsubscribe message handling. for (const type of [IncomingMessageType.SUBSCRIBE, IncomingMessageType.UNSUBSCRIBE]) { it.each([ @@ -254,12 +234,14 @@ describe('Index', () => { describe('close', () => { it('disconnects connection on close', () => { + jest.spyOn(websocket, 'removeAllListeners').mockImplementation(jest.fn()); jest.spyOn(websocket, 'terminate').mockImplementation(jest.fn()); websocket.emit(WebsocketEvents.CLOSE); // Run timers for heartbeat. jest.runAllTimers(); expect(wsPingSpy).not.toHaveBeenCalled(); + expect(websocket.removeAllListeners).toHaveBeenCalledTimes(1); expect(websocket.terminate).toHaveBeenCalledTimes(1); expect(mockSub.remove).toHaveBeenCalledWith(connectionId); expect(index.connections[connectionId]).toBeUndefined(); @@ -277,6 +259,15 @@ describe('Index', () => { }); }); + describe('ping', () => { + it('sends pong on receiving ping', () => { + (v4 as unknown as jest.Mock).mockReturnValueOnce(connectionId); + mockConnect(websocket, new IncomingMessage(new Socket())); + websocket.emit(WebsocketEvents.PING); + expect(wsPongSpy).toHaveBeenCalledTimes(2); + }); + }); + describe('pong', () => { it('removes delayed disconnect on pong', () => { // Run pending timers to start heartbeat to attach delayed disconnect. diff --git a/indexer/services/socks/src/lib/invalid-message.ts b/indexer/services/socks/src/lib/invalid-message.ts index 3c1944646c..efcfee7ac4 100644 --- a/indexer/services/socks/src/lib/invalid-message.ts +++ b/indexer/services/socks/src/lib/invalid-message.ts @@ -3,7 +3,7 @@ import { logger } from '@dydxprotocol-indexer/base'; import config from '../config'; import { createErrorMessage } from '../helpers/message'; import { sendMessage } from '../helpers/wss'; -import { Connection } from '../types'; +import { Connection, WebsocketEvents } from '../types'; import { WS_CLOSE_CODE_POLICY_VIOLATION } from './constants'; import { RateLimiter } from './rate-limit'; @@ -42,6 +42,7 @@ export class InvalidMessageHandler { WS_CLOSE_CODE_POLICY_VIOLATION, JSON.stringify({ message: 'Rate limited' }), ); + connection.ws.removeAllListeners(WebsocketEvents.MESSAGE); logger.info({ at: 'invalid-message#handleInvalidMessage', diff --git a/indexer/services/socks/src/lib/ping.ts b/indexer/services/socks/src/lib/ping.ts deleted file mode 100644 index 85d8a5fc04..0000000000 --- a/indexer/services/socks/src/lib/ping.ts +++ /dev/null @@ -1,73 +0,0 @@ -import { logger } from '@dydxprotocol-indexer/base'; - -import config from '../config'; -import { createErrorMessage, createPongMessage } from '../helpers/message'; -import { sendMessage } from '../helpers/wss'; -import { - Connection, - PingMessage, -} from '../types'; -import { WS_CLOSE_CODE_POLICY_VIOLATION } from './constants'; -import { RateLimiter } from './rate-limit'; - -export class PingHandler { - private rateLimiter: RateLimiter; - - constructor() { - this.rateLimiter = new RateLimiter({ - points: config.RATE_LIMIT_PING_POINTS, - durationMs: config.RATE_LIMIT_PING_DURATION_MS, - }); - } - - public handlePing( - pingMessage: PingMessage, - connection: Connection, - connectionId: string, - ): void { - const duration: number = this.rateLimiter.rateLimit({ - connectionId, - key: 'ping', - }); - if (duration > 0) { - sendMessage( - connection.ws, - connectionId, - createErrorMessage( - 'Too many ping messages. Please reconnect and try again.', - connectionId, - connection.messageId, - ), - ); - - // Violated rate-limit; disconnect. - connection.ws.close( - WS_CLOSE_CODE_POLICY_VIOLATION, - JSON.stringify({ message: 'Rate limited' }), - ); - - logger.info({ - at: 'ping#handlePing', - message: 'Connection closed due to violating rate limit', - connectionId, - }); - return; - } - - sendMessage( - connection.ws, - connectionId, - createPongMessage( - connectionId, - connection.messageId, - pingMessage.id, - ), - ); - } - - public handleDisconnect( - connectionId: string, - ): void { - this.rateLimiter.removeConnection(connectionId); - } -} diff --git a/indexer/services/socks/src/types.ts b/indexer/services/socks/src/types.ts index bb49eda969..c4074bbda9 100644 --- a/indexer/services/socks/src/types.ts +++ b/indexer/services/socks/src/types.ts @@ -149,4 +149,5 @@ export enum WebsocketEvents { LISTENING = 'listening', MESSAGE = 'message', PONG = 'pong', + PING = 'ping', } diff --git a/indexer/services/socks/src/websocket/index.ts b/indexer/services/socks/src/websocket/index.ts index 003532c2ee..ebdf4120d8 100644 --- a/indexer/services/socks/src/websocket/index.ts +++ b/indexer/services/socks/src/websocket/index.ts @@ -1,30 +1,24 @@ import { - stats, logger, safeJsonStringify, InfoObject, + InfoObject, logger, safeJsonStringify, stats, } from '@dydxprotocol-indexer/base'; import { v4 as uuidv4 } from 'uuid'; import WebSocket from 'ws'; import config from '../config'; import { getCountry } from '../helpers/header-utils'; -import { - createErrorMessage, - createConnectedMessage, - createUnsubscribedMessage, -} from '../helpers/message'; -import { Wss, sendMessage } from '../helpers/wss'; +import { createConnectedMessage, createErrorMessage, createUnsubscribedMessage } from '../helpers/message'; +import { sendMessage, Wss } from '../helpers/wss'; import { ERR_INVALID_WEBSOCKET_FRAME, WS_CLOSE_CODE_SERVICE_RESTART } from '../lib/constants'; import { InvalidMessageHandler } from '../lib/invalid-message'; -import { PingHandler } from '../lib/ping'; import { Subscriptions } from '../lib/subscription'; import { - IncomingMessageType, + ALL_CHANNELS, Channel, + Connection, IncomingMessage, + IncomingMessageType, SubscribeMessage, UnsubscribeMessage, - Connection, - PingMessage, - ALL_CHANNELS, WebsocketEvents, } from '../types'; @@ -40,14 +34,12 @@ export class Index { // Subscriptions tracking object (see lib/subscriptions.ts). private subscriptions: Subscriptions; // Handlers for pings and invalid messages. - private pingHandler: PingHandler; private invalidMessageHandler: InvalidMessageHandler; constructor(wss: Wss, subscriptions: Subscriptions) { this.wss = wss; this.connections = {}; this.subscriptions = subscriptions; - this.pingHandler = new PingHandler(); this.invalidMessageHandler = new InvalidMessageHandler(); // Attach the new connection handler to the websocket server. @@ -163,14 +155,8 @@ export class Index { HEARTBEAT_INTERVAL_MS, ); - // Attach handler for pongs (response to heartbeat pings) from connection. + // Attach handler for pongs (response to heartbeat [ping]s) from connection. this.connections[connectionId].ws.on(WebsocketEvents.PONG, () => { - logger.info({ - at: 'index#onPong', - message: 'Received pong', - connectionId, - }); - // Clear the delayed disconnect set by the heartbeat handler when a pong is received. if (this.connections[connectionId].disconnect) { clearTimeout(this.connections[connectionId].disconnect); @@ -178,6 +164,10 @@ export class Index { } }); + this.connections[connectionId].ws.on(WebsocketEvents.PING, (data: Buffer) => { + ws.pong(data); + }); + // Attach handler for close events from the connection. this.connections[connectionId].ws.on(WebsocketEvents.CLOSE, (code: number, reason: Buffer) => { logger.info({ @@ -315,12 +305,8 @@ export class Index { ); break; } + // TODO: Consider custom ping messages as invalid after publishing updated documentation. case IncomingMessageType.PING: { - this.pingHandler.handlePing( - parsed as PingMessage, - this.connections[connectionId], - connectionId, - ); break; } default: { @@ -392,11 +378,11 @@ export class Index { if (this.connections[connectionId].heartbeat) { clearInterval(this.connections[connectionId].heartbeat); } + this.connections[connectionId].ws.removeAllListeners(); this.connections[connectionId].ws.terminate(); // Delete subscription data. this.subscriptions.remove(connectionId); - this.pingHandler.handleDisconnect(connectionId); this.invalidMessageHandler.handleDisconnect(connectionId); delete this.connections[connectionId]; } catch (error) {