Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/guywaldman/orch into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
guywaldman committed Jul 20, 2024
2 parents a0ef90a + eb892d6 commit aa4432b
Show file tree
Hide file tree
Showing 17 changed files with 926 additions and 0 deletions.
22 changes: 22 additions & 0 deletions examples/embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//! This example demonstrates how to use the `Executor` to generate embeddings from the LLM.
//! We construct an `Ollama` instance and use it to generate embeddings.
//!
use orch::{Executor, OllamaBuilder};

#[tokio::main]
async fn main() {
let text = "Lorem ipsum";

println!("Text: {text}");
println!("---");

let ollama = OllamaBuilder::new().build();
let executor = Executor::new(&ollama);
let embedding = executor
.generate_embedding(text)
.await
.expect("Execution failed");

println!("Embedding:");
println!("{:?}", embedding);
}
24 changes: 24 additions & 0 deletions examples/text_generation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//! This example demonstrates how to use the `Executor` to generate a response from the LLM.
//! We construct an `Ollama` instance and use it to generate a response.
use orch::{Executor, OllamaBuilder};

#[tokio::main]
async fn main() {
let prompt = "What is 2+2?";
let system_prompt = "You are a helpful assistant";

println!("Prompt: {prompt}");
println!("System prompt: {system_prompt}");
println!("---");

let ollama = OllamaBuilder::new().build();
let executor = Executor::new(&ollama);
let response = executor
.text_complete(prompt, system_prompt)
.await
.expect("Execution failed");

println!("Response:");
println!("{}", response.text);
}
34 changes: 34 additions & 0 deletions examples/text_generation_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//! This example demonstrates how to use the `Executor` to generate a streaming response from the LLM.
//! We construct an `Ollama` instance and use it to generate a streaming response.
use orch::{Executor, OllamaBuilder};
use tokio_stream::StreamExt;

#[tokio::main]
async fn main() {
let prompt = "What is 2+2?";
let system_prompt = "You are a helpful assistant";

println!("Prompt: {prompt}");
println!("System prompt: {system_prompt}");
println!("---");

let ollama = OllamaBuilder::new().build();
let executor = Executor::new(&ollama);
let mut response = executor
.text_complete_stream(prompt, system_prompt)
.await
.expect("Execution failed");

println!("Response:");
while let Some(chunk) = response.stream.next().await {
match chunk {
Ok(chunk) => print!("{chunk}"),
Err(e) => {
println!("Error: {e}");
break;
}
}
}
println!();
}
3 changes: 3 additions & 0 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod net;

pub use net::*;
4 changes: 4 additions & 0 deletions src/core/net/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/// Module for working with Server-Sent Events.
mod sse;

pub use sse::*;
28 changes: 28 additions & 0 deletions src/core/net/sse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use async_gen::AsyncIter;
use reqwest::{header, Client};
use tokio_stream::Stream;

/// A client for working with Server-Sent Events.
pub struct SseClient;

impl SseClient {
pub fn post(url: &str, body: Option<String>) -> impl Stream<Item = String> {
let client = Client::new();
let mut req = Client::post(&client, url)
.header(header::ACCEPT, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.header(header::CONNECTION, "keep-alive")
.header(header::CONTENT_TYPE, "application/json");
if let Some(body) = body {
req = req.body(body);
}
let req = req.build().unwrap();

AsyncIter::from(async_gen::gen! {
let mut conn = client.execute(req).await.unwrap();
while let Some(event) = conn.chunk().await.unwrap() {
yield std::str::from_utf8(&event).unwrap().to_owned();
}
})
}
}
110 changes: 110 additions & 0 deletions src/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use std::pin::Pin;

use thiserror::Error;
use tokio_stream::Stream;

use crate::{Llm, LlmError, TextCompleteOptions, TextCompleteStreamOptions};

pub struct Executor<'a, L: Llm> {
llm: &'a L,
}

#[derive(Debug, Error)]
pub enum ExecutorError {
#[error("LLM error: {0}")]
Llm(LlmError),
}

impl<'a, L: Llm> Executor<'a, L> {
/// Creates a new `Executor` instance.
///
/// # Arguments
/// * `llm` - The LLM to use for the execution.
pub fn new(llm: &'a L) -> Self {
Self { llm }
}

/// Generates a response from the LLM (non-streaming).
///
/// # Arguments
/// * `prompt` - The prompt to generate a response for.
/// * `system_prompt` - The system prompt to use for the generation.
///
/// # Returns
/// A [Result] containing the response from the LLM or an error if there was a problem.
pub async fn text_complete(
&self,
prompt: &str,
system_prompt: &str,
) -> Result<ExecutorTextCompleteResponse, ExecutorError> {
let options = TextCompleteOptions {
..Default::default()
};
let response = self
.llm
.text_complete(prompt, system_prompt, options)
.await
.map_err(ExecutorError::Llm)?;
Ok(ExecutorTextCompleteResponse {
text: response.text,
context: ExecutorContext {},
})
}

/// Generates a streaming response from the LLM.
///
/// # Arguments
/// * `prompt` - The prompt to generate a response for.
/// * `system_prompt` - The system prompt to use for the generation.
///
/// # Returns
/// A [Result] containing the response from the LLM or an error if there was a problem.
pub async fn text_complete_stream(
&self,
prompt: &str,
system_prompt: &str,
) -> Result<ExecutorTextCompleteStreamResponse, ExecutorError> {
let options = TextCompleteStreamOptions {
..Default::default()
};
let response = self
.llm
.text_complete_stream(prompt, system_prompt, options)
.await
.map_err(ExecutorError::Llm)?;
Ok(ExecutorTextCompleteStreamResponse {
stream: response.stream,
context: ExecutorContext {},
})
}

/// Generates an embedding from the LLM.
///
/// # Arguments
/// * `prompt` - The item to generate an embedding for.
///
/// # Returns
///
/// A [Result] containing the embedding or an error if there was a problem.
pub async fn generate_embedding(&self, prompt: &str) -> Result<Vec<f32>, ExecutorError> {
let response = self
.llm
.generate_embedding(prompt)
.await
.map_err(ExecutorError::Llm)?;
Ok(response)
}
}

// TODO: Support context for completions (e.g., IDs of past conversations in Ollama).
pub struct ExecutorContext;

pub struct ExecutorTextCompleteResponse {
pub text: String,
pub context: ExecutorContext,
}

pub struct ExecutorTextCompleteStreamResponse {
pub stream: Pin<Box<dyn Stream<Item = Result<String, LlmError>> + Send>>,
pub context: ExecutorContext,
}
8 changes: 8 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
mod core;
mod executor;
mod llm;

// TODO: Narrow the scope of the use statements.
pub use core::*;
pub use executor::*;
pub use llm::*;
51 changes: 51 additions & 0 deletions src/llm/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use thiserror::Error;

use crate::{LlmProvider, OllamaError};

#[derive(Debug, Error)]
pub enum LlmProviderError {
#[error("Invalid LLM provider: {0}")]
InvalidValue(String),
}

impl std::fmt::Display for LlmProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LlmProvider::Ollama => write!(f, "ollama"),
LlmProvider::OpenAi => write!(f, "openai"),
}
}
}

impl Default for LlmProvider {
fn default() -> Self {
Self::Ollama
}
}

impl TryFrom<&str> for LlmProvider {
type Error = LlmProviderError;

fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"ollama" => Ok(LlmProvider::Ollama),
"openai" => Ok(LlmProvider::OpenAi),
_ => Err(LlmProviderError::InvalidValue(value.to_string())),
}
}
}

#[derive(Debug, Error)]
pub enum LlmError {
#[error("Text generation error: {0}")]
TextGeneration(String),

#[error("Embedding generation error: {0}")]
EmbeddingGeneration(String),

#[error("Configuration error: {0}")]
Configuration(String),

#[error("Ollama error: {0}")]
Ollama(#[from] OllamaError),
}
4 changes: 4 additions & 0 deletions src/llm/llm_provider/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod ollama;
mod openai;

pub use ollama::*;
18 changes: 18 additions & 0 deletions src/llm/llm_provider/ollama/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
pub base_url: Option<String>,
pub model: Option<String>,
pub embedding_model: Option<String>,
}

impl Default for OllamaConfig {
fn default() -> Self {
Self {
base_url: Some("http://localhost:11434".to_string()),
model: Some("codestral:latest".to_string()),
embedding_model: Some("nomic-embed-text:latest".to_string()),
}
}
}
Loading

0 comments on commit aa4432b

Please sign in to comment.