-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/guywaldman/orch into develop
- Loading branch information
Showing
17 changed files
with
926 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
//! This example demonstrates how to use the `Executor` to generate embeddings from the LLM. | ||
//! We construct an `Ollama` instance and use it to generate embeddings. | ||
//! | ||
use orch::{Executor, OllamaBuilder}; | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
let text = "Lorem ipsum"; | ||
|
||
println!("Text: {text}"); | ||
println!("---"); | ||
|
||
let ollama = OllamaBuilder::new().build(); | ||
let executor = Executor::new(&ollama); | ||
let embedding = executor | ||
.generate_embedding(text) | ||
.await | ||
.expect("Execution failed"); | ||
|
||
println!("Embedding:"); | ||
println!("{:?}", embedding); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
//! This example demonstrates how to use the `Executor` to generate a response from the LLM. | ||
//! We construct an `Ollama` instance and use it to generate a response. | ||
use orch::{Executor, OllamaBuilder}; | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
let prompt = "What is 2+2?"; | ||
let system_prompt = "You are a helpful assistant"; | ||
|
||
println!("Prompt: {prompt}"); | ||
println!("System prompt: {system_prompt}"); | ||
println!("---"); | ||
|
||
let ollama = OllamaBuilder::new().build(); | ||
let executor = Executor::new(&ollama); | ||
let response = executor | ||
.text_complete(prompt, system_prompt) | ||
.await | ||
.expect("Execution failed"); | ||
|
||
println!("Response:"); | ||
println!("{}", response.text); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
//! This example demonstrates how to use the `Executor` to generate a streaming response from the LLM. | ||
//! We construct an `Ollama` instance and use it to generate a streaming response. | ||
use orch::{Executor, OllamaBuilder}; | ||
use tokio_stream::StreamExt; | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
let prompt = "What is 2+2?"; | ||
let system_prompt = "You are a helpful assistant"; | ||
|
||
println!("Prompt: {prompt}"); | ||
println!("System prompt: {system_prompt}"); | ||
println!("---"); | ||
|
||
let ollama = OllamaBuilder::new().build(); | ||
let executor = Executor::new(&ollama); | ||
let mut response = executor | ||
.text_complete_stream(prompt, system_prompt) | ||
.await | ||
.expect("Execution failed"); | ||
|
||
println!("Response:"); | ||
while let Some(chunk) = response.stream.next().await { | ||
match chunk { | ||
Ok(chunk) => print!("{chunk}"), | ||
Err(e) => { | ||
println!("Error: {e}"); | ||
break; | ||
} | ||
} | ||
} | ||
println!(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
mod net; | ||
|
||
pub use net::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
/// Module for working with Server-Sent Events. | ||
mod sse; | ||
|
||
pub use sse::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
use async_gen::AsyncIter; | ||
use reqwest::{header, Client}; | ||
use tokio_stream::Stream; | ||
|
||
/// A client for working with Server-Sent Events. | ||
pub struct SseClient; | ||
|
||
impl SseClient { | ||
pub fn post(url: &str, body: Option<String>) -> impl Stream<Item = String> { | ||
let client = Client::new(); | ||
let mut req = Client::post(&client, url) | ||
.header(header::ACCEPT, "text/event-stream") | ||
.header(header::CACHE_CONTROL, "no-cache") | ||
.header(header::CONNECTION, "keep-alive") | ||
.header(header::CONTENT_TYPE, "application/json"); | ||
if let Some(body) = body { | ||
req = req.body(body); | ||
} | ||
let req = req.build().unwrap(); | ||
|
||
AsyncIter::from(async_gen::gen! { | ||
let mut conn = client.execute(req).await.unwrap(); | ||
while let Some(event) = conn.chunk().await.unwrap() { | ||
yield std::str::from_utf8(&event).unwrap().to_owned(); | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
use std::pin::Pin; | ||
|
||
use thiserror::Error; | ||
use tokio_stream::Stream; | ||
|
||
use crate::{Llm, LlmError, TextCompleteOptions, TextCompleteStreamOptions}; | ||
|
||
pub struct Executor<'a, L: Llm> { | ||
llm: &'a L, | ||
} | ||
|
||
#[derive(Debug, Error)] | ||
pub enum ExecutorError { | ||
#[error("LLM error: {0}")] | ||
Llm(LlmError), | ||
} | ||
|
||
impl<'a, L: Llm> Executor<'a, L> { | ||
/// Creates a new `Executor` instance. | ||
/// | ||
/// # Arguments | ||
/// * `llm` - The LLM to use for the execution. | ||
pub fn new(llm: &'a L) -> Self { | ||
Self { llm } | ||
} | ||
|
||
/// Generates a response from the LLM (non-streaming). | ||
/// | ||
/// # Arguments | ||
/// * `prompt` - The prompt to generate a response for. | ||
/// * `system_prompt` - The system prompt to use for the generation. | ||
/// | ||
/// # Returns | ||
/// A [Result] containing the response from the LLM or an error if there was a problem. | ||
pub async fn text_complete( | ||
&self, | ||
prompt: &str, | ||
system_prompt: &str, | ||
) -> Result<ExecutorTextCompleteResponse, ExecutorError> { | ||
let options = TextCompleteOptions { | ||
..Default::default() | ||
}; | ||
let response = self | ||
.llm | ||
.text_complete(prompt, system_prompt, options) | ||
.await | ||
.map_err(ExecutorError::Llm)?; | ||
Ok(ExecutorTextCompleteResponse { | ||
text: response.text, | ||
context: ExecutorContext {}, | ||
}) | ||
} | ||
|
||
/// Generates a streaming response from the LLM. | ||
/// | ||
/// # Arguments | ||
/// * `prompt` - The prompt to generate a response for. | ||
/// * `system_prompt` - The system prompt to use for the generation. | ||
/// | ||
/// # Returns | ||
/// A [Result] containing the response from the LLM or an error if there was a problem. | ||
pub async fn text_complete_stream( | ||
&self, | ||
prompt: &str, | ||
system_prompt: &str, | ||
) -> Result<ExecutorTextCompleteStreamResponse, ExecutorError> { | ||
let options = TextCompleteStreamOptions { | ||
..Default::default() | ||
}; | ||
let response = self | ||
.llm | ||
.text_complete_stream(prompt, system_prompt, options) | ||
.await | ||
.map_err(ExecutorError::Llm)?; | ||
Ok(ExecutorTextCompleteStreamResponse { | ||
stream: response.stream, | ||
context: ExecutorContext {}, | ||
}) | ||
} | ||
|
||
/// Generates an embedding from the LLM. | ||
/// | ||
/// # Arguments | ||
/// * `prompt` - The item to generate an embedding for. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A [Result] containing the embedding or an error if there was a problem. | ||
pub async fn generate_embedding(&self, prompt: &str) -> Result<Vec<f32>, ExecutorError> { | ||
let response = self | ||
.llm | ||
.generate_embedding(prompt) | ||
.await | ||
.map_err(ExecutorError::Llm)?; | ||
Ok(response) | ||
} | ||
} | ||
|
||
// TODO: Support context for completions (e.g., IDs of past conversations in Ollama). | ||
pub struct ExecutorContext; | ||
|
||
pub struct ExecutorTextCompleteResponse { | ||
pub text: String, | ||
pub context: ExecutorContext, | ||
} | ||
|
||
pub struct ExecutorTextCompleteStreamResponse { | ||
pub stream: Pin<Box<dyn Stream<Item = Result<String, LlmError>> + Send>>, | ||
pub context: ExecutorContext, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
mod core; | ||
mod executor; | ||
mod llm; | ||
|
||
// TODO: Narrow the scope of the use statements. | ||
pub use core::*; | ||
pub use executor::*; | ||
pub use llm::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
use thiserror::Error; | ||
|
||
use crate::{LlmProvider, OllamaError}; | ||
|
||
#[derive(Debug, Error)] | ||
pub enum LlmProviderError { | ||
#[error("Invalid LLM provider: {0}")] | ||
InvalidValue(String), | ||
} | ||
|
||
impl std::fmt::Display for LlmProvider { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
match self { | ||
LlmProvider::Ollama => write!(f, "ollama"), | ||
LlmProvider::OpenAi => write!(f, "openai"), | ||
} | ||
} | ||
} | ||
|
||
impl Default for LlmProvider { | ||
fn default() -> Self { | ||
Self::Ollama | ||
} | ||
} | ||
|
||
impl TryFrom<&str> for LlmProvider { | ||
type Error = LlmProviderError; | ||
|
||
fn try_from(value: &str) -> Result<Self, Self::Error> { | ||
match value { | ||
"ollama" => Ok(LlmProvider::Ollama), | ||
"openai" => Ok(LlmProvider::OpenAi), | ||
_ => Err(LlmProviderError::InvalidValue(value.to_string())), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Error)] | ||
pub enum LlmError { | ||
#[error("Text generation error: {0}")] | ||
TextGeneration(String), | ||
|
||
#[error("Embedding generation error: {0}")] | ||
EmbeddingGeneration(String), | ||
|
||
#[error("Configuration error: {0}")] | ||
Configuration(String), | ||
|
||
#[error("Ollama error: {0}")] | ||
Ollama(#[from] OllamaError), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
mod ollama; | ||
mod openai; | ||
|
||
pub use ollama::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct OllamaConfig { | ||
pub base_url: Option<String>, | ||
pub model: Option<String>, | ||
pub embedding_model: Option<String>, | ||
} | ||
|
||
impl Default for OllamaConfig { | ||
fn default() -> Self { | ||
Self { | ||
base_url: Some("http://localhost:11434".to_string()), | ||
model: Some("codestral:latest".to_string()), | ||
embedding_model: Some("nomic-embed-text:latest".to_string()), | ||
} | ||
} | ||
} |
Oops, something went wrong.