Skip to content

Commit

Permalink
Added new keep_alive parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
pepperoni21 committed Feb 7, 2024
1 parent bda42c9 commit 715b2c2
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/generation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pub mod chat;
pub mod completion;
pub mod embeddings;
pub mod format;
pub mod images;
pub mod options;
pub mod parameters;
2 changes: 1 addition & 1 deletion src/generation/chat/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::Serialize;

use crate::generation::{format::FormatType, options::GenerationOptions};
use crate::generation::{options::GenerationOptions, parameters::FormatType};

use super::ChatMessage;

Expand Down
14 changes: 13 additions & 1 deletion src/generation/completion/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use serde::Serialize;

use crate::generation::{format::FormatType, images::Image, options::GenerationOptions};
use crate::generation::{
images::Image,
options::GenerationOptions,
parameters::{FormatType, KeepAlive},
};

use super::GenerationContext;

Expand All @@ -16,6 +20,7 @@ pub struct GenerationRequest {
pub template: Option<String>,
pub context: Option<GenerationContext>,
pub format: Option<FormatType>,
pub keep_alive: Option<KeepAlive>,
pub(crate) stream: bool,
}

Expand All @@ -30,6 +35,7 @@ impl GenerationRequest {
template: None,
context: None,
format: None,
keep_alive: None,
// Stream value will be overwritten by Ollama::generate_stream() and Ollama::generate() methods
stream: false,
}
Expand Down Expand Up @@ -76,4 +82,10 @@ impl GenerationRequest {
self.format = Some(format);
self
}

/// Used to control how long a model stays loaded in memory, by default models are unloaded after 5 minutes of inactivity
pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
self.keep_alive = Some(keep_alive);
self
}
}
8 changes: 0 additions & 8 deletions src/generation/format.rs

This file was deleted.

51 changes: 51 additions & 0 deletions src/generation/parameters.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use serde::{Deserialize, Serialize};

/// The format to return a response in. Currently the only accepted value is `json`
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum FormatType {
Json,
}

/// Used to control how long a model stays loaded in memory, by default models are unloaded after 5 minutes of inactivity
#[derive(Debug, Clone)]
pub enum KeepAlive {
Indefinitely,
UnloadOnCompletion,
Until { time: u64, unit: TimeUnit },
}

impl Serialize for KeepAlive {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
KeepAlive::Indefinitely => serializer.serialize_i8(-1),
KeepAlive::UnloadOnCompletion => serializer.serialize_i8(0),
KeepAlive::Until { time, unit } => {
let mut s = String::new();
s.push_str(&time.to_string());
s.push_str(unit.to_symbol());
serializer.serialize_str(&s)
}
}
}
}

#[derive(Debug, Clone)]
pub enum TimeUnit {
Seconds,
Minutes,
Hours,
}

impl TimeUnit {
pub fn to_symbol(&self) -> &'static str {
match self {
TimeUnit::Seconds => "s",
TimeUnit::Minutes => "m",
TimeUnit::Hours => "hr",
}
}
}

0 comments on commit 715b2c2

Please sign in to comment.