Skip to content

Commit

Permalink
BAML formatter implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
dymk committed Dec 4, 2024
1 parent c17e0da commit f770bc8
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 1 deletion.
26 changes: 26 additions & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion engine/baml-lib/schema-ast/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ bstd.workspace = true
log = "0.4.20"
serde_json.workspace = true
serde.workspace = true

anyhow.workspace = true

pest = "2.1.3"
pest_derive = "2.1.0"
either = "1.8.1"
test-log = "0.2.16"
pretty = "0.12.3"

[dev-dependencies]
unindent = "0.2.3"

[features]
debug_parser = []
Expand Down
315 changes: 315 additions & 0 deletions engine/baml-lib/schema-ast/src/formatter/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
use std::{
borrow::BorrowMut,
cell::{RefCell, RefMut},
rc::Rc,
sync::Arc,
};

use crate::parser::{BAMLParser, Rule};
use anyhow::{anyhow, Result};
use pest::{
iterators::{Pair, Pairs},
Parser,
};
use pretty::RcDoc;

const INDENT_WIDTH: isize = 4;
pub fn format_schema(source: &str) -> Result<String> {
let mut schema = BAMLParser::parse(Rule::schema, source)?;
let schema_pair = schema.next().ok_or(anyhow!("Expected a schema"))?;
if schema_pair.as_rule() != Rule::schema {
return Err(anyhow!("Expected a schema"));
}

let doc = schema_to_doc(schema_pair.into_inner())?;
let mut w = Vec::new();
doc.render(10, &mut w)
.map_err(|_| anyhow!("Failed to render doc"))?;
Ok(String::from_utf8(w).map_err(|_| anyhow!("Failed to convert to string"))?)
}

macro_rules! next_pair {
($pairs:ident, $rule:expr) => {{
match $pairs.peek() {
Some(pair) => {
if pair.as_rule() != $rule {
Err(anyhow!(
"Expected a {:?}, got a {:?} ({}:{})",
$rule,
pair.as_rule(),
file!(),
line!()
))
} else {
$pairs.next();
Ok(pair)
}
}
None => Err(anyhow!("Expected a {}", stringify!($rule))),
}
}};

($pairs:ident, $rule:expr, optional) => {{
match $pairs.peek() {
Some(pair) => {
if pair.as_rule() == $rule {
$pairs.next()
} else {
None
}
}
None => None,
}
}};
}

fn schema_to_doc(mut pairs: Pairs<'_, Rule>) -> Result<RcDoc<'_, ()>> {
let mut doc = RcDoc::nil();

for pair in &mut pairs {
match pair.as_rule() {
Rule::type_expression_block => {
doc = doc.append(type_expression_block_to_doc(pair.into_inner())?);
}
Rule::EOI => {
// skip
}
_ => {
panic!("Unhandled rule: {:?}", pair.as_rule());
}
}
}
Ok(doc)
}

fn type_expression_block_to_doc(mut pairs: Pairs<'_, Rule>) -> Result<RcDoc<'_, ()>> {
let class_or_enum = next_pair!(pairs, Rule::identifier)?;
let ident = next_pair!(pairs, Rule::identifier)?;
next_pair!(pairs, Rule::named_argument_list, optional);
next_pair!(pairs, Rule::BLOCK_OPEN)?;
let contents = next_pair!(pairs, Rule::type_expression_contents)?;
next_pair!(pairs, Rule::BLOCK_CLOSE)?;

Ok(RcDoc::nil()
.append(pair_to_doc_text(class_or_enum))
.append(RcDoc::space())
.append(pair_to_doc_text(ident))
.append(RcDoc::space())
.append(RcDoc::text("{"))
.append(
type_expression_contents_to_doc(contents.into_inner())?
.nest(INDENT_WIDTH)
.group(),
)
.append(RcDoc::text("}")))
}

fn type_expression_contents_to_doc(mut pairs: Pairs<'_, Rule>) -> Result<RcDoc<'_, ()>> {
let mut content_docs = vec![];

for pair in &mut pairs {
match pair.as_rule() {
Rule::type_expression => {
content_docs.push(type_expression_to_doc(pair.into_inner())?);
}
Rule::block_attribute => {
content_docs.push(pair_to_doc_text(pair));
}
Rule::comment_block => {
content_docs.push(pair_to_doc_text(pair));
}
Rule::empty_lines => {
// skip
}
_ => {
panic!("Unhandled rule: {:?}", pair.as_rule());
}
}
}

let doc = if content_docs.len() > 0 {
content_docs
.into_iter()
.fold(RcDoc::hardline(), |acc, doc| {
acc.append(doc).append(RcDoc::hardline())
})
} else {
RcDoc::nil()
};

Ok(doc)
}

fn type_expression_to_doc(mut pairs: Pairs<'_, Rule>) -> Result<RcDoc<'_, ()>> {
let ident = next_pair!(pairs, Rule::identifier)?;
let field_type_chain = next_pair!(pairs, Rule::field_type_chain)?;

let mut doc = RcDoc::nil()
.append(pair_to_doc_text(ident))
.append(RcDoc::space())
.append(field_type_chain_to_doc(field_type_chain.into_inner())?);

for pair in pairs {
match pair.as_rule() {
Rule::NEWLINE => {
// skip
}
Rule::field_attribute => {
doc = doc.append(pair_to_doc_text(pair).nest(INDENT_WIDTH).group());
}
Rule::trailing_comment => {
doc = doc.append(pair_to_doc_text(pair).nest(INDENT_WIDTH).group());
}
_ => {
panic!("Unhandled rule: {:?}", pair.as_rule());
}
}
}

Ok(doc)
}

fn field_type_chain_to_doc(pairs: Pairs<'_, Rule>) -> Result<RcDoc<'_, ()>> {
let mut docs = vec![];

for pair in pairs {
match pair.as_rule() {
Rule::field_type_with_attr => {
docs.push(field_type_with_attr_to_doc(pair.into_inner())?);
}
Rule::field_operator => {
docs.push(RcDoc::text("|"));
}
_ => {
panic!("Unhandled rule: {:?}", pair.as_rule());
}
}
}

Ok(RcDoc::intersperse(docs, RcDoc::space())
.nest(INDENT_WIDTH)
.group())
}

fn field_type_with_attr_to_doc(mut pairs: Pairs<'_, Rule>) -> Result<RcDoc<'_, ()>> {
let mut docs = vec![];

for pair in &mut pairs {
match pair.as_rule() {
Rule::field_type => {
docs.push(field_type_to_doc(pair.into_inner())?);
}
Rule::field_attribute | Rule::trailing_comment => {
docs.push(pair_to_doc_text(pair));
}
_ => {
panic!("Unhandled rule: {:?}", pair.as_rule());
}
}
}

Ok(RcDoc::intersperse(docs, RcDoc::space())
.nest(INDENT_WIDTH)
.group())
}

fn field_type_to_doc(pairs: Pairs<'_, Rule>) -> Result<RcDoc<'_, ()>> {
let mut docs = vec![];
field_type_to_doc_impl(pairs, &mut docs)?;
Ok(docs
.into_iter()
.fold(RcDoc::nil(), |acc, doc| acc.append(doc)))
}

fn field_type_to_doc_impl<'a>(pairs: Pairs<'a, Rule>, docs: &mut Vec<RcDoc<'a, ()>>) -> Result<()> {
for pair in pairs {
match pair.as_rule() {
Rule::field_type | Rule::union => {
field_type_to_doc_impl(pair.into_inner(), docs)?;
}
Rule::field_operator => {
docs.push(RcDoc::space());
docs.push(RcDoc::text("|"));
docs.push(RcDoc::space());
}
Rule::base_type_with_attr | Rule::non_union => {
docs.push(pair_to_doc_text(pair));
}
_ => {
panic!("Unhandled rule: {:?}", pair.as_rule());
}
}
}

Ok(())
}

fn pair_to_doc_text<'a>(pair: Pair<'a, Rule>) -> RcDoc<'a, ()> {
RcDoc::text(pair.as_str().trim())
}

#[cfg(test)]
mod tests {
use super::*;
use unindent::Unindent as _;

#[track_caller]
fn assert_format_eq(schema: &str, expected: &str) -> Result<()> {
let formatted = format_schema(&schema.unindent().trim_end())?;
assert_eq!(expected.unindent().trim_end(), formatted);
Ok(())
}

#[test]
fn test_format_schema() -> anyhow::Result<()> {
assert_format_eq(
r#"
class Foo {
}
"#,
r#"
class Foo {}
"#,
)?;

assert_format_eq(
r#"
class Foo { field1 string }
"#,
r#"
class Foo {
field1 string
}
"#,
)?;

assert_format_eq(
r#"
class Foo {
field1 string
}
"#,
r#"
class Foo {
field1 string
}
"#,
)?;

assert_format_eq(
r#"
class Foo {
field1 string|int
}
"#,
r#"
class Foo {
field1 string | int
}
"#,
)?;

Ok(())
}
}
Loading

0 comments on commit f770bc8

Please sign in to comment.