diff --git a/src/generation/chat/mod.rs b/src/generation/chat/mod.rs index 6971587..8648c1a 100644 --- a/src/generation/chat/mod.rs +++ b/src/generation/chat/mod.rs @@ -207,15 +207,14 @@ impl Ollama { /// Without impact for existing history /// Used to prepare history for request fn get_chat_messages_by_id(&mut self, history_id: impl ToString) -> Vec { - let mut binding = { - let new_history = - std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default())); - self.messages_history = Some(new_history); - self.messages_history.clone().unwrap() - }; let chat_history = match self.messages_history.as_mut() { Some(history) => history, - None => &mut binding, + None => { + let new_history = + std::sync::Arc::new(std::sync::RwLock::new(MessagesHistory::default())); + self.messages_history = Some(new_history); + &mut self.messages_history.clone().unwrap() + } }; // Clone the current chat messages to avoid borrowing issues // And not to add message to the history if the request fails diff --git a/tests/chat_history_management.rs b/tests/chat_history_management.rs index a1935bd..1a9901f 100644 --- a/tests/chat_history_management.rs +++ b/tests/chat_history_management.rs @@ -1,4 +1,7 @@ -use ollama_rs::Ollama; +use ollama_rs::{ + generation::chat::{request::ChatMessageRequest, ChatMessage, MessageRole}, + Ollama, +}; #[test] fn test_chat_history_saved_as_should() { @@ -111,3 +114,39 @@ fn test_chat_history_freed_if_limit_exceeded() { assert!(last.is_some()); assert_eq!(last.unwrap().content, "Hi again".to_string()); } + +#[tokio::test] +async fn test_chat_history_accumulated() { + let mut ollama = Ollama::new_default_with_history(30); + let chat_id = "default"; + + assert!(ollama + .send_chat_messages_with_history( + ChatMessageRequest::new( + "granite-code:3b".into(), + vec![ChatMessage::new( + MessageRole::User, + "Why is the sky blue?".into(), + )], + ), + chat_id, + ) + .await + .is_ok()); + + assert!(ollama + .send_chat_messages_with_history( + ChatMessageRequest::new( + "granite-code:3b".into(), + vec![ChatMessage::new( + MessageRole::User, + "But, why is the sky blue?".into() + )] + ), + chat_id + ) + .await + .is_ok()); + + assert_eq!(ollama.get_messages_history(chat_id).unwrap().len(), 4) +}