Skip to content

Commit

Permalink
feat: Add tests for each LLM service using their cheapest model
Browse files Browse the repository at this point in the history
  • Loading branch information
danielcampagnolitg committed Aug 19, 2024
1 parent aff5ded commit 41682f1
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 14 deletions.
11 changes: 10 additions & 1 deletion src/llm/base-llm.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { countTokens } from '#llm/tokens';
import { FunctionResponse, GenerateFunctionOptions, GenerateJsonOptions, GenerateTextOptions, LLM } from './llm';
import { FunctionResponse, GenerateFunctionOptions, GenerateJsonOptions, GenerateTextOptions, LLM, LlmMessage } from './llm';
import { extractJsonResult, extractStringResult, parseFunctionCallsXml } from './responseParsers';

export interface SerializedLLM {
Expand Down Expand Up @@ -74,4 +74,13 @@ export abstract class BaseLLM implements LLM {
// defaults to gpt4o token parser
return countTokens(text);
}

async generateJson2<T>(messages: LlmMessage[], opts?: GenerateJsonOptions): Promise<T> {
const response = await this.generateText2(messages, opts ? { type: 'json', ...opts } : { type: 'json' });
return extractJsonResult(response);
}

generateText2(messages: LlmMessage[], opts?: GenerateTextOptions): Promise<string> {
throw new Error('NotImplemented');
}
}
59 changes: 50 additions & 9 deletions src/llm/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,48 @@ export type GenerateJsonOptions = Omit<GenerateTextOptions, 'type'>;
*/
export type GenerateFunctionOptions = Omit<GenerateTextOptions, 'type'>;

export interface LlmMessage {
role: 'system' | 'user' | 'assistant';
text: string;
/** Set the cache_control flag with Claude models */
cache: boolean;
}

export function system(text: string, cache = false): LlmMessage {
return {
role: 'system',
text: text,
cache,
};
}

export function user(text: string, cache = false): LlmMessage {
return {
role: 'user',
text: text,
cache,
};
}

/**
* Prefill the assistant message to help guide its response
* @see https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/prefill-claudes-response
* @param text
*/
export function assistant(text: string): LlmMessage {
return {
role: 'assistant',
text: text,
cache: false,
};
}

export interface LLM {
generateText2(messages: LlmMessage[]): Promise<string>;

/* Generates a response that is expected to be in JSON format, and returns the object */
generateJson2<T>(messages: LlmMessage[], opts?: GenerateJsonOptions): Promise<T>;

/* Generates text from a LLM */
generateText(userPrompt: string, systemPrompt?: string, opts?: GenerateTextOptions): Promise<string>;

Expand Down Expand Up @@ -133,16 +174,16 @@ export function combinePrompts(userPrompt: string, systemPrompt?: string): strin
export function logTextGeneration(originalMethod: any, context: ClassMethodDecoratorContext): any {
return async function replacementMethod(this: BaseLLM, ...args: any[]) {
// system prompt
if (args.length > 1 && args[1]) {
logger.info(`= SYSTEM PROMPT ===================================================\n${args[1]}`);
}
logger.info(`= USER PROMPT =================================================================\n${args[0]}`);

const start = Date.now();
// if (args.length > 1 && args[1]) {
// logger.info(`= SYSTEM PROMPT ===================================================\n${args[1]}`);
// }
// logger.info(`= USER PROMPT =================================================================\n${args[0]}`);
//
// const start = Date.now();
const result = await originalMethod.call(this, ...args);
logger.info(`= RESPONSE ${this.model} ==========================================================\n${JSON.stringify(result)}`);
const duration = `${((Date.now() - start) / 1000).toFixed(1)}s`;
logger.info(`${duration} <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<`);
// logger.info(`= RESPONSE ${this.model} ==========================================================\n${JSON.stringify(result)}`);
// const duration = `${((Date.now() - start) / 1000).toFixed(1)}s`;
// logger.info(`${duration} <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<`);
return result;
};
}
Expand Down
70 changes: 70 additions & 0 deletions src/llm/models/llm.int.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import { expect } from 'chai';
import { Claude3_Haiku } from '#llm/models/anthropic';
import { deepseekChat } from '#llm/models/deepseek';
import { fireworksLlama3_70B } from '#llm/models/fireworks';
import { groqGemma7bIt } from '#llm/models/groq';
import { Ollama_Phi3 } from '#llm/models/ollama';
import { GPT4oMini } from '#llm/models/openai';
import { togetherLlama3_70B } from '#llm/models/together';
import { Gemini_1_5_Flash } from '#llm/models/vertexai';

describe('LLMs', () => {
const SKY_PROMPT = 'What colour is the day sky? Answer in one word.';
Expand All @@ -12,4 +19,67 @@ describe('LLMs', () => {
expect(response.toLowerCase()).to.include('blue');
});
});

describe('Deepseek', () => {
const llm = deepseekChat();

it('should generateText', async () => {
const response = await llm.generateText(SKY_PROMPT, null, { temperature: 0 });
expect(response.toLowerCase()).to.include('blue');
});
});

describe('Fireworks', () => {
const llm = fireworksLlama3_70B();

it('should generateText', async () => {
const response = await llm.generateText(SKY_PROMPT, null, { temperature: 0 });
expect(response.toLowerCase()).to.include('blue');
});
});

describe('Groq', () => {
const llm = groqGemma7bIt();

it('should generateText', async () => {
const response = await llm.generateText(SKY_PROMPT, null, { temperature: 0 });
expect(response.toLowerCase()).to.include('blue');
});
});

describe('Ollama', () => {
const llm = Ollama_Phi3();

it('should generateText', async () => {
const response = await llm.generateText(SKY_PROMPT, null, { temperature: 0 });
expect(response.toLowerCase()).to.include('blue');
});
});

describe('OpenAI', () => {
const llm = GPT4oMini();

it('should generateText', async () => {
const response = await llm.generateText(SKY_PROMPT, null, { temperature: 0 });
expect(response.toLowerCase()).to.include('blue');
});
});

describe('Together', () => {
const llm = togetherLlama3_70B();

it('should generateText', async () => {
const response = await llm.generateText(SKY_PROMPT, null, { temperature: 0 });
expect(response.toLowerCase()).to.include('blue');
});
});

describe('VertexAI', () => {
const llm = Gemini_1_5_Flash();

it('should generateText', async () => {
const response = await llm.generateText(SKY_PROMPT, null, { temperature: 0 });
expect(response.toLowerCase()).to.include('blue');
});
});
});
78 changes: 74 additions & 4 deletions src/llm/models/vertexai.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { HarmBlockThreshold, HarmCategory, SafetySetting, VertexAI } from '@google-cloud/vertexai';
import { GenerativeModel, HarmBlockThreshold, HarmCategory, SafetySetting, VertexAI } from '@google-cloud/vertexai';
import axios from 'axios';
import { AgentLLMs, addCost, agentContext } from '#agent/agentContext';
import { CreateLlmRequest, LlmCall } from '#llm/llmCallService/llmCall';
import { CallerId } from '#llm/llmCallService/llmCallService';
import { LlmCall } from '#llm/llmCallService/llmCall';
import { logger } from '#o11y/logger';
import { withActiveSpan } from '#o11y/trace';
import { currentUser } from '#user/userService/userContext';
Expand All @@ -29,6 +28,7 @@ export function vertexLLMRegistry(): Record<string, () => LLM> {
[`${VERTEX_SERVICE}:gemini-experimental`]: Gemini_1_5_Experimental,
[`${VERTEX_SERVICE}:gemini-1.5-pro`]: Gemini_1_5_Pro,
[`${VERTEX_SERVICE}:gemini-1.5-flash`]: Gemini_1_5_Flash,
[`${VERTEX_SERVICE}:Llama3-405b-instruct-maas`]: Vertex_Llama3_405b,
};
}

Expand All @@ -37,6 +37,8 @@ export function vertexLLMRegistry(): Record<string, () => LLM> {
// https://cloud.google.com/vertex-ai/generative-ai/pricing

// gemini-1.5-pro-latest
// gemini-1.5-pro-exp-0801
// exp-0801
export function Gemini_1_5_Pro(version = '001') {
return new VertexLLM(
'Gemini 1.5 Pro',
Expand Down Expand Up @@ -70,6 +72,74 @@ export function Gemini_1_5_Flash(version = '001') {
);
}

// async imageToText(urlOrBytes: string | Buffer): Promise<string> {
// return withActiveSpan('imageToText', async (span) => {
// const generativeVisionModel = this.vertex().getGenerativeModel({
// model: this.imageToTextModel,
// }) as GenerativeModel;
//
// let filePart: { fileData?: { fileUri: string; mimeType: string }; inlineData?: { data: string; mimeType: string } };
// if (typeof urlOrBytes === 'string') {
// filePart = {
// fileData: {
// fileUri: urlOrBytes,
// mimeType: 'image/jpeg', // Adjust mime type if needed
// },
// };
// } else if (Buffer.isBuffer(urlOrBytes)) {
// filePart = {
// inlineData: {
// data: urlOrBytes.toString('base64'),
// mimeType: 'image/jpeg', // Adjust mime type if needed
// },
// };
// } else {
// throw new Error('Invalid input: must be a URL string or a Buffer');
// }
//
// const textPart = {
// text: 'Describe the contents of this image',
// };
//
// const request = {
// contents: [
// {
// role: 'user',
// parts: [filePart, textPart],
// },
// ],
// };
//
// try {
// const response = await generativeVisionModel.generateContent(request);
// const fullTextResponse = response.response.candidates[0].content.parts[0].text;
//
// span.setAttributes({
// inputType: typeof urlOrBytes === 'string' ? 'url' : 'buffer',
// outputLength: fullTextResponse.length,
// });
//
// return fullTextResponse;
// } catch (error) {
// logger.error('Error in imageToText:', error);
// span.recordException(error);
// span.setStatus({ code: SpanStatusCode.ERROR, message: error.message });
// throw error;
// }
// });
// }

export function Vertex_Llama3_405b() {
return new VertexLLM(
'Llama3 405b (Vertex)',
VERTEX_SERVICE,
'Llama3-405b-instruct-maas', // meta/llama3
100_000,
(input: string) => 0,
(output: string) => 0,
);
}

/**
* Vertex AI models - Gemini
*/
Expand Down Expand Up @@ -118,7 +188,7 @@ class VertexLLM extends BaseLLM {
} else {
const generativeModel = this.vertex().getGenerativeModel({
model: this.model,
systemInstruction: systemPrompt ? { role: 'system', parts: [{ text: systemPrompt }] } : undefined,
systemInstruction: systemPrompt, // ? { role: 'system', parts: [{ text: systemPrompt }] } : undefined
generationConfig: {
maxOutputTokens: 8192,
temperature: opts?.temperature,
Expand Down

0 comments on commit 41682f1

Please sign in to comment.