diff --git a/engine/Cargo.lock b/engine/Cargo.lock index 633e6eccf..28bf23172 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -120,6 +120,12 @@ version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +[[package]] +name = "arrayvec" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" + [[package]] name = "askama" version = "0.12.1" @@ -2713,6 +2719,7 @@ dependencies = [ name = "internal-baml-schema-ast" version = "0.69.0" dependencies = [ + "anyhow", "baml-types", "bstd", "either", @@ -2720,9 +2727,11 @@ dependencies = [ "log", "pest", "pest_derive", + "pretty", "serde", "serde_json", "test-log", + "unindent", ] [[package]] @@ -3654,6 +3663,17 @@ dependencies = [ "termtree", ] +[[package]] +name = "pretty" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55c4d17d994b637e2f4daf6e5dc5d660d209d5642377d675d7a1c3ab69fa579" +dependencies = [ + "arrayvec", + "typed-arena", + "unicode-width", +] + [[package]] name = "pretty_assertions" version = "1.4.0" @@ -5066,6 +5086,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typed-arena" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a" + [[package]] name = "typenum" version = "1.17.0" diff --git a/engine/baml-lib/schema-ast/Cargo.toml b/engine/baml-lib/schema-ast/Cargo.toml index 984d40645..fe64fbffc 100644 --- a/engine/baml-lib/schema-ast/Cargo.toml +++ b/engine/baml-lib/schema-ast/Cargo.toml @@ -18,13 +18,16 @@ bstd.workspace = true log = "0.4.20" serde_json.workspace = true serde.workspace = true - +anyhow.workspace = true pest = "2.1.3" pest_derive = "2.1.0" either = "1.8.1" test-log = "0.2.16" +pretty = "0.12.3" +[dev-dependencies] +unindent = "0.2.3" [features] debug_parser = [] diff --git a/engine/baml-lib/schema-ast/src/formatter/mod.rs b/engine/baml-lib/schema-ast/src/formatter/mod.rs new file mode 100644 index 000000000..29050da5e --- /dev/null +++ b/engine/baml-lib/schema-ast/src/formatter/mod.rs @@ -0,0 +1,315 @@ +use std::{ + borrow::BorrowMut, + cell::{RefCell, RefMut}, + rc::Rc, + sync::Arc, +}; + +use crate::parser::{BAMLParser, Rule}; +use anyhow::{anyhow, Result}; +use pest::{ + iterators::{Pair, Pairs}, + Parser, +}; +use pretty::RcDoc; + +const INDENT_WIDTH: isize = 4; +pub fn format_schema(source: &str) -> Result { + let mut schema = BAMLParser::parse(Rule::schema, source)?; + let schema_pair = schema.next().ok_or(anyhow!("Expected a schema"))?; + if schema_pair.as_rule() != Rule::schema { + return Err(anyhow!("Expected a schema")); + } + + let doc = schema_to_doc(schema_pair.into_inner())?; + let mut w = Vec::new(); + doc.render(10, &mut w) + .map_err(|_| anyhow!("Failed to render doc"))?; + Ok(String::from_utf8(w).map_err(|_| anyhow!("Failed to convert to string"))?) +} + +macro_rules! next_pair { + ($pairs:ident, $rule:expr) => {{ + match $pairs.peek() { + Some(pair) => { + if pair.as_rule() != $rule { + Err(anyhow!( + "Expected a {:?}, got a {:?} ({}:{})", + $rule, + pair.as_rule(), + file!(), + line!() + )) + } else { + $pairs.next(); + Ok(pair) + } + } + None => Err(anyhow!("Expected a {}", stringify!($rule))), + } + }}; + + ($pairs:ident, $rule:expr, optional) => {{ + match $pairs.peek() { + Some(pair) => { + if pair.as_rule() == $rule { + $pairs.next() + } else { + None + } + } + None => None, + } + }}; +} + +fn schema_to_doc(mut pairs: Pairs<'_, Rule>) -> Result> { + let mut doc = RcDoc::nil(); + + for pair in &mut pairs { + match pair.as_rule() { + Rule::type_expression_block => { + doc = doc.append(type_expression_block_to_doc(pair.into_inner())?); + } + Rule::EOI => { + // skip + } + _ => { + panic!("Unhandled rule: {:?}", pair.as_rule()); + } + } + } + Ok(doc) +} + +fn type_expression_block_to_doc(mut pairs: Pairs<'_, Rule>) -> Result> { + let class_or_enum = next_pair!(pairs, Rule::identifier)?; + let ident = next_pair!(pairs, Rule::identifier)?; + next_pair!(pairs, Rule::named_argument_list, optional); + next_pair!(pairs, Rule::BLOCK_OPEN)?; + let contents = next_pair!(pairs, Rule::type_expression_contents)?; + next_pair!(pairs, Rule::BLOCK_CLOSE)?; + + Ok(RcDoc::nil() + .append(pair_to_doc_text(class_or_enum)) + .append(RcDoc::space()) + .append(pair_to_doc_text(ident)) + .append(RcDoc::space()) + .append(RcDoc::text("{")) + .append( + type_expression_contents_to_doc(contents.into_inner())? + .nest(INDENT_WIDTH) + .group(), + ) + .append(RcDoc::text("}"))) +} + +fn type_expression_contents_to_doc(mut pairs: Pairs<'_, Rule>) -> Result> { + let mut content_docs = vec![]; + + for pair in &mut pairs { + match pair.as_rule() { + Rule::type_expression => { + content_docs.push(type_expression_to_doc(pair.into_inner())?); + } + Rule::block_attribute => { + content_docs.push(pair_to_doc_text(pair)); + } + Rule::comment_block => { + content_docs.push(pair_to_doc_text(pair)); + } + Rule::empty_lines => { + // skip + } + _ => { + panic!("Unhandled rule: {:?}", pair.as_rule()); + } + } + } + + let doc = if content_docs.len() > 0 { + content_docs + .into_iter() + .fold(RcDoc::hardline(), |acc, doc| { + acc.append(doc).append(RcDoc::hardline()) + }) + } else { + RcDoc::nil() + }; + + Ok(doc) +} + +fn type_expression_to_doc(mut pairs: Pairs<'_, Rule>) -> Result> { + let ident = next_pair!(pairs, Rule::identifier)?; + let field_type_chain = next_pair!(pairs, Rule::field_type_chain)?; + + let mut doc = RcDoc::nil() + .append(pair_to_doc_text(ident)) + .append(RcDoc::space()) + .append(field_type_chain_to_doc(field_type_chain.into_inner())?); + + for pair in pairs { + match pair.as_rule() { + Rule::NEWLINE => { + // skip + } + Rule::field_attribute => { + doc = doc.append(pair_to_doc_text(pair).nest(INDENT_WIDTH).group()); + } + Rule::trailing_comment => { + doc = doc.append(pair_to_doc_text(pair).nest(INDENT_WIDTH).group()); + } + _ => { + panic!("Unhandled rule: {:?}", pair.as_rule()); + } + } + } + + Ok(doc) +} + +fn field_type_chain_to_doc(pairs: Pairs<'_, Rule>) -> Result> { + let mut docs = vec![]; + + for pair in pairs { + match pair.as_rule() { + Rule::field_type_with_attr => { + docs.push(field_type_with_attr_to_doc(pair.into_inner())?); + } + Rule::field_operator => { + docs.push(RcDoc::text("|")); + } + _ => { + panic!("Unhandled rule: {:?}", pair.as_rule()); + } + } + } + + Ok(RcDoc::intersperse(docs, RcDoc::space()) + .nest(INDENT_WIDTH) + .group()) +} + +fn field_type_with_attr_to_doc(mut pairs: Pairs<'_, Rule>) -> Result> { + let mut docs = vec![]; + + for pair in &mut pairs { + match pair.as_rule() { + Rule::field_type => { + docs.push(field_type_to_doc(pair.into_inner())?); + } + Rule::field_attribute | Rule::trailing_comment => { + docs.push(pair_to_doc_text(pair)); + } + _ => { + panic!("Unhandled rule: {:?}", pair.as_rule()); + } + } + } + + Ok(RcDoc::intersperse(docs, RcDoc::space()) + .nest(INDENT_WIDTH) + .group()) +} + +fn field_type_to_doc(pairs: Pairs<'_, Rule>) -> Result> { + let mut docs = vec![]; + field_type_to_doc_impl(pairs, &mut docs)?; + Ok(docs + .into_iter() + .fold(RcDoc::nil(), |acc, doc| acc.append(doc))) +} + +fn field_type_to_doc_impl<'a>(pairs: Pairs<'a, Rule>, docs: &mut Vec>) -> Result<()> { + for pair in pairs { + match pair.as_rule() { + Rule::field_type | Rule::union => { + field_type_to_doc_impl(pair.into_inner(), docs)?; + } + Rule::field_operator => { + docs.push(RcDoc::space()); + docs.push(RcDoc::text("|")); + docs.push(RcDoc::space()); + } + Rule::base_type_with_attr | Rule::non_union => { + docs.push(pair_to_doc_text(pair)); + } + _ => { + panic!("Unhandled rule: {:?}", pair.as_rule()); + } + } + } + + Ok(()) +} + +fn pair_to_doc_text<'a>(pair: Pair<'a, Rule>) -> RcDoc<'a, ()> { + RcDoc::text(pair.as_str().trim()) +} + +#[cfg(test)] +mod tests { + use super::*; + use unindent::Unindent as _; + + #[track_caller] + fn assert_format_eq(schema: &str, expected: &str) -> Result<()> { + let formatted = format_schema(&schema.unindent().trim_end())?; + assert_eq!(expected.unindent().trim_end(), formatted); + Ok(()) + } + + #[test] + fn test_format_schema() -> anyhow::Result<()> { + assert_format_eq( + r#" + class Foo { + } + "#, + r#" + class Foo {} + "#, + )?; + + assert_format_eq( + r#" + class Foo { field1 string } + "#, + r#" + class Foo { + field1 string + } + "#, + )?; + + assert_format_eq( + r#" + class Foo { + + field1 string + } + "#, + r#" + class Foo { + field1 string + } + "#, + )?; + + assert_format_eq( + r#" + class Foo { + field1 string|int + } + "#, + r#" + class Foo { + field1 string | int + } + "#, + )?; + + Ok(()) + } +} diff --git a/engine/baml-lib/schema-ast/src/lib.rs b/engine/baml-lib/schema-ast/src/lib.rs index 241803f73..d6532dcfe 100644 --- a/engine/baml-lib/schema-ast/src/lib.rs +++ b/engine/baml-lib/schema-ast/src/lib.rs @@ -9,8 +9,11 @@ pub use self::parser::parse_schema; /// source span information. pub mod ast; +mod formatter; mod parser; +pub use formatter::format_schema; + /// Transform the input string into a valid (quoted and escaped) PSL string literal. /// /// PSL string literals have the exact same grammar as [JSON string diff --git a/engine/cli/src/commands.rs b/engine/cli/src/commands.rs index 1a8fc861d..ec907a0fb 100644 --- a/engine/cli/src/commands.rs +++ b/engine/cli/src/commands.rs @@ -35,6 +35,9 @@ pub(crate) enum Commands { #[command(about = "Deploy a BAML project to Boundary Cloud")] Deploy(crate::deploy::DeployArgs), + + #[command(about = "Format BAML source files")] + Format(crate::format::FormatArgs), } impl RuntimeCli { @@ -64,6 +67,7 @@ impl RuntimeCli { args.from = BamlRuntime::parse_baml_src_path(&args.from)?; t.block_on(async { args.run_async().await }) } + Commands::Format(args) => args.run(), } } } diff --git a/engine/cli/src/format.rs b/engine/cli/src/format.rs new file mode 100644 index 000000000..0a111b5dc --- /dev/null +++ b/engine/cli/src/format.rs @@ -0,0 +1,26 @@ +use std::{fs, path::PathBuf}; + +use anyhow::Result; +use clap::Args; +use internal_baml_core::internal_baml_schema_ast::format_schema; + +#[derive(Args, Debug)] +pub struct FormatArgs { + #[arg(long, help = "path/to/baml_src", default_value = "./baml_src")] + pub from: PathBuf, +} + +impl FormatArgs { + pub fn run(&self) -> Result<()> { + let source = fs::read_to_string(&self.from)?; + let formatted = format_schema(&source)?; + + let mut to = self.from.clone(); + to.set_extension("formatted.baml"); + fs::write(&to, formatted)?; + + log::info!("Formatted {} to {}", self.from.display(), to.display()); + + Ok(()) + } +} diff --git a/engine/cli/src/lib.rs b/engine/cli/src/lib.rs index 9f7168093..6ce4c4025 100644 --- a/engine/cli/src/lib.rs +++ b/engine/cli/src/lib.rs @@ -3,6 +3,7 @@ pub(crate) mod auth; pub(crate) mod colordiff; pub(crate) mod commands; pub(crate) mod deploy; +pub(crate) mod format; pub(crate) mod propelauth; pub(crate) mod tui;