Skip to content

Commit

Permalink
feat(chat): add chat support to the client (#57)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Pilar <[email protected]>
  • Loading branch information
pilartomas authored Oct 24, 2023
1 parent e992686 commit 23b4658
Show file tree
Hide file tree
Showing 10 changed files with 437 additions and 24 deletions.
82 changes: 82 additions & 0 deletions examples/chat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { Client } from '../src/index.js';

const client = new Client({
apiKey: process.env.GENAI_API_KEY,
});

const model_id = 'google/ul2';

{
// Start a conversation
const {
conversation_id,
result: { generated_text: answer1 },
} = await client.chat({
model_id,
messages: [
{
role: 'system',
content: 'Answer yes or no',
},
{
role: 'user',
content: 'Hello, are you a robot?',
},
],
});
console.log(answer1);

// Continue the conversation
const {
result: { generated_text: answer2 },
} = await client.chat({
conversation_id,
model_id,
messages: [
{
role: 'user',
content: 'Are you sure?',
},
],
});
console.log(answer2);
}

{
// Chat inteface has the same promise, streaming and callback variants as generate interface

// Promise
const data = await client.chat({
model_id,
messages: [{ role: 'user', content: 'How are you?' }],
});
console.log(data.result.generated_text);
// Callback
client.chat(
{ model_id, messages: [{ role: 'user', content: 'How are you?' }] },
(err, data) => {
if (err) console.error(err);
else console.log(data.result.generated_text);
},
);
// Stream
for await (const chunk of client.chat(
{ model_id, messages: [{ role: 'user', content: 'How are you?' }] },
{ stream: true },
)) {
console.log(chunk.result.generated_text);
}
// Streaming callbacks
client.chat(
{
model_id: 'google/ul2',
messages: [{ role: 'user', content: 'How are you?' }],
},
{ stream: true },
(err, data) => {
if (err) console.error(err);
else if (data) console.log(data.result.generated_text);
else console.log('EOS');
},
);
}
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
"example:generate": "yarn run example:run examples/generate.ts",
"example:tune": "yarn run example:run examples/tune.ts",
"example:prompt-template": "yarn run example:run examples/prompt-templates.ts",
"example:file": "yarn run example:run examples/file.ts"
"example:file": "yarn run example:run examples/file.ts",
"example:chat": "yarn run example:run examples/chat.ts"
},
"peerDependencies": {
"langchain": ">=0.0.155"
Expand Down
40 changes: 39 additions & 1 deletion src/api-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ export interface UserGenerateDefaultOutput {

// GENERATE

const ParametersSchema = z.record(z.any());

export const GenerateInputSchema = z.object({
model_id: z.string().nullish(),
prompt_id: z.string().nullish(),
inputs: z.array(z.string()),
parameters: z.optional(z.record(z.any())),
parameters: z.optional(ParametersSchema),
use_default: z.optional(z.boolean()),
});
export type GenerateInput = z.infer<typeof GenerateInputSchema>;
Expand Down Expand Up @@ -399,3 +401,39 @@ export const FilesOutputSchema = PaginationOutputSchema.extend({
results: z.array(SingleFileOutputSchema),
});
export type FilesOutput = z.output<typeof FilesOutputSchema>;

// CHAT

export const ChatRoleSchema = z.enum(['user', 'system', 'assistant']);
export type ChatRole = z.infer<typeof ChatRoleSchema>;

export const ChatInputSchema = z.object({
model_id: z.string(),
messages: z.array(
z.object({
role: ChatRoleSchema,
content: z.string(),
}),
),
conversation_id: z.string().nullish(),
parent_id: z.string().nullish(),
prompt_id: z.string().nullish(),
parameters: ParametersSchema.nullish(),
});
export type ChatInput = z.input<typeof ChatInputSchema>;
export const ChatOutputSchema = z.object({
conversation_id: z.string(),
results: z.array(
z
.object({
generated_text: z.string(),
})
.partial(),
),
});
export type ChatOutput = z.output<typeof ChatOutputSchema>;

export const ChatStreamInputSchema = ChatInputSchema;
export type ChatStreamInput = z.input<typeof ChatStreamInputSchema>;
export const ChatStreamOutputSchema = ChatOutputSchema;
export type ChatStreamOutput = z.output<typeof ChatStreamOutputSchema>;
128 changes: 111 additions & 17 deletions src/client/client.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import http, { IncomingMessage } from 'node:http';
import https from 'node:https';
import { Transform, TransformCallback } from 'stream';
import { Transform, TransformCallback } from 'node:stream';

import axios, { AxiosError } from 'axios';
import FormData from 'form-data';
Expand Down Expand Up @@ -36,10 +36,13 @@ import {
handleGenerator,
paginator,
isEmptyObject,
callbackifyStream,
callbackifyPromise,
} from '../helpers/common.js';
import { TypedReadable } from '../utils/stream.js';
import { lookupApiKey, lookupEndpoint } from '../helpers/config.js';
import { RETRY_ATTEMPTS_DEFAULT } from '../constants.js';
import { Callback } from '../helpers/types.js';

import {
GenerateConfigInput,
Expand Down Expand Up @@ -92,6 +95,12 @@ import {
FilesInput,
FileDeleteOutput,
PromptTemplateDeleteOutput,
ChatInput,
ChatOutput,
ChatOptions,
ChatStreamOptions,
ChatStreamInput,
ChatStreamOutput,
} from './types.js';
import { CacheDiscriminator, generateCacheKey } from './cache.js';

Expand All @@ -116,10 +125,6 @@ export interface Configuration {
retries?: HttpHandlerOptions['retries'];
}

type ErrorCallback = (err: unknown) => void;
type DataCallback<T> = (err: unknown, result: T) => void;
export type Callback<T> = ErrorCallback | DataCallback<T>;

export class Client {
readonly #client: AxiosCacheInstance;
readonly #options: Required<Configuration>;
Expand Down Expand Up @@ -484,12 +489,7 @@ export class Client {
return stream;
}

stream.on('data', (data) => callback(null, data));
stream.on('error', (err) => (callback as ErrorCallback)(err));
stream.on('finish', () =>
(callback as DataCallback<GenerateOutput | null>)(null, null),
);

callbackifyStream<GenerateOutput>(stream)(callback);
return;
}

Expand Down Expand Up @@ -549,12 +549,7 @@ export class Client {
});

if (callback) {
promises.forEach((promise) =>
promise.then(
(data) => callback(null as never, data),
(err) => (callback as ErrorCallback)(err),
),
);
promises.forEach((promise) => callbackifyPromise(promise)(callback));
} else {
return Array.isArray(input) ? promises : promises[0];
}
Expand Down Expand Up @@ -1320,4 +1315,103 @@ export class Client {
return transformOutput(result);
});
}

chat(input: ChatInput, callback: Callback<ChatOutput>): void;
chat(
input: ChatInput,
options: ChatOptions,
callback: Callback<ChatOutput>,
): void;
chat(
input: ChatStreamInput,
options: ChatStreamOptions,
callback: Callback<ChatStreamOutput | null>,
): void;
chat(input: ChatInput, options?: ChatOptions): Promise<ChatOutput>;
chat(
input: ChatStreamInput,
options?: ChatStreamOptions,
): TypedReadable<ChatStreamOutput>;
chat(
input: ChatInput | ChatStreamInput,
optionsOrCallback?:
| ChatOptions
| ChatStreamOptions
| Callback<ChatOutput>
| Callback<ChatStreamOutput>,
callback?: Callback<ChatOutput>,
): TypedReadable<ChatStreamOutput> | Promise<ChatOutput> | void {
const { callback: cb, options } = parseFunctionOverloads(
undefined,
optionsOrCallback,
callback,
);

if (options?.stream) {
const stream = new Transform({
autoDestroy: true,
objectMode: true,
transform(
chunk: ApiTypes.ChatStreamOutput,
encoding: BufferEncoding,
callback: TransformCallback,
) {
const { results, ...rest } = chunk;
callback(null, {
...rest,
result: results[0],
} as ChatStreamOutput);
},
});
this.#fetcher<ApiTypes.ChatStreamOutput, ApiTypes.ChatStreamInput>({
...options,
method: 'POST',
url: '/v0/generate/chat',
data: {
...input,
parameters: {
...input.parameters,
stream: true,
},
},
stream: true,
})
.on('error', (err) => stream.emit('error', errorTransformer(err)))
.pipe(stream);

if (cb) {
callbackifyStream<ChatStreamOutput>(stream)(cb);
return;
} else {
return stream;
}
} else {
const promise = (async () => {
const { results, ...rest } = await this.#fetcher<
ApiTypes.ChatOutput,
ApiTypes.ChatInput
>(
{
...options,
method: 'POST',
url: '/v0/generate/chat',
data: input,
stream: false,
},
ApiTypes.ChatOutputSchema,
);
if (results.length !== 1) {
throw new InternalError('Unexpected number of results');
}
return { ...rest, result: results[0] };
})();

if (cb) {
callbackifyPromise(promise)(cb);
return;
} else {
return promise;
}
}
}
}
21 changes: 21 additions & 0 deletions src/client/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,24 @@ export type FileDeleteOptions = HttpHandlerOptions & FlagOption<'delete', true>;
export const FilesOutputSchema =
ApiTypes.FilesOutputSchema.shape.results.element;
export type FilesOutput = z.output<typeof FilesOutputSchema>;

// CHAT

export const ChatInputSchema = z.union([
ApiTypes.ChatInputSchema,
ApiTypes.ChatStreamInputSchema,
]);
export type ChatInput = z.input<typeof ChatInputSchema>;
export type ChatOptions = HttpHandlerNoStreamOptions;
export const ChatOutputSchema = ApiTypes.ChatOutputSchema.omit({
results: true,
}).extend({ result: ApiTypes.ChatOutputSchema.shape.results.element });
export type ChatOutput = z.output<typeof ChatOutputSchema>;

export const ChatStreamInputSchema = ApiTypes.ChatStreamInputSchema;
export type ChatStreamInput = z.input<typeof ChatStreamInputSchema>;
export type ChatStreamOptions = HttpHandlerStreamOptions;
export const ChatStreamOutputSchema = ApiTypes.ChatStreamOutputSchema.omit({
results: true,
}).extend({ result: ApiTypes.ChatOutputSchema.shape.results.element });
export type ChatStreamOutput = z.output<typeof ChatStreamOutputSchema>;
23 changes: 21 additions & 2 deletions src/helpers/common.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { callbackify } from 'node:util';
import { URLSearchParams } from 'node:url';
import { Readable } from 'node:stream';

import { z } from 'zod';

export type FalsyValues = false | '' | 0 | null | undefined;
export type Truthy<T> = T extends FalsyValues ? never : T;
import { ErrorCallback, DataCallback, Truthy, Callback } from './types.js';

export function isTruthy<T>(value: T): value is Truthy<T> {
return Boolean(value);
Expand Down Expand Up @@ -151,6 +151,25 @@ export function callbackifyGenerator<T>(generatorFn: () => AsyncGenerator<T>) {
};
}

export function callbackifyStream<T>(stream: Readable) {
return (callbackFn: Callback<T>) => {
stream.on('data', (data) => callbackFn(null, data));
stream.on('error', (err) => (callbackFn as ErrorCallback)(err));
stream.on('finish', () =>
(callbackFn as DataCallback<T | null>)(null, null),
);
};
}

export function callbackifyPromise<T>(promise: Promise<T>) {
return (callbackFn: Callback<T>) => {
promise.then(
(data) => callbackFn(null, data),
(err) => (callbackFn as ErrorCallback)(err),
);
};
}

export async function* paginator<T>(
executor: (searchParams: URLSearchParams) => Promise<{
results: T[];
Expand Down
Loading

0 comments on commit 23b4658

Please sign in to comment.