Skip to content

Commit

Permalink
Cherry-pick changes (#1303)
Browse files Browse the repository at this point in the history
* Remove custom ping message from socks (#1301)

* sample more metrics (#1304)
  • Loading branch information
dydxwill authored Apr 1, 2024
1 parent d59bd14 commit 6cc1b8c
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 127 deletions.
57 changes: 57 additions & 0 deletions indexer/services/socks/__tests__/lib/invalid-message.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { InvalidMessageHandler } from '../../src/lib/invalid-message';
import { RateLimiter } from '../../src/lib/rate-limit';
import { sendMessage } from '../../src/helpers/wss';
import WebSocket from 'ws';
import { WS_CLOSE_CODE_POLICY_VIOLATION } from '../../src/lib/constants';
import { Connection, WebsocketEvents } from '../../src/types';

jest.mock('../../src/lib/rate-limit');
jest.mock('../../src/helpers/wss', () => ({
sendMessage: jest.fn(),
}));

describe('InvalidMessageHandler', () => {
let invalidMessageHandler: InvalidMessageHandler;
let mockConnection: Connection;
const connectionId = 'testConnectionId';
const responseMessage = 'Test response message';

beforeEach(() => {
(RateLimiter as jest.Mock).mockImplementation(() => ({
rateLimit: jest.fn().mockReturnValue(0),
removeConnection: jest.fn(),
}));

mockConnection = {
ws: {
close: jest.fn(),
removeAllListeners: jest.fn(),
} as unknown as WebSocket,
messageId: 1,
};
});

test('should send normal response message if not rate-limited', () => {
invalidMessageHandler = new InvalidMessageHandler();
invalidMessageHandler.handleInvalidMessage(responseMessage, mockConnection, connectionId);

expect(sendMessage).toHaveBeenCalled();
expect(mockConnection.ws.close).not.toHaveBeenCalled();
});

test('should rate limit, close connection, remove all event listeners for messages if over limit', () => {
(RateLimiter as jest.Mock).mockImplementation(() => ({
rateLimit: jest.fn().mockReturnValue(1000),
removeConnection: jest.fn(),
}));
invalidMessageHandler = new InvalidMessageHandler();
invalidMessageHandler.handleInvalidMessage(responseMessage, mockConnection, connectionId);

expect(sendMessage).toHaveBeenCalled();
expect(mockConnection.ws.close).toHaveBeenCalledWith(
WS_CLOSE_CODE_POLICY_VIOLATION,
JSON.stringify({ message: 'Rate limited' }),
);
expect(mockConnection.ws.removeAllListeners).toHaveBeenCalledWith(WebsocketEvents.MESSAGE);
});
});
39 changes: 15 additions & 24 deletions indexer/services/socks/__tests__/websocket/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ import {
WebsocketEvents,
} from '../../src/types';
import { InvalidMessageHandler } from '../../src/lib/invalid-message';
import { PingHandler } from '../../src/lib/ping';
import { COUNTRY_HEADER_KEY } from '@dydxprotocol-indexer/compliance';

jest.mock('uuid');
jest.mock('../../src/helpers/wss');
jest.mock('../../src/lib/subscription');
jest.mock('../../src/lib/invalid-message');
jest.mock('../../src/lib/ping');

describe('Index', () => {
let index: Index;
Expand All @@ -30,8 +28,8 @@ describe('Index', () => {
let mockConnect: (ws: WebSocket, req: IncomingMessage) => void;
let wsOnSpy: jest.SpyInstance;
let wsPingSpy: jest.SpyInstance;
let wsPongSpy: jest.SpyInstance;
let invalidMsgHandlerSpy: jest.SpyInstance;
let pingHandlerSpy: jest.SpyInstance;

const connectionId: string = 'conId';
const countryCode: string = 'AR';
Expand All @@ -54,14 +52,14 @@ describe('Index', () => {
websocket = new WebSocket(null);
wsOnSpy = jest.spyOn(websocket, 'on');
wsPingSpy = jest.spyOn(websocket, 'ping').mockImplementation(jest.fn());
wsPongSpy = jest.spyOn(websocket, 'pong').mockImplementation(jest.fn());
mockWss.onConnection = jest.fn().mockImplementation(
(cb: (ws: WebSocket, req: IncomingMessage) => void) => {
mockConnect = cb;
},
);
mockSub = new Subscriptions();
invalidMsgHandlerSpy = jest.spyOn(InvalidMessageHandler.prototype, 'handleInvalidMessage');
pingHandlerSpy = jest.spyOn(PingHandler.prototype, 'handlePing');
index = new Index(mockWss, mockSub);
});

Expand All @@ -76,11 +74,12 @@ describe('Index', () => {
expect(index.connections[connectionId].messageId).toEqual(0);

// Test that handlers are attached.
expect(wsOnSpy).toHaveBeenCalledTimes(4);
expect(wsOnSpy).toHaveBeenCalledTimes(5);
expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.MESSAGE, expect.anything());
expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.CLOSE, expect.anything());
expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.ERROR, expect.anything());
expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.PONG, expect.anything());
expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.PING, expect.anything());

// Test that a connection messages is sent.
expect(sendMessage).toHaveBeenCalledTimes(1);
Expand Down Expand Up @@ -127,25 +126,6 @@ describe('Index', () => {
);
});

it('handles ping message', () => {
const pingMessage: IncomingMessage = createIncomingMessage(
{ type: IncomingMessageType.PING },
);
websocket.emit(WebsocketEvents.MESSAGE, JSON.stringify(pingMessage));

expect(pingHandlerSpy).toHaveBeenCalledTimes(1);
expect(pingHandlerSpy).toHaveBeenCalledWith(
expect.objectContaining({
type: IncomingMessageType.PING,
}),
expect.objectContaining({
ws: websocket,
messageId: index.connections[connectionId].messageId,
}),
connectionId,
);
});

// Nested parameterized test of invalid subscribe and unsubscribe message handling.
for (const type of [IncomingMessageType.SUBSCRIBE, IncomingMessageType.UNSUBSCRIBE]) {
it.each([
Expand Down Expand Up @@ -254,12 +234,14 @@ describe('Index', () => {

describe('close', () => {
it('disconnects connection on close', () => {
jest.spyOn(websocket, 'removeAllListeners').mockImplementation(jest.fn());
jest.spyOn(websocket, 'terminate').mockImplementation(jest.fn());
websocket.emit(WebsocketEvents.CLOSE);
// Run timers for heartbeat.
jest.runAllTimers();

expect(wsPingSpy).not.toHaveBeenCalled();
expect(websocket.removeAllListeners).toHaveBeenCalledTimes(1);
expect(websocket.terminate).toHaveBeenCalledTimes(1);
expect(mockSub.remove).toHaveBeenCalledWith(connectionId);
expect(index.connections[connectionId]).toBeUndefined();
Expand All @@ -277,6 +259,15 @@ describe('Index', () => {
});
});

describe('ping', () => {
it('sends pong on receiving ping', () => {
(v4 as unknown as jest.Mock).mockReturnValueOnce(connectionId);
mockConnect(websocket, new IncomingMessage(new Socket()));
websocket.emit(WebsocketEvents.PING);
expect(wsPongSpy).toHaveBeenCalledTimes(2);
});
});

describe('pong', () => {
it('removes delayed disconnect on pong', () => {
// Run pending timers to start heartbeat to attach delayed disconnect.
Expand Down
3 changes: 2 additions & 1 deletion indexer/services/socks/src/lib/invalid-message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { logger } from '@dydxprotocol-indexer/base';
import config from '../config';
import { createErrorMessage } from '../helpers/message';
import { sendMessage } from '../helpers/wss';
import { Connection } from '../types';
import { Connection, WebsocketEvents } from '../types';
import { WS_CLOSE_CODE_POLICY_VIOLATION } from './constants';
import { RateLimiter } from './rate-limit';

Expand Down Expand Up @@ -42,6 +42,7 @@ export class InvalidMessageHandler {
WS_CLOSE_CODE_POLICY_VIOLATION,
JSON.stringify({ message: 'Rate limited' }),
);
connection.ws.removeAllListeners(WebsocketEvents.MESSAGE);

logger.info({
at: 'invalid-message#handleInvalidMessage',
Expand Down
73 changes: 0 additions & 73 deletions indexer/services/socks/src/lib/ping.ts

This file was deleted.

1 change: 1 addition & 0 deletions indexer/services/socks/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,5 @@ export enum WebsocketEvents {
LISTENING = 'listening',
MESSAGE = 'message',
PONG = 'pong',
PING = 'ping',
}
52 changes: 23 additions & 29 deletions indexer/services/socks/src/websocket/index.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
import {
stats, logger, safeJsonStringify, InfoObject,
InfoObject, logger, safeJsonStringify, stats,
} from '@dydxprotocol-indexer/base';
import { v4 as uuidv4 } from 'uuid';
import WebSocket from 'ws';

import config from '../config';
import { getCountry } from '../helpers/header-utils';
import {
createErrorMessage,
createConnectedMessage,
createUnsubscribedMessage,
} from '../helpers/message';
import { Wss, sendMessage } from '../helpers/wss';
import { createConnectedMessage, createErrorMessage, createUnsubscribedMessage } from '../helpers/message';
import { sendMessage, Wss } from '../helpers/wss';
import { ERR_INVALID_WEBSOCKET_FRAME, WS_CLOSE_CODE_SERVICE_RESTART } from '../lib/constants';
import { InvalidMessageHandler } from '../lib/invalid-message';
import { PingHandler } from '../lib/ping';
import { Subscriptions } from '../lib/subscription';
import {
IncomingMessageType,
ALL_CHANNELS,
Channel,
Connection,
IncomingMessage,
IncomingMessageType,
SubscribeMessage,
UnsubscribeMessage,
Connection,
PingMessage,
ALL_CHANNELS,
WebsocketEvents,
} from '../types';

Expand All @@ -40,14 +34,12 @@ export class Index {
// Subscriptions tracking object (see lib/subscriptions.ts).
private subscriptions: Subscriptions;
// Handlers for pings and invalid messages.
private pingHandler: PingHandler;
private invalidMessageHandler: InvalidMessageHandler;

constructor(wss: Wss, subscriptions: Subscriptions) {
this.wss = wss;
this.connections = {};
this.subscriptions = subscriptions;
this.pingHandler = new PingHandler();
this.invalidMessageHandler = new InvalidMessageHandler();

// Attach the new connection handler to the websocket server.
Expand Down Expand Up @@ -163,21 +155,19 @@ export class Index {
HEARTBEAT_INTERVAL_MS,
);

// Attach handler for pongs (response to heartbeat pings) from connection.
// Attach handler for pongs (response to heartbeat [ping]s) from connection.
this.connections[connectionId].ws.on(WebsocketEvents.PONG, () => {
logger.info({
at: 'index#onPong',
message: 'Received pong',
connectionId,
});

// Clear the delayed disconnect set by the heartbeat handler when a pong is received.
if (this.connections[connectionId].disconnect) {
clearTimeout(this.connections[connectionId].disconnect);
delete this.connections[connectionId].disconnect;
}
});

this.connections[connectionId].ws.on(WebsocketEvents.PING, (data: Buffer) => {
ws.pong(data);
});

// Attach handler for close events from the connection.
this.connections[connectionId].ws.on(WebsocketEvents.CLOSE, (code: number, reason: Buffer) => {
logger.info({
Expand Down Expand Up @@ -224,7 +214,11 @@ export class Index {
* @returns
*/
private onMessage(connectionId: string, message: WebSocket.Data): void {
stats.increment(`${config.SERVICE_NAME}.on_message`, 1);
stats.increment(
`${config.SERVICE_NAME}.on_message`,
1,
config.MESSAGE_FORWARDER_STATSD_SAMPLE_RATE,
);
if (!this.connections[connectionId]) {
logger.info({
at: 'index#onMessage',
Expand Down Expand Up @@ -315,12 +309,8 @@ export class Index {
);
break;
}
// TODO: Consider custom ping messages as invalid after publishing updated documentation.
case IncomingMessageType.PING: {
this.pingHandler.handlePing(
parsed as PingMessage,
this.connections[connectionId],
connectionId,
);
break;
}
default: {
Expand All @@ -332,7 +322,11 @@ export class Index {
return;
}
}
stats.increment(`${config.SERVICE_NAME}.message_received_${parsed.type}`, 1);
stats.increment(
`${config.SERVICE_NAME}.message_received_${parsed.type}`,
1,
config.MESSAGE_FORWARDER_STATSD_SAMPLE_RATE,
);
}

/**
Expand Down Expand Up @@ -392,11 +386,11 @@ export class Index {
if (this.connections[connectionId].heartbeat) {
clearInterval(this.connections[connectionId].heartbeat);
}
this.connections[connectionId].ws.removeAllListeners();
this.connections[connectionId].ws.terminate();

// Delete subscription data.
this.subscriptions.remove(connectionId);
this.pingHandler.handleDisconnect(connectionId);
this.invalidMessageHandler.handleDisconnect(connectionId);
delete this.connections[connectionId];
} catch (error) {
Expand Down

0 comments on commit 6cc1b8c

Please sign in to comment.