From 55f4c20b12da9c843773d4af166b2f9978ae918e Mon Sep 17 00:00:00 2001 From: Guy Waldman <6546430+guywaldman@users.noreply.github.com> Date: Sun, 21 Jul 2024 23:13:24 +0300 Subject: [PATCH] Fix proc macro, simplify, refactor (#9) --- README.md | 76 ++-- core/Cargo.toml | 6 +- core/README.md | 143 +------- core/examples/embeddings.rs | 6 +- .../structured_data_generation_basic.rs | 76 ---- .../structured_data_generation_blog.rs | 89 +++-- .../structured_data_generation_capital.rs | 93 +++++ core/examples/text_generation.rs | 6 +- core/examples/text_generation_stream.rs | 6 +- core/examples/variants_derive.rs | 63 ++++ core/src/execution/builder.rs | 24 +- core/src/execution/executor.rs | 39 +- core/src/lib.rs | 127 +------ core/src/lm/lm_provider/models.rs | 25 ++ core/src/lm/lm_provider/ollama/builder.rs | 24 +- core/src/lm/lm_provider/openai/builder.rs | 39 +- core/src/lm/models.rs | 2 +- core/src/response.rs | 2 + response/Cargo.toml | 2 +- response/src/lib.rs | 14 +- response_derive/Cargo.toml | 4 +- response_derive/src/attribute_impl.rs | 6 +- response_derive/src/derive_impl.rs | 336 ++++++++++++------ response_derive/src/lib.rs | 24 +- 24 files changed, 624 insertions(+), 608 deletions(-) delete mode 100644 core/examples/structured_data_generation_basic.rs create mode 100644 core/examples/structured_data_generation_capital.rs create mode 100644 core/examples/variants_derive.rs create mode 100644 core/src/response.rs diff --git a/README.md b/README.md index d691422..706f177 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ orch = "*" # Substitute with the latest version ## Simple Text Generation -```rust +```rust no_run use orch::execution::*; use orch::lm::*; @@ -42,7 +42,7 @@ async fn main() { ## Streaming Text Generation -```rust +```rust no_run use orch::execution::*; use orch::lm::*; use tokio_stream::StreamExt; @@ -67,54 +67,74 @@ async fn main() { ## Structured Data Generation -```rust +```rust no_run use orch::execution::*; use orch::lm::*; -use orch_response_derive::*; +use orch::response::*; -#[derive(OrchResponseOptions)] -pub enum CapitalCityExecutorResponseOptions { - #[response( - scenario = "You know the capital city of the country", - description = "Capital city of the country" - )] +#[derive(Variants, serde::Deserialize)] +pub enum ResponseVariants { + Answer(AnswerResponseVariant), + Fail(FailResponseVariant), +} + +#[derive(Variant, serde::Deserialize)] +#[variant( + variant = "Answer", + scenario = "You know the capital city of the country", + description = "Capital city of the country" +)] +pub struct AnswerResponseVariant { #[schema( - field = "capital", description = "Capital city of the received country", example = "London" )] - Answer { capital: String }, - #[response( - scenario = "You don't know the capital city of the country", - description = "Reason why the capital city is not known" - )] + pub capital: String, +} + +#[derive(Variant, serde::Deserialize)] +#[variant( + variant = "Fail", + scenario = "You don't know the capital city of the country", + description = "Reason why the capital city is not known" +)] +pub struct FailResponseVariant { #[schema( - field = "reason", description = "Reason why the capital city is not known", example = "Country 'foobar' does not exist" )] - Fail { reason: String }, + pub reason: String, } #[tokio::main] async fn main() { - let lm = OllamaBuilder::new().try_build().unwrap(); - let executor = StructuredExecutorBuilder::new() + let lm = OllamaBuilder::new().try_build().unwrap(); + let executor = StructuredExecutorBuilder::new() .with_lm(&lm) .with_preamble("You are a geography expert who helps users understand the capital city of countries around the world.") - .with_options(&options!(CapitalCityExecutorResponseOptions)) + .with_options(&variants!(ResponseVariants)) .try_build() .unwrap(); - let response = executor.execute("What is the capital of Fooland?").await.expect("Execution failed"); - - println!("Response:"); - println!("{:?}", response.content); + let response = executor + .execute("What is the capital of Fooland?") + .await + .expect("Execution failed"); + + println!("Response:"); + match response.content { + ResponseVariants::Answer(answer) => { + println!("Capital city: {}", answer.capital); + } + ResponseVariants::Fail(fail) => { + println!("Model failed to generate a response: {}", fail.reason); + } + } } ``` ## Embedding Generation -```rust +```rust no_run use orch::execution::*; use orch::lm::*; @@ -138,7 +158,3 @@ async fn main() { ## More Examples See the [examples](https://github.com/guywaldman/orch/tree/main/core/examples) directory for usage examples. - -## Roadmap - -- [ ] Agents and tools diff --git a/core/Cargo.toml b/core/Cargo.toml index 10cb971..f4fd065 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "orch" -version = "0.0.11" +version = "0.0.12" edition = "2021" license = "MIT" description = "Language model orchestration library" @@ -9,8 +9,8 @@ repository = "https://github.com/guywaldman/orch" keywords = ["llm", "openai", "ollama", "rust"] [dependencies] -orch_response = { path = "../response", version = "0.0.11" } -orch_response_derive = { path = "../response_derive", version = "0.0.11" } +orch_response = { path = "../response", version = "0.0.12" } +orch_response_derive = { path = "../response_derive", version = "0.0.12" } async-gen = "0.2.3" dotenv = "0.15.0" dyn-clone = "1.0.17" diff --git a/core/README.md b/core/README.md index d691422..f4f0004 100644 --- a/core/README.md +++ b/core/README.md @@ -1,144 +1,3 @@ # orch -![Crates.io Version](https://img.shields.io/crates/v/orch?link=https%3A%2F%2Fcrates.io%2Fcrates%2Forch) -![Crates.io Total Downloads](https://img.shields.io/crates/d/orch?link=https%3A%2F%2Fcrates.io%2Fcrates%2Forch) - -`orch` is a library for building language model powered applications and agents for the Rust programming language. -It was primarily built for usage in [magic-cli](https://github.com/guywaldman/magic-cli), but can be used in other contexts as well. - -> [!NOTE] -> -> If the project gains traction, this can be compiled as an addon to other languages such as Python or a standalone WebAssembly module. - -# Installation - -```shell -cargo add orch -``` - -Alternatively, add `orch as a dependency to your `Cargo.toml` file: - -```toml -[dependencies] -orch = "*" # Substitute with the latest version -``` - -# Basic Usage - -## Simple Text Generation - -```rust -use orch::execution::*; -use orch::lm::*; - -#[tokio::main] -async fn main() { - let lm = OllamaBuilder::new().try_build().unwrap(); - let executor = TextExecutorBuilder::new().with_lm(&lm).try_build().unwrap(); - let response = executor.execute("What is 2+2?").await.expect("Execution failed"); - println!("{}", response.content); -} -``` - -## Streaming Text Generation - -```rust -use orch::execution::*; -use orch::lm::*; -use tokio_stream::StreamExt; - -#[tokio::main] -async fn main() { - let lm = OllamaBuilder::new().try_build().unwrap(); - let executor = TextExecutorBuilder::new().with_lm(&lm).try_build().unwrap(); - let mut response = executor.execute_stream("What is 2+2?").await.expect("Execution failed"); - while let Some(chunk) = response.stream.next().await { - match chunk { - Ok(chunk) => print!("{chunk}"), - Err(e) => { - println!("Error: {e}"); - break; - } - } - } - println!(); -} -``` - -## Structured Data Generation - -```rust -use orch::execution::*; -use orch::lm::*; -use orch_response_derive::*; - -#[derive(OrchResponseOptions)] -pub enum CapitalCityExecutorResponseOptions { - #[response( - scenario = "You know the capital city of the country", - description = "Capital city of the country" - )] - #[schema( - field = "capital", - description = "Capital city of the received country", - example = "London" - )] - Answer { capital: String }, - #[response( - scenario = "You don't know the capital city of the country", - description = "Reason why the capital city is not known" - )] - #[schema( - field = "reason", - description = "Reason why the capital city is not known", - example = "Country 'foobar' does not exist" - )] - Fail { reason: String }, -} - -#[tokio::main] -async fn main() { - let lm = OllamaBuilder::new().try_build().unwrap(); - let executor = StructuredExecutorBuilder::new() - .with_lm(&lm) - .with_preamble("You are a geography expert who helps users understand the capital city of countries around the world.") - .with_options(&options!(CapitalCityExecutorResponseOptions)) - .try_build() - .unwrap(); - let response = executor.execute("What is the capital of Fooland?").await.expect("Execution failed"); - - println!("Response:"); - println!("{:?}", response.content); -} -``` - -## Embedding Generation - -```rust -use orch::execution::*; -use orch::lm::*; - -#[tokio::main] -async fn main() { - let lm = OllamaBuilder::new().try_build().unwrap(); - let executor = TextExecutorBuilder::new() - .with_lm(&lm) - .try_build() - .unwrap(); - let embedding = executor - .generate_embedding("Phrase to generate an embedding for") - .await - .expect("Execution failed"); - - println!("Embedding:"); - println!("{:?}", embedding); -} -``` - -## More Examples - -See the [examples](https://github.com/guywaldman/orch/tree/main/core/examples) directory for usage examples. - -## Roadmap - -- [ ] Agents and tools +See the [main README](../README.md) for more information. diff --git a/core/examples/embeddings.rs b/core/examples/embeddings.rs index 36c2d70..24f7eab 100644 --- a/core/examples/embeddings.rs +++ b/core/examples/embeddings.rs @@ -1,5 +1,5 @@ //! This example demonstrates how to use the `Executor` to generate embeddings from the language model. -//! We construct an `Ollama` instance and use it to generate embeddings. +//! Run like so: `cargo run --example embeddings` use orch::execution::*; use orch::lm::*; @@ -26,13 +26,13 @@ async fn main() { let lm: Box = match provider { LanguageModelProvider::Ollama => Box::new( OllamaBuilder::new() - .with_embeddings_model(ollama_embedding_model::NOMIC_EMBED_TEXT) + .with_embeddings_model(ollama_embedding_model::NOMIC_EMBED_TEXT.to_string()) .try_build() .unwrap(), ), LanguageModelProvider::OpenAi => Box::new( OpenAiBuilder::new() - .with_api_key(&open_ai_api_key) + .with_api_key(open_ai_api_key) .try_build() .unwrap(), ), diff --git a/core/examples/structured_data_generation_basic.rs b/core/examples/structured_data_generation_basic.rs deleted file mode 100644 index 6fc0c37..0000000 --- a/core/examples/structured_data_generation_basic.rs +++ /dev/null @@ -1,76 +0,0 @@ -#![allow(dead_code)] -//! This example demonstrates how to use the `Executor` to generate a structured response from the LLM. - -use orch::execution::*; -use orch::lm::*; -use orch::response::*; - -#[derive(OrchResponseOptions)] -pub enum CapitalCityExecutorResponseOptions { - #[response( - scenario = "You know the capital city of the country", - description = "Capital city of the country" - )] - #[schema( - field = "capital", - description = "Capital city of the received country", - example = "London" - )] - #[schema( - field = "country", - description = "Country of the received capital city", - example = "United Kingdom" - )] - Answer { capital: String, country: String }, - #[response( - scenario = "You don't know the capital city of the country", - description = "Reason why the capital city is not known" - )] - #[schema( - field = "reason", - description = "Reason why the capital city is not known", - example = "Country 'foobar' does not exist" - )] - Fail { reason: String }, -} - -#[tokio::main] -async fn main() { - // ! Change this to use a different provider. - let provider = LanguageModelProvider::Ollama; - - let prompt = "What is the capital of France?"; - - println!("Prompt: {prompt}"); - println!("---"); - - // Use a different language model, per the `provider` variable (feel free to change it). - let open_ai_api_key = { - if provider == LanguageModelProvider::OpenAi { - std::env::var("OPENAI_API_KEY") - .unwrap_or_else(|_| panic!("OPENAI_API_KEY environment variable not set")) - } else { - String::new() - } - }; - let lm: Box = match provider { - LanguageModelProvider::Ollama => Box::new(OllamaBuilder::new().try_build().unwrap()), - LanguageModelProvider::OpenAi => Box::new( - OpenAiBuilder::new() - .with_api_key(&open_ai_api_key) - .try_build() - .unwrap(), - ), - }; - - let executor = StructuredExecutorBuilder::new() - .with_lm(&*lm) - .with_preamble("You are a geography expert who helps users understand the capital city of countries around the world.") - .with_options(&options!(CapitalCityExecutorResponseOptions)) - .try_build() - .unwrap(); - let response = executor.execute(prompt).await.expect("Execution failed"); - - println!("Response:"); - println!("{:?}", response.content); -} diff --git a/core/examples/structured_data_generation_blog.rs b/core/examples/structured_data_generation_blog.rs index 42f23bf..8e7ac85 100644 --- a/core/examples/structured_data_generation_blog.rs +++ b/core/examples/structured_data_generation_blog.rs @@ -1,33 +1,45 @@ -#![allow(dead_code)] - //! This example demonstrates how to use the `Executor` to generate a structured response from the LLM. +//! Run like so: `cargo run --example structured_data_generation_blog -- blog.md` + +#![allow(dead_code)] use orch::execution::*; use orch::lm::*; use orch::response::*; -#[derive(OrchResponseOptions)] -pub enum BlogPostReviewerResponseOption { - #[response( - scenario = "You have reviewed the blog post", - description = "Suggestions for improving the blog post" - )] +#[derive(Variants, serde::Deserialize)] +#[serde(tag = "response_type")] +pub enum ResponseVariants { + Answer(AnswerResponseVariant), + Fail(FailResponseVariant), +} + +#[derive(Variant, serde::Deserialize)] +#[variant( + variant = "Answer", + scenario = "You have reviewed the blog post", + description = "Suggestions for improving the blog post" +)] +pub struct AnswerResponseVariant { #[schema( - field = "suggestions", description = "Suggestions for improving the blog post", example = "[\"You wrote 'excellent' in two consecutive paragraphs in section 'Introduction'\"]" )] - Answer { suggestions: Vec }, - #[response( - scenario = "For some reason you failed to generate suggestions", - description = "Reason why you failed to generate suggestions" - )] + pub suggestions: Vec, +} + +#[derive(Variant, serde::Deserialize)] +#[variant( + variant = "Fail", + scenario = "For some reason you failed to generate suggestions", + description = "Reason why you failed to generate suggestions" +)] +pub struct FailResponseVariant { #[schema( - field = "reason", description = "Reason why you failed to generate suggestions", example = "Content was invalid" )] - Fail { reason: String }, + pub reason: String, } #[tokio::main] @@ -35,13 +47,14 @@ async fn main() { // ! Change this to use a different provider. let provider = LanguageModelProvider::OpenAi; - let prompt = " - # Introduction - Hello, I am Guy. This is my first blog post! - "; + let args = std::env::args().collect::>(); + let blog_file_path = args.get(1).unwrap_or_else(|| { + eprintln!("ERROR: Please provide a path to a blog file"); + std::process::exit(1); + }); + let prompt = std::fs::read_to_string(blog_file_path).expect("Failed to read blog file"); - println!("Prompt: {prompt}"); - println!("---"); + println!("Analyzing blog post at path '{blog_file_path}'..."); // Use a different language model, per the `provider` variable (feel free to change it). let open_ai_api_key = { @@ -55,13 +68,13 @@ async fn main() { let lm: Box = match provider { LanguageModelProvider::Ollama => Box::new( OllamaBuilder::new() - .with_model(ollama_model::PHI3_MINI) + .with_model(ollama_model::PHI3_MINI.to_string()) .try_build() .unwrap(), ), LanguageModelProvider::OpenAi => Box::new( OpenAiBuilder::new() - .with_api_key(&open_ai_api_key) + .with_api_key(open_ai_api_key) .try_build() .unwrap(), ), @@ -69,12 +82,30 @@ async fn main() { let executor = StructuredExecutorBuilder::new() .with_lm(&*lm) - .with_preamble("You are an experienced writer and blog post reviewer who helps users improve their blog posts.") - .with_options(&options!(BlogPostReviewerResponseOption)) + .with_preamble(" + You are an experienced writer and blog post reviewer who helps users improve their blog posts. + You will receive a blog post written in Markdown, and you will need to provide suggestions for improving it. + Provide *specific* suggestions for improving the blog post, these can as nitpicky as you want. + Consider things such as grammar, spelling, clarity, and conciseness. + Even things like mentioning the same phrase too much in one paragraph, etc. + The tone should be personal, friendly and professional at the same time. + + Be very specific and refer to specific sentences, paragraph and sections of the blog post. + ") + .with_options(&variants!(ResponseVariants)) .try_build() .unwrap(); - let response = executor.execute(prompt).await.expect("Execution failed"); + let response = executor.execute(&prompt).await.expect("Execution failed"); - println!("Response:"); - println!("{:?}", response.content); + match response.content { + ResponseVariants::Answer(answer) => { + println!("Suggestions for improving the blog post:"); + for suggestion in answer.suggestions { + println!("- {}", suggestion); + } + } + ResponseVariants::Fail(fail) => { + println!("Model failed to generate a response: {}", fail.reason); + } + } } diff --git a/core/examples/structured_data_generation_capital.rs b/core/examples/structured_data_generation_capital.rs new file mode 100644 index 0000000..72e4440 --- /dev/null +++ b/core/examples/structured_data_generation_capital.rs @@ -0,0 +1,93 @@ +//! This example demonstrates how to use the `Executor` to generate a structured response from the LLM. +//! Run like so: `cargo run --example structured_data_generation_capital -- France` + +#![allow(dead_code)] + +use orch::execution::*; +use orch::lm::*; +use orch::response::*; + +#[derive(Variants, serde::Deserialize)] +pub enum ResponseVariants { + Answer(AnswerResponseVariant), + Fail(FailResponseVariant), +} + +#[derive(Variant, serde::Deserialize)] +#[variant( + variant = "Answer", + scenario = "You know the capital city of the country", + description = "Capital city of the country" +)] +pub struct AnswerResponseVariant { + #[schema( + description = "Capital city of the received country", + example = "London" + )] + pub capital: String, +} + +#[derive(Variant, serde::Deserialize)] +#[variant( + variant = "Fail", + scenario = "You don't know the capital city of the country", + description = "Reason why the capital city is not known" +)] +pub struct FailResponseVariant { + #[schema( + description = "Reason why the capital city is not known", + example = "Country 'foobar' does not exist" + )] + pub reason: String, +} + +#[tokio::main] +async fn main() { + // ! Change this to use a different provider. + let provider = LanguageModelProvider::Ollama; + + let args = std::env::args().collect::>(); + let prompt = args.get(1).unwrap_or_else(|| { + eprintln!("ERROR: Please provide a country name"); + std::process::exit(1); + }); + + // Use a different language model, per the `provider` variable (feel free to change it). + let open_ai_api_key = { + if provider == LanguageModelProvider::OpenAi { + std::env::var("OPENAI_API_KEY") + .unwrap_or_else(|_| panic!("OPENAI_API_KEY environment variable not set")) + } else { + String::new() + } + }; + let lm: Box = match provider { + LanguageModelProvider::Ollama => Box::new(OllamaBuilder::new().try_build().unwrap()), + LanguageModelProvider::OpenAi => Box::new( + OpenAiBuilder::new() + .with_api_key(open_ai_api_key) + .try_build() + .unwrap(), + ), + }; + + let executor = StructuredExecutorBuilder::new() + .with_lm(&*lm) + .with_preamble(" + You are a geography expert who helps users understand the capital city of countries around the world. + You will receive a country name, and you will need to provide the capital city of that country. + ") + .with_options(&variants!(ResponseVariants)) + .try_build() + .unwrap(); + let response = executor.execute(prompt).await.expect("Execution failed"); + + match response.content { + ResponseVariants::Answer(answer) => { + println!("Capital city: {}", answer.capital); + } + ResponseVariants::Fail(fail) => { + println!("Model failed to generate a response: {}", fail.reason); + } + } +} diff --git a/core/examples/text_generation.rs b/core/examples/text_generation.rs index 16a9a86..8f975c1 100644 --- a/core/examples/text_generation.rs +++ b/core/examples/text_generation.rs @@ -1,5 +1,5 @@ //! This example demonstrates how to use the `Executor` to generate a response from the LLM. -//! We construct an `Ollama` instance and use it to generate a response. +//! Run like so: `cargo run --example text_generation` use orch::execution::*; use orch::lm::*; @@ -26,13 +26,13 @@ async fn main() { let lm: Box = match provider { LanguageModelProvider::Ollama => Box::new( OllamaBuilder::new() - .with_model(ollama_model::PHI3_MINI) + .with_model(ollama_model::PHI3_MINI.to_string()) .try_build() .unwrap(), ), LanguageModelProvider::OpenAi => Box::new( OpenAiBuilder::new() - .with_api_key(&open_ai_api_key) + .with_api_key(open_ai_api_key) .try_build() .unwrap(), ), diff --git a/core/examples/text_generation_stream.rs b/core/examples/text_generation_stream.rs index 074b3eb..be6ae7e 100644 --- a/core/examples/text_generation_stream.rs +++ b/core/examples/text_generation_stream.rs @@ -1,5 +1,5 @@ //! This example demonstrates how to use the `Executor` to generate a streaming response from the LLM. -//! We construct an `Ollama` instance and use it to generate a streaming response. +//! Run like so: `cargo run --example text_generation_stream` use orch::execution::*; use orch::lm::*; @@ -27,13 +27,13 @@ async fn main() { let lm: Box = match provider { LanguageModelProvider::Ollama => Box::new( OllamaBuilder::new() - .with_model(ollama_model::PHI3_MINI) + .with_model(ollama_model::PHI3_MINI.to_string()) .try_build() .unwrap(), ), LanguageModelProvider::OpenAi => Box::new( OpenAiBuilder::new() - .with_api_key(&open_ai_api_key) + .with_api_key(open_ai_api_key) .try_build() .unwrap(), ), diff --git a/core/examples/variants_derive.rs b/core/examples/variants_derive.rs new file mode 100644 index 0000000..ae1bc6e --- /dev/null +++ b/core/examples/variants_derive.rs @@ -0,0 +1,63 @@ +//! This example demonstrates how to use the `Variants` derive macro to generate a structured response from the LLM. +//! +//! Run like so: `cargo run --example variants_derive` + +use orch::response::*; + +#[derive(Variants, serde::Deserialize)] +pub enum ResponseOptions { + Answer(AnswerResponseOption), + Fail(FailResponseOption), +} + +#[derive(Variant, serde::Deserialize)] +#[variant( + variant = "Answer", + scenario = "You know the capital city of the country", + description = "Capital city of the country" +)] +pub struct AnswerResponseOption { + #[schema( + description = "Capital city of the received country", + example = "London" + )] + pub capital: String, + #[schema( + description = "Country of the received capital city", + example = "United Kingdom" + )] + pub country: String, +} + +#[derive(Variant, serde::Deserialize)] +#[variant( + variant = "Fail", + scenario = "You don't know the capital city of the country", + description = "Reason why the capital city is not known" +)] +pub struct FailResponseOption { + #[schema( + description = "Reason why the capital city is not known", + example = "Country 'foobar' does not exist" + )] + pub reason: String, +} + +fn main() { + let response = r#" + { + "response_type": "Answer", + "capital": "London", + "country": "United Kingdom" + } + "#; + let parsed_response = variants!(ResponseOptions).parse(response).unwrap(); + match parsed_response { + ResponseOptions::Answer(answer_response) => { + println!("{}", answer_response.capital); + } + ResponseOptions::Fail(fail_response) => { + println!("{}", fail_response.reason); + } + } +} diff --git a/core/src/execution/builder.rs b/core/src/execution/builder.rs index 9774b25..f3fd45d 100644 --- a/core/src/execution/builder.rs +++ b/core/src/execution/builder.rs @@ -1,4 +1,4 @@ -use orch_response::ResponseOptions; +use orch_response::OrchResponseVariants; use thiserror::Error; use crate::lm::LanguageModel; @@ -49,24 +49,18 @@ impl<'a> TextExecutorBuilder<'a> { } #[derive(Default)] -pub struct StructuredExecutorBuilder<'a, T> -where - T: serde::de::DeserializeOwned + Sized, -{ +pub struct StructuredExecutorBuilder<'a, T> { lm: Option<&'a dyn LanguageModel>, preamble: Option<&'a str>, - options: Option<&'a dyn ResponseOptions>, + variants: Option<&'a dyn OrchResponseVariants>, } -impl<'a, T> StructuredExecutorBuilder<'a, T> -where - T: serde::de::DeserializeOwned + Sized, -{ +impl<'a, T> StructuredExecutorBuilder<'a, T> { pub fn new() -> Self { Self { lm: None, preamble: None, - options: None, + variants: None, } } @@ -75,8 +69,8 @@ where self } - pub fn with_options(mut self, options: &'a dyn ResponseOptions) -> Self { - self.options = Some(options); + pub fn with_options(mut self, options: &'a dyn OrchResponseVariants) -> Self { + self.variants = Some(options); self } @@ -91,7 +85,7 @@ where "Language model".to_string(), )); }; - let Some(response_options) = self.options else { + let Some(response_options) = self.variants else { return Err(ExecutorBuilderError::ConfigurationNotSet( "Response options".to_string(), )); @@ -99,7 +93,7 @@ where Ok(StructuredExecutor { lm, preamble: self.preamble, - response_options, + variants: response_options, }) } } diff --git a/core/src/execution/executor.rs b/core/src/execution/executor.rs index 3ba47c8..0eab265 100644 --- a/core/src/execution/executor.rs +++ b/core/src/execution/executor.rs @@ -1,6 +1,6 @@ use std::{cell::OnceCell, pin::Pin}; -use orch_response::{ResponseOption, ResponseOptions, ResponseSchemaField}; +use orch_response::{OrchResponseVariants, ResponseOption, ResponseSchemaField}; use thiserror::Error; use tokio_stream::Stream; @@ -36,7 +36,7 @@ trait Executor<'a> { fn system_prompt(&self) -> String { let cell = OnceCell::new(); cell.get_or_init(|| { - let response_options = self.response_options().unwrap_or_default(); + let response_options = self.variants().unwrap_or_default(); generate_system_prompt( self.format(), self.preamble().unwrap_or(DEFAULT_PREAMBLE), @@ -46,7 +46,7 @@ trait Executor<'a> { .clone() } - fn response_options(&self) -> Option> { + fn variants(&self) -> Option> { None } @@ -137,22 +137,19 @@ impl<'a> TextExecutor<'a> { } } -pub struct StructuredExecutor<'a, T> -where - T: serde::de::DeserializeOwned, -{ +pub struct StructuredExecutor<'a, T> { pub(crate) lm: &'a dyn LanguageModel, pub(crate) preamble: Option<&'a str>, - pub(crate) response_options: &'a dyn ResponseOptions, + pub(crate) variants: &'a dyn OrchResponseVariants, } -impl<'a, T: serde::de::DeserializeOwned> Executor<'a> for StructuredExecutor<'a, T> { +impl<'a, T> Executor<'a> for StructuredExecutor<'a, T> { fn format(&self) -> ResponseFormat { ResponseFormat::Json } - fn response_options(&self) -> Option> { - Some(self.response_options.options()) + fn variants(&self) -> Option> { + Some(self.variants.variants()) } fn lm(&self) -> &'a dyn LanguageModel { @@ -167,10 +164,7 @@ impl<'a, T: serde::de::DeserializeOwned> Executor<'a> for StructuredExecutor<'a, /// Trait for LLM execution. /// This should be implemented for each LLM text generation use-case, where the system prompt /// changes according to the trait implementations. -impl<'a, T> StructuredExecutor<'a, T> -where - T: serde::de::DeserializeOwned, -{ +impl<'a, T> StructuredExecutor<'a, T> { /// Generates a structured response from the LLM (non-streaming). /// /// # Arguments @@ -184,15 +178,12 @@ where prompt: &'a str, ) -> Result, ExecutorError> { let text_result = self.text_complete(prompt).await?; - let result = self - .response_options - .parse(&text_result.content) - .map_err(|e| { - ExecutorError::Parsing(format!( - "Error while parsing response: {e}\nResponse: {:?}", - text_result.content - )) - })?; + let result = self.variants.parse(&text_result.content).map_err(|e| { + ExecutorError::Parsing(format!( + "Error while parsing response: {e}\nResponse: {:?}", + text_result.content + )) + })?; // TODO: Add error correction and handling. Ok(ExecutorTextCompleteResponse { content: result, diff --git a/core/src/lib.rs b/core/src/lib.rs index 365bc2d..806a8c2 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,129 +1,6 @@ -//! `orch` is a library for building language model powered applications and agents for the Rust programming language. -//! It was primarily built for usage in [magic-cli](https://github.com/guywaldman/magic-cli), but can be used in other contexts as well. -//! -//! # Examples -//! -//! # Basic Usage -//! -//! ## Simple Text Generation -//! -//! ```no_run -//! use orch::execution::*; -//! use orch::lm::*; -//! -//! #[tokio::main] -//! async fn main() { -//! let lm = OllamaBuilder::new().try_build().unwrap(); -//! let executor = TextExecutorBuilder::new().with_lm(&lm).try_build().unwrap(); -//! let response = executor.execute("What is 2+2?").await.expect("Execution failed"); -//! println!("{}", response.content); -//! } -//! ``` -//! -//! ## Streaming Text Generation -//! -//! ```no_run -//! use orch::execution::*; -//! use orch::lm::*; -//! use tokio_stream::StreamExt; -//! -//! #[tokio::main] -//! async fn main() { -//! let lm = OllamaBuilder::new().try_build().unwrap(); -//! let executor = TextExecutorBuilder::new().with_lm(&lm).try_build().unwrap(); -//! let mut response = executor.execute_stream("What is 2+2?").await.expect("Execution failed"); -//! while let Some(chunk) = response.stream.next().await { -//! match chunk { -//! Ok(chunk) => print!("{chunk}"), -//! Err(e) => { -//! println!("Error: {e}"); -//! break; -//! } -//! } -//! } -//! println!(); -//! } -//! ``` -//! -//! ## Structured Data Generation -//! -//! ```no_run -//! use orch::execution::*; -//! use orch::lm::*; -//! use orch::response::*; -//! -//! #[derive(OrchResponseOptions)] -//! pub enum CapitalCityExecutorResponseOptions { -//! #[response( -//! scenario = "You know the capital city of the country", -//! description = "Capital city of the country" -//! )] -//! #[schema( -//! field = "capital", -//! description = "Capital city of the received country", -//! example = "London" -//! )] -//! Answer { capital: String }, -//! #[response( -//! scenario = "You don't know the capital city of the country", -//! description = "Reason why the capital city is not known" -//! )] -//! #[schema( -//! field = "reason", -//! description = "Reason why the capital city is not known", -//! example = "Country 'foobar' does not exist" -//! )] -//! Fail { reason: String }, -//! } -//! -//! #[tokio::main] -//! async fn main() { -//! let lm = OllamaBuilder::new().try_build().unwrap(); -//! let executor = StructuredExecutorBuilder::new() -//! .with_lm(&lm) -//! .with_preamble("You are a geography expert who helps users understand the capital city of countries around the world.") -//! .with_options(&options!(CapitalCityExecutorResponseOptions)) -//! .try_build() -//! .unwrap(); -//! let response = executor.execute("What is the capital of Fooland?").await.expect("Execution failed"); -//! -//! println!("Response:"); -//! println!("{:?}", response.content); -//! } -//! ``` -//! -//! ## Embedding Generation -//! -//! ```no_run -//! use orch::execution::*; -//! use orch::lm::*; -//! -//! #[tokio::main] -//! async fn main() { -//! let lm = OllamaBuilder::new().try_build().unwrap(); -//! let executor = TextExecutorBuilder::new() -//! .with_lm(&lm) -//! .try_build() -//! .unwrap(); -//! let embedding = executor -//! .generate_embedding("Phrase to generate an embedding for") -//! .await -//! .expect("Execution failed"); -//! -//! println!("Embedding:"); -//! println!("{:?}", embedding); -//! } -//! ``` -//! -//! ## More Examples -//! -//! See the [examples](https://github.com/guywaldman/orch/tree/main/core/examples) directory for usage examples. +#![doc = include_str!("../../README.md")] pub mod execution; pub mod lm; mod net; - -pub mod response { - pub use orch_response::*; - pub use orch_response_derive::*; -} +pub mod response; diff --git a/core/src/lm/lm_provider/models.rs b/core/src/lm/lm_provider/models.rs index 1e96d09..b13ba26 100644 --- a/core/src/lm/lm_provider/models.rs +++ b/core/src/lm/lm_provider/models.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::lm::LanguageModel; + use super::{Ollama, OpenAi}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -14,3 +16,26 @@ pub enum OrchLanguageModel { Ollama(Ollama), OpenAi(OpenAi), } + +impl OrchLanguageModel { + pub fn provider(&self) -> LanguageModelProvider { + match self { + OrchLanguageModel::Ollama(_) => LanguageModelProvider::Ollama, + OrchLanguageModel::OpenAi(_) => LanguageModelProvider::OpenAi, + } + } + + pub fn text_completion_model_name(&self) -> String { + match self { + OrchLanguageModel::Ollama(lm) => lm.text_completion_model_name(), + OrchLanguageModel::OpenAi(lm) => lm.text_completion_model_name(), + } + } + + pub fn embedding_model_name(&self) -> String { + match self { + OrchLanguageModel::Ollama(lm) => lm.embedding_model_name(), + OrchLanguageModel::OpenAi(lm) => lm.embedding_model_name(), + } + } +} diff --git a/core/src/lm/lm_provider/ollama/builder.rs b/core/src/lm/lm_provider/ollama/builder.rs index 9174efc..fbc8395 100644 --- a/core/src/lm/lm_provider/ollama/builder.rs +++ b/core/src/lm/lm_provider/ollama/builder.rs @@ -10,35 +10,35 @@ pub enum OllamaBuilderError { ConfigurationNotSet(String), } -pub struct OllamaBuilder<'a> { - base_url: Option<&'a str>, - model: Option<&'a str>, - embeddings_model: Option<&'a str>, +pub struct OllamaBuilder { + base_url: Option, + model: Option, + embeddings_model: Option, } -impl<'a> OllamaBuilder<'a> { - pub fn with_base_url(mut self, base_url: &'a str) -> Self { +impl OllamaBuilder { + pub fn with_base_url(mut self, base_url: String) -> Self { self.base_url = Some(base_url); self } - pub fn with_model(mut self, model: &'a str) -> Self { + pub fn with_model(mut self, model: String) -> Self { self.model = Some(model); self } - pub fn with_embeddings_model(mut self, embeddings_model: &'a str) -> Self { + pub fn with_embeddings_model(mut self, embeddings_model: String) -> Self { self.embeddings_model = Some(embeddings_model); self } } -impl<'a> LanguageModelBuilder for OllamaBuilder<'a> { +impl LanguageModelBuilder for OllamaBuilder { fn new() -> Self { Self { - base_url: Some("http://localhost:11434"), - model: Some(ollama_model::CODESTRAL), - embeddings_model: Some(ollama_embedding_model::NOMIC_EMBED_TEXT), + base_url: Some("http://localhost:11434".to_string()), + model: Some(ollama_model::CODESTRAL.to_string()), + embeddings_model: Some(ollama_embedding_model::NOMIC_EMBED_TEXT.to_string()), } } diff --git a/core/src/lm/lm_provider/openai/builder.rs b/core/src/lm/lm_provider/openai/builder.rs index 1cea89f..957e165 100644 --- a/core/src/lm/lm_provider/openai/builder.rs +++ b/core/src/lm/lm_provider/openai/builder.rs @@ -10,51 +10,48 @@ pub enum OpenAiBuilderError { ConfigurationNotSet(String), } -pub struct OpenAiBuilder<'a> { - api_key: Option<&'a str>, - model: Option<&'a str>, - embeddings_model: Option<&'a str>, +pub struct OpenAiBuilder { + api_key: Option, + model: Option, + embeddings_model: Option, embedding_dimensions: Option, } -impl<'a> OpenAiBuilder<'a> { - pub fn with_api_key(mut self, api_key: &'a str) -> Self { +impl OpenAiBuilder { + pub fn with_api_key(mut self, api_key: String) -> Self { self.api_key = Some(api_key); self } - pub fn with_model(mut self, model: &'a str) -> Self { + pub fn with_model(mut self, model: String) -> Self { self.model = Some(model); self } - pub fn with_embeddings_model(mut self, embeddings_model: &'a str) -> Self { - self.embeddings_model = Some(embeddings_model); - match embeddings_model { + pub fn with_embeddings_model(mut self, embeddings_model: String) -> Self { + self.embeddings_model = Some(embeddings_model.clone()); + self.embedding_dimensions = match embeddings_model.as_ref() { openai_embedding_model::TEXT_EMBEDDING_ADA_002 => { - self.embedding_dimensions = - Some(openai_embedding_model::TEXT_EMBEDDING_ADA_002_DIMENSIONS); + Some(openai_embedding_model::TEXT_EMBEDDING_ADA_002_DIMENSIONS) } openai_embedding_model::TEXT_EMBEDDING_3_SMALL => { - self.embedding_dimensions = - Some(openai_embedding_model::TEXT_EMBEDDING_3_SMALL_DIMENSIONS); + Some(openai_embedding_model::TEXT_EMBEDDING_3_SMALL_DIMENSIONS) } openai_embedding_model::TEXT_EMBEDDING_3_LARGE => { - self.embedding_dimensions = - Some(openai_embedding_model::TEXT_EMBEDDING_3_LARGE_DIMENSIONS); + Some(openai_embedding_model::TEXT_EMBEDDING_3_LARGE_DIMENSIONS) } - _ => {} - } + _ => None, + }; self } } -impl<'a> LanguageModelBuilder for OpenAiBuilder<'a> { +impl LanguageModelBuilder for OpenAiBuilder { fn new() -> Self { Self { api_key: None, - model: Some(openai_model::GPT_4O_MINI), - embeddings_model: Some(openai_embedding_model::TEXT_EMBEDDING_ADA_002), + model: Some(openai_model::GPT_4O_MINI.to_string()), + embeddings_model: Some(openai_embedding_model::TEXT_EMBEDDING_ADA_002.to_string()), embedding_dimensions: Some(openai_embedding_model::TEXT_EMBEDDING_ADA_002_DIMENSIONS), } } diff --git a/core/src/lm/models.rs b/core/src/lm/models.rs index dad91db..ab5118e 100644 --- a/core/src/lm/models.rs +++ b/core/src/lm/models.rs @@ -13,7 +13,7 @@ use super::{error::LanguageModelError, LanguageModelProvider}; /// > `DynClone` is used so that there can be dynamic dispatch of the `Llm` trait, /// > especially needed for [magic-cli](https://github.com/guywaldman/magic-cli). #[async_trait] -pub trait LanguageModel: DynClone { +pub trait LanguageModel: DynClone + Send + Sync { /// Generates a response from the LLM. /// /// # Arguments diff --git a/core/src/response.rs b/core/src/response.rs new file mode 100644 index 0000000..05bc7e0 --- /dev/null +++ b/core/src/response.rs @@ -0,0 +1,2 @@ +pub use orch_response::*; +pub use orch_response_derive::*; diff --git a/response/Cargo.toml b/response/Cargo.toml index 88806bf..63f2db9 100644 --- a/response/Cargo.toml +++ b/response/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "orch_response" -version = "0.0.11" +version = "0.0.12" edition = "2021" license = "MIT" description = "Models for orch Executor responses" diff --git a/response/src/lib.rs b/response/src/lib.rs index 8e01318..20d5a70 100644 --- a/response/src/lib.rs +++ b/response/src/lib.rs @@ -24,13 +24,11 @@ pub struct ResponseSchemaField { pub example: String, } -pub trait ResponseOptions -where - T: serde::de::DeserializeOwned, -{ - fn options(&self) -> Vec; +pub trait OrchResponseVariant: Send + Sync { + fn variant() -> ResponseOption; +} - fn parse(&self, response: &str) -> Result { - serde_json::from_str(response) - } +pub trait OrchResponseVariants: Send + Sync { + fn variants(&self) -> Vec; + fn parse(&self, response: &str) -> Result; } diff --git a/response_derive/Cargo.toml b/response_derive/Cargo.toml index d32969a..0fc1ae2 100644 --- a/response_derive/Cargo.toml +++ b/response_derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "orch_response_derive" -version = "0.0.11" +version = "0.0.12" edition = "2021" license = "MIT" description = "Derive macros for orch Executor responses" @@ -12,7 +12,7 @@ keywords = ["llm", "openai", "ollama", "rust"] proc-macro = true [dependencies] -orch_response = { path = "../response", version = "0.0.11" } +orch_response = { path = "../response", version = "0.0.12" } darling = "0.20.10" proc-macro2 = "1.0.86" quote = "1.0.36" diff --git a/response_derive/src/attribute_impl.rs b/response_derive/src/attribute_impl.rs index 62a3eee..77c56ed 100644 --- a/response_derive/src/attribute_impl.rs +++ b/response_derive/src/attribute_impl.rs @@ -1,14 +1,16 @@ use darling::FromMeta; +/// #[variant(...)] #[derive(Debug, FromMeta)] -pub(crate) struct ResponseAttribute { +pub(crate) struct VariantAttribute { + pub(crate) variant: String, pub(crate) scenario: String, pub(crate) description: String, } +/// #[schema(...)] #[derive(Debug, FromMeta)] pub(crate) struct SchemaAttribute { - pub(crate) field: String, pub(crate) description: String, pub(crate) example: String, } diff --git a/response_derive/src/derive_impl.rs b/response_derive/src/derive_impl.rs index 65b9229..75f5286 100644 --- a/response_derive/src/derive_impl.rs +++ b/response_derive/src/derive_impl.rs @@ -2,135 +2,273 @@ use darling::FromMeta; use quote::quote; use syn::{parse_macro_input, spanned::Spanned, DeriveInput, PathArguments}; -use crate::attribute_impl::{ResponseAttribute, SchemaAttribute}; +use crate::attribute_impl::{SchemaAttribute, VariantAttribute}; -pub(crate) fn derive_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { +pub(crate) fn response_variants_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut output = quote!(); - // 1. - // Parse the input and validate its type (only an enum is supported). + // Bring traits into scope. + output.extend(quote! { + use ::orch::response::OrchResponseVariant; + use ::serde::de::Error; + }); + let original_enum = parse_macro_input!(input as DeriveInput); - let DeriveInput { data, ident, vis, .. } = original_enum.clone(); + let DeriveInput { data, ident, .. } = original_enum.clone(); let syn::Data::Enum(data) = data else { - panic!("#[derive(OrchResponseOptions)] can only be used with enums"); + panic!("#[derive(OrchResponseVariants)] can only be used with enums"); }; + let original_enum_ident = ident; - // 2. - // Create a new derived struct, such that we later add an impl trait for the original - // enum which returns the derived struct. + // // 2. + // // Create a new derived struct, such that we later add an impl trait for the original + // // enum which returns the derived struct. + + // let derived_enum_ident = syn::Ident::new(&format!("{}Derived", original_enum_ident), original_enum_ident.span()); + // let derived_enum_variant_fields = data.variants.iter().map(|v| { + // let ident = &v.ident; + // let variant_struct_name = syn::Ident::new(&format!("{}{}", original_enum_ident, v.ident), v.ident.span()); + // quote! { #ident(#variant_struct_name), } + // }); + // output.extend(quote! { + // #[derive(Debug, ::serde::Deserialize)] + // #[serde(tag = "response_type")] + // #vis enum #derived_enum_ident { + // #(#derived_enum_variant_fields)* + // } + // }); + + // // 3. Add the new derived enum variant structs. + // for variant in data.variants.iter() { + // let ident = &variant.ident; + // let variant_struct_name = syn::Ident::new(&format!("{}{}", original_enum_ident, ident), ident.span()); + // let fields = variant.fields.iter(); + // output.extend(quote! { + // #[derive(Debug, ::serde::Deserialize)] + // pub struct #variant_struct_name { + // #(#fields),* + // } + // }); + // } + + // // 4. Declare a new struct that will be used to parse the response. + // let parser_struct_ident = syn::Ident::new(&format!("{}Parser", original_enum_ident), original_enum_ident.span()); + // output.extend(quote! { + // #[derive(Debug)] + // pub struct #parser_struct_ident; + // }); + + // // 4. Implement the `ResponseOptions` trait for a new struct that will be used to parse the response. + // let mut options_vec_pushes = quote!(); + // for syn::Variant { ident, attrs, fields, .. } in data.variants.iter() { + // let schema_attrs = attrs + // .iter() + // .filter_map(|attr| SchemaAttribute::from_meta(&attr.meta).ok()) + // .collect::>(); + // if schema_attrs.len() != fields.len() { + // panic!("Expected a single #[schema(...)] attribute for each field of the enum variant"); + // } + // let mut schema_fields = Vec::new(); + // for variant_field in fields.iter() { + // let schema_attr_for_field = schema_attrs + // .iter() + // .find(|attr| *variant_field.ident.as_ref().unwrap() == attr.field) + // .unwrap_or_else(|| { + // panic!( + // "Field {} not found in #[schema(...)] attributes", + // variant_field.ident.as_ref().unwrap() + // ) + // }); + // let SchemaAttribute { + // field, + // description, + // example, + // } = schema_attr_for_field; + // let typ = ast_type_to_str(&variant_field.ty).unwrap_or_else(|_| { + // panic!( + // "Failed to convert type to string for field `{}` of variant `{}`", + // variant_field.ident.as_ref().unwrap(), + // ident + // ) + // }); + // let typ = syn::LitStr::new(&typ, variant_field.span()); + // schema_fields.push(quote! { + // ::orch::response::ResponseSchemaField { + // name: #field.to_string(), + // description: #description.to_string(), + // typ: #typ.to_string(), + // example: #example.to_string(), + // } + // }) + // } + + // // Each enum variant should implement the `OrchResponseOption` trait. + // options_vec_pushes.extend(quote! { + // options.push(#original_enum_ident::#ident::option()) + // }); + // } + + let vec_capacity = data.variants.len(); + // let mut parse_match_arms = quote!(); + // Transform the derived enum variant into the original enum variant. + // for original_variant in data.variants.iter() { + // let ident = &original_variant.ident; + // let fields = original_variant.fields.iter().map(|field| { + // let ident = &field.ident; + // quote! { #ident: parsed_response.#ident } + // }); + // parse_match_arms.extend(quote! { + // #derived_enum_ident::#ident(parsed_response) => Ok(#original_enum_ident::#ident { + // #(#fields),* + // }), + // }); + // } + + let mut options_vec_pushes = quote!(); + for variant in data.variants.iter() { + let ident = syn::Ident::new( + &get_enum_variant_struct_ident(variant).expect("Failed to parse enum variant"), + variant.ident.span(), + ); + + options_vec_pushes.extend(quote! { + options.push(#ident::variant()); + }); + } + + // We construct a new struct that will be used to parse the response. + // NOTE: This is hacky, but a workaround for the fact that the enum cannot be constructed. + let derived_enum_struct_ident = syn::Ident::new(&format!("{}Derived", original_enum_ident), original_enum_ident.span()); - let original_enum_ident = ident; - let derived_enum_ident = syn::Ident::new(&format!("{}Derived", original_enum_ident), original_enum_ident.span()); - let derived_enum_variant_fields = data.variants.iter().map(|v| { - let ident = &v.ident; - let variant_struct_name = syn::Ident::new(&format!("{}{}", original_enum_ident, v.ident), v.ident.span()); - quote! { #ident(#variant_struct_name), } - }); output.extend(quote! { - #[derive(Debug, ::serde::Deserialize)] - #[serde(tag = "response_type")] - #vis enum #derived_enum_ident { - #(#derived_enum_variant_fields)* - } + #[derive(Debug)] + pub struct #derived_enum_struct_ident; }); - // 3. Add the new derived enum variant structs. + // Note: We parse with a dynamic evaluation and looking at the `response_type` field, but this could be done + // by deriving #[serde(tag = "response_type")] on the enum. + let mut response_type_arms = quote!(); for variant in data.variants.iter() { - let ident = &variant.ident; - let variant_struct_name = syn::Ident::new(&format!("{}{}", original_enum_ident, ident), ident.span()); - let fields = variant.fields.iter(); - output.extend(quote! { - #[derive(Debug, ::serde::Deserialize)] - pub struct #variant_struct_name { - #(#fields),* - } + let variant_ident = variant.ident.clone(); + let variant_ident_str = syn::LitStr::new(&variant.ident.to_string(), variant.ident.span()); + let struct_ident = syn::Ident::new( + &get_enum_variant_struct_ident(variant).expect("Failed to parse enum variant"), + variant.ident.span(), + ); + response_type_arms.extend(quote! { + #variant_ident_str => Ok(#original_enum_ident::#variant_ident(serde_json::from_str::<#struct_ident>(response)?)), }); } - // 4. Declare a new struct that will be used to parse the response. - let parser_struct_ident = syn::Ident::new(&format!("{}Parser", original_enum_ident), original_enum_ident.span()); output.extend(quote! { - #[derive(Debug)] - pub struct #parser_struct_ident; + impl ::orch::response::OrchResponseVariants<#original_enum_ident> for #derived_enum_struct_ident { + fn variants(&self) -> Vec<::orch::response::ResponseOption> { + let mut options = Vec::with_capacity(#vec_capacity); + #options_vec_pushes + options + } + + fn parse(&self, response: &str) -> Result<#original_enum_ident, ::serde_json::Error> { + let dynamic_parsed = serde_json::from_str::(response)?; + let response_type = dynamic_parsed.get("response_type").unwrap().as_str().unwrap(); + match response_type { + #response_type_arms + _ => Err(::serde_json::Error::custom("Invalid response type")), + } + } + } }); - // 4. Implement the `ResponseOptions` trait for a new struct that will be used to parse the response. - let mut options_vec_pushes = quote!(); - for syn::Variant { ident, attrs, fields, .. } in data.variants.iter() { - let response_attr = attrs - .iter() - .filter_map(|attr| ResponseAttribute::from_meta(&attr.meta).ok()) - .next() - .expect("#[response] attribute not found on variant field"); - let ResponseAttribute { scenario, description } = response_attr; + output.into() +} + +pub fn response_variant_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let original_struct = parse_macro_input!(input as DeriveInput); + let DeriveInput { data, ident, attrs, .. } = original_struct.clone(); + let syn::Data::Struct(data) = data else { + panic!("#[derive(OrchResponseOption)] can only be used with structs"); + }; + let original_struct_ident = ident.clone(); - let schema_attrs = attrs + let fields = data.fields; + + // Parse the #[variant(...)] attribute. + let variant_attr = attrs + .iter() + .filter_map(|attr| VariantAttribute::from_meta(&attr.meta).ok()) + .next() + .expect("#[variant(...)] attribute not found on variant field"); + let VariantAttribute { + variant, + scenario, + description, + } = variant_attr; + + // Parse the fields used in [`orch::response::OrchResponseVariant`]. + let mut schema_fields = Vec::new(); + for variant_field in fields.iter() { + // Parse the #[schema(...)] attribute. + let schema_attr = variant_field + .attrs .iter() .filter_map(|attr| SchemaAttribute::from_meta(&attr.meta).ok()) .collect::>(); - if schema_attrs.len() != fields.len() { - panic!("Expected a single #[schema(...)] attribute for each field of the enum variant"); - } - let mut schema_fields = Vec::new(); - for variant_field in fields.iter() { - let schema_attr_for_field = schema_attrs - .iter() - .find(|attr| *variant_field.ident.as_ref().unwrap() == attr.field) - .unwrap_or_else(|| { - panic!( - "Field {} not found in #[schema(...)] attributes", - variant_field.ident.as_ref().unwrap() - ) - }); - let SchemaAttribute { - field, - description, - example, - } = schema_attr_for_field; - let typ = ast_type_to_str(&variant_field.ty).unwrap_or_else(|_| { - panic!( - "Failed to convert type to string for field `{}` of variant `{}`", - variant_field.ident.as_ref().unwrap(), - ident - ) - }); - let typ = syn::LitStr::new(&typ, variant_field.span()); - schema_fields.push(quote! { - ::orch::response::ResponseSchemaField { - name: #field.to_string(), - description: #description.to_string(), - typ: #typ.to_string(), - example: #example.to_string(), - } - }) + if schema_attr.len() != 1 { + panic!("Expected a single #[schema(...)] attribute for each field of the enum variant with the correct format and parameters"); } + let SchemaAttribute { description, example } = schema_attr.first().expect("Failed to parse schema attribute"); - let schema_fields = schema_fields.iter(); - let ident_str = syn::LitStr::new(&ident.to_string(), ident.span()); - options_vec_pushes.extend(quote! { - options.push(::orch::response::ResponseOption { - type_name: #ident_str.to_string(), - scenario: #scenario.to_string(), - description: #description.to_string(), - schema: vec![ - #(#schema_fields),* - ] - }); + let typ = ast_type_to_str(&variant_field.ty).unwrap_or_else(|_| { + panic!( + "Failed to convert type to string for field `{}` of variant `{}`", + variant_field.ident.as_ref().unwrap(), + ident + ) }); + let typ = syn::LitStr::new(&typ, variant_field.span()); + let field_ident = syn::LitStr::new(&variant_field.ident.as_ref().unwrap().to_string(), variant_field.span()); + schema_fields.push(quote! { + ::orch::response::ResponseSchemaField { + name: #field_ident.to_string(), + description: #description.to_string(), + typ: #typ.to_string(), + example: #example.to_string(), + } + }) } - let vec_capacity = data.variants.len(); - output.extend(quote! { - impl ::orch::response::ResponseOptions<#derived_enum_ident> for #parser_struct_ident { - fn options(&self) -> Vec<::orch::response::ResponseOption> { - let mut options = Vec::with_capacity(#vec_capacity); - #options_vec_pushes - options + quote! { + impl ::orch::response::OrchResponseVariant for #original_struct_ident { + fn variant() -> ::orch::response::ResponseOption { + ::orch::response::ResponseOption { + type_name: #variant.to_string(), + scenario: #scenario.to_string(), + description: #description.to_string(), + schema: vec![ + #(#schema_fields),* + ] + } } } - }); + } + .into() +} - output.into() +// Parse `Answer(AnswerResponseOption)` into `AnswerResponseOption`. +fn get_enum_variant_struct_ident(variant: &syn::Variant) -> Result { + // We expect the enum variant to look like this: `Answer(AnswerResponseOption)`, + // so we parse the `AnswerResponseOption` struct. + let syn::Fields::Unnamed(fields) = &variant.fields else { + panic!("Expected an unnamed struct for each enum variant"); + }; + let Some(syn::Field { ty, .. }) = fields.unnamed.first() else { + panic!("Expected an unnamed struct for each enum variant"); + }; + let syn::Type::Path(p) = ty else { + panic!("Expected an unnamed struct for each enum variant"); + }; + let ident = &p.path.segments.first().unwrap().ident; + Ok(ident.to_string()) } fn ast_type_to_str(ty: &syn::Type) -> Result { diff --git a/response_derive/src/lib.rs b/response_derive/src/lib.rs index ecda692..e967d49 100644 --- a/response_derive/src/lib.rs +++ b/response_derive/src/lib.rs @@ -1,20 +1,26 @@ -use quote::quote; - mod attribute_impl; mod derive_impl; -/// Used to derive the `ResponseOptions` trait for a given enum. -#[proc_macro_derive(OrchResponseOptions, attributes(response, schema))] -pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - derive_impl::derive_impl(input) +use quote::quote; + +/// Used to derive the `OrchResponseVariants` trait for a given enum +#[proc_macro_derive(Variants)] +pub fn derive_orch_response_variants(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + derive_impl::response_variants_derive(input) +} + +/// Used to derive the `OrchResponseVariant` trait for a given enum. +#[proc_macro_derive(Variant, attributes(variant, schema))] +pub fn derive_orch_response_variant_variant(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + derive_impl::response_variant_derive(input) } -/// Used to create a new `ResponseOptions` instance. +/// Used to construct the identifier of the derived enum. #[proc_macro] -pub fn options(input: proc_macro::TokenStream) -> proc_macro::TokenStream { +pub fn variants(input: proc_macro::TokenStream) -> proc_macro::TokenStream { // Expects the identifier of the derived enum. let enum_ident = syn::parse_macro_input!(input as syn::Ident); - let derived_enum_ident = syn::Ident::new(&format!("{}Parser", enum_ident), enum_ident.span()); + let derived_enum_ident = syn::Ident::new(&format!("{}Derived", enum_ident), enum_ident.span()); quote! { #derived_enum_ident {} }