From f20170dc39e3e2664625ab4b8a945971e1f1773b Mon Sep 17 00:00:00 2001 From: Faris Masad Date: Fri, 31 May 2024 13:13:16 -0700 Subject: [PATCH] [protocolv2] Init always exists (#159) * I -> Input, O -> Output, E -> Err * Update types * Update implementation * Update tests * Misc changes * fix checking init instead of input --- PROTOCOL.md | 6 +- __tests__/bandwidth.bench.ts | 2 +- __tests__/cleanup.test.ts | 9 +- __tests__/disconnects.test.ts | 4 +- __tests__/e2e.test.ts | 13 +- __tests__/fixtures/services.ts | 53 +++- __tests__/handler.test.ts | 15 +- __tests__/serialize.test.ts | 100 ++++++- __tests__/typescript-stress.test.ts | 22 +- router/client.ts | 147 +++------ router/procedures.ts | 450 +++++++++++----------------- router/result.ts | 22 +- router/server.ts | 306 ++++++++----------- router/services.ts | 54 ++-- util/testHelpers.ts | 135 ++++----- 15 files changed, 613 insertions(+), 725 deletions(-) diff --git a/PROTOCOL.md b/PROTOCOL.md index 55113322..d266cd31 100644 --- a/PROTOCOL.md +++ b/PROTOCOL.md @@ -95,9 +95,9 @@ interface BaseError { The `Result` type MUST conform to: ```ts -type Result = - | { ok: true; payload: T } - | { ok: false; payload: E }; +type Result = + | { ok: true; payload: SuccessPayload } + | { ok: false; payload: ErrorPayload }; ``` The messages in either direction must also contain additional information so that the receiving party knows where to route the message payload. This wrapper message is referred to as a `TransportMessage` and its payload can be a `Control`, a `Result`, an `Init`, an `Input`, or an `Output`. The schema for the transport message is as follows: diff --git a/__tests__/bandwidth.bench.ts b/__tests__/bandwidth.bench.ts index 2c2b44bd..546f2ad4 100644 --- a/__tests__/bandwidth.bench.ts +++ b/__tests__/bandwidth.bench.ts @@ -52,7 +52,7 @@ describe('bandwidth', async () => { { time: BENCH_DURATION }, ); - const [inputWriter, outputReader] = await client.test.echo.stream(); + const [inputWriter, outputReader] = await client.test.echo.stream({}); bench( `${name} -- stream`, async () => { diff --git a/__tests__/cleanup.test.ts b/__tests__/cleanup.test.ts index 95081727..0bfeda71 100644 --- a/__tests__/cleanup.test.ts +++ b/__tests__/cleanup.test.ts @@ -182,8 +182,9 @@ describe.each(testMatrix())( clientTransport.eventDispatcher.numberOfListeners('message'); // start procedure - const [inputWriter, outputReader, close] = - await client.test.echo.stream(); + const [inputWriter, outputReader, close] = await client.test.echo.stream( + {}, + ); inputWriter.write({ msg: '1', ignore: false, end: undefined }); inputWriter.write({ msg: '2', ignore: false, end: true }); @@ -318,7 +319,7 @@ describe.each(testMatrix())( // start procedure const [inputWriter, addResult] = - await client.uploadable.addMultiple.upload(); + await client.uploadable.addMultiple.upload({}); inputWriter.write({ n: 1 }); inputWriter.write({ n: 2 }); inputWriter.close(); @@ -368,7 +369,7 @@ describe.each(testMatrix())( }); // start a stream - const [inputWriter, outputReader] = await client.test.echo.stream(); + const [inputWriter, outputReader] = await client.test.echo.stream({}); inputWriter.write({ msg: '1', ignore: false }); const outputIterator = getIteratorFromStream(outputReader); diff --git a/__tests__/disconnects.test.ts b/__tests__/disconnects.test.ts index ca2cfae4..c56f4a8e 100644 --- a/__tests__/disconnects.test.ts +++ b/__tests__/disconnects.test.ts @@ -91,7 +91,7 @@ describe.each(testMatrix())( }); // start procedure - const [inputWriter, outputReader] = await client.test.echo.stream(); + const [inputWriter, outputReader] = await client.test.echo.stream({}); const outputIterator = getIteratorFromStream(outputReader); inputWriter.write({ msg: 'abc', ignore: false }); @@ -236,7 +236,7 @@ describe.each(testMatrix())( // start procedure const [inputWriter, addResult] = - await client.uploadable.addMultiple.upload(); + await client.uploadable.addMultiple.upload({}); inputWriter.write({ n: 1 }); inputWriter.write({ n: 2 }); // end procedure diff --git a/__tests__/e2e.test.ts b/__tests__/e2e.test.ts index c89c98f8..ef09dec2 100644 --- a/__tests__/e2e.test.ts +++ b/__tests__/e2e.test.ts @@ -151,8 +151,9 @@ describe.each(testMatrix())( }); // test - const [inputWriter, outputReader, close] = - await client.test.echo.stream(); + const [inputWriter, outputReader, close] = await client.test.echo.stream( + {}, + ); const outputIterator = getIteratorFromStream(outputReader); inputWriter.write({ msg: 'abc', ignore: false }); @@ -241,7 +242,7 @@ describe.each(testMatrix())( // test const [inputWriter, outputReader, close] = - await client.fallible.echo.stream(); + await client.fallible.echo.stream({}); const outputIterator = getIteratorFromStream(outputReader); inputWriter.write({ msg: 'abc', throwResult: false, throwError: false }); const result1 = await iterNext(outputIterator); @@ -332,7 +333,7 @@ describe.each(testMatrix())( // test const [inputWriter, addResult] = - await client.uploadable.addMultiple.upload(); + await client.uploadable.addMultiple.upload({}); inputWriter.write({ n: 1 }); inputWriter.write({ n: 2 }); inputWriter.close(); @@ -478,7 +479,7 @@ describe.each(testMatrix())( // test const openStreams = []; for (let i = 0; i < CONCURRENCY; i++) { - const streamHandle = await client.test.echo.stream(); + const streamHandle = await client.test.echo.stream({}); const inputWriter = streamHandle[0]; inputWriter.write({ msg: `${i}-1`, ignore: false }); inputWriter.write({ msg: `${i}-2`, ignore: false }); @@ -690,7 +691,7 @@ describe.each(testMatrix())( const services = { test: ServiceSchema.define({ getData: Procedure.rpc({ - input: Type.Object({}), + init: Type.Object({}), output: Type.Object({ data: Type.String(), extra: Type.Number(), diff --git a/__tests__/fixtures/services.ts b/__tests__/fixtures/services.ts index 27ff7b67..e0e44220 100644 --- a/__tests__/fixtures/services.ts +++ b/__tests__/fixtures/services.ts @@ -17,7 +17,7 @@ const TestServiceScaffold = ServiceSchema.scaffold({ const testServiceProcedures = TestServiceScaffold.procedures({ add: Procedure.rpc({ - input: Type.Object({ n: Type.Number() }), + init: Type.Object({ n: Type.Number() }), output: Type.Object({ result: Type.Number() }), async handler(ctx, { n }) { ctx.state.count += n; @@ -26,7 +26,7 @@ const testServiceProcedures = TestServiceScaffold.procedures({ }), array: Procedure.rpc({ - input: Type.Object({ n: Type.Number() }), + init: Type.Object({ n: Type.Number() }), output: Type.Array(Type.Number()), async handler(ctx, { n }) { ctx.state.count += n; @@ -35,9 +35,10 @@ const testServiceProcedures = TestServiceScaffold.procedures({ }), arrayStream: Procedure.stream({ + init: Type.Object({}), input: Type.Object({ n: Type.Number() }), output: Type.Array(Type.Number()), - async handler(_, msgStream, returnStream) { + async handler(_, _init, msgStream, returnStream) { for await (const msg of msgStream) { returnStream.write(Ok([msg.n])); } @@ -45,9 +46,10 @@ const testServiceProcedures = TestServiceScaffold.procedures({ }), echo: Procedure.stream({ + init: Type.Object({}), input: EchoRequest, output: EchoResponse, - async handler(_ctx, msgStream, returnStream) { + async handler(_ctx, _init, msgStream, returnStream) { for await (const { ignore, msg, end } of msgStream) { if (!ignore) { returnStream.write(Ok({ response: msg })); @@ -75,7 +77,7 @@ const testServiceProcedures = TestServiceScaffold.procedures({ echoUnion: Procedure.rpc({ description: 'Echos back whatever we sent', - input: Type.Union([ + init: Type.Union([ Type.Object( { a: Type.Number({ description: 'A number' }) }, { description: 'A' }, @@ -99,6 +101,23 @@ const testServiceProcedures = TestServiceScaffold.procedures({ return Ok(input); }, }), + + unimplementedUpload: Procedure.upload({ + init: Type.Object({}), + input: Type.Object({}), + output: Type.Object({}), + async handler() { + throw new Error('Not implemented'); + }, + }), + + unimplementedSubscription: Procedure.subscription({ + init: Type.Object({}), + output: Type.Object({}), + async handler() { + throw new Error('Not implemented'); + }, + }), }); export const TestServiceSchema = TestServiceScaffold.finalize({ @@ -109,7 +128,7 @@ export const OrderingServiceSchema = ServiceSchema.define( { initializeState: () => ({ msgs: [] as Array }) }, { add: Procedure.rpc({ - input: Type.Object({ n: Type.Number() }), + init: Type.Object({ n: Type.Number() }), output: Type.Object({ n: Type.Number() }), async handler(ctx, { n }) { ctx.state.msgs.push(n); @@ -118,7 +137,7 @@ export const OrderingServiceSchema = ServiceSchema.define( }), getAll: Procedure.rpc({ - input: Type.Object({}), + init: Type.Object({}), output: Type.Object({ msgs: Type.Array(Type.Number()) }), async handler(ctx, _msg) { return Ok({ msgs: ctx.state.msgs }); @@ -129,7 +148,7 @@ export const OrderingServiceSchema = ServiceSchema.define( export const BinaryFileServiceSchema = ServiceSchema.define({ getFile: Procedure.rpc({ - input: Type.Object({ file: Type.String() }), + init: Type.Object({ file: Type.String() }), output: Type.Object({ contents: Type.Uint8Array() }), async handler(_ctx, { file }) { const bytes: Uint8Array = Buffer.from(`contents for file ${file}`); @@ -143,7 +162,7 @@ export const STREAM_ERROR = 'STREAM_ERROR'; export const FallibleServiceSchema = ServiceSchema.define({ divide: Procedure.rpc({ - input: Type.Object({ a: Type.Number(), b: Type.Number() }), + init: Type.Object({ a: Type.Number(), b: Type.Number() }), output: Type.Object({ result: Type.Number() }), errors: Type.Union([ Type.Object({ @@ -166,6 +185,7 @@ export const FallibleServiceSchema = ServiceSchema.define({ }), echo: Procedure.stream({ + init: Type.Object({}), input: Type.Object({ msg: Type.String(), throwResult: Type.Boolean(), @@ -176,7 +196,7 @@ export const FallibleServiceSchema = ServiceSchema.define({ code: Type.Literal(STREAM_ERROR), message: Type.String(), }), - async handler(_ctx, msgStream, returnStream) { + async handler(_ctx, _init, msgStream, returnStream) { for await (const { msg, throwError, throwResult } of msgStream) { if (throwError) { throw new Error('some message'); @@ -199,7 +219,7 @@ export const SubscribableServiceSchema = ServiceSchema.define( { initializeState: () => ({ count: new Observable(0) }) }, { add: Procedure.rpc({ - input: Type.Object({ n: Type.Number() }), + init: Type.Object({ n: Type.Number() }), output: Type.Object({ result: Type.Number() }), async handler(ctx, { n }) { ctx.state.count.set((prev) => prev + n); @@ -208,7 +228,7 @@ export const SubscribableServiceSchema = ServiceSchema.define( }), value: Procedure.subscription({ - input: Type.Object({}), + init: Type.Object({}), output: Type.Object({ result: Type.Number() }), async handler(ctx, _msg, returnStream) { return ctx.state.count.observe((count) => { @@ -221,9 +241,10 @@ export const SubscribableServiceSchema = ServiceSchema.define( export const UploadableServiceSchema = ServiceSchema.define({ addMultiple: Procedure.upload({ + init: Type.Object({}), input: Type.Object({ n: Type.Number() }), output: Type.Object({ result: Type.Number() }), - async handler(_ctx, msgStream) { + async handler(_ctx, _init, msgStream) { let result = 0; for await (const { n } of msgStream) { result += n; @@ -256,7 +277,7 @@ const RecursivePayload = Type.Recursive((This) => export const NonObjectSchemas = ServiceSchema.define({ add: Procedure.rpc({ - input: Type.Number(), + init: Type.Number(), output: Type.Number(), async handler(_ctx, n) { return Ok(n + 1); @@ -264,7 +285,7 @@ export const NonObjectSchemas = ServiceSchema.define({ }), echoRecursive: Procedure.rpc({ - input: RecursivePayload, + init: RecursivePayload, output: RecursivePayload, async handler(_ctx, msg) { return Ok(msg); @@ -277,7 +298,7 @@ export function SchemaWithDisposableState(dispose: () => void) { { initializeState: () => ({ [Symbol.dispose]: dispose }) }, { add: Procedure.rpc({ - input: Type.Number(), + init: Type.Number(), output: Type.Number(), async handler(_ctx, n) { return Ok(n + 1); diff --git a/__tests__/handler.test.ts b/__tests__/handler.test.ts index 648060f9..aabaf398 100644 --- a/__tests__/handler.test.ts +++ b/__tests__/handler.test.ts @@ -18,7 +18,7 @@ import { import { UNCAUGHT_ERROR } from '../router/result'; import { Observable } from './fixtures/observable'; -describe.skip('server-side test', () => { +describe('server-side test', () => { const service = TestServiceSchema.instantiate(); test('rpc basic', async () => { @@ -73,7 +73,12 @@ describe.skip('server-side test', () => { assert(result2.ok); expect(result2.payload).toStrictEqual({ response: 'ghi' }); - expect(outputIterator.next()).toEqual({ done: true, value: undefined }); + await outputReader.requestClose(); + + expect(await outputIterator.next()).toEqual({ + done: true, + value: undefined, + }); }); test('stream with initialization', async () => { @@ -96,8 +101,12 @@ describe.skip('server-side test', () => { const result2 = await iterNext(outputIterator); assert(result2.ok); expect(result2.payload).toStrictEqual({ response: 'test ghi' }); + await outputReader.requestClose(); - expect(outputIterator.next()).toEqual({ done: true, value: undefined }); + expect(await outputIterator.next()).toEqual({ + done: true, + value: undefined, + }); }); test('fallible stream', async () => { diff --git a/__tests__/serialize.test.ts b/__tests__/serialize.test.ts index 0c7fb0e2..733cfb1a 100644 --- a/__tests__/serialize.test.ts +++ b/__tests__/serialize.test.ts @@ -26,7 +26,7 @@ describe('serialize server to jsonschema', () => { test: { procedures: { add: { - input: { + init: { properties: { n: { type: 'number' }, }, @@ -49,7 +49,7 @@ describe('serialize server to jsonschema', () => { errors: { not: {}, }, - input: { + init: { properties: { n: { type: 'number', @@ -70,6 +70,10 @@ describe('serialize server to jsonschema', () => { errors: { not: {}, }, + init: { + properties: {}, + type: 'object', + }, input: { properties: { n: { @@ -88,6 +92,10 @@ describe('serialize server to jsonschema', () => { type: 'stream', }, echo: { + init: { + properties: {}, + type: 'object', + }, input: { properties: { msg: { type: 'string' }, @@ -153,7 +161,7 @@ describe('serialize server to jsonschema', () => { errors: { not: {}, }, - input: { + init: { anyOf: [ { description: 'A', @@ -207,6 +215,38 @@ describe('serialize server to jsonschema', () => { }, type: 'rpc', }, + unimplementedSubscription: { + errors: { + not: {}, + }, + init: { + properties: {}, + type: 'object', + }, + output: { + properties: {}, + type: 'object', + }, + type: 'subscription', + }, + unimplementedUpload: { + errors: { + not: {}, + }, + init: { + properties: {}, + type: 'object', + }, + input: { + properties: {}, + type: 'object', + }, + output: { + properties: {}, + type: 'object', + }, + type: 'upload', + }, }, }, }, @@ -219,7 +259,7 @@ describe('serialize service to jsonschema', () => { expect(TestServiceSchema.serialize()).toStrictEqual({ procedures: { add: { - input: { + init: { properties: { n: { type: 'number' }, }, @@ -242,7 +282,7 @@ describe('serialize service to jsonschema', () => { errors: { not: {}, }, - input: { + init: { properties: { n: { type: 'number', @@ -263,6 +303,10 @@ describe('serialize service to jsonschema', () => { errors: { not: {}, }, + init: { + properties: {}, + type: 'object', + }, input: { properties: { n: { @@ -281,6 +325,10 @@ describe('serialize service to jsonschema', () => { type: 'stream', }, echo: { + init: { + properties: {}, + type: 'object', + }, input: { properties: { msg: { type: 'string' }, @@ -346,7 +394,7 @@ describe('serialize service to jsonschema', () => { errors: { not: {}, }, - input: { + init: { anyOf: [ { description: 'A', @@ -400,6 +448,38 @@ describe('serialize service to jsonschema', () => { }, type: 'rpc', }, + unimplementedSubscription: { + errors: { + not: {}, + }, + init: { + properties: {}, + type: 'object', + }, + output: { + properties: {}, + type: 'object', + }, + type: 'subscription', + }, + unimplementedUpload: { + errors: { + not: {}, + }, + init: { + properties: {}, + type: 'object', + }, + input: { + properties: {}, + type: 'object', + }, + output: { + properties: {}, + type: 'object', + }, + type: 'upload', + }, }, }); }); @@ -411,7 +491,7 @@ describe('serialize service to jsonschema', () => { errors: { not: {}, }, - input: { + init: { properties: { file: { type: 'string', @@ -439,7 +519,7 @@ describe('serialize service to jsonschema', () => { expect(FallibleServiceSchema.serialize()).toStrictEqual({ procedures: { divide: { - input: { + init: { properties: { a: { type: 'number' }, b: { type: 'number' }, @@ -487,6 +567,10 @@ describe('serialize service to jsonschema', () => { required: ['code', 'message'], type: 'object', }, + init: { + properties: {}, + type: 'object', + }, input: { properties: { msg: { diff --git a/__tests__/typescript-stress.test.ts b/__tests__/typescript-stress.test.ts index 4edaea9b..313e1de0 100644 --- a/__tests__/typescript-stress.test.ts +++ b/__tests__/typescript-stress.test.ts @@ -41,7 +41,7 @@ const fnBody = Procedure.rpc< typeof output, typeof errors >({ - input, + init: input, output, errors, async handler(_state, msg) { @@ -210,30 +210,32 @@ describe("ensure typescript doesn't give up trying to infer the types for large const services = { test: ServiceSchema.define({ rpc: Procedure.rpc({ - input: Type.Object({ n: Type.Number() }), + init: Type.Object({ n: Type.Number() }), output: Type.Object({ n: Type.Number() }), async handler(_, { n }) { return Ok({ n }); }, }), stream: Procedure.stream({ + init: Type.Object({}), input: Type.Object({ n: Type.Number() }), output: Type.Object({ n: Type.Number() }), - async handler(_c, _in, output) { - output.push(Ok({ n: 1 })); + async handler(_c, _init, _input, output) { + output.write(Ok({ n: 1 })); }, }), subscription: Procedure.subscription({ - input: Type.Object({ n: Type.Number() }), + init: Type.Object({ n: Type.Number() }), output: Type.Object({ n: Type.Number() }), - async handler(_c, _in, output) { - output.push(Ok({ n: 1 })); + async handler(_c, _init, output) { + output.write(Ok({ n: 1 })); }, }), upload: Procedure.upload({ + init: Type.Object({}), input: Type.Object({ n: Type.Number() }), output: Type.Object({ n: Type.Number() }), - async handler(_c, _in) { + async handler(_c, _init, _input) { return Ok({ n: 1 }); }, }), @@ -267,7 +269,7 @@ describe('Output<> type', () => { // Then void client.test.stream - .stream() + .stream({}) .then(([_in, outputReader, _close]) => iterNext(getIteratorFromStream(outputReader)), ) @@ -301,7 +303,7 @@ describe('Output<> type', () => { // Then void client.test.upload - .upload() + .upload({}) .then(([_input, result]) => result) .then(acceptOutput); expect(client).toBeTruthy(); diff --git a/router/client.ts b/router/client.ts index 08637943..4ea58bbb 100644 --- a/router/client.ts +++ b/router/client.ts @@ -2,7 +2,6 @@ import { ClientTransport } from '../transport/transport'; import { AnyService, ProcErrors, - ProcHasInit, ProcInit, ProcInput, ProcOutput, @@ -46,7 +45,7 @@ type ServiceClient = { > extends 'rpc' ? { rpc: ( - input: Static>, + init: Static>, ) => Promise< Result< Static>, @@ -55,66 +54,37 @@ type ServiceClient = { >; } : ProcType extends 'upload' - ? ProcHasInit extends true - ? { - upload: (init: Static>) => Promise< - [ - WriteStream>>, // input - Promise< - Result< - Static>, - Static> - > - >, // output - ] - >; - } - : { - upload: () => Promise< - [ - WriteStream>>, // input - Promise< - Result< - Static>, - Static> - > - >, // output - ] - >; - } + ? { + upload: (init: Static>) => Promise< + [ + WriteStream>>, // input + Promise< + Result< + Static>, + Static> + > + >, // output + ] + >; + } : ProcType extends 'stream' - ? ProcHasInit extends true - ? { - stream: (init: Static>) => Promise< - [ - WriteStream>>, // input - ReadStream< - Result< - Static>, - Static> - > - >, // output - () => void, // close handle - ] - >; - } - : { - stream: () => Promise< - [ - WriteStream>>, // input - ReadStream< - Result< - Static>, - Static> - > - >, // output - () => void, // close handle - ] - >; - } + ? { + stream: (init: Static>) => Promise< + [ + WriteStream>>, // input + ReadStream< + Result< + Static>, + Static> + > + >, // output + () => void, // close handle + ] + >; + } : ProcType extends 'subscription' ? { - subscribe: (input: Static>) => Promise< + subscribe: (init: Static>) => Promise< [ ReadStream< Result< @@ -332,9 +302,7 @@ function handleStream( procedureName, streamId, ); - let firstMessage = true; let healthyClose = true; - const inputWriter = new WriteStreamImpl( (rawIn: unknown) => { const m: PartialTransportMessage = { @@ -343,14 +311,6 @@ function handleStream( controlFlags: 0, }; - if (firstMessage) { - m.serviceName = serviceName; - m.procedureName = procedureName; - m.tracing = getPropagationContext(ctx); - m.controlFlags |= ControlFlags.StreamOpenBit; - firstMessage = false; - } - transport.send(serverId, m); }, () => { @@ -362,18 +322,14 @@ function handleStream( const readStreamRequestCloseNotImplemented = () => undefined; const outputReader = new ReadStreamImpl(readStreamRequestCloseNotImplemented); - if (init) { - transport.send(serverId, { - streamId, - serviceName, - procedureName, - tracing: getPropagationContext(ctx), - payload: init, - controlFlags: ControlFlags.StreamOpenBit, - }); - - firstMessage = false; - } + transport.send(serverId, { + streamId, + serviceName, + procedureName, + tracing: getPropagationContext(ctx), + payload: init, + controlFlags: ControlFlags.StreamOpenBit, + }); // transport -> output function onMessage(msg: OpaqueTransportMessage) { @@ -512,7 +468,6 @@ function handleUpload( streamId, ); - let firstMessage = true; let healthyClose = true; const inputWriter = new WriteStreamImpl( @@ -523,14 +478,6 @@ function handleUpload( controlFlags: 0, }; - if (firstMessage) { - m.serviceName = serviceName; - m.procedureName = procedureName; - m.tracing = getPropagationContext(ctx); - m.controlFlags |= ControlFlags.StreamOpenBit; - firstMessage = false; - } - transport.send(serverId, m); }, () => { @@ -540,18 +487,14 @@ function handleUpload( }, ); - if (init) { - transport.send(serverId, { - streamId, - serviceName, - procedureName, - tracing: getPropagationContext(ctx), - payload: init, - controlFlags: ControlFlags.StreamOpenBit, - }); - - firstMessage = false; - } + transport.send(serverId, { + streamId, + serviceName, + procedureName, + tracing: getPropagationContext(ctx), + payload: init, + controlFlags: ControlFlags.StreamOpenBit, + }); const responsePromise = new Promise((resolve) => { // on disconnect, set a timer to return an error diff --git a/router/procedures.ts b/router/procedures.ts index 216bbb44..97f356c6 100644 --- a/router/procedures.ts +++ b/router/procedures.ts @@ -21,11 +21,11 @@ export type Unbranded = T extends Branded ? U : never; export type ValidProcType = // Single message in both directions (1:1). | 'rpc' - // Client-stream (potentially preceded by an initialization message), single message from server (n:1). + // Client-stream single message from server (n:1). | 'upload' // Single message from client, stream from server (1:n). | 'subscription' - // Bidirectional stream (potentially preceded by an initialization message) (n:n). + // Bidirectional stream (n:n). | 'stream'; /** @@ -38,33 +38,33 @@ export type PayloadType = TSchema; * from a single message. */ export type ProcedureResult< - O extends PayloadType, - E extends RiverError, -> = Result, Static | Static>; + Output extends PayloadType, + Err extends RiverError, +> = Result, Static | Static>; /** * Procedure for a single message in both directions (1:1). * * @template State - The context state object. - * @template I - The TypeBox schema of the input object. - * @template O - The TypeBox schema of the output object. - * @template E - The TypeBox schema of the error object. + * @template Init - The TypeBox schema of the initialization object. + * @template Output - The TypeBox schema of the output object. + * @template Err - The TypeBox schema of the error object. */ export interface RpcProcedure< State, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, + Init extends PayloadType, + Output extends PayloadType, + Err extends RiverError, > { type: 'rpc'; - input: I; - output: O; - errors: E; + init: Init; + output: Output; + errors: Err; description?: string; handler( context: ServiceContextWithTransportInfo, - input: Static, - ): Promise>; + init: Static, + ): Promise>; } /** @@ -72,66 +72,54 @@ export interface RpcProcedure< * single message from server (n:1). * * @template State - The context state object. - * @template I - The TypeBox schema of the input object. - * @template O - The TypeBox schema of the output object. - * @template E - The TypeBox schema of the error object. - * @template Init - The TypeBox schema of the input initialization object, if any. + * @template Init - The TypeBox schema of the initialization object. + * @template Input - The TypeBox schema of the input object. + * @template Output - The TypeBox schema of the output object. + * @template Err - The TypeBox schema of the error object. */ -export type UploadProcedure< +export interface UploadProcedure< State, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, - Init extends PayloadType | null = null, -> = Init extends PayloadType - ? { - type: 'upload'; - init: Init; - input: I; - output: O; - errors: E; - description?: string; - handler( - context: ServiceContextWithTransportInfo, - init: Static, - input: ReadStream>, - ): Promise>; - } - : { - type: 'upload'; - input: I; - output: O; - errors: E; - description?: string; - handler( - context: ServiceContextWithTransportInfo, - input: ReadStream>, - ): Promise>; - }; + Init extends PayloadType, + Input extends PayloadType, + Output extends PayloadType, + Err extends RiverError, +> { + type: 'upload'; + init: Init; + input: Input; + output: Output; + errors: Err; + description?: string; + handler( + context: ServiceContextWithTransportInfo, + init: Static, + input: ReadStream>, + ): Promise>; +} /** * Procedure for a single message from client, stream from server (1:n). * * @template State - The context state object. - * @template I - The TypeBox schema of the input object. - * @template O - The TypeBox schema of the output object. - * @template E - The TypeBox schema of the error object. + * @template Init - The TypeBox schema of the initialization object. + * @template Output - The TypeBox schema of the output object. + * @template Err - The TypeBox schema of the error object. */ export interface SubscriptionProcedure< State, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, + Init extends PayloadType, + Output extends PayloadType, + Err extends RiverError, > { type: 'subscription'; - input: I; - output: O; - errors: E; + init: Init; + output: Output; + errors: Err; description?: string; handler( context: ServiceContextWithTransportInfo, - input: Static, - output: WriteStream>, + init: Static, + output: WriteStream>, ): Promise<(() => void) | void>; } @@ -140,44 +128,31 @@ export interface SubscriptionProcedure< * (n:n). * * @template State - The context state object. - * @template I - The TypeBox schema of the input object. - * @template O - The TypeBox schema of the output object. - * @template E - The TypeBox schema of the error object. - * @template Init - The TypeBox schema of the input initialization object, if any. + * @template Init - The TypeBox schema of the initialization object. + * @template Input - The TypeBox schema of the input object. + * @template Output - The TypeBox schema of the output object. + * @template Err - The TypeBox schema of the error object. */ -export type StreamProcedure< +export interface StreamProcedure< State, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, - Init extends PayloadType | null = null, -> = Init extends PayloadType - ? { - type: 'stream'; - init: Init; - input: I; - output: O; - errors: E; - description?: string; - handler( - context: ServiceContextWithTransportInfo, - init: Static, - input: ReadStream>, - output: WriteStream>, - ): Promise<(() => void) | void>; - } - : { - type: 'stream'; - input: I; - output: O; - errors: E; - description?: string; - handler( - context: ServiceContextWithTransportInfo, - input: ReadStream>, - output: WriteStream>, - ): Promise<(() => void) | void>; - }; + Init extends PayloadType, + Input extends PayloadType, + Output extends PayloadType, + Err extends RiverError, +> { + type: 'stream'; + init: Init; + input: Input; + output: Output; + errors: Err; + description?: string; + handler( + context: ServiceContextWithTransportInfo, + init: Static, + input: ReadStream>, + output: WriteStream>, + ): Promise<(() => void) | void>; +} /** * Defines a Procedure type that can be a: @@ -190,29 +165,28 @@ export type StreamProcedure< * * @template State - The TypeBox schema of the state object. * @template Ty - The type of the procedure. - * @template I - The TypeBox schema of the input object. - * @template O - The TypeBox schema of the output object. + * @template Input - The TypeBox schema of the input object. * @template Init - The TypeBox schema of the input initialization object, if any. + * @template Output - The TypeBox schema of the output object. */ -// prettier-ignore export type Procedure< State, Ty extends ValidProcType, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, - Init extends PayloadType | null = null, -> = { type: Ty } & ( - Init extends PayloadType - ? Ty extends 'upload' ? UploadProcedure - : Ty extends 'stream' ? StreamProcedure - : never - : Ty extends 'rpc' ? RpcProcedure - : Ty extends 'upload' ? UploadProcedure - : Ty extends 'subscription' ? SubscriptionProcedure - : Ty extends 'stream' ? StreamProcedure - : never -); + Init extends PayloadType, + Input extends PayloadType | null, + Output extends PayloadType, + Err extends RiverError, +> = { type: Ty } & (Input extends PayloadType + ? Ty extends 'upload' + ? UploadProcedure + : Ty extends 'stream' + ? StreamProcedure + : never + : Ty extends 'rpc' + ? RpcProcedure + : Ty extends 'subscription' + ? SubscriptionProcedure + : never); /** * Represents any {@link Procedure} type. @@ -224,9 +198,9 @@ export type AnyProcedure = Procedure< State, ValidProcType, PayloadType, + PayloadType | null, PayloadType, - RiverError, - PayloadType | null + RiverError >; /** @@ -245,37 +219,37 @@ export type ProcedureMap = Record>; * Creates an {@link RpcProcedure}. */ // signature: default errors -function rpc(def: { - input: I; - output: O; +function rpc(def: { + init: Init; + output: Output; errors?: never; description?: string; - handler: RpcProcedure['handler']; -}): Branded>; + handler: RpcProcedure['handler']; +}): Branded>; // signature: explicit errors function rpc< State, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, + Init extends PayloadType, + Output extends PayloadType, + Err extends RiverError, >(def: { - input: I; - output: O; - errors: E; + init: Init; + output: Output; + errors: Err; description?: string; - handler: RpcProcedure['handler']; -}): Branded>; + handler: RpcProcedure['handler']; +}): Branded>; // implementation function rpc({ - input, + init, output, errors = Type.Never(), description, handler, }: { - input: PayloadType; + init: PayloadType; output: PayloadType; errors?: RiverError; description?: string; @@ -289,7 +263,7 @@ function rpc({ return { ...(description ? { description } : {}), type: 'rpc', - input, + init, output, errors, handler, @@ -302,58 +276,33 @@ function rpc({ // signature: init with default errors function upload< State, - I extends PayloadType, - O extends PayloadType, Init extends PayloadType, + Input extends PayloadType, + Output extends PayloadType, >(def: { init: Init; - input: I; - output: O; + input: Input; + output: Output; errors?: never; description?: string; - handler: UploadProcedure['handler']; -}): Branded>; + handler: UploadProcedure['handler']; +}): Branded>; // signature: init with explicit errors function upload< State, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, Init extends PayloadType, + Input extends PayloadType, + Output extends PayloadType, + Err extends RiverError, >(def: { init: Init; - input: I; - output: O; - errors: E; - description?: string; - handler: UploadProcedure['handler']; -}): Branded>; - -// signature: no init with default errors -function upload(def: { - init?: never; - input: I; - output: O; - errors?: never; - description?: string; - handler: UploadProcedure['handler']; -}): Branded>; - -// signature: no init with explicit errors -function upload< - State, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, ->(def: { - init?: never; - input: I; - output: O; - errors: E; + input: Input; + output: Output; + errors: Err; description?: string; - handler: UploadProcedure['handler']; -}): Branded>; + handler: UploadProcedure['handler']; +}): Branded>; // implementation function upload({ @@ -364,7 +313,7 @@ function upload({ description, handler, }: { - init?: PayloadType | null; + init: PayloadType; input: PayloadType; output: PayloadType; errors?: RiverError; @@ -373,28 +322,19 @@ function upload({ object, PayloadType, PayloadType, - RiverError, - PayloadType | null + PayloadType, + RiverError >['handler']; }) { - return init !== undefined && init !== null - ? { - type: 'upload', - ...(description ? { description } : {}), - init, - input, - output, - errors, - handler, - } - : { - type: 'upload', - ...(description ? { description } : {}), - input, - output, - errors, - handler, - }; + return { + type: 'upload', + ...(description ? { description } : {}), + init, + input, + output, + errors, + handler, + }; } /** @@ -403,39 +343,39 @@ function upload({ // signature: default errors function subscription< State, - I extends PayloadType, - O extends PayloadType, + Init extends PayloadType, + Output extends PayloadType, >(def: { - input: I; - output: O; + init: Init; + output: Output; errors?: never; description?: string; - handler: SubscriptionProcedure['handler']; -}): Branded>; + handler: SubscriptionProcedure['handler']; +}): Branded>; // signature: explicit errors function subscription< State, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, + Init extends PayloadType, + Output extends PayloadType, + Err extends RiverError, >(def: { - input: I; - output: O; - errors: E; + init: Init; + output: Output; + errors: Err; description?: string; - handler: SubscriptionProcedure['handler']; -}): Branded>; + handler: SubscriptionProcedure['handler']; +}): Branded>; // implementation function subscription({ - input, + init, output, errors = Type.Never(), description, handler, }: { - input: PayloadType; + init: PayloadType; output: PayloadType; errors?: RiverError; description?: string; @@ -449,7 +389,7 @@ function subscription({ return { type: 'subscription', ...(description ? { description } : {}), - input, + init, output, errors, handler, @@ -459,61 +399,36 @@ function subscription({ /** * Creates a {@link StreamProcedure}, optionally with an initialization message. */ -// signature: init with default errors +// signature: with default errors function stream< State, - I extends PayloadType, - O extends PayloadType, Init extends PayloadType, + Input extends PayloadType, + Output extends PayloadType, >(def: { init: Init; - input: I; - output: O; + input: Input; + output: Output; errors?: never; description?: string; - handler: StreamProcedure['handler']; -}): Branded>; + handler: StreamProcedure['handler']; +}): Branded>; -// signature: init with explicit errors +// signature: explicit errors function stream< State, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, Init extends PayloadType, + Input extends PayloadType, + Output extends PayloadType, + Err extends RiverError, >(def: { init: Init; - input: I; - output: O; - errors: E; - description?: string; - handler: StreamProcedure['handler']; -}): Branded>; - -// signature: no init with default errors -function stream(def: { - init?: never; - input: I; - output: O; - errors?: never; - description?: string; - handler: StreamProcedure['handler']; -}): Branded>; - -// signature: no init with explicit errors -function stream< - State, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, ->(def: { - init?: never; - input: I; - output: O; - errors: E; + input: Input; + output: Output; + errors: Err; description?: string; - handler: StreamProcedure['handler']; -}): Branded>; + handler: StreamProcedure['handler']; +}): Branded>; // implementation function stream({ @@ -524,7 +439,7 @@ function stream({ description, handler, }: { - init?: PayloadType | null; + init: PayloadType; input: PayloadType; output: PayloadType; errors?: RiverError; @@ -533,28 +448,19 @@ function stream({ object, PayloadType, PayloadType, - RiverError, - PayloadType | null + PayloadType, + RiverError >['handler']; }) { - return init !== undefined && init !== null - ? { - type: 'stream', - ...(description ? { description } : {}), - init, - input, - output, - errors, - handler, - } - : { - type: 'stream', - ...(description ? { description } : {}), - input, - output, - errors, - handler, - }; + return { + type: 'stream', + ...(description ? { description } : {}), + init, + input, + output, + errors, + handler, + }; } /** diff --git a/router/result.ts b/router/result.ts index 21e10a78..984f3789 100644 --- a/router/result.ts +++ b/router/result.ts @@ -40,29 +40,31 @@ export const RiverUncaughtSchema = Type.Object({ message: Type.String(), }); -export type Result = +export type Result = | { ok: true; payload: T; } | { ok: false; - payload: E; + payload: Err; }; -export function Ok, const E>(p: T): Result; -export function Ok, const E>( +export function Ok, const Err>( p: T, -): Result; -export function Ok(payload: T): Result; -export function Ok(payload: T): Result { +): Result; +export function Ok, const Err>( + p: T, +): Result; +export function Ok(payload: T): Result; +export function Ok(payload: T): Result { return { ok: true, payload, }; } -export function Err(error: E): Result { +export function Err(error: Err): Result { return { ok: false, payload: error, @@ -79,8 +81,8 @@ export type ResultUnwrapOk = R extends Result /** * Refine a {@link Result} type to its error payload. */ -export type ResultUnwrapErr = R extends Result - ? E +export type ResultUnwrapErr = R extends Result + ? Err : never; /** diff --git a/router/server.ts b/router/server.ts index 42e33a56..73f462fe 100644 --- a/router/server.ts +++ b/router/server.ts @@ -114,7 +114,7 @@ class RiverServer { } let procStream = this.streamMap.get(message.streamId); - const isFirstMessage = !procStream; + const isInit = !procStream; // create a proc stream if it doesnt exist procStream ||= this.createNewProcStream(message); @@ -123,7 +123,7 @@ class RiverServer { return; } - await this.pushToStream(procStream, message, isFirstMessage); + await this.pushToStream(procStream, message, isInit); }; // cleanup streams on session close @@ -162,58 +162,58 @@ class RiverServer { } } - createNewProcStream(message: OpaqueTransportMessage) { - if (!isStreamOpen(message.controlFlags)) { + createNewProcStream(initMessage: OpaqueTransportMessage) { + if (!isStreamOpen(initMessage.controlFlags)) { log?.error( `can't create a new procedure stream from a message that doesn't have the stream open bit set`, { clientId: this.transport.clientId, - transportMessage: message, + transportMessage: initMessage, tags: ['invariant-violation'], }, ); return; } - if (!message.procedureName || !message.serviceName) { + if (!initMessage.procedureName || !initMessage.serviceName) { log?.warn(`missing procedure or service name in stream open message`, { clientId: this.transport.clientId, - transportMessage: message, + transportMessage: initMessage, }); return; } - if (!(message.serviceName in this.services)) { - log?.warn(`couldn't find service ${message.serviceName}`, { + if (!(initMessage.serviceName in this.services)) { + log?.warn(`couldn't find service ${initMessage.serviceName}`, { clientId: this.transport.clientId, - transportMessage: message, + transportMessage: initMessage, }); return; } - const service = this.services[message.serviceName]; - const serviceContext = this.getContext(service, message.serviceName); - if (!(message.procedureName in service.procedures)) { + const service = this.services[initMessage.serviceName]; + const serviceContext = this.getContext(service, initMessage.serviceName); + if (!(initMessage.procedureName in service.procedures)) { log?.warn( - `couldn't find a matching procedure for ${message.serviceName}.${message.procedureName}`, + `couldn't find a matching procedure for ${initMessage.serviceName}.${initMessage.procedureName}`, { clientId: this.transport.clientId, - transportMessage: message, + transportMessage: initMessage, }, ); return; } - const session = this.transport.sessions.get(message.from); + const session = this.transport.sessions.get(initMessage.from); if (!session) { - log?.warn(`couldn't find session for ${message.from}`, { + log?.warn(`couldn't find session for ${initMessage.from}`, { clientId: this.transport.clientId, - transportMessage: message, + transportMessage: initMessage, }); return; } - const procedure = service.procedures[message.procedureName]; + const procedure = service.procedures[initMessage.procedureName]; const readStreamRequestCloseNotImplemented = () => void 0; const incoming: ProcStream['incoming'] = new ReadStreamImpl( readStreamRequestCloseNotImplemented, @@ -224,16 +224,16 @@ class RiverServer { const outgoing: ProcStream['outgoing'] = new WriteStreamImpl( (response) => { this.transport.send(session.to, { - streamId: message.streamId, + streamId: initMessage.streamId, controlFlags: needsClose ? 0 : ControlFlags.StreamClosedBit, payload: response, }); }, () => { - if (needsClose && !this.disconnectedSessions.has(message.from)) { + if (needsClose && !this.disconnectedSessions.has(initMessage.from)) { // we ended, send a close bit back to the client // also, if the client has disconnected, we don't need to send a close - this.transport.sendCloseStream(session.to, message.streamId); + this.transport.sendCloseStream(session.to, initMessage.streamId); } // call disposables returned from handlers disposables.forEach((d) => d()); @@ -243,7 +243,7 @@ class RiverServer { const errorHandler = (err: unknown, span: Span) => { const errorMsg = coerceErrorString(err); log?.error( - `procedure ${message.serviceName}.${message.procedureName} threw an uncaught error: ${errorMsg}`, + `procedure ${initMessage.serviceName}.${initMessage.procedureName} threw an uncaught error: ${errorMsg}`, session.loggingMetadata, ); @@ -268,13 +268,12 @@ class RiverServer { // pump incoming message stream -> handler -> outgoing message stream let inputHandler: Promise; - const procHasInitMessage = 'init' in procedure; const serviceContextWithTransportInfo: ServiceContextWithTransportInfo = { ...serviceContext, - to: message.to, - from: message.from, - streamId: message.streamId, + to: initMessage.to, + from: initMessage.from, + streamId: initMessage.streamId, session, metadata: sessionMeta, }; @@ -283,15 +282,15 @@ class RiverServer { case 'rpc': inputHandler = createHandlerSpan( procedure.type, - message, + initMessage, async (span) => { - if (!Value.Check(procedure.input, message.payload)) { - log?.error('subscription input failed validation', { + if (!Value.Check(procedure.init, initMessage.payload)) { + log?.error('rpc init failed validation', { clientId: this.transport.clientId, - transportMessage: message, + transportMessage: initMessage, }); - errorHandler('subscription input failed validation', span); + errorHandler('rpc init failed validation', span); span.end(); return; @@ -300,7 +299,7 @@ class RiverServer { try { const outputMessage = await procedure.handler( serviceContextWithTransportInfo, - message.payload, + initMessage.payload, ); outgoing.write(outputMessage); } catch (err) { @@ -312,82 +311,60 @@ class RiverServer { ); break; case 'stream': - if (procHasInitMessage) { - inputHandler = createHandlerSpan( - procedure.type, - message, - async (span) => { - if (!Value.Check(procedure.init, message.payload)) { - log?.error( - 'procedure requires init, but first message failed validation', - { - clientId: this.transport.clientId, - transportMessage: message, - }, - ); - - errorHandler( - 'procedure requires init, but first message failed validation', - span, - ); - span.end(); - - return; - } + inputHandler = createHandlerSpan( + procedure.type, + initMessage, + async (span) => { + if (!Value.Check(procedure.init, initMessage.payload)) { + log?.error( + 'procedure requires init, but first message failed validation', + { + clientId: this.transport.clientId, + transportMessage: initMessage, + }, + ); - try { - const dispose = await procedure.handler( - serviceContextWithTransportInfo, - message.payload, - incoming, - outgoing, - ); - - if (dispose) { - disposables.push(dispose); - } - } catch (err) { - errorHandler(err, span); - } finally { - span.end(); - } - }, - ); - } else { - inputHandler = createHandlerSpan( - procedure.type, - message, - async (span) => { - try { - const dispose = await procedure.handler( - serviceContextWithTransportInfo, - incoming, - outgoing, - ); - if (dispose) { - disposables.push(dispose); - } - } catch (err) { - errorHandler(err, span); - } finally { - span.end(); + errorHandler( + 'procedure requires init, but first message failed validation', + span, + ); + span.end(); + + return; + } + + try { + const dispose = await procedure.handler( + serviceContextWithTransportInfo, + initMessage.payload, + incoming, + outgoing, + ); + + if (dispose) { + disposables.push(dispose); } - }, - ); - } + } catch (err) { + errorHandler(err, span); + } finally { + span.end(); + } + }, + ); + break; case 'subscription': inputHandler = createHandlerSpan( procedure.type, - message, + initMessage, async (span) => { - if (!Value.Check(procedure.input, message.payload)) { - log?.error('subscription input failed validation', { + if (!Value.Check(procedure.init, initMessage.payload)) { + log?.error('subscription init failed validation', { clientId: this.transport.clientId, - transportMessage: message, + transportMessage: initMessage, }); - errorHandler('subscription input failed validation', span); + errorHandler('subscription init failed validation', span); span.end(); return; @@ -396,7 +373,7 @@ class RiverServer { try { const dispose = await procedure.handler( serviceContextWithTransportInfo, - message.payload, + initMessage.payload, outgoing, ); @@ -412,68 +389,45 @@ class RiverServer { ); break; case 'upload': - if (procHasInitMessage) { - inputHandler = createHandlerSpan( - procedure.type, - message, - async (span) => { - if (!Value.Check(procedure.init, message.payload)) { - log?.error( - 'procedure requires init, but first message failed validation', - { - clientId: this.transport.clientId, - transportMessage: message, - }, - ); - - errorHandler( - 'procedure requires init, but first message failed validation', - span, - ); - span.end(); - - return; - } + inputHandler = createHandlerSpan( + procedure.type, + initMessage, + async (span) => { + if (!Value.Check(procedure.init, initMessage.payload)) { + log?.error( + 'procedure requires init, but first message failed validation', + { + clientId: this.transport.clientId, + transportMessage: initMessage, + }, + ); - try { - const outputMessage = await procedure.handler( - serviceContextWithTransportInfo, - message.payload, - incoming, - ); - - if (!this.disconnectedSessions.has(message.from)) { - outgoing.write(outputMessage); - } - } catch (err) { - errorHandler(err, span); - } finally { - span.end(); - } - }, - ); - } else { - inputHandler = createHandlerSpan( - procedure.type, - message, - async (span) => { - try { - const outputMessage = await procedure.handler( - serviceContextWithTransportInfo, - incoming, - ); - - if (!this.disconnectedSessions.has(message.from)) { - outgoing.write(outputMessage); - } - } catch (err) { - errorHandler(err, span); - } finally { - span.end(); + errorHandler( + 'procedure requires init, but first message failed validation', + span, + ); + span.end(); + + return; + } + + try { + const outputMessage = await procedure.handler( + serviceContextWithTransportInfo, + initMessage.payload, + incoming, + ); + + if (!this.disconnectedSessions.has(initMessage.from)) { + outgoing.write(outputMessage); } - }, - ); - } + } catch (err) { + errorHandler(err, span); + } finally { + span.end(); + } + }, + ); break; default: @@ -482,28 +436,28 @@ class RiverServer { log?.warn( `got request for invalid procedure type ${ (procedure as AnyProcedure).type - } at ${message.serviceName}.${message.procedureName}`, - { ...session.loggingMetadata, transportMessage: message }, + } at ${initMessage.serviceName}.${initMessage.procedureName}`, + { ...session.loggingMetadata, transportMessage: initMessage }, ); return; } const procStream: ProcStream = { - id: message.streamId, + id: initMessage.streamId, incoming, outgoing, - serviceName: message.serviceName, - procedureName: message.procedureName, + serviceName: initMessage.serviceName, + procedureName: initMessage.procedureName, promises: { inputHandler }, }; - this.streamMap.set(message.streamId, procStream); + this.streamMap.set(initMessage.streamId, procStream); // add this stream to ones from that client so we can clean it up in the case of a disconnect without close const streamsFromThisClient = - this.clientStreams.get(message.from) ?? new Set(); - streamsFromThisClient.add(message.streamId); - this.clientStreams.set(message.from, streamsFromThisClient); + this.clientStreams.get(initMessage.from) ?? new Set(); + streamsFromThisClient.add(initMessage.streamId); + this.clientStreams.set(initMessage.from, streamsFromThisClient); return procStream; } @@ -511,15 +465,17 @@ class RiverServer { async pushToStream( procStream: ProcStream, message: OpaqueTransportMessage, - isFirstMessage?: boolean, + isInit?: boolean, ) { const { serviceName, procedureName } = procStream; const procedure = this.services[serviceName].procedures[procedureName]; - const procHasInitMessage = 'init' in procedure; - if (!isFirstMessage || !procHasInitMessage) { - // Init message is consumed during stream instantiation - if (Value.Check(procedure.input, message.payload)) { + // Init message is consumed during stream instantiation + if (!isInit) { + if ( + 'input' in procedure && + Value.Check(procedure.input, message.payload) + ) { procStream.incoming.pushValue(message.payload as PayloadType); } else if (!Value.Check(ControlMessagePayloadSchema, message.payload)) { // whelp we got a message that isn't a control message and doesn't match the procedure input @@ -530,7 +486,7 @@ class RiverServer { clientId: this.transport.clientId, transportMessage: message, validationErrors: [ - ...Value.Errors(procedure.input, message.payload), + ...Value.Errors(procedure.init, message.payload), ], }, ); diff --git a/router/services.ts b/router/services.ts index d1a1bafc..b138393f 100644 --- a/router/services.ts +++ b/router/services.ts @@ -59,16 +59,6 @@ export type ProcHandler< ProcName extends keyof S['procedures'], > = S['procedures'][ProcName]['handler']; -/** - * Helper to get whether the type definition for the procedure contains an init type. - * @template S - The service. - * @template ProcName - The name of the procedure. - */ -export type ProcHasInit< - S extends AnyService, - ProcName extends keyof S['procedures'], -> = S['procedures'][ProcName] extends { init: PayloadType } ? true : false; - /** * Helper to get the type definition for the procedure init type of a service. * @template S - The service. @@ -77,9 +67,7 @@ export type ProcHasInit< export type ProcInit< S extends AnyService, ProcName extends keyof S['procedures'], -> = S['procedures'][ProcName] extends { init: PayloadType } - ? S['procedures'][ProcName]['init'] - : never; +> = S['procedures'][ProcName]['init']; /** * Helper to get the type definition for the procedure input of a service. @@ -89,7 +77,9 @@ export type ProcInit< export type ProcInput< S extends AnyService, ProcName extends keyof S['procedures'], -> = S['procedures'][ProcName]['input']; +> = S['procedures'][ProcName] extends { input: PayloadType } + ? S['procedures'][ProcName]['input'] + : never; /** * Helper to get the type definition for the procedure output of a service. @@ -141,11 +131,11 @@ export interface SerializedServiceSchema { procedures: Record< string, { - input: PayloadType; + init: PayloadType; + input?: PayloadType; output: PayloadType; errors?: RiverError; type: 'rpc' | 'subscription' | 'upload' | 'stream'; - init?: PayloadType; } >; } @@ -236,10 +226,10 @@ export class ServiceSchema< * * const incrementProcedures = MyServiceScaffold.procedures({ * increment: Procedure.rpc({ - * input: Type.Object({ amount: Type.Number() }), + * init: Type.Object({ amount: Type.Number() }), * output: Type.Object({ current: Type.Number() }), - * async handler(ctx, input) { - * ctx.state.count += input.amount; + * async handler(ctx, init) { + * ctx.state.count += init.amount; * return Ok({ current: ctx.state.count }); * } * }), @@ -263,10 +253,10 @@ export class ServiceSchema< * .scaffold({ initializeState: () => ({ count: 0 }) }) * .finalize({ * increment: Procedure.rpc({ - * input: Type.Object({ amount: Type.Number() }), + * init: Type.Object({ amount: Type.Number() }), * output: Type.Object({ current: Type.Number() }), - * async handler(ctx, input) { - * ctx.state.count += input.amount; + * async handler(ctx, init) { + * ctx.state.count += init.amount; * return Ok({ current: ctx.state.count }); * } * }), @@ -296,10 +286,10 @@ export class ServiceSchema< * { initializeState: () => ({ count: 0 }) }, * { * increment: Procedure.rpc({ - * input: Type.Object({ amount: Type.Number() }), + * init: Type.Object({ amount: Type.Number() }), * output: Type.Object({ current: Type.Number() }), - * async handler(ctx, input) { - * ctx.state.count += input.amount; + * async handler(ctx, init) { + * ctx.state.count += init.amount; * return Ok({ current: ctx.state.count }); * } * }), @@ -331,10 +321,10 @@ export class ServiceSchema< * ``` * const service = ServiceSchema.define({ * add: Procedure.rpc({ - * input: Type.Object({ a: Type.Number(), b: Type.Number() }), + * init: Type.Object({ a: Type.Number(), b: Type.Number() }), * output: Type.Object({ result: Type.Number() }), - * async handler(ctx, input) { - * return Ok({ result: input.a + input.b }); + * async handler(ctx, init) { + * return Ok({ result: init.a + init.b }); * } * }), * }); @@ -382,7 +372,7 @@ export class ServiceSchema< Object.entries(this.procedures).map(([procName, procDef]) => [ procName, { - input: Type.Strict(procDef.input), + init: Type.Strict(procDef.init), output: Type.Strict(procDef.output), // Only add `description` field if the type declares it. ...('description' in procDef @@ -395,10 +385,10 @@ export class ServiceSchema< } : {}), type: procDef.type, - // Only add the `init` field if the type declares it. - ...('init' in procDef + // Only add the `input` field if the type declares it. + ...('input' in procDef ? { - init: Type.Strict(procDef.init), + input: Type.Strict(procDef.input), } : {}), }, diff --git a/util/testHelpers.ts b/util/testHelpers.ts index 38763d70..aef37d17 100644 --- a/util/testHelpers.ts +++ b/util/testHelpers.ts @@ -187,22 +187,21 @@ function dummyCtx( export function asClientRpc< State extends object, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, - Init extends PayloadType | null = null, + Init extends PayloadType, + Output extends PayloadType, + Err extends RiverError, >( state: State, - proc: Procedure, + proc: Procedure, extendedContext?: Omit, session: Session = dummySession(), ) { return async ( - msg: Static, + msg: Static, ): Promise< - Result, Static | Static> + Result, Static | Static> > => { - return await proc + return proc .handler(dummyCtx(state, session, extendedContext), msg) .catch(catchProcError); }; @@ -226,70 +225,53 @@ function createPipe(): { reader: ReadStream; writer: WriteStream } { export function asClientStream< State extends object, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, - Init extends PayloadType | null = null, + Init extends PayloadType, + Input extends PayloadType, + Output extends PayloadType, + Err extends RiverError, >( state: State, - proc: Procedure, - init?: Init extends PayloadType ? Static : null, + proc: Procedure, + init?: Static, extendedContext?: Omit, session: Session = dummySession(), -): [WriteStream>, ReadStream>] { - const inputPipe = createPipe>(); - const outputPipe = createPipe>(); +): [WriteStream>, ReadStream>] { + const inputPipe = createPipe>(); + const outputPipe = createPipe>(); - void (async () => { - if (init) { - const _proc = proc as Procedure; - await _proc - - .handler( - dummyCtx(state, session, extendedContext), - init, - inputPipe.reader, - outputPipe.writer, - ) - .catch((err: unknown) => outputPipe.writer.write(catchProcError(err))); - } else { - const _proc = proc as Procedure; - await _proc - .handler( - dummyCtx(state, session, extendedContext), - inputPipe.reader, - outputPipe.writer, - ) - .catch((err: unknown) => outputPipe.writer.write(catchProcError(err))); - } - })(); + void proc + .handler( + dummyCtx(state, session, extendedContext), + init ?? {}, + inputPipe.reader, + outputPipe.writer, + ) + .catch((err: unknown) => outputPipe.writer.write(catchProcError(err))); return [inputPipe.writer, outputPipe.reader]; } export function asClientSubscription< State extends object, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, + Init extends PayloadType, + Output extends PayloadType, + Err extends RiverError, >( state: State, - proc: Procedure, + proc: Procedure, extendedContext?: Omit, session: Session = dummySession(), -): (msg: Static) => ReadStream> { - const outputPipe = createPipe>(); +): (msg: Static) => ReadStream> { + const outputPipe = createPipe>(); - return (msg: Static) => { - void (async () => { - await proc - .handler( - dummyCtx(state, session, extendedContext), - msg, - outputPipe.writer, - ) - .catch((err: unknown) => outputPipe.writer.write(catchProcError(err))); - })(); + return (msg: Static) => { + void proc + .handler( + dummyCtx(state, session, extendedContext), + msg, + outputPipe.writer, + ) + .catch((err: unknown) => outputPipe.writer.write(catchProcError(err))); return outputPipe.reader; }; @@ -297,35 +279,26 @@ export function asClientSubscription< export function asClientUpload< State extends object, - I extends PayloadType, - O extends PayloadType, - E extends RiverError, - Init extends PayloadType | null = null, + Init extends PayloadType, + Input extends PayloadType, + Output extends PayloadType, + Err extends RiverError, >( state: State, - proc: Procedure, - init?: Init extends PayloadType ? Static : null, + proc: Procedure, + init?: Static, extendedContext?: Omit, session: Session = dummySession(), -): [WriteStream>, Promise>] { - const inputPipe = createPipe>(); - if (init) { - const _proc = proc as Procedure; - const result = _proc - .handler( - dummyCtx(state, session, extendedContext), - init, - inputPipe.reader, - ) - .catch(catchProcError); - return [inputPipe.writer, result]; - } else { - const _proc = proc as Procedure; - const result = _proc - .handler(dummyCtx(state, session, extendedContext), inputPipe.reader) - .catch(catchProcError); - return [inputPipe.writer, result]; - } +): [WriteStream>, Promise>] { + const inputPipe = createPipe>(); + const result = proc + .handler( + dummyCtx(state, session, extendedContext), + init ?? {}, + inputPipe.reader, + ) + .catch(catchProcError); + return [inputPipe.writer, result]; } export const getUnixSocketPath = () => {