-
Notifications
You must be signed in to change notification settings - Fork 11
/
example.js
64 lines (55 loc) · 2.08 KB
/
example.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import { pipeline } from '@xenova/transformers';
import pg from 'pg';
import pgvector from 'pgvector/pg';
const client = new pg.Client({database: 'pgvector_example'});
await client.connect();
await client.query('CREATE EXTENSION IF NOT EXISTS vector');
await pgvector.registerTypes(client);
await client.query('DROP TABLE IF EXISTS documents');
await client.query('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))');
await client.query("CREATE INDEX ON documents USING GIN (to_tsvector('english', content))");
const input = [
'The dog is barking',
'The cat is purring',
'The bear is growling'
];
const extractor = await pipeline('feature-extraction', 'Xenova/multi-qa-MiniLM-L6-cos-v1');
async function generateEmbedding(content) {
const output = await extractor(content, {pooling: 'mean', normalize: true});
return Array.from(output.data);
}
for (let content of input) {
const embedding = await generateEmbedding(content);
await client.query('INSERT INTO documents (content, embedding) VALUES ($1, $2)', [content, pgvector.toSql(embedding)]);
}
const sql = `
WITH semantic_search AS (
SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
FROM documents
ORDER BY embedding <=> $2
LIMIT 20
),
keyword_search AS (
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
FROM documents, plainto_tsquery('english', $1) query
WHERE to_tsvector('english', content) @@ query
ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
LIMIT 20
)
SELECT
COALESCE(semantic_search.id, keyword_search.id) AS id,
COALESCE(1.0 / ($3 + semantic_search.rank), 0.0) +
COALESCE(1.0 / ($3 + keyword_search.rank), 0.0) AS score
FROM semantic_search
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
ORDER BY score DESC
LIMIT 5
`;
const query = 'growling bear'
const embedding = await generateEmbedding(query);
const k = 60
const { rows } = await client.query(sql, [query, pgvector.toSql(embedding), k]);
for (let row of rows) {
console.log(row);
}
await client.end();