From 23b4658dca71dea7c53da1acacbb5a1c258e50ff Mon Sep 17 00:00:00 2001 From: Tomas Pilar Date: Tue, 24 Oct 2023 13:22:20 +0200 Subject: [PATCH] feat(chat): add chat support to the client (#57) Signed-off-by: Tomas Pilar --- examples/chat.ts | 82 +++++++++++++++++ package.json | 3 +- src/api-types.ts | 40 ++++++++- src/client/client.ts | 128 +++++++++++++++++++++++---- src/client/types.ts | 21 +++++ src/helpers/common.ts | 23 ++++- src/helpers/types.ts | 7 ++ src/tests/e2e/client.test.ts | 97 +++++++++++++++++++- src/tests/integration/client.test.ts | 28 +++++- src/tests/mocks/handlers.ts | 32 +++++++ 10 files changed, 437 insertions(+), 24 deletions(-) create mode 100644 examples/chat.ts diff --git a/examples/chat.ts b/examples/chat.ts new file mode 100644 index 0000000..1217790 --- /dev/null +++ b/examples/chat.ts @@ -0,0 +1,82 @@ +import { Client } from '../src/index.js'; + +const client = new Client({ + apiKey: process.env.GENAI_API_KEY, +}); + +const model_id = 'google/ul2'; + +{ + // Start a conversation + const { + conversation_id, + result: { generated_text: answer1 }, + } = await client.chat({ + model_id, + messages: [ + { + role: 'system', + content: 'Answer yes or no', + }, + { + role: 'user', + content: 'Hello, are you a robot?', + }, + ], + }); + console.log(answer1); + + // Continue the conversation + const { + result: { generated_text: answer2 }, + } = await client.chat({ + conversation_id, + model_id, + messages: [ + { + role: 'user', + content: 'Are you sure?', + }, + ], + }); + console.log(answer2); +} + +{ + // Chat inteface has the same promise, streaming and callback variants as generate interface + + // Promise + const data = await client.chat({ + model_id, + messages: [{ role: 'user', content: 'How are you?' }], + }); + console.log(data.result.generated_text); + // Callback + client.chat( + { model_id, messages: [{ role: 'user', content: 'How are you?' }] }, + (err, data) => { + if (err) console.error(err); + else console.log(data.result.generated_text); + }, + ); + // Stream + for await (const chunk of client.chat( + { model_id, messages: [{ role: 'user', content: 'How are you?' }] }, + { stream: true }, + )) { + console.log(chunk.result.generated_text); + } + // Streaming callbacks + client.chat( + { + model_id: 'google/ul2', + messages: [{ role: 'user', content: 'How are you?' }], + }, + { stream: true }, + (err, data) => { + if (err) console.error(err); + else if (data) console.log(data.result.generated_text); + else console.log('EOS'); + }, + ); +} diff --git a/package.json b/package.json index b6f550b..f0972f7 100644 --- a/package.json +++ b/package.json @@ -60,7 +60,8 @@ "example:generate": "yarn run example:run examples/generate.ts", "example:tune": "yarn run example:run examples/tune.ts", "example:prompt-template": "yarn run example:run examples/prompt-templates.ts", - "example:file": "yarn run example:run examples/file.ts" + "example:file": "yarn run example:run examples/file.ts", + "example:chat": "yarn run example:run examples/chat.ts" }, "peerDependencies": { "langchain": ">=0.0.155" diff --git a/src/api-types.ts b/src/api-types.ts index 2342f8c..9d53f1b 100644 --- a/src/api-types.ts +++ b/src/api-types.ts @@ -61,11 +61,13 @@ export interface UserGenerateDefaultOutput { // GENERATE +const ParametersSchema = z.record(z.any()); + export const GenerateInputSchema = z.object({ model_id: z.string().nullish(), prompt_id: z.string().nullish(), inputs: z.array(z.string()), - parameters: z.optional(z.record(z.any())), + parameters: z.optional(ParametersSchema), use_default: z.optional(z.boolean()), }); export type GenerateInput = z.infer; @@ -399,3 +401,39 @@ export const FilesOutputSchema = PaginationOutputSchema.extend({ results: z.array(SingleFileOutputSchema), }); export type FilesOutput = z.output; + +// CHAT + +export const ChatRoleSchema = z.enum(['user', 'system', 'assistant']); +export type ChatRole = z.infer; + +export const ChatInputSchema = z.object({ + model_id: z.string(), + messages: z.array( + z.object({ + role: ChatRoleSchema, + content: z.string(), + }), + ), + conversation_id: z.string().nullish(), + parent_id: z.string().nullish(), + prompt_id: z.string().nullish(), + parameters: ParametersSchema.nullish(), +}); +export type ChatInput = z.input; +export const ChatOutputSchema = z.object({ + conversation_id: z.string(), + results: z.array( + z + .object({ + generated_text: z.string(), + }) + .partial(), + ), +}); +export type ChatOutput = z.output; + +export const ChatStreamInputSchema = ChatInputSchema; +export type ChatStreamInput = z.input; +export const ChatStreamOutputSchema = ChatOutputSchema; +export type ChatStreamOutput = z.output; diff --git a/src/client/client.ts b/src/client/client.ts index c07af46..b0b2c15 100644 --- a/src/client/client.ts +++ b/src/client/client.ts @@ -1,6 +1,6 @@ import http, { IncomingMessage } from 'node:http'; import https from 'node:https'; -import { Transform, TransformCallback } from 'stream'; +import { Transform, TransformCallback } from 'node:stream'; import axios, { AxiosError } from 'axios'; import FormData from 'form-data'; @@ -36,10 +36,13 @@ import { handleGenerator, paginator, isEmptyObject, + callbackifyStream, + callbackifyPromise, } from '../helpers/common.js'; import { TypedReadable } from '../utils/stream.js'; import { lookupApiKey, lookupEndpoint } from '../helpers/config.js'; import { RETRY_ATTEMPTS_DEFAULT } from '../constants.js'; +import { Callback } from '../helpers/types.js'; import { GenerateConfigInput, @@ -92,6 +95,12 @@ import { FilesInput, FileDeleteOutput, PromptTemplateDeleteOutput, + ChatInput, + ChatOutput, + ChatOptions, + ChatStreamOptions, + ChatStreamInput, + ChatStreamOutput, } from './types.js'; import { CacheDiscriminator, generateCacheKey } from './cache.js'; @@ -116,10 +125,6 @@ export interface Configuration { retries?: HttpHandlerOptions['retries']; } -type ErrorCallback = (err: unknown) => void; -type DataCallback = (err: unknown, result: T) => void; -export type Callback = ErrorCallback | DataCallback; - export class Client { readonly #client: AxiosCacheInstance; readonly #options: Required; @@ -484,12 +489,7 @@ export class Client { return stream; } - stream.on('data', (data) => callback(null, data)); - stream.on('error', (err) => (callback as ErrorCallback)(err)); - stream.on('finish', () => - (callback as DataCallback)(null, null), - ); - + callbackifyStream(stream)(callback); return; } @@ -549,12 +549,7 @@ export class Client { }); if (callback) { - promises.forEach((promise) => - promise.then( - (data) => callback(null as never, data), - (err) => (callback as ErrorCallback)(err), - ), - ); + promises.forEach((promise) => callbackifyPromise(promise)(callback)); } else { return Array.isArray(input) ? promises : promises[0]; } @@ -1320,4 +1315,103 @@ export class Client { return transformOutput(result); }); } + + chat(input: ChatInput, callback: Callback): void; + chat( + input: ChatInput, + options: ChatOptions, + callback: Callback, + ): void; + chat( + input: ChatStreamInput, + options: ChatStreamOptions, + callback: Callback, + ): void; + chat(input: ChatInput, options?: ChatOptions): Promise; + chat( + input: ChatStreamInput, + options?: ChatStreamOptions, + ): TypedReadable; + chat( + input: ChatInput | ChatStreamInput, + optionsOrCallback?: + | ChatOptions + | ChatStreamOptions + | Callback + | Callback, + callback?: Callback, + ): TypedReadable | Promise | void { + const { callback: cb, options } = parseFunctionOverloads( + undefined, + optionsOrCallback, + callback, + ); + + if (options?.stream) { + const stream = new Transform({ + autoDestroy: true, + objectMode: true, + transform( + chunk: ApiTypes.ChatStreamOutput, + encoding: BufferEncoding, + callback: TransformCallback, + ) { + const { results, ...rest } = chunk; + callback(null, { + ...rest, + result: results[0], + } as ChatStreamOutput); + }, + }); + this.#fetcher({ + ...options, + method: 'POST', + url: '/v0/generate/chat', + data: { + ...input, + parameters: { + ...input.parameters, + stream: true, + }, + }, + stream: true, + }) + .on('error', (err) => stream.emit('error', errorTransformer(err))) + .pipe(stream); + + if (cb) { + callbackifyStream(stream)(cb); + return; + } else { + return stream; + } + } else { + const promise = (async () => { + const { results, ...rest } = await this.#fetcher< + ApiTypes.ChatOutput, + ApiTypes.ChatInput + >( + { + ...options, + method: 'POST', + url: '/v0/generate/chat', + data: input, + stream: false, + }, + ApiTypes.ChatOutputSchema, + ); + if (results.length !== 1) { + throw new InternalError('Unexpected number of results'); + } + return { ...rest, result: results[0] }; + })(); + + if (cb) { + callbackifyPromise(promise)(cb); + return; + } else { + return promise; + } + } + } } diff --git a/src/client/types.ts b/src/client/types.ts index 1a93c33..6b28f9e 100644 --- a/src/client/types.ts +++ b/src/client/types.ts @@ -263,3 +263,24 @@ export type FileDeleteOptions = HttpHandlerOptions & FlagOption<'delete', true>; export const FilesOutputSchema = ApiTypes.FilesOutputSchema.shape.results.element; export type FilesOutput = z.output; + +// CHAT + +export const ChatInputSchema = z.union([ + ApiTypes.ChatInputSchema, + ApiTypes.ChatStreamInputSchema, +]); +export type ChatInput = z.input; +export type ChatOptions = HttpHandlerNoStreamOptions; +export const ChatOutputSchema = ApiTypes.ChatOutputSchema.omit({ + results: true, +}).extend({ result: ApiTypes.ChatOutputSchema.shape.results.element }); +export type ChatOutput = z.output; + +export const ChatStreamInputSchema = ApiTypes.ChatStreamInputSchema; +export type ChatStreamInput = z.input; +export type ChatStreamOptions = HttpHandlerStreamOptions; +export const ChatStreamOutputSchema = ApiTypes.ChatStreamOutputSchema.omit({ + results: true, +}).extend({ result: ApiTypes.ChatOutputSchema.shape.results.element }); +export type ChatStreamOutput = z.output; diff --git a/src/helpers/common.ts b/src/helpers/common.ts index abd99ec..7d2e63b 100644 --- a/src/helpers/common.ts +++ b/src/helpers/common.ts @@ -1,10 +1,10 @@ import { callbackify } from 'node:util'; import { URLSearchParams } from 'node:url'; +import { Readable } from 'node:stream'; import { z } from 'zod'; -export type FalsyValues = false | '' | 0 | null | undefined; -export type Truthy = T extends FalsyValues ? never : T; +import { ErrorCallback, DataCallback, Truthy, Callback } from './types.js'; export function isTruthy(value: T): value is Truthy { return Boolean(value); @@ -151,6 +151,25 @@ export function callbackifyGenerator(generatorFn: () => AsyncGenerator) { }; } +export function callbackifyStream(stream: Readable) { + return (callbackFn: Callback) => { + stream.on('data', (data) => callbackFn(null, data)); + stream.on('error', (err) => (callbackFn as ErrorCallback)(err)); + stream.on('finish', () => + (callbackFn as DataCallback)(null, null), + ); + }; +} + +export function callbackifyPromise(promise: Promise) { + return (callbackFn: Callback) => { + promise.then( + (data) => callbackFn(null, data), + (err) => (callbackFn as ErrorCallback)(err), + ); + }; +} + export async function* paginator( executor: (searchParams: URLSearchParams) => Promise<{ results: T[]; diff --git a/src/helpers/types.ts b/src/helpers/types.ts index c3be50e..3325fa5 100644 --- a/src/helpers/types.ts +++ b/src/helpers/types.ts @@ -4,3 +4,10 @@ export type RequiredPartial = Required> & export type FlagOption = T extends true ? { [k in Key]: true } : { [k in Key]?: false }; + +export type FalsyValues = false | '' | 0 | null | undefined; +export type Truthy = T extends FalsyValues ? never : T; + +export type ErrorCallback = (err: unknown) => void; +export type DataCallback = (err: unknown, result: T) => void; +export type Callback = ErrorCallback | DataCallback; diff --git a/src/tests/e2e/client.test.ts b/src/tests/e2e/client.test.ts index 2f3f297..95eff85 100644 --- a/src/tests/e2e/client.test.ts +++ b/src/tests/e2e/client.test.ts @@ -1,4 +1,8 @@ -import { GenerateInput, GenerateOutput } from '../../client/types.js'; +import { + ChatOutput, + GenerateInput, + GenerateOutput, +} from '../../client/types.js'; import { Client } from '../../client/client.js'; import { RequestCanceledError } from '../../errors.js'; @@ -180,6 +184,97 @@ describe('client', () => { }); }); + describe('chat', () => { + describe('streaming', () => { + const makeValidStream = () => + client.chat( + { + model_id: 'google/ul2', + messages: [{ role: 'user', content: 'Hello World!' }], + }, + { + stream: true, + }, + ); + + const validateStreamChunk = (chunk: ChatOutput) => { + expect(chunk).toBeObject(); + expect(chunk).toHaveProperty('conversation_id'); + expect(chunk).toHaveProperty('result'); + }; + + test('should return valid stream', async () => { + const stream = makeValidStream(); + + const chunks: ChatOutput[] = []; + for await (const chunk of stream) { + validateStreamChunk(chunk); + chunks.push(chunk); + } + + expect(chunks.length).toBeGreaterThan(0); + }, 15_000); + + test('should handle callback approach', async () => { + const chunks = await new Promise((resolve, reject) => { + const chunks: ChatOutput[] = []; + client.chat( + { + model_id: 'google/ul2', + messages: [{ role: 'user', content: 'Hello World!' }], + }, + { + stream: true, + }, + (err, data) => { + if (err) { + console.info(data); + reject(err); + return; + } + if (data === null) { + resolve(chunks); + return; + } + chunks.push(data); + }, + ); + }); + + expect(chunks.length).toBeGreaterThan(0); + for (const chunk of chunks) { + validateStreamChunk(chunk); + } + }, 15_000); + + test('should handle errors', async () => { + const stream = client.chat( + { + model_id: 'XXX/XXX', + messages: [{ role: 'user', content: 'Hello World!' }], + }, + { + stream: true, + }, + ); + + await expect( + new Promise((_, reject) => { + stream.on('error', reject); + }), + ).rejects.toMatchObject({ + code: 'ERR_NON_2XX_3XX_RESPONSE', + statusCode: 404, + message: 'Model not found', + extensions: { + code: 'NOT_FOUND', + state: { model_id: 'XXX/XXX' }, + }, + }); + }, 5_000); + }); + }); + describe('error handling', () => { test('should reject with extended error for invalid model', async () => { await expect( diff --git a/src/tests/integration/client.test.ts b/src/tests/integration/client.test.ts index ae47a48..93ec1b6 100644 --- a/src/tests/integration/client.test.ts +++ b/src/tests/integration/client.test.ts @@ -144,8 +144,8 @@ describe('client', () => { }); describe('tokenize', () => { - test('should return tokenize info', () => { - expect( + test('should return tokenize info', async () => { + await expect( client.tokenize({ input: 'Hello, how are you? Are you okay?', model_id: 'google/flan-t5-xl', @@ -154,6 +154,30 @@ describe('client', () => { }); }); + describe('chat', () => { + test('should start a conversation', async () => { + await expect( + client.chat({ + model_id: 'google/flan-t5-xl', + messages: [ + { role: 'system', content: 'foo' }, + { role: 'user', content: 'bar' }, + ], + }), + ).resolves.toHaveProperty('conversation_id'); + }); + + test('should continue an existing conversation', async () => { + await expect( + client.chat({ + model_id: 'google/flan-t5-xl', + conversation_id: 'foo', + messages: [{ role: 'user', content: 'bar' }], + }), + ).resolves.toHaveProperty('conversation_id', 'foo'); + }); + }); + describe('models', () => { test('should return some models', async () => { const models = await client.models(); diff --git a/src/tests/mocks/handlers.ts b/src/tests/mocks/handlers.ts index 5fea17c..2be4c44 100644 --- a/src/tests/mocks/handlers.ts +++ b/src/tests/mocks/handlers.ts @@ -130,11 +130,22 @@ export const resetHistoryStore = () => { })); }; +export const chatStore = new Map(); +export const resetChatStore = () => { + chatStore.clear(); + chatStore.set(randomUUID(), [ + { role: 'system', content: 'instruction' }, + { role: 'user', content: 'hello' }, + { role: 'assistant', content: 'hi' }, + ]); +}; + export const resetStores = () => { resetGenerateConfigStore(); resetTunesStore(); resetPromptTemplateStore(); resetHistoryStore(); + resetChatStore(); }; resetStores(); @@ -401,4 +412,25 @@ export const handlers: RestHandler>[] = [ }), ); }), + + // Chat + rest.post(`${MOCK_ENDPOINT}/v0/generate/chat`, async (req, res, ctx) => { + const body = await req.json(); + const conversation_id = body.conversation_id ?? randomUUID(); + if (!chatStore.has(conversation_id)) { + chatStore.set(conversation_id, body.messages); + } else { + chatStore.get(conversation_id)?.push(...body.messages); + } + const conversation = chatStore.get(conversation_id); + return res( + ctx.status(200), + ctx.json({ + conversation_id, + results: conversation + ?.slice(-1) + .map(({ role, content }) => ({ role, generated_text: content })), + }), + ); + }), ];