diff --git a/Cargo.lock b/Cargo.lock index ba1f331..f611c66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -711,6 +711,7 @@ dependencies = [ "async-stream", "async-trait", "base64", + "http", "log", "ollama-rs", "regex", diff --git a/Cargo.toml b/Cargo.toml index 6b32046..e88e168 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,18 +22,21 @@ scraper = { version = "0.19.0", optional = true } text-splitter = { version = "0.13.1", optional = true } regex = { version = "1.9.3", optional = true } async-stream = "0.3.5" +http = {version = "1.1.0", optional = true } [features] default = ["reqwest/default-tls"] stream = ["tokio-stream", "reqwest/stream", "tokio"] rustls = ["reqwest/rustls-tls"] chat-history = [] +headers = ["http"] function-calling = ["scraper", "text-splitter", "regex", "chat-history"] [dev-dependencies] tokio = { version = "1", features = ["full"] } ollama-rs = { path = ".", features = [ "stream", + "headers", "chat-history", "function-calling", ] } diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 09d628e..bf511f0 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -32,9 +32,12 @@ impl Ollama { let serialized = serde_json::to_string(&request) .map_err(|e| e.to_string()) .unwrap(); - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await @@ -75,9 +78,12 @@ impl Ollama { let url = format!("{}api/chat", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await diff --git a/src/generation/completion/mod.rs b/src/generation/completion/mod.rs index a7ba8cd..99382d0 100644 --- a/src/generation/completion/mod.rs +++ b/src/generation/completion/mod.rs @@ -32,9 +32,12 @@ impl Ollama { let url = format!("{}api/generate", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await @@ -70,9 +73,12 @@ impl Ollama { let url = format!("{}api/generate", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await diff --git a/src/generation/embeddings/mod.rs b/src/generation/embeddings/mod.rs index b19b579..3c4918f 100644 --- a/src/generation/embeddings/mod.rs +++ b/src/generation/embeddings/mod.rs @@ -16,9 +16,12 @@ impl Ollama { ) -> crate::error::Result { let url = format!("{}api/embed", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await diff --git a/src/headers.rs b/src/headers.rs new file mode 100644 index 0000000..40d49dd --- /dev/null +++ b/src/headers.rs @@ -0,0 +1,22 @@ +use crate::{IntoUrl, Ollama}; + +pub use http::header::*; + +impl Ollama { + /// # Panics + /// + /// Panics if the host is not a valid URL or if the URL cannot have a port. + pub fn new_with_request_headers(host: impl IntoUrl, port: u16, headers: HeaderMap) -> Self { + let mut ollama = Self::new(host, port); + ollama.set_headers(Some(headers)); + + ollama + } + + pub fn set_headers(&mut self, headers: Option) { + match headers { + Some(h) => self.request_headers = h, + None => self.request_headers = HeaderMap::new(), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 6e3369c..3dfaf7a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,8 @@ use url::Url; pub mod error; pub mod generation; +#[cfg(feature = "headers")] +pub mod headers; #[cfg(feature = "chat-history")] pub mod history; pub mod models; @@ -69,6 +71,8 @@ impl IntoUrlSealed for String { pub struct Ollama { pub(crate) url: Url, pub(crate) reqwest_client: reqwest::Client, + #[cfg(feature = "headers")] + pub(crate) request_headers: reqwest::header::HeaderMap, #[cfg(feature = "chat-history")] pub(crate) messages_history: Option, } @@ -145,6 +149,8 @@ impl Default for Ollama { Self { url: Url::parse("http://127.0.0.1:11434").unwrap(), reqwest_client: reqwest::Client::new(), + #[cfg(feature = "headers")] + request_headers: reqwest::header::HeaderMap::new(), #[cfg(feature = "chat-history")] messages_history: None, } diff --git a/src/models/copy.rs b/src/models/copy.rs index 65b7c16..a868a92 100644 --- a/src/models/copy.rs +++ b/src/models/copy.rs @@ -16,9 +16,12 @@ impl Ollama { let url = format!("{}api/copy", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await diff --git a/src/models/create.rs b/src/models/create.rs index eb580a7..7bc75df 100644 --- a/src/models/create.rs +++ b/src/models/create.rs @@ -23,9 +23,12 @@ impl Ollama { let url = format!("{}api/create", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await @@ -65,9 +68,12 @@ impl Ollama { ) -> crate::error::Result { let url = format!("{}api/create", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await diff --git a/src/models/delete.rs b/src/models/delete.rs index ad25d76..d76786f 100644 --- a/src/models/delete.rs +++ b/src/models/delete.rs @@ -9,9 +9,12 @@ impl Ollama { let url = format!("{}api/delete", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .delete(url) + let builder = self.reqwest_client.delete(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await diff --git a/src/models/list_local.rs b/src/models/list_local.rs index f0c86db..6dd86cb 100644 --- a/src/models/list_local.rs +++ b/src/models/list_local.rs @@ -7,12 +7,12 @@ use super::LocalModel; impl Ollama { pub async fn list_local_models(&self) -> crate::error::Result> { let url = format!("{}api/tags", self.url_str()); - let res = self - .reqwest_client - .get(url) - .send() - .await - .map_err(|e| e.to_string())?; + let builder = self.reqwest_client.get(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder.send().await.map_err(|e| e.to_string())?; if !res.status().is_success() { return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into()); diff --git a/src/models/pull.rs b/src/models/pull.rs index b764a19..201ee3a 100644 --- a/src/models/pull.rs +++ b/src/models/pull.rs @@ -30,9 +30,12 @@ impl Ollama { let url = format!("{}api/pull", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await @@ -81,9 +84,12 @@ impl Ollama { let url = format!("{}api/pull", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await diff --git a/src/models/push.rs b/src/models/push.rs index 9ec592e..e3a3264 100644 --- a/src/models/push.rs +++ b/src/models/push.rs @@ -30,9 +30,12 @@ impl Ollama { let url = format!("{}api/push", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await @@ -82,9 +85,12 @@ impl Ollama { let url = format!("{}api/push", self.url_str()); let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await diff --git a/src/models/show_info.rs b/src/models/show_info.rs index e32083f..34db163 100644 --- a/src/models/show_info.rs +++ b/src/models/show_info.rs @@ -10,9 +10,12 @@ impl Ollama { let url = format!("{}api/show", self.url_str()); let serialized = serde_json::to_string(&ModelInfoRequest { model_name }).map_err(|e| e.to_string())?; - let res = self - .reqwest_client - .post(url) + let builder = self.reqwest_client.post(url); + + #[cfg(feature = "headers")] + let builder = builder.headers(self.request_headers.clone()); + + let res = builder .body(serialized) .send() .await