diff --git a/__tests__/fixtures/cleanup.ts b/__tests__/fixtures/cleanup.ts index 6662d6fc..347242f3 100644 --- a/__tests__/fixtures/cleanup.ts +++ b/__tests__/fixtures/cleanup.ts @@ -17,9 +17,9 @@ export async function ensureTransportIsClean(t: Transport) { `transport ${t.clientId} should not have open connections after the test`, ).toStrictEqual(new Map()); expect( - t.messageHandlers, + t.eventDispatcher.numberOfListeners('message'), `transport ${t.clientId} should not have open message handlers after the test`, - ).toStrictEqual(new Set()); + ).equal(0); } export async function waitUntil( diff --git a/__tests__/invariants.test.ts b/__tests__/invariants.test.ts index 1b1745d7..7779990b 100644 --- a/__tests__/invariants.test.ts +++ b/__tests__/invariants.test.ts @@ -86,16 +86,22 @@ describe('procedures should leave no trace after finishing', async () => { const server = await createServer(serverTransport, serviceDefs); const client = createClient(clientTransport); - let serverListeners = serverTransport.messageHandlers.size; - let clientListeners = clientTransport.messageHandlers.size; + let serverListeners = + serverTransport.eventDispatcher.numberOfListeners('message'); + let clientListeners = + clientTransport.eventDispatcher.numberOfListeners('message'); // start procedure await client.test.add.rpc({ n: 3 }); // end procedure // number of message handlers shouldn't increase after rpc - expect(serverTransport.messageHandlers.size).toEqual(serverListeners); - expect(clientTransport.messageHandlers.size).toEqual(clientListeners); + expect( + serverTransport.eventDispatcher.numberOfListeners('message'), + ).toEqual(serverListeners); + expect( + clientTransport.eventDispatcher.numberOfListeners('message'), + ).toEqual(clientListeners); // check number of connections expect(serverTransport.connections.size).toEqual(1); @@ -113,8 +119,10 @@ describe('procedures should leave no trace after finishing', async () => { const server = await createServer(serverTransport, serviceDefs); const client = createClient(clientTransport); - let serverListeners = serverTransport.messageHandlers.size; - let clientListeners = clientTransport.messageHandlers.size; + let serverListeners = + serverTransport.eventDispatcher.numberOfListeners('message'); + let clientListeners = + clientTransport.eventDispatcher.numberOfListeners('message'); // start procedure const [input, output, close] = await client.test.echo.stream(); @@ -138,8 +146,12 @@ describe('procedures should leave no trace after finishing', async () => { // end procedure // number of message handlers shouldn't increase after stream ends - expect(serverTransport.messageHandlers.size).toEqual(serverListeners); - expect(clientTransport.messageHandlers.size).toEqual(clientListeners); + expect( + serverTransport.eventDispatcher.numberOfListeners('message'), + ).toEqual(serverListeners); + expect( + clientTransport.eventDispatcher.numberOfListeners('message'), + ).toEqual(clientListeners); // check number of connections expect(serverTransport.connections.size).toEqual(1); @@ -157,8 +169,10 @@ describe('procedures should leave no trace after finishing', async () => { const server = await createServer(serverTransport, serviceDefs); const client = createClient(clientTransport); - let serverListeners = serverTransport.messageHandlers.size; - let clientListeners = clientTransport.messageHandlers.size; + let serverListeners = + serverTransport.eventDispatcher.numberOfListeners('message'); + let clientListeners = + clientTransport.eventDispatcher.numberOfListeners('message'); // start procedure const [subscription, close] = await client.test.value.subscribe({}); @@ -177,8 +191,12 @@ describe('procedures should leave no trace after finishing', async () => { // end procedure // number of message handlers shouldn't increase after stream ends - expect(serverTransport.messageHandlers.size).toEqual(serverListeners); - expect(clientTransport.messageHandlers.size).toEqual(clientListeners); + expect( + serverTransport.eventDispatcher.numberOfListeners('message'), + ).toEqual(serverListeners); + expect( + clientTransport.eventDispatcher.numberOfListeners('message'), + ).toEqual(clientListeners); // check number of connections expect(serverTransport.connections.size).toEqual(1); diff --git a/router/client.ts b/router/client.ts index 4d03b52c..599c3bca 100644 --- a/router/client.ts +++ b/router/client.ts @@ -187,7 +187,7 @@ export const createClient = >>( } }; - transport.addMessageListener(listener); + transport.addEventListener('message', listener); const closeHandler = () => { inputStream.end(); outputStream.end(); @@ -200,7 +200,7 @@ export const createClient = >>( streamId, ), ); - transport.removeMessageListener(listener); + transport.removeEventListener('message', listener); }; return [inputStream, outputStream, closeHandler]; @@ -243,7 +243,7 @@ export const createClient = >>( } }; - transport.addMessageListener(listener); + transport.addEventListener('message', listener); const closeHandler = () => { outputStream.end(); transport.send( @@ -255,7 +255,7 @@ export const createClient = >>( streamId, ), ); - transport.removeMessageListener(listener); + transport.removeEventListener('message', listener); }; return [outputStream, closeHandler]; diff --git a/router/server.ts b/router/server.ts index daa3aedb..5a85bf4c 100644 --- a/router/server.ts +++ b/router/server.ts @@ -244,12 +244,12 @@ export async function createServer>( } }; - transport.addMessageListener(handler); + transport.addEventListener('message', handler); return { services, streams: streamMap, async close() { - transport.removeMessageListener(handler); + transport.removeEventListener('message', handler); for (const streamIdx of streamMap.keys()) { await cleanupStream(streamIdx); } diff --git a/transport/events.ts b/transport/events.ts new file mode 100644 index 00000000..b9f12555 --- /dev/null +++ b/transport/events.ts @@ -0,0 +1,45 @@ +import { OpaqueTransportMessage } from './message'; +import { Connection } from './transport'; + +export interface EventMap { + message: OpaqueTransportMessage; + connectionStatus: { + status: 'connect' | 'disconnect'; + conn: Connection; + }; +} + +export type EventTypes = keyof EventMap; +export type EventHandler = (event: EventMap[K]) => void; + +export class EventDispatcher { + private eventListeners: { [K in T]?: Set> } = {}; + + numberOfListeners(eventType: K) { + return this.eventListeners[eventType]?.size ?? 0; + } + + addEventListener(eventType: K, handler: EventHandler) { + if (!this.eventListeners[eventType]) { + this.eventListeners[eventType] = new Set(); + } + + this.eventListeners[eventType]?.add(handler); + } + + removeEventListener(eventType: K, handler: EventHandler) { + const handlers = this.eventListeners[eventType]; + if (handlers) { + this.eventListeners[eventType]?.delete(handler); + } + } + + dispatchEvent(eventType: K, event: EventMap[K]) { + const handlers = this.eventListeners[eventType]; + if (handlers) { + for (const handler of handlers) { + handler(event); + } + } + } +} diff --git a/transport/index.ts b/transport/index.ts index 0a83fcda..76a91a3b 100644 --- a/transport/index.ts +++ b/transport/index.ts @@ -33,12 +33,12 @@ export async function waitForMessage( function onMessage(msg: OpaqueTransportMessage) { if (!filter || filter?.(msg)) { resolve(msg.payload); - t.removeMessageListener(onMessage); + t.removeEventListener('message', onMessage); } else if (rejectMismatch) { reject(new Error('message didnt match the filter')); } } - t.addMessageListener(onMessage); + t.addEventListener('message', onMessage); }); } diff --git a/transport/transport.ts b/transport/transport.ts index c2142840..07f5fc0d 100644 --- a/transport/transport.ts +++ b/transport/transport.ts @@ -11,6 +11,7 @@ import { reply, } from './message'; import { log } from '../logging'; +import { EventDispatcher, EventHandler, EventTypes } from './events'; /** * A 1:1 connection between two transports. Once this is created, @@ -88,11 +89,6 @@ export abstract class Transport { */ clientId: TransportClientId; - /** - * The set of message handlers registered with this transport. - */ - messageHandlers: Set<(msg: OpaqueTransportMessage) => void>; - /** * An array of message IDs that are waiting to be sent over the WebSocket connection. * This builds up if the WebSocket is down for a period of time. @@ -109,13 +105,18 @@ export abstract class Transport { */ connections: Map; + /** + * The event dispatcher for handling events of type EventTypes. + */ + eventDispatcher: EventDispatcher; + /** * Creates a new Transport instance. * @param codec The codec used to encode and decode messages. * @param clientId The client ID of this transport. */ constructor(codec: Codec, clientId: TransportClientId) { - this.messageHandlers = new Set(); + this.eventDispatcher = new EventDispatcher(); this.sendBuffer = new Map(); this.sendQueue = new Map(); this.connections = new Map(); @@ -146,6 +147,11 @@ export abstract class Transport { log?.info(`${this.clientId} -- new connection to ${conn.connectedTo}`); this.connections.set(conn.connectedTo, conn); + this.eventDispatcher.dispatchEvent('connectionStatus', { + status: 'connect', + conn, + }); + // send outstanding const outstanding = this.sendQueue.get(conn.connectedTo); if (!outstanding) { @@ -175,6 +181,10 @@ export abstract class Transport { log?.info(`${this.clientId} -- disconnect from ${conn.connectedTo}`); conn.close(); this.connections.delete(conn.connectedTo); + this.eventDispatcher.dispatchEvent('connectionStatus', { + status: 'disconnect', + conn, + }); } /** @@ -232,9 +242,7 @@ export abstract class Transport { return; } - for (const handler of this.messageHandlers) { - handler(msg); - } + this.eventDispatcher.dispatchEvent('message', msg); if (!isAck(msg.controlFlags)) { const ackMsg = reply(msg, { ack: msg.id }); @@ -247,19 +255,27 @@ export abstract class Transport { } /** - * Adds a message listener to this transport. + * Adds a listener to this transport. + * @param the type of event to listen for * @param handler The message handler to add. */ - addMessageListener(handler: (msg: OpaqueTransportMessage) => void): void { - this.messageHandlers.add(handler); + addEventListener>( + type: K, + handler: T, + ): void { + this.eventDispatcher.addEventListener(type, handler); } /** - * Removes a message listener from this transport. + * Removes a listener from this transport. + * @param the type of event to unlisten on * @param handler The message handler to remove. */ - removeMessageListener(handler: (msg: OpaqueTransportMessage) => void): void { - this.messageHandlers.delete(handler); + removeEventListener>( + type: K, + handler: T, + ): void { + this.eventDispatcher.removeEventListener(type, handler); } /** diff --git a/tsconfig.json b/tsconfig.json index 6c89d5ff..e8da252e 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -27,8 +27,8 @@ "noImplicitAny": true /* Raise error on expressions and declarations with an implied 'any' type. */, "strictNullChecks": true /* Enable strict null checks. */, "strictFunctionTypes": true /* Enable strict checking of function types. */, - // "strictBindCallApply": true, /* Enable strict 'bind', 'call', and 'apply' methods on functions. */ - // "strictPropertyInitialization": true, /* Enable strict checking of property initialization in classes. */ + "strictBindCallApply": true /* Enable strict 'bind', 'call', and 'apply' methods on functions. */, + "strictPropertyInitialization": true /* Enable strict checking of property initialization in classes. */, "noImplicitThis": true /* Raise error on 'this' expressions with an implied 'any' type. */, "alwaysStrict": true /* Parse in strict mode and emit "use strict" for each source file. */, /* Additional Checks */