From 715b2c25319411f0ef0653395b4500fd8b48ba14 Mon Sep 17 00:00:00 2001 From: pepperoni21 Date: Wed, 7 Feb 2024 11:36:42 +0100 Subject: [PATCH] Added new keep_alive parameter --- src/generation.rs | 2 +- src/generation/chat/request.rs | 2 +- src/generation/completion/request.rs | 14 +++++++- src/generation/format.rs | 8 ----- src/generation/parameters.rs | 51 ++++++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 11 deletions(-) delete mode 100644 src/generation/format.rs create mode 100644 src/generation/parameters.rs diff --git a/src/generation.rs b/src/generation.rs index 1169523..e739bcd 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -1,6 +1,6 @@ pub mod chat; pub mod completion; pub mod embeddings; -pub mod format; pub mod images; pub mod options; +pub mod parameters; diff --git a/src/generation/chat/request.rs b/src/generation/chat/request.rs index bbc9d8b..184afd0 100644 --- a/src/generation/chat/request.rs +++ b/src/generation/chat/request.rs @@ -1,6 +1,6 @@ use serde::Serialize; -use crate::generation::{format::FormatType, options::GenerationOptions}; +use crate::generation::{options::GenerationOptions, parameters::FormatType}; use super::ChatMessage; diff --git a/src/generation/completion/request.rs b/src/generation/completion/request.rs index 10ccbe4..d306d94 100644 --- a/src/generation/completion/request.rs +++ b/src/generation/completion/request.rs @@ -1,6 +1,10 @@ use serde::Serialize; -use crate::generation::{format::FormatType, images::Image, options::GenerationOptions}; +use crate::generation::{ + images::Image, + options::GenerationOptions, + parameters::{FormatType, KeepAlive}, +}; use super::GenerationContext; @@ -16,6 +20,7 @@ pub struct GenerationRequest { pub template: Option, pub context: Option, pub format: Option, + pub keep_alive: Option, pub(crate) stream: bool, } @@ -30,6 +35,7 @@ impl GenerationRequest { template: None, context: None, format: None, + keep_alive: None, // Stream value will be overwritten by Ollama::generate_stream() and Ollama::generate() methods stream: false, } @@ -76,4 +82,10 @@ impl GenerationRequest { self.format = Some(format); self } + + /// Used to control how long a model stays loaded in memory, by default models are unloaded after 5 minutes of inactivity + pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self { + self.keep_alive = Some(keep_alive); + self + } } diff --git a/src/generation/format.rs b/src/generation/format.rs deleted file mode 100644 index 65bcb09..0000000 --- a/src/generation/format.rs +++ /dev/null @@ -1,8 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// The format to return a response in. Currently the only accepted value is `json` -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(rename_all = "lowercase")] -pub enum FormatType { - Json, -} diff --git a/src/generation/parameters.rs b/src/generation/parameters.rs new file mode 100644 index 0000000..53f53cf --- /dev/null +++ b/src/generation/parameters.rs @@ -0,0 +1,51 @@ +use serde::{Deserialize, Serialize}; + +/// The format to return a response in. Currently the only accepted value is `json` +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum FormatType { + Json, +} + +/// Used to control how long a model stays loaded in memory, by default models are unloaded after 5 minutes of inactivity +#[derive(Debug, Clone)] +pub enum KeepAlive { + Indefinitely, + UnloadOnCompletion, + Until { time: u64, unit: TimeUnit }, +} + +impl Serialize for KeepAlive { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + KeepAlive::Indefinitely => serializer.serialize_i8(-1), + KeepAlive::UnloadOnCompletion => serializer.serialize_i8(0), + KeepAlive::Until { time, unit } => { + let mut s = String::new(); + s.push_str(&time.to_string()); + s.push_str(unit.to_symbol()); + serializer.serialize_str(&s) + } + } + } +} + +#[derive(Debug, Clone)] +pub enum TimeUnit { + Seconds, + Minutes, + Hours, +} + +impl TimeUnit { + pub fn to_symbol(&self) -> &'static str { + match self { + TimeUnit::Seconds => "s", + TimeUnit::Minutes => "m", + TimeUnit::Hours => "hr", + } + } +}