Skip to content

Commit

Permalink
Merge branch 'pepperoni21:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
andthattoo authored Sep 26, 2024
2 parents 00c67cf + a40ea0c commit 1182725
Show file tree
Hide file tree
Showing 14 changed files with 122 additions and 48 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@ scraper = { version = "0.19.0", optional = true }
text-splitter = { version = "0.13.1", optional = true }
regex = { version = "1.9.3", optional = true }
async-stream = "0.3.5"
http = {version = "1.1.0", optional = true }

[features]
default = ["reqwest/default-tls"]
stream = ["tokio-stream", "reqwest/stream", "tokio"]
rustls = ["reqwest/rustls-tls"]
chat-history = []
headers = ["http"]
function-calling = ["scraper", "text-splitter", "regex", "chat-history"]

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
ollama-rs = { path = ".", features = [
"stream",
"headers",
"chat-history",
"function-calling",
] }
Expand Down
18 changes: 12 additions & 6 deletions src/generation/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ impl Ollama {
let serialized = serde_json::to_string(&request)
.map_err(|e| e.to_string())
.unwrap();
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down Expand Up @@ -75,9 +78,12 @@ impl Ollama {

let url = format!("{}api/chat", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down
18 changes: 12 additions & 6 deletions src/generation/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ impl Ollama {

let url = format!("{}api/generate", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down Expand Up @@ -70,9 +73,12 @@ impl Ollama {

let url = format!("{}api/generate", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down
9 changes: 6 additions & 3 deletions src/generation/embeddings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ impl Ollama {
) -> crate::error::Result<GenerateEmbeddingsResponse> {
let url = format!("{}api/embed", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down
22 changes: 22 additions & 0 deletions src/headers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::{IntoUrl, Ollama};

pub use http::header::*;

impl Ollama {
/// # Panics
///
/// Panics if the host is not a valid URL or if the URL cannot have a port.
pub fn new_with_request_headers(host: impl IntoUrl, port: u16, headers: HeaderMap) -> Self {
let mut ollama = Self::new(host, port);
ollama.set_headers(Some(headers));

ollama
}

pub fn set_headers(&mut self, headers: Option<HeaderMap>) {
match headers {
Some(h) => self.request_headers = h,
None => self.request_headers = HeaderMap::new(),
}
}
}
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use url::Url;

pub mod error;
pub mod generation;
#[cfg(feature = "headers")]
pub mod headers;
#[cfg(feature = "chat-history")]
pub mod history;
pub mod models;
Expand Down Expand Up @@ -69,6 +71,8 @@ impl IntoUrlSealed for String {
pub struct Ollama {
pub(crate) url: Url,
pub(crate) reqwest_client: reqwest::Client,
#[cfg(feature = "headers")]
pub(crate) request_headers: reqwest::header::HeaderMap,
#[cfg(feature = "chat-history")]
pub(crate) messages_history: Option<WrappedMessageHistory>,
}
Expand Down Expand Up @@ -145,6 +149,8 @@ impl Default for Ollama {
Self {
url: Url::parse("http://127.0.0.1:11434").unwrap(),
reqwest_client: reqwest::Client::new(),
#[cfg(feature = "headers")]
request_headers: reqwest::header::HeaderMap::new(),
#[cfg(feature = "chat-history")]
messages_history: None,
}
Expand Down
9 changes: 6 additions & 3 deletions src/models/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ impl Ollama {

let url = format!("{}api/copy", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down
18 changes: 12 additions & 6 deletions src/models/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ impl Ollama {

let url = format!("{}api/create", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down Expand Up @@ -65,9 +68,12 @@ impl Ollama {
) -> crate::error::Result<CreateModelStatus> {
let url = format!("{}api/create", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down
9 changes: 6 additions & 3 deletions src/models/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ impl Ollama {

let url = format!("{}api/delete", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.delete(url)
let builder = self.reqwest_client.delete(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down
12 changes: 6 additions & 6 deletions src/models/list_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use super::LocalModel;
impl Ollama {
pub async fn list_local_models(&self) -> crate::error::Result<Vec<LocalModel>> {
let url = format!("{}api/tags", self.url_str());
let res = self
.reqwest_client
.get(url)
.send()
.await
.map_err(|e| e.to_string())?;
let builder = self.reqwest_client.get(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder.send().await.map_err(|e| e.to_string())?;

if !res.status().is_success() {
return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into());
Expand Down
18 changes: 12 additions & 6 deletions src/models/pull.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ impl Ollama {

let url = format!("{}api/pull", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down Expand Up @@ -81,9 +84,12 @@ impl Ollama {

let url = format!("{}api/pull", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down
18 changes: 12 additions & 6 deletions src/models/push.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ impl Ollama {

let url = format!("{}api/push", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down Expand Up @@ -82,9 +85,12 @@ impl Ollama {

let url = format!("{}api/push", self.url_str());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down
9 changes: 6 additions & 3 deletions src/models/show_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ impl Ollama {
let url = format!("{}api/show", self.url_str());
let serialized =
serde_json::to_string(&ModelInfoRequest { model_name }).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(url)
let builder = self.reqwest_client.post(url);

#[cfg(feature = "headers")]
let builder = builder.headers(self.request_headers.clone());

let res = builder
.body(serialized)
.send()
.await
Expand Down

0 comments on commit 1182725

Please sign in to comment.