Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama-rs integration #156

Merged
merged 6 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ target/
Cargo.lock
.DS_Store
.fastembed_cache
.vscode
54 changes: 27 additions & 27 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@ publish = true
repository = "https://github.com/Abraxas-365/langchain-rust"
license = "MIT"
description = "LangChain for Rust, the easiest way to write LLM-based programs in Rust"
keywords = [
"chain",
"chatgpt",
"llm",
"langchain",
] # List of keywords related to your crate
keywords = ["chain", "chatgpt", "llm", "langchain"]
documentation = "https://langchain-rust.sellie.tech/get-started/quickstart"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -32,16 +27,16 @@ async-openai = "0.21.0"
mockito = "1.4.0"
tiktoken-rs = "0.5.8"
sqlx = { version = "0.7.4", default-features = false, features = [
"postgres",
"sqlite",
"runtime-tokio-native-tls",
"json",
"uuid",
"postgres",
"sqlite",
"runtime-tokio-native-tls",
"json",
"uuid",
], optional = true }
uuid = { version = "1.8.0", features = ["v4"], optional = true }
pgvector = { version = "0.3.2", features = [
"postgres",
"sqlx",
"postgres",
"sqlx",
], optional = true }
text-splitter = { version = "0.13", features = ["tiktoken-rs", "markdown"] }
surrealdb = { version = "1.4.2", optional = true, default-features = false }
Expand All @@ -58,13 +53,13 @@ url = "2.5.0"
fastembed = "3"
flume = { version = "0.11.0", optional = true }
gix = { version = "0.62.0", default-features = false, optional = true, features = [
"parallel",
"revision",
"serde",
"parallel",
"revision",
"serde",
] }
opensearch = { version = "2", optional = true, features = ["aws-auth"] }
aws-config = { version = "1.2", optional = true, features = [
"behavior-version-latest",
"behavior-version-latest",
] }
glob = "0.3.1"
strum_macros = "0.26.2"
Expand All @@ -77,27 +72,32 @@ tree-sitter-c = { version = "0.21", optional = true }
tree-sitter-go = { version = "0.21", optional = true }
tree-sitter-python = { version = "0.21", optional = true }
tree-sitter-typescript = { version = "0.21", optional = true }
qdrant-client = {version = "1.8.0", optional = true }
qdrant-client = { version = "1.8.0", optional = true }
ollama-rs = { version = "0.1.9", optional = true, features = [
"stream",
"chat-history",
] }

[features]
default = []
postgres = ["pgvector", "sqlx", "uuid"]
tree-sitter = [
"cc",
"dep:tree-sitter",
"dep:tree-sitter-rust",
"dep:tree-sitter-cpp",
"dep:tree-sitter-javascript",
"dep:tree-sitter-c",
"dep:tree-sitter-go",
"dep:tree-sitter-python",
"dep:tree-sitter-typescript",
"cc",
"dep:tree-sitter",
"dep:tree-sitter-rust",
"dep:tree-sitter-cpp",
"dep:tree-sitter-javascript",
"dep:tree-sitter-c",
"dep:tree-sitter-go",
"dep:tree-sitter-python",
"dep:tree-sitter-typescript",
]
surrealdb = ["dep:surrealdb"]
sqlite = ["sqlx"]
git = ["gix", "flume"]
opensearch = ["dep:opensearch", "aws-config"]
qdrant = ["qdrant-client", "uuid"]
ollama = ['ollama-rs']

[dev-dependencies]
tokio-test = "0.4.4"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ This is the Rust language implementation of [LangChain](https://github.com/langc

- [x] [OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_openai.rs)
- [x] [Azure OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_azure_open_ai.rs)
- [x] [Ollama and Compatible Api](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs)
- [x] [Ollama](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs)
- [x] [Anthropic Claude](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_anthropic_claude.rs)

- Embeddings

- [x] [OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_openai.rs)
- [x] [Ollama](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_ollama.rs)
- [x] [Azure OpenAi](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_azure_open_ai.rs)
- [x] [Ollama](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_ollama.rs)
- [x] [Local FastEmbed](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/embedding_fastembed.rs)

- VectorStores
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding_ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use langchain_rust::embedding::{
async fn main() {
let ollama = OllamaEmbedder::default().with_model("nomic-embed-text");

let response = ollama.embed_query("What is the sky blue?").await.unwrap();
let response = ollama.embed_query("Why is the sky blue?").await.unwrap();

println!("{:?}", response);
}
14 changes: 2 additions & 12 deletions examples/llm_ollama.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
use langchain_rust::llm::OpenAIConfig;

use langchain_rust::{language_models::llm::LLM, llm::openai::OpenAI};
use langchain_rust::{language_models::llm::LLM, llm::ollama::client::Ollama};

#[tokio::main]
async fn main() {
//Since Ollmama is OpenAi compatible
//You can call Ollama this way:
let ollama = OpenAI::default()
.with_config(
OpenAIConfig::default()
.with_api_base("http://localhost:11434/v1")
.with_api_key("ollama"),
)
.with_model("llama2");
let ollama = Ollama::default().with_model("llama3");

let response = ollama.invoke("hola").await.unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth saying Hello instead of hola.

println!("{}", response);
Expand Down
6 changes: 6 additions & 0 deletions src/embedding/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use async_openai::error::OpenAIError;
#[cfg(feature = "ollama")]
use ollama_rs::error::OllamaError;
use reqwest::{Error as ReqwestError, StatusCode};
use thiserror::Error;

Expand All @@ -21,4 +23,8 @@ pub enum EmbedderError {

#[error("FastEmbed error: {0}")]
FastEmbedError(String),

#[cfg(feature = "ollama")]
#[error("Ollama error: {0}")]
OllamaError(#[from] OllamaError),
}
8 changes: 7 additions & 1 deletion src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
mod error;

pub mod embedder_trait;
pub use embedder_trait::*;
mod error;

#[cfg(feature = "ollama")]
pub mod ollama;
#[cfg(feature = "ollama")]
pub use ollama::*;

pub mod openai;
pub use error::*;

Expand Down
101 changes: 46 additions & 55 deletions src/embedding/ollama/ollama_embedder.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
#![allow(dead_code)]
use std::sync::Arc;

use crate::embedding::{embedder_trait::Embedder, EmbedderError};
use async_trait::async_trait;
use reqwest::{Client, Url};
use serde::Deserialize;
use serde_json::json;

#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
embedding: Vec<f64>,
}
use ollama_rs::{generation::options::GenerationOptions, Ollama as OllamaClient};

#[derive(Debug)]
pub struct OllamaEmbedder {
pub(crate) client: Arc<OllamaClient>,
pub(crate) model: String,
pub(crate) base_url: String,
pub(crate) options: Option<GenerationOptions>,
}

/// [nomic-embed-text](https://ollama.com/library/nomic-embed-text) is a 137M parameters, 274MB model.
const DEFAULT_MODEL: &str = "nomic-embed-text";

impl OllamaEmbedder {
pub fn new<S: Into<String>>(model: S, base_url: S) -> Self {
OllamaEmbedder {
pub fn new<S: Into<String>>(
client: Arc<OllamaClient>,
model: S,
options: Option<GenerationOptions>,
) -> Self {
Self {
client,
model: model.into(),
base_url: base_url.into(),
options,
}
}

Expand All @@ -30,74 +32,63 @@ impl OllamaEmbedder {
self
}

pub fn with_api_base<S: Into<String>>(mut self, base_url: S) -> Self {
self.base_url = base_url.into();
pub fn with_options(mut self, options: GenerationOptions) -> Self {
self.options = Some(options);
self
}
}

impl Default for OllamaEmbedder {
fn default() -> Self {
let model = String::from("nomic-embed-text");
let base_url = String::from("http://localhost:11434");
OllamaEmbedder::new(model, base_url)
let client = Arc::new(OllamaClient::default());
Self::new(client, String::from(DEFAULT_MODEL), None)
}
}

#[async_trait]
impl Embedder for OllamaEmbedder {
async fn embed_documents(&self, documents: &[String]) -> Result<Vec<Vec<f64>>, EmbedderError> {
log::debug!("Embedding documents: {:?}", documents);
let client = Client::new();
let url = Url::parse(&format!("{}{}", self.base_url, "/api/embeddings"))?;

let mut embeddings = Vec::with_capacity(documents.len());

for doc in documents {
let res = client
.post(url.clone())
.json(&json!({
"prompt": doc,
"model": &self.model,
}))
.send()
let res = self
.client
.generate_embeddings(self.model.clone(), doc.clone(), self.options.clone())
.await?;
if res.status() != 200 {
log::error!("Error from OLLAMA: {}", &res.status());
return Err(EmbedderError::HttpError {
status_code: res.status(),
error_message: format!("Received non-200 response: {}", res.status()),
});
}
let data: EmbeddingResponse = res.json().await?;
embeddings.push(data.embedding);

embeddings.push(res.embeddings);
}

Ok(embeddings)
}

async fn embed_query(&self, text: &str) -> Result<Vec<f64>, EmbedderError> {
log::debug!("Embedding query: {:?}", text);
let client = Client::new();
let url = Url::parse(&format!("{}{}", self.base_url, "/api/embeddings"))?;

let res = client
.post(url)
.json(&json!({
"prompt": text,
"model": &self.model,
}))
.send()

let res = self
.client
.generate_embeddings(self.model.clone(), text.to_string(), self.options.clone())
.await?;

if res.status() != 200 {
log::error!("Error from OLLAMA: {}", &res.status());
return Err(EmbedderError::HttpError {
status_code: res.status(),
error_message: format!("Received non-200 response: {}", res.status()),
});
}
let data: EmbeddingResponse = res.json().await?;
Ok(data.embedding)
Ok(res.embeddings)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
#[ignore]
async fn test_ollama_embed() {
let ollama = OllamaEmbedder::default()
.with_model("nomic-embed-text")
.with_options(GenerationOptions::default().temperature(0.5));

let response = ollama.embed_query("Why is the sky blue?").await.unwrap();

assert_eq!(response.len(), 768);
}
}
6 changes: 6 additions & 0 deletions src/language_models/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use async_openai::error::OpenAIError;
#[cfg(feature = "ollama")]
use ollama_rs::error::OllamaError;
use reqwest::Error as ReqwestError;
use serde_json::Error as SerdeJsonError;
use thiserror::Error;
Expand All @@ -14,6 +16,10 @@ pub enum LLMError {
#[error("Anthropic error: {0}")]
AnthropicError(#[from] AnthropicError),

#[cfg(feature = "ollama")]
#[error("Ollama error: {0}")]
OllamaError(#[from] OllamaError),

#[error("Network request failed: {0}")]
RequestError(#[from] ReqwestError),

Expand Down
4 changes: 2 additions & 2 deletions src/llm/claude/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ impl LLM for Claude {
.build()?;

// Instead of sending the request directly, return a stream wrapper
let stream = client.execute(request).await?.bytes_stream();

let stream = client.execute(request).await?;
let stream = stream.bytes_stream();
// Process each chunk as it arrives
let processed_stream = stream.then(move |result| {
async move {
Expand Down
3 changes: 3 additions & 0 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ pub use openai::*;

pub mod claude;
pub use claude::*;

pub mod ollama;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these just be mod ollama; instead of pub mod ollama;.

Since this already exists probably can be fixed separately.

pub use ollama::*;
Loading
Loading