From daeb5d2b0c5abafbd8e60822aa49d318a7c187ef Mon Sep 17 00:00:00 2001 From: Guy Waldman <6546430+guywaldman@users.noreply.github.com> Date: Fri, 19 Jul 2024 06:31:17 +0300 Subject: [PATCH 1/2] Release 0.0.2 (#5) * Add basics for agent & tools * Add basics of text completion * Basic tool orchestration (#1) * Basic example with HTTP tool * Add a PDF tool * Add web search tool * Rename janus to orch * Update OAI dep & fix lint warnings * Clean up & async APIs (#4) * Add SSE and basic example * Cleanup * Update fn signatures & basic example * Reorganize * Update Cargo.toml * Add license * Update README * Add embeddings * Add GitHub Actions * Remove rustdoc lints * Update release.yml * Release: 0.0.2 * Update build.yml --- .github/workflows/build.yml | 92 +++++++++ .github/workflows/release.yml | 21 ++ .gitignore | 16 ++ .rusfmt.toml | 1 + .vscode/settings.json | 5 + CHANGELOG.md | 10 + Cargo.toml | 21 ++ LICENSE.md | 21 ++ README.md | 20 +- examples/embeddings.rs | 22 +++ examples/text_generation.rs | 24 +++ examples/text_generation_stream.rs | 34 ++++ src/core/mod.rs | 3 + src/core/net/mod.rs | 4 + src/core/net/sse.rs | 28 +++ src/executor.rs | 110 +++++++++++ src/lib.rs | 8 + src/llm/error.rs | 51 +++++ src/llm/llm_provider/mod.rs | 4 + src/llm/llm_provider/ollama/config.rs | 18 ++ src/llm/llm_provider/ollama/llm.rs | 272 ++++++++++++++++++++++++++ src/llm/llm_provider/ollama/mod.rs | 6 + src/llm/llm_provider/ollama/models.rs | 158 +++++++++++++++ src/llm/llm_provider/openai.rs | 57 ++++++ src/llm/mod.rs | 7 + src/llm/models.rs | 120 ++++++++++++ 26 files changed, 1132 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/build.yml create mode 100644 .github/workflows/release.yml create mode 100644 .gitignore create mode 100644 .rusfmt.toml create mode 100644 .vscode/settings.json create mode 100644 CHANGELOG.md create mode 100644 Cargo.toml create mode 100644 LICENSE.md create mode 100644 examples/embeddings.rs create mode 100644 examples/text_generation.rs create mode 100644 examples/text_generation_stream.rs create mode 100644 src/core/mod.rs create mode 100644 src/core/net/mod.rs create mode 100644 src/core/net/sse.rs create mode 100644 src/executor.rs create mode 100644 src/lib.rs create mode 100644 src/llm/error.rs create mode 100644 src/llm/llm_provider/mod.rs create mode 100644 src/llm/llm_provider/ollama/config.rs create mode 100644 src/llm/llm_provider/ollama/llm.rs create mode 100644 src/llm/llm_provider/ollama/mod.rs create mode 100644 src/llm/llm_provider/ollama/models.rs create mode 100644 src/llm/llm_provider/openai.rs create mode 100644 src/llm/mod.rs create mode 100644 src/llm/models.rs diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..1650ddf --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,92 @@ +name: Build +on: + pull_request: + branches: [main] + push: + branches: [develop] + +jobs: + check: + name: Check + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + + - uses: Swatinem/rust-cache@v1 + + - name: Run cargo check + uses: actions-rs/cargo@v1 + with: + command: check + + test: + name: Test Suite + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + rust: [stable] + runs-on: ${{ matrix.os }} + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + + - uses: Swatinem/rust-cache@v1 + + - name: Run cargo test + uses: actions-rs/cargo@v1 + with: + command: test + + lints: + name: Lints + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2 + with: + submodules: true + + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + components: rustfmt, clippy + + - uses: Swatinem/rust-cache@v1 + + - name: Run cargo fmt + uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check + + - name: Run cargo clippy + uses: actions-rs/cargo@v1 + with: + command: clippy + args: -- -D warnings + + # - name: Run rustdoc lints + # uses: actions-rs/cargo@v1 + # env: + # RUSTDOCFLAGS: "-D missing_docs -D rustdoc::missing_doc_code_examples" + # with: + # command: doc + # args: --workspace --all-features --no-deps --document-private-items diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..10a2965 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,21 @@ +name: Build +on: + push: + branches: [main] + +jobs: + release: + name: Release + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Install stable toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + - name: Publish to crates.io + run: cargo publish --token ${{ secrets.CRATES_IO_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a1835db --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + +.env \ No newline at end of file diff --git a/.rusfmt.toml b/.rusfmt.toml new file mode 100644 index 0000000..8449be0 --- /dev/null +++ b/.rusfmt.toml @@ -0,0 +1 @@ +max_width = 140 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..01e9a8d --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "rust-analyzer.linkedProjects": [ + "./Cargo.toml" + ] +} \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..001a52a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,10 @@ + +# Changelog + +## Version 0.0.2 + +No functional changes. + +## Version 0.0.1 + +Initial release. \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..e819030 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "orch" +version = "0.0.2" +edition = "2021" +license = "MIT" +description = "LLM orchestration library" +homepage = "https://github.com/guywaldman/orch" +repository = "https://github.com/guywaldman/orch" +keywords = ["llm", "openai", "ollama", "rust"] + +[dependencies] +async-gen = "0.2.3" +dotenv = "0.15.0" +dyn-clone = "1.0.17" +openai-api-rs = "5.0.2" +reqwest = { version = "0.12.5", features = ["blocking"] } +serde = { version = "1.0.164", features = ["derive"] } +serde_json = "1.0.97" +thiserror = "1.0.63" +tokio = { version = "1.28.2", features = ["rt", "macros"] } +tokio-stream = "0.1.15" diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..8aa2645 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) [year] [fullname] + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 2adbc28..c4cc4a2 100644 --- a/README.md +++ b/README.md @@ -1 +1,19 @@ -# janus \ No newline at end of file +# orch + +Orch (stands for "orchestrator") is a library for building LLM-powered applications and agents for the Rust programming language. +It was primarily built for usage in [magic-cli](https://github.com/guywaldman/magic-cli), but can be used in other contexts as well. + +> [!NOTE] +> +> If the project gains traction, this can be compiled as an addon to other languages such as Python or a standalone WebAssembly module. + +There is currently support for text generation with `ollama` (either stream or non-stream) and embedding generation. +Originally this contained agents and tools as well, but this was removed for now. + +See [/examples](examples) for usage examples. + +## Roadmap + +- [ ] Support for text generation with `openai` +- [ ] Embedding generation +- [ ] Agents and tools diff --git a/examples/embeddings.rs b/examples/embeddings.rs new file mode 100644 index 0000000..d523950 --- /dev/null +++ b/examples/embeddings.rs @@ -0,0 +1,22 @@ +//! This example demonstrates how to use the `Executor` to generate embeddings from the LLM. +//! We construct an `Ollama` instance and use it to generate embeddings. +//! +use orch::{Executor, OllamaBuilder}; + +#[tokio::main] +async fn main() { + let text = "Lorem ipsum"; + + println!("Text: {text}"); + println!("---"); + + let ollama = OllamaBuilder::new().build(); + let executor = Executor::new(&ollama); + let embedding = executor + .generate_embedding(text) + .await + .expect("Execution failed"); + + println!("Embedding:"); + println!("{:?}", embedding); +} diff --git a/examples/text_generation.rs b/examples/text_generation.rs new file mode 100644 index 0000000..5848b35 --- /dev/null +++ b/examples/text_generation.rs @@ -0,0 +1,24 @@ +//! This example demonstrates how to use the `Executor` to generate a response from the LLM. +//! We construct an `Ollama` instance and use it to generate a response. + +use orch::{Executor, OllamaBuilder}; + +#[tokio::main] +async fn main() { + let prompt = "What is 2+2?"; + let system_prompt = "You are a helpful assistant"; + + println!("Prompt: {prompt}"); + println!("System prompt: {system_prompt}"); + println!("---"); + + let ollama = OllamaBuilder::new().build(); + let executor = Executor::new(&ollama); + let response = executor + .text_complete(prompt, system_prompt) + .await + .expect("Execution failed"); + + println!("Response:"); + println!("{}", response.text); +} diff --git a/examples/text_generation_stream.rs b/examples/text_generation_stream.rs new file mode 100644 index 0000000..264b059 --- /dev/null +++ b/examples/text_generation_stream.rs @@ -0,0 +1,34 @@ +//! This example demonstrates how to use the `Executor` to generate a streaming response from the LLM. +//! We construct an `Ollama` instance and use it to generate a streaming response. + +use orch::{Executor, OllamaBuilder}; +use tokio_stream::StreamExt; + +#[tokio::main] +async fn main() { + let prompt = "What is 2+2?"; + let system_prompt = "You are a helpful assistant"; + + println!("Prompt: {prompt}"); + println!("System prompt: {system_prompt}"); + println!("---"); + + let ollama = OllamaBuilder::new().build(); + let executor = Executor::new(&ollama); + let mut response = executor + .text_complete_stream(prompt, system_prompt) + .await + .expect("Execution failed"); + + println!("Response:"); + while let Some(chunk) = response.stream.next().await { + match chunk { + Ok(chunk) => print!("{chunk}"), + Err(e) => { + println!("Error: {e}"); + break; + } + } + } + println!(); +} diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 0000000..6aada65 --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,3 @@ +mod net; + +pub use net::*; diff --git a/src/core/net/mod.rs b/src/core/net/mod.rs new file mode 100644 index 0000000..4ca4b1c --- /dev/null +++ b/src/core/net/mod.rs @@ -0,0 +1,4 @@ +/// Module for working with Server-Sent Events. +mod sse; + +pub use sse::*; diff --git a/src/core/net/sse.rs b/src/core/net/sse.rs new file mode 100644 index 0000000..5f7f259 --- /dev/null +++ b/src/core/net/sse.rs @@ -0,0 +1,28 @@ +use async_gen::AsyncIter; +use reqwest::{header, Client}; +use tokio_stream::Stream; + +/// A client for working with Server-Sent Events. +pub struct SseClient; + +impl SseClient { + pub fn post(url: &str, body: Option) -> impl Stream { + let client = Client::new(); + let mut req = Client::post(&client, url) + .header(header::ACCEPT, "text/event-stream") + .header(header::CACHE_CONTROL, "no-cache") + .header(header::CONNECTION, "keep-alive") + .header(header::CONTENT_TYPE, "application/json"); + if let Some(body) = body { + req = req.body(body); + } + let req = req.build().unwrap(); + + AsyncIter::from(async_gen::gen! { + let mut conn = client.execute(req).await.unwrap(); + while let Some(event) = conn.chunk().await.unwrap() { + yield std::str::from_utf8(&event).unwrap().to_owned(); + } + }) + } +} diff --git a/src/executor.rs b/src/executor.rs new file mode 100644 index 0000000..1bc66a7 --- /dev/null +++ b/src/executor.rs @@ -0,0 +1,110 @@ +use std::pin::Pin; + +use thiserror::Error; +use tokio_stream::Stream; + +use crate::{Llm, LlmError, TextCompleteOptions, TextCompleteStreamOptions}; + +pub struct Executor<'a, L: Llm> { + llm: &'a L, +} + +#[derive(Debug, Error)] +pub enum ExecutorError { + #[error("LLM error: {0}")] + Llm(LlmError), +} + +impl<'a, L: Llm> Executor<'a, L> { + /// Creates a new `Executor` instance. + /// + /// # Arguments + /// * `llm` - The LLM to use for the execution. + pub fn new(llm: &'a L) -> Self { + Self { llm } + } + + /// Generates a response from the LLM (non-streaming). + /// + /// # Arguments + /// * `prompt` - The prompt to generate a response for. + /// * `system_prompt` - The system prompt to use for the generation. + /// + /// # Returns + /// A [Result] containing the response from the LLM or an error if there was a problem. + pub async fn text_complete( + &self, + prompt: &str, + system_prompt: &str, + ) -> Result { + let options = TextCompleteOptions { + ..Default::default() + }; + let response = self + .llm + .text_complete(prompt, system_prompt, options) + .await + .map_err(ExecutorError::Llm)?; + Ok(ExecutorTextCompleteResponse { + text: response.text, + context: ExecutorContext {}, + }) + } + + /// Generates a streaming response from the LLM. + /// + /// # Arguments + /// * `prompt` - The prompt to generate a response for. + /// * `system_prompt` - The system prompt to use for the generation. + /// + /// # Returns + /// A [Result] containing the response from the LLM or an error if there was a problem. + pub async fn text_complete_stream( + &self, + prompt: &str, + system_prompt: &str, + ) -> Result { + let options = TextCompleteStreamOptions { + ..Default::default() + }; + let response = self + .llm + .text_complete_stream(prompt, system_prompt, options) + .await + .map_err(ExecutorError::Llm)?; + Ok(ExecutorTextCompleteStreamResponse { + stream: response.stream, + context: ExecutorContext {}, + }) + } + + /// Generates an embedding from the LLM. + /// + /// # Arguments + /// * `prompt` - The item to generate an embedding for. + /// + /// # Returns + /// + /// A [Result] containing the embedding or an error if there was a problem. + pub async fn generate_embedding(&self, prompt: &str) -> Result, ExecutorError> { + let response = self + .llm + .generate_embedding(prompt) + .await + .map_err(ExecutorError::Llm)?; + Ok(response) + } +} + +// TODO: Support context for completions (e.g., IDs of past conversations in Ollama). +pub struct ExecutorContext; + +pub struct ExecutorTextCompleteResponse { + pub text: String, + pub context: ExecutorContext, +} + +pub struct ExecutorTextCompleteStreamResponse { + pub stream: Pin> + Send>>, + pub context: ExecutorContext, +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c272657 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,8 @@ +mod core; +mod executor; +mod llm; + +// TODO: Narrow the scope of the use statements. +pub use core::*; +pub use executor::*; +pub use llm::*; diff --git a/src/llm/error.rs b/src/llm/error.rs new file mode 100644 index 0000000..7f068d2 --- /dev/null +++ b/src/llm/error.rs @@ -0,0 +1,51 @@ +use thiserror::Error; + +use crate::{LlmProvider, OllamaError}; + +#[derive(Debug, Error)] +pub enum LlmProviderError { + #[error("Invalid LLM provider: {0}")] + InvalidValue(String), +} + +impl std::fmt::Display for LlmProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LlmProvider::Ollama => write!(f, "ollama"), + LlmProvider::OpenAi => write!(f, "openai"), + } + } +} + +impl Default for LlmProvider { + fn default() -> Self { + Self::Ollama + } +} + +impl TryFrom<&str> for LlmProvider { + type Error = LlmProviderError; + + fn try_from(value: &str) -> Result { + match value { + "ollama" => Ok(LlmProvider::Ollama), + "openai" => Ok(LlmProvider::OpenAi), + _ => Err(LlmProviderError::InvalidValue(value.to_string())), + } + } +} + +#[derive(Debug, Error)] +pub enum LlmError { + #[error("Text generation error: {0}")] + TextGeneration(String), + + #[error("Embedding generation error: {0}")] + EmbeddingGeneration(String), + + #[error("Configuration error: {0}")] + Configuration(String), + + #[error("Ollama error: {0}")] + Ollama(#[from] OllamaError), +} diff --git a/src/llm/llm_provider/mod.rs b/src/llm/llm_provider/mod.rs new file mode 100644 index 0000000..6b034ef --- /dev/null +++ b/src/llm/llm_provider/mod.rs @@ -0,0 +1,4 @@ +mod ollama; +mod openai; + +pub use ollama::*; diff --git a/src/llm/llm_provider/ollama/config.rs b/src/llm/llm_provider/ollama/config.rs new file mode 100644 index 0000000..36dbae1 --- /dev/null +++ b/src/llm/llm_provider/ollama/config.rs @@ -0,0 +1,18 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OllamaConfig { + pub base_url: Option, + pub model: Option, + pub embedding_model: Option, +} + +impl Default for OllamaConfig { + fn default() -> Self { + Self { + base_url: Some("http://localhost:11434".to_string()), + model: Some("codestral:latest".to_string()), + embedding_model: Some("nomic-embed-text:latest".to_string()), + } + } +} diff --git a/src/llm/llm_provider/ollama/llm.rs b/src/llm/llm_provider/ollama/llm.rs new file mode 100644 index 0000000..73d0cdf --- /dev/null +++ b/src/llm/llm_provider/ollama/llm.rs @@ -0,0 +1,272 @@ +use thiserror::Error; +use tokio_stream::StreamExt; + +use crate::*; + +pub mod ollama_model { + pub const CODESTRAL: &str = "codestral:latest"; +} + +pub mod ollama_embedding_model { + pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text:latest"; +} + +#[derive(Debug, Clone)] +pub struct Ollama<'a> { + base_url: &'a str, + pub model: Option<&'a str>, + pub embeddings_model: Option<&'a str>, +} + +impl Default for Ollama<'_> { + fn default() -> Self { + Self { + base_url: "http://localhost:11434", + model: Some(ollama_model::CODESTRAL), + embeddings_model: Some(ollama_embedding_model::NOMIC_EMBED_TEXT), + } + } +} + +pub struct OllamaBuilder<'a> { + base_url: &'a str, + model: Option<&'a str>, + embeddings_model: Option<&'a str>, +} + +impl Default for OllamaBuilder<'_> { + fn default() -> Self { + let ollama = Ollama::default(); + Self { + base_url: ollama.base_url, + model: ollama.model, + embeddings_model: ollama.embeddings_model, + } + } +} + +impl<'a> OllamaBuilder<'a> { + pub fn new() -> Self { + Default::default() + } + + pub fn with_base_url(mut self, base_url: &'a str) -> Self { + self.base_url = base_url; + self + } + + pub fn with_model(mut self, model: &'a str) -> Self { + self.model = Some(model); + self + } + + pub fn with_embeddings_model(mut self, embeddings_model: &'a str) -> Self { + self.embeddings_model = Some(embeddings_model); + self + } + + pub fn build(self) -> Ollama<'a> { + Ollama { + base_url: self.base_url, + model: self.model, + embeddings_model: self.embeddings_model, + } + } +} + +#[derive(Error, Debug)] +pub enum OllamaError { + #[error("Unexpected response from API. Error: {0}")] + Api(String), + + #[error("Unexpected error when parsing response from Ollama. Error: {0}")] + Parsing(String), + + #[error("Configuration error: {0}")] + Configuration(String), + + #[error("Serialization error: {0}")] + Serialization(String), + + #[error( + "Ollama API is not available. Please check if Ollama is running in the specified port. Error: {0}" + )] + ApiUnavailable(String), +} + +impl<'a> Ollama<'a> { + /// Lists the running models in the Ollama API. + /// + /// # Returns + /// + /// A [Result] containing the list of running models or an error if there was a problem. + /// + #[allow(dead_code)] + pub(crate) fn list_running_models(&self) -> Result { + let response = self.get_from_ollama_api("api/ps")?; + let parsed_response = Self::parse_models_response(&response)?; + Ok(parsed_response) + } + + // /// Lists the local models in the Ollama API. + // /// + // /// # Returns + // /// + // /// A [Result] containing the list of local models or an error if there was a problem. + #[allow(dead_code)] + pub fn list_local_models(&self) -> Result { + let response = self.get_from_ollama_api("api/tags")?; + let parsed_response = Self::parse_models_response(&response)?; + Ok(parsed_response) + } + + fn parse_models_response(response: &str) -> Result { + let models: OllamaApiModelsMetadata = + serde_json::from_str(response).map_err(|e| OllamaError::Parsing(e.to_string()))?; + Ok(models) + } + + fn get_from_ollama_api(&self, url: &str) -> Result { + let url = format!("{}/{}", self.base_url()?, url); + + let client = reqwest::blocking::Client::new(); + let response = client + .get(url) + .send() + .map_err(|e| OllamaError::ApiUnavailable(e.to_string()))?; + let response_text = response + .text() + .map_err(|e| OllamaError::Api(e.to_string()))?; + Ok(response_text) + } + + fn base_url(&self) -> Result { + Ok(self.base_url.to_string()) + } + + fn model(&self) -> Result { + self.model + .map(|s| s.to_owned()) + .ok_or_else(|| OllamaError::Configuration("Model not set".to_string())) + } + + fn embedding_model(&self) -> Result { + self.embeddings_model + .map(|s| s.to_owned()) + .ok_or_else(|| OllamaError::Configuration("Embedding model not set".to_string())) + } +} + +impl<'a> Llm for Ollama<'a> { + async fn text_complete( + &self, + prompt: &str, + system_prompt: &str, + _options: TextCompleteOptions, + ) -> Result { + let body = OllamaGenerateRequest { + model: self + .model() + .map_err(|_e| LlmError::Configuration("Model not set".to_string()))?, + prompt: prompt.to_string(), + system: Some(system_prompt.to_string()), + ..Default::default() + }; + + let client = reqwest::Client::new(); + let url = format!( + "{}/api/generate", + self.base_url() + .map_err(|_e| LlmError::Configuration("Base URL not set".to_string()))? + ); + let response = client + .post(url) + .body(serde_json::to_string(&body).unwrap()) + .send() + .await + .map_err(|e| LlmError::Ollama(OllamaError::ApiUnavailable(e.to_string())))?; + let body = response + .text() + .await + .map_err(|e| LlmError::Ollama(OllamaError::Api(e.to_string())))?; + let ollama_response: OllamaGenerateResponse = serde_json::from_str(&body) + .map_err(|e| LlmError::Ollama(OllamaError::Parsing(e.to_string())))?; + let response = TextCompleteResponse { + text: ollama_response.response, + context: ollama_response.context, + }; + Ok(response) + } + + async fn text_complete_stream( + &self, + prompt: &str, + system_prompt: &str, + options: TextCompleteStreamOptions, + ) -> Result { + let body = OllamaGenerateRequest { + model: self.model()?, + prompt: prompt.to_string(), + stream: Some(true), + format: None, + images: None, + system: Some(system_prompt.to_string()), + keep_alive: Some("5m".to_string()), + context: options.context, + }; + + let url = format!("{}/api/generate", self.base_url()?); + let stream = SseClient::post(&url, Some(serde_json::to_string(&body).unwrap())); + let stream = stream.map(|event| { + let parsed_message = serde_json::from_str::(&event); + match parsed_message { + Ok(message) => Ok(message.response), + Err(e) => Err(LlmError::Ollama(OllamaError::Parsing(e.to_string()))), + } + }); + let response = TextCompleteStreamResponse { + stream: Box::pin(stream), + }; + Ok(response) + } + + async fn generate_embedding(&self, prompt: &str) -> Result, LlmError> { + let client = reqwest::Client::new(); + let url = format!("{}/api/embeddings", self.base_url()?); + let body = OllamaEmbeddingsRequest { + model: self.embedding_model()?, + prompt: prompt.to_string(), + }; + let response = client + .post(url) + .body( + serde_json::to_string(&body) + .map_err(|e| OllamaError::Serialization(e.to_string()))?, + ) + .send() + .await + .map_err(|e| OllamaError::ApiUnavailable(e.to_string()))?; + let body = response + .text() + .await + .map_err(|e| OllamaError::Api(e.to_string()))?; + let response: OllamaEmbeddingsResponse = + serde_json::from_str(&body).map_err(|e| OllamaError::Parsing(e.to_string()))?; + + Ok(response.embedding) + } + + fn provider(&self) -> LlmProvider { + LlmProvider::Ollama + } + + fn text_completion_model_name(&self) -> String { + self.model().expect("Model not set").to_string() + } + + fn embedding_model_name(&self) -> String { + self.embedding_model() + .expect("Embedding model not set") + .to_string() + } +} diff --git a/src/llm/llm_provider/ollama/mod.rs b/src/llm/llm_provider/ollama/mod.rs new file mode 100644 index 0000000..596eee7 --- /dev/null +++ b/src/llm/llm_provider/ollama/mod.rs @@ -0,0 +1,6 @@ +mod config; +mod llm; +mod models; + +pub use llm::*; +pub use models::*; diff --git a/src/llm/llm_provider/ollama/models.rs b/src/llm/llm_provider/ollama/models.rs new file mode 100644 index 0000000..9cc41f3 --- /dev/null +++ b/src/llm/llm_provider/ollama/models.rs @@ -0,0 +1,158 @@ +use serde::{Deserialize, Serialize}; + +use crate::ollama_model; + +/// Response from the Ollama API for obtaining information about local models. +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#list-running-models). +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaApiModelsMetadata { + pub models: Vec, +} + +/// Response item from the Ollama API for obtaining information about local models. +/// +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#response-22). +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaApiModelMetadata { + /// The name of the model (e.g., "mistral:latest") + pub name: String, + + /// The Ollama identifier of the model (e.g., "mistral:latest") + pub model: String, + + /// Size of the model in bytes + pub size: usize, + + /// Digest of the model using SHA256 (e.g., "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8") + pub digest: String, + + /// Model expiry time in ISO 8601 format (e.g., "2024-06-04T14:38:31.83753-07:00") + pub expires_at: Option, + + /// More details about the model + pub details: OllamaApiModelDetails, +} + +/// Details about a running model in the API for listing running models (`GET /api/ps`). +/// +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#response-22). +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaApiModelDetails { + /// Model identifier that this model is based on + pub parent_model: String, + + /// Format that this model is stored in (e.g., "gguf") + pub format: String, + + /// Model family (e.g., "ollama") + pub family: String, + + /// Parameters of the model (e.g., "7.2B") + pub parameter_size: String, + + /// Quantization level of the model (e.g., "Q4_0" for 4-bit quantization) + pub quantization_level: String, +} + +/// Request for generating a response from the Ollama API. +/// +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-a-completion). +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaGenerateRequest { + /// Model identifier (e.g., "mistral:latest") + pub model: String, + + /// The prompt to generate a response for (e.g., "List all Kubernetes pods") + pub prompt: String, + + /// The context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory + pub context: Option>, + + /// Optional list of base64-encoded images (for multimodal models such as `llava`) + pub images: Option>, + + /// Optional format to use for the response (currently only "json" is supported) + pub format: Option, + + /// Optional flag that controls whether the response is streamed or not (defaults to true). + /// If `false`` the response will be returned as a single response object, rather than a stream of objects + pub stream: Option, + + // System message (overrides what is defined in the Modelfile) + pub system: Option, + + /// Controls how long the model will stay loaded into memory following the request (default: 5m) + pub keep_alive: Option, +} + +impl Default for OllamaGenerateRequest { + fn default() -> Self { + Self { + model: ollama_model::CODESTRAL.to_string(), + prompt: "".to_string(), + stream: Some(false), + format: None, + images: None, + system: Some("You are a helpful assistant".to_string()), + keep_alive: Some("5m".to_string()), + context: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[allow(dead_code)] +pub struct OllamaGenerateResponse { + /// Model identifier (e.g., "mistral:latest") + pub model: String, + + /// Time at which the response was generated (ISO 8601 format) + pub created_at: String, + + /// The response to the prompt + pub response: String, + + /// The encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory + pub context: Option>, + + /// The duration of the response in nanoseconds + pub total_duration: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaGenerateStreamItemResponse { + /// Model identifier (e.g., "mistral:latest") + pub model: String, + + /// Time at which the response was generated (ISO 8601 format) + pub created_at: String, + + /// The response to the prompt + pub response: String, +} + +/// Request for generating an embedding from the Ollama API. +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-embeddings). +/// +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct OllamaEmbeddingsRequest { + /// The string to generate an embedding for. + pub prompt: String, + + /// The model to use for the embedding generation. + pub model: String, +} + +/// Response from the Ollama API for generating an embedding. +/// Referenced from the Ollama API documentation [here](https://github.com/ollama/ollama/blob/fedf71635ec77644f8477a86c6155217d9213a11/docs/api.md#generate-embeddings). +/// +#[allow(dead_code)] +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaEmbeddingsResponse { + /// The embedding for the prompt. + pub embedding: Vec, +} diff --git a/src/llm/llm_provider/openai.rs b/src/llm/llm_provider/openai.rs new file mode 100644 index 0000000..f2dcb1e --- /dev/null +++ b/src/llm/llm_provider/openai.rs @@ -0,0 +1,57 @@ +// use async_trait::async_trait; +// use openai_api_rs::v1::{ +// api::OpenAIClient, +// chat_completion::{self, ChatCompletionRequest}, +// common::{GPT3_5_TURBO, GPT4, GPT4_O}, +// }; + +// pub mod openai_model { +// pub const GPT35_TURBO: &str = GPT35_TURBO; +// pub const GPT4: &str = GPT4; +// pub const GPT40: &str = GPT40; +// } + +// pub struct OpenAi<'a> { +// pub model: &'a str, +// api_key: &'a str, +// } + +// impl<'a> OpenAi<'a> { +// pub fn new(api_key: &'a str, model: &'a str) -> Self { +// Self { api_key, model } +// } +// } + +// #[async_trait] +// impl<'a> TextCompletionLlm for OpenAi<'a> { +// async fn complete( +// &self, +// system_prompts: &[String], +// ) -> Result> { +// let client = OpenAIClient::new(self.api_key.to_owned()); +// let system_msgs = system_prompts +// .iter() +// .map(|p| chat_completion::ChatCompletionMessage { +// role: chat_completion::MessageRole::system, +// content: chat_completion::Content::Text(p.to_owned()), +// name: None, +// tool_calls: None, +// tool_call_id: None, +// }) +// .collect::>(); +// let mut req = ChatCompletionRequest::new(self.model.to_owned(), system_msgs); +// req.max_tokens = Some(self.config.max_tokens as i64); +// req.temperature = Some(self.config.temperature); + +// let result = client.chat_completion(req).await?; +// let completion = result +// .choices +// .first() +// .unwrap() +// .message +// .content +// .clone() +// .unwrap(); +// Ok(completion) +// } +// } diff --git a/src/llm/mod.rs b/src/llm/mod.rs new file mode 100644 index 0000000..446b0ef --- /dev/null +++ b/src/llm/mod.rs @@ -0,0 +1,7 @@ +mod error; +mod llm_provider; +mod models; + +pub use error::*; +pub use llm_provider::*; +pub use models::*; diff --git a/src/llm/models.rs b/src/llm/models.rs new file mode 100644 index 0000000..fa9bf6b --- /dev/null +++ b/src/llm/models.rs @@ -0,0 +1,120 @@ +#![allow(dead_code)] + +use std::pin::Pin; + +use dyn_clone::DynClone; +use serde::{Deserialize, Serialize}; +use tokio_stream::Stream; + +use super::error::LlmError; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum LlmProvider { + #[serde(rename = "ollama")] + Ollama, + #[serde(rename = "openai")] + OpenAi, +} + +/// A trait for LLM providers which implements text completion, embeddings, etc. +/// +/// > `DynClone` is used so that there can be dynamic dispatch of the `Llm` trait, +/// > especially needed for [magic-cli](https://github.com/guywaldman/magic-cli). +pub trait Llm: DynClone { + /// Generates a response from the LLM. + /// + /// # Arguments + /// * `prompt` - The prompt to generate a response for. + /// * `system_prompt` - The system prompt to use for the generation. + /// * `options` - The options for the generation. + /// + /// # Returns + /// A [Result] containing the response from the LLM or an error if there was a problem. + /// + fn text_complete( + &self, + prompt: &str, + system_prompt: &str, + options: TextCompleteOptions, + ) -> impl std::future::Future> + Send; + + /// Generates a streaming response from the LLM. + /// + /// # Arguments + /// * `prompt` - The prompt to generate a response for. + /// * `system_prompt` - The system prompt to use for the generation. + /// * `options` - The options for the generation. + /// + /// # Returns + /// A [Result] containing the response from the LLM or an error if there was a problem. + /// + fn text_complete_stream( + &self, + prompt: &str, + system_prompt: &str, + options: TextCompleteStreamOptions, + ) -> impl std::future::Future> + Send; + + /// Generates an embedding from the LLM. + /// + /// # Arguments + /// * `prompt` - The item to generate an embedding for. + /// + /// # Returns + /// + /// A [Result] containing the embedding or an error if there was a problem. + fn generate_embedding( + &self, + prompt: &str, + ) -> impl std::future::Future, LlmError>> + Send; + + /// Returns the provider of the LLM. + fn provider(&self) -> LlmProvider; + + /// Returns the name of the model used for text completions. + fn text_completion_model_name(&self) -> String; + + /// Returns the name of the model used for embeddings. + fn embedding_model_name(&self) -> String; +} + +#[derive(Debug, Clone, Default)] +pub struct TextCompleteOptions { + /// An encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory. + /// This should be as returned from the previous response. + pub context: Option>, +} + +#[derive(Debug, Clone, Default)] +pub struct TextCompleteStreamOptions { + pub context: Option>, +} + +#[derive(Debug, Clone)] +pub struct TextCompleteResponse { + pub text: String, + // TODO: This is specific to Ollama, context looks differently for other LLM providers. + pub context: Option>, +} + +pub struct TextCompleteStreamResponse { + pub stream: Pin> + Send>>, + // TODO: Handle context with streaming response. + // pub context: Vec, +} + +#[derive(Debug)] +pub(crate) struct SystemPromptResponseOption { + pub scenario: String, + pub type_name: String, + pub response: String, + pub schema: Vec, +} + +#[derive(Debug)] +pub(crate) struct SystemPromptCommandSchemaField { + pub name: String, + pub description: String, + pub typ: String, + pub example: String, +} From eb892d67e70a9fb3e0fc7a9ebcfa8ef1354678c2 Mon Sep 17 00:00:00 2001 From: Guy Waldman <6546430+guywaldman@users.noreply.github.com> Date: Fri, 19 Jul 2024 06:32:15 +0300 Subject: [PATCH 2/2] Fix crates.io API token name in GH Actions --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 10a2965..ee845f7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -18,4 +18,4 @@ jobs: toolchain: stable override: true - name: Publish to crates.io - run: cargo publish --token ${{ secrets.CRATES_IO_TOKEN }} + run: cargo publish --token ${{ secrets.CRATES_IO_API_TOKEN }}