From 93b393ded048817fdb7ffef65cb698f9edb14764 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Fri, 22 Nov 2024 15:41:01 -0800 Subject: [PATCH] Validate fieldnames and types when using pydantic codegen (#1189) Python and pydantic do not allow arbitrary identifiers to be used as fields in classes. This PR adds checks to the BAML grammar, which run conditionally when the user includes a python/pydantic code generator block: - field names must not be Python keywords. - field names must not be lexographically equal to the field type, or the base of an optional type. E.g. rule 1: ```python # Not ok class Foo(BaseModel): if string ``` E.g. rule 2: ```python class ETA(BaseModel): time: string # Not ok class Foo(BaseModel): ETA: ETA ``` These rules are now checked during validation of the syntax tree prior to construction of the IR, and if they are violated we push an error to `Diagnostics`. Bonus: There are a few changes in the PR not related to the issue - they are little cleanups to reduce the number of unnecessary `rustc` warnings. ---- > [!IMPORTANT] > Add validation for field names in BAML classes to prevent Python keyword and type name conflicts when using Pydantic code generation. > > - **Validation**: > - Add `assert_no_field_name_collisions()` in `classes.rs` to check field names against Python keywords and type names when using Pydantic. > - Use `reserved_names()` to map keywords to target languages. > - **Diagnostics**: > - Update `new_field_validation_error()` in `error.rs` to accept `String` for error messages. > - **Miscellaneous**: > - Remove unused code and features in `lib.rs` and `build.rs` to reduce rustc warnings. > - Add tests `generator_keywords1.baml` and `generator_keywords2.baml` to validate new rules. > > This description was created by [Ellipsis](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral) for 49d31fb6f91433530e8f67eadeec8994f86616f3. It will automatically update as commits are pushed. --- .../validation_pipeline/validations.rs | 13 ++ .../validations/classes.rs | 117 +++++++++++++++++- engine/baml-lib/baml-types/src/generator.rs | 9 ++ .../class/generator_keywords1.baml | 29 +++++ .../class/generator_keywords2.baml | 17 +++ engine/baml-lib/diagnostics/src/error.rs | 2 +- engine/baml-lib/parser-database/build.rs | 3 - engine/baml-runtime/src/lib.rs | 9 +- engine/baml-runtime/src/macros.rs | 9 -- .../baml-schema-wasm/src/runtime_wasm/mod.rs | 6 +- .../src/runtime_wasm/runtime_prompt.rs | 5 +- 11 files changed, 190 insertions(+), 29 deletions(-) create mode 100644 engine/baml-lib/baml/tests/validation_files/class/generator_keywords1.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/class/generator_keywords2.baml delete mode 100644 engine/baml-runtime/src/macros.rs diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs index b595c7f96..c5c6acb3f 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations.rs @@ -7,8 +7,14 @@ mod functions; mod template_strings; mod types; +use baml_types::GeneratorOutputType; + +use crate::{configuration::Generator, validate::generator_loader::load_generators_from_ast}; + use super::context::Context; +use std::collections::HashSet; + pub(super) fn validate(ctx: &mut Context<'_>) { enums::validate(ctx); classes::validate(ctx); @@ -17,6 +23,13 @@ pub(super) fn validate(ctx: &mut Context<'_>) { template_strings::validate(ctx); configurations::validate(ctx); + let generators = load_generators_from_ast(ctx.db.ast(), ctx.diagnostics); + let codegen_targets: HashSet = generators.into_iter().filter_map(|generator| match generator { + Generator::Codegen(gen) => Some(gen.output_type), + Generator::BoundaryCloud(_) => None + }).collect::>(); + classes::assert_no_field_name_collisions(ctx, &codegen_targets); + if !ctx.diagnostics.has_errors() { cycle::validate(ctx); } diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/classes.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/classes.rs index 43f08f0f5..4e1301386 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/classes.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/classes.rs @@ -1,9 +1,13 @@ -use internal_baml_schema_ast::ast::{WithName, WithSpan}; +use baml_types::GeneratorOutputType; +use internal_baml_schema_ast::ast::{Field, FieldType, WithName, WithSpan}; use super::types::validate_type; use crate::validate::validation_pipeline::context::Context; use internal_baml_diagnostics::DatamodelError; +use itertools::join; +use std::collections::{HashMap, HashSet}; + pub(super) fn validate(ctx: &mut Context<'_>) { let mut defined_types = internal_baml_jinja_types::PredefinedTypes::default( internal_baml_jinja_types::JinjaContext::Prompt, @@ -45,3 +49,114 @@ pub(super) fn validate(ctx: &mut Context<'_>) { defined_types.errors_mut().clear(); } } + +/// Enforce that keywords in the user's requested target languages +/// do not appear as field names in BAML classes, and that field +/// names are not equal to type names when using Pydantic. +pub(super) fn assert_no_field_name_collisions( + ctx: &mut Context<'_>, + generator_output_types: &HashSet, +) { + // The list of reserved words for all user-requested codegen targets. + let reserved = reserved_names(generator_output_types); + + for cls in ctx.db.walk_classes() { + for c in cls.static_fields() { + let field: &Field = c.ast_field(); + + // Check for keyword in field name. + if let Some(langs) = reserved.get(field.name()) { + let msg = match langs.as_slice() { + [lang] => format!("Field name is a reserved word in generated {lang} clients."), + _ => format!( + "Field name is a reserved word in language clients: {}.", + join(langs, ", ") + ), + }; + ctx.push_error(DatamodelError::new_field_validation_error( + msg, + "class", + c.name(), + field.name(), + field.span.clone(), + )) + } + + // Check for collision between field name and type name when using Pydantic. + if generator_output_types.contains(&GeneratorOutputType::PythonPydantic) { + let type_name = field + .expr + .as_ref() + .map_or("".to_string(), |r#type| r#type.name()); + if field.name() == type_name { + ctx.push_error(DatamodelError::new_field_validation_error( + "When using the python/pydantic generator, a field name must not be exactly equal to the type name. Consider changing the field name and using an alias.".to_string(), + "class", + c.name(), + field.name(), + field.span.clone() + )) + } + } + } + } +} + +/// For a given set of target languages, construct a map from keyword to the +/// list of target languages in which that identifier is a keyword. +/// +/// This will be used later to make error messages like, "Could not use name +/// `continue` becase that is a keyword in Python", "Could not use the name +/// `return` because that is a keyword in Python and Typescript". +fn reserved_names( + generator_output_types: &HashSet, +) -> HashMap<&'static str, Vec> { + let mut keywords: HashMap<&str, Vec> = HashMap::new(); + + let language_keywords: Vec<(&str, GeneratorOutputType)> = [ + if generator_output_types.contains(&GeneratorOutputType::PythonPydantic) { + RESERVED_NAMES_PYTHON + .into_iter() + .map(|name| (*name, GeneratorOutputType::PythonPydantic)) + .collect() + } else { + Vec::new() + }, + if generator_output_types.contains(&GeneratorOutputType::Typescript) { + RESERVED_NAMES_TYPESCRIPT + .into_iter() + .map(|name| (*name, GeneratorOutputType::Typescript)) + .collect() + } else { + Vec::new() + }, + ] + .iter() + .flatten() + .cloned() + .collect(); + + language_keywords + .into_iter() + .for_each(|(keyword, generator_output_type)| { + keywords + .entry(keyword) + .and_modify(|types| (*types).push(generator_output_type)) + .or_insert(vec![generator_output_type]); + }); + + keywords +} + +// This list of keywords was copied from +// https://www.w3schools.com/python/python_ref_keywords.asp +// . +const RESERVED_NAMES_PYTHON: &[&str] = &[ + "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class", "continue", + "def", "del", "elif", "else", "except", "finally", "for", "from", "global", "if", "import", + "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise", "return", "try", "while", + "with", "yield", +]; + +// Typescript is much more flexible in the key names it allows. +const RESERVED_NAMES_TYPESCRIPT: &[&str] = &[]; diff --git a/engine/baml-lib/baml-types/src/generator.rs b/engine/baml-lib/baml-types/src/generator.rs index 3f78d6c49..51897e43e 100644 --- a/engine/baml-lib/baml-types/src/generator.rs +++ b/engine/baml-lib/baml-types/src/generator.rs @@ -8,6 +8,8 @@ strum::VariantArray, strum::VariantNames, )] + +#[derive(PartialEq, Eq)] pub enum GeneratorOutputType { #[strum(serialize = "rest/openapi")] OpenApi, @@ -22,6 +24,13 @@ pub enum GeneratorOutputType { RubySorbet, } +impl std::hash::Hash for GeneratorOutputType { + fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); + } +} + + impl GeneratorOutputType { pub fn default_client_mode(&self) -> GeneratorDefaultClientMode { match self { diff --git a/engine/baml-lib/baml/tests/validation_files/class/generator_keywords1.baml b/engine/baml-lib/baml/tests/validation_files/class/generator_keywords1.baml new file mode 100644 index 000000000..4547adcd9 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/class/generator_keywords1.baml @@ -0,0 +1,29 @@ +generator lang_python { + output_type python/pydantic + output_dir "../python" + version "0.68.0" +} + +class ETA { + thing string +} + +class Foo { + if string + ETA ETA? +} + +// error: Error validating field `if` in class `if`: Field name is a reserved word in generated python/pydantic clients. +// --> class/generator_keywords1.baml:12 +// | +// 11 | class Foo { +// 12 | if string +// 13 | ETA ETA? +// | +// error: Error validating field `ETA` in class `ETA`: When using the python/pydantic generator, a field name must not be exactly equal to the type name. Consider changing the field name and using an alias. +// --> class/generator_keywords1.baml:13 +// | +// 12 | if string +// 13 | ETA ETA? +// 14 | } +// | diff --git a/engine/baml-lib/baml/tests/validation_files/class/generator_keywords2.baml b/engine/baml-lib/baml/tests/validation_files/class/generator_keywords2.baml new file mode 100644 index 000000000..5b2c8eeb7 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/class/generator_keywords2.baml @@ -0,0 +1,17 @@ +// This file is just like generator_keywords1.baml, except that the fieldname +// has been changed in order to not collide with the field type `ETA`, and an +// alias is used to render that field as `ETA` in prompts. + +generator lang_python { + output_type python/pydantic + output_dir "../python" + version "0.68.0" +} + +class ETA { + thing string +} + +class Foo { + eta ETA? @alias("ETA") +} \ No newline at end of file diff --git a/engine/baml-lib/diagnostics/src/error.rs b/engine/baml-lib/diagnostics/src/error.rs index e1f1386ce..e79830925 100644 --- a/engine/baml-lib/diagnostics/src/error.rs +++ b/engine/baml-lib/diagnostics/src/error.rs @@ -346,7 +346,7 @@ impl DatamodelError { } pub fn new_field_validation_error( - message: &str, + message: String, container_type: &str, container_name: &str, field: &str, diff --git a/engine/baml-lib/parser-database/build.rs b/engine/baml-lib/parser-database/build.rs index 52355c4a9..f79c691f0 100644 --- a/engine/baml-lib/parser-database/build.rs +++ b/engine/baml-lib/parser-database/build.rs @@ -1,5 +1,2 @@ fn main() { - // If you have an existing build.rs file, just add this line to it. - #[cfg(feature = "use-pyo3")] - pyo3_build_config::use_pyo3_cfgs(); } diff --git a/engine/baml-runtime/src/lib.rs b/engine/baml-runtime/src/lib.rs index e72536a1e..d4659ead4 100644 --- a/engine/baml-runtime/src/lib.rs +++ b/engine/baml-runtime/src/lib.rs @@ -1,10 +1,4 @@ -#[cfg(all(test, feature = "no_wasm"))] -mod tests; - -// #[cfg(all(feature = "wasm", feature = "no_wasm"))] -// compile_error!( -// "The features 'wasm' and 'no_wasm' are mutually exclusive. You can only use one at a time." -// ); +// mod tests; #[cfg(feature = "internal")] pub mod internal; @@ -15,7 +9,6 @@ pub(crate) mod internal; pub mod cli; pub mod client_registry; pub mod errors; -mod macros; pub mod request; mod runtime; pub mod runtime_interface; diff --git a/engine/baml-runtime/src/macros.rs b/engine/baml-runtime/src/macros.rs deleted file mode 100644 index 4c9dc6f8f..000000000 --- a/engine/baml-runtime/src/macros.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[macro_use] -macro_rules! internal_feature { - () => { - #[cfg(feature = "internal")] - { pub } - #[cfg(not(feature = "internal"))] - { pub(crate) } - }; -} diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index de6f8678c..1e72eb67b 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -1232,7 +1232,7 @@ impl WasmRuntime { if span.file_path.as_str().ends_with(file_name) && ((span.start + 1)..=(span.end + 1)).contains(&cursor_idx) { - if let Some(parent_function) = + if let Some(_parent_function) = tc.parent_functions.iter().find(|f| f.name == selected_func) { return functions.into_iter().find(|f| f.name == selected_func); @@ -1251,7 +1251,7 @@ impl WasmRuntime { if span.file_path.as_str().ends_with(file_name) && ((span.start + 1)..=(span.end + 1)).contains(&cursor_idx) { - if let Some(parent_function) = + if let Some(_parent_function) = tc.parent_functions.iter().find(|f| f.name == selected_func) { return functions.into_iter().find(|f| f.name == selected_func); @@ -1441,7 +1441,7 @@ fn js_fn_to_baml_src_reader(get_baml_src_cb: js_sys::Function) -> BamlSrcReader } #[wasm_bindgen] -struct WasmCallContext { +pub struct WasmCallContext { /// Index of the orchestration graph node to use for the call /// Defaults to 0 when unset node_index: Option, diff --git a/engine/baml-schema-wasm/src/runtime_wasm/runtime_prompt.rs b/engine/baml-schema-wasm/src/runtime_wasm/runtime_prompt.rs index b0cd13e68..c193ac890 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/runtime_prompt.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/runtime_prompt.rs @@ -7,16 +7,13 @@ use baml_runtime::{ }, ChatMessagePart, RenderedPrompt, }; -use serde::Serialize; use serde_json::json; use crate::runtime_wasm::ToJsValue; -use baml_types::{BamlMedia, BamlMediaContent, BamlMediaType, MediaBase64}; +use baml_types::{BamlMediaContent, BamlMediaType, MediaBase64}; use serde_wasm_bindgen::to_value; use wasm_bindgen::prelude::*; -use super::WasmFunction; - #[wasm_bindgen(getter_with_clone)] pub struct WasmScope { scope: OrchestrationScope,