From d71f9da2f39b2fc3b15cdceccfaab81fe734e67c Mon Sep 17 00:00:00 2001 From: lhchavez Date: Wed, 13 Dec 2023 11:19:46 -0800 Subject: [PATCH] fix: Don't require _every_ message to contain procedure information (#32) Previously we had the requirement that every single message in a stream had to have the service name + procedure. But that's very wasteful and also kind of awkward to be saving in other implementations. This change does a little refactor so that the server no longer relies on having this information in every message in a stream. In doing this refactor, we also now can reason about what the first message is, and can now correctly validate that the initialization message is received first and then all future messages only comply with the regular input type. Also fixed a bug where an overly zealous listener would incorrectly close all streams upon the first close message. --- package.json | 2 +- router/client.ts | 72 +++------ router/server.ts | 290 ++++++++++++++++++---------------- transport/impls/ws/ws.test.ts | 13 +- transport/message.test.ts | 12 +- transport/message.ts | 37 +++-- transport/transport.ts | 15 +- util/testHelpers.ts | 2 +- 8 files changed, 232 insertions(+), 211 deletions(-) diff --git a/package.json b/package.json index 4feefc8f..2119b09e 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "@replit/river", "sideEffects": false, "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.9.1", + "version": "0.9.2", "type": "module", "exports": { ".": "./dist/router/index.js", diff --git a/router/client.ts b/router/client.ts index 2e955f81..7e46f974 100644 --- a/router/client.ts +++ b/router/client.ts @@ -189,11 +189,7 @@ export const createClient = >>( const streamId = nanoid(); function belongsToSameStream(msg: OpaqueTransportMessage) { - return ( - msg.serviceName === serviceName && - msg.procedureName === procName && - msg.streamId === streamId - ); + return msg.streamId === streamId; } if (procType === 'stream') { @@ -205,10 +201,10 @@ export const createClient = >>( const m = msg( transport.clientId, serverId, - serviceName, - procName, streamId, input as object, + serviceName, + procName, ); // first message needs the open bit. @@ -224,13 +220,13 @@ export const createClient = >>( const m = msg( transport.clientId, serverId, - serviceName, - procName, streamId, rawIn as object, ); if (firstMessage) { + m.serviceName = serviceName; + m.procedureName = procName; m.controlFlags |= ControlFlags.StreamOpenBit; firstMessage = false; } @@ -241,9 +237,13 @@ export const createClient = >>( // transport -> output const listener = (msg: OpaqueTransportMessage) => { + if (!belongsToSameStream(msg)) { + return; + } + if (isStreamClose(msg.controlFlags)) { outputStream.end(); - } else if (belongsToSameStream(msg)) { + } else { outputStream.push(msg.payload); } }; @@ -252,15 +252,7 @@ export const createClient = >>( const closeHandler = () => { inputStream.end(); outputStream.end(); - transport.send( - closeStream( - transport.clientId, - serverId, - serviceName, - procName, - streamId, - ), - ); + transport.send(closeStream(transport.clientId, serverId, streamId)); transport.removeEventListener('message', listener); }; @@ -269,10 +261,10 @@ export const createClient = >>( const m = msg( transport.clientId, serverId, - serviceName, - procName, streamId, input as object, + serviceName, + procName, ); // rpc is a stream open + close @@ -284,10 +276,10 @@ export const createClient = >>( const m = msg( transport.clientId, serverId, - serviceName, - procName, streamId, input as object, + serviceName, + procName, ); m.controlFlags |= ControlFlags.StreamOpenBit; transport.send(m); @@ -295,27 +287,21 @@ export const createClient = >>( // transport -> output const outputStream = pushable({ objectMode: true }); const listener = (msg: OpaqueTransportMessage) => { - if (belongsToSameStream(msg)) { - outputStream.push(msg.payload); + if (!belongsToSameStream(msg)) { + return; } if (isStreamClose(msg.controlFlags)) { outputStream.end(); + } else { + outputStream.push(msg.payload); } }; transport.addEventListener('message', listener); const closeHandler = () => { outputStream.end(); - transport.send( - closeStream( - transport.clientId, - serverId, - serviceName, - procName, - streamId, - ), - ); + transport.send(closeStream(transport.clientId, serverId, streamId)); transport.removeEventListener('message', listener); }; @@ -328,10 +314,10 @@ export const createClient = >>( const m = msg( transport.clientId, serverId, - serviceName, - procName, streamId, input as object, + serviceName, + procName, ); // first message needs the open bit. @@ -347,29 +333,21 @@ export const createClient = >>( const m = msg( transport.clientId, serverId, - serviceName, - procName, streamId, rawIn as object, ); if (firstMessage) { m.controlFlags |= ControlFlags.StreamOpenBit; + m.serviceName = serviceName; + m.procedureName = procName; firstMessage = false; } transport.send(m); } - transport.send( - closeStream( - transport.clientId, - serverId, - serviceName, - procName, - streamId, - ), - ); + transport.send(closeStream(transport.clientId, serverId, streamId)); })(); return [inputStream, waitForMessage(transport, belongsToSameStream)]; diff --git a/router/server.ts b/router/server.ts index 33dba4fb..ea445316 100644 --- a/router/server.ts +++ b/router/server.ts @@ -34,6 +34,10 @@ export interface Server { } interface ProcStream { + id: string; + serviceName: string; + procedureName: string; + procedure: AnyProcedure; incoming: Pushable; outgoing: Pushable< TransportMessage, Static>> @@ -66,13 +70,15 @@ export async function createServer>( async function cleanupStream(id: string) { const stream = streamMap.get(id); - if (stream) { - stream.incoming.end(); - await stream.promises.inputHandler; - stream.outgoing.end(); - await stream.promises.outputHandler; - streamMap.delete(id); + if (!stream) { + return; } + + stream.incoming.end(); + await stream.promises.inputHandler; + stream.outgoing.end(); + await stream.promises.outputHandler; + streamMap.delete(id); } function getContext(service: AnyService) { @@ -99,7 +105,35 @@ export async function createServer>( return; } - if (!(message.serviceName in services)) { + const streamIdx = message.streamId; + const procStream = streamMap.get(streamIdx); + if (procStream) { + // If the stream is a continuation, we do not admit the init messages. + if (Value.Check(procStream.procedure.input, message.payload)) { + procStream.incoming.push(message as TransportMessage); + } else if (!Value.Check(ControlMessagePayloadSchema, message.payload)) { + log?.error( + `${transport.clientId} -- procedure ${procStream.serviceName}.${ + procStream.procedureName + } received invalid payload: ${JSON.stringify(message.payload)}`, + ); + } + + if (isStreamClose(message.controlFlags)) { + await cleanupStream(streamIdx); + } + + return; + } + + if (!isStreamOpen(message.controlFlags)) { + log?.warn( + `${transport.clientId} -- couldn't find a matching procedure stream for ${message.serviceName}.${message.procedureName}:${message.streamId}`, + ); + return; + } + + if (!message.serviceName || !(message.serviceName in services)) { log?.warn( `${transport.clientId} -- couldn't find service ${message.serviceName}`, ); @@ -108,177 +142,161 @@ export async function createServer>( const service = services[message.serviceName]; const serviceContext = getContext(service); - if (!(message.procedureName in service.procedures)) { + if ( + !message.procedureName || + !(message.procedureName in service.procedures) + ) { log?.warn( `${transport.clientId} -- couldn't find a matching procedure for ${message.serviceName}.${message.procedureName}`, ); return; } - const procedure = service.procedures[message.procedureName] as AnyProcedure; - const streamIdx = `${message.serviceName}.${message.procedureName}:${message.streamId}`; - if (isStreamOpen(message.controlFlags) && !streamMap.has(streamIdx)) { - const incoming: ProcStream['incoming'] = pushable({ objectMode: true }); - const outgoing: ProcStream['outgoing'] = pushable({ objectMode: true }); - const outputHandler: Promise = - // sending outgoing messages back to client - (async () => { - for await (const response of outgoing) { - transport.send(response); - } + const procHasInitMessage = 'init' in procedure; + const incoming: ProcStream['incoming'] = pushable({ objectMode: true }); + const outgoing: ProcStream['outgoing'] = pushable({ objectMode: true }); + const outputHandler: Promise = + // sending outgoing messages back to client + (async () => { + for await (const response of outgoing) { + transport.send(response); + } - // we ended, send a close bit back to the client - // only subscriptions and streams have streams the - // handler can close - if ( - procedure.type === 'subscription' || - procedure.type === 'stream' - ) { - transport.send( - closeStream( - transport.clientId, - message.from, - message.serviceName, - message.procedureName, - message.streamId, - ), - ); + // we ended, send a close bit back to the client + // only subscriptions and streams have streams the + // handler can close + if (procedure.type === 'subscription' || procedure.type === 'stream') { + transport.send( + closeStream(transport.clientId, message.from, message.streamId), + ); + } + })(); + + function errorHandler(err: unknown) { + const errorMsg = + err instanceof Error ? err.message : `[coerced to error] ${err}`; + log?.error( + `${transport.clientId} -- procedure ${message.serviceName}.${message.procedureName}:${message.streamId} threw an error: ${errorMsg}`, + ); + outgoing.push( + reply( + message, + Err({ + code: UNCAUGHT_ERROR, + message: errorMsg, + } satisfies Static), + ), + ); + } + + // pump incoming message stream -> handler -> outgoing message stream + let inputHandler: Promise; + if (procedure.type === 'stream') { + if (procHasInitMessage) { + inputHandler = (async () => { + const initMessage = await incoming.next(); + if (initMessage.done) { + return; } - })(); - function errorHandler(err: unknown) { - const errorMsg = - err instanceof Error ? err.message : `[coerced to error] ${err}`; - log?.error( - `${transport.clientId} -- procedure ${message.serviceName}.${message.procedureName}:${message.streamId} threw an error: ${errorMsg}`, - ); - outgoing.push( - reply( - message, - Err({ - code: UNCAUGHT_ERROR, - message: errorMsg, - } satisfies Static), - ), - ); + return procedure + .handler(serviceContext, initMessage.value, incoming, outgoing) + .catch(errorHandler); + })(); + } else { + inputHandler = procedure + .handler(serviceContext, incoming, outgoing) + .catch(errorHandler); } + } else if (procedure.type === 'rpc') { + inputHandler = (async () => { + const inputMessage = await incoming.next(); + if (inputMessage.done) { + return; + } - // pump incoming message stream -> handler -> outgoing message stream - let inputHandler: Promise; - if (procedure.type === 'stream') { - if ('init' in procedure) { - inputHandler = (async () => { - const initMessage = await incoming.next(); - if (initMessage.done) { - return; - } + try { + const outputMessage = await procedure.handler( + serviceContext, + inputMessage.value, + ); + outgoing.push(outputMessage); + } catch (err) { + errorHandler(err); + } + })(); + } else if (procedure.type === 'subscription') { + inputHandler = (async () => { + const inputMessage = await incoming.next(); + if (inputMessage.done) { + return; + } - return procedure - .handler(serviceContext, initMessage.value, incoming, outgoing) - .catch(errorHandler); - })(); - } else { - inputHandler = procedure - .handler(serviceContext, incoming, outgoing) - .catch(errorHandler); + try { + await procedure.handler(serviceContext, inputMessage.value, outgoing); + } catch (err) { + errorHandler(err); } - } else if (procedure.type === 'rpc') { + })(); + } else if (procedure.type === 'upload') { + if (procHasInitMessage) { inputHandler = (async () => { - const inputMessage = await incoming.next(); - if (inputMessage.done) { + const initMessage = await incoming.next(); + if (initMessage.done) { return; } try { const outputMessage = await procedure.handler( serviceContext, - inputMessage.value, + initMessage.value, + incoming, ); outgoing.push(outputMessage); } catch (err) { errorHandler(err); } })(); - } else if (procedure.type === 'subscription') { + } else { inputHandler = (async () => { - const inputMessage = await incoming.next(); - if (inputMessage.done) { - return; - } - try { - await procedure.handler( + const outputMessage = await procedure.handler( serviceContext, - inputMessage.value, - outgoing, + incoming, ); + outgoing.push(outputMessage); } catch (err) { errorHandler(err); } })(); - } else if (procedure.type === 'upload') { - if ('init' in procedure) { - inputHandler = (async () => { - const initMessage = await incoming.next(); - if (initMessage.done) { - return; - } - - try { - const outputMessage = await procedure.handler( - serviceContext, - initMessage.value, - incoming, - ); - outgoing.push(outputMessage); - } catch (err) { - errorHandler(err); - } - })(); - } else { - inputHandler = (async () => { - try { - const outputMessage = await procedure.handler( - serviceContext, - incoming, - ); - outgoing.push(outputMessage); - } catch (err) { - errorHandler(err); - } - })(); - } - } else { - // procedure is inferred to be never here as this is not a valid procedure type - // we cast just to log - log?.warn( - `${transport.clientId} -- got request for invalid procedure type ${ - (procedure as AnyProcedure).type - } at ${message.serviceName}.${message.procedureName}`, - ); - return; } - - streamMap.set(streamIdx, { - incoming, - outgoing, - promises: { inputHandler, outputHandler }, - }); - } - - const procStream = streamMap.get(streamIdx); - if (!procStream) { + } else { + // procedure is inferred to be never here as this is not a valid procedure type + // we cast just to log log?.warn( - `${transport.clientId} -- couldn't find a matching procedure stream for ${message.serviceName}.${message.procedureName}:${message.streamId}`, + `${transport.clientId} -- got request for invalid procedure type ${ + (procedure as AnyProcedure).type + } at ${message.serviceName}.${message.procedureName}`, ); return; } + streamMap.set(streamIdx, { + id: message.streamId, + incoming, + outgoing, + serviceName: message.serviceName, + procedureName: message.procedureName, + procedure, + promises: { inputHandler, outputHandler }, + }); + + // This is the first message, so we parse is as the initialization message, if supplied. if ( - Value.Check(procedure.input, message.payload) || - ('init' in procedure && Value.Check(procedure.init, message.payload)) + (!procHasInitMessage && Value.Check(procedure.input, message.payload)) || + (procHasInitMessage && Value.Check(procedure.init, message.payload)) ) { - procStream.incoming.push(message as TransportMessage); + incoming.push(message as TransportMessage); } else if (!Value.Check(ControlMessagePayloadSchema, message.payload)) { log?.error( `${transport.clientId} -- procedure ${message.serviceName}.${ diff --git a/transport/impls/ws/ws.test.ts b/transport/impls/ws/ws.test.ts index 26f3024a..a12c0e0f 100644 --- a/transport/impls/ws/ws.test.ts +++ b/transport/impls/ws/ws.test.ts @@ -38,9 +38,16 @@ describe('sending and receiving across websockets works', async () => { test('sending respects to/from fields', async () => { const makeDummyMessage = (from: string, to: string, message: string) => { - return msg(from, to, 'service', 'proc', 'stream', { - msg: message, - }); + return msg( + from, + to, + 'stream', + { + msg: message, + }, + 'service', + 'proc', + ); }; const clientId1 = 'client1'; diff --git a/transport/message.test.ts b/transport/message.test.ts index c58b3152..58996812 100644 --- a/transport/message.test.ts +++ b/transport/message.test.ts @@ -10,7 +10,7 @@ import { describe, test, expect } from 'vitest'; describe('message helpers', () => { test('ack', () => { - const m = msg('a', 'b', 'svc', 'proc', 'stream', { test: 1 }); + const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); m.controlFlags |= ControlFlags.AckBit; expect(m).toHaveProperty('controlFlags'); expect(isAck(m.controlFlags)).toBe(true); @@ -19,7 +19,7 @@ describe('message helpers', () => { }); test('streamOpen', () => { - const m = msg('a', 'b', 'svc', 'proc', 'stream', { test: 1 }); + const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); m.controlFlags |= ControlFlags.StreamOpenBit; expect(m).toHaveProperty('controlFlags'); expect(isAck(m.controlFlags)).toBe(false); @@ -28,7 +28,7 @@ describe('message helpers', () => { }); test('streamClose', () => { - const m = msg('a', 'b', 'svc', 'proc', 'stream', { test: 1 }); + const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); m.controlFlags |= ControlFlags.StreamClosedBit; expect(m).toHaveProperty('controlFlags'); expect(isAck(m.controlFlags)).toBe(false); @@ -37,7 +37,7 @@ describe('message helpers', () => { }); test('reply', () => { - const m = msg('a', 'b', 'svc', 'proc', 'stream', { test: 1 }); + const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); const payload = { cool: 2 }; const resp = reply(m, payload); expect(resp.id).not.toBe(m.id); @@ -47,14 +47,14 @@ describe('message helpers', () => { }); test('default message has no control flags set', () => { - const m = msg('a', 'b', 'svc', 'proc', 'stream', { test: 1 }); + const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); expect(isAck(m.controlFlags)).toBe(false); expect(isStreamOpen(m.controlFlags)).toBe(false); expect(isStreamClose(m.controlFlags)).toBe(false); }); test('combining control flags works', () => { - const m = msg('a', 'b', 'svc', 'proc', 'stream', { test: 1 }); + const m = msg('a', 'b', 'stream', { test: 1 }, 'svc', 'proc'); m.controlFlags |= ControlFlags.StreamOpenBit; expect(isStreamOpen(m.controlFlags)).toBe(true); expect(isStreamClose(m.controlFlags)).toBe(false); diff --git a/transport/message.ts b/transport/message.ts index cfc89207..a18c2502 100644 --- a/transport/message.ts +++ b/transport/message.ts @@ -25,8 +25,8 @@ export const TransportMessageSchema = (t: T) => id: Type.String(), from: Type.String(), to: Type.String(), - serviceName: Type.String(), - procedureName: Type.String(), + serviceName: Type.Optional(Type.Union([Type.String(), Type.Null()])), + procedureName: Type.Optional(Type.Union([Type.String(), Type.Null()])), streamId: Type.String(), controlFlags: Type.Integer(), payload: t, @@ -59,6 +59,15 @@ export const OpaqueTransportMessageSchema = TransportMessageSchema( /** * Represents a transport message. This is the same type as {@link TransportMessageSchema} but * we can't statically infer generics from generic Typebox schemas so we have to define it again here. + * + * TypeScript can't enforce types when a bitmask is involved, so these are the semantics of + * `controlFlags`: + * * If `controlFlags & StreamOpenBit == StreamOpenBit`, `streamId` must be set to a unique value + * (suggestion: use `nanoid`). + * * `serviceName` and `procedureName` must be set only when `controlFlags & StreamOpenBit == + * StreamOpenBit`. + * * If `controlFlags & StreamClosedBit` is set and the kind is `stream` or `subscription`, + * `payload` can be a control message. * @template Payload The type of the payload. */ export type TransportMessage< @@ -67,8 +76,8 @@ export type TransportMessage< id: string; from: string; to: string; - serviceName: string; - procedureName: string; + serviceName?: string; + procedureName?: string; streamId: string; controlFlags: number; payload: Payload; @@ -97,18 +106,18 @@ export type TransportClientId = string; export function msg( from: string, to: string, - service: string, - proc: string, - stream: string, + streamId: string, payload: Payload, + serviceName?: string, + procedureName?: string, ): TransportMessage { return { id: nanoid(), to, from, - serviceName: service, - procedureName: proc, - streamId: stream, + serviceName, + procedureName, + streamId, controlFlags: 0, payload, }; @@ -125,9 +134,9 @@ export function reply( response: Payload, ): TransportMessage { return { - ...msg, - controlFlags: 0, id: nanoid(), + streamId: msg.streamId, + controlFlags: 0, to: msg.from, from: msg.to, payload: response, @@ -144,11 +153,9 @@ export function reply( export function closeStream( from: TransportClientId, to: TransportClientId, - service: string, - proc: string, stream: string, ) { - const closeMessage = msg(from, to, service, proc, stream, { + const closeMessage = msg(from, to, stream, { type: 'CLOSE' as const, } satisfies Static); closeMessage.controlFlags |= ControlFlags.StreamClosedBit; diff --git a/transport/transport.ts b/transport/transport.ts index 07f5fc0d..f943dc01 100644 --- a/transport/transport.ts +++ b/transport/transport.ts @@ -210,10 +210,21 @@ export abstract class Transport { } if (Value.Check(OpaqueTransportMessageSchema, parsedMsg)) { - return parsedMsg; + // JSON can't express the difference between `undefined` and `null`, so we need to patch that. + return { + ...parsedMsg, + serviceName: + parsedMsg.serviceName === null ? undefined : parsedMsg.serviceName, + procedureName: + parsedMsg.procedureName === null + ? undefined + : parsedMsg.procedureName, + }; } else { log?.warn( - `${this.clientId} -- received invalid msg: ${JSON.stringify(msg)}`, + `${this.clientId} -- received invalid msg: ${JSON.stringify( + parsedMsg, + )}`, ); return null; } diff --git a/util/testHelpers.ts b/util/testHelpers.ts index 754ebe30..8ccf0f6e 100644 --- a/util/testHelpers.ts +++ b/util/testHelpers.ts @@ -483,7 +483,7 @@ export function payloadToTransportMessage( from: TransportClientId = 'client', to: TransportClientId = 'SERVER', ): TransportMessage { - return msg(from, to, 'service', 'procedure', streamId ?? 'stream', payload); + return msg(from, to, streamId ?? 'stream', payload, 'service', 'procedure'); } /**