Skip to content

Commit

Permalink
feat(adapters): add embedding support for IBM vLLM (#251)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D authored Dec 13, 2024
1 parent 5d0c926 commit 2925dfc
Show file tree
Hide file tree
Showing 17 changed files with 2,603 additions and 57 deletions.
8 changes: 8 additions & 0 deletions examples/llms/providers/ibm-vllm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,11 @@ const client = new Client();
]);
console.info(response.messages);
}

{
console.info("===EMBEDDING===");
const llm = new IBMvLLM({ client, modelId: "baai/bge-large-en-v1.5" });

const response = await llm.embed([`Hello world!`, `Hello family!`]);
console.info(response);
}
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
"mathjs": "^14.0.0",
"mustache": "^4.2.0",
"object-hash": "^3.0.0",
"p-queue": "^8.0.1",
"p-queue-compat": "^1.0.227",
"p-throttle": "^7.0.0",
"pino": "^9.5.0",
"promise-based-task": "^3.1.1",
Expand Down
8 changes: 4 additions & 4 deletions scripts/ibm_vllm_generate_protos/ibm_vllm_generate_protos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ GRPC_PROTO_PATH="./src/adapters/ibm-vllm/proto"
GRPC_TYPES_PATH="./src/adapters/ibm-vllm/types.ts"

SCRIPT_DIR="$(dirname "$0")"
OUTPUT_RELATIVE_PATH="dist/generation.d.ts"
GRPC_TYPES_TMP_PATH=types
OUTPUT_RELATIVE_PATH="dist/merged.d.ts"
GRPC_TYPES_TMP_PATH="types"

rm -f "$GRPC_TYPES_PATH"

Expand All @@ -39,7 +39,7 @@ yarn run proto-loader-gen-types \


cd "$SCRIPT_DIR"
tsup --dts-only
ENTRY="$(basename "$OUTPUT_RELATIVE_PATH" ".d.ts")" tsup --dts-only
sed -i.bak '$ d' "$OUTPUT_RELATIVE_PATH"
sed -i.bak -E "s/^interface/export interface/" "$OUTPUT_RELATIVE_PATH"
sed -i.bak -E "s/^type/export type/" "$OUTPUT_RELATIVE_PATH"
Expand All @@ -50,4 +50,4 @@ rm -rf "${SCRIPT_DIR}"/{dist,dts,types}

yarn run lint:fix "${GRPC_TYPES_PATH}"
yarn prettier --write "${GRPC_TYPES_PATH}"
yarn copyright
TARGETS="$GRPC_TYPES_PATH" yarn copyright
2 changes: 1 addition & 1 deletion scripts/ibm_vllm_generate_protos/tsconfig.proto.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"rootDir": ".",
"baseUrl": ".",
"target": "ESNext",
"module": "ES6",
"module": "ESNext",
"outDir": "dist",
"declaration": true,
"emitDeclarationOnly": true,
Expand Down
17 changes: 15 additions & 2 deletions scripts/ibm_vllm_generate_protos/tsup.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,27 @@
*/

import { defineConfig } from "tsup";
import fs from "node:fs";

if (!process.env.ENTRY) {
throw new Error(`Entry file was not provided!`);
}
const target = `types/${process.env.ENTRY}.ts`;
await fs.promises.writeFile(
target,
[
`export { ProtoGrpcType as A } from "./caikit_runtime_Nlp.js"`,
`export { ProtoGrpcType as B } from "./generation.js"`,
].join("\n"),
);

export default defineConfig({
entry: ["types/generation.ts"],
entry: [target],
tsconfig: "./tsconfig.proto.json",
sourcemap: false,
dts: true,
format: ["esm"],
treeshake: false,
treeshake: true,
legacyOutput: false,
skipNodeModulesBundle: true,
bundle: true,
Expand Down
16 changes: 10 additions & 6 deletions src/adapters/ibm-vllm/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

import { isFunction, isObjectType } from "remeda";

import { IBMvLLM, IBMvLLMGenerateOptions, IBMvLLMOutput, IBMvLLMParameters } from "./llm.js";
import {
IBMvLLM,
IBMvLLMEmbeddingOptions,
IBMvLLMGenerateOptions,
IBMvLLMOutput,
IBMvLLMParameters,
} from "./llm.js";

import { Cache } from "@/cache/decoratorCache.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
Expand All @@ -25,7 +31,6 @@ import { ChatLLM, ChatLLMGenerateEvents, ChatLLMOutput } from "@/llms/chat.js";
import {
AsyncStream,
BaseLLMTokenizeOutput,
EmbeddingOptions,
EmbeddingOutput,
LLMCache,
LLMError,
Expand All @@ -36,7 +41,6 @@ import { shallowCopy } from "@/serializer/utils.js";
import { IBMVllmChatLLMPreset, IBMVllmChatLLMPresetModel } from "@/adapters/ibm-vllm/chatPreset.js";
import { Client } from "./client.js";
import { GetRunContext } from "@/context.js";
import { NotImplementedError } from "@/errors.js";

export class GrpcChatLLMOutput extends ChatLLMOutput {
public readonly raw: IBMvLLMOutput;
Expand Down Expand Up @@ -118,9 +122,9 @@ export class IBMVllmChatLLM extends ChatLLM<GrpcChatLLMOutput> {
return this.llm.meta();
}

// eslint-disable-next-line unused-imports/no-unused-vars
async embed(input: BaseMessage[][], options?: EmbeddingOptions): Promise<EmbeddingOutput> {
throw new NotImplementedError();
async embed(input: BaseMessage[][], options?: IBMvLLMEmbeddingOptions): Promise<EmbeddingOutput> {
const inputs = input.map((messages) => this.messagesToPrompt(messages));
return this.llm.embed(inputs, options);
}

createSnapshot() {
Expand Down
89 changes: 71 additions & 18 deletions src/adapters/ibm-vllm/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,39 @@ import grpc, {
ClientOptions as GRPCClientOptions,
ClientReadableStream,
ClientUnaryCall,
Metadata,
} from "@grpc/grpc-js";

import * as R from "remeda";
// eslint-disable-next-line no-restricted-imports
import { UnaryCallback } from "@grpc/grpc-js/build/src/client.js";
import { FrameworkError, ValueError } from "@/errors.js";
import protoLoader from "@grpc/proto-loader";
import protoLoader, { Options } from "@grpc/proto-loader";

import {
BatchedGenerationRequest,
BatchedGenerationResponse__Output,
BatchedTokenizeRequest,
BatchedTokenizeResponse__Output,
type EmbeddingTasksRequest,
GenerationRequest__Output,
ModelInfoRequest,
ModelInfoResponse__Output,
ProtoGrpcType as GenerationProtoGentypes,
ProtoGrpcType$1 as CaikitProtoGentypes,
SingleGenerationRequest,
EmbeddingResults__Output,
type SubtypeConstructor,
} from "@/adapters/ibm-vllm/types.js";
import { parseEnv } from "@/internals/env.js";
import { z } from "zod";
import { Cache } from "@/cache/decoratorCache.js";
import { Serializable } from "@/internals/serializable.js";
import PQueue from "p-queue-compat";
import { getProp } from "@/internals/helpers/object.js";

const GENERATION_PROTO_PATH = new URL("./proto/generation.proto", import.meta.url);
const NLP_PROTO_PATH = new URL("./proto/caikit_runtime_Nlp.proto", import.meta.url);

interface ClientOptions {
modelRouterSubdomain?: string;
Expand All @@ -55,6 +63,11 @@ interface ClientOptions {
};
grpcClientOptions: GRPCClientOptions;
clientShutdownDelay: number;
limits?: {
concurrency?: {
embeddings?: number;
};
};
}

const defaultOptions = {
Expand All @@ -66,18 +79,24 @@ const defaultOptions = {
},
};

const generationPackageObject = grpc.loadPackageDefinition(
protoLoader.loadSync([GENERATION_PROTO_PATH.pathname], {
longs: Number,
enums: String,
arrays: true,
objects: true,
oneofs: true,
keepCase: true,
defaults: true,
}),
const grpcConfig: Options = {
longs: Number,
enums: String,
arrays: true,
objects: true,
oneofs: true,
keepCase: true,
defaults: true,
};

const generationPackage = grpc.loadPackageDefinition(
protoLoader.loadSync([GENERATION_PROTO_PATH.pathname], grpcConfig),
) as unknown as GenerationProtoGentypes;

const embeddingsPackage = grpc.loadPackageDefinition(
protoLoader.loadSync([NLP_PROTO_PATH.pathname], grpcConfig),
) as unknown as CaikitProtoGentypes;

const GRPC_CLIENT_TTL = 15 * 60 * 1000;

type CallOptions = GRPCCallOptions & { signal?: AbortSignal };
Expand All @@ -88,9 +107,12 @@ export class Client extends Serializable {
private usedDefaultCredentials = false;

@Cache({ ttl: GRPC_CLIENT_TTL })
protected getClient(modelId: string) {
protected getClient<T extends { close: () => void }>(
modelId: string,
factory: SubtypeConstructor<typeof grpc.Client, T>,
): T {
const modelSpecificUrl = this.options.url.replace(/{model_id}/, modelId.replaceAll("/", "--"));
const client = new generationPackageObject.fmaas.GenerationService(
const client = new factory(
modelSpecificUrl,
grpc.credentials.createSsl(
Buffer.from(this.options.credentials.rootCert),
Expand Down Expand Up @@ -129,43 +151,64 @@ export class Client extends Serializable {
}

async modelInfo(request: RequiredModel<ModelInfoRequest>, options?: CallOptions) {
const client = this.getClient(request.model_id);
const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService);
return this.wrapGrpcCall<ModelInfoRequest, ModelInfoResponse__Output>(
client.modelInfo.bind(client),
)(request, options);
}

async generate(request: RequiredModel<BatchedGenerationRequest>, options?: CallOptions) {
const client = this.getClient(request.model_id);
const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService);
return this.wrapGrpcCall<BatchedGenerationRequest, BatchedGenerationResponse__Output>(
client.generate.bind(client),
)(request, options);
}

async generateStream(request: RequiredModel<SingleGenerationRequest>, options?: CallOptions) {
const client = this.getClient(request.model_id);
const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService);
return this.wrapGrpcStream<SingleGenerationRequest, GenerationRequest__Output>(
client.generateStream.bind(client),
)(request, options);
}

async tokenize(request: RequiredModel<BatchedTokenizeRequest>, options?: CallOptions) {
const client = this.getClient(request.model_id);
const client = this.getClient(request.model_id, generationPackage.fmaas.GenerationService);
return this.wrapGrpcCall<BatchedTokenizeRequest, BatchedTokenizeResponse__Output>(
client.tokenize.bind(client),
)(request, options);
}

async embed(request: RequiredModel<EmbeddingTasksRequest>, options?: CallOptions) {
const client = this.getClient(
request.model_id,
embeddingsPackage.caikit.runtime.Nlp.NlpService,
);
return this.queues.embeddings.add(
() =>
this.wrapGrpcCall<EmbeddingTasksRequest, EmbeddingResults__Output>(
client.embeddingTasksPredict.bind(client),
)(request, options),
{ throwOnTimeout: true },
);
}

protected wrapGrpcCall<TRequest, TResponse>(
fn: (
request: TRequest,
metadata: Metadata,
options: CallOptions,
callback: UnaryCallback<TResponse>,
) => ClientUnaryCall,
) {
return (request: TRequest, { signal, ...options }: CallOptions = {}): Promise<TResponse> => {
const metadata = new Metadata();
const modelId = getProp(request, ["model_id"]);
if (modelId) {
metadata.add("mm-model-id", modelId);
}

return new Promise<TResponse>((resolve, reject) => {
const call = fn(request, options, (err, response) => {
const call = fn(request, metadata, options, (err, response) => {
signal?.removeEventListener("abort", abortHandler);
if (err) {
reject(err);
Expand Down Expand Up @@ -213,4 +256,14 @@ export class Client extends Serializable {
Object.assign(this, snapshot);
this.options.credentials = this.getDefaultCredentials();
}

@Cache({ enumerable: false })
protected get queues() {
return {
embeddings: new PQueue({
concurrency: this.options.limits?.concurrency?.embeddings ?? 5,
throwOnTimeout: true,
}),
};
}
}
48 changes: 43 additions & 5 deletions src/adapters/ibm-vllm/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ import {
LLMError,
LLMMeta,
} from "@/llms/base.js";
import { isEmpty, isString } from "remeda";
import type { DecodingParameters, SingleGenerationRequest } from "@/adapters/ibm-vllm/types.js";
import { chunk, isEmpty, isString } from "remeda";
import type {
DecodingParameters,
SingleGenerationRequest,
EmbeddingTasksRequest,
} from "@/adapters/ibm-vllm/types.js";
import { LLM, LLMEvents, LLMInput } from "@/llms/llm.js";
import { Emitter } from "@/emitter/emitter.js";
import { GenerationResponse__Output } from "@/adapters/ibm-vllm/types.js";
Expand All @@ -39,6 +43,7 @@ import { ServiceError } from "@grpc/grpc-js";
import { Client } from "@/adapters/ibm-vllm/client.js";
import { GetRunContext } from "@/context.js";
import { BatchedGenerationRequest } from "./types.js";
import { OmitPrivateKeys } from "@/internals/types.js";

function isGrpcServiceError(err: unknown): err is ServiceError {
return (
Expand Down Expand Up @@ -100,6 +105,12 @@ export type IBMvLLMParameters = NonNullable<

export interface IBMvLLMGenerateOptions extends GenerateOptions {}

export interface IBMvLLMEmbeddingOptions
extends EmbeddingOptions,
Omit<OmitPrivateKeys<EmbeddingTasksRequest>, "texts"> {
chunkSize?: number;
}

export type IBMvLLMEvents = LLMEvents<IBMvLLMOutput>;

export class IBMvLLM extends LLM<IBMvLLMOutput, IBMvLLMGenerateOptions> {
Expand Down Expand Up @@ -128,9 +139,36 @@ export class IBMvLLM extends LLM<IBMvLLMOutput, IBMvLLMGenerateOptions> {
};
}

// eslint-disable-next-line unused-imports/no-unused-vars
async embed(input: LLMInput[], options?: EmbeddingOptions): Promise<EmbeddingOutput> {
throw new NotImplementedError();
async embed(
input: LLMInput[],
{ chunkSize, signal, ...options }: IBMvLLMEmbeddingOptions = {},
): Promise<EmbeddingOutput> {
const results = await Promise.all(
chunk(input, chunkSize ?? 100).map(async (texts) => {
const response = await this.client.embed(
{
model_id: this.modelId,
truncate_input_tokens: options?.truncate_input_tokens ?? 512,
texts,
},
{
signal,
},
);
const embeddings = response.results?.vectors.map((vector) => {
const embedding = vector[vector.data]?.values;
if (!embedding) {
throw new LLMError("Missing embedding");
}
return embedding;
});
if (embeddings?.length !== texts.length) {
throw new LLMError("Missing embedding");
}
return embeddings;
}),
);
return { embeddings: results.flat() };
}

async tokenize(input: LLMInput): Promise<BaseLLMTokenizeOutput> {
Expand Down
Loading

0 comments on commit 2925dfc

Please sign in to comment.