Skip to content

Commit

Permalink
Use the ollama-rs history
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisTM committed Nov 1, 2024
1 parent 218d8b4 commit 85fe437
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "persona-ai"
version = "1.0.0"
version = "1.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
16 changes: 9 additions & 7 deletions src/commands/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,17 @@ pub async fn run(ctx: &Context, command: &CommandInteraction, persona: Arc<RwLoc
..
}) = command.data.options().first()
{
let history_id = command.channel_id.to_string();
let _ = command.defer(&ctx.http).await;
let prompt = { persona.read().await.get_prompt(&author_name, prompt_slice) };
let response = { persona.read().await.brain.request(&prompt).await };
let response = {
persona
.write()
.await
.brain
.request(&prompt, &history_id)
.await
};
if let Some(response) = response {
let content = format!(
"\nFrom **{author_name}:**```{prompt_slice}```**{}:**```{}```",
Expand All @@ -33,12 +41,6 @@ pub async fn run(ctx: &Context, command: &CommandInteraction, persona: Arc<RwLoc
let builder = EditInteractionResponse::new().content(content);
if let Err(why) = command.edit_response(&ctx.http, builder).await {
println!("Cannot respond to slash command: {why}");
} else {
persona.write().await.set_prompt_response(
&author_name,
prompt_slice,
&response.content,
);
}
} else {
println!("Error with ollama");
Expand Down
3 changes: 2 additions & 1 deletion src/commands/clear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use serenity::prelude::RwLock;
use crate::persona::Persona;

pub async fn run(ctx: &Context, command: &CommandInteraction, persona: Arc<RwLock<Persona>>) {
persona.write().await.clear();
let history_id = command.channel_id.to_string();
persona.write().await.clear(&history_id);
if let Err(why) = command
.create_response(
&ctx.http,
Expand Down
17 changes: 9 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,19 @@ impl EventHandler for Handler {

let prompt = { persona.read().await.get_prompt(&author_name, prompt_slice) };

let response = { persona.read().await.brain.request(&prompt).await };
let history_id = key.to_string();
let response = {
persona
.write()
.await
.brain
.request(&prompt, &history_id)
.await
};
if let Some(response) = response {
if let Err(why) = msg.channel_id.say(&ctx.http, &response.content).await {
println!("Error sending message: {:?}", why);
}
{
persona.write().await.set_prompt_response(
&author_name,
prompt_slice,
&response.content,
);
}
}
}

Expand Down
16 changes: 12 additions & 4 deletions src/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,37 @@ use ollama_rs::{

#[derive(Debug)]
pub struct OllamaAI {
ollama: Ollama,
pub ollama: Ollama,
options: GenerationOptions,
pub model: String,
}

impl OllamaAI {
pub fn new(model: &str, options: GenerationOptions) -> Self {
Self {
ollama: Ollama::default(),
ollama: Ollama::new_default_with_history(100),
options,
model: model.to_owned(),
}
}

pub async fn request(&self, messages: &[ChatMessage]) -> Option<ChatMessage> {
pub async fn request(
&mut self,
messages: &[ChatMessage],
history_id: &str,
) -> Option<ChatMessage> {
let request = ChatMessageRequest::new(self.model.clone(), messages.to_owned());
let response = self
.ollama
.send_chat_messages(request.options(self.options.clone()))
.send_chat_messages_with_history(request.options(self.options.clone()), history_id)
.await;
if let Ok(response) = response {
return response.message;
}
None
}

pub fn clear(&mut self, history_id: &str) {
self.ollama.clear_messages_for_id(history_id);
}
}
55 changes: 20 additions & 35 deletions src/persona.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ use std::clone::Clone;
use std::collections::HashMap;
use std::sync::Arc;

const MAX_RECOLLECTIONS: usize = 20;

// The nursery allows to find the persona we are interested in, in all those servers
pub struct Nursery;
impl TypeMapKey for Nursery {
Expand Down Expand Up @@ -52,8 +50,6 @@ impl Default for PersonaConfig {
pub struct Persona {
pub brain: OllamaAI,
pub config: PersonaConfig,
// The actual live memory of the bot.
recollections: Vec<ChatMessage>,
}

impl Default for Persona {
Expand All @@ -65,22 +61,7 @@ impl Default for Persona {

impl Persona {
pub fn get_prompt(&self, author: &str, prompt: &str) -> Vec<ChatMessage> {
let mut prompts = self.recollections.clone();
prompts.push(ChatMessage::user(format!("{author}: {prompt}").to_owned()));
prompts
}

pub fn set_prompt_response(&mut self, author: &str, prompt: &str, response: &str) {
self.recollections.push(ChatMessage::user(
format!("{author}: {}", prompt).to_owned(),
));
self.recollections
.push(ChatMessage::assistant(response.to_owned()));

if self.recollections.len() > (MAX_RECOLLECTIONS * 2) {
self.recollections.remove(0);
self.recollections.remove(0);
}
vec![ChatMessage::user(format!("{author}: {prompt}").to_owned())]
}

pub fn set_botname(&mut self, name: &str) {
Expand All @@ -92,21 +73,19 @@ impl Persona {
}

// Remove recollections
pub fn clear(&mut self) {
self.recollections.clear();
pub fn clear(&mut self, history_id: &str) {
self.brain.ollama.clear_messages_for_id(history_id);
}

pub fn from_config(config: PersonaConfig) -> Persona {
Persona {
brain: OllamaAI::new(&config.model, config.options.clone()),
recollections: Vec::new(),
config,
}
}

pub fn update_from_config(&mut self, config: PersonaConfig) {
self.brain = OllamaAI::new(&config.model, config.options.clone());
self.recollections = Vec::new();
self.config = config;
}

Expand All @@ -121,24 +100,30 @@ impl Persona {
None
}
}
pub fn get_config(&self) -> String {
let recollections: String = self
.recollections
.iter()
.map(|x| match x.role {
MessageRole::System => format!("System: {}\\nn", x.content),
MessageRole::Assistant => format!("bot: {}\n", x.content),
MessageRole::User => format!("{}\n", x.content),
})
.collect();
pub fn get_config(&mut self, history_id: &str) -> String {
let recollections = self.brain.ollama.get_messages_history(history_id);
let recollections_str = if let Some(recollections) = recollections {
let recollections: String = recollections
.iter()
.map(|x| match x.role {
MessageRole::System => format!("System: {}\\nn", x.content),
MessageRole::Assistant => format!("bot: {}\n", x.content),
MessageRole::User => format!("{}\n", x.content),
})
.collect();
recollections
} else {
"".to_owned()
};

format!(
"{botname} config.
===========
Recollections
---------------
{recollections}\n",
botname = self.config.botname,
recollections = recollections,
recollections = recollections_str,
)
}
}

0 comments on commit 85fe437

Please sign in to comment.