From a29cde327e6eff2c0f97adbb593d675b5f25ebfd Mon Sep 17 00:00:00 2001 From: Tomas Pilar Date: Thu, 4 Apr 2024 16:44:49 +0200 Subject: [PATCH] fix(langchain): genAIModel moderations (#100) Signed-off-by: Tomas Pilar --- examples/langchain/llm.ts | 10 +- src/errors.ts | 2 - src/langchain/llm.ts | 282 +++++++++++++++----------------- tests/e2e/langchain/llm.test.ts | 24 +-- 4 files changed, 147 insertions(+), 171 deletions(-) diff --git a/examples/langchain/llm.ts b/examples/langchain/llm.ts index cb9237b..00594ee 100644 --- a/examples/langchain/llm.ts +++ b/examples/langchain/llm.ts @@ -1,10 +1,9 @@ import { Client } from '../../src/index.js'; import { GenAIModel } from '../../src/langchain/index.js'; -const makeClient = (stream?: boolean) => +const makeClient = () => new GenAIModel({ - modelId: 'google/flan-t5-xl', - stream, + model_id: 'google/flan-t5-xl', client: new Client({ endpoint: process.env.ENDPOINT, apiKey: process.env.API_KEY, @@ -15,6 +14,9 @@ const makeClient = (stream?: boolean) => max_new_tokens: 25, repetition_penalty: 1.5, }, + moderations: { + hap: true, + }, }); { @@ -40,7 +42,7 @@ const makeClient = (stream?: boolean) => { console.info('---Streaming Example---'); - const chat = makeClient(true); + const chat = makeClient(); const prompt = 'What is a molecule?'; console.info(`Request: ${prompt}`); diff --git a/src/errors.ts b/src/errors.ts index 5aa3706..bf47890 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -1,5 +1,3 @@ -import { AbortError } from 'p-queue-compat'; - import { ApiError } from './api/client.js'; export class BaseError extends Error {} diff --git a/src/langchain/llm.ts b/src/langchain/llm.ts index 14fc6d1..cfb3f6a 100644 --- a/src/langchain/llm.ts +++ b/src/langchain/llm.ts @@ -1,134 +1,91 @@ -import { BaseLLM, BaseLLMParams } from '@langchain/core/language_models/llms'; +import { + BaseLLM, + BaseLLMCallOptions, + BaseLLMParams, +} from '@langchain/core/language_models/llms'; import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; -import type { LLMResult, Generation } from '@langchain/core/outputs'; +import type { LLMResult } from '@langchain/core/outputs'; import { GenerationChunk } from '@langchain/core/outputs'; +import merge from 'lodash/merge.js'; import { Client, Configuration } from '../client.js'; -import { - isNotEmptyArray, - concatUnique, - isNullish, - asyncGeneratorToArray, -} from '../helpers/common.js'; +import { concatUnique, isNullish } from '../helpers/common.js'; import { TextGenerationCreateInput, - TextGenerationCreateOutput, + TextGenerationCreateStreamInput, } from '../schema.js'; -type BaseGenAIModelOptions = { - stream?: boolean; - parameters?: Record; - timeout?: number; -} & ( - | { client: Client; configuration?: never } - | { client?: never; configuration: Configuration } -); - -export type GenAIModelOptions = - | (BaseGenAIModelOptions & { modelId?: string; promptId?: never }) - | (BaseGenAIModelOptions & { modelId?: never; promptId: string }); - -export class GenAIModel extends BaseLLM { - #client: Client; - - protected modelId?: string; - protected promptId?: string; - protected isStreaming: boolean; - protected timeout: number | undefined; - protected parameters: Record; +type TextGenerationInput = TextGenerationCreateInput & + TextGenerationCreateStreamInput; + +export type GenAIModelParams = BaseLLMParams & + Pick< + TextGenerationInput, + 'model_id' | 'prompt_id' | 'parameters' | 'moderations' + > & { + model_id: NonNullable; + } & ( + | { client: Client; configuration?: never } + | { client?: never; configuration: Configuration } + ); +export type GenAIModelOptions = BaseLLMCallOptions & + Partial>; + +export class GenAIModel extends BaseLLM { + protected readonly client: Client; + + public readonly modelId: GenAIModelParams['model_id']; + public readonly promptId: GenAIModelParams['prompt_id']; + public readonly parameters: GenAIModelParams['parameters']; + public readonly moderations: GenAIModelParams['moderations']; constructor({ - modelId, - promptId, - stream = false, + model_id, + prompt_id, parameters, - timeout, + moderations, client, configuration, - ...baseParams - }: GenAIModelOptions & BaseLLMParams) { - super(baseParams ?? {}); - - this.modelId = modelId; - this.promptId = promptId; - this.timeout = timeout; - this.isStreaming = Boolean(stream); - this.parameters = parameters || {}; - this.#client = client ?? new Client(configuration); - } - - #createPayload( - prompts: string[], - options: this['ParsedCallOptions'], - ): TextGenerationCreateInput[] { - const stopSequences = concatUnique(this.parameters.stop, options.stop); - - return prompts.map((input) => ({ - ...(!isNullish(this.promptId) - ? { - prompt_id: this.promptId, - } - : !isNullish(this.modelId) - ? { - model_id: this.modelId, - } - : {}), - input, - parameters: { - ...this.parameters, - stop_sequences: isNotEmptyArray(stopSequences) - ? stopSequences - : undefined, - }, - })); - } - - async #execute( - prompts: string[], - options: this['ParsedCallOptions'], - ): Promise { - return await Promise.all( - this.#createPayload(prompts, options).map((input) => - this.#client.text.generation.create(input, { - signal: options.signal, - }), - ), - ); + ...options + }: GenAIModelParams) { + super(options); + + this.modelId = model_id; + this.promptId = prompt_id; + this.parameters = parameters; + this.moderations = moderations; + this.client = client ?? new Client(configuration); } async _generate( - prompts: string[], + inputs: string[], options: this['ParsedCallOptions'], runManager?: CallbackManagerForLLMRun, ): Promise { - const response: TextGenerationCreateOutput[] = []; - if (this.isStreaming) { - const { output } = await asyncGeneratorToArray( - this._streamResponseChunks(prompts[0], options, runManager), - ); - response.push(output); - } else { - const outputs = await this.#execute(prompts, options); - response.push(...outputs); - } - - const contentResponses = response.flatMap( - (res) => res.results?.at(0) ?? [], + const outputs = await Promise.all( + inputs.map((input) => + this.client.text.generation.create( + this._prepareRequest(input, options), + { + signal: options.signal, + }, + ), + ), ); - const generations: Generation[][] = contentResponses.map( - ({ generated_text: text, ...generationInfo }) => [ - { - text, - generationInfo, - }, - ], + const generations = outputs.map((output) => + output.results.map((result) => { + const { generated_text, ...generationInfo } = result; + return { text: generated_text, generationInfo }; + }), ); - const llmOutput = await contentResponses.reduce( + const llmOutput = generations.flat().reduce( (acc, generation) => { - acc.generated_token_count += generation.generated_token_count; - acc.input_token_count += generation.input_token_count ?? 0; + acc.generated_token_count += + generation.generationInfo.generated_token_count; + acc.input_token_count += + generation.generationInfo.input_token_count ?? 0; return acc; }, { @@ -137,63 +94,80 @@ export class GenAIModel extends BaseLLM { }, ); - return { generations, llmOutput }; + return { + generations, + llmOutput, + }; } async *_streamResponseChunks( - _input: string, - _options: this['ParsedCallOptions'], - _runManager?: CallbackManagerForLLMRun, + input: string, + options: this['ParsedCallOptions'], + runManager?: CallbackManagerForLLMRun, ): AsyncGenerator { - const [payload] = this.#createPayload([_input], _options); - const stream = await this.#client.text.generation.create_stream(payload, { - signal: _options.signal, - }); - - const fullOutput = { - id: '', - model_id: '', - created_at: '', - results: [ - { - generated_text: '', - stop_reason: - 'not_finished' as TextGenerationCreateOutput['results'][number]['stop_reason'], - input_token_count: 0, - generated_token_count: 0, - }, - ], - } satisfies TextGenerationCreateOutput; + const stream = await this.client.text.generation.create_stream( + this._prepareRequest(input, options), + { + signal: options.signal, + }, + ); for await (const response of stream) { - fullOutput.id = response.id ?? fullOutput.id; - fullOutput.model_id = response.model_id; - fullOutput.created_at = response.created_at ?? fullOutput.created_at; - - const results = response.results; - if (!results) continue; - - const { generated_text, ...chunk } = results[0]; - const generation = new GenerationChunk({ - text: generated_text, - generationInfo: chunk, - }); - yield generation; - void _runManager?.handleLLMNewToken(generated_text); - - fullOutput.results[0].generated_text += generation.text; - if (chunk.stop_reason) { - fullOutput.results[0].stop_reason = chunk.stop_reason; + if (response.results) { + for (const { generated_text, ...generationInfo } of response.results) { + yield new GenerationChunk({ + text: generated_text, + generationInfo, + }); + void runManager?.handleText(generated_text); + } + } + if (response.moderation) { + yield new GenerationChunk({ + text: '', + generationInfo: { + moderation: response.moderation, + }, + }); + void runManager?.handleText(''); } - fullOutput.results[0].input_token_count += chunk.input_token_count ?? 0; - fullOutput.results[0].generated_token_count += - chunk.generated_token_count; } - return fullOutput; + } + + private _prepareRequest( + input: string, + options: this['ParsedCallOptions'], + ): TextGenerationInput { + const stop_sequences = concatUnique( + options.stop, + options.parameters?.stop_sequences, + ); + const { model_id, prompt_id, ...rest } = merge( + { + model_id: this.modelId, + prompt_id: this.promptId, + moderations: this.moderations, + parameters: this.parameters, + }, + { + model_id: options.model_id, + prompt_id: options.prompt_id, + moderations: options.moderations, + parameters: { + ...options.parameters, + stop_sequences, + }, + }, + { input }, + ); + return { + ...(prompt_id ? { prompt_id } : { model_id }), + ...rest, + }; } async getNumTokens(input: string): Promise { - const result = await this.#client.text.tokenization.create({ + const result = await this.client.text.tokenization.create({ ...(!isNullish(this.modelId) && { model_id: this.modelId, }), @@ -209,7 +183,7 @@ export class GenAIModel extends BaseLLM { } _modelType(): string { - return this.modelId ?? 'default'; + return this.modelId; } _llmType(): string { diff --git a/tests/e2e/langchain/llm.test.ts b/tests/e2e/langchain/llm.test.ts index 388da3a..c2903e0 100644 --- a/tests/e2e/langchain/llm.test.ts +++ b/tests/e2e/langchain/llm.test.ts @@ -5,10 +5,9 @@ import { GenAIModel } from '../../../src/langchain/llm.js'; import { Client } from '../../../src/client.js'; describe('Langchain', () => { - const makeClient = (modelId?: string, stream?: boolean) => + const makeClient = (modelId: string) => new GenAIModel({ - modelId, - stream, + model_id: modelId, client: new Client({ endpoint: process.env.ENDPOINT, apiKey: process.env.API_KEY, @@ -38,7 +37,7 @@ describe('Langchain', () => { describe('generate', () => { // TODO: enable once we will set default model for the test account test.skip('should handle empty modelId', async () => { - const client = makeClient(); + const client = makeClient('google/flan-ul2'); const data = await client.invoke('Who are you?'); expectIsString(data); @@ -101,24 +100,27 @@ describe('Langchain', () => { }); test('streaming', async () => { - const client = makeClient('google/flan-t5-xl', true); + const client = makeClient('google/flan-t5-xl'); const tokens: string[] = []; - const handleNewToken = vi.fn((token: string) => { + const handleText = vi.fn((token: string) => { tokens.push(token); }); - const output = await client.invoke('Tell me a joke.', { + const stream = await client.stream('Tell me a joke.', { callbacks: [ { - handleLLMNewToken: handleNewToken, + handleText, }, ], }); - expect(handleNewToken).toHaveBeenCalled(); - expectIsString(output); - expect(tokens.join('')).toStrictEqual(output); + const outputs = []; + for await (const output of stream) { + outputs.push(output); + } + expect(handleText).toHaveBeenCalledTimes(outputs.length); + expect(tokens.join('')).toStrictEqual(outputs.join('')); }, 15_000); });