From 8786e3a62a43b995dc875529623400f567fbe167 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Thu, 14 Nov 2024 02:51:40 +0000 Subject: [PATCH] Add `TypeAlias` walker --- engine/baml-lib/baml-core/src/ir/repr.rs | 8 ++-- .../validation_pipeline/validations/cycle.rs | 5 +-- .../validations/functions.rs | 4 +- .../validations/template_strings.rs | 1 - engine/baml-lib/parser-database/src/lib.rs | 26 ++++++++---- .../baml-lib/parser-database/src/names/mod.rs | 10 ++--- .../parser-database/src/walkers/alias.rs | 5 +++ .../parser-database/src/walkers/class.rs | 24 +++++------ .../parser-database/src/walkers/function.rs | 17 ++++---- .../parser-database/src/walkers/mod.rs | 42 ++++++++++++------- 10 files changed, 81 insertions(+), 61 deletions(-) create mode 100644 engine/baml-lib/parser-database/src/walkers/alias.rs diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index 8d266626e..3867af340 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -2,14 +2,13 @@ use std::collections::HashSet; use anyhow::{anyhow, Result}; use baml_types::{Constraint, ConstraintLevel, FieldType}; -use either::Either; use indexmap::{IndexMap, IndexSet}; use internal_baml_parser_database::{ walkers::{ ClassWalker, ClientSpec as AstClientSpec, ClientWalker, ConfigurationWalker, EnumValueWalker, EnumWalker, FieldWalker, FunctionWalker, TemplateStringWalker, }, - Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, + Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, TypeWalker, }; use internal_baml_schema_ast::ast::SubType; @@ -413,7 +412,7 @@ impl WithRepr for ast::FieldType { } ast::FieldType::Symbol(arity, idn, ..) => type_with_arity( match db.find_type(idn) { - Some(Either::Left(class_walker)) => { + Some(TypeWalker::Class(class_walker)) => { let base_class = FieldType::Class(class_walker.name().to_string()); let maybe_constraints = class_walker.get_constraints(SubType::Class); match maybe_constraints { @@ -424,7 +423,7 @@ impl WithRepr for ast::FieldType { _ => base_class, } } - Some(Either::Right(enum_walker)) => { + Some(TypeWalker::Enum(enum_walker)) => { let base_type = FieldType::Enum(enum_walker.name().to_string()); let maybe_constraints = enum_walker.get_constraints(SubType::Enum); match maybe_constraints { @@ -435,6 +434,7 @@ impl WithRepr for ast::FieldType { _ => base_type, } } + Some(TypeWalker::TypeAlias(type_alias_walker)) => todo!(), None => return Err(anyhow!("Field type uses unresolvable local identifier")), }, arity, diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs index c177d1f8e..25d872143 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs @@ -1,8 +1,7 @@ use std::collections::{HashMap, HashSet}; -use either::Either; use internal_baml_diagnostics::DatamodelError; -use internal_baml_parser_database::Tarjan; +use internal_baml_parser_database::{Tarjan, TypeWalker}; use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithName, WithSpan}; use crate::validate::validation_pipeline::context::Context; @@ -67,7 +66,7 @@ fn insert_required_deps( ) { match field { FieldType::Symbol(arity, ident, _) if arity.is_required() => { - if let Some(Either::Left(class)) = ctx.db.find_type_by_str(ident.name()) { + if let Some(TypeWalker::Class(class)) = ctx.db.find_type_by_str(ident.name()) { deps.insert(class.id); } } diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs index 2296e4191..a966beeba 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/functions.rs @@ -2,9 +2,9 @@ use std::collections::HashSet; use crate::validate::validation_pipeline::context::Context; -use either::Either; use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span}; +use internal_baml_parser_database::TypeWalker; use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan}; use super::types::validate_type; @@ -246,7 +246,7 @@ impl<'c> NestedChecks<'c> { match field_type { FieldType::Symbol(_, id, ..) => match self.ctx.db.find_type(id) { - Some(Either::Left(class_walker)) => { + Some(TypeWalker::Class(class_walker)) => { // Stop recursion when dealing with recursive types. if !self.visited.insert(class_walker.id) { return false; diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs index d8a33485c..68b6d252f 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/template_strings.rs @@ -2,7 +2,6 @@ use std::collections::HashSet; use crate::validate::validation_pipeline::context::Context; -use either::Either; use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span}; use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan}; diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index bb3998589..f4fe8cf1a 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -39,7 +39,6 @@ mod types; use std::collections::{HashMap, HashSet}; pub use coerce_expression::{coerce, coerce_array, coerce_opt}; -use either::Either; pub use internal_baml_schema_ast::ast; use internal_baml_schema_ast::ast::SchemaAst; pub use tarjan::Tarjan; @@ -47,6 +46,7 @@ pub use types::{ Attributes, ContantDelayStrategy, ExponentialBackoffStrategy, PrinterType, PromptAst, PromptVariable, RetryPolicy, RetryPolicyStrategy, StaticType, }; +pub use walkers::TypeWalker; use self::{context::Context, interner::StringId, types::Types}; use internal_baml_diagnostics::{DatamodelError, Diagnostics}; @@ -157,8 +157,10 @@ impl ParserDatabase { let deps = HashSet::from_iter(deps.iter().filter_map( |dep| match self.find_type_by_str(dep) { - Some(Either::Left(cls)) => Some(cls.id), - Some(Either::Right(_)) => None, + Some(TypeWalker::Class(cls)) => Some(cls.id), + Some(TypeWalker::Enum(_)) => None, + // TODO: Does this interfere with recursive types? + Some(TypeWalker::TypeAlias(_)) => todo!(), None => panic!("Unknown class `{dep}`"), }, )); @@ -173,8 +175,8 @@ impl ParserDatabase { .map(|cycle| cycle.into_iter().collect()) .collect(); - // Additionally ensure the same thing for functions, but since we've already handled classes, - // this should be trivial. + // Additionally ensure the same thing for functions, but since we've + // already handled classes, this should be trivial. let extends = self .types .function @@ -184,8 +186,11 @@ impl ParserDatabase { let input_deps = input .iter() .filter_map(|f| match self.find_type_by_str(f) { - Some(Either::Left(walker)) => Some(walker.dependencies().iter().cloned()), - Some(Either::Right(_)) => None, + Some(TypeWalker::Class(walker)) => { + Some(walker.dependencies().iter().cloned()) + } + Some(TypeWalker::Enum(_)) => None, + Some(TypeWalker::TypeAlias(_)) => None, _ => panic!("Unknown class `{}`", f), }) .flatten() @@ -194,8 +199,11 @@ impl ParserDatabase { let output_deps = output .iter() .filter_map(|f| match self.find_type_by_str(f) { - Some(Either::Left(walker)) => Some(walker.dependencies().iter().cloned()), - Some(Either::Right(_)) => None, + Some(TypeWalker::Class(walker)) => { + Some(walker.dependencies().iter().cloned()) + } + Some(TypeWalker::Enum(_)) => None, + Some(TypeWalker::TypeAlias(_)) => todo!(), _ => panic!("Unknown class `{}`", f), }) .flatten() diff --git a/engine/baml-lib/parser-database/src/names/mod.rs b/engine/baml-lib/parser-database/src/names/mod.rs index a1a5ef4e9..ada725b22 100644 --- a/engine/baml-lib/parser-database/src/names/mod.rs +++ b/engine/baml-lib/parser-database/src/names/mod.rs @@ -33,7 +33,7 @@ pub(super) struct Names { /// - Generators /// - Model fields for each model pub(super) fn resolve_names(ctx: &mut Context<'_>) { - let mut enum_value_names: HashSet<&str> = HashSet::default(); // throwaway container for duplicate checking + let mut tmp_names: HashSet<&str> = HashSet::default(); // throwaway container for duplicate checking let mut names = Names::default(); for (top_id, top) in ctx.ast.iter_tops() { @@ -41,7 +41,7 @@ pub(super) fn resolve_names(ctx: &mut Context<'_>) { let namespace = match (top_id, top) { (_, ast::Top::Enum(ast_enum)) => { - enum_value_names.clear(); + tmp_names.clear(); validate_enum_name(ast_enum, ctx.diagnostics); validate_attribute_identifiers(ast_enum, ctx); @@ -50,7 +50,7 @@ pub(super) fn resolve_names(ctx: &mut Context<'_>) { validate_attribute_identifiers(value, ctx); - if !enum_value_names.insert(value.name()) { + if !tmp_names.insert(value.name()) { ctx.push_error(DatamodelError::new_duplicate_enum_value_error( ast_enum.name.name(), value.name(), @@ -147,13 +147,13 @@ pub(super) fn resolve_names(ctx: &mut Context<'_>) { (_, ast::Top::Generator(generator)) => { validate_generator_name(generator, ctx.diagnostics); - check_for_duplicate_properties(top, generator.fields(), &mut enum_value_names, ctx); + check_for_duplicate_properties(top, generator.fields(), &mut tmp_names, ctx); Some(either::Left(&mut names.generators)) } (ast::TopId::TestCase(testcase_id), ast::Top::TestCase(testcase)) => { validate_test(testcase, ctx.diagnostics); - check_for_duplicate_properties(top, testcase.fields(), &mut enum_value_names, ctx); + check_for_duplicate_properties(top, testcase.fields(), &mut tmp_names, ctx); // TODO: I think we should do this later after all parsing, as duplication // would work best as a validation error with walkers. diff --git a/engine/baml-lib/parser-database/src/walkers/alias.rs b/engine/baml-lib/parser-database/src/walkers/alias.rs new file mode 100644 index 000000000..832862335 --- /dev/null +++ b/engine/baml-lib/parser-database/src/walkers/alias.rs @@ -0,0 +1,5 @@ +use super::TypeWalker; +use internal_baml_schema_ast::ast::{self, Identifier}; + +/// A `class` declaration in the Prisma schema. +pub type TypeAliasWalker<'db> = super::Walker<'db, ast::TypeExpId>; diff --git a/engine/baml-lib/parser-database/src/walkers/class.rs b/engine/baml-lib/parser-database/src/walkers/class.rs index 5dd2500b9..b9f707aa3 100644 --- a/engine/baml-lib/parser-database/src/walkers/class.rs +++ b/engine/baml-lib/parser-database/src/walkers/class.rs @@ -1,5 +1,6 @@ use std::collections::HashSet; +use super::TypeWalker; use super::{field::FieldWalker, EnumWalker}; use crate::types::Attributes; use baml_types::Constraint; @@ -42,9 +43,8 @@ impl<'db> ClassWalker<'db> { self.db.types.class_dependencies[&self.class_id()] .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(_cls)) => None, - Some(Either::Right(walker)) => Some(walker), - None => None, + Some(TypeWalker::Enum(walker)) => Some(walker), + _ => None, }) } @@ -53,9 +53,8 @@ impl<'db> ClassWalker<'db> { self.db.types.class_dependencies[&self.class_id()] .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(walker)) => Some(walker), - Some(Either::Right(_enm)) => None, - None => None, + Some(TypeWalker::Class(walker)) => Some(walker), + _ => None, }) } @@ -92,7 +91,8 @@ impl<'db> ClassWalker<'db> { /// Get the constraints of a class or an enum. pub fn get_constraints(&self, sub_type: SubType) -> Option> { - self.get_default_attributes(sub_type).map(|attrs| attrs.constraints.clone()) + self.get_default_attributes(sub_type) + .map(|attrs| attrs.constraints.clone()) } /// Arguments of the function. @@ -166,9 +166,8 @@ impl<'db> ArgWalker<'db> { input .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(_cls)) => None, - Some(Either::Right(walker)) => Some(walker), - None => None, + Some(TypeWalker::Enum(walker)) => Some(walker), + _ => None, }) } @@ -178,9 +177,8 @@ impl<'db> ArgWalker<'db> { input .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(walker)) => Some(walker), - Some(Either::Right(_enm)) => None, - None => None, + Some(TypeWalker::Class(walker)) => Some(walker), + _ => None, }) } } diff --git a/engine/baml-lib/parser-database/src/walkers/function.rs b/engine/baml-lib/parser-database/src/walkers/function.rs index 1627cfc81..eb4b96978 100644 --- a/engine/baml-lib/parser-database/src/walkers/function.rs +++ b/engine/baml-lib/parser-database/src/walkers/function.rs @@ -7,7 +7,7 @@ use crate::{ types::FunctionType, }; -use super::{ClassWalker, ConfigurationWalker, EnumWalker, Walker}; +use super::{ClassWalker, ConfigurationWalker, EnumWalker, TypeWalker, Walker}; use std::iter::ExactSizeIterator; @@ -143,7 +143,10 @@ impl<'db> FunctionWalker<'db> { match client.0.split_once("/") { // TODO: do this in a more robust way // actually validate which clients are and aren't allowed - Some((provider, model)) => Ok(ClientSpec::Shorthand(provider.to_string(), model.to_string())), + Some((provider, model)) => Ok(ClientSpec::Shorthand( + provider.to_string(), + model.to_string(), + )), None => match self.db.find_client(client.0.as_str()) { Some(client) => Ok(ClientSpec::Named(client.name().to_string())), None => { @@ -217,9 +220,8 @@ impl<'db> ArgWalker<'db> { if self.id.1 { input } else { output } .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(_cls)) => None, - Some(Either::Right(walker)) => Some(walker), - None => None, + Some(TypeWalker::Enum(walker)) => Some(walker), + _ => None, }) } @@ -229,9 +231,8 @@ impl<'db> ArgWalker<'db> { if self.id.1 { input } else { output } .iter() .filter_map(|f| match self.db.find_type_by_str(f) { - Some(Either::Left(walker)) => Some(walker), - Some(Either::Right(_enm)) => None, - None => None, + Some(TypeWalker::Class(walker)) => Some(walker), + _ => None, }) } } diff --git a/engine/baml-lib/parser-database/src/walkers/mod.rs b/engine/baml-lib/parser-database/src/walkers/mod.rs index 9d1fd2a92..6f9d5abff 100644 --- a/engine/baml-lib/parser-database/src/walkers/mod.rs +++ b/engine/baml-lib/parser-database/src/walkers/mod.rs @@ -6,6 +6,7 @@ //! - Know about relations. //! - Do not know anything about connectors, they are generic. +mod alias; mod r#class; mod client; mod configuration; @@ -14,16 +15,17 @@ mod field; mod function; mod template_string; +use alias::TypeAliasWalker; use baml_types::TypeValue; pub use client::*; pub use configuration::*; use either::Either; pub use field::*; -pub use function::{FunctionWalker, ClientSpec}; -pub use template_string::TemplateStringWalker; +pub use function::{ClientSpec, FunctionWalker}; use internal_baml_schema_ast::ast::{FieldType, Identifier, TopId, TypeExpId, WithName}; pub use r#class::*; pub use r#enum::*; +pub use template_string::TemplateStringWalker; /// A generic walker. Only walkers intantiated with a concrete ID type (`I`) are useful. #[derive(Clone, Copy)] @@ -50,11 +52,21 @@ where } } +/// Walker kind. +pub enum TypeWalker<'db> { + /// Class walker. + Class(ClassWalker<'db>), + /// Enum walker. + Enum(EnumWalker<'db>), + /// Type alias walker. + TypeAlias(TypeAliasWalker<'db>), +} + impl<'db> crate::ParserDatabase { /// Find an enum by name. pub fn find_enum(&'db self, idn: &Identifier) -> Option> { self.find_type(idn).and_then(|either| match either { - Either::Right(class) => Some(class), + TypeWalker::Enum(enm) => Some(enm), _ => None, }) } @@ -66,22 +78,19 @@ impl<'db> crate::ParserDatabase { } /// Find a type by name. - pub fn find_type_by_str( - &'db self, - name: &str, - ) -> Option, EnumWalker<'db>>> { + pub fn find_type_by_str(&'db self, name: &str) -> Option> { self.find_top_by_str(name).and_then(|top_id| match top_id { - TopId::Class(class_id) => Some(Either::Left(self.walk(*class_id))), - TopId::Enum(enum_id) => Some(Either::Right(self.walk(*enum_id))), + TopId::Class(class_id) => Some(TypeWalker::Class(self.walk(*class_id))), + TopId::Enum(enum_id) => Some(TypeWalker::Enum(self.walk(*enum_id))), + TopId::TypeAlias(type_alias_id) => { + Some(TypeWalker::TypeAlias(self.walk(*type_alias_id))) + } _ => None, }) } /// Find a type by name. - pub fn find_type( - &'db self, - idn: &Identifier, - ) -> Option, EnumWalker<'db>>> { + pub fn find_type(&'db self, idn: &Identifier) -> Option> { match idn { Identifier::Local(local, _) => self.find_type_by_str(local), _ => None, @@ -91,7 +100,7 @@ impl<'db> crate::ParserDatabase { /// Find a model by name. pub fn find_class(&'db self, idn: &Identifier) -> Option> { self.find_type(idn).and_then(|either| match either { - Either::Left(class) => Some(class), + TypeWalker::Class(class) => Some(class), _ => None, }) } @@ -255,8 +264,9 @@ impl<'db> crate::ParserDatabase { FieldType::Symbol(arity, idn, ..) => { let mut t = match self.find_type(idn) { None => Type::Undefined, - Some(Either::Left(_)) => Type::ClassRef(idn.to_string()), - Some(Either::Right(_)) => Type::String, + Some(TypeWalker::Class(_)) => Type::ClassRef(idn.to_string()), + Some(TypeWalker::Enum(_)) => Type::String, + Some(TypeWalker::TypeAlias(_)) => Type::String, }; if arity.is_optional() { t = Type::None | t;