From c14067f12958a21ab9c587a9266e6f896a8c473c Mon Sep 17 00:00:00 2001 From: Tomas Pilar Date: Fri, 15 Dec 2023 21:50:51 +0100 Subject: [PATCH] fixup! sse streaming client Signed-off-by: Tomas Pilar --- src/api/event-client.ts | 121 ++++++++++++++++++++++++++++++++++++++++ src/client.ts | 31 +++++----- src/index.ts | 1 - src/utils/stream.ts | 109 ------------------------------------ 4 files changed, 136 insertions(+), 126 deletions(-) create mode 100644 src/api/event-client.ts diff --git a/src/api/event-client.ts b/src/api/event-client.ts new file mode 100644 index 0000000..fd5cae3 --- /dev/null +++ b/src/api/event-client.ts @@ -0,0 +1,121 @@ +import { + EventStreamContentType, + fetchEventSource, +} from '@ai-zen/node-fetch-event-source'; + +import { TypedReadable } from '../utils/stream.js'; +import { BaseError, HttpError, InternalError } from '../errors.js'; +import { safeParseJson } from '../helpers/common.js'; +import { RawHeaders } from '../client.js'; + +export interface ApiEventClient { + stream: (opts: { + url: string; + headers?: RawHeaders; + body?: any; + signal?: AbortSignal; + }) => TypedReadable; +} + +export function createApiEventClient(clientOptions: { + baseUrl?: string; + headers?: RawHeaders; +}): ApiEventClient { + return { + stream: function fetchSSE({ + url, + headers, + body, + signal, + }: Parameters[0]) { + const outputStream = new TypedReadable({ + autoDestroy: true, + objectMode: true, + signal: signal, + }); + + const onClose = () => { + if (outputStream.readable) { + outputStream.push(null); + } + }; + + const delegatedController = new AbortController(); + if (signal) { + signal.addEventListener( + 'abort', + () => { + delegatedController.abort(); + }, + { + once: true, + }, + ); + } + + const onError = (e: unknown) => { + const err = + e instanceof BaseError + ? e + : new InternalError('Unexpected error', { cause: e }); + + delegatedController.abort(); + if (outputStream.readable) { + outputStream.emit('error', err); + throw err; + } + onClose(); + }; + fetchEventSource(new URL(url, clientOptions.baseUrl).toString(), { + method: 'POST', + body: JSON.stringify(body), + headers: { + ...clientOptions.headers, + ...headers, + 'Content-Type': 'application/json', + }, + signal: delegatedController.signal, + onclose: onClose, + async onopen(response) { + const contentType = response.headers.get('content-type') || ''; + + if (response.ok && contentType === EventStreamContentType) { + return; + } + + const responseData = contentType.startsWith('application/json') + ? await response.json().catch(() => null) + : null; + + onError(new HttpError(responseData)); + }, + onmessage(message) { + if (message.event === 'close') { + onClose(); + return; + } + if (message.data === '') { + return; + } + + const result = safeParseJson(message.data); + if (result === null) { + onError( + new InternalError( + `Failed to parse message "${JSON.stringify(message)}"`, + ), + ); + return; + } + + outputStream.push(result); + }, + onerror: onError, + }).catch(() => { + /* Prevent uncaught exception (errors are handled inside the stream) */ + }); + + return outputStream; + }, + }; +} diff --git a/src/client.ts b/src/client.ts index 7429279..dd383a2 100644 --- a/src/client.ts +++ b/src/client.ts @@ -12,9 +12,9 @@ import { ApiClientResponse, createApiClient, } from './api/client.js'; -import { fetchSSE } from './utils/stream.js'; import { clientErrorWrapper } from './utils/errors.js'; import { OmitVersion } from './utils/types.js'; +import { ApiEventClient, createApiEventClient } from './api/event-client.js'; export type RawHeaders = Record; @@ -28,8 +28,7 @@ export type Options = { signal?: AbortSignal }; export class Client { readonly #client: ApiClient; - readonly #endpoint: string; - readonly #headers: RawHeaders; + readonly #eventClient: ApiEventClient; constructor(config: Configuration = {}) { const endpoint = config.endpoint ?? lookupEndpoint(); @@ -44,8 +43,7 @@ export class Client { const agent = version ? `node-sdk/${version}` : 'node-sdk'; - this.#endpoint = endpoint; - this.#headers = { + const headers = { 'User-Agent': agent, 'X-Request-Origin': agent, ...config.headers, @@ -53,10 +51,14 @@ export class Client { Authorization: `Bearer ${apiKey}`, }; this.#client = createApiClient({ - baseUrl: this.#endpoint, - headers: this.#headers, + baseUrl: endpoint, + headers, fetch: fetchRetry(fetch) as any, // https://github.com/jonbern/fetch-retry/issues/89 }); + this.#eventClient = createApiEventClient({ + baseUrl: endpoint, + headers, + }); } async models( @@ -137,15 +139,12 @@ export class Client { }, }); - fetchSSE({ - url: new URL( - `/v2/text/generation_stream?version=2023-11-22`, - this.#endpoint, - ), - headers: this.#headers, - body: input, - signal: opts?.signal, - }) + this.#eventClient + .stream({ + url: '/v2/text/generation_stream?version=2023-11-22', + body: input, + signal: opts?.signal, + }) .on('error', (err) => stream.emit('error', err)) .pipe(stream); diff --git a/src/index.ts b/src/index.ts index 3268530..8583a02 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,4 +2,3 @@ export * from './client.js'; export * from './errors.js'; export * from './buildInfo.js'; -export * from './constants.js'; diff --git a/src/utils/stream.ts b/src/utils/stream.ts index 7ba7209..25972d3 100644 --- a/src/utils/stream.ts +++ b/src/utils/stream.ts @@ -1,14 +1,5 @@ import { Readable } from 'stream'; -import { - EventStreamContentType, - fetchEventSource, -} from '@ai-zen/node-fetch-event-source'; - -import { BaseError, HttpError, InternalError } from '../errors.js'; -import { safeParseJson } from '../helpers/common.js'; -import { RawHeaders } from '../client.js'; - export class TypedReadable extends Readable { // eslint-disable-next-line @typescript-eslint/no-unused-vars _read(size: number) { @@ -44,103 +35,3 @@ export class TypedReadable extends Readable { return super[Symbol.asyncIterator](); } } - -export function fetchSSE({ - url, - headers, - body, - signal, -}: { - url: URL; - headers: RawHeaders; - body: any; - signal?: AbortSignal; -}) { - const outputStream = new TypedReadable({ - autoDestroy: true, - objectMode: true, - signal: signal, - }); - - const onClose = () => { - if (outputStream.readable) { - outputStream.push(null); - } - }; - - const delegatedController = new AbortController(); - if (signal) { - signal.addEventListener( - 'abort', - () => { - delegatedController.abort(); - }, - { - once: true, - }, - ); - } - - const onError = (e: unknown) => { - const err = - e instanceof BaseError - ? e - : new InternalError('Unexpected error', { cause: e }); - - delegatedController.abort(); - if (outputStream.readable) { - outputStream.emit('error', err); - throw err; - } - onClose(); - }; - fetchEventSource(url.toString(), { - method: 'POST', - body: JSON.stringify(body), - headers: { - ...headers, - 'Content-Type': 'application/json', - }, - signal: delegatedController.signal, - onclose: onClose, - async onopen(response) { - const contentType = response.headers.get('content-type') || ''; - - if (response.ok && contentType === EventStreamContentType) { - return; - } - - const responseData = contentType.startsWith('application/json') - ? await response.json().catch(() => null) - : null; - - onError(new HttpError(responseData)); - }, - onmessage(message) { - if (message.event === 'close') { - onClose(); - return; - } - if (message.data === '') { - return; - } - - const result = safeParseJson(message.data); - if (result === null) { - onError( - new InternalError( - `Failed to parse message "${JSON.stringify(message)}"`, - ), - ); - return; - } - - outputStream.push(result); - }, - onerror: onError, - }).catch(() => { - /* Prevent uncaught exception (errors are handled inside the stream) */ - }); - - return outputStream; -}