From 1bdcb6bcdf9d5d72a37acc77847ab3803fee4545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=A1clav=20Kubern=C3=A1t?= Date: Wed, 25 Sep 2024 17:23:38 +0200 Subject: [PATCH] Parse RpcMessages with Zod --- package.json | 2 +- src/chainpack.ts | 2 +- src/cpon.ts | 2 +- src/rpcmessage.ts | 237 +++++++++++----------------------------------- src/rpcvalue.ts | 8 +- src/ws-client.ts | 128 +++++++++++++++++-------- src/zod.ts | 30 +++++- 7 files changed, 182 insertions(+), 227 deletions(-) diff --git a/package.json b/package.json index b8e23ec..4715d56 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "libshv-js", - "version": "3.3.6", + "version": "3.4.0", "description": "Typescript implementation of libshv", "scripts": { "test": "echo \"Error: no test specified\" && exit 1", diff --git a/src/chainpack.ts b/src/chainpack.ts index e96d8bf..40d3ed6 100644 --- a/src/chainpack.ts +++ b/src/chainpack.ts @@ -100,7 +100,7 @@ class ChainPackReader { } const implReturn = (x: RpcValueType) => { - const ret = meta !== undefined ? new RpcValueWithMetaData(x, meta) : x; + const ret = meta !== undefined ? new RpcValueWithMetaData(meta, x) : x; return ret; }; diff --git a/src/cpon.ts b/src/cpon.ts index b0331e9..41a37da 100644 --- a/src/cpon.ts +++ b/src/cpon.ts @@ -71,7 +71,7 @@ class CponReader { } const implReturn = (x: RpcValueType) => { - const ret = meta !== undefined ? new RpcValueWithMetaData(x, meta) : x; + const ret = meta !== undefined ? new RpcValueWithMetaData(meta, x) : x; return ret; }; diff --git a/src/rpcmessage.ts b/src/rpcmessage.ts index fed1cd5..f0b975c 100644 --- a/src/rpcmessage.ts +++ b/src/rpcmessage.ts @@ -1,19 +1,13 @@ -import {type IMap, type Int, isIMap, makeIMap, makeMetaMap, type MetaMap, type RpcValue, RpcValueWithMetaData} from './rpcvalue'; -import {toCpon} from './cpon'; -import {toChainPack} from './chainpack'; +import * as z from './zod'; -enum RpcMessageTag { - RequestId = 8, - ShvPath = 9, - Method = 10, - CallerIds = 11, -} +export const RPC_MESSAGE_REQUEST_ID = 8; +export const RPC_MESSAGE_SHV_PATH = 9; +export const RPC_MESSAGE_METHOD = 10; +export const RPC_MESSAGE_CALLER_IDS = 11; -enum RpcMessageKey { - Params = 1, - Result = 2, - Error = 3, -} +export const RPC_MESSAGE_PARAMS = 1; +export const RPC_MESSAGE_RESULT = 2; +export const RPC_MESSAGE_ERROR = 3; export const ERROR_CODE = 1; export const ERROR_MESSAGE = 2; @@ -34,169 +28,52 @@ enum ErrorCode { NotImplemented = 12, } -export type ErrorMap = { - [ERROR_CODE]: ErrorCode; - [ERROR_MESSAGE]?: string; - [ERROR_DATA]?: RpcValue; -}; - -class RpcError extends Error { - constructor(private readonly err_info: ErrorMap) { - super(err_info[ERROR_MESSAGE] ?? 'Unknown RpcError'); - } - - data() { - return this.err_info[ERROR_DATA]; - } -} - -export class ProtocolError extends Error {} - -export class InvalidRequest extends RpcError {} -export class MethodNotFound extends RpcError {} -export class InvalidParams extends RpcError {} -export class InternalError extends RpcError {} -export class ParseError extends RpcError {} -export class MethodCallTimeout extends RpcError {} -export class MethodCallCancelled extends RpcError {} -export class MethodCallException extends RpcError {} -export class Unknown extends RpcError {} -export class LoginRequired extends RpcError {} -export class UserIDRequired extends RpcError {} -export class NotImplemented extends RpcError {} - -class RpcMessage { - value: IMap; - meta: MetaMap; - constructor(rpc_val?: RpcValue) { - if (rpc_val === undefined) { - this.value = makeIMap({}); - this.meta = makeMetaMap({}); - return; - } - - if (!(rpc_val instanceof RpcValueWithMetaData && isIMap(rpc_val.value))) { - throw new TypeError(`RpcMessage initialized with a non-IMap: ${toCpon(rpc_val)}`); - } - - this.value = rpc_val.value; - this.meta = rpc_val.meta; - } - - isValid() { - return this.shvPath() && (this.isRequest() || this.isResponse || this.isSignal()); - } - - isRequest(): boolean { - return this.requestId() !== undefined && this.method() !== undefined; - } - - isResponse(): boolean { - return this.requestId() !== undefined && this.method() === undefined; - } - - isSignal(): boolean { - return this.requestId() === undefined && this.method() !== undefined; - } - - requestId(): Int | undefined { - return (this.meta[RpcMessageTag.RequestId] as Int); - } - - setRequestId(id: number) { - this.meta[RpcMessageTag.RequestId] = id; - } - - callerIds(): RpcValue[] | undefined { - return this.meta[RpcMessageTag.CallerIds] as RpcValue[]; - } - - setCallerIds(ids: RpcValue[]) { - this.meta[RpcMessageTag.CallerIds] = ids; - } - - shvPath(): string | undefined { - return (this.meta[RpcMessageTag.ShvPath] as string); - } - - setShvPath(val: string) { - this.meta[RpcMessageTag.ShvPath] = val; - } - - method(): string | undefined { - return (this.meta[RpcMessageTag.Method] as string); - } - - setMethod(val: string) { - this.meta[RpcMessageTag.Method] = val; - } - - params() { - return this.value[RpcMessageKey.Params] as RpcValue; - } - - setParams(params: RpcValue) { - this.value[RpcMessageKey.Params] = params; - } - - resultOrError() { - if (Object.hasOwn(this.value, RpcMessageKey.Error)) { - if (!isIMap(this.value[RpcMessageKey.Error])) { - return new ProtocolError('Response had an error, but this error was not a map'); - } - - const errorMap = this.value[RpcMessageKey.Error]; - if (typeof errorMap[ERROR_CODE] !== 'number') { - return new ProtocolError('Response had an error, but this error did not contain at least an error code'); - } - - const code = errorMap[ERROR_CODE] as unknown; - - const ErrorTypeCtor = (() => { - switch (code) { - case ErrorCode.InvalidRequest: return InvalidRequest; - case ErrorCode.MethodNotFound: return MethodNotFound; - case ErrorCode.InvalidParams: return InvalidParams; - case ErrorCode.InternalError: return InternalError; - case ErrorCode.ParseError: return ParseError; - case ErrorCode.MethodCallTimeout: return MethodCallTimeout; - case ErrorCode.MethodCallCancelled: return MethodCallCancelled; - case ErrorCode.MethodCallException: return MethodCallException; - case ErrorCode.Unknown: return Unknown; - case ErrorCode.LoginRequired: return LoginRequired; - case ErrorCode.UserIDRequired: return UserIDRequired; - case ErrorCode.NotImplemented: return NotImplemented; - default: return Unknown; - } - })(); - - return new ErrorTypeCtor(this.value[RpcMessageKey.Error] as IMap); - } - - if (Object.hasOwn(this.value, RpcMessageKey.Result)) { - return this.value[RpcMessageKey.Result] as RpcValue; - } - - return new ProtocolError('Response included neither result nor error'); - } - - setResult(result: RpcValue) { - this.value[RpcMessageKey.Result] = result; - } - - setError(error: string) { - this.value[RpcMessageKey.Error] = error; - } - - toCpon() { - return toCpon(new RpcValueWithMetaData(this.value, this.meta)); - } - - toChainPack() { - return toChainPack(new RpcValueWithMetaData(this.value, this.meta)); - } -} - -export type RpcResponse = T | Error; - -export {RpcMessage, RpcError, ErrorCode}; +const ErrorMapZod = z.imap({ + [ERROR_CODE]: z.number(), + [ERROR_MESSAGE]: z.string().optional(), + [ERROR_DATA]: z.rpcvalue().optional(), +}); + +export type ErrorMap = z.infer; + +const RpcRequestMetaZod = z.metamap({ + [RPC_MESSAGE_REQUEST_ID]: z.number(), + [RPC_MESSAGE_METHOD]: z.string(), + [RPC_MESSAGE_SHV_PATH]: z.string(), +}); + +const RpcRequestValueZod = z.imap({ + [RPC_MESSAGE_PARAMS]: z.rpcvalue().optional(), +}); + +const RpcResponseMetaZod = z.metamap({ + [RPC_MESSAGE_REQUEST_ID]: z.number(), +}); +const RpcResponseValueZod = z.imap({ + [RPC_MESSAGE_RESULT]: z.rpcvalue(), +}).or(z.imap({ + [RPC_MESSAGE_ERROR]: ErrorMapZod, +})); + +const RpcSignalMetaZod = z.metamap({ + [RPC_MESSAGE_SHV_PATH]: z.string(), + [RPC_MESSAGE_METHOD]: z.string(), +}); +const RpcSignalValueZod = z.imap({ + [RPC_MESSAGE_PARAMS]: z.rpcvalue().optional(), +}); + +const RpcRequestZod = z.withMeta(RpcRequestMetaZod, RpcRequestValueZod); +const RpcResponseZod = z.withMeta(RpcResponseMetaZod, RpcResponseValueZod); +const RpcSignalZod = z.withMeta(RpcSignalMetaZod, RpcSignalValueZod); +const RpcMessageZod = z.union([RpcRequestZod, RpcResponseZod, RpcSignalZod]); +export type RpcRequest = z.infer; +export type RpcResponse = z.infer; +export type RpcSignal = z.infer; +export type RpcMessage = z.infer; + +export const isSignal = (message: RpcMessage): message is RpcSignal => !(RPC_MESSAGE_REQUEST_ID in message.meta) && RPC_MESSAGE_METHOD in message.meta; +export const isRequest = (message: RpcMessage): message is RpcRequest => RPC_MESSAGE_REQUEST_ID in message.meta && RPC_MESSAGE_METHOD in message.meta; +export const isResponse = (message: RpcMessage): message is RpcResponse => RPC_MESSAGE_REQUEST_ID in message.meta && !(RPC_MESSAGE_METHOD in message.meta); + +export {RpcMessageZod, ErrorCode}; diff --git a/src/rpcvalue.ts b/src/rpcvalue.ts index ae09ac4..69c1b96 100644 --- a/src/rpcvalue.ts +++ b/src/rpcvalue.ts @@ -97,6 +97,8 @@ const isShvMap = (x: unknown): x is ShvMap => typeof x === 'object' && (x as Shv const isIMap = (x: unknown): x is IMap => typeof x === 'object' && (x as IMap)[shvMapType] === 'imap'; +const isMetaMap = (x: unknown): x is IMap => typeof x === 'object' && (x as MetaMap)[shvMapType] === 'metamap'; + const makeMetaMap = = Record, U extends Record = Omit>(x: U = {} as U): MetaMap => ({ ...x, [shvMapType]: 'metamap', @@ -112,10 +114,10 @@ const makeMap = = Record, [shvMapType]: 'map', }); -class RpcValueWithMetaData { - constructor(public value: RpcValueType, public meta: MetaMap) {} +class RpcValueWithMetaData { + constructor(public meta: MetaSchema, public value: ValueSchema) {} } export type RpcValue = RpcValueType | RpcValueWithMetaData; -export {shvMapType, Decimal, Double, type IMap, type MetaMap, RpcValueWithMetaData, type ShvMap, UInt, withOffset, makeMap, makeIMap, makeMetaMap, isIMap, isShvMap}; +export {shvMapType, Decimal, Double, type IMap, type MetaMap, RpcValueWithMetaData, type ShvMap, UInt, withOffset, makeMap, makeIMap, makeMetaMap, isIMap, isMetaMap, isShvMap}; diff --git a/src/ws-client.ts b/src/ws-client.ts index 6280e24..e68226e 100644 --- a/src/ws-client.ts +++ b/src/ws-client.ts @@ -1,7 +1,7 @@ -import {ChainPackReader, CHAINPACK_PROTOCOL_TYPE, ChainPackWriter} from './chainpack'; -import {type CponReader, CPON_PROTOCOL_TYPE} from './cpon'; -import {ERROR_MESSAGE, ErrorCode, ERROR_CODE, RpcMessage, type RpcResponse, MethodCallTimeout} from './rpcmessage'; -import {type RpcValue, type Null, type Int, type IMap, type ShvMap, makeMap, makeIMap} from './rpcvalue'; +import {ChainPackReader, CHAINPACK_PROTOCOL_TYPE, ChainPackWriter, toChainPack} from './chainpack'; +import {type CponReader, CPON_PROTOCOL_TYPE, toCpon} from './cpon'; +import {ERROR_MESSAGE, ErrorCode, ERROR_CODE, RpcMessageZod, type RpcMessage, isSignal, isRequest, type RpcRequest, isResponse, ERROR_DATA, type ErrorMap, RPC_MESSAGE_METHOD, RPC_MESSAGE_SHV_PATH, RPC_MESSAGE_REQUEST_ID, RPC_MESSAGE_PARAMS, RPC_MESSAGE_ERROR, RPC_MESSAGE_RESULT} from './rpcmessage'; +import {type RpcValue, type Null, type Int, type IMap, type ShvMap, makeMap, makeIMap, RpcValueWithMetaData, makeMetaMap} from './rpcvalue'; const DEFAULT_TIMEOUT = 5000; const DEFAULT_PING_INTERVAL = 30 * 1000; @@ -25,7 +25,9 @@ const dataToRpcValue = (buff: ArrayBuffer) => { type SubscriptionCallback = (path: string, method: string, param?: RpcValue) => void; -type RpcResponseResolver = (rpc_msg: RpcResponse) => void; +type ResultOrError = T | Error; + +type RpcResponseResolver = (rpc_msg: ResultOrError) => void; type Subscription = { subscriber: string; @@ -46,7 +48,7 @@ type WsClientOptions = { onConnected: () => void; onConnectionFailure: (error: Error) => void; onDisconnected: () => void; - onRequest: (rpc_msg: RpcMessage) => void; + onRequest: (rpc_msg: RpcRequest) => void; }; type LsResult = string[]; @@ -75,6 +77,31 @@ type DirResult = Array>; +class RpcError extends Error { + constructor(private readonly err_info: ErrorMap) { + super(err_info[ERROR_MESSAGE] ?? 'Unknown RpcError'); + } + + data() { + return this.err_info[ERROR_DATA]; + } +} + +class ProtocolError extends Error {} + +class InvalidRequest extends RpcError {} +class MethodNotFound extends RpcError {} +class InvalidParams extends RpcError {} +class InternalError extends RpcError {} +class ParseError extends RpcError {} +class MethodCallTimeout extends RpcError {} +class MethodCallCancelled extends RpcError {} +class MethodCallException extends RpcError {} +class Unknown extends RpcError {} +class LoginRequired extends RpcError {} +class UserIDRequired extends RpcError {} +class NotImplemented extends RpcError {} + class WsClient { requestId = 1; pingTimerId = -1; @@ -171,30 +198,56 @@ class WsClient { this.websocket.addEventListener('message', (evt: MessageEvent) => { const rpcVal = dataToRpcValue(evt.data); - const rpcMsg = new RpcMessage(rpcVal); - this.logDebug(`message received: ${rpcMsg.toCpon()}`); + const rpcMsg = RpcMessageZod.parse(rpcVal); + this.logDebug(`message received: ${toCpon(rpcMsg)}`); - if (rpcMsg.isSignal()) { + if (isSignal(rpcMsg)) { for (const sub of this.subscriptions) { - const shvPath = rpcMsg.shvPath(); - const method = rpcMsg.method(); + const shvPath = rpcMsg.meta[RPC_MESSAGE_SHV_PATH]; + const method = rpcMsg.meta[RPC_MESSAGE_METHOD]; - if (shvPath?.startsWith(sub.path) && method === sub.method) { - sub.callback(shvPath, method, rpcMsg.params()); + if (shvPath.startsWith(sub.path) && method === sub.method) { + sub.callback(shvPath, method, rpcMsg.value[RPC_MESSAGE_PARAMS]); } } - } else if (rpcMsg.isRequest()) { + } else if (isRequest(rpcMsg)) { this.onRequest(rpcMsg); - } else if (rpcMsg.isResponse()) { - const requestId = rpcMsg.requestId(); - if (requestId === undefined) { - throw new Error('got RpcResponse without requestId'); - } + } else if (isResponse(rpcMsg)) { + const requestId = rpcMsg.meta[RPC_MESSAGE_REQUEST_ID]; if (this.rpcHandlers[Number(requestId)] !== undefined) { const handler = this.rpcHandlers[Number(requestId)]; clearTimeout(handler.timeout_handle); - handler.resolve(rpcMsg.resultOrError()); + handler.resolve((() => { + if (RPC_MESSAGE_ERROR in rpcMsg.value) { + const code = rpcMsg.value[RPC_MESSAGE_ERROR][ERROR_CODE] as unknown; + const ErrorTypeCtor = (() => { + switch (code) { + case ErrorCode.InvalidRequest: return InvalidRequest; + case ErrorCode.MethodNotFound: return MethodNotFound; + case ErrorCode.InvalidParams: return InvalidParams; + case ErrorCode.InternalError: return InternalError; + case ErrorCode.ParseError: return ParseError; + case ErrorCode.MethodCallTimeout: return MethodCallTimeout; + case ErrorCode.MethodCallCancelled: return MethodCallCancelled; + case ErrorCode.MethodCallException: return MethodCallException; + case ErrorCode.Unknown: return Unknown; + case ErrorCode.LoginRequired: return LoginRequired; + case ErrorCode.UserIDRequired: return UserIDRequired; + case ErrorCode.NotImplemented: return NotImplemented; + default: return Unknown; + } + })(); + + return new ErrorTypeCtor(rpcMsg.value[RPC_MESSAGE_ERROR]); + } + + if (RPC_MESSAGE_RESULT in rpcMsg.value) { + return rpcMsg.value[RPC_MESSAGE_RESULT]; + } + + return new ProtocolError('Response included neither result nor error'); + })()); // eslint-disable-next-line @typescript-eslint/no-array-delete, @typescript-eslint/no-dynamic-delete delete this.rpcHandlers[Number(requestId)]; } @@ -207,30 +260,29 @@ class WsClient { }); } - callRpcMethod(shv_path: '.broker/currentClient', method: 'accessGrantForMethodCall', params: [string, string]): Promise>; - callRpcMethod(shv_path: string | undefined, method: 'dir', params?: RpcValue): Promise>; - callRpcMethod(shv_path: string | undefined, method: 'ls', params?: RpcValue): Promise>; - callRpcMethod(shv_path: string | undefined, method: string, params?: RpcValue): Promise; - callRpcMethod(shv_path: string | undefined, method: string, params?: RpcValue) { - const rq = new RpcMessage(); + callRpcMethod(shv_path: '.broker/currentClient', method: 'accessGrantForMethodCall', params: [string, string]): Promise>; + callRpcMethod(shv_path: string | undefined, method: 'dir', params?: RpcValue): Promise>; + callRpcMethod(shv_path: string | undefined, method: 'ls', params?: RpcValue): Promise>; + callRpcMethod(shv_path: string | undefined, method: string, params?: RpcValue): Promise; + callRpcMethod(shv_path: string | undefined, method: string, params?: RpcValue): Promise { const rqId = this.requestId++; - rq.setRequestId(rqId); - if (shv_path !== undefined) { - rq.setShvPath(shv_path); - } - - rq.setMethod(method); - if (params !== undefined) { - rq.setParams(params); - } + const rq: RpcRequest = new RpcValueWithMetaData(makeMetaMap({ + [RPC_MESSAGE_REQUEST_ID]: rqId, + [RPC_MESSAGE_METHOD]: method ?? '', + [RPC_MESSAGE_SHV_PATH]: shv_path ?? '', + }), makeIMap({ + [RPC_MESSAGE_PARAMS]: params, + }), + ); this.sendRpcMessage(rq); - const promise = new Promise(resolve => { + const promise = new Promise(resolve => { this.rpcHandlers[rqId] = {resolve, timeout_handle: self.setTimeout(() => { resolve(new MethodCallTimeout(makeIMap({ [ERROR_CODE]: ErrorCode.MethodCallTimeout, [ERROR_MESSAGE]: `Shv call timeout after: ${this.timeout} msec.`, + [ERROR_DATA]: undefined, }))); }, this.timeout)}; }); @@ -240,8 +292,8 @@ class WsClient { sendRpcMessage(rpc_msg: RpcMessage) { if (this.websocket && this.websocket.readyState === WebSocket.OPEN) { - this.logDebug('sending rpc message:', rpc_msg.toCpon()); - const msgData = new Uint8Array(rpc_msg.toChainPack()); + this.logDebug('sending rpc message:', toCpon(rpc_msg)); + const msgData = new Uint8Array(toChainPack(rpc_msg)); const wr = new ChainPackWriter(); wr.writeUIntData(msgData.length + 1); diff --git a/src/zod.ts b/src/zod.ts index 7284937..621d547 100644 --- a/src/zod.ts +++ b/src/zod.ts @@ -1,11 +1,35 @@ import {z, type ZodType} from 'zod'; -import {type IMap, isIMap, isShvMap, type RpcValue, type ShvMap, UInt} from './rpcvalue'; +import {Decimal, Double, type IMap, isIMap, isMetaMap, isShvMap, type MetaMap, type RpcValue, type RpcValueType, RpcValueWithMetaData, type ShvMap, UInt} from './rpcvalue'; -export const map = >>(schema: T) => z.custom}>>((data: RpcValue) => isShvMap(data) && z.object(schema).safeParse(data).success); +export const map = >>(schema?: T) => z.custom}>>((data: RpcValue) => isShvMap(data) && (schema === undefined || z.object(schema).safeParse(data).success)); export const recmap = >(schema: T) => z.custom>>>((data: RpcValue) => isShvMap(data) && z.record(z.string(), schema).safeParse(data).success); -export const imap = >>(schema: T) => z.custom}>>((data: RpcValue) => isIMap(data) && z.object(schema).safeParse(data).success); +export const imap = >>(schema?: T) => z.custom}>>((data: RpcValue) => isIMap(data) && (schema === undefined || z.object(schema).safeParse(data).success)); +export const metamap = >>(schema?: T) => z.custom}>>((data: RpcValue) => isMetaMap(data) && (schema === undefined || z.object(schema).safeParse(data).success)); export const int = () => z.number(); // eslint-disable-next-line @typescript-eslint/no-unnecessary-type-arguments -- Zod needs the default argument, otherwise it'll infer as UInt export const uint = () => z.instanceof(UInt); +const withMetaInstanceParser = z.instanceof(RpcValueWithMetaData); +export const rpcvalue = () => z.union([ + z.undefined(), + z.boolean(), + z.number(), + uint(), + z.instanceof(Double), + z.instanceof(Decimal), + z.instanceof(ArrayBuffer), + z.string(), + z.date(), + z.array(z.any()), + map(), + imap(), + withMetaInstanceParser, +]); + +export const withMeta = (metaParser: ZodType, valueParser: ZodType) => + z.custom, z.infer>>((data: any) => withMetaInstanceParser.and(z.object({ + meta: metaParser, + value: valueParser, + })).safeParse(data).success); + export * from 'zod';