-
Notifications
You must be signed in to change notification settings - Fork 0
/
test-embeddings.js
64 lines (54 loc) · 2.04 KB
/
test-embeddings.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import Replicate from 'replicate';
import * as dotenv from 'dotenv';
import * as fs from 'fs';
// Load environment variables from .env file
dotenv.config();
// Initialize Replicate with API token
const replicate = new Replicate({
auth: process.env.REPLICATE_API_TOKEN,
});
// Model information for embedding
const version = 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305';
const model = 'replicate/all-mpnet-base-v25';
// Function to get embedding for a given text
async function getEmbedding(text) {
console.log(`Generating embedding for test query: "${text}"`);
const input = {
text_batch: JSON.stringify([text]),
};
const output = await replicate.run(`${model}:${version}`, { input });
return output[0];
}
// Load pre-computed embeddings from file
const { embeddings } = JSON.parse(fs.readFileSync('data/embeddings.json', 'utf-8'));
// Test function to demonstrate embeddings search
async function test() {
const prompt = "How do I set a shape's color?";
console.log(`Test query: "${prompt}"`);
const inputEmbedding = await getEmbedding(prompt);
// Calculate similarity of the test query with each stored embedding
let similarities = embeddings.map(({ text, embedding }) => ({
text,
similarity: cosineSimilarity(inputEmbedding.embedding, embedding),
}));
// Sort the results by similarity in descending order
similarities = similarities.sort((a, b) => b.similarity - a.similarity);
// Display the top 10 results
console.log('Top 10 Results:');
similarities = similarities.slice(0, 10);
similarities.forEach((item, index) => {
console.log(`${index + 1}: ${item.text} (Similarity: ${item.similarity.toFixed(3)})`);
});
}
// Functions to calculate cosine similarity
function dotProduct(vecA, vecB) {
return vecA.reduce((sum, val, i) => sum + val * vecB[i], 0);
}
function magnitude(vec) {
return Math.sqrt(vec.reduce((sum, val) => sum + val * val, 0));
}
function cosineSimilarity(vecA, vecB) {
return dotProduct(vecA, vecB) / (magnitude(vecA) * magnitude(vecB));
}
// Call the test function
test();