Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add half-streaming agents #274

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 203 additions & 10 deletions src/agent/executor.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::pin::Pin;
use std::{collections::HashMap, sync::Arc};

use async_stream::stream;
use async_trait::async_trait;
use serde_json::json;
use futures::Stream;
use serde_json::{json, Value};
use tokio::sync::Mutex;

use super::{agent::Agent, AgentError};
use crate::schemas::{LogTools, Message};
use crate::schemas::{LogTools, Message, StreamData};
use crate::{
chain::{chain_trait::Chain, ChainError},
language_models::GenerateResult,
Expand Down Expand Up @@ -137,16 +140,20 @@ where

let mut tools_ai_message_seen: HashMap<String, ()> = HashMap::default();
for (action, observation) in steps {
let LogTools { tool_id, tools } = serde_json::from_str(&action.log)?;
let tools_value: serde_json::Value = serde_json::from_str(&tools)?;
if tools_ai_message_seen.insert(tools, ()).is_none() {
memory.add_message(
Message::new_ai_message("").with_tool_calls(tools_value),
);
if let Ok(LogTools { tool_id, tools }) = serde_json::from_str(action.log.trim_matches('`').trim_start_matches("json")) {
if let Ok(tools_value) = serde_json::from_str(&tools){
if tools_ai_message_seen.insert(tools, ()).is_none() {
memory.add_message(
Message::new_ai_message("").with_tool_calls(tools_value),
);
}
} // The else isn't really coverable.
memory.add_message(Message::new_tool_message(observation, tool_id));
} else {
log::debug!("LogTools deserialization failed, expecting non-OpenAI tool call and falling back to System message");
memory.add_message(Message::new_system_message(observation));
}
memory.add_message(Message::new_tool_message(observation, tool_id));
}

memory.add_ai_message(&finish.output);
}
return Ok(GenerateResult {
Expand All @@ -171,4 +178,190 @@ where
let result = self.call(input_variables).await?;
Ok(result.generation)
}

async fn stream<'life>(
&'life self,
input_variables: PromptArgs,
) -> Result<
Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send + 'life>>,
ChainError,
> {
let mut input_variables = input_variables.clone();
let name_to_tools = self.get_name_to_tools();
let mut steps: Vec<(AgentAction, String)> = Vec::new();
log::debug!("steps: {:?}", steps);
if let Some(memory) = &self.memory {
let memory = memory.lock().await;
input_variables.insert("chat_history".to_string(), json!(memory.messages()));
} else {
input_variables.insert(
"chat_history".to_string(),
json!(SimpleMemory::new().messages()),
);
}

// let my_agent = self.agent.clone();

let main_stream = stream! {

// pin_mut!(steps);
// let input_variables = pin!(input_variables);
// let name_to_tools = pin!(name_to_tools);

loop {
let agent_event = self
.agent
.plan(&steps, input_variables.clone())
.await
.map_err(|e| ChainError::AgentError(format!("Error in agent planning: {}", e)))?;
match agent_event {
AgentEvent::Action(actions) => {
for action in actions {
log::debug!("Action: {:?}", action.tool_input);
let tool = name_to_tools
.get(&action.tool)
.ok_or_else(|| {
AgentError::ToolError(format!("Tool {} not found", action.tool))
})
.map_err(|e| ChainError::AgentError(e.to_string()))?;

let observation_result = tool.call(&action.tool_input).await;

match observation_result.map_err(|e| Box::new(e.to_string())) {
Ok(result) => {
let observation = result;
steps.push((action, observation));
}
Err(err_str) => {
log::info!(
"The tool return the following error: {}",
err_str
);
if self.break_if_error {
let intermed_err = AgentError::ToolError(*err_str).to_string();
yield Err(ChainError::AgentError(
intermed_err,
));
return;
} else {
let observation = format!("The tool return the following error: {}", err_str); //TODO add clause to yield here
steps.push((action, observation));
}
}
}

}
}
AgentEvent::Finish(finish) => {
if let Some(memory) = &self.memory { //FIXME: This would be a problem if the lifetime of memory is not 'self_life
let mut memory = memory.lock().await;

memory.add_user_message(match &input_variables["input"] {
// This avoids adding extra quotes to the user input in the history.
serde_json::Value::String(s) => s,
x => x, // this the json encoded value.
});

let mut tools_ai_message_seen: HashMap<String, ()> = HashMap::default();
for (action, observation) in steps {
if let Ok(LogTools { tool_id, tools }) = serde_json::from_str(action.log.trim_matches('`').trim_start_matches("json")) {
if let Ok(tools_value) = serde_json::from_str(&tools){
if tools_ai_message_seen.insert(tools, ()).is_none() {
memory.add_message(
Message::new_ai_message("").with_tool_calls(tools_value),
);
}
} // The else isn't really coverable.

memory.add_message(Message::new_tool_message(observation, tool_id));
} else {
log::debug!("LogTools deserialization failed, expecting non-OpenAI tool call and falling back to System message");
memory.add_message(Message::new_system_message(observation));
}
}

memory.add_ai_message(&finish.output);
}


// yield Ok(GenerateResult {
// generation: finish.output,
// ..Default::default()
// });
yield Ok(StreamData {
value: json!({"generation": finish.output.clone()}), //TODO: this might be a problem
content: finish.output.clone(),
tokens: None,
});
return;
}
}

if let Some(max_iterations) = self.max_iterations {
if steps.len() >= max_iterations as usize {
// yield Ok(GenerateResult {
// generation: "Max iterations reached".to_string(),
// ..Default::default()
// });
yield Ok(StreamData {
value: Value::String("Max iterations reached".to_string()), //TODO: this might be a problem
content: "Max iterations reached".to_string(),
tokens: None,
});
return;
}
}
}
};

return Ok(Box::pin(main_stream));
}
}

#[cfg(feature = "ollama")]
#[cfg(test)]
mod test {
use std::sync::Arc;

use futures_util::StreamExt;

use crate::agent::{AgentExecutor, ConversationalAgentBuilder};
use crate::chain::options::ChainCallOptions;
use crate::chain::Chain;
use crate::prompt_args;

use crate::tools::CommandExecutor;

#[cfg(feature = "ollama")]
use crate::{llm::client::Ollama, memory::SimpleMemory};
#[cfg(feature = "ollama")]
#[tokio::test]
async fn streaming_agent() {
let llm = Ollama::default().with_model("llama3.2");
let memory = SimpleMemory::new();
let command_executor = CommandExecutor::default();
let agent = ConversationalAgentBuilder::new()
.tools(&[Arc::new(command_executor)])
.options(ChainCallOptions::default().with_max_tokens(1000))
.build(llm)
.expect("Failed to build agent");

let executor = AgentExecutor::from_agent(agent).with_memory(memory.into());

let input_variables = prompt_args! {
"input" => "What is the name of the current directory? Do not add any backticks when using the command executor.",
};

let mut result_stream = executor
.stream(input_variables.clone())
.await
.expect("Failed to execute agent");

println!("Created stream");

while let Some(content) = result_stream.next().await {
println!("\n\ncontent: {:?}\n", content);
}
println!("Finished streaming agent");
}
}
6 changes: 3 additions & 3 deletions src/chain/chain_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ pub trait Chain: Sync + Send {
/// # };
/// ```
///
async fn stream(
&self,
async fn stream<'self_life>(
&'self_life self,
_input_variables: PromptArgs,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send + 'self_life>>, ChainError>
{
log::warn!("stream not implemented for this chain");
unimplemented!()
Expand Down
6 changes: 3 additions & 3 deletions src/chain/conversational/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ impl Chain for ConversationalChain {
Ok(result)
}

async fn stream(
&self,
async fn stream<'self_life>(
&'self_life self,
input_variables: PromptArgs,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send + 'self_life>>, ChainError>
{
let input_variable = &input_variables
.get(&self.input_key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ impl Chain for ConversationalRetrieverChain {
Ok(result)
}

async fn stream(
&self,
async fn stream<'self_life>(
&'self_life self,
input_variables: PromptArgs,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send + 'self_life>>, ChainError>
{
let input_variable = &input_variables
.get(&self.input_key)
Expand Down
6 changes: 3 additions & 3 deletions src/chain/llm_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ impl Chain for LLMChain {
Ok(output)
}

async fn stream(
&self,
async fn stream<'self_life>(
&'self_life self,
input_variables: PromptArgs,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send + 'self_life>>, ChainError>
{
let prompt = self.prompt.format_prompt(input_variables.clone())?;
log::debug!("Prompt: {:?}", prompt);
Expand Down
6 changes: 3 additions & 3 deletions src/chain/question_answering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ impl Chain for CondenseQuestionGeneratorChain {
self.chain.call(input_variables).await
}

async fn stream(
&self,
async fn stream<'self_life>(
&'self_life self,
input_variables: PromptArgs,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send + 'self_life>>, ChainError>
{
self.chain.stream(input_variables).await
}
Expand Down
6 changes: 3 additions & 3 deletions src/chain/sql_datbase/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ impl Chain for SQLDatabaseChain {
Ok(result.generation)
}

async fn stream(
&self,
async fn stream<'self_life>(
&'self_life self,
input_variables: PromptArgs,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send + 'self_life>>, ChainError>
{
let (llm_inputs, _) = self.call_builder_chains(&input_variables).await?;

Expand Down
6 changes: 3 additions & 3 deletions src/chain/stuff_documents/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ impl Chain for StuffDocument {
self.llm_chain.call(input_values).await
}

async fn stream(
&self,
async fn stream<'self_life>(
&'self_life self,
input_variables: PromptArgs,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send>>, ChainError>
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, ChainError>> + Send + 'self_life>>, ChainError>
{
let docs = input_variables
.get(&self.input_key)
Expand Down
4 changes: 2 additions & 2 deletions src/llm/ollama/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ mod tests {
use tokio_stream::StreamExt;

#[tokio::test]
#[ignore]
// #[ignore]
async fn test_ollama_openai() {
let ollama = OpenAI::new(OllamaConfig::default()).with_model("llama2");
let response = ollama.invoke("hola").await.unwrap();
println!("{}", response);
}

#[tokio::test]
#[ignore]
// #[ignore]
async fn test_ollama_openai_stream() {
let ollama = OpenAI::new(OllamaConfig::default()).with_model("phi3");

Expand Down