Skip to content

Commit

Permalink
Ollama-rs integration (#156)
Browse files Browse the repository at this point in the history
* wrap ollama-rs with embedder

* ollama openai config added, todo streams

* update example, tiny refactors

* added streaming

* handle unwraps in stream

* `StreamData` value is now `data` in `stream`
  • Loading branch information
erhant authored May 25, 2024
1 parent ff058b4 commit a0a8078
Show file tree
Hide file tree
Showing 14 changed files with 360 additions and 100 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ target/
Cargo.lock
.DS_Store
.fastembed_cache
.vscode
54 changes: 27 additions & 27 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@ publish = true
repository = "https://github.com/Abraxas-365/langchain-rust"
license = "MIT"
description = "LangChain for Rust, the easiest way to write LLM-based programs in Rust"
keywords = [
"chain",
"chatgpt",
"llm",
"langchain",
] # List of keywords related to your crate
keywords = ["chain", "chatgpt", "llm", "langchain"]
documentation = "https://langchain-rust.sellie.tech/get-started/quickstart"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -32,16 +27,16 @@ async-openai = "0.21.0"
mockito = "1.4.0"
tiktoken-rs = "0.5.8"
sqlx = { version = "0.7.4", default-features = false, features = [
"postgres",
"sqlite",
"runtime-tokio-native-tls",
"json",
"uuid",
"postgres",
"sqlite",
"runtime-tokio-native-tls",
"json",
"uuid",
], optional = true }
uuid = { version = "1.8.0", features = ["v4"], optional = true }
pgvector = { version = "0.3.2", features = [
"postgres",
"sqlx",
"postgres",
"sqlx",
], optional = true }
text-splitter = { version = "0.13", features = ["tiktoken-rs", "markdown"] }
surrealdb = { version = "1.4.2", optional = true, default-features = false }
Expand All @@ -58,13 +53,13 @@ url = "2.5.0"
fastembed = "3"
flume = { version = "0.11.0", optional = true }
gix = { version = "0.62.0", default-features = false, optional = true, features = [
"parallel",
"revision",
"serde",
"parallel",
"revision",
"serde",
] }
opensearch = { version = "2", optional = true, features = ["aws-auth"] }
aws-config = { version = "1.2", optional = true, features = [
"behavior-version-latest",
"behavior-version-latest",
] }
glob = "0.3.1"
strum_macros = "0.26.2"
Expand All @@ -77,27 +72,32 @@ tree-sitter-c = { version = "0.21", optional = true }
tree-sitter-go = { version = "0.21", optional = true }
tree-sitter-python = { version = "0.21", optional = true }
tree-sitter-typescript = { version = "0.21", optional = true }
qdrant-client = {version = "1.8.0", optional = true }
qdrant-client = { version = "1.8.0", optional = true }
ollama-rs = { version = "0.1.9", optional = true, features = [
"stream",
"chat-history",
] }

[features]
default = []
postgres = ["pgvector", "sqlx", "uuid"]
tree-sitter = [
"cc",
"dep:tree-sitter",
"dep:tree-sitter-rust",
"dep:tree-sitter-cpp",
"dep:tree-sitter-javascript",
"dep:tree-sitter-c",
"dep:tree-sitter-go",
"dep:tree-sitter-python",
"dep:tree-sitter-typescript",
"cc",
"dep:tree-sitter",
"dep:tree-sitter-rust",
"dep:tree-sitter-cpp",
"dep:tree-sitter-javascript",
"dep:tree-sitter-c",
"dep:tree-sitter-go",
"dep:tree-sitter-python",
"dep:tree-sitter-typescript",
]
surrealdb = ["dep:surrealdb"]
sqlite = ["sqlx"]
git = ["gix", "flume"]
opensearch = ["dep:opensearch", "aws-config"]
qdrant = ["qdrant-client", "uuid"]
ollama = ['ollama-rs']

[dev-dependencies]
tokio-test = "0.4.4"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ This is the Rust language implementation of [LangChain](https://github.com/langc

- [x] [OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_openai.rs)
- [x] [Azure OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_azure_open_ai.rs)
- [x] [Ollama and Compatible Api](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs)
- [x] [Ollama](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs)
- [x] [Anthropic Claude](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_anthropic_claude.rs)

- Embeddings

- [x] [OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_openai.rs)
- [x] [Ollama](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_ollama.rs)
- [x] [Azure OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_azure_open_ai.rs)
- [x] [Ollama](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_ollama.rs)
- [x] [Local FastEmbed](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_fastembed.rs)

- VectorStores
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding_ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use langchain_rust::embedding::{
async fn main() {
let ollama = OllamaEmbedder::default().with_model("nomic-embed-text");

let response = ollama.embed_query("What is the sky blue?").await.unwrap();
let response = ollama.embed_query("Why is the sky blue?").await.unwrap();

println!("{:?}", response);
}
14 changes: 2 additions & 12 deletions examples/llm_ollama.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
use langchain_rust::llm::OpenAIConfig;

use langchain_rust::{language_models::llm::LLM, llm::openai::OpenAI};
use langchain_rust::{language_models::llm::LLM, llm::ollama::client::Ollama};

#[tokio::main]
async fn main() {
//Since Ollmama is OpenAi compatible
//You can call Ollama this way:
let ollama = OpenAI::default()
.with_config(
OpenAIConfig::default()
.with_api_base("http://localhost:11434/v1")
.with_api_key("ollama"),
)
.with_model("llama2");
let ollama = Ollama::default().with_model("llama3");

let response = ollama.invoke("hola").await.unwrap();
println!("{}", response);
Expand Down
6 changes: 6 additions & 0 deletions src/embedding/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use async_openai::error::OpenAIError;
#[cfg(feature = "ollama")]
use ollama_rs::error::OllamaError;
use reqwest::{Error as ReqwestError, StatusCode};
use thiserror::Error;

Expand All @@ -21,4 +23,8 @@ pub enum EmbedderError {

#[error("FastEmbed error: {0}")]
FastEmbedError(String),

#[cfg(feature = "ollama")]
#[error("Ollama error: {0}")]
OllamaError(#[from] OllamaError),
}
8 changes: 7 additions & 1 deletion src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
mod error;

pub mod embedder_trait;
pub use embedder_trait::*;
mod error;

#[cfg(feature = "ollama")]
pub mod ollama;
#[cfg(feature = "ollama")]
pub use ollama::*;

pub mod openai;
pub use error::*;

Expand Down
101 changes: 46 additions & 55 deletions src/embedding/ollama/ollama_embedder.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
#![allow(dead_code)]
use std::sync::Arc;

use crate::embedding::{embedder_trait::Embedder, EmbedderError};
use async_trait::async_trait;
use reqwest::{Client, Url};
use serde::Deserialize;
use serde_json::json;

#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
embedding: Vec<f64>,
}
use ollama_rs::{generation::options::GenerationOptions, Ollama as OllamaClient};

#[derive(Debug)]
pub struct OllamaEmbedder {
pub(crate) client: Arc<OllamaClient>,
pub(crate) model: String,
pub(crate) base_url: String,
pub(crate) options: Option<GenerationOptions>,
}

/// [nomic-embed-text](https://ollama.com/library/nomic-embed-text) is a 137M parameters, 274MB model.
const DEFAULT_MODEL: &str = "nomic-embed-text";

impl OllamaEmbedder {
pub fn new<S: Into<String>>(model: S, base_url: S) -> Self {
OllamaEmbedder {
pub fn new<S: Into<String>>(
client: Arc<OllamaClient>,
model: S,
options: Option<GenerationOptions>,
) -> Self {
Self {
client,
model: model.into(),
base_url: base_url.into(),
options,
}
}

Expand All @@ -30,74 +32,63 @@ impl OllamaEmbedder {
self
}

pub fn with_api_base<S: Into<String>>(mut self, base_url: S) -> Self {
self.base_url = base_url.into();
pub fn with_options(mut self, options: GenerationOptions) -> Self {
self.options = Some(options);
self
}
}

impl Default for OllamaEmbedder {
fn default() -> Self {
let model = String::from("nomic-embed-text");
let base_url = String::from("http://localhost:11434");
OllamaEmbedder::new(model, base_url)
let client = Arc::new(OllamaClient::default());
Self::new(client, String::from(DEFAULT_MODEL), None)
}
}

#[async_trait]
impl Embedder for OllamaEmbedder {
async fn embed_documents(&self, documents: &[String]) -> Result<Vec<Vec<f64>>, EmbedderError> {
log::debug!("Embedding documents: {:?}", documents);
let client = Client::new();
let url = Url::parse(&format!("{}{}", self.base_url, "/api/embeddings"))?;

let mut embeddings = Vec::with_capacity(documents.len());

for doc in documents {
let res = client
.post(url.clone())
.json(&json!({
"prompt": doc,
"model": &self.model,
}))
.send()
let res = self
.client
.generate_embeddings(self.model.clone(), doc.clone(), self.options.clone())
.await?;
if res.status() != 200 {
log::error!("Error from OLLAMA: {}", &res.status());
return Err(EmbedderError::HttpError {
status_code: res.status(),
error_message: format!("Received non-200 response: {}", res.status()),
});
}
let data: EmbeddingResponse = res.json().await?;
embeddings.push(data.embedding);

embeddings.push(res.embeddings);
}

Ok(embeddings)
}

async fn embed_query(&self, text: &str) -> Result<Vec<f64>, EmbedderError> {
log::debug!("Embedding query: {:?}", text);
let client = Client::new();
let url = Url::parse(&format!("{}{}", self.base_url, "/api/embeddings"))?;

let res = client
.post(url)
.json(&json!({
"prompt": text,
"model": &self.model,
}))
.send()

let res = self
.client
.generate_embeddings(self.model.clone(), text.to_string(), self.options.clone())
.await?;

if res.status() != 200 {
log::error!("Error from OLLAMA: {}", &res.status());
return Err(EmbedderError::HttpError {
status_code: res.status(),
error_message: format!("Received non-200 response: {}", res.status()),
});
}
let data: EmbeddingResponse = res.json().await?;
Ok(data.embedding)
Ok(res.embeddings)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
#[ignore]
async fn test_ollama_embed() {
let ollama = OllamaEmbedder::default()
.with_model("nomic-embed-text")
.with_options(GenerationOptions::default().temperature(0.5));

let response = ollama.embed_query("Why is the sky blue?").await.unwrap();

assert_eq!(response.len(), 768);
}
}
6 changes: 6 additions & 0 deletions src/language_models/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use async_openai::error::OpenAIError;
#[cfg(feature = "ollama")]
use ollama_rs::error::OllamaError;
use reqwest::Error as ReqwestError;
use serde_json::Error as SerdeJsonError;
use thiserror::Error;
Expand All @@ -14,6 +16,10 @@ pub enum LLMError {
#[error("Anthropic error: {0}")]
AnthropicError(#[from] AnthropicError),

#[cfg(feature = "ollama")]
#[error("Ollama error: {0}")]
OllamaError(#[from] OllamaError),

#[error("Network request failed: {0}")]
RequestError(#[from] ReqwestError),

Expand Down
4 changes: 2 additions & 2 deletions src/llm/claude/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ impl LLM for Claude {
.build()?;

// Instead of sending the request directly, return a stream wrapper
let stream = client.execute(request).await?.bytes_stream();

let stream = client.execute(request).await?;
let stream = stream.bytes_stream();
// Process each chunk as it arrives
let processed_stream = stream.then(move |result| {
async move {
Expand Down
3 changes: 3 additions & 0 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ pub use openai::*;

pub mod claude;
pub use claude::*;

pub mod ollama;
pub use ollama::*;
Loading

0 comments on commit a0a8078

Please sign in to comment.