From a0a807837704e644bc29916ea5bc03712ae7e313 Mon Sep 17 00:00:00 2001 From: Erhan Date: Sat, 25 May 2024 18:58:00 +0300 Subject: [PATCH] Ollama-rs integration (#156) * wrap ollama-rs with embedder * ollama openai config added, todo streams * update example, tiny refactors * added streaming * handle unwraps in stream * `StreamData` value is now `data` in `stream` --- .gitignore | 1 + Cargo.toml | 54 ++++---- README.md | 4 +- examples/embedding_ollama.rs | 2 +- examples/llm_ollama.rs | 14 +- src/embedding/error.rs | 6 + src/embedding/mod.rs | 8 +- src/embedding/ollama/ollama_embedder.rs | 101 +++++++-------- src/language_models/error.rs | 6 + src/llm/claude/client.rs | 4 +- src/llm/mod.rs | 3 + src/llm/ollama/client.rs | 164 ++++++++++++++++++++++++ src/llm/ollama/mod.rs | 4 + src/llm/ollama/openai.rs | 89 +++++++++++++ 14 files changed, 360 insertions(+), 100 deletions(-) create mode 100644 src/llm/ollama/client.rs create mode 100644 src/llm/ollama/mod.rs create mode 100644 src/llm/ollama/openai.rs diff --git a/.gitignore b/.gitignore index 12876e1a..e70adbe6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ target/ Cargo.lock .DS_Store .fastembed_cache +.vscode diff --git a/Cargo.toml b/Cargo.toml index e2589e69..f43746b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,12 +7,7 @@ publish = true repository = "https://github.com/Abraxas-365/langchain-rust" license = "MIT" description = "LangChain for Rust, the easiest way to write LLM-based programs in Rust" -keywords = [ - "chain", - "chatgpt", - "llm", - "langchain", -] # List of keywords related to your crate +keywords = ["chain", "chatgpt", "llm", "langchain"] documentation = "https://langchain-rust.sellie.tech/get-started/quickstart" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -32,16 +27,16 @@ async-openai = "0.21.0" mockito = "1.4.0" tiktoken-rs = "0.5.8" sqlx = { version = "0.7.4", default-features = false, features = [ - "postgres", - "sqlite", - "runtime-tokio-native-tls", - "json", - "uuid", + "postgres", + "sqlite", + "runtime-tokio-native-tls", + "json", + "uuid", ], optional = true } uuid = { version = "1.8.0", features = ["v4"], optional = true } pgvector = { version = "0.3.2", features = [ - "postgres", - "sqlx", + "postgres", + "sqlx", ], optional = true } text-splitter = { version = "0.13", features = ["tiktoken-rs", "markdown"] } surrealdb = { version = "1.4.2", optional = true, default-features = false } @@ -58,13 +53,13 @@ url = "2.5.0" fastembed = "3" flume = { version = "0.11.0", optional = true } gix = { version = "0.62.0", default-features = false, optional = true, features = [ - "parallel", - "revision", - "serde", + "parallel", + "revision", + "serde", ] } opensearch = { version = "2", optional = true, features = ["aws-auth"] } aws-config = { version = "1.2", optional = true, features = [ - "behavior-version-latest", + "behavior-version-latest", ] } glob = "0.3.1" strum_macros = "0.26.2" @@ -77,27 +72,32 @@ tree-sitter-c = { version = "0.21", optional = true } tree-sitter-go = { version = "0.21", optional = true } tree-sitter-python = { version = "0.21", optional = true } tree-sitter-typescript = { version = "0.21", optional = true } -qdrant-client = {version = "1.8.0", optional = true } +qdrant-client = { version = "1.8.0", optional = true } +ollama-rs = { version = "0.1.9", optional = true, features = [ + "stream", + "chat-history", +] } [features] default = [] postgres = ["pgvector", "sqlx", "uuid"] tree-sitter = [ - "cc", - "dep:tree-sitter", - "dep:tree-sitter-rust", - "dep:tree-sitter-cpp", - "dep:tree-sitter-javascript", - "dep:tree-sitter-c", - "dep:tree-sitter-go", - "dep:tree-sitter-python", - "dep:tree-sitter-typescript", + "cc", + "dep:tree-sitter", + "dep:tree-sitter-rust", + "dep:tree-sitter-cpp", + "dep:tree-sitter-javascript", + "dep:tree-sitter-c", + "dep:tree-sitter-go", + "dep:tree-sitter-python", + "dep:tree-sitter-typescript", ] surrealdb = ["dep:surrealdb"] sqlite = ["sqlx"] git = ["gix", "flume"] opensearch = ["dep:opensearch", "aws-config"] qdrant = ["qdrant-client", "uuid"] +ollama = ['ollama-rs'] [dev-dependencies] tokio-test = "0.4.4" diff --git a/README.md b/README.md index 4f3e5b1d..118ccc81 100644 --- a/README.md +++ b/README.md @@ -20,14 +20,14 @@ This is the Rust language implementation of [LangChain](https://github.com/langc - [x] [OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_openai.rs) - [x] [Azure OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_azure_open_ai.rs) - - [x] [Ollama and Compatible Api](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs) + - [x] [Ollama](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs) - [x] [Anthropic Claude](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_anthropic_claude.rs) - Embeddings - [x] [OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_openai.rs) - - [x] [Ollama](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_ollama.rs) - [x] [Azure OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_azure_open_ai.rs) + - [x] [Ollama](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_ollama.rs) - [x] [Local FastEmbed](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_fastembed.rs) - VectorStores diff --git a/examples/embedding_ollama.rs b/examples/embedding_ollama.rs index feb719d2..0fff1332 100644 --- a/examples/embedding_ollama.rs +++ b/examples/embedding_ollama.rs @@ -6,7 +6,7 @@ use langchain_rust::embedding::{ async fn main() { let ollama = OllamaEmbedder::default().with_model("nomic-embed-text"); - let response = ollama.embed_query("What is the sky blue?").await.unwrap(); + let response = ollama.embed_query("Why is the sky blue?").await.unwrap(); println!("{:?}", response); } diff --git a/examples/llm_ollama.rs b/examples/llm_ollama.rs index 9bd92f44..24830a5c 100644 --- a/examples/llm_ollama.rs +++ b/examples/llm_ollama.rs @@ -1,18 +1,8 @@ -use langchain_rust::llm::OpenAIConfig; - -use langchain_rust::{language_models::llm::LLM, llm::openai::OpenAI}; +use langchain_rust::{language_models::llm::LLM, llm::ollama::client::Ollama}; #[tokio::main] async fn main() { - //Since Ollmama is OpenAi compatible - //You can call Ollama this way: - let ollama = OpenAI::default() - .with_config( - OpenAIConfig::default() - .with_api_base("http://localhost:11434/v1") - .with_api_key("ollama"), - ) - .with_model("llama2"); + let ollama = Ollama::default().with_model("llama3"); let response = ollama.invoke("hola").await.unwrap(); println!("{}", response); diff --git a/src/embedding/error.rs b/src/embedding/error.rs index e2aa4cdf..3a4e16b0 100644 --- a/src/embedding/error.rs +++ b/src/embedding/error.rs @@ -1,4 +1,6 @@ use async_openai::error::OpenAIError; +#[cfg(feature = "ollama")] +use ollama_rs::error::OllamaError; use reqwest::{Error as ReqwestError, StatusCode}; use thiserror::Error; @@ -21,4 +23,8 @@ pub enum EmbedderError { #[error("FastEmbed error: {0}")] FastEmbedError(String), + + #[cfg(feature = "ollama")] + #[error("Ollama error: {0}")] + OllamaError(#[from] OllamaError), } diff --git a/src/embedding/mod.rs b/src/embedding/mod.rs index da49278d..8b88344c 100644 --- a/src/embedding/mod.rs +++ b/src/embedding/mod.rs @@ -1,7 +1,13 @@ +mod error; + pub mod embedder_trait; pub use embedder_trait::*; -mod error; + +#[cfg(feature = "ollama")] pub mod ollama; +#[cfg(feature = "ollama")] +pub use ollama::*; + pub mod openai; pub use error::*; diff --git a/src/embedding/ollama/ollama_embedder.rs b/src/embedding/ollama/ollama_embedder.rs index 37eea29e..2619d4c9 100644 --- a/src/embedding/ollama/ollama_embedder.rs +++ b/src/embedding/ollama/ollama_embedder.rs @@ -1,27 +1,29 @@ -#![allow(dead_code)] +use std::sync::Arc; use crate::embedding::{embedder_trait::Embedder, EmbedderError}; use async_trait::async_trait; -use reqwest::{Client, Url}; -use serde::Deserialize; -use serde_json::json; - -#[derive(Debug, Deserialize)] -struct EmbeddingResponse { - embedding: Vec, -} +use ollama_rs::{generation::options::GenerationOptions, Ollama as OllamaClient}; #[derive(Debug)] pub struct OllamaEmbedder { + pub(crate) client: Arc, pub(crate) model: String, - pub(crate) base_url: String, + pub(crate) options: Option, } +/// [nomic-embed-text](https://ollama.com/library/nomic-embed-text) is a 137M parameters, 274MB model. +const DEFAULT_MODEL: &str = "nomic-embed-text"; + impl OllamaEmbedder { - pub fn new>(model: S, base_url: S) -> Self { - OllamaEmbedder { + pub fn new>( + client: Arc, + model: S, + options: Option, + ) -> Self { + Self { + client, model: model.into(), - base_url: base_url.into(), + options, } } @@ -30,17 +32,16 @@ impl OllamaEmbedder { self } - pub fn with_api_base>(mut self, base_url: S) -> Self { - self.base_url = base_url.into(); + pub fn with_options(mut self, options: GenerationOptions) -> Self { + self.options = Some(options); self } } impl Default for OllamaEmbedder { fn default() -> Self { - let model = String::from("nomic-embed-text"); - let base_url = String::from("http://localhost:11434"); - OllamaEmbedder::new(model, base_url) + let client = Arc::new(OllamaClient::default()); + Self::new(client, String::from(DEFAULT_MODEL), None) } } @@ -48,29 +49,16 @@ impl Default for OllamaEmbedder { impl Embedder for OllamaEmbedder { async fn embed_documents(&self, documents: &[String]) -> Result>, EmbedderError> { log::debug!("Embedding documents: {:?}", documents); - let client = Client::new(); - let url = Url::parse(&format!("{}{}", self.base_url, "/api/embeddings"))?; let mut embeddings = Vec::with_capacity(documents.len()); for doc in documents { - let res = client - .post(url.clone()) - .json(&json!({ - "prompt": doc, - "model": &self.model, - })) - .send() + let res = self + .client + .generate_embeddings(self.model.clone(), doc.clone(), self.options.clone()) .await?; - if res.status() != 200 { - log::error!("Error from OLLAMA: {}", &res.status()); - return Err(EmbedderError::HttpError { - status_code: res.status(), - error_message: format!("Received non-200 response: {}", res.status()), - }); - } - let data: EmbeddingResponse = res.json().await?; - embeddings.push(data.embedding); + + embeddings.push(res.embeddings); } Ok(embeddings) @@ -78,26 +66,29 @@ impl Embedder for OllamaEmbedder { async fn embed_query(&self, text: &str) -> Result, EmbedderError> { log::debug!("Embedding query: {:?}", text); - let client = Client::new(); - let url = Url::parse(&format!("{}{}", self.base_url, "/api/embeddings"))?; - - let res = client - .post(url) - .json(&json!({ - "prompt": text, - "model": &self.model, - })) - .send() + + let res = self + .client + .generate_embeddings(self.model.clone(), text.to_string(), self.options.clone()) .await?; - if res.status() != 200 { - log::error!("Error from OLLAMA: {}", &res.status()); - return Err(EmbedderError::HttpError { - status_code: res.status(), - error_message: format!("Received non-200 response: {}", res.status()), - }); - } - let data: EmbeddingResponse = res.json().await?; - Ok(data.embedding) + Ok(res.embeddings) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[ignore] + async fn test_ollama_embed() { + let ollama = OllamaEmbedder::default() + .with_model("nomic-embed-text") + .with_options(GenerationOptions::default().temperature(0.5)); + + let response = ollama.embed_query("Why is the sky blue?").await.unwrap(); + + assert_eq!(response.len(), 768); } } diff --git a/src/language_models/error.rs b/src/language_models/error.rs index b4b64827..3127c4bd 100644 --- a/src/language_models/error.rs +++ b/src/language_models/error.rs @@ -1,4 +1,6 @@ use async_openai::error::OpenAIError; +#[cfg(feature = "ollama")] +use ollama_rs::error::OllamaError; use reqwest::Error as ReqwestError; use serde_json::Error as SerdeJsonError; use thiserror::Error; @@ -14,6 +16,10 @@ pub enum LLMError { #[error("Anthropic error: {0}")] AnthropicError(#[from] AnthropicError), + #[cfg(feature = "ollama")] + #[error("Ollama error: {0}")] + OllamaError(#[from] OllamaError), + #[error("Network request failed: {0}")] RequestError(#[from] ReqwestError), diff --git a/src/llm/claude/client.rs b/src/llm/claude/client.rs index eac555d5..6979eb7b 100644 --- a/src/llm/claude/client.rs +++ b/src/llm/claude/client.rs @@ -178,8 +178,8 @@ impl LLM for Claude { .build()?; // Instead of sending the request directly, return a stream wrapper - let stream = client.execute(request).await?.bytes_stream(); - + let stream = client.execute(request).await?; + let stream = stream.bytes_stream(); // Process each chunk as it arrives let processed_stream = stream.then(move |result| { async move { diff --git a/src/llm/mod.rs b/src/llm/mod.rs index b106bcf6..43dd90cd 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -3,3 +3,6 @@ pub use openai::*; pub mod claude; pub use claude::*; + +pub mod ollama; +pub use ollama::*; diff --git a/src/llm/ollama/client.rs b/src/llm/ollama/client.rs new file mode 100644 index 00000000..11377b5a --- /dev/null +++ b/src/llm/ollama/client.rs @@ -0,0 +1,164 @@ +use crate::{ + language_models::{llm::LLM, GenerateResult, LLMError, TokenUsage}, + schemas::{Message, MessageType, StreamData}, +}; +use async_trait::async_trait; +use futures::Stream; +use ollama_rs::{ + error::OllamaError, + generation::{ + chat::{request::ChatMessageRequest, ChatMessage, MessageRole}, + options::GenerationOptions, + }, + Ollama as OllamaClient, +}; +use std::pin::Pin; +use std::sync::Arc; +use tokio_stream::StreamExt; + +#[derive(Debug, Clone)] +pub struct Ollama { + pub(crate) client: Arc, + pub(crate) model: String, + pub(crate) options: Option, +} + +/// [llama3](https://ollama.com/library/llama3) is a 8B parameters, 4.7GB model. +const DEFAULT_MODEL: &str = "llama3"; + +impl Ollama { + pub fn new>( + client: Arc, + model: S, + options: Option, + ) -> Self { + Ollama { + client, + model: model.into(), + options, + } + } + + pub fn with_model>(mut self, model: S) -> Self { + self.model = model.into(); + self + } + + pub fn with_options(mut self, options: GenerationOptions) -> Self { + self.options = Some(options); + self + } + + fn generate_request(&self, messages: &[Message]) -> ChatMessageRequest { + let mapped_messages = messages.iter().map(|message| message.into()).collect(); + ChatMessageRequest::new(self.model.clone(), mapped_messages) + } +} + +impl From<&Message> for ChatMessage { + fn from(message: &Message) -> Self { + ChatMessage { + content: message.content.clone(), + images: None, + role: message.message_type.clone().into(), + } + } +} + +impl From for MessageRole { + fn from(message_type: MessageType) -> Self { + match message_type { + MessageType::AIMessage => MessageRole::Assistant, + MessageType::ToolMessage => MessageRole::Assistant, + MessageType::SystemMessage => MessageRole::System, + MessageType::HumanMessage => MessageRole::User, + } + } +} + +impl Default for Ollama { + fn default() -> Self { + let client = Arc::new(OllamaClient::default()); + Ollama::new(client, String::from(DEFAULT_MODEL), None) + } +} + +#[async_trait] +impl LLM for Ollama { + async fn generate(&self, messages: &[Message]) -> Result { + let request = self.generate_request(messages); + let result = self.client.send_chat_messages(request).await?; + + let generation = match result.message { + Some(message) => message.content, + None => return Err(OllamaError::from("No message in response".to_string()).into()), + }; + + let tokens = result.final_data.map(|final_data| { + let prompt_tokens = final_data.prompt_eval_count as u32; + let completion_tokens = final_data.eval_count as u32; + TokenUsage { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + } + }); + + Ok(GenerateResult { tokens, generation }) + } + + async fn stream( + &self, + messages: &[Message], + ) -> Result> + Send>>, LLMError> { + let request = self.generate_request(messages); + let result = self.client.send_chat_messages_stream(request).await?; + + let stream = result.map(|data| match data { + Ok(data) => match data.message.clone() { + Some(message) => Ok(StreamData::new( + serde_json::to_value(data).unwrap_or_default(), + message.content, + )), + // TODO: no need to return error, see https://github.com/Abraxas-365/langchain-rust/issues/140 + None => Err(LLMError::ContentNotFound( + "No message in response".to_string(), + )), + }, + Err(_) => Err(OllamaError::from("Stream error".to_string()).into()), + }); + + Ok(Box::pin(stream)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncWriteExt; + use tokio_stream::StreamExt; + + #[tokio::test] + #[ignore] + async fn test_generate() { + let ollama = Ollama::default().with_model("llama3"); + let response = ollama.invoke("Hey Macarena, ay").await.unwrap(); + println!("{}", response); + } + + #[tokio::test] + #[ignore] + async fn test_stream() { + let ollama = Ollama::default().with_model("llama3"); + + let message = Message::new_human_message("Why does water boil at 100 degrees?"); + let mut stream = ollama.stream(&vec![message]).await.unwrap(); + let mut stdout = tokio::io::stdout(); + while let Some(res) = stream.next().await { + let data = res.unwrap(); + stdout.write(data.content.as_bytes()).await.unwrap(); + } + stdout.write(b"\n").await.unwrap(); + stdout.flush().await.unwrap(); + } +} diff --git a/src/llm/ollama/mod.rs b/src/llm/ollama/mod.rs new file mode 100644 index 00000000..1d73fa14 --- /dev/null +++ b/src/llm/ollama/mod.rs @@ -0,0 +1,4 @@ +#[cfg(feature = "ollama")] +pub mod client; + +pub mod openai; diff --git a/src/llm/ollama/openai.rs b/src/llm/ollama/openai.rs new file mode 100644 index 00000000..24119b88 --- /dev/null +++ b/src/llm/ollama/openai.rs @@ -0,0 +1,89 @@ +use async_openai::config::Config; +use reqwest::header::HeaderMap; +use secrecy::Secret; +use serde::Deserialize; + +const OLLAMA_API_BASE: &str = "http://localhost:11434/v1"; + +/// Ollama has [OpenAI compatiblity](https://ollama.com/blog/openai-compatibility), meaning that you can use it as an OpenAI API. +/// +/// This struct implements the `Config` trait of OpenAI, and has the necessary setup for OpenAI configurations for you to use Ollama. +/// +/// ## Example +/// +/// ```rs +/// let ollama = OpenAI::new(OllamaConfig::default()).with_model("llama3"); +/// let response = ollama.invoke("Say hello!").await.unwrap(); +/// ``` +#[derive(Clone, Debug, Deserialize)] +#[serde(default)] +pub struct OllamaConfig { + api_key: Secret, +} + +impl OllamaConfig { + pub fn new() -> Self { + Self::default() + } +} + +impl Config for OllamaConfig { + fn api_key(&self) -> &Secret { + &self.api_key + } + + fn api_base(&self) -> &str { + OLLAMA_API_BASE + } + + fn headers(&self) -> HeaderMap { + HeaderMap::default() + } + + fn query(&self) -> Vec<(&str, &str)> { + vec![] + } + + fn url(&self, path: &str) -> String { + format!("{}{}", self.api_base(), path) + } +} + +impl Default for OllamaConfig { + fn default() -> Self { + Self { + api_key: Secret::new("ollama".to_string()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{language_models::llm::LLM, llm::openai::OpenAI, schemas::Message}; + use tokio::io::AsyncWriteExt; + use tokio_stream::StreamExt; + + #[tokio::test] + #[ignore] + async fn test_ollama_openai() { + let ollama = OpenAI::new(OllamaConfig::default()).with_model("llama2"); + let response = ollama.invoke("hola").await.unwrap(); + println!("{}", response); + } + + #[tokio::test] + #[ignore] + async fn test_ollama_openai_stream() { + let ollama = OpenAI::new(OllamaConfig::default()).with_model("phi3"); + + let message = Message::new_human_message("Why does water boil at 100 degrees?"); + let mut stream = ollama.stream(&vec![message]).await.unwrap(); + let mut stdout = tokio::io::stdout(); + while let Some(res) = stream.next().await { + let data = res.unwrap(); + stdout.write(data.content.as_bytes()).await.unwrap(); + } + stdout.flush().await.unwrap(); + } +}