Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick changes #1303

Merged
merged 2 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions indexer/services/socks/__tests__/lib/invalid-message.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { InvalidMessageHandler } from '../../src/lib/invalid-message';
import { RateLimiter } from '../../src/lib/rate-limit';
import { sendMessage } from '../../src/helpers/wss';
import WebSocket from 'ws';
import { WS_CLOSE_CODE_POLICY_VIOLATION } from '../../src/lib/constants';
import { Connection, WebsocketEvents } from '../../src/types';

jest.mock('../../src/lib/rate-limit');
jest.mock('../../src/helpers/wss', () => ({
sendMessage: jest.fn(),
}));

describe('InvalidMessageHandler', () => {
let invalidMessageHandler: InvalidMessageHandler;
let mockConnection: Connection;
const connectionId = 'testConnectionId';
const responseMessage = 'Test response message';

beforeEach(() => {
(RateLimiter as jest.Mock).mockImplementation(() => ({
rateLimit: jest.fn().mockReturnValue(0),
removeConnection: jest.fn(),
}));

mockConnection = {
ws: {
close: jest.fn(),
removeAllListeners: jest.fn(),
} as unknown as WebSocket,
messageId: 1,
};
});

test('should send normal response message if not rate-limited', () => {
invalidMessageHandler = new InvalidMessageHandler();
invalidMessageHandler.handleInvalidMessage(responseMessage, mockConnection, connectionId);

expect(sendMessage).toHaveBeenCalled();
expect(mockConnection.ws.close).not.toHaveBeenCalled();
});

test('should rate limit, close connection, remove all event listeners for messages if over limit', () => {
(RateLimiter as jest.Mock).mockImplementation(() => ({
rateLimit: jest.fn().mockReturnValue(1000),
removeConnection: jest.fn(),
}));
invalidMessageHandler = new InvalidMessageHandler();
invalidMessageHandler.handleInvalidMessage(responseMessage, mockConnection, connectionId);

expect(sendMessage).toHaveBeenCalled();
expect(mockConnection.ws.close).toHaveBeenCalledWith(
WS_CLOSE_CODE_POLICY_VIOLATION,
JSON.stringify({ message: 'Rate limited' }),
);
expect(mockConnection.ws.removeAllListeners).toHaveBeenCalledWith(WebsocketEvents.MESSAGE);
});
});
39 changes: 15 additions & 24 deletions indexer/services/socks/__tests__/websocket/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ import {
WebsocketEvents,
} from '../../src/types';
import { InvalidMessageHandler } from '../../src/lib/invalid-message';
import { PingHandler } from '../../src/lib/ping';
import { COUNTRY_HEADER_KEY } from '@dydxprotocol-indexer/compliance';

jest.mock('uuid');
jest.mock('../../src/helpers/wss');
jest.mock('../../src/lib/subscription');
jest.mock('../../src/lib/invalid-message');
jest.mock('../../src/lib/ping');

describe('Index', () => {
let index: Index;
Expand All @@ -30,8 +28,8 @@ describe('Index', () => {
let mockConnect: (ws: WebSocket, req: IncomingMessage) => void;
let wsOnSpy: jest.SpyInstance;
let wsPingSpy: jest.SpyInstance;
let wsPongSpy: jest.SpyInstance;
let invalidMsgHandlerSpy: jest.SpyInstance;
let pingHandlerSpy: jest.SpyInstance;

const connectionId: string = 'conId';
const countryCode: string = 'AR';
Expand All @@ -54,14 +52,14 @@ describe('Index', () => {
websocket = new WebSocket(null);
wsOnSpy = jest.spyOn(websocket, 'on');
wsPingSpy = jest.spyOn(websocket, 'ping').mockImplementation(jest.fn());
wsPongSpy = jest.spyOn(websocket, 'pong').mockImplementation(jest.fn());
mockWss.onConnection = jest.fn().mockImplementation(
(cb: (ws: WebSocket, req: IncomingMessage) => void) => {
mockConnect = cb;
},
);
mockSub = new Subscriptions();
invalidMsgHandlerSpy = jest.spyOn(InvalidMessageHandler.prototype, 'handleInvalidMessage');
pingHandlerSpy = jest.spyOn(PingHandler.prototype, 'handlePing');
index = new Index(mockWss, mockSub);
});

Expand All @@ -76,11 +74,12 @@ describe('Index', () => {
expect(index.connections[connectionId].messageId).toEqual(0);

// Test that handlers are attached.
expect(wsOnSpy).toHaveBeenCalledTimes(4);
expect(wsOnSpy).toHaveBeenCalledTimes(5);
expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.MESSAGE, expect.anything());
expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.CLOSE, expect.anything());
expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.ERROR, expect.anything());
expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.PONG, expect.anything());
expect(wsOnSpy).toHaveBeenCalledWith(WebsocketEvents.PING, expect.anything());

// Test that a connection messages is sent.
expect(sendMessage).toHaveBeenCalledTimes(1);
Expand Down Expand Up @@ -127,25 +126,6 @@ describe('Index', () => {
);
});

it('handles ping message', () => {
const pingMessage: IncomingMessage = createIncomingMessage(
{ type: IncomingMessageType.PING },
);
websocket.emit(WebsocketEvents.MESSAGE, JSON.stringify(pingMessage));

expect(pingHandlerSpy).toHaveBeenCalledTimes(1);
expect(pingHandlerSpy).toHaveBeenCalledWith(
expect.objectContaining({
type: IncomingMessageType.PING,
}),
expect.objectContaining({
ws: websocket,
messageId: index.connections[connectionId].messageId,
}),
connectionId,
);
});

// Nested parameterized test of invalid subscribe and unsubscribe message handling.
for (const type of [IncomingMessageType.SUBSCRIBE, IncomingMessageType.UNSUBSCRIBE]) {
it.each([
Expand Down Expand Up @@ -254,12 +234,14 @@ describe('Index', () => {

describe('close', () => {
it('disconnects connection on close', () => {
jest.spyOn(websocket, 'removeAllListeners').mockImplementation(jest.fn());
jest.spyOn(websocket, 'terminate').mockImplementation(jest.fn());
websocket.emit(WebsocketEvents.CLOSE);
// Run timers for heartbeat.
jest.runAllTimers();

expect(wsPingSpy).not.toHaveBeenCalled();
expect(websocket.removeAllListeners).toHaveBeenCalledTimes(1);
expect(websocket.terminate).toHaveBeenCalledTimes(1);
expect(mockSub.remove).toHaveBeenCalledWith(connectionId);
expect(index.connections[connectionId]).toBeUndefined();
Expand All @@ -277,6 +259,15 @@ describe('Index', () => {
});
});

describe('ping', () => {
it('sends pong on receiving ping', () => {
(v4 as unknown as jest.Mock).mockReturnValueOnce(connectionId);
mockConnect(websocket, new IncomingMessage(new Socket()));
websocket.emit(WebsocketEvents.PING);
expect(wsPongSpy).toHaveBeenCalledTimes(2);
});
});

describe('pong', () => {
it('removes delayed disconnect on pong', () => {
// Run pending timers to start heartbeat to attach delayed disconnect.
Expand Down
3 changes: 2 additions & 1 deletion indexer/services/socks/src/lib/invalid-message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { logger } from '@dydxprotocol-indexer/base';
import config from '../config';
import { createErrorMessage } from '../helpers/message';
import { sendMessage } from '../helpers/wss';
import { Connection } from '../types';
import { Connection, WebsocketEvents } from '../types';
import { WS_CLOSE_CODE_POLICY_VIOLATION } from './constants';
import { RateLimiter } from './rate-limit';

Expand Down Expand Up @@ -42,6 +42,7 @@ export class InvalidMessageHandler {
WS_CLOSE_CODE_POLICY_VIOLATION,
JSON.stringify({ message: 'Rate limited' }),
);
connection.ws.removeAllListeners(WebsocketEvents.MESSAGE);

logger.info({
at: 'invalid-message#handleInvalidMessage',
Expand Down
73 changes: 0 additions & 73 deletions indexer/services/socks/src/lib/ping.ts

This file was deleted.

1 change: 1 addition & 0 deletions indexer/services/socks/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,5 @@ export enum WebsocketEvents {
LISTENING = 'listening',
MESSAGE = 'message',
PONG = 'pong',
PING = 'ping',
}
52 changes: 23 additions & 29 deletions indexer/services/socks/src/websocket/index.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
import {
stats, logger, safeJsonStringify, InfoObject,
InfoObject, logger, safeJsonStringify, stats,
} from '@dydxprotocol-indexer/base';
import { v4 as uuidv4 } from 'uuid';
import WebSocket from 'ws';

import config from '../config';
import { getCountry } from '../helpers/header-utils';
import {
createErrorMessage,
createConnectedMessage,
createUnsubscribedMessage,
} from '../helpers/message';
import { Wss, sendMessage } from '../helpers/wss';
import { createConnectedMessage, createErrorMessage, createUnsubscribedMessage } from '../helpers/message';
import { sendMessage, Wss } from '../helpers/wss';
import { ERR_INVALID_WEBSOCKET_FRAME, WS_CLOSE_CODE_SERVICE_RESTART } from '../lib/constants';
import { InvalidMessageHandler } from '../lib/invalid-message';
import { PingHandler } from '../lib/ping';
import { Subscriptions } from '../lib/subscription';
import {
IncomingMessageType,
ALL_CHANNELS,
Channel,
Connection,
IncomingMessage,
IncomingMessageType,
SubscribeMessage,
UnsubscribeMessage,
Connection,
PingMessage,
ALL_CHANNELS,
WebsocketEvents,
} from '../types';

Expand All @@ -40,14 +34,12 @@ export class Index {
// Subscriptions tracking object (see lib/subscriptions.ts).
private subscriptions: Subscriptions;
// Handlers for pings and invalid messages.
private pingHandler: PingHandler;
private invalidMessageHandler: InvalidMessageHandler;

constructor(wss: Wss, subscriptions: Subscriptions) {
this.wss = wss;
this.connections = {};
this.subscriptions = subscriptions;
this.pingHandler = new PingHandler();
this.invalidMessageHandler = new InvalidMessageHandler();

// Attach the new connection handler to the websocket server.
Expand Down Expand Up @@ -163,21 +155,19 @@ export class Index {
HEARTBEAT_INTERVAL_MS,
);

// Attach handler for pongs (response to heartbeat pings) from connection.
// Attach handler for pongs (response to heartbeat [ping]s) from connection.
this.connections[connectionId].ws.on(WebsocketEvents.PONG, () => {
logger.info({
at: 'index#onPong',
message: 'Received pong',
connectionId,
});

// Clear the delayed disconnect set by the heartbeat handler when a pong is received.
if (this.connections[connectionId].disconnect) {
clearTimeout(this.connections[connectionId].disconnect);
delete this.connections[connectionId].disconnect;
}
});

this.connections[connectionId].ws.on(WebsocketEvents.PING, (data: Buffer) => {
ws.pong(data);
});

// Attach handler for close events from the connection.
this.connections[connectionId].ws.on(WebsocketEvents.CLOSE, (code: number, reason: Buffer) => {
logger.info({
Expand Down Expand Up @@ -224,7 +214,11 @@ export class Index {
* @returns
*/
private onMessage(connectionId: string, message: WebSocket.Data): void {
stats.increment(`${config.SERVICE_NAME}.on_message`, 1);
stats.increment(
`${config.SERVICE_NAME}.on_message`,
1,
config.MESSAGE_FORWARDER_STATSD_SAMPLE_RATE,
);
if (!this.connections[connectionId]) {
logger.info({
at: 'index#onMessage',
Expand Down Expand Up @@ -315,12 +309,8 @@ export class Index {
);
break;
}
// TODO: Consider custom ping messages as invalid after publishing updated documentation.
case IncomingMessageType.PING: {
this.pingHandler.handlePing(
parsed as PingMessage,
this.connections[connectionId],
connectionId,
);
break;
}
default: {
Expand All @@ -332,7 +322,11 @@ export class Index {
return;
}
}
stats.increment(`${config.SERVICE_NAME}.message_received_${parsed.type}`, 1);
stats.increment(
`${config.SERVICE_NAME}.message_received_${parsed.type}`,
1,
config.MESSAGE_FORWARDER_STATSD_SAMPLE_RATE,
);
}

/**
Expand Down Expand Up @@ -392,11 +386,11 @@ export class Index {
if (this.connections[connectionId].heartbeat) {
clearInterval(this.connections[connectionId].heartbeat);
}
this.connections[connectionId].ws.removeAllListeners();
this.connections[connectionId].ws.terminate();

// Delete subscription data.
this.subscriptions.remove(connectionId);
this.pingHandler.handleDisconnect(connectionId);
this.invalidMessageHandler.handleDisconnect(connectionId);
delete this.connections[connectionId];
} catch (error) {
Expand Down
Loading