diff --git a/__tests__/typescript-stress.test.ts b/__tests__/typescript-stress.test.ts index 7ae62469..27eb8f23 100644 --- a/__tests__/typescript-stress.test.ts +++ b/__tests__/typescript-stress.test.ts @@ -77,7 +77,7 @@ export class MockTransport extends Transport { super(NaiveJsonCodec, clientId); } - send(msg: OpaqueTransportMessage): MessageId { + async send(msg: OpaqueTransportMessage): Promise { const id = msg.id; return id; } diff --git a/index.ts b/index.ts index 8f5563db..ec49013b 100644 --- a/index.ts +++ b/index.ts @@ -43,6 +43,5 @@ export { onServerReady, createWsTransports, waitForMessage, - waitForSocketReady, - createWebSocketClient, + createLocalWebSocketClient, } from './transport/util'; diff --git a/package-lock.json b/package-lock.json index 5e8feec1..5c49cff6 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.1.10", + "version": "0.2.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.1.10", + "version": "0.2.0", "license": "MIT", "dependencies": { "@sinclair/typebox": "^0.31.8", diff --git a/package.json b/package.json index 99fa708b..c9481009 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "@replit/river", "sideEffects": false, "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.1.10", + "version": "0.2.0", "type": "module", "main": "index.js", "types": "index.d.ts", diff --git a/router/client.ts b/router/client.ts index 3dccce49..57fb9282 100644 --- a/router/client.ts +++ b/router/client.ts @@ -74,7 +74,7 @@ export const createClient = >>( // this gets cleaned up on i.end() which is called by closeHandler (async () => { for await (const rawIn of i) { - transport.send( + await transport.send( msg( transport.clientId, 'SERVER', @@ -103,7 +103,7 @@ export const createClient = >>( return [i, o, closeHandler]; } else { // rpc case - const id = transport.send( + const id = await transport.send( msg( transport.clientId, 'SERVER', diff --git a/router/server.ts b/router/server.ts index 85512649..7033b97f 100644 --- a/router/server.ts +++ b/router/server.ts @@ -63,7 +63,7 @@ export async function createServer>( // sending outgoing messages back to client (async () => { for await (const response of outgoing) { - transport.send(response); + await transport.send(response); } })(), ]), @@ -101,7 +101,7 @@ export async function createServer>( getContext(service), inputMessage, ); - transport.send(response); + await transport.send(response); return; } else if ( procedure.type === 'stream' && diff --git a/transport/message.ts b/transport/message.ts index 59deea40..80bf324f 100644 --- a/transport/message.ts +++ b/transport/message.ts @@ -32,6 +32,7 @@ export type MessageId = string; export type OpaqueTransportMessage = TransportMessage; export type TransportClientId = 'SERVER' | string; export const TransportAckSchema = Type.Object({ + id: Type.String(), from: Type.String(), ack: Type.String(), }); @@ -63,6 +64,7 @@ export function payloadToTransportMessage( export function ack(msg: OpaqueTransportMessage): TransportMessageAck { return { + id: nanoid(), from: msg.to, ack: msg.id, }; diff --git a/transport/stream.test.ts b/transport/stream.test.ts index 74219f3b..563d629b 100644 --- a/transport/stream.test.ts +++ b/transport/stream.test.ts @@ -24,7 +24,7 @@ describe('sending and receiving across node streams works', () => { }; const p = waitForMessage(serverTransport); - clientTransport.send({ + await clientTransport.send({ id: '1', from: 'client', to: 'SERVER', diff --git a/transport/stream.ts b/transport/stream.ts index de7ee142..a30f0cb8 100644 --- a/transport/stream.ts +++ b/transport/stream.ts @@ -22,7 +22,7 @@ export class StreamTransport extends Transport { rl.on('line', (msg) => this.onMessage(msg)); } - send(msg: OpaqueTransportMessage): string { + async send(msg: OpaqueTransportMessage): Promise { const id = msg.id; this.output.write(this.codec.toStringBuf(msg) + '\n'); return id; diff --git a/transport/types.ts b/transport/types.ts index 3fd36244..d8ecedea 100644 --- a/transport/types.ts +++ b/transport/types.ts @@ -53,6 +53,8 @@ export abstract class Transport { this.handlers.delete(handler); } - abstract send(msg: OpaqueTransportMessage | TransportMessageAck): MessageId; + abstract send( + msg: OpaqueTransportMessage | TransportMessageAck, + ): Promise; abstract close(): Promise; } diff --git a/transport/util.ts b/transport/util.ts index be3a0734..8ea38a7a 100644 --- a/transport/util.ts +++ b/transport/util.ts @@ -18,33 +18,25 @@ export async function onServerReady( }); } +export async function createLocalWebSocketClient(port: number) { + return new WebSocket(`ws://localhost:${port}`); +} + export async function createWsTransports( port: number, wss: WebSocketServer, ): Promise<[Transport, Transport]> { return new Promise((resolve) => { - const clientSockPromise = createWebSocketClient(port); + const clientSockPromise = createLocalWebSocketClient(port); wss.on('connection', async (serverSock) => { resolve([ - new WebSocketTransport(await clientSockPromise, 'client'), - new WebSocketTransport(serverSock, 'SERVER'), + new WebSocketTransport(() => clientSockPromise, 'client'), + new WebSocketTransport(() => Promise.resolve(serverSock), 'SERVER'), ]); }); }); } -export async function waitForSocketReady(socket: WebSocket) { - return new Promise((resolve) => { - socket.addEventListener('open', () => resolve()); - }); -} - -export async function createWebSocketClient(port: number) { - const client = new WebSocket(`ws://localhost:${port}`); - await waitForSocketReady(client); - return client; -} - export async function waitForMessage( t: Transport, filter?: (msg: OpaqueTransportMessage) => boolean, diff --git a/transport/ws.test.ts b/transport/ws.test.ts index 1418f036..67de7a4a 100644 --- a/transport/ws.test.ts +++ b/transport/ws.test.ts @@ -3,7 +3,7 @@ import { WebSocketServer } from 'ws'; import { WebSocketTransport } from './ws'; import { describe, test, expect, beforeAll, afterAll } from 'vitest'; import { - createWebSocketClient, + createLocalWebSocketClient, createWebSocketServer, onServerReady, waitForMessage, @@ -28,18 +28,22 @@ describe('sending and receiving across websockets works', () => { test('basic send/receive', async () => { let serverTransport: WebSocketTransport | undefined; wss.on('connection', (conn) => { - serverTransport = new WebSocketTransport(conn, 'SERVER'); + serverTransport = new WebSocketTransport( + () => Promise.resolve(conn), + 'SERVER', + ); }); - - const clientSoc = await createWebSocketClient(port); - const clientTransport = new WebSocketTransport(clientSoc, 'client'); + const clientTransport = new WebSocketTransport( + () => createLocalWebSocketClient(port), + 'client', + ); const msg = { msg: 'cool', test: 123, }; - clientTransport.send({ + await clientTransport.send({ id: '1', from: 'client', to: 'SERVER', diff --git a/transport/ws.ts b/transport/ws.ts index a6433311..04a58a4b 100644 --- a/transport/ws.ts +++ b/transport/ws.ts @@ -1,4 +1,4 @@ -import type WebSocket from 'isomorphic-ws'; +import WebSocket from 'isomorphic-ws'; import { Transport } from './types'; import { NaiveJsonCodec } from '../codec/json'; import { @@ -13,21 +13,60 @@ import { // - how do we handle forceful client disconnects? (i.e. broken connection, offline) // - how do we handle forceful service disconnects (i.e. a crash)? export class WebSocketTransport extends Transport { - ws: WebSocket; + wsGetter: () => Promise; + ws?: WebSocket; + destroyed: boolean; - constructor(ws: WebSocket, clientId: TransportClientId) { + constructor(wsGetter: () => Promise, clientId: TransportClientId) { super(NaiveJsonCodec, clientId); - this.ws = ws; - ws.onmessage = (msg) => this.onMessage(msg.data.toString()); + this.destroyed = false; + this.wsGetter = wsGetter; + this.waitForSocketReady(); } - send(msg: OpaqueTransportMessage): MessageId { + // postcondition: ws is concretely a WebSocket + private async waitForSocketReady(): Promise { + return new Promise((resolve, reject) => { + if (this.destroyed) { + reject(new Error('ws is destroyed')); + return; + } + + if (this.ws) { + // constructed ws but not open + if (this.ws.readyState === this.ws.OPEN) { + return resolve(this.ws); + } + + // resolve on open + this.ws.onopen = (evt) => { + return resolve(evt.target); + }; + + // reject if borked + this.ws.onerror = (err) => reject(err); + } else { + // not constructed + this.wsGetter().then((ws) => { + this.ws = ws; + return resolve(this.waitForSocketReady()); + }); + } + }).then((ws) => { + ws.onmessage = (msg) => this.onMessage(msg.data.toString()); + return ws; + }); + } + + async send(msg: OpaqueTransportMessage): Promise { const id = msg.id; - this.ws.send(this.codec.toStringBuf(msg)); + const ws = await this.waitForSocketReady(); + ws.send(this.codec.toStringBuf(msg)); return id; } async close() { - return this.ws.close(); + this.destroyed = true; + return this.ws?.close(); } }