From 4673b5efd98adce398da46e883458a5e9d5a9673 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Fri, 13 Dec 2024 18:49:38 +0100 Subject: [PATCH] feat(adapters): add embedding support for Groq Ref: #176 Signed-off-by: Tomas Dvorak --- src/adapters/groq/chat.ts | 15 ++++++++++++--- tests/e2e/adapters/groq/chat.test.ts | 16 ++++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/adapters/groq/chat.ts b/src/adapters/groq/chat.ts index 45d8ad5e..f3e1e58d 100644 --- a/src/adapters/groq/chat.ts +++ b/src/adapters/groq/chat.ts @@ -34,7 +34,6 @@ import { GetRunContext } from "@/context.js"; import { Serializer } from "@/serializer/serializer.js"; import { getPropStrict } from "@/internals/helpers/object.js"; import { ChatCompletionCreateParams } from "groq-sdk/resources/chat/completions"; -import { NotImplementedError } from "@/errors.js"; type Parameters = Omit; type Response = Omit; @@ -148,9 +147,19 @@ export class GroqChatLLM extends ChatLLM { }; } - // eslint-disable-next-line unused-imports/no-unused-vars async embed(input: BaseMessage[][], options?: EmbeddingOptions): Promise { - throw new NotImplementedError(); + const { data } = await this.client.embeddings.create( + { + model: this.modelId, + input: input.flatMap((msgs) => msgs.map((msg) => msg.text)) as string[], + encoding_format: "float", + }, + { + signal: options?.signal, + stream: false, + }, + ); + return { embeddings: data.map(({ embedding }) => embedding as number[]) }; } async tokenize(input: BaseMessage[]): Promise { diff --git a/tests/e2e/adapters/groq/chat.test.ts b/tests/e2e/adapters/groq/chat.test.ts index 350b8be2..7ef16530 100644 --- a/tests/e2e/adapters/groq/chat.test.ts +++ b/tests/e2e/adapters/groq/chat.test.ts @@ -20,9 +20,9 @@ import { GroqChatLLM } from "@/adapters/groq/chat.js"; const apiKey = process.env.GROQ_API_KEY; describe.runIf(Boolean(apiKey))("Adapter Groq Chat LLM", () => { - const createChatLLM = () => { + const createChatLLM = (modelId = "llama3-8b-8192") => { const model = new GroqChatLLM({ - modelId: "llama3-8b-8192", + modelId, parameters: { temperature: 0, max_tokens: 1024, @@ -69,4 +69,16 @@ describe.runIf(Boolean(apiKey))("Adapter Groq Chat LLM", () => { ); } }); + + // Embedding model does not available right now + it.skip("Embeds", async () => { + const llm = createChatLLM("nomic-embed-text-v1_5"); + const response = await llm.embed([ + [BaseMessage.of({ role: "user", text: `Hello world!` })], + [BaseMessage.of({ role: "user", text: `Hello family!` })], + ]); + expect(response.embeddings.length).toBe(2); + expect(response.embeddings[0].length).toBe(1024); + expect(response.embeddings[1].length).toBe(1024); + }); });