Skip to content

Commit

Permalink
Generate client docstrings from BAML source docstrings (#1177)
Browse files Browse the repository at this point in the history
Pass BAML docstrings through to generated code.

## Examples

For this baml code:
```
/// A Martian organism with an age.
/// Such a nice type.
class Martian {
  /// The age of the Martian in Mars years.
  /// So many Mars years.
  age int @check(young_enough, {{ this < 30 }})
}

/// An enum with three values,
/// ONE, TWO and THREE.
enum EnumOutput {

  /// The first enum.
  ONE

  /// The second enum.
  TWO
  THREE

  @@alias("VALUE_ENUM")
}
```

We get this python code:

```
class Martian(BaseModel):
    """A Martian organism with an age.
    Such a nice type."""
    age: Checked[int,Literal["young_enough"]]
    """The age of the Martian in Mars years.
    So many Mars years."""

class EnumOutput(str, Enum):
    """An enum with three values,
    ONE, TWO and THREE."""
    
    ONE = "ONE"
    """The first enum."""
    TWO = "TWO"
    """The second enum."""
    THREE = "THREE"
```

This typescript code:
```
/**
 * A Martian organism with an age.
 * Such a nice type.
 */
export interface Martian {
  /**
   * The age of the Martian in Mars years.
   * So many Mars years.
   */
  age: Checked<number,"young_enough">
  
}

/**
 * An enum with three values,
 * ONE, TWO and THREE.
 */
export enum EnumOutput {
  /**
   * The first enum.
   */
  ONE = "ONE",
  /**
   * The second enum.
   */
  TWO = "TWO",
  THREE = "THREE",
}
```

And this Ruby code:
```
    # A Martian organism with an age.
    # Such a nice type.
    class Martian < T::Struct
      include Baml::Sorbet::Struct
      # The age of the Martian in Mars years.
      # So many Mars years.
      const :age, Baml::Checked[Integer]

      def initialize(props)
        super(
          age: props[:age],
        )

        @props = props
      end
    end

    class EnumOutput < T::Enum
      # An enum with three values,
      # ONE, TWO and THREE.
      enums do
        ONE = new("ONE")
        TWO = new("TWO")
        THREE = new("THREE")
      end
    end
```

Testing done:
 - python, ts, and ruby integ-tests
- mouse hover on {class,field,enum,variant} gives nice intellisense in
VSCode.

<!-- ELLIPSIS_HIDDEN -->


> [!IMPORTANT]
> This PR generates client docstrings from BAML source docstrings for
Python, Ruby, and TypeScript by extracting and rendering them in
generated code.
> 
>   - **Behavior**:
> - Extracts docstrings from BAML source and includes them in generated
client code for Python, Ruby, and TypeScript.
> - Updates `Field` and `Class` structures in `repr.rs` to include
`docstring` field.
> - Adds `get_documentation()` method to `ClassWalker` and `FieldWalker`
in `class.rs` and `field.rs`.
>   - **Templates**:
> - Updates Python templates (`types.py.j2`, `partial_types.py.j2`) to
render docstrings.
> - Updates Ruby templates (`types.rb.j2`, `partial-types.rb.j2`) to
render docstrings.
> - Updates TypeScript templates (`types.ts.j2`) to render docstrings.
>   - **Tests**:
> - Adds tests in `repr.rs` and `parse_schema.rs` to verify docstring
extraction and rendering.
>     - Updates integration test files to include docstring examples.
>   - **Misc**:
>     - Makes `Comment` struct public in `comment.rs`.
>     - Updates `comments.mdx` to reflect docstring support.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 50a607d. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
imalsogreg authored Nov 16, 2024
1 parent 7689ce7 commit 170ece9
Show file tree
Hide file tree
Showing 34 changed files with 416 additions and 1,139 deletions.
2 changes: 1 addition & 1 deletion engine/baml-lib/baml-core/src/ir/json_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl WithJsonSchema for Walker<'_, &Enum> {
"enum": self.elem().values
.iter()
.map(|v| json!({
"const": v.elem.0.clone()
"const": v.0.elem.0.clone()
}))
.collect::<Vec<_>>(),

Expand Down
82 changes: 77 additions & 5 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ impl WithRepr<TemplateString> for TemplateStringWalker<'_> {
.map(|f| Field {
name: id.name().to_string(),
r#type: f,
docstring: None,
})
.ok()
})
Expand All @@ -638,7 +639,9 @@ pub struct EnumValue(pub String);
#[derive(serde::Serialize, Debug)]
pub struct Enum {
pub name: EnumId,
pub values: Vec<Node<EnumValue>>,
pub values: Vec<(Node<EnumValue>, Option<Docstring>)>,
/// Docstring.
pub docstring: Option<Docstring>,
}

impl WithRepr<EnumValue> for EnumValueWalker<'_> {
Expand Down Expand Up @@ -673,18 +676,20 @@ impl WithRepr<Enum> for EnumWalker<'_> {
fn repr(&self, db: &ParserDatabase) -> Result<Enum> {
Ok(Enum {
name: self.name().to_string(),
values: self
.values()
.map(|v| v.node(db))
.collect::<Result<Vec<_>>>()?,
values: self.values().map(|w| (w.node(db).map(|v| (v, w.documentation().map(|s| Docstring(s.to_string())))))).collect::<Result<Vec<_>,_>>()?,
docstring: self.get_documentation().map(|s| Docstring(s))
})
}
}

#[derive(serde::Serialize, Debug)]
pub struct Docstring(pub String);

#[derive(serde::Serialize, Debug)]
pub struct Field {
pub name: String,
pub r#type: Node<FieldType>,
pub docstring: Option<Docstring>,
}

impl WithRepr<Field> for FieldWalker<'_> {
Expand Down Expand Up @@ -714,8 +719,10 @@ impl WithRepr<Field> for FieldWalker<'_> {
.repr(db)?,
attributes: self.attributes(db),
},
docstring: self.get_documentation().map(|s| Docstring(s)),
})
}

}

type ClassId = String;
Expand All @@ -731,6 +738,9 @@ pub struct Class {

/// Parameters to the class definition.
pub inputs: Vec<(String, FieldType)>,

/// Docstring.
pub docstring: Option<Docstring>,
}

impl WithRepr<Class> for ClassWalker<'_> {
Expand Down Expand Up @@ -764,6 +774,7 @@ impl WithRepr<Class> for ClassWalker<'_> {
.collect::<Result<Vec<_>>>()?,
None => Vec::new(),
},
docstring: self.get_documentation().map(|s| Docstring(s))
})
}
}
Expand Down Expand Up @@ -1204,3 +1215,64 @@ pub fn make_test_ir(source_code: &str) -> anyhow::Result<IntermediateRepr> {
)?;
Ok(ir)
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ir::ir_helpers::IRHelper;

#[test]
fn test_docstrings() {
let ir = make_test_ir(r#"
/// Foo class.
class Foo {
/// Bar field.
bar string
/// Baz field.
baz int
}
/// Test enum.
enum TestEnum {
/// First variant.
FIRST
/// Second variant.
SECOND
THIRD
}
"#).unwrap();

// Test class docstrings
let foo = ir.find_class("Foo").as_ref().unwrap().clone().elem();
assert_eq!(foo.docstring.as_ref().unwrap().0.as_str(), "Foo class.");
match foo.static_fields.as_slice() {
[field1, field2] => {
assert_eq!(field1.elem.docstring.as_ref().unwrap().0, "Bar field.");
assert_eq!(field2.elem.docstring.as_ref().unwrap().0, "Baz field.");
},
_ => {
panic!("Expected 2 fields");
}
}

// Test enum docstrings
let test_enum = ir.find_enum("TestEnum").as_ref().unwrap().clone().elem();
assert_eq!(test_enum.docstring.as_ref().unwrap().0.as_str(), "Test enum.");
match test_enum.values.as_slice() {
[val1, val2, val3] => {
assert_eq!(val1.0.elem.0, "FIRST");
assert_eq!(val1.1.as_ref().unwrap().0, "First variant.");
assert_eq!(val2.0.elem.0, "SECOND");
assert_eq!(val2.1.as_ref().unwrap().0, "Second variant.");
assert_eq!(val3.0.elem.0, "THIRD");
assert!(val3.1.is_none());
},
_ => {
panic!("Expected 3 enum values");
}
}
}
}
6 changes: 3 additions & 3 deletions engine/baml-lib/baml-core/src/ir/walker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl<'a> Walker<'a, &'a Enum> {
pub fn walk_values(&'a self) -> impl Iterator<Item = Walker<'a, &'a EnumValue>> {
self.item.elem.values.iter().map(|v| Walker {
db: self.db,
item: v,
item: &v.0,
})
}

Expand All @@ -142,10 +142,10 @@ impl<'a> Walker<'a, &'a Enum> {
.elem
.values
.iter()
.find(|v| v.elem.0 == name)
.find(|v| v.0.elem.0 == name)
.map(|v| Walker {
db: self.db,
item: v,
item: &v.0,
})
}

Expand Down
6 changes: 6 additions & 0 deletions engine/baml-lib/parser-database/src/walkers/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ impl<'db> ClassWalker<'db> {
})
}


/// Class docstring.
pub fn get_documentation(&self) -> Option<String> {
self.ast_type_block().documentation.as_ref().map(|c| c.text.clone())
}

/// The name of the template string.
pub fn add_to_types(self, types: &mut internal_baml_jinja_types::PredefinedTypes) {
types.add_class(
Expand Down
5 changes: 5 additions & 0 deletions engine/baml-lib/parser-database/src/walkers/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ impl<'db> FieldWalker<'db> {

result
}

/// The field's docstring.
pub fn get_documentation(&self) -> Option<String> {
self.ast_field().documentation.as_ref().map(|c| c.text.clone())
}
}

impl<'db> WithName for FieldWalker<'db> {
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/schema-ast/src/ast/comment.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct Comment {
pub struct Comment {
pub text: String,
}
2 changes: 1 addition & 1 deletion engine/baml-lib/schema-ast/src/ast/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub struct Field<T> {
/// ^^^^^^^^^^^
/// name String @id @default("lol")
/// ```
pub(crate) documentation: Option<Comment>,
pub documentation: Option<Comment>,
/// The attributes of this field.
///
/// ```ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub struct TypeExpressionBlock {
/// Value2
/// }
/// ```
pub(crate) documentation: Option<Comment>,
pub documentation: Option<Comment>,
/// The location of this enum in the text representation.
pub span: Span,

Expand Down
101 changes: 100 additions & 1 deletion engine/baml-lib/schema-ast/src/parser/parse_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ pub fn parse_schema(
// Some(Rule::enum_declaration) => {
// pending_block_comment = Some(current);
// }
_ => (),
_ => {
pending_block_comment = Some(current);
}
}
}
// We do nothing here.
Expand Down Expand Up @@ -244,6 +246,103 @@ mod tests {
let result = parse_schema(&root_path.into(), &source).unwrap();
assert_eq!(result.1.errors().len(), 0);
}

#[test]
fn test_comments() {
let input = r##"
/// Doc comment for Foo
/// has multiple lines
class Foo {
/// A nice bar.
bar int
/// Followed by a
/// multiline baz.
baz string
}
/// Documented enum.
enum E {
/// Documented variant.
EFoo
/// Another documented variant.
EBar
EBaz
}
"##;
let root_path = "a.baml";
let source = SourceFile::new_static(root_path.into(), input);
let schema = parse_schema(&root_path.into(), &source).unwrap().0;
let mut tops = schema.iter_tops();
let foo_top = tops.next().unwrap().1;
match foo_top {
Top::Class(TypeExpressionBlock {
name,
fields,
documentation,
..
}) => {
assert_eq!(name.to_string().as_str(), "Foo");
assert_eq!(
documentation.as_ref().unwrap().text.as_str(),
"Doc comment for Foo\nhas multiple lines"
);
match fields.as_slice() {
[field1, field2] => {
assert_eq!(
field1.documentation.as_ref().unwrap().text.as_str(),
"A nice bar."
);
assert_eq!(
field2.documentation.as_ref().unwrap().text.as_str(),
"Followed by a\nmultiline baz."
);
}
_ => {
panic!("Expected exactly 2 fields");
}
}
}
_ => {
panic!("Expected class.")
}
}
let e_top = tops.next().unwrap().1;
match e_top {
Top::Enum(TypeExpressionBlock {
name,
fields,
documentation,
..
}) => {
assert_eq!(name.to_string().as_str(), "E");
assert_eq!(
documentation.as_ref().unwrap().text.as_str(),
"Documented enum."
);
match fields.as_slice() {
[field1, field2, field3] => {
assert_eq!(
field1.documentation.as_ref().unwrap().text.as_str(),
"Documented variant."
);
assert_eq!(
field2.documentation.as_ref().unwrap().text.as_str(),
"Another documented variant."
);
assert!(field3.documentation.is_none());
}
_ => {
panic!("Expected exactly 3 enum variants");
}
}
}
_ => {
panic!("Expected enum. got {e_top:?}")
}
}
}
}

fn get_expected_from_error(positives: &[Rule]) -> String {
Expand Down
2 changes: 1 addition & 1 deletion engine/language_client_codegen/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ impl<'ir> TryFrom<EnumWalker<'ir>> for TypeSpecWithMeta {
.elem
.values
.iter()
.map(|v| v.elem.0.to_string())
.map(|v| v.0.elem.0.to_string())
.collect(),
),
r#const: None,
Expand Down
Loading

0 comments on commit 170ece9

Please sign in to comment.