From 78b8e8ab0388f6ec6fbc4717430afbfe1d746050 Mon Sep 17 00:00:00 2001 From: Butch78 Date: Wed, 1 Nov 2023 18:23:09 +0100 Subject: [PATCH] feat: Concurrent Embeddings --- orca/Cargo.toml | 1 + orca/src/llm/openai.rs | 28 +++++++++++++++++++++------- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/orca/Cargo.toml b/orca/Cargo.toml index 2418eb5..dc846ba 100644 --- a/orca/Cargo.toml +++ b/orca/Cargo.toml @@ -36,6 +36,7 @@ candle-nn = { git = "https://github.com/huggingface/candle" } tracing-chrome = "0.7.1" tracing-subscriber = "0.3.17" log = "0.4.20" +futures = "0.3.29" [dev-dependencies] base64 = "0.21.4" diff --git a/orca/src/llm/openai.rs b/orca/src/llm/openai.rs index 2013cc1..c680e4d 100644 --- a/orca/src/llm/openai.rs +++ b/orca/src/llm/openai.rs @@ -5,6 +5,7 @@ use crate::{ prompt::{chat::Message, Prompt}, }; use anyhow::Result; +use futures::{stream::FuturesUnordered, TryFutureExt, TryStreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -273,15 +274,20 @@ impl EmbeddingTrait for OpenAI { Ok(res.into()) } - /// TODO: Concurrent - async fn generate_embeddings(&self, prompt: Vec>) -> Result { + async fn generate_embeddings(&self, prompts: Vec>) -> Result { let mut embeddings = Vec::new(); - for prompt in prompt { + let mut futures = FuturesUnordered::new(); + + for prompt in prompts { let req = self.generate_embedding_request(&prompt.to_string())?; - let res = self.client.execute(req).await?; - let res = res.json::().await?; + let fut = self.client.execute(req).and_then(|res| res.json::()); + futures.push(fut); + } + + while let Some(res) = futures.try_next().await? { embeddings.push(res); } + Ok(EmbeddingResponse::OpenAI(embeddings)) } } @@ -289,9 +295,9 @@ impl EmbeddingTrait for OpenAI { #[cfg(test)] mod test { use super::*; - use crate::prompt; use crate::prompt::TemplateEngine; use crate::template; + use crate::{prompt, prompts}; use std::collections::HashMap; #[tokio::test] @@ -322,10 +328,18 @@ mod test { } #[tokio::test] - async fn test_embeddings() { + async fn test_embedding() { let client = OpenAI::new(); let content = prompt!("This is a test"); let res = client.generate_embedding(content).await.unwrap(); assert!(res.to_vec2().unwrap().len() > 0); } + + #[tokio::test] + async fn test_embeddings() { + let client = OpenAI::new(); + let content = prompts!("This is a test", "This is another test", "This is a third test"); + let res = client.generate_embeddings(content).await.unwrap(); + assert!(res.to_vec2().unwrap().len() > 0); + } }