Skip to content

Commit

Permalink
feat: Concurrent Embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Butch78 committed Nov 1, 2023
1 parent 54743b0 commit 78b8e8a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
1 change: 1 addition & 0 deletions orca/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ candle-nn = { git = "https://github.com/huggingface/candle" }
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.17"
log = "0.4.20"
futures = "0.3.29"

[dev-dependencies]
base64 = "0.21.4"
Expand Down
28 changes: 21 additions & 7 deletions orca/src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
prompt::{chat::Message, Prompt},
};
use anyhow::Result;
use futures::{stream::FuturesUnordered, TryFutureExt, TryStreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -273,25 +274,30 @@ impl EmbeddingTrait for OpenAI {
Ok(res.into())
}

/// TODO: Concurrent
async fn generate_embeddings(&self, prompt: Vec<Box<dyn Prompt>>) -> Result<EmbeddingResponse> {
async fn generate_embeddings(&self, prompts: Vec<Box<dyn Prompt>>) -> Result<EmbeddingResponse> {
let mut embeddings = Vec::new();
for prompt in prompt {
let mut futures = FuturesUnordered::new();

for prompt in prompts {
let req = self.generate_embedding_request(&prompt.to_string())?;
let res = self.client.execute(req).await?;
let res = res.json::<OpenAIEmbeddingResponse>().await?;
let fut = self.client.execute(req).and_then(|res| res.json::<OpenAIEmbeddingResponse>());
futures.push(fut);
}

while let Some(res) = futures.try_next().await? {
embeddings.push(res);
}

Ok(EmbeddingResponse::OpenAI(embeddings))
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::prompt;
use crate::prompt::TemplateEngine;
use crate::template;
use crate::{prompt, prompts};
use std::collections::HashMap;

#[tokio::test]
Expand Down Expand Up @@ -322,10 +328,18 @@ mod test {
}

#[tokio::test]
async fn test_embeddings() {
async fn test_embedding() {
let client = OpenAI::new();
let content = prompt!("This is a test");
let res = client.generate_embedding(content).await.unwrap();
assert!(res.to_vec2().unwrap().len() > 0);
}

#[tokio::test]
async fn test_embeddings() {
let client = OpenAI::new();
let content = prompts!("This is a test", "This is another test", "This is a third test");
let res = client.generate_embeddings(content).await.unwrap();
assert!(res.to_vec2().unwrap().len() > 0);
}
}

0 comments on commit 78b8e8a

Please sign in to comment.