diff --git a/examples/embeddings.rs b/examples/embeddings.rs new file mode 100644 index 0000000..d523950 --- /dev/null +++ b/examples/embeddings.rs @@ -0,0 +1,22 @@ +//! This example demonstrates how to use the `Executor` to generate embeddings from the LLM. +//! We construct an `Ollama` instance and use it to generate embeddings. +//! +use orch::{Executor, OllamaBuilder}; + +#[tokio::main] +async fn main() { + let text = "Lorem ipsum"; + + println!("Text: {text}"); + println!("---"); + + let ollama = OllamaBuilder::new().build(); + let executor = Executor::new(&ollama); + let embedding = executor + .generate_embedding(text) + .await + .expect("Execution failed"); + + println!("Embedding:"); + println!("{:?}", embedding); +} diff --git a/examples/text_generation.rs b/examples/text_generation.rs new file mode 100644 index 0000000..5848b35 --- /dev/null +++ b/examples/text_generation.rs @@ -0,0 +1,24 @@ +//! This example demonstrates how to use the `Executor` to generate a response from the LLM. +//! We construct an `Ollama` instance and use it to generate a response. + +use orch::{Executor, OllamaBuilder}; + +#[tokio::main] +async fn main() { + let prompt = "What is 2+2?"; + let system_prompt = "You are a helpful assistant"; + + println!("Prompt: {prompt}"); + println!("System prompt: {system_prompt}"); + println!("---"); + + let ollama = OllamaBuilder::new().build(); + let executor = Executor::new(&ollama); + let response = executor + .text_complete(prompt, system_prompt) + .await + .expect("Execution failed"); + + println!("Response:"); + println!("{}", response.text); +} diff --git a/examples/text_generation_stream.rs b/examples/text_generation_stream.rs new file mode 100644 index 0000000..264b059 --- /dev/null +++ b/examples/text_generation_stream.rs @@ -0,0 +1,34 @@ +//! This example demonstrates how to use the `Executor` to generate a streaming response from the LLM. +//! We construct an `Ollama` instance and use it to generate a streaming response. + +use orch::{Executor, OllamaBuilder}; +use tokio_stream::StreamExt; + +#[tokio::main] +async fn main() { + let prompt = "What is 2+2?"; + let system_prompt = "You are a helpful assistant"; + + println!("Prompt: {prompt}"); + println!("System prompt: {system_prompt}"); + println!("---"); + + let ollama = OllamaBuilder::new().build(); + let executor = Executor::new(&ollama); + let mut response = executor + .text_complete_stream(prompt, system_prompt) + .await + .expect("Execution failed"); + + println!("Response:"); + while let Some(chunk) = response.stream.next().await { + match chunk { + Ok(chunk) => print!("{chunk}"), + Err(e) => { + println!("Error: {e}"); + break; + } + } + } + println!(); +} diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 0000000..6aada65 --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,3 @@ +mod net; + +pub use net::*; diff --git a/src/core/net/mod.rs b/src/core/net/mod.rs new file mode 100644 index 0000000..4ca4b1c --- /dev/null +++ b/src/core/net/mod.rs @@ -0,0 +1,4 @@ +/// Module for working with Server-Sent Events. +mod sse; + +pub use sse::*; diff --git a/src/core/net/sse.rs b/src/core/net/sse.rs new file mode 100644 index 0000000..5f7f259 --- /dev/null +++ b/src/core/net/sse.rs @@ -0,0 +1,28 @@ +use async_gen::AsyncIter; +use reqwest::{header, Client}; +use tokio_stream::Stream; + +/// A client for working with Server-Sent Events. +pub struct SseClient; + +impl SseClient { + pub fn post(url: &str, body: Option) -> impl Stream { + let client = Client::new(); + let mut req = Client::post(&client, url) + .header(header::ACCEPT, "text/event-stream") + .header(header::CACHE_CONTROL, "no-cache") + .header(header::CONNECTION, "keep-alive") + .header(header::CONTENT_TYPE, "application/json"); + if let Some(body) = body { + req = req.body(body); + } + let req = req.build().unwrap(); + + AsyncIter::from(async_gen::gen! { + let mut conn = client.execute(req).await.unwrap(); + while let Some(event) = conn.chunk().await.unwrap() { + yield std::str::from_utf8(&event).unwrap().to_owned(); + } + }) + } +} diff --git a/src/executor.rs b/src/executor.rs new file mode 100644 index 0000000..1bc66a7 --- /dev/null +++ b/src/executor.rs @@ -0,0 +1,110 @@ +use std::pin::Pin; + +use thiserror::Error; +use tokio_stream::Stream; + +use crate::{Llm, LlmError, TextCompleteOptions, TextCompleteStreamOptions}; + +pub struct Executor<'a, L: Llm> { + llm: &'a L, +} + +#[derive(Debug, Error)] +pub enum ExecutorError { + #[error("LLM error: {0}")] + Llm(LlmError), +} + +impl<'a, L: Llm> Executor<'a, L> { + /// Creates a new `Executor` instance. + /// + /// # Arguments + /// * `llm` - The LLM to use for the execution. + pub fn new(llm: &'a L) -> Self { + Self { llm } + } + + /// Generates a response from the LLM (non-streaming). + /// + /// # Arguments + /// * `prompt` - The prompt to generate a response for. + /// * `system_prompt` - The system prompt to use for the generation. + /// + /// # Returns + /// A [Result] containing the response from the LLM or an error if there was a problem. + pub async fn text_complete( + &self, + prompt: &str, + system_prompt: &str, + ) -> Result { + let options = TextCompleteOptions { + ..Default::default() + }; + let response = self + .llm + .text_complete(prompt, system_prompt, options) + .await + .map_err(ExecutorError::Llm)?; + Ok(ExecutorTextCompleteResponse { + text: response.text, + context: ExecutorContext {}, + }) + } + + /// Generates a streaming response from the LLM. + /// + /// # Arguments + /// * `prompt` - The prompt to generate a response for. + /// * `system_prompt` - The system prompt to use for the generation. + /// + /// # Returns + /// A [Result] containing the response from the LLM or an error if there was a problem. + pub async fn text_complete_stream( + &self, + prompt: &str, + system_prompt: &str, + ) -> Result { + let options = TextCompleteStreamOptions { + ..Default::default() + }; + let response = self + .llm + .text_complete_stream(prompt, system_prompt, options) + .await + .map_err(ExecutorError::Llm)?; + Ok(ExecutorTextCompleteStreamResponse { + stream: response.stream, + context: ExecutorContext {}, + }) + } + + /// Generates an embedding from the LLM. + /// + /// # Arguments + /// * `prompt` - The item to generate an embedding for. + /// + /// # Returns + /// + /// A [Result] containing the embedding or an error if there was a problem. + pub async fn generate_embedding(&self, prompt: &str) -> Result, ExecutorError> { + let response = self + .llm + .generate_embedding(prompt) + .await + .map_err(ExecutorError::Llm)?; + Ok(response) + } +} + +// TODO: Support context for completions (e.g., IDs of past conversations in Ollama). +pub struct ExecutorContext; + +pub struct ExecutorTextCompleteResponse { + pub text: String, + pub context: ExecutorContext, +} + +pub struct ExecutorTextCompleteStreamResponse { + pub stream: Pin> + Send>>, + pub context: ExecutorContext, +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c272657 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,8 @@ +mod core; +mod executor; +mod llm; + +// TODO: Narrow the scope of the use statements. +pub use core::*; +pub use executor::*; +pub use llm::*; diff --git a/src/llm/error.rs b/src/llm/error.rs new file mode 100644 index 0000000..7f068d2 --- /dev/null +++ b/src/llm/error.rs @@ -0,0 +1,51 @@ +use thiserror::Error; + +use crate::{LlmProvider, OllamaError}; + +#[derive(Debug, Error)] +pub enum LlmProviderError { + #[error("Invalid LLM provider: {0}")] + InvalidValue(String), +} + +impl std::fmt::Display for LlmProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LlmProvider::Ollama => write!(f, "ollama"), + LlmProvider::OpenAi => write!(f, "openai"), + } + } +} + +impl Default for LlmProvider { + fn default() -> Self { + Self::Ollama + } +} + +impl TryFrom<&str> for LlmProvider { + type Error = LlmProviderError; + + fn try_from(value: &str) -> Result { + match value { + "ollama" => Ok(LlmProvider::Ollama), + "openai" => Ok(LlmProvider::OpenAi), + _ => Err(LlmProviderError::InvalidValue(value.to_string())), + } + } +} + +#[derive(Debug, Error)] +pub enum LlmError { + #[error("Text generation error: {0}")] + TextGeneration(String), + + #[error("Embedding generation error: {0}")] + EmbeddingGeneration(String), + + #[error("Configuration error: {0}")] + Configuration(String), + + #[error("Ollama error: {0}")] + Ollama(#[from] OllamaError), +} diff --git a/src/llm/llm_provider/mod.rs b/src/llm/llm_provider/mod.rs new file mode 100644 index 0000000..6b034ef --- /dev/null +++ b/src/llm/llm_provider/mod.rs @@ -0,0 +1,4 @@ +mod ollama; +mod openai; + +pub use ollama::*; diff --git a/src/llm/llm_provider/ollama/config.rs b/src/llm/llm_provider/ollama/config.rs new file mode 100644 index 0000000..36dbae1 --- /dev/null +++ b/src/llm/llm_provider/ollama/config.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OllamaConfig { + pub base_url: Option, + pub model: Option, + pub embedding_model: Option, +} + +impl Default for OllamaConfig { + fn default() -> Self { + Self { + base_url: Some("http://localhost:11434".to_string()), + model: Some("codestral:latest".to_string()), + embedding_model: Some("nomic-embed-text:latest".to_string()), + } + } +} diff --git a/src/llm/llm_provider/ollama/llm.rs b/src/llm/llm_provider/ollama/llm.rs new file mode 100644 index 0000000..73d0cdf --- /dev/null +++ b/src/llm/llm_provider/ollama/llm.rs @@ -0,0 +1,272 @@ +use thiserror::Error; +use tokio_stream::StreamExt; + +use crate::*; + +pub mod ollama_model { + pub const CODESTRAL: &str = "codestral:latest"; +} + +pub mod ollama_embedding_model { + pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text:latest"; +} + +#[derive(Debug, Clone)] +pub struct Ollama<'a> { + base_url: &'a str, + pub model: Option<&'a str>, + pub embeddings_model: Option<&'a str>, +} + +impl Default for Ollama<'_> { + fn default() -> Self { + Self { + base_url: "http://localhost:11434", + model: Some(ollama_model::CODESTRAL), + embeddings_model: Some(ollama_embedding_model::NOMIC_EMBED_TEXT), + } + } +} + +pub struct OllamaBuilder<'a> { + base_url: &'a str, + model: Option<&'a str>, + embeddings_model: Option<&'a str>, +} + +impl Default for OllamaBuilder<'_> { + fn default() -> Self { + let ollama = Ollama::default(); + Self { + base_url: ollama.base_url, + model: ollama.model, + embeddings_model: ollama.embeddings_model, + } + } +} + +impl<'a> OllamaBuilder<'a> { + pub fn new() -> Self { + Default::default() + } + + pub fn with_base_url(mut self, base_url: &'a str) -> Self { + self.base_url = base_url; + self + } + + pub fn with_model(mut self, model: &'a str) -> Self { + self.model = Some(model); + self + } + + pub fn with_embeddings_model(mut self, embeddings_model: &'a str) -> Self { + self.embeddings_model = Some(embeddings_model); + self + } + + pub fn build(self) -> Ollama<'a> { + Ollama { + base_url: self.base_url, + model: self.model, + embeddings_model: self.embeddings_model, + } + } +} + +#[derive(Error, Debug)] +pub enum OllamaError { + #[error("Unexpected response from API. Error: {0}")] + Api(String), + + #[error("Unexpected error when parsing response from Ollama. Error: {0}")] + Parsing(String), + + #[error("Configuration error: {0}")] + Configuration(String), + + #[error("Serialization error: {0}")] + Serialization(String), + + #[error( + "Ollama API is not available. Please check if Ollama is running in the specified port. Error: {0}" + )] + ApiUnavailable(String), +} + +impl<'a> Ollama<'a> { + /// Lists the running models in the Ollama API. + /// + /// # Returns + /// + /// A [Result] containing the list of running models or an error if there was a problem. + /// + #[allow(dead_code)] + pub(crate) fn list_running_models(&self) -> Result { + let response = self.get_from_ollama_api("api/ps")?; + let parsed_response = Self::parse_models_response(&response)?; + Ok(parsed_response) + } + + // /// Lists the local models in the Ollama API. + // /// + // /// # Returns + // /// + // /// A [Result] containing the list of local models or an error if there was a problem. + #[allow(dead_code)] + pub fn list_local_models(&self) -> Result { + let response = self.get_from_ollama_api("api/tags")?; + let parsed_response = Self::parse_models_response(&response)?; + Ok(parsed_response) + } + + fn parse_models_response(response: &str) -> Result { + let models: OllamaApiModelsMetadata = + serde_json::from_str(response).map_err(|e| OllamaError::Parsing(e.to_string()))?; + Ok(models) + } + + fn get_from_ollama_api(&self, url: &str) -> Result { + let url = format!("{}/{}", self.base_url()?, url); + + let client = reqwest::blocking::Client::new(); + let response = client + .get(url) + .send() + .map_err(|e| OllamaError::ApiUnavailable(e.to_string()))?; + let response_text = response + .text() + .map_err(|e| OllamaError::Api(e.to_string()))?; + Ok(response_text) + } + + fn base_url(&self) -> Result { + Ok(self.base_url.to_string()) + } + + fn model(&self) -> Result { + self.model + .map(|s| s.to_owned()) + .ok_or_else(|| OllamaError::Configuration("Model not set".to_string())) + } + + fn embedding_model(&self) -> Result { + self.embeddings_model + .map(|s| s.to_owned()) + .ok_or_else(|| OllamaError::Configuration("Embedding model not set".to_string())) + } +} + +impl<'a> Llm for Ollama<'a> { + async fn text_complete( + &self, + prompt: &str, + system_prompt: &str, + _options: TextCompleteOptions, + ) -> Result { + let body = OllamaGenerateRequest { + model: self + .model() + .map_err(|_e| LlmError::Configuration("Model not set".to_string()))?, + prompt: prompt.to_string(), + system: Some(system_prompt.to_string()), + ..Default::default() + }; + + let client = reqwest::Client::new(); + let url = format!( + "{}/api/generate", + self.base_url() + .map_err(|_e| LlmError::Configuration("Base URL not set".to_string()))? + ); + let response = client + .post(url) + .body(serde_json::to_string(&body).unwrap()) + .send() + .await + .map_err(|e| LlmError::Ollama(OllamaError::ApiUnavailable(e.to_string())))?; + let body = response + .text() + .await + .map_err(|e| LlmError::Ollama(OllamaError::Api(e.to_string())))?; + let ollama_response: OllamaGenerateResponse = serde_json::from_str(&body) + .map_err(|e| LlmError::Ollama(OllamaError::Parsing(e.to_string())))?; + let response = TextCompleteResponse { + text: ollama_response.response, + context: ollama_response.context, + }; + Ok(response) + } + + async fn text_complete_stream( + &self, + prompt: &str, + system_prompt: &str, + options: TextCompleteStreamOptions, + ) -> Result { + let body = OllamaGenerateRequest { + model: self.model()?, + prompt: prompt.to_string(), + stream: Some(true), + format: None, + images: None, + system: Some(system_prompt.to_string()), + keep_alive: Some("5m".to_string()), + context: options.context, + }; + + let url = format!("{}/api/generate", self.base_url()?); + let stream = SseClient::post(&url, Some(serde_json::to_string(&body).unwrap())); + let stream = stream.map(|event| { + let parsed_message = serde_json::from_str::(&event); + match parsed_message { + Ok(message) => Ok(message.response), + Err(e) => Err(LlmError::Ollama(OllamaError::Parsing(e.to_string()))), + } + }); + let response = TextCompleteStreamResponse { + stream: Box::pin(stream), + }; + Ok(response) + } + + async fn generate_embedding(&self, prompt: &str) -> Result, LlmError> { + let client = reqwest::Client::new(); + let url = format!("{}/api/embeddings", self.base_url()?); + let body = OllamaEmbeddingsRequest { + model: self.embedding_model()?, + prompt: prompt.to_string(), + }; + let response = client + .post(url) + .body( + serde_json::to_string(&body) + .map_err(|e| OllamaError::Serialization(e.to_string()))?, + ) + .send() + .await + .map_err(|e| OllamaError::ApiUnavailable(e.to_string()))?; + let body = response + .text() + .await + .map_err(|e| OllamaError::Api(e.to_string()))?; + let response: OllamaEmbeddingsResponse = + serde_json::from_str(&body).map_err(|e| OllamaError::Parsing(e.to_string()))?; + + Ok(response.embedding) + } + + fn provider(&self) -> LlmProvider { + LlmProvider::Ollama + } + + fn text_completion_model_name(&self) -> String { + self.model().expect("Model not set").to_string() + } + + fn embedding_model_name(&self) -> String { + self.embedding_model() + .expect("Embedding model not set") + .to_string() + } +} diff --git a/src/llm/llm_provider/ollama/mod.rs b/src/llm/llm_provider/ollama/mod.rs new file mode 100644 index 0000000..596eee7 --- /dev/null +++ b/src/llm/llm_provider/ollama/mod.rs @@ -0,0 +1,6 @@ +mod config; +mod llm; +mod models; + +pub use llm::*; +pub use models::*; diff --git a/src/llm/llm_provider/ollama/models.rs b/src/llm/llm_provider/ollama/models.rs new file mode 100644 index 0000000..9cc41f3 --- /dev/null +++ b/src/llm/llm_provider/ollama/models.rs @@ -0,0 +1,158 @@ +use serde::{Deserialize, Serialize}; + +use crate::ollama_model; + +/// Response from the Ollama API for obtaining information about local models. +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#list-running-models). +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaApiModelsMetadata { + pub models: Vec, +} + +/// Response item from the Ollama API for obtaining information about local models. +/// +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#response-22). +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaApiModelMetadata { + /// The name of the model (e.g., "mistral:latest") + pub name: String, + + /// The Ollama identifier of the model (e.g., "mistral:latest") + pub model: String, + + /// Size of the model in bytes + pub size: usize, + + /// Digest of the model using SHA256 (e.g., "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8") + pub digest: String, + + /// Model expiry time in ISO 8601 format (e.g., "2024-06-04T14:38:31.83753-07:00") + pub expires_at: Option, + + /// More details about the model + pub details: OllamaApiModelDetails, +} + +/// Details about a running model in the API for listing running models (`GET /api/ps`). +/// +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#response-22). +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaApiModelDetails { + /// Model identifier that this model is based on + pub parent_model: String, + + /// Format that this model is stored in (e.g., "gguf") + pub format: String, + + /// Model family (e.g., "ollama") + pub family: String, + + /// Parameters of the model (e.g., "7.2B") + pub parameter_size: String, + + /// Quantization level of the model (e.g., "Q4_0" for 4-bit quantization) + pub quantization_level: String, +} + +/// Request for generating a response from the Ollama API. +/// +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-a-completion). +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaGenerateRequest { + /// Model identifier (e.g., "mistral:latest") + pub model: String, + + /// The prompt to generate a response for (e.g., "List all Kubernetes pods") + pub prompt: String, + + /// The context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory + pub context: Option>, + + /// Optional list of base64-encoded images (for multimodal models such as `llava`) + pub images: Option>, + + /// Optional format to use for the response (currently only "json" is supported) + pub format: Option, + + /// Optional flag that controls whether the response is streamed or not (defaults to true). + /// If `false`` the response will be returned as a single response object, rather than a stream of objects + pub stream: Option, + + // System message (overrides what is defined in the Modelfile) + pub system: Option, + + /// Controls how long the model will stay loaded into memory following the request (default: 5m) + pub keep_alive: Option, +} + +impl Default for OllamaGenerateRequest { + fn default() -> Self { + Self { + model: ollama_model::CODESTRAL.to_string(), + prompt: "".to_string(), + stream: Some(false), + format: None, + images: None, + system: Some("You are a helpful assistant".to_string()), + keep_alive: Some("5m".to_string()), + context: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[allow(dead_code)] +pub struct OllamaGenerateResponse { + /// Model identifier (e.g., "mistral:latest") + pub model: String, + + /// Time at which the response was generated (ISO 8601 format) + pub created_at: String, + + /// The response to the prompt + pub response: String, + + /// The encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory + pub context: Option>, + + /// The duration of the response in nanoseconds + pub total_duration: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaGenerateStreamItemResponse { + /// Model identifier (e.g., "mistral:latest") + pub model: String, + + /// Time at which the response was generated (ISO 8601 format) + pub created_at: String, + + /// The response to the prompt + pub response: String, +} + +/// Request for generating an embedding from the Ollama API. +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-embeddings). +/// +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct OllamaEmbeddingsRequest { + /// The string to generate an embedding for. + pub prompt: String, + + /// The model to use for the embedding generation. + pub model: String, +} + +/// Response from the Ollama API for generating an embedding. +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-embeddings). +/// +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaEmbeddingsResponse { + /// The embedding for the prompt. + pub embedding: Vec, +} diff --git a/src/llm/llm_provider/openai.rs b/src/llm/llm_provider/openai.rs new file mode 100644 index 0000000..f2dcb1e --- /dev/null +++ b/src/llm/llm_provider/openai.rs @@ -0,0 +1,57 @@ +// use async_trait::async_trait; +// use openai_api_rs::v1::{ +// api::OpenAIClient, +// chat_completion::{self, ChatCompletionRequest}, +// common::{GPT3_5_TURBO, GPT4, GPT4_O}, +// }; + +// pub mod openai_model { +// pub const GPT35_TURBO: &str = GPT35_TURBO; +// pub const GPT4: &str = GPT4; +// pub const GPT40: &str = GPT40; +// } + +// pub struct OpenAi<'a> { +// pub model: &'a str, +// api_key: &'a str, +// } + +// impl<'a> OpenAi<'a> { +// pub fn new(api_key: &'a str, model: &'a str) -> Self { +// Self { api_key, model } +// } +// } + +// #[async_trait] +// impl<'a> TextCompletionLlm for OpenAi<'a> { +// async fn complete( +// &self, +// system_prompts: &[String], +// ) -> Result> { +// let client = OpenAIClient::new(self.api_key.to_owned()); +// let system_msgs = system_prompts +// .iter() +// .map(|p| chat_completion::ChatCompletionMessage { +// role: chat_completion::MessageRole::system, +// content: chat_completion::Content::Text(p.to_owned()), +// name: None, +// tool_calls: None, +// tool_call_id: None, +// }) +// .collect::>(); +// let mut req = ChatCompletionRequest::new(self.model.to_owned(), system_msgs); +// req.max_tokens = Some(self.config.max_tokens as i64); +// req.temperature = Some(self.config.temperature); + +// let result = client.chat_completion(req).await?; +// let completion = result +// .choices +// .first() +// .unwrap() +// .message +// .content +// .clone() +// .unwrap(); +// Ok(completion) +// } +// } diff --git a/src/llm/mod.rs b/src/llm/mod.rs new file mode 100644 index 0000000..446b0ef --- /dev/null +++ b/src/llm/mod.rs @@ -0,0 +1,7 @@ +mod error; +mod llm_provider; +mod models; + +pub use error::*; +pub use llm_provider::*; +pub use models::*; diff --git a/src/llm/models.rs b/src/llm/models.rs new file mode 100644 index 0000000..fa9bf6b --- /dev/null +++ b/src/llm/models.rs @@ -0,0 +1,120 @@ +#![allow(dead_code)] + +use std::pin::Pin; + +use dyn_clone::DynClone; +use serde::{Deserialize, Serialize}; +use tokio_stream::Stream; + +use super::error::LlmError; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum LlmProvider { + #[serde(rename = "ollama")] + Ollama, + #[serde(rename = "openai")] + OpenAi, +} + +/// A trait for LLM providers which implements text completion, embeddings, etc. +/// +/// > `DynClone` is used so that there can be dynamic dispatch of the `Llm` trait, +/// > especially needed for [magic-cli](https://github.com/guywaldman/magic-cli). +pub trait Llm: DynClone { + /// Generates a response from the LLM. + /// + /// # Arguments + /// * `prompt` - The prompt to generate a response for. + /// * `system_prompt` - The system prompt to use for the generation. + /// * `options` - The options for the generation. + /// + /// # Returns + /// A [Result] containing the response from the LLM or an error if there was a problem. + /// + fn text_complete( + &self, + prompt: &str, + system_prompt: &str, + options: TextCompleteOptions, + ) -> impl std::future::Future> + Send; + + /// Generates a streaming response from the LLM. + /// + /// # Arguments + /// * `prompt` - The prompt to generate a response for. + /// * `system_prompt` - The system prompt to use for the generation. + /// * `options` - The options for the generation. + /// + /// # Returns + /// A [Result] containing the response from the LLM or an error if there was a problem. + /// + fn text_complete_stream( + &self, + prompt: &str, + system_prompt: &str, + options: TextCompleteStreamOptions, + ) -> impl std::future::Future> + Send; + + /// Generates an embedding from the LLM. + /// + /// # Arguments + /// * `prompt` - The item to generate an embedding for. + /// + /// # Returns + /// + /// A [Result] containing the embedding or an error if there was a problem. + fn generate_embedding( + &self, + prompt: &str, + ) -> impl std::future::Future, LlmError>> + Send; + + /// Returns the provider of the LLM. + fn provider(&self) -> LlmProvider; + + /// Returns the name of the model used for text completions. + fn text_completion_model_name(&self) -> String; + + /// Returns the name of the model used for embeddings. + fn embedding_model_name(&self) -> String; +} + +#[derive(Debug, Clone, Default)] +pub struct TextCompleteOptions { + /// An encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory. + /// This should be as returned from the previous response. + pub context: Option>, +} + +#[derive(Debug, Clone, Default)] +pub struct TextCompleteStreamOptions { + pub context: Option>, +} + +#[derive(Debug, Clone)] +pub struct TextCompleteResponse { + pub text: String, + // TODO: This is specific to Ollama, context looks differently for other LLM providers. + pub context: Option>, +} + +pub struct TextCompleteStreamResponse { + pub stream: Pin> + Send>>, + // TODO: Handle context with streaming response. + // pub context: Vec, +} + +#[derive(Debug)] +pub(crate) struct SystemPromptResponseOption { + pub scenario: String, + pub type_name: String, + pub response: String, + pub schema: Vec, +} + +#[derive(Debug)] +pub(crate) struct SystemPromptCommandSchemaField { + pub name: String, + pub description: String, + pub typ: String, + pub example: String, +}