-
Notifications
You must be signed in to change notification settings - Fork 91
/
vector_search_cohere.rs
80 lines (70 loc) · 2.85 KB
/
vector_search_cohere.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
use std::env;
use rig::{
embeddings::EmbeddingsBuilder,
providers::cohere::{Client, EMBED_ENGLISH_V3},
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
Embed,
};
use serde::{Deserialize, Serialize};
// Shape of data that needs to be RAG'ed.
// The definition field will be used to generate embeddings.
#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)]
struct WordDefinition {
id: String,
word: String,
#[embed]
definitions: Vec<String>,
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create Cohere client
let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
let cohere_client = Client::new(&cohere_api_key);
let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document");
let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query");
let embeddings = EmbeddingsBuilder::new(document_model.clone())
.documents(vec![
WordDefinition {
id: "doc0".to_string(),
word: "flurbo".to_string(),
definitions: vec![
"A green alien that lives on cold planets.".to_string(),
"A fictional digital currency that originated in the animated series Rick and Morty.".to_string()
]
},
WordDefinition {
id: "doc1".to_string(),
word: "glarb-glarb".to_string(),
definitions: vec![
"An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
"A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string()
]
},
WordDefinition {
id: "doc2".to_string(),
word: "linglingdong".to_string(),
definitions: vec![
"A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(),
"A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string()
]
},
])?
.build()
.await?;
// Create vector store with the embeddings
let vector_store =
InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone());
// Create vector store index
let index = vector_store.index(search_model);
let results = index
.top_n::<WordDefinition>(
"Which instrument is found in the Nebulon Mountain Ranges?",
1,
)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc.word))
.collect::<Vec<_>>();
println!("Results: {:?}", results);
Ok(())
}