Skip to content

Commit

Permalink
Add TypeAlias walker
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniosarosi committed Nov 14, 2024
1 parent ad299e6 commit 8786e3a
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 61 deletions.
8 changes: 4 additions & 4 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ use std::collections::HashSet;

use anyhow::{anyhow, Result};
use baml_types::{Constraint, ConstraintLevel, FieldType};
use either::Either;
use indexmap::{IndexMap, IndexSet};
use internal_baml_parser_database::{
walkers::{
ClassWalker, ClientSpec as AstClientSpec, ClientWalker, ConfigurationWalker,
EnumValueWalker, EnumWalker, FieldWalker, FunctionWalker, TemplateStringWalker,
},
Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy,
Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, TypeWalker,
};
use internal_baml_schema_ast::ast::SubType;

Expand Down Expand Up @@ -413,7 +412,7 @@ impl WithRepr<FieldType> for ast::FieldType {
}
ast::FieldType::Symbol(arity, idn, ..) => type_with_arity(
match db.find_type(idn) {
Some(Either::Left(class_walker)) => {
Some(TypeWalker::Class(class_walker)) => {
let base_class = FieldType::Class(class_walker.name().to_string());
let maybe_constraints = class_walker.get_constraints(SubType::Class);
match maybe_constraints {
Expand All @@ -424,7 +423,7 @@ impl WithRepr<FieldType> for ast::FieldType {
_ => base_class,
}
}
Some(Either::Right(enum_walker)) => {
Some(TypeWalker::Enum(enum_walker)) => {
let base_type = FieldType::Enum(enum_walker.name().to_string());
let maybe_constraints = enum_walker.get_constraints(SubType::Enum);
match maybe_constraints {
Expand All @@ -435,6 +434,7 @@ impl WithRepr<FieldType> for ast::FieldType {
_ => base_type,
}
}
Some(TypeWalker::TypeAlias(type_alias_walker)) => todo!(),
None => return Err(anyhow!("Field type uses unresolvable local identifier")),
},
arity,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::collections::{HashMap, HashSet};

use either::Either;
use internal_baml_diagnostics::DatamodelError;
use internal_baml_parser_database::Tarjan;
use internal_baml_parser_database::{Tarjan, TypeWalker};
use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithName, WithSpan};

use crate::validate::validation_pipeline::context::Context;
Expand Down Expand Up @@ -67,7 +66,7 @@ fn insert_required_deps(
) {
match field {
FieldType::Symbol(arity, ident, _) if arity.is_required() => {
if let Some(Either::Left(class)) = ctx.db.find_type_by_str(ident.name()) {
if let Some(TypeWalker::Class(class)) = ctx.db.find_type_by_str(ident.name()) {
deps.insert(class.id);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::collections::HashSet;

use crate::validate::validation_pipeline::context::Context;

use either::Either;
use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span};

use internal_baml_parser_database::TypeWalker;
use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan};

use super::types::validate_type;
Expand Down Expand Up @@ -246,7 +246,7 @@ impl<'c> NestedChecks<'c> {

match field_type {
FieldType::Symbol(_, id, ..) => match self.ctx.db.find_type(id) {
Some(Either::Left(class_walker)) => {
Some(TypeWalker::Class(class_walker)) => {
// Stop recursion when dealing with recursive types.
if !self.visited.insert(class_walker.id) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::collections::HashSet;

use crate::validate::validation_pipeline::context::Context;

use either::Either;
use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Span};

use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan};
Expand Down
26 changes: 17 additions & 9 deletions engine/baml-lib/parser-database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ mod types;
use std::collections::{HashMap, HashSet};

pub use coerce_expression::{coerce, coerce_array, coerce_opt};
use either::Either;
pub use internal_baml_schema_ast::ast;
use internal_baml_schema_ast::ast::SchemaAst;
pub use tarjan::Tarjan;
pub use types::{
Attributes, ContantDelayStrategy, ExponentialBackoffStrategy, PrinterType, PromptAst,
PromptVariable, RetryPolicy, RetryPolicyStrategy, StaticType,
};
pub use walkers::TypeWalker;

use self::{context::Context, interner::StringId, types::Types};
use internal_baml_diagnostics::{DatamodelError, Diagnostics};
Expand Down Expand Up @@ -157,8 +157,10 @@ impl ParserDatabase {
let deps =
HashSet::from_iter(deps.iter().filter_map(
|dep| match self.find_type_by_str(dep) {
Some(Either::Left(cls)) => Some(cls.id),
Some(Either::Right(_)) => None,
Some(TypeWalker::Class(cls)) => Some(cls.id),
Some(TypeWalker::Enum(_)) => None,
// TODO: Does this interfere with recursive types?
Some(TypeWalker::TypeAlias(_)) => todo!(),
None => panic!("Unknown class `{dep}`"),
},
));
Expand All @@ -173,8 +175,8 @@ impl ParserDatabase {
.map(|cycle| cycle.into_iter().collect())
.collect();

// Additionally ensure the same thing for functions, but since we've already handled classes,
// this should be trivial.
// Additionally ensure the same thing for functions, but since we've
// already handled classes, this should be trivial.
let extends = self
.types
.function
Expand All @@ -184,8 +186,11 @@ impl ParserDatabase {
let input_deps = input
.iter()
.filter_map(|f| match self.find_type_by_str(f) {
Some(Either::Left(walker)) => Some(walker.dependencies().iter().cloned()),
Some(Either::Right(_)) => None,
Some(TypeWalker::Class(walker)) => {
Some(walker.dependencies().iter().cloned())
}
Some(TypeWalker::Enum(_)) => None,
Some(TypeWalker::TypeAlias(_)) => None,
_ => panic!("Unknown class `{}`", f),
})
.flatten()
Expand All @@ -194,8 +199,11 @@ impl ParserDatabase {
let output_deps = output
.iter()
.filter_map(|f| match self.find_type_by_str(f) {
Some(Either::Left(walker)) => Some(walker.dependencies().iter().cloned()),
Some(Either::Right(_)) => None,
Some(TypeWalker::Class(walker)) => {
Some(walker.dependencies().iter().cloned())
}
Some(TypeWalker::Enum(_)) => None,
Some(TypeWalker::TypeAlias(_)) => todo!(),
_ => panic!("Unknown class `{}`", f),
})
.flatten()
Expand Down
10 changes: 5 additions & 5 deletions engine/baml-lib/parser-database/src/names/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ pub(super) struct Names {
/// - Generators
/// - Model fields for each model
pub(super) fn resolve_names(ctx: &mut Context<'_>) {
let mut enum_value_names: HashSet<&str> = HashSet::default(); // throwaway container for duplicate checking
let mut tmp_names: HashSet<&str> = HashSet::default(); // throwaway container for duplicate checking
let mut names = Names::default();

for (top_id, top) in ctx.ast.iter_tops() {
assert_is_not_a_reserved_scalar_type(top.identifier(), ctx);

let namespace = match (top_id, top) {
(_, ast::Top::Enum(ast_enum)) => {
enum_value_names.clear();
tmp_names.clear();
validate_enum_name(ast_enum, ctx.diagnostics);
validate_attribute_identifiers(ast_enum, ctx);

Expand All @@ -50,7 +50,7 @@ pub(super) fn resolve_names(ctx: &mut Context<'_>) {

validate_attribute_identifiers(value, ctx);

if !enum_value_names.insert(value.name()) {
if !tmp_names.insert(value.name()) {
ctx.push_error(DatamodelError::new_duplicate_enum_value_error(
ast_enum.name.name(),
value.name(),
Expand Down Expand Up @@ -147,13 +147,13 @@ pub(super) fn resolve_names(ctx: &mut Context<'_>) {

(_, ast::Top::Generator(generator)) => {
validate_generator_name(generator, ctx.diagnostics);
check_for_duplicate_properties(top, generator.fields(), &mut enum_value_names, ctx);
check_for_duplicate_properties(top, generator.fields(), &mut tmp_names, ctx);
Some(either::Left(&mut names.generators))
}

(ast::TopId::TestCase(testcase_id), ast::Top::TestCase(testcase)) => {
validate_test(testcase, ctx.diagnostics);
check_for_duplicate_properties(top, testcase.fields(), &mut enum_value_names, ctx);
check_for_duplicate_properties(top, testcase.fields(), &mut tmp_names, ctx);

// TODO: I think we should do this later after all parsing, as duplication
// would work best as a validation error with walkers.
Expand Down
5 changes: 5 additions & 0 deletions engine/baml-lib/parser-database/src/walkers/alias.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use super::TypeWalker;
use internal_baml_schema_ast::ast::{self, Identifier};

/// A `class` declaration in the Prisma schema.
pub type TypeAliasWalker<'db> = super::Walker<'db, ast::TypeExpId>;
24 changes: 11 additions & 13 deletions engine/baml-lib/parser-database/src/walkers/class.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashSet;

use super::TypeWalker;
use super::{field::FieldWalker, EnumWalker};
use crate::types::Attributes;
use baml_types::Constraint;
Expand Down Expand Up @@ -42,9 +43,8 @@ impl<'db> ClassWalker<'db> {
self.db.types.class_dependencies[&self.class_id()]
.iter()
.filter_map(|f| match self.db.find_type_by_str(f) {
Some(Either::Left(_cls)) => None,
Some(Either::Right(walker)) => Some(walker),
None => None,
Some(TypeWalker::Enum(walker)) => Some(walker),
_ => None,
})
}

Expand All @@ -53,9 +53,8 @@ impl<'db> ClassWalker<'db> {
self.db.types.class_dependencies[&self.class_id()]
.iter()
.filter_map(|f| match self.db.find_type_by_str(f) {
Some(Either::Left(walker)) => Some(walker),
Some(Either::Right(_enm)) => None,
None => None,
Some(TypeWalker::Class(walker)) => Some(walker),
_ => None,
})
}

Expand Down Expand Up @@ -92,7 +91,8 @@ impl<'db> ClassWalker<'db> {

/// Get the constraints of a class or an enum.
pub fn get_constraints(&self, sub_type: SubType) -> Option<Vec<Constraint>> {
self.get_default_attributes(sub_type).map(|attrs| attrs.constraints.clone())
self.get_default_attributes(sub_type)
.map(|attrs| attrs.constraints.clone())
}

/// Arguments of the function.
Expand Down Expand Up @@ -166,9 +166,8 @@ impl<'db> ArgWalker<'db> {
input
.iter()
.filter_map(|f| match self.db.find_type_by_str(f) {
Some(Either::Left(_cls)) => None,
Some(Either::Right(walker)) => Some(walker),
None => None,
Some(TypeWalker::Enum(walker)) => Some(walker),
_ => None,
})
}

Expand All @@ -178,9 +177,8 @@ impl<'db> ArgWalker<'db> {
input
.iter()
.filter_map(|f| match self.db.find_type_by_str(f) {
Some(Either::Left(walker)) => Some(walker),
Some(Either::Right(_enm)) => None,
None => None,
Some(TypeWalker::Class(walker)) => Some(walker),
_ => None,
})
}
}
Expand Down
17 changes: 9 additions & 8 deletions engine/baml-lib/parser-database/src/walkers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
types::FunctionType,
};

use super::{ClassWalker, ConfigurationWalker, EnumWalker, Walker};
use super::{ClassWalker, ConfigurationWalker, EnumWalker, TypeWalker, Walker};

use std::iter::ExactSizeIterator;

Expand Down Expand Up @@ -143,7 +143,10 @@ impl<'db> FunctionWalker<'db> {
match client.0.split_once("/") {
// TODO: do this in a more robust way
// actually validate which clients are and aren't allowed
Some((provider, model)) => Ok(ClientSpec::Shorthand(provider.to_string(), model.to_string())),
Some((provider, model)) => Ok(ClientSpec::Shorthand(
provider.to_string(),
model.to_string(),
)),
None => match self.db.find_client(client.0.as_str()) {
Some(client) => Ok(ClientSpec::Named(client.name().to_string())),
None => {
Expand Down Expand Up @@ -217,9 +220,8 @@ impl<'db> ArgWalker<'db> {
if self.id.1 { input } else { output }
.iter()
.filter_map(|f| match self.db.find_type_by_str(f) {
Some(Either::Left(_cls)) => None,
Some(Either::Right(walker)) => Some(walker),
None => None,
Some(TypeWalker::Enum(walker)) => Some(walker),
_ => None,
})
}

Expand All @@ -229,9 +231,8 @@ impl<'db> ArgWalker<'db> {
if self.id.1 { input } else { output }
.iter()
.filter_map(|f| match self.db.find_type_by_str(f) {
Some(Either::Left(walker)) => Some(walker),
Some(Either::Right(_enm)) => None,
None => None,
Some(TypeWalker::Class(walker)) => Some(walker),
_ => None,
})
}
}
Expand Down
Loading

0 comments on commit 8786e3a

Please sign in to comment.