-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c261178
commit 7190cf5
Showing
1 changed file
with
86 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,108 +1,86 @@ | ||
// use crate::prompt::context::Context; | ||
|
||
// use super::simple::LLMPipeline; | ||
// use super::{Pipeline, PipelineResult}; | ||
// use anyhow::Result; | ||
// use serde_json::Value as JsonValue; | ||
// use std::collections::HashMap; | ||
// use std::sync::Arc; | ||
// use tokio::sync::RwLock; | ||
|
||
// pub struct SequentialPipeline { | ||
// /// The name of the LLMPipeline. | ||
// name: String, | ||
|
||
// /// Vector of LLM pipelines used by the SequentialPipeline. | ||
// pipelines: Vec<Arc<RwLock<dyn Pipeline>>>, | ||
|
||
// /// The context for for the templates used by the SequentialPipeline. | ||
// context: HashMap<String, JsonValue>, | ||
// } | ||
|
||
// impl Default for SequentialPipeline { | ||
// fn default() -> Self { | ||
// Self { | ||
// name: uuid::Uuid::new_v4().to_string(), | ||
// pipelines: Vec::new(), | ||
// context: HashMap::new(), | ||
// } | ||
// } | ||
// } | ||
|
||
// impl SequentialPipeline { | ||
// /// Initialize a new sequential pipeline. | ||
// pub fn new() -> SequentialPipeline { | ||
// SequentialPipeline::default() | ||
// } | ||
|
||
// /// Add a simple LLM Pipeline to the sequential pipeline. | ||
// pub fn link(mut self, pipeline: LLMPipeline) -> SequentialPipeline { | ||
// self.pipelines.push(Arc::new(RwLock::new(pipeline))); | ||
// self | ||
// } | ||
// } | ||
|
||
// #[async_trait::async_trait] | ||
// impl Pipeline for SequentialPipeline { | ||
// async fn execute(&self, target: &str) -> Result<PipelineResult> { | ||
// let mut response = String::new(); | ||
// let mut result: PipelineResult = PipelineResult::new(self.name.to_string()); // initialize result to a default value | ||
// for pipeline in &self.pipelines { | ||
// if !response.is_empty() { | ||
// pipeline | ||
// .write() | ||
// .await | ||
// .template_engine() | ||
// .add_to_template(target, &format!("{{{{#user}}}}{}{{{{/user}}}}", response)); | ||
// } | ||
// result = pipeline.read().await.execute(target).await?; | ||
// response = result.content(); | ||
// } | ||
// Ok(result) | ||
// } | ||
|
||
// fn context(&mut self) -> &mut HashMap<String, JsonValue> { | ||
// &mut self.context | ||
// } | ||
|
||
// async fn load_context(&mut self, context: &Context) { | ||
// for pipeline in &mut self.pipelines { | ||
// pipeline.write().await.load_context(context).await; | ||
// } | ||
// } | ||
// } | ||
|
||
// #[cfg(test)] | ||
// mod test { | ||
|
||
// use super::*; | ||
// use crate::llm::openai::OpenAI; | ||
// use serde::Serialize; | ||
|
||
// #[derive(Serialize)] | ||
// pub struct Data { | ||
// play: String, | ||
// } | ||
|
||
// #[tokio::test] | ||
// async fn test_generate() { | ||
// let client = OpenAI::new(); | ||
|
||
// let first = "{{#chat}}{{#user}}Give me a summary of {{play}}'s plot.{{/user}}{{/chat}}"; | ||
// let second = "{{#chat}}{{#system}}You are a professional critic. When given a summary of a play, you must write a review of it. Here is a summary of {{play}}'s plot:{{/system}}{{/chat}}"; | ||
|
||
// let mut pipeline = SequentialPipeline::new() | ||
// .link(LLMPipeline::new(&client).with_template("review", first).unwrap()) | ||
// .link(LLMPipeline::new(&client).with_template("review", second).unwrap()); | ||
// pipeline | ||
// .load_context( | ||
// &Context::new(&Data { | ||
// play: "Hamlet".to_string(), | ||
// }) | ||
// .unwrap(), | ||
// ) | ||
// .await; | ||
// let res = pipeline.execute("review").await; | ||
// assert!(res.is_ok()); | ||
// } | ||
// } | ||
use super::{Pipeline, PipelineResult}; | ||
use anyhow::Result; | ||
use std::sync::Arc; | ||
use tokio::sync::RwLock; | ||
|
||
pub struct SequentialPipeline<P> { | ||
/// The name of the LLMPipeline. | ||
name: String, | ||
|
||
/// Vector of LLM pipelines used by the SequentialPipeline. | ||
pipelines: Vec<Arc<RwLock<P>>>, | ||
} | ||
|
||
impl<P> Default for SequentialPipeline<P> { | ||
fn default() -> Self { | ||
Self { | ||
name: uuid::Uuid::new_v4().to_string(), | ||
pipelines: Vec::new(), | ||
} | ||
} | ||
} | ||
|
||
impl<P: Pipeline> SequentialPipeline<P> { | ||
/// Initialize a new sequential pipeline. | ||
pub fn new() -> SequentialPipeline<P> { | ||
SequentialPipeline::default() | ||
} | ||
|
||
/// Add a simple LLM Pipeline to the sequential pipeline. | ||
pub fn link(mut self, pipeline: P) -> SequentialPipeline<P> { | ||
self.pipelines.push(Arc::new(RwLock::new(pipeline))); | ||
self | ||
} | ||
} | ||
|
||
#[async_trait::async_trait] | ||
impl<P: Pipeline> Pipeline for SequentialPipeline<P> { | ||
async fn execute(&self, target: &str) -> Result<PipelineResult> { | ||
let mut response = String::new(); | ||
let mut result: PipelineResult = PipelineResult::new(self.name.to_string()); // initialize result to a default value | ||
for pipeline in &self.pipelines { | ||
if !response.is_empty() { | ||
pipeline | ||
.write() | ||
.await | ||
.template_engine() | ||
.add_to_template(target, &format!("{{{{#user}}}}{}{{{{/user}}}}", response)); | ||
} | ||
result = pipeline.read().await.execute(target).await?; | ||
response = result.content(); | ||
} | ||
Ok(result) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
|
||
use super::*; | ||
use crate::{llm::openai::OpenAI, pipeline::simple::LLMPipeline, prompt::context::Context}; | ||
use serde::Serialize; | ||
|
||
#[derive(Serialize)] | ||
pub struct Data { | ||
play: String, | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_generate() { | ||
let client = OpenAI::new(); | ||
|
||
let first = "{{#chat}}{{#user}}Give me a summary of {{play}}'s plot.{{/user}}{{/chat}}"; | ||
let second = "{{#chat}}{{#system}}You are a professional critic. When given a summary of a play, you must write a review of it. Here is a summary of {{play}}'s plot:{{/system}}{{/chat}}"; | ||
|
||
let context = &Context::new(&Data { | ||
play: "Hamlet".to_string(), | ||
}) | ||
.unwrap(); | ||
let summary = LLMPipeline::new(&client).load_template("review", first).unwrap().load_context(context).unwrap(); | ||
let review = LLMPipeline::new(&client).load_template("review", second).unwrap().load_context(context).unwrap(); | ||
|
||
let pipeline = SequentialPipeline::new().link(summary).link(review); | ||
let res = pipeline.execute("review").await; | ||
assert!(res.is_ok()); | ||
} | ||
} |