diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index aaef43ddb..891a3a729 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -34,6 +34,9 @@ RUN curl https://mise.run | sh \ # Install Rust RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +# Install WASM tools +RUN cargo install wasm-bindgen-cli@0.2.92 wasm-pack + # Install Infisical RUN curl -1sLf 'https://dl.cloudsmith.io/public/infisical/infisical-cli/setup.deb.sh' | sudo -E bash \ && sudo apt update && sudo apt install -y infisical diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index f357144e9..8cb95463c 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -177,7 +177,9 @@ impl ArgCoercer { (FieldType::Enum(name), _) => match value { BamlValue::String(s) => { if let Ok(e) = ir.find_enum(name) { - if e.walk_values().any(|v| v.item.elem.0 == *s) { + if e.walk_values().any(|v| v.item.elem.0 == *s) + || e.item.attributes.get("dynamic_type").is_some() + { Ok(BamlValue::Enum(name.to_string(), s.to_string())) } else { scope.push_error(format!( diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index affd6872b..1f8140630 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -2,8 +2,8 @@ use std::collections::HashSet; use anyhow::{anyhow, Result}; use baml_types::{ - Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, StreamingBehavior, StringOr, - UnresolvedValue, + Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, StreamingBehavior, + StringOr, UnresolvedValue, }; use either::Either; use indexmap::{IndexMap, IndexSet}; @@ -15,7 +15,9 @@ use internal_baml_parser_database::{ Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, TypeWalker, }; -use internal_baml_schema_ast::ast::{self, Attribute, FieldArity, SubType, ValExpId, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{ + self, Attribute, FieldArity, SubType, ValExpId, WithName, WithSpan, +}; use internal_llm_client::{ClientProvider, ClientSpec, UnresolvedClientProperty}; use serde::Serialize; @@ -179,6 +181,9 @@ impl IntermediateRepr { db: &ParserDatabase, configuration: Configuration, ) -> Result { + // TODO: We're iterating over the AST tops once for every property in + // the IR. Easy performance optimization here by iterating only one time + // and distributing the tops to the appropriate IR properties. let mut repr = IntermediateRepr { enums: db .walk_enums() @@ -347,10 +352,7 @@ fn to_ir_attributes( }); let streaming_done = streaming_done.as_ref().and_then(|v| { if *v { - Some(( - "stream.done".to_string(), - UnresolvedValue::Bool(true, ()), - )) + Some(("stream.done".to_string(), UnresolvedValue::Bool(true, ()))) } else { None } @@ -594,7 +596,6 @@ impl WithRepr for ast::FieldType { ), }; - let use_metadata = has_constraints || has_special_streaming_behavior; let with_constraints = if use_metadata { FieldType::WithMetadata { @@ -609,30 +610,6 @@ impl WithRepr for ast::FieldType { } } -// #[derive(serde::Serialize, Debug)] -// pub enum Identifier { -// /// Starts with env.* -// ENV(String), -// /// The path to a Local Identifer + the local identifer. Separated by '.' -// #[allow(dead_code)] -// Ref(Vec), -// /// A string without spaces or '.' Always starts with a letter. May contain numbers -// Local(String), -// /// Special types (always lowercase). -// Primitive(baml_types::TypeValue), -// } - -// impl Identifier { -// pub fn name(&self) -> String { -// match self { -// Identifier::ENV(k) => k.clone(), -// Identifier::Ref(r) => r.join("."), -// Identifier::Local(l) => l.clone(), -// Identifier::Primitive(p) => p.to_string(), -// } -// } -// } - type TemplateStringId = String; #[derive(Debug)] @@ -717,7 +694,15 @@ impl WithRepr for EnumWalker<'_> { fn repr(&self, db: &ParserDatabase) -> Result { Ok(Enum { - name: self.name().to_string(), + // TODO: #1343 Temporary solution until we implement scoping in the AST. + name: if self.ast_type_block().is_dynamic_type_def { + self.name() + .strip_prefix(ast::DYNAMIC_TYPE_NAME_PREFIX) + .unwrap() + .to_string() + } else { + self.name().to_string() + }, values: self .values() .map(|w| { @@ -803,7 +788,15 @@ impl WithRepr for ClassWalker<'_> { fn repr(&self, db: &ParserDatabase) -> Result { Ok(Class { - name: self.name().to_string(), + // TODO: #1343 Temporary solution until we implement scoping in the AST. + name: if self.ast_type_block().is_dynamic_type_def { + self.name() + .strip_prefix(ast::DYNAMIC_TYPE_NAME_PREFIX) + .unwrap() + .to_string() + } else { + self.name().to_string() + }, static_fields: self .static_fields() .map(|e| e.node(db)) @@ -1118,6 +1111,21 @@ impl WithRepr for ConfigurationWalker<'_> { } } +// TODO: #1343 Temporary solution until we implement scoping in the AST. +#[derive(Debug)] +pub enum TypeBuilderEntry { + Enum(Node), + Class(Node), + TypeAlias(Node), +} + +// TODO: #1343 Temporary solution until we implement scoping in the AST. +#[derive(Debug)] +pub struct TestTypeBuilder { + pub entries: Vec, + pub structural_recursive_alias_cycles: Vec>, +} + #[derive(serde::Serialize, Debug)] pub struct TestCaseFunction(String); @@ -1133,6 +1141,7 @@ pub struct TestCase { pub functions: Vec>, pub args: IndexMap>, pub constraints: Vec, + pub type_builder: TestTypeBuilder, } impl WithRepr for (&ConfigurationWalker<'_>, usize) { @@ -1180,6 +1189,69 @@ impl WithRepr for ConfigurationWalker<'_> { let functions = (0..self.test_case().functions.len()) .map(|i| (self, i).node(db)) .collect::>>()?; + + // TODO: #1343 Temporary solution until we implement scoping in the AST. + let enums = self + .test_case() + .type_builder_scoped_db + .walk_enums() + .filter(|e| { + self.test_case().type_builder_scoped_db.ast()[e.id].is_dynamic_type_def + || db.find_type_by_str(e.name()).is_none() + }) + .map(|e| e.node(&self.test_case().type_builder_scoped_db)) + .collect::>>>()?; + let classes = self + .test_case() + .type_builder_scoped_db + .walk_classes() + .filter(|c| { + self.test_case().type_builder_scoped_db.ast()[c.id].is_dynamic_type_def + || db.find_type_by_str(c.name()).is_none() + }) + .map(|c| c.node(&self.test_case().type_builder_scoped_db)) + .collect::>>>()?; + let type_aliases = self + .test_case() + .type_builder_scoped_db + .walk_type_aliases() + .filter(|a| db.find_type_by_str(a.name()).is_none()) + .map(|a| a.node(&self.test_case().type_builder_scoped_db)) + .collect::>>>()?; + let mut type_builder_entries = Vec::new(); + + for e in enums { + type_builder_entries.push(TypeBuilderEntry::Enum(e)); + } + for c in classes { + type_builder_entries.push(TypeBuilderEntry::Class(c)); + } + for a in type_aliases { + type_builder_entries.push(TypeBuilderEntry::TypeAlias(a)); + } + + let mut recursive_aliases = vec![]; + for cycle in self + .test_case() + .type_builder_scoped_db + .recursive_alias_cycles() + { + let mut component = IndexMap::new(); + for id in cycle { + let alias = &self.test_case().type_builder_scoped_db.ast()[*id]; + // Those are global cycles, skip. + if db.find_type_by_str(alias.name()).is_some() { + continue; + } + // Cycles defined in the scoped test type builder block. + component.insert( + alias.name().to_string(), + alias.value.repr(&self.test_case().type_builder_scoped_db)?, + ); + } + recursive_aliases.push(component); + } + Ok(TestCase { name: self.name().to_string(), args: self @@ -1195,9 +1267,14 @@ impl WithRepr for ConfigurationWalker<'_> { .constraints .into_iter() .collect::>(), + type_builder: TestTypeBuilder { + entries: type_builder_entries, + structural_recursive_alias_cycles: recursive_aliases, + }, }) } } + #[derive(Debug, Clone, Serialize)] pub enum Prompt { // The prompt stirng, and a list of input replacer keys (raw key w/ magic string, and key to replace with) @@ -1440,7 +1517,6 @@ mod tests { let alias = class.find_field("field").unwrap(); assert_eq!(*alias.r#type(), FieldType::Primitive(TypeValue::Int)); - } #[test] @@ -1461,7 +1537,10 @@ mod tests { let class = ir.find_class("Test").unwrap(); let alias = class.find_field("field").unwrap(); - let FieldType::WithMetadata { base, constraints, .. } = alias.r#type() else { + let FieldType::WithMetadata { + base, constraints, .. + } = alias.r#type() + else { panic!( "expected resolved constrained type, found {:?}", alias.r#type() diff --git a/engine/baml-lib/baml-core/src/ir/walker.rs b/engine/baml-lib/baml-core/src/ir/walker.rs index 624b224ce..57bc1d0cd 100644 --- a/engine/baml-lib/baml-core/src/ir/walker.rs +++ b/engine/baml-lib/baml-core/src/ir/walker.rs @@ -9,8 +9,8 @@ use internal_llm_client::ClientSpec; use std::collections::{HashMap, HashSet}; use super::{ - repr::{self, FunctionConfig, WithRepr}, - Class, Client, Enum, EnumValue, Field, FunctionNode, IRHelper, Impl, RetryPolicy, + repr::{self, FunctionConfig, TypeBuilderEntry, WithRepr}, + Class, Client, Enum, EnumValue, Field, FieldType, FunctionNode, IRHelper, Impl, RetryPolicy, TemplateString, TestCase, TypeAlias, Walker, }; use crate::ir::jinja_helpers::render_expression; @@ -224,6 +224,21 @@ impl<'a> Walker<'a, (&'a FunctionNode, &'a TestCase)> { .collect() } + // TODO: #1343 Temporary solution until we implement scoping in the AST. + pub fn type_builder_contents(&self) -> &[TypeBuilderEntry] { + &self.item.1.elem.type_builder.entries + } + + // TODO: #1343 Temporary solution until we implement scoping in the AST. + pub fn type_builder_recursive_aliases(&self) -> &[IndexMap] { + &self + .item + .1 + .elem + .type_builder + .structural_recursive_alias_cycles + } + pub fn function(&'a self) -> Walker<'a, &'a FunctionNode> { Walker { db: self.db, diff --git a/engine/baml-lib/baml-core/src/lib.rs b/engine/baml-lib/baml-core/src/lib.rs index f6936829e..4563f7d76 100644 --- a/engine/baml-lib/baml-core/src/lib.rs +++ b/engine/baml-lib/baml-core/src/lib.rs @@ -3,10 +3,13 @@ #![allow(clippy::derive_partial_eq_without_eq)] pub use internal_baml_diagnostics; +use internal_baml_parser_database::TypeWalker; pub use internal_baml_parser_database::{self}; +use internal_baml_schema_ast::ast::{Identifier, WithName}; pub use internal_baml_schema_ast::{self, ast}; +use ir::repr::WithRepr; use rayon::prelude::*; use std::{ path::{Path, PathBuf}, @@ -98,6 +101,9 @@ pub fn validate(root_path: &Path, files: Vec) -> ValidatedSchema { // Some last linker stuff can only happen post validation. db.finalize(&mut diagnostics); + // TODO: #1343 Temporary solution until we implement scoping in the AST. + validate_type_builder_blocks(&mut diagnostics, &mut db, &configuration); + ValidatedSchema { db, diagnostics, @@ -105,6 +111,163 @@ pub fn validate(root_path: &Path, files: Vec) -> ValidatedSchema { } } +/// TODO: This is a very ugly hack to implement scoping for type builder blocks +/// in test cases. Type builder blocks support all the type definitions (class, +/// enum, type alias), and all these definitions have access to both the global +/// and local scope but not the scope of other test cases. +/// +/// This codebase was not designed with scoping in mind, so there's no simple +/// way of implementing scopes in the AST and IR. +/// +/// # Hack Explanation +/// +/// For every single type_builder block within a test we are creating a separate +/// instance of [`internal_baml_parser_database::ParserDatabase`] that includes +/// both the global type defs and the local type builder defs in the same AST. +/// That way we can run all the validation logic that we normally execute for +/// the global scope but including the local scope as well, and it doesn't +/// become too complicated to figure out stuff like name resolution, alias +/// resolution, dependencies, etc. +/// +/// However, this increases memory usage significantly since we create a copy +/// of the entire AST for every single type builder block. We implemented it +/// this way because we wanted to ship the feature and delay AST refactoring, +/// since it would take much longer to refactor the AST and include scoping than +/// it would take to ship this hack. +fn validate_type_builder_blocks( + diagnostics: &mut Diagnostics, + db: &mut internal_baml_parser_database::ParserDatabase, + configuration: &Configuration, +) { + let mut test_case_scoped_dbs = Vec::new(); + for test in db.walk_test_cases() { + let mut scoped_db = internal_baml_parser_database::ParserDatabase::new(); + scoped_db.add_ast(db.ast().to_owned()); + + let Some(type_builder) = test.test_case().type_builder.as_ref() else { + continue; + }; + + let mut local_ast = ast::SchemaAst::new(); + for type_def in &type_builder.entries { + local_ast.tops.push(match type_def { + ast::TypeBuilderEntry::Class(c) => { + if c.attributes.iter().any(|attr| attr.name.name() == "dynamic") { + diagnostics.push_error(DatamodelError::new_validation_error( + "The `@@dynamic` attribute is not allowed in type_builder blocks", + c.span.to_owned(), + )); + continue; + } + + ast::Top::Class(c.to_owned()) + }, + ast::TypeBuilderEntry::Enum(e) => { + if e.attributes.iter().any(|attr| attr.name.name() == "dynamic") { + diagnostics.push_error(DatamodelError::new_validation_error( + "The `@@dynamic` attribute is not allowed in type_builder blocks", + e.span.to_owned(), + )); + continue; + } + + ast::Top::Enum(e.to_owned()) + }, + ast::TypeBuilderEntry::Dynamic(d) => { + if d.attributes.iter().any(|attr| attr.name.name() == "dynamic") { + diagnostics.push_error(DatamodelError::new_validation_error( + "Dynamic type definitions cannot contain the `@@dynamic` attribute", + d.span.to_owned(), + )); + continue; + } + + let mut dyn_type = d.to_owned(); + + // TODO: Extemely ugly hack to avoid collisions in the name + // interner. We use syntax that is not normally allowed by + // BAML for type names. + dyn_type.name = Identifier::Local( + format!("{}{}", ast::DYNAMIC_TYPE_NAME_PREFIX, dyn_type.name()), + dyn_type.span.to_owned(), + ); + + dyn_type.is_dynamic_type_def = true; + + // Resolve dynamic definition. It either appends to a + // @@dynamic class or enum. + match db.find_type_by_str(d.name()) { + Some(t) => match t { + TypeWalker::Class(cls) => { + if !cls.ast_type_block().attributes.iter().any(|attr| attr.name.name() == "dynamic") { + diagnostics.push_error(DatamodelError::new_validation_error( + &format!( + "Type '{}' does not contain the `@@dynamic` attribute so it cannot be modified in a type builder block", + cls.name() + ), + dyn_type.span.to_owned(), + )); + continue; + } + + ast::Top::Class(dyn_type) + }, + TypeWalker::Enum(enm) => { + if !enm.ast_type_block().attributes.iter().any(|attr| attr.name.name() == "dynamic") { + diagnostics.push_error(DatamodelError::new_validation_error( + &format!( + "Type '{}' does not contain the `@@dynamic` attribute so it cannot be modified in a type builder block", + enm.name() + ), + dyn_type.span.to_owned(), + )); + continue; + } + + ast::Top::Enum(dyn_type) + }, + TypeWalker::TypeAlias(_) => { + diagnostics.push_error(DatamodelError::new_validation_error( + &format!("The `dynamic` keyword only works on classes and enums, but type '{}' is a type alias", d.name()), + d.span.to_owned(), + )); + continue; + }, + }, + None => { + diagnostics.push_error(DatamodelError::new_validation_error( + &format!("Type '{}' not found", dyn_type.name()), + dyn_type.span.to_owned(), + )); + continue; + } + } + } + ast::TypeBuilderEntry::TypeAlias(assignment) => { + ast::Top::TypeAlias(assignment.to_owned()) + }, + }); + } + + scoped_db.add_ast(local_ast); + + if let Err(d) = scoped_db.validate(diagnostics) { + diagnostics.push(d); + continue; + } + validate::validate(&scoped_db, configuration.preview_features(), diagnostics); + if diagnostics.has_errors() { + continue; + } + scoped_db.finalize(diagnostics); + + test_case_scoped_dbs.push((test.id.0, scoped_db)); + } + for (test_id, scoped_db) in test_case_scoped_dbs.into_iter() { + db.add_test_case_db(test_id, scoped_db); + } +} + /// Loads all configuration blocks from a datamodel using the built-in source definitions. pub fn validate_single_file( root_path: &Path, diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs index 9e2bf6d83..0f57a6b31 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs @@ -7,7 +7,7 @@ use std::{ use internal_baml_diagnostics::DatamodelError; use internal_baml_parser_database::{Tarjan, TypeWalker}; use internal_baml_schema_ast::ast::{ - FieldType, SchemaAst, TypeAliasId, TypeExpId, WithName, WithSpan, + self, FieldType, SchemaAst, TypeAliasId, TypeExpId, WithName, WithSpan, }; use crate::validate::validation_pipeline::context::Context; @@ -133,7 +133,17 @@ where for component in &components { let cycle = component .iter() - .map(|id| ctx.db.ast()[*id].name().to_string()) + .map(|id| { + // TODO: #1343 Temporary solution until we implement scoping in the AST. + let name = ctx.db.ast()[*id].name().to_string(); + if name.starts_with(ast::DYNAMIC_TYPE_NAME_PREFIX) { + name.strip_prefix(ast::DYNAMIC_TYPE_NAME_PREFIX) + .map(ToOwned::to_owned) + .unwrap() + } else { + name + } + }) .collect::>() .join(" -> "); @@ -167,6 +177,17 @@ fn insert_required_class_deps( match ctx.db.find_type_by_str(ident.name()) { Some(TypeWalker::Class(class)) => { deps.insert(class.id); + + // TODO: #1343 Temporary solution until we implement scoping in the AST. + if !class.name().starts_with(ast::DYNAMIC_TYPE_NAME_PREFIX) { + let dyn_def_name = + format!("{}{}", ast::DYNAMIC_TYPE_NAME_PREFIX, class.name()); + if let Some(TypeWalker::Class(dyn_def)) = + ctx.db.find_type_by_str(&dyn_def_name) + { + deps.insert(dyn_def.id); + } + } } Some(TypeWalker::TypeAlias(alias)) => { // This code runs after aliases are already resolved. diff --git a/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types.baml b/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types.baml new file mode 100644 index 000000000..2df0d9280 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types.baml @@ -0,0 +1,53 @@ +class Resume { + name string + education Education[] + skills string[] + @@dynamic // This class is marked as @@dynamic. +} + +class Education { + school string + degree string + year int +} + +// This function returns the dynamic class defined above. +function ExtractResume(from_text: string) -> Resume { + client "openai/gpt-4o-mini" + prompt #"Hello"# +} + +test ReturnDynamicClassTest { + functions [ExtractResume] + // New type_builder block used to define types and inject dynamic props. + type_builder { + // Defines a new type available only within this test block. + class Experience { + title string + company string + start_date string + end_date string + } + + // This `dynamic` block is used to inject new properties into the + // `@@dynamic` part of the Resume class. + dynamic Resume { + experience Experience[] + } + } + args { + from_text #" + John Doe + + Education + - University of California, Berkeley, B.S. in Computer Science, 2020 + + Experience + - Software Engineer, Boundary, Sep 2022 - Sep 2023 + + Skills + - Python + - Java + "# + } +} diff --git a/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_external_cycle_errors.baml b/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_external_cycle_errors.baml new file mode 100644 index 000000000..d0bf67621 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_external_cycle_errors.baml @@ -0,0 +1,53 @@ +function TypeBuilderFn() -> string { + client "openai/gpt-4o-mini" + prompt #"Hello"# +} + +class DynamicClass { + a string + b string + @@dynamic +} + +test AttemptToIntroduceInfiniteCycle { + functions [TypeBuilderFn] + type_builder { + class A { + p B + } + class B { + p C + } + class C { + p A + } + + dynamic DynamicClass { + cycle A + } + } + args { + from_text "Test" + } +} + +test AttemptToMakeClassInfinitelyRecursive { + functions [TypeBuilderFn] + type_builder { + dynamic DynamicClass { + cycle DynamicClass + } + } + args { + from_text "Test" + } +} + +// error: Error validating: These classes form a dependency cycle: A -> B -> C +// --> tests/dynamic_types_external_cycle_errors.baml:15 +// | +// 14 | type_builder { +// 15 | class A { +// 16 | p B +// 17 | } +// | diff --git a/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_internal_cycle_errors.baml b/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_internal_cycle_errors.baml new file mode 100644 index 000000000..f3fb4ed7a --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_internal_cycle_errors.baml @@ -0,0 +1,31 @@ +function TypeBuilderFn() -> string { + client "openai/gpt-4o-mini" + prompt #"Hello"# +} + +class DynamicClass { + a string + b string + @@dynamic +} + +test AttemptToMakeClassInfinitelyRecursive { + functions [TypeBuilderFn] + type_builder { + dynamic DynamicClass { + cycle DynamicClass + } + } + args { + from_text "Test" + } +} + +// error: Error validating: These classes form a dependency cycle: DynamicClass +// --> tests/dynamic_types_internal_cycle_errors.baml:15 +// | +// 14 | type_builder { +// 15 | dynamic DynamicClass { +// 16 | cycle DynamicClass +// 17 | } +// | diff --git a/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_parser_errors.baml b/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_parser_errors.baml new file mode 100644 index 000000000..2f1f4e03f --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_parser_errors.baml @@ -0,0 +1,83 @@ +function TypeBuilderFn(from_text: string) -> Resume { + client "openai/gpt-4o-mini" + prompt #"Hello"# + type_builder { + class Foo { + foo string + } + dynamic Bar { + bar int + } + } +} + +test MultipleTypeBuilderBlocks { + functions [TypeBuilderFn] + type_builder { + class Foo { + foo string + } + dynamic Bar { + bar int + } + } + type_builder { + class A { + a string + } + dynamic B { + b int + } + } + args { + from_text "Test" + } +} + +test IncompleteSyntax { + functions [TypeBuilderFn] + type_builder { + type + + dynamic Bar { + bar int + } + } + args { + from_text "Test" + } +} + +// error: Error validating: Only tests may have a type_builder block. +// --> tests/dynamic_types_parser_errors.baml:4 +// | +// 3 | prompt #"Hello"# +// 4 | type_builder { +// 5 | class Foo { +// 6 | foo string +// 7 | } +// 8 | dynamic Bar { +// 9 | bar int +// 10 | } +// 11 | } +// | +// error: Error validating: Definition of multiple `type_builder` blocks in the same parent block +// --> tests/dynamic_types_parser_errors.baml:24 +// | +// 23 | } +// 24 | type_builder { +// 25 | class A { +// 26 | a string +// 27 | } +// 28 | dynamic B { +// 29 | b int +// 30 | } +// 31 | } +// | +// error: Error validating: Syntax error in type builder block +// --> tests/dynamic_types_parser_errors.baml:40 +// | +// 39 | type_builder { +// 40 | type +// 41 | +// | diff --git a/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_validation_errors.baml b/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_validation_errors.baml new file mode 100644 index 000000000..e0b136448 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/tests/dynamic_types_validation_errors.baml @@ -0,0 +1,151 @@ +function TypeBuilderFn() -> string { + client "openai/gpt-4o-mini" + prompt #"Hello"# +} + +class NonDynamic { + a string + b string +} + +test AttemptToModifyNonDynamicClass { + functions [TypeBuilderFn] + type_builder { + dynamic NonDynamic { + c string + } + } + args { + from_text "Test" + } +} + +type SomeAlias = NonDynamic + +test AttemptToModifyTypeAlias { + functions [TypeBuilderFn] + type_builder { + dynamic SomeAlias { + c string + } + } + args { + from_text "Test" + } +} + +class DynamicClass { + a string + b string + @@dynamic +} + +test AttemptToAddDynamicAttrInDyanmicDef { + functions [TypeBuilderFn] + type_builder { + class NotAllowedHere { + a string + @@dynamic + } + enum StillNotAllowed { + A + @@dynamic + } + dynamic DynamicClass { + c string + @@dynamic + } + } + args { + from_text "Test" + } +} + +test AttemptToModifySameDynamicMultipleTimes { + functions [TypeBuilderFn] + type_builder { + dynamic DynamicClass { + c string + } + dynamic DynamicClass { + d string + } + } + args { + from_text "Test" + } +} + +test NameAlreadyExists { + functions [TypeBuilderFn] + type_builder { + class NonDynamic { + a string + b string + } + dynamic DynamicClass { + non_dynamic NonDynamic + } + } + args { + from_text "Test" + } +} + +// error: Error validating: Type 'NonDynamic' does not contain the `@@dynamic` attribute so it cannot be modified in a type builder block +// --> tests/dynamic_types_validation_errors.baml:14 +// | +// 13 | type_builder { +// 14 | dynamic NonDynamic { +// 15 | c string +// 16 | } +// | +// error: Error validating: The `dynamic` keyword only works on classes and enums, but type 'SomeAlias' is a type alias +// --> tests/dynamic_types_validation_errors.baml:28 +// | +// 27 | type_builder { +// 28 | dynamic SomeAlias { +// 29 | c string +// 30 | } +// | +// error: Error validating: The `@@dynamic` attribute is not allowed in type_builder blocks +// --> tests/dynamic_types_validation_errors.baml:46 +// | +// 45 | type_builder { +// 46 | class NotAllowedHere { +// 47 | a string +// 48 | @@dynamic +// 49 | } +// | +// error: Error validating: The `@@dynamic` attribute is not allowed in type_builder blocks +// --> tests/dynamic_types_validation_errors.baml:50 +// | +// 49 | } +// 50 | enum StillNotAllowed { +// 51 | A +// 52 | @@dynamic +// 53 | } +// | +// error: Error validating: Dynamic type definitions cannot contain the `@@dynamic` attribute +// --> tests/dynamic_types_validation_errors.baml:54 +// | +// 53 | } +// 54 | dynamic DynamicClass { +// 55 | c string +// 56 | @@dynamic +// 57 | } +// | +// error: Error validating: Multiple dynamic definitions for type `DynamicClass` +// --> tests/dynamic_types_validation_errors.baml:70 +// | +// 69 | } +// 70 | dynamic DynamicClass { +// 71 | d string +// 72 | } +// | +// error: The class "NonDynamic" cannot be defined because a class with that name already exists. +// --> tests/dynamic_types_validation_errors.baml:82 +// | +// 81 | type_builder { +// 82 | class NonDynamic { +// | diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index 53841d4ce..9e724ecbf 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -40,7 +40,7 @@ use std::collections::{HashMap, HashSet, VecDeque}; pub use coerce_expression::{coerce, coerce_array, coerce_opt}; pub use internal_baml_schema_ast::ast; -use internal_baml_schema_ast::ast::{FieldType, SchemaAst, WithName}; +use internal_baml_schema_ast::ast::{FieldType, SchemaAst, ValExpId, WithName}; pub use tarjan::Tarjan; pub use types::{ Attributes, ClientProperties, ContantDelayStrategy, ExponentialBackoffStrategy, PrinterType, @@ -96,6 +96,15 @@ impl ParserDatabase { } } + /// TODO: #1343 Temporary solution until we implement scoping in the AST. + pub fn add_test_case_db(&mut self, test_cases_id: ValExpId, scoped_db: Self) { + self.types + .test_cases + .get_mut(&test_cases_id) + .unwrap() + .type_builder_scoped_db = scoped_db; + } + /// See the docs on [ParserDatabase](/struct.ParserDatabase.html). pub fn add_ast(&mut self, ast: SchemaAst) { self.ast.tops.extend(ast.tops); diff --git a/engine/baml-lib/parser-database/src/names/mod.rs b/engine/baml-lib/parser-database/src/names/mod.rs index 4d9446e27..df5a9412a 100644 --- a/engine/baml-lib/parser-database/src/names/mod.rs +++ b/engine/baml-lib/parser-database/src/names/mod.rs @@ -216,6 +216,24 @@ fn insert_name( if let Some(existing) = namespace.insert(name, top_id) { let current_type = top.get_type(); + + // TODO: #1343 Temporary solution until we implement scoping in the AST. + if ctx.ast[existing] + .name() + .starts_with(ast::DYNAMIC_TYPE_NAME_PREFIX) + { + return ctx.push_error(DatamodelError::new_validation_error( + &format!( + "Multiple dynamic definitions for type `{}`", + ctx.ast[existing] + .name() + .strip_prefix(ast::DYNAMIC_TYPE_NAME_PREFIX) + .unwrap() + ), + top.span().to_owned(), + )); + } + if current_type != "impl" && current_type != "impl" { ctx.push_error(duplicate_top_error(&ctx.ast[existing], top)); } diff --git a/engine/baml-lib/parser-database/src/types/configurations.rs b/engine/baml-lib/parser-database/src/types/configurations.rs index 5ff685bc2..7dd8b4290 100644 --- a/engine/baml-lib/parser-database/src/types/configurations.rs +++ b/engine/baml-lib/parser-database/src/types/configurations.rs @@ -198,6 +198,7 @@ fn visit_strategy( } } +// TODO: Are test cases "configurations"? pub(crate) fn visit_test_case<'db>( idx: ValExpId, config: &'db ValueExprBlock, @@ -289,6 +290,8 @@ pub(crate) fn visit_test_case<'db>( args, args_field_span: args_field_span.clone(), constraints, + type_builder: config.type_builder.clone(), + type_builder_scoped_db: Default::default(), }, ); } diff --git a/engine/baml-lib/parser-database/src/types/mod.rs b/engine/baml-lib/parser-database/src/types/mod.rs index acab429e9..413999d78 100644 --- a/engine/baml-lib/parser-database/src/types/mod.rs +++ b/engine/baml-lib/parser-database/src/types/mod.rs @@ -13,8 +13,8 @@ use indexmap::IndexMap; use internal_baml_diagnostics::{Diagnostics, Span}; use internal_baml_prompt_parser::ast::{ChatBlock, PrinterBlock, Variable}; use internal_baml_schema_ast::ast::{ - self, Expression, FieldId, FieldType, RawString, TypeAliasId, ValExpId, WithIdentifier, - WithName, WithSpan, + self, BlockArgs, Expression, FieldId, FieldType, RawString, TypeAliasId, TypeBuilderBlock, + ValExpId, WithIdentifier, WithName, WithSpan, }; use internal_llm_client::{ClientProvider, PropertyHandler, UnresolvedClientProperty}; @@ -176,6 +176,9 @@ pub struct TestCase { pub args: IndexMap)>, pub args_field_span: Span, pub constraints: Vec<(Constraint, Span, Span)>, + pub type_builder: Option, + // TODO: #1343 Temporary solution until we implement scoping in the AST. + pub type_builder_scoped_db: ParserDatabase, } #[derive(Debug, Clone)] @@ -409,10 +412,10 @@ fn visit_class<'db>( let mut used_types = class .iter_fields() - .flat_map(|(_, f)| f.expr.iter().flat_map(|e| e.flat_idns())) + .flat_map(|(_, f)| f.expr.iter().flat_map(FieldType::flat_idns)) .map(|id| id.name().to_string()) .collect::>(); - let input_deps = class.input().map(|f| f.flat_idns()).unwrap_or_default(); + let input_deps = class.input().map(BlockArgs::flat_idns).unwrap_or_default(); ctx.types.class_dependencies.insert(class_id, { used_types.extend(input_deps.iter().map(|id| id.name().to_string())); diff --git a/engine/baml-lib/parser-database/src/walkers/configuration.rs b/engine/baml-lib/parser-database/src/walkers/configuration.rs index b317a10dc..e21d750ed 100644 --- a/engine/baml-lib/parser-database/src/walkers/configuration.rs +++ b/engine/baml-lib/parser-database/src/walkers/configuration.rs @@ -2,11 +2,15 @@ use internal_baml_schema_ast::ast::{self, WithIdentifier, WithSpan}; use crate::types::{RetryPolicy, TestCase}; -/// A `class` declaration in the Prisma schema. +/// Subset of [`ast::ValueExprBlock`] that represents a configuration. +/// +/// Only retry policies are configurations, test cases are not really +/// "configurations" but we'll keep the old [`ConfigurationWalker`] name for +/// now. pub type ConfigurationWalker<'db> = super::Walker<'db, (ast::ValExpId, &'static str)>; impl ConfigurationWalker<'_> { - /// Get the AST node for this class. + /// Get the AST node for this [`ast::ValExpId`] (usually used for classes). pub fn ast_node(&self) -> &ast::ValueExprBlock { &self.db.ast[self.id.0] } diff --git a/engine/baml-lib/parser-database/src/walkers/enum.rs b/engine/baml-lib/parser-database/src/walkers/enum.rs index 8b74af42c..06a187da5 100644 --- a/engine/baml-lib/parser-database/src/walkers/enum.rs +++ b/engine/baml-lib/parser-database/src/walkers/enum.rs @@ -7,15 +7,11 @@ pub type EnumWalker<'db> = Walker<'db, ast::TypeExpId>; pub type EnumValueWalker<'db> = Walker<'db, (ast::TypeExpId, ast::FieldId)>; impl<'db> EnumWalker<'db> { - /// The name of the enum. - /// The values of the enum. pub fn values(self) -> impl ExactSizeIterator> { self.ast_type_block() .iter_fields() - .map(move |(valid_id, _)| self.walk((self.id, valid_id))) - .collect::>() - .into_iter() + .map(move |(field_id, _)| self.walk((self.id, field_id))) } /// Find a value by name. diff --git a/engine/baml-lib/parser-database/src/walkers/mod.rs b/engine/baml-lib/parser-database/src/walkers/mod.rs index ccc84d4be..dd166949b 100644 --- a/engine/baml-lib/parser-database/src/walkers/mod.rs +++ b/engine/baml-lib/parser-database/src/walkers/mod.rs @@ -67,10 +67,11 @@ pub enum TypeWalker<'db> { impl<'db> crate::ParserDatabase { /// Find an enum by name. pub fn find_enum(&'db self, idn: &Identifier) -> Option> { - self.find_type(idn).and_then(|either| match either { - TypeWalker::Enum(enm) => Some(enm), - _ => None, - }) + self.find_type(idn) + .and_then(|type_walker| match type_walker { + TypeWalker::Enum(enm) => Some(enm), + _ => None, + }) } fn find_top_by_str(&'db self, name: &str) -> Option<&'db TopId> { @@ -176,29 +177,25 @@ impl<'db> crate::ParserDatabase { let mut names: Vec = self.walk_classes().map(|c| c.name().to_string()).collect(); names.extend(self.walk_enums().map(|e| e.name().to_string())); // Add primitive types - names.extend( - vec!["string", "int", "float", "bool", "true", "false"] - .into_iter() - .map(String::from), - ); + names.extend(["string", "int", "float", "bool", "true", "false"].map(String::from)); names } - /// Get all the types that are valid in the schema. (including primitives) + /// Get all the valid functions in the schema. pub fn valid_function_names(&self) -> Vec { self.walk_functions() .map(|c| c.name().to_string()) .collect::>() } - /// Get all the types that are valid in the schema. (including primitives) + /// Get all the valid retry policies in the schema. pub fn valid_retry_policy_names(&self) -> Vec { self.walk_retry_policies() .map(|c| c.name().to_string()) .collect() } - /// Get all the types that are valid in the schema. (including primitives) + /// Get all the valid client names in the schema. pub fn valid_client_names(&self) -> Vec { self.walk_clients().map(|c| c.name().to_string()).collect() } @@ -236,7 +233,7 @@ impl<'db> crate::ParserDatabase { }) } - /// Walk all template strings in the schema. + /// Walk all templates strings in the schema. pub fn walk_templates(&self) -> impl Iterator> { self.ast() .iter_tops() @@ -247,7 +244,7 @@ impl<'db> crate::ParserDatabase { }) } - /// Walk all classes in the schema. + /// Walk all functions in the schema. pub fn walk_functions(&self) -> impl Iterator> { self.ast() .iter_tops() @@ -258,7 +255,7 @@ impl<'db> crate::ParserDatabase { }) } - /// Walk all classes in the schema. + /// Walk all clients in the schema. pub fn walk_clients(&self) -> impl Iterator> { self.ast() .iter_tops() @@ -269,7 +266,7 @@ impl<'db> crate::ParserDatabase { }) } - /// Walk all classes in the schema. + /// Walk all retry policies in the schema. pub fn walk_retry_policies(&self) -> impl Iterator> { self.ast() .iter_tops() @@ -280,7 +277,7 @@ impl<'db> crate::ParserDatabase { }) } - /// Walk all classes in the schema. + /// Walk all test cases in the schema. pub fn walk_test_cases(&self) -> impl Iterator> { self.ast() .iter_tops() diff --git a/engine/baml-lib/schema-ast/src/ast.rs b/engine/baml-lib/schema-ast/src/ast.rs index d34b3ecad..a350937ae 100644 --- a/engine/baml-lib/schema-ast/src/ast.rs +++ b/engine/baml-lib/schema-ast/src/ast.rs @@ -15,6 +15,7 @@ mod newline_type; mod template_string; mod top; mod traits; +mod type_builder_block; mod type_expression_block; mod value_expression_block; pub(crate) use self::comment::Comment; @@ -32,6 +33,7 @@ pub use newline_type::NewlineType; pub use template_string::TemplateString; pub use top::Top; pub use traits::{WithAttributes, WithDocumentation, WithIdentifier, WithName, WithSpan}; +pub use type_builder_block::{TypeBuilderBlock, TypeBuilderEntry, DYNAMIC_TYPE_NAME_PREFIX}; pub use type_expression_block::{FieldId, SubType, TypeExpressionBlock}; pub use value_expression_block::{BlockArg, BlockArgs, ValueExprBlock, ValueExprBlockType}; @@ -45,7 +47,7 @@ pub use value_expression_block::{BlockArg, BlockArgs, ValueExprBlock, ValueExprB /// node is annotated with its location in the text representation. /// Basically, the AST is an object oriented representation of the datamodel's /// text. Schema = Datamodel + Generators + Datasources -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct SchemaAst { /// All models, enums, composite types, datasources, generators and type aliases. pub tops: Vec, diff --git a/engine/baml-lib/schema-ast/src/ast/assignment.rs b/engine/baml-lib/schema-ast/src/ast/assignment.rs index bfa675645..a91270834 100644 --- a/engine/baml-lib/schema-ast/src/ast/assignment.rs +++ b/engine/baml-lib/schema-ast/src/ast/assignment.rs @@ -2,10 +2,7 @@ //! //! As of right now the only supported "assignments" are type aliases. -use super::{ - traits::WithAttributes, Attribute, BlockArgs, Comment, Field, FieldType, Identifier, Span, - WithDocumentation, WithIdentifier, WithSpan, -}; +use super::{FieldType, Identifier, Span, WithIdentifier, WithSpan}; /// Assignment expression. `left = right`. #[derive(Debug, Clone)] diff --git a/engine/baml-lib/schema-ast/src/ast/type_builder_block.rs b/engine/baml-lib/schema-ast/src/ast/type_builder_block.rs new file mode 100644 index 000000000..26f9257a8 --- /dev/null +++ b/engine/baml-lib/schema-ast/src/ast/type_builder_block.rs @@ -0,0 +1,34 @@ +use internal_baml_diagnostics::Span; + +use super::{Assignment, TypeExpressionBlock}; + +// TODO: #1343 Temporary solution until we implement scoping in the AST. +pub const DYNAMIC_TYPE_NAME_PREFIX: &str = "Dynamic::"; + +/// Blocks allowed in `type_builder` blocks. +#[derive(Debug, Clone)] +pub enum TypeBuilderEntry { + /// An enum declaration. + Enum(TypeExpressionBlock), + /// A class declaration. + Class(TypeExpressionBlock), + /// Type alias expression. + TypeAlias(Assignment), + /// Dynamic block. + Dynamic(TypeExpressionBlock), +} + +/// The `type_builder` block. +/// +/// ```ignore +/// test SomeTest { +/// type_builder { +/// // Contents +/// } +/// } +/// ``` +#[derive(Debug, Clone)] +pub struct TypeBuilderBlock { + pub entries: Vec, + pub span: Span, +} diff --git a/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs b/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs index 621f5b2a4..42716b0ce 100644 --- a/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs @@ -27,6 +27,7 @@ impl std::ops::Index for TypeExpressionBlock { pub enum SubType { Enum, Class, + Dynamic, Other(String), } @@ -85,6 +86,8 @@ pub struct TypeExpressionBlock { /// This is used to distinguish between enums and classes. pub sub_type: SubType, + /// TODO: #1343 Temporary solution until we implement scoping in the AST. + pub is_dynamic_type_def: bool, } impl TypeExpressionBlock { diff --git a/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs b/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs index 8962001b9..43e13edbf 100644 --- a/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs @@ -1,4 +1,5 @@ use super::argument::ArgumentId; +use super::type_builder_block::TypeBuilderBlock; use super::{ traits::WithAttributes, Attribute, Comment, Expression, Field, FieldType, Identifier, Span, WithDocumentation, WithIdentifier, WithSpan, @@ -155,6 +156,17 @@ pub struct ValueExprBlock { pub(crate) span: Span, pub fields: Vec>, + /// Type builder block. + /// + /// ```ignore + /// test Example { + /// type_builder { + /// // Contents + /// } + /// } + /// ``` + pub type_builder: Option, + pub block_type: ValueExprBlockType, } diff --git a/engine/baml-lib/schema-ast/src/parser/datamodel.pest b/engine/baml-lib/schema-ast/src/parser/datamodel.pest index d5ff707ae..2161aa25c 100644 --- a/engine/baml-lib/schema-ast/src/parser/datamodel.pest +++ b/engine/baml-lib/schema-ast/src/parser/datamodel.pest @@ -23,10 +23,19 @@ field_type_with_attr = { field_type ~ (NEWLINE? ~ (field_attribute | trailing_co value_expression_keyword = { FUNCTION_KEYWORD | TEST_KEYWORD | CLIENT_KEYWORD | RETRY_POLICY_KEYWORD | GENERATOR_KEYWORD } value_expression_block = { value_expression_keyword ~ identifier ~ named_argument_list? ~ ARROW? ~ field_type_chain? ~ SPACER_TEXT ~ BLOCK_OPEN ~ value_expression_contents ~ BLOCK_CLOSE } value_expression_contents = { - (value_expression | comment_block | block_attribute | empty_lines | BLOCK_LEVEL_CATCH_ALL)* + (type_builder_block | value_expression | comment_block | block_attribute | empty_lines | BLOCK_LEVEL_CATCH_ALL)* } value_expression = { identifier ~ expression? ~ (NEWLINE? ~ field_attribute)* ~ trailing_comment? } +// ###################################### +// Type builder +// ###################################### + +type_builder_block = { + TYPE_BUILDER_KEYWORD ~ BLOCK_OPEN ~ type_builder_contents ~ BLOCK_CLOSE +} +type_builder_contents = { (type_expression_block | type_alias | comment_block | empty_lines | BLOCK_LEVEL_CATCH_ALL)* } + // ###################################### ARROW = { SPACER_TEXT ~ "->" ~ SPACER_TEXT } @@ -181,6 +190,7 @@ CATCH_ALL = { (!NEWLINE ~ ANY)+ ~ NEWLINE? } FUNCTION_KEYWORD = { "function" } TEMPLATE_KEYWORD = { "template_string" | "string_template" } TEST_KEYWORD = { "test" } +TYPE_BUILDER_KEYWORD = { "type_builder" } CLIENT_KEYWORD = { "client" | "client" } GENERATOR_KEYWORD = { "generator" } RETRY_POLICY_KEYWORD = { "retry_policy" } diff --git a/engine/baml-lib/schema-ast/src/parser/mod.rs b/engine/baml-lib/schema-ast/src/parser/mod.rs index 5f9bf74de..be7e1df4d 100644 --- a/engine/baml-lib/schema-ast/src/parser/mod.rs +++ b/engine/baml-lib/schema-ast/src/parser/mod.rs @@ -9,6 +9,7 @@ mod parse_identifier; mod parse_named_args_list; mod parse_schema; mod parse_template_string; +mod parse_type_builder_block; mod parse_type_expression_block; mod parse_types; mod parse_value_expression_block; diff --git a/engine/baml-lib/schema-ast/src/parser/parse_type_builder_block.rs b/engine/baml-lib/schema-ast/src/parser/parse_type_builder_block.rs new file mode 100644 index 000000000..15acdef54 --- /dev/null +++ b/engine/baml-lib/schema-ast/src/parser/parse_type_builder_block.rs @@ -0,0 +1,175 @@ +use super::{ + helpers::{parsing_catch_all, Pair}, + Rule, +}; + +use crate::{ + assert_correct_parser, + ast::*, + parser::{ + parse_assignment::parse_assignment, + parse_type_expression_block::parse_type_expression_block, + }, +}; +use internal_baml_diagnostics::{DatamodelError, Diagnostics}; + +pub(crate) fn parse_type_builder_block( + pair: Pair<'_>, + diagnostics: &mut Diagnostics, +) -> Result { + assert_correct_parser!(pair, Rule::type_builder_block); + + let span = diagnostics.span(pair.as_span()); + let mut entries = Vec::new(); + + for current in pair.into_inner() { + match current.as_rule() { + // First token is the `type_builder` keyword. + Rule::TYPE_BUILDER_KEYWORD => {} + + // Second token is opening bracket. + Rule::BLOCK_OPEN => {} + + // Block content. + Rule::type_builder_contents => { + let mut pending_block_comment = None; + + for nested in current.into_inner() { + match nested.as_rule() { + Rule::comment_block => pending_block_comment = Some(nested), + + Rule::type_expression_block => { + let type_expr = parse_type_expression_block( + nested, + pending_block_comment.take(), + diagnostics, + ); + + match type_expr.sub_type { + SubType::Class => entries.push(TypeBuilderEntry::Class(type_expr)), + SubType::Enum => entries.push(TypeBuilderEntry::Enum(type_expr)), + SubType::Dynamic => { + entries.push(TypeBuilderEntry::Dynamic(type_expr)) + } + _ => {} // may need to save other somehow for error propagation + } + } + + Rule::type_alias => { + let assignment = parse_assignment(nested, diagnostics); + entries.push(TypeBuilderEntry::TypeAlias(assignment)); + } + + Rule::BLOCK_LEVEL_CATCH_ALL => { + diagnostics.push_error(DatamodelError::new_validation_error( + "Syntax error in type builder block", + diagnostics.span(nested.as_span()), + )) + } + + _ => parsing_catch_all(nested, "type_builder_contents"), + } + } + } + + // Last token, closing bracket. + Rule::BLOCK_CLOSE => {} + + _ => parsing_catch_all(current, "type_builder_block"), + } + } + + Ok(TypeBuilderBlock { entries, span }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::{BAMLParser, Rule}; + use internal_baml_diagnostics::{Diagnostics, SourceFile}; + use pest::Parser; + + #[test] + fn parse_block() { + let root_path = "test_file.baml"; + + let input = r#"type_builder { + class Example { + a string + b int + } + + enum Bar { + A + B + } + + /// Some doc + /// comment + dynamic Cls { + e Example + s string + } + + dynamic Enm { + C + D + } + + type Alias = Example + }"#; + + let source = SourceFile::new_static(root_path.into(), input); + let mut diagnostics = Diagnostics::new(root_path.into()); + + diagnostics.set_source(&source); + + let parsed = BAMLParser::parse(Rule::type_builder_block, input) + .unwrap() + .next() + .unwrap(); + + let type_buider_block = parse_type_builder_block(parsed, &mut diagnostics).unwrap(); + + assert_eq!(type_buider_block.entries.len(), 5); + + let TypeBuilderEntry::Class(example) = &type_buider_block.entries[0] else { + panic!( + "Expected class Example, got {:?}", + type_buider_block.entries[0] + ); + }; + + let TypeBuilderEntry::Enum(bar) = &type_buider_block.entries[1] else { + panic!("Expected enum Bar, got {:?}", type_buider_block.entries[1]); + }; + + let TypeBuilderEntry::Dynamic(cls) = &type_buider_block.entries[2] else { + panic!( + "Expected dynamic Cls, got {:?}", + type_buider_block.entries[2] + ); + }; + + let TypeBuilderEntry::Dynamic(enm) = &type_buider_block.entries[3] else { + panic!( + "Expected dynamic Enm, got {:?}", + type_buider_block.entries[3] + ); + }; + + let TypeBuilderEntry::TypeAlias(alias) = &type_buider_block.entries[4] else { + panic!( + "Expected type Alias, got {:?}", + type_buider_block.entries[4] + ); + }; + + assert_eq!(example.name(), "Example"); + assert_eq!(bar.name(), "Bar"); + assert_eq!(cls.name(), "Cls"); + assert_eq!(cls.documentation(), Some("Some doc\ncomment")); + assert_eq!(enm.name(), "Enm"); + assert_eq!(alias.name(), "Alias"); + } +} diff --git a/engine/baml-lib/schema-ast/src/parser/parse_type_expression_block.rs b/engine/baml-lib/schema-ast/src/parser/parse_type_expression_block.rs index e70681a0a..1bf7549fe 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_type_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_type_expression_block.rs @@ -36,8 +36,9 @@ pub(crate) fn parse_type_expression_block( if sub_type.is_none() { // First identifier is the type of block (e.g. class, enum). match current.as_str() { - "class" => sub_type = Some(SubType::Class.clone()), - "enum" => sub_type = Some(SubType::Enum.clone()), + "class" => sub_type = Some(SubType::Class), + "enum" => sub_type = Some(SubType::Enum), + "dynamic" => sub_type = Some(SubType::Dynamic), // Report this as an error, otherwise the syntax will be // correct but the type will not be registered and the @@ -79,6 +80,7 @@ pub(crate) fn parse_type_expression_block( sub_type.clone().map(|st| match st { SubType::Enum => "Enum", SubType::Class => "Class", + SubType::Dynamic => "Dynamic", SubType::Other(_) => "Other", }).unwrap_or(""), item, @@ -87,11 +89,9 @@ pub(crate) fn parse_type_expression_block( sub_type_is_enum, ); match sub_type_expression { - Ok(field) => { - fields.push(field); - }, - Err(err) => diagnostics.push_error(err), - } + Ok(field) => fields.push(field), + Err(err) => diagnostics.push_error(err), + } } Rule::comment_block => pending_field_comment = Some(item), Rule::BLOCK_LEVEL_CATCH_ALL => { @@ -123,6 +123,7 @@ pub(crate) fn parse_type_expression_block( sub_type: sub_type .clone() .unwrap_or(SubType::Other("Subtype not found".to_string())), + is_dynamic_type_def: matches!(sub_type, Some(SubType::Dynamic)), }, _ => panic!("Encountered impossible type_expression declaration during parsing",), } diff --git a/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs b/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs index d7c0ef720..d8e3f1761 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_value_expression_block.rs @@ -5,6 +5,7 @@ use super::{ parse_field::parse_value_expr, parse_identifier::parse_identifier, parse_named_args_list::{parse_function_arg, parse_named_argument_list}, + parse_type_builder_block::parse_type_builder_block, Rule, }; @@ -21,6 +22,7 @@ pub(crate) fn parse_value_expression_block( let mut attributes: Vec = Vec::new(); let mut input = None; let mut output = None; + let mut type_builder = None; let mut fields: Vec> = vec![]; let mut sub_type: Option = None; let mut has_arrow = false; @@ -35,9 +37,7 @@ pub(crate) fn parse_value_expression_block( "generator" => sub_type = Some(ValueExprBlockType::Generator), _ => panic!("Unexpected value expression keyword: {}", current.as_str()), }, - Rule::ARROW => { - has_arrow = true; - } + Rule::ARROW => has_arrow = true, Rule::identifier => name = Some(parse_identifier(current, diagnostics)), Rule::named_argument_list => match parse_named_argument_list(current, diagnostics) { Ok(arg) => input = Some(arg), @@ -87,6 +87,19 @@ pub(crate) fn parse_value_expression_block( pending_field_comment = None; } + Rule::type_builder_block => { + let block = parse_type_builder_block(item, diagnostics)?; + + match type_builder { + None => type_builder = Some(block), + + Some(_) => diagnostics.push_error(DatamodelError::new_validation_error( + "Definition of multiple `type_builder` blocks in the same parent block", + block.span + )), + } + } + Rule::comment_block => pending_field_comment = Some(item), Rule::block_attribute => { let span = item.as_span(); @@ -127,59 +140,74 @@ pub(crate) fn parse_value_expression_block( } } - let response = match name { - Some(name) => { - let msg = if has_arrow { - match (input.is_some(), output.is_some()) { - (true, true) => { - return Ok(ValueExprBlock { - name, - input, - output, - attributes, - fields, - documentation: doc_comment.and_then(parse_comment_block), - span: diagnostics.span(pair_span), - block_type: sub_type.unwrap_or(ValueExprBlockType::Function), // Unwrap or provide a default - }); - } - (true, false) => "No return type specified.", - (false, true) => "No input parameters specified.", - _ => "Invalid syntax: missing input parameters and return type.", - } - } else { - return Ok(ValueExprBlock { - name, - input, - output, - attributes, - fields, - documentation: doc_comment.and_then(parse_comment_block), - span: diagnostics.span(pair_span), - block_type: sub_type.unwrap_or(ValueExprBlockType::Function), // Unwrap or provide a default - }); - }; - - (msg, Some(name.name().to_string())) + // Block has no name. Functions, test, clients and generators have names. + let Some(name) = name else { + return Err(value_expr_block_syntax_error( + "Invalid syntax: missing name.", + None, + diagnostics.span(pair_span), + )); + }; + + // Only test blocks can have `type_builder` blocks in them. This is not a + // "syntax" error so we won't fail yet. + if let Some(ref t) = type_builder { + if sub_type != Some(ValueExprBlockType::Test) { + diagnostics.push_error(DatamodelError::new_validation_error( + "Only tests may have a type_builder block.", + t.span.to_owned(), + )); } - None => ("Invalid syntax: missing name.", None), }; - Err(DatamodelError::new_model_validation_error( - format!( - r##"{} Valid function syntax is + // No arrow means it's not a function. If it's a function then check params + // and return type. If any of the conditions are met then we're ok. + if !has_arrow || (input.is_some() && output.is_some()) { + return Ok(ValueExprBlock { + name, + input, + output, + attributes, + fields, + documentation: doc_comment.and_then(parse_comment_block), + span: diagnostics.span(pair_span), + type_builder, + block_type: sub_type.unwrap_or(ValueExprBlockType::Function), + }); + } + + // If we reach this code, we're dealing with a malformed function. + let message = match (input, output) { + (Some(_), None) => "No return type specified.", + (None, Some(_)) => "No input parameters specified.", + _ => "Invalid syntax: missing input parameters and return type.", + }; + + Err(value_expr_block_syntax_error( + message, + Some(name.name()), + diagnostics.span(pair_span), + )) +} + +fn value_expr_block_syntax_error(message: &str, name: Option<&str>, span: Span) -> DatamodelError { + let function_name = name.unwrap_or("MyFunction"); + + // TODO: Different block types (test, client, generator). + let correct_syntax = format!( + r##"{message} Valid function syntax is ``` -function {}(param1: String, param2: String) -> ReturnType {{ +function {function_name}(param1: String, param2: String) -> ReturnType {{ client SomeClient prompt #"..."# }} -```"##, - response.0, - response.1.as_deref().unwrap_or("MyFunction") - ) - .as_str(), +```"## + ); + + DatamodelError::new_model_validation_error( + &correct_syntax, "value expression", - response.1.as_deref().unwrap_or(""), - diagnostics.span(pair_span), - )) + name.unwrap_or(""), + span, + ) } diff --git a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs index bc0562efe..d120409a1 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs @@ -140,7 +140,6 @@ fn find_existing_class_field( } } - let name = Name::new_with_alias(field_name.to_string(), alias.value()); let desc = desc.value(); let r#type = field_walker.r#type(); @@ -229,6 +228,8 @@ fn relevant_data_models<'a>( let mut structural_recursive_aliases = IndexMap::new(); let mut start: Vec = vec![output.clone()]; + // start.extend(ctx.type_alias_overrides.values().cloned()); + let eval_ctx = ctx.eval_ctx(false); while let Some(output) = start.pop() { @@ -398,6 +399,15 @@ fn relevant_data_models<'a>( } } } + + // Overrides. + for cycle in &ctx.recursive_type_alias_overrides { + if cycle.contains_key(name) { + for (alias, target) in cycle.iter() { + structural_recursive_aliases.insert(alias.to_owned(), target.clone()); + } + } + } } (FieldType::Literal(_), _) => {} (FieldType::Primitive(_), _) => {} diff --git a/engine/baml-runtime/src/lib.rs b/engine/baml-runtime/src/lib.rs index 5dc493846..3ffc6eb8e 100644 --- a/engine/baml-runtime/src/lib.rs +++ b/engine/baml-runtime/src/lib.rs @@ -207,9 +207,8 @@ impl BamlRuntime { ctx: &RuntimeContext, strict: bool, ) -> Result> { - let (params, _) = - self.get_test_params_and_constraints(function_name, test_name, ctx, strict)?; - Ok(params) + self.inner + .get_test_params(function_name, test_name, ctx, strict) } pub async fn run_test( @@ -224,12 +223,17 @@ impl BamlRuntime { { let span = self.tracer.start_span(test_name, ctx, &Default::default()); + let type_builder = self + .inner + .get_test_type_builder(function_name, test_name, ctx) + .unwrap(); + let run_to_response = || async { - let rctx = ctx.create_ctx(None, None)?; + let rctx = ctx.create_ctx(type_builder.as_ref(), None)?; let (params, constraints) = self.get_test_params_and_constraints(function_name, test_name, &rctx, true)?; log::info!("params: {:#?}", params); - let rctx_stream = ctx.create_ctx(None, None)?; + let rctx_stream = ctx.create_ctx(type_builder.as_ref(), None)?; let mut stream = self.inner.stream_function_impl( function_name.into(), ¶ms, diff --git a/engine/baml-runtime/src/runtime/runtime_interface.rs b/engine/baml-runtime/src/runtime/runtime_interface.rs index deec73e29..f6c1cd3f2 100644 --- a/engine/baml-runtime/src/runtime/runtime_interface.rs +++ b/engine/baml-runtime/src/runtime/runtime_interface.rs @@ -3,6 +3,8 @@ use std::{collections::HashMap, path::PathBuf, sync::Arc}; use super::InternalBamlRuntime; use crate::internal::llm_client::traits::WithClientProperties; use crate::internal::llm_client::LLMResponse; +use crate::type_builder::TypeBuilder; +use crate::RuntimeContextManager; use crate::{ client_registry::ClientProperty, internal::{ @@ -25,6 +27,7 @@ use crate::{ }; use anyhow::{Context, Result}; use baml_types::{BamlMap, BamlValue, Constraint, EvaluationContext}; +use internal_baml_core::ir::repr::TypeBuilderEntry; use internal_baml_core::{ internal_baml_diagnostics::SourceFile, ir::{repr::IntermediateRepr, ArgCoercer, FunctionWalker, IRHelper}, @@ -285,6 +288,60 @@ impl InternalRuntimeInterface for InternalBamlRuntime { let walker = self.ir().find_test(&func, test_name)?; Ok(walker.item.1.elem.constraints.clone()) } + + fn get_test_type_builder( + &self, + function_name: &str, + test_name: &str, + ctx: &RuntimeContextManager, + ) -> Result> { + let func = self.get_function(function_name, &ctx.create_ctx(None, None)?)?; + let test = self.ir().find_test(&func, test_name)?; + + if test.type_builder_contents().is_empty() { + return Ok(None); + } + + let type_builder = TypeBuilder::new(); + + for entry in test.type_builder_contents() { + match entry { + TypeBuilderEntry::Class(cls) => { + let mutex = type_builder.class(&cls.elem.name); + let class_builder = mutex.lock().unwrap(); + for f in &cls.elem.static_fields { + class_builder + .property(&f.elem.name) + .lock() + .unwrap() + .r#type(f.elem.r#type.elem.to_owned()); + } + } + + TypeBuilderEntry::Enum(enm) => { + let mutex = type_builder.r#enum(&enm.elem.name); + let enum_builder = mutex.lock().unwrap(); + for (variant, _) in &enm.elem.values { + enum_builder.value(&variant.elem.0).lock().unwrap(); + } + } + + TypeBuilderEntry::TypeAlias(alias) => { + let mutex = type_builder.type_alias(&alias.elem.name); + let alias_builder = mutex.lock().unwrap(); + alias_builder.target(alias.elem.r#type.elem.to_owned()); + } + } + } + + type_builder + .recursive_type_aliases() + .lock() + .unwrap() + .extend(test.type_builder_recursive_aliases().iter().cloned()); + + Ok(Some(type_builder)) + } } impl RuntimeConstructor for InternalBamlRuntime { diff --git a/engine/baml-runtime/src/runtime_interface.rs b/engine/baml-runtime/src/runtime_interface.rs index 0a43d488d..4d117b979 100644 --- a/engine/baml-runtime/src/runtime_interface.rs +++ b/engine/baml-runtime/src/runtime_interface.rs @@ -9,6 +9,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::internal::llm_client::llm_provider::LLMProvider; use crate::internal::llm_client::orchestrator::{OrchestrationScope, OrchestratorNode}; use crate::tracing::{BamlTracer, TracingSpan}; +use crate::type_builder::TypeBuilder; use crate::types::on_log_event::LogEventCallbackSync; use crate::{ internal::{ir_features::IrFeatures, llm_client::retry_policy::CallablePolicy}, @@ -166,4 +167,11 @@ pub trait InternalRuntimeInterface { test_name: &str, ctx: &RuntimeContext, ) -> Result>; + + fn get_test_type_builder( + &self, + function_name: &str, + test_name: &str, + ctx: &RuntimeContextManager, + ) -> Result>; } diff --git a/engine/baml-runtime/src/type_builder/mod.rs b/engine/baml-runtime/src/type_builder/mod.rs index 72fabece2..acf9528a8 100644 --- a/engine/baml-runtime/src/type_builder/mod.rs +++ b/engine/baml-runtime/src/type_builder/mod.rs @@ -133,6 +133,26 @@ impl EnumBuilder { } } +pub struct TypeAliasBuilder { + target: Arc>>, + meta: MetaData, +} +impl_meta!(TypeAliasBuilder); + +impl TypeAliasBuilder { + pub fn new() -> Self { + Self { + target: Default::default(), + meta: Arc::new(Mutex::new(Default::default())), + } + } + + pub fn target(&self, target: FieldType) -> &Self { + *self.target.lock().unwrap() = Some(target); + self + } +} + impl std::fmt::Debug for TypeBuilder { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // Start the debug printout with the struct name @@ -170,6 +190,8 @@ impl std::fmt::Debug for TypeBuilder { pub struct TypeBuilder { classes: Arc>>>>, enums: Arc>>>>, + type_aliases: Arc>>>>, + recursive_type_aliases: Arc>>>, } impl Default for TypeBuilder { @@ -183,6 +205,8 @@ impl TypeBuilder { Self { classes: Default::default(), enums: Default::default(), + type_aliases: Default::default(), + recursive_type_aliases: Default::default(), } } @@ -206,11 +230,27 @@ impl TypeBuilder { ) } + pub fn type_alias(&self, name: &str) -> Arc> { + Arc::clone( + self.type_aliases + .lock() + .unwrap() + .entry(name.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(TypeAliasBuilder::new()))), + ) + } + + pub fn recursive_type_aliases(&self) -> Arc>>> { + Arc::clone(&self.recursive_type_aliases) + } + pub fn to_overrides( &self, ) -> ( IndexMap, IndexMap, + IndexMap, + Vec>, ) { log::debug!("Converting types to overrides"); let cls = self @@ -283,12 +323,29 @@ impl TypeBuilder { ) }) .collect(); + + let aliases = self + .type_aliases + .lock() + .unwrap() + .iter() + .map(|(name, builder)| { + let mutex = builder.lock().unwrap(); + let target = mutex.target.lock().unwrap(); + // TODO: target.unwrap() might not be guaranteed here. + (name.clone(), target.to_owned().unwrap()) + }) + .collect(); + log::debug!( "Dynamic types: \n {:#?} \n Dynamic enums\n {:#?} enums", cls, enm ); - (cls, enm) + + let recursive_aliases = self.recursive_type_aliases.lock().unwrap().clone(); + + (cls, enm, aliases, recursive_aliases) } } diff --git a/engine/baml-runtime/src/types/context_manager.rs b/engine/baml-runtime/src/types/context_manager.rs index 001236103..f0cea3038 100644 --- a/engine/baml-runtime/src/types/context_manager.rs +++ b/engine/baml-runtime/src/types/context_manager.rs @@ -105,8 +105,8 @@ impl RuntimeContextManager { pub fn create_ctx( &self, - tb: Option<&TypeBuilder>, - cb: Option<&ClientRegistry>, + type_builder: Option<&TypeBuilder>, + client_registry: Option<&ClientRegistry>, ) -> Result { let mut tags = self.global_tags.lock().unwrap().clone(); let ctx_tags = { @@ -125,7 +125,9 @@ impl RuntimeContextManager { ctx.map(|(.., tags)| tags).cloned().unwrap_or_default() }; - let (cls, enm) = tb.map(|tb| tb.to_overrides()).unwrap_or_default(); + let (cls, enm, als, rec_als) = type_builder + .map(TypeBuilder::to_overrides) + .unwrap_or_default(); let mut ctx = RuntimeContext::new( self.baml_src_reader.clone(), @@ -134,18 +136,18 @@ impl RuntimeContextManager { Default::default(), cls, enm, + als, + rec_als, ); - let client_overrides = match cb { - Some(cb) => Some( - cb.to_clients(&ctx) + ctx.client_overrides = match client_registry { + Some(cr) => Some( + cr.to_clients(&ctx) .with_context(|| "Failed to create clients from client_registry")?, ), None => None, }; - ctx.client_overrides = client_overrides; - Ok(ctx) } @@ -159,6 +161,8 @@ impl RuntimeContextManager { Default::default(), Default::default(), Default::default(), + Default::default(), + Default::default(), ) } diff --git a/engine/baml-runtime/src/types/runtime_context.rs b/engine/baml-runtime/src/types/runtime_context.rs index 8abd3ea7a..3bc15b251 100644 --- a/engine/baml-runtime/src/types/runtime_context.rs +++ b/engine/baml-runtime/src/types/runtime_context.rs @@ -56,6 +56,8 @@ pub struct RuntimeContext { pub client_overrides: Option<(Option, HashMap>)>, pub class_override: IndexMap, pub enum_overrides: IndexMap, + pub type_alias_overrides: IndexMap, + pub recursive_type_alias_overrides: Vec>, } impl RuntimeContext { @@ -78,6 +80,8 @@ impl RuntimeContext { client_overrides: Option<(Option, HashMap>)>, class_override: IndexMap, enum_overrides: IndexMap, + type_alias_overrides: IndexMap, + recursive_type_alias_overrides: Vec>, ) -> RuntimeContext { RuntimeContext { baml_src, @@ -86,6 +90,8 @@ impl RuntimeContext { client_overrides, class_override, enum_overrides, + type_alias_overrides, + recursive_type_alias_overrides, } } diff --git a/engine/baml-runtime/tests/test_runtime.rs b/engine/baml-runtime/tests/test_runtime.rs index 97e9d07ac..2902e088c 100644 --- a/engine/baml-runtime/tests/test_runtime.rs +++ b/engine/baml-runtime/tests/test_runtime.rs @@ -129,7 +129,7 @@ mod internal_tests { files.insert( "main.baml", r##" - + class Education { school string | null @description(#" 111 @@ -146,15 +146,15 @@ mod internal_tests { api_key env.OPENAI_API_KEY } } - - + + function Extract(input: string) -> Education { client GPT4Turbo prompt #" - + {{ ctx.output_format }} "# - } + } test Test { functions [Extract] @@ -206,7 +206,7 @@ mod internal_tests { files.insert( "main.baml", r##" - + class Education { // school string | (null | int) @description(#" // 111 @@ -226,15 +226,15 @@ mod internal_tests { api_key env.OPENAI_API_KEY } } - - + + function Extract(input: string) -> Education { client GPT4Turbo prompt #" - + {{ ctx.output_format }} "# - } + } test Test { functions [Extract] @@ -285,7 +285,13 @@ mod internal_tests { BamlRuntime::from_file_content( "baml_src", &files, - [("OPENAI_API_KEY", "OPENAI_API_KEY")].into(), + [( + "OPENAI_API_KEY", + // Use this to test with a real API key. + // option_env!("OPENAI_API_KEY").unwrap_or("NO_API_KEY"), + "OPENAI_API_KEY", + )] + .into(), ) } @@ -306,7 +312,7 @@ class Item { price float quantity int @description("If not specified, assume 1") } - + // This is our LLM function we can call in Python or Typescript // the receipt can be an image OR text here! function ExtractReceipt(receipt: image | string) -> Receipt { @@ -458,7 +464,7 @@ function BuildTree(input: BinaryNode) -> Tree { INPUT: {{ input }} - {{ ctx.output_format }} + {{ ctx.output_format }} "# } @@ -600,4 +606,351 @@ test RecursiveAliasCycle { Ok(()) } + + struct TypeBuilderBlockTest { + function_name: &'static str, + test_name: &'static str, + baml: &'static str, + } + + fn run_type_builder_block_test( + TypeBuilderBlockTest { + function_name, + test_name, + baml, + }: TypeBuilderBlockTest, + ) -> anyhow::Result<()> { + // Use this and RUST_LOG=debug to see the rendered prompt in the + // terminal. + env_logger::init(); + + let runtime = make_test_runtime(baml)?; + + let ctx = runtime.create_ctx_manager(BamlValue::String("test".to_string()), None); + + let run_test_future = runtime.run_test(function_name, test_name, &ctx, Some(|r| {})); + let (res, span) = runtime.async_runtime.block_on(run_test_future); + + Ok(()) + } + + #[test] + fn test_type_builder_block_with_dynamic_class() -> anyhow::Result<()> { + run_type_builder_block_test(TypeBuilderBlockTest { + function_name: "ExtractResume", + test_name: "ReturnDynamicClassTest", + baml: r##" + class Resume { + name string + education Education[] + skills string[] + @@dynamic + } + + class Education { + school string + degree string + year int + } + + function ExtractResume(from_text: string) -> Resume { + client "openai/gpt-4o" + prompt #" + Extract the resume information from the given text. + + {{ from_text }} + + {{ ctx.output_format }} + "# + } + + test ReturnDynamicClassTest { + functions [ExtractResume] + type_builder { + class Experience { + title string + company string + start_date string + end_date string + } + + dynamic Resume { + experience Experience[] + } + } + args { + from_text #" + John Doe + + Education + - University of California, Berkeley, B.S. in Computer Science, 2020 + + Experience + - Software Engineer, Boundary, Sep 2022 - Sep 2023 + + Skills + - Python + - Java + "# + } + } + "##, + }) + } + + #[test] + fn test_type_builder_block_with_dynamic_enum() -> anyhow::Result<()> { + run_type_builder_block_test(TypeBuilderBlockTest { + function_name: "ClassifyMessage", + test_name: "ReturnDynamicEnumTest", + baml: r##" + enum Category { + Refund + CancelOrder + AccountIssue + @@dynamic + } + + // Function that returns the dynamic enum. + function ClassifyMessage(message: string) -> Category { + client "openai/gpt-4o" + prompt #" + Classify this message: + + {{ message }} + + {{ ctx.output_format }} + "# + } + + test ReturnDynamicEnumTest { + functions [ClassifyMessage] + type_builder { + dynamic Category { + Question + Feedback + TechnicalSupport + } + } + args { + message "I think the product is great!" + } + } + "##, + }) + } + + #[test] + fn test_type_builder_block_mixed_enums_and_classes() -> anyhow::Result<()> { + run_type_builder_block_test(TypeBuilderBlockTest { + function_name: "ExtractResume", + test_name: "ReturnDynamicClassTest", + baml: r##" + class Resume { + name string + education Education[] + skills string[] + @@dynamic + } + + class Education { + school string + degree string + year int + } + + enum Role { + SoftwareEngineer + DataScientist + @@dynamic + } + + function ExtractResume(from_text: string) -> Resume { + client "openai/gpt-4o" + prompt #" + Extract the resume information from the given text. + + {{ from_text }} + + {{ ctx.output_format }} + "# + } + + test ReturnDynamicClassTest { + functions [ExtractResume] + type_builder { + class Experience { + title string + company string + start_date string + end_date string + } + + enum Industry { + Tech + Finance + Healthcare + } + + dynamic Role { + ProductManager + Sales + } + + dynamic Resume { + experience Experience[] + role Role + industry Industry + } + } + args { + from_text #" + John Doe + + Education + - University of California, Berkeley, B.S. in Computer Science, 2020 + + Experience + - Software Engineer, Boundary, Sep 2022 - Sep 2023 + + Skills + - Python + - Java + "# + } + } + "##, + }) + } + + #[test] + fn test_type_builder_block_type_aliases() -> anyhow::Result<()> { + run_type_builder_block_test(TypeBuilderBlockTest { + function_name: "ExtractResume", + test_name: "ReturnDynamicClassTest", + baml: r##" + class Resume { + name string + education Education[] + skills string[] + @@dynamic + } + + class Education { + school string + degree string + year int + } + + function ExtractResume(from_text: string) -> Resume { + client "openai/gpt-4o" + prompt #" + Extract the resume information from the given text. + + {{ from_text }} + + {{ ctx.output_format }} + "# + } + + test ReturnDynamicClassTest { + functions [ExtractResume] + type_builder { + class Experience { + title string + company string + start_date string + end_date string + } + + type ExpAlias = Experience + + dynamic Resume { + experience ExpAlias + } + } + args { + from_text #" + John Doe + + Education + - University of California, Berkeley, B.S. in Computer Science, 2020 + + Experience + - Software Engineer, Boundary, Sep 2022 - Sep 2023 + + Skills + - Python + - Java + "# + } + } + "##, + }) + } + + #[test] + fn test_type_builder_block_recursive_type_aliases() -> anyhow::Result<()> { + run_type_builder_block_test(TypeBuilderBlockTest { + function_name: "ExtractResume", + test_name: "ReturnDynamicClassTest", + baml: r##" + class Resume { + name string + education Education[] + skills string[] + @@dynamic + } + + class Education { + school string + degree string + year int + } + + class WhatTheFuck { + j JsonValue + } + + type JsonValue = int | float | bool | string | JsonValue[] | map + + function ExtractResume(from_text: string) -> Resume { + client "openai/gpt-4o" + prompt #" + Extract the resume information from the given text. + + {{ from_text }} + + {{ ctx.output_format }} + "# + } + + test ReturnDynamicClassTest { + functions [ExtractResume] + type_builder { + type JSON = int | float | bool | string | JSON[] | map + + dynamic Resume { + experience JSON + } + } + args { + from_text #" + John Doe + + Education + - University of California, Berkeley, B.S. in Computer Science, 2020 + + Experience + - Software Engineer, Boundary, Sep 2022 - Sep 2023 + + Skills + - Python + - Java + "# + } + } + "##, + }) + } } diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index d53d68e44..6ee78b002 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -522,8 +522,8 @@ fn serialize_value_counting_checks(value: &ResponseBamlValue) -> (serde_json::Va .collect::>(); let sub_check_count: usize = value.0.iter().map(|node| node.meta().1.len()).sum(); - let json_value: serde_json::Value = - serde_json::to_value(value.serialize_final()).unwrap_or("Error converting value to JSON".into()); + let json_value: serde_json::Value = serde_json::to_value(value.serialize_final()) + .unwrap_or("Error converting value to JSON".into()); let check_count = checks.len() + sub_check_count; @@ -1507,13 +1507,20 @@ impl WasmFunction { wasm_call_context: &WasmCallContext, get_baml_src_cb: js_sys::Function, ) -> JsResult { - let ctx = rt + let context_manager = rt.runtime.create_ctx_manager( + BamlValue::String("wasm".to_string()), + js_fn_to_baml_src_reader(get_baml_src_cb), + ); + + let test_type_builder = rt .runtime - .create_ctx_manager( - BamlValue::String("wasm".to_string()), - js_fn_to_baml_src_reader(get_baml_src_cb), - ) - .create_ctx_with_default(); + .internal() + .get_test_type_builder(&self.name, &test_name, &context_manager) + .map_err(|e| JsError::new(format!("{e:?}").as_str()))?; + + let ctx = context_manager + .create_ctx(test_type_builder.as_ref(), None) + .map_err(|e| JsError::new(format!("{e:?}").as_str()))?; let params = rt .runtime @@ -1553,13 +1560,20 @@ impl WasmFunction { expand_images: bool, get_baml_src_cb: js_sys::Function, ) -> Result { - let ctx = rt + let context_manager = rt.runtime.create_ctx_manager( + BamlValue::String("wasm".to_string()), + js_fn_to_baml_src_reader(get_baml_src_cb), + ); + + let test_type_builder = rt .runtime - .create_ctx_manager( - BamlValue::String("wasm".to_string()), - js_fn_to_baml_src_reader(get_baml_src_cb), - ) - .create_ctx_with_default(); + .internal() + .get_test_type_builder(&self.name, &test_name, &context_manager) + .map_err(|e| JsError::new(format!("{e:?}").as_str()))?; + + let ctx = context_manager + .create_ctx(test_type_builder.as_ref(), None) + .map_err(|e| JsError::new(format!("{e:?}").as_str()))?; let params = rt .runtime diff --git a/engine/baml-schema-wasm/tests/test_file_manager.rs b/engine/baml-schema-wasm/tests/test_file_manager.rs index 6069e6091..f4abf9748 100644 --- a/engine/baml-schema-wasm/tests/test_file_manager.rs +++ b/engine/baml-schema-wasm/tests/test_file_manager.rs @@ -21,27 +21,27 @@ mod tests { /// Sample BAML content for testing. fn sample_baml_content() -> String { r##" - - + + class Email { subject string body string from_address string } - + enum OrderStatus { ORDERED SHIPPED DELIVERED CANCELLED } - + class OrderInfo { order_status OrderStatus tracking_number string? estimated_arrival_date string? } - + client GPT4Turbo { provider baml-openai-chat options { @@ -49,7 +49,7 @@ mod tests { api_key env.OPENAI_API_KEY } } - + function GetOrderInfo(input: string) -> OrderInfo { client GPT4Turbo prompt #" diff --git a/typescript/codemirror-lang-baml/src/index.ts b/typescript/codemirror-lang-baml/src/index.ts index 11119c692..a625155da 100644 --- a/typescript/codemirror-lang-baml/src/index.ts +++ b/typescript/codemirror-lang-baml/src/index.ts @@ -44,6 +44,9 @@ export const BAMLLanguage = LRLanguage.define({ TestDecl: t.keyword, 'TestDecl/IdentifierDecl': t.typeName, + TypeBuilderDecl: t.keyword, + TypeBuilderKeyword: t.keyword, + EnumDecl: t.keyword, 'EnumDecl/IdentifierDecl': t.typeName, 'EnumDecl/EnumValueDecl/IdentifierDecl': t.propertyName, @@ -54,6 +57,9 @@ export const BAMLLanguage = LRLanguage.define({ TypeAliasDecl: t.keyword, 'TypeAliasDecl/IdentifierDecl': t.typeName, + DynamicDecl: t.keyword, + 'DynamicDecl/IdentifierDecl': t.typeName, + ClientDecl: t.keyword, 'ClientDecl/IdentifierDecl': t.typeName, diff --git a/typescript/codemirror-lang-baml/src/syntax.grammar b/typescript/codemirror-lang-baml/src/syntax.grammar index 4bbdd0b88..bbd213768 100644 --- a/typescript/codemirror-lang-baml/src/syntax.grammar +++ b/typescript/codemirror-lang-baml/src/syntax.grammar @@ -1,6 +1,6 @@ @top Baml { Decl* } -Decl { ClassDecl | EnumDecl | FunctionDecl | ClientDecl | TemplateStringDecl | TestDecl | TypeAliasDecl } +Decl { ClassDecl | EnumDecl | FunctionDecl | ClientDecl | TemplateStringDecl | TestDecl | TypeAliasDecl | TypeBuilderDecl } ClassDecl { "class" IdentifierDecl "{" (BlockAttribute | ClassField)* "}" } @@ -12,6 +12,12 @@ TypeAliasDecl { "type" IdentifierDecl "=" TypeExpr } +// TODO: Can't easily disambiguate between dynamic class and dynamic enum +// definitions. Dynamic enums are highlighted as if they were classes. +DynamicDecl { "dynamic" IdentifierDecl "{" ClassField* "}" } + +TypeBuilderDecl { TypeBuilderKeyword "{" (ClassDecl | EnumDecl | DynamicDecl | TypeAliasDecl)* "}" } + TypeExpr { ComplexTypeExpr | UnionTypeExpr } ComplexTypeExpr { "(" UnionTypeExpr ")" | @@ -56,7 +62,7 @@ ClientDecl { } TestDecl { - "test" IdentifierDecl "{" (BlockAttribute | TupleValue)* "}" + "test" IdentifierDecl "{" (BlockAttribute | TupleValue | TypeBuilderDecl)* "}" } LiteralDecl { NumericLiteral | QuotedString | UnquotedString } @@ -71,9 +77,10 @@ TemplateStringDecl { NumericLiteral, QuotedString, UnquotedString, UnquotedAttributeValue } @precedence { - MapIdentifier, IdentifierDecl, UnquotedAttributeValue + MapIdentifier, TypeBuilderKeyword, IdentifierDecl, UnquotedAttributeValue } + TypeBuilderKeyword { "type_builder" } MapIdentifier { "map" } IdentifierDecl { $[A-Za-z0-9-_]+ } diff --git a/typescript/vscode-ext/packages/syntaxes/baml.tmLanguage.json b/typescript/vscode-ext/packages/syntaxes/baml.tmLanguage.json index aa3009544..773afabe5 100644 --- a/typescript/vscode-ext/packages/syntaxes/baml.tmLanguage.json +++ b/typescript/vscode-ext/packages/syntaxes/baml.tmLanguage.json @@ -14,7 +14,8 @@ { "include": "#function" }, { "include": "#language_block_python" }, { "include": "#language_block_ts" }, - { "include": "#language_block_jinja" } + { "include": "#language_block_jinja" }, + { "include": "#type_builder" } ] }, "comment": { @@ -77,6 +78,31 @@ { "include": "#block_attribute" } ] }, + "dynamic_declaration": { + "begin": "(dynamic)\\s+(\\w+)\\s*\\{", + "beginCaptures": { + "1": { "name": "storage.type.dynamic" }, + "2": { "name": "entity.name.type.dynamic" } + }, + "end": "\\}", + "patterns": [ + { "include": "#comment" }, + { + "comment": "Property + Type", + "begin": "(\\w+)", + "beginCaptures": { + "1": { "name": "variable.other.readwrite.interface" } + }, + "end": "(?=$|\\n|@|\\}|/)", + "patterns": [{ "include": "#type_definition" }] + }, + { + "name": "variable.other.field", + "match": "\\b[A-Za-z_][A-Za-z0-9_]*\\b" + }, + { "include": "#block_attribute" } + ] + }, "template_string_declaration": { "begin": "(template_string)\\s+(\\w+)", "beginCaptures": { @@ -327,6 +353,7 @@ "patterns": [ { "include": "#comment" }, { "include": "#block_attribute" }, + { "include": "#type_builder" }, { "include": "#property_assignment_expression" } ] }, @@ -731,7 +758,7 @@ "end": "(?=$|\\n)", "patterns": [ { "include": "#comment" }, - { + { "begin": "(?<=\\=)\\s*", "end": "(?=//|$|\\n)", "patterns": [ @@ -748,6 +775,20 @@ "string_literal": { "match": "\"[^\"]*\"", "name": "string.quoted.double" + }, + "type_builder": { + "begin": "(type_builder)\\s*\\{", + "beginCaptures": { + "1": { "name": "keyword.control.type_builder" } + }, + "end": "\\}", + "patterns": [ + { "include": "#interface_declaration" }, + { "include": "#enum_declaration" }, + { "include": "#dynamic_declaration" }, + { "include": "#type_alias" }, + { "include": "#comment" } + ] } }, "scopeName": "source.baml"