Skip to content

Commit

Permalink
Merge attrs and allow only checks and asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniosarosi committed Dec 9, 2024
1 parent 2b8cce6 commit bad63bf
Show file tree
Hide file tree
Showing 27 changed files with 806 additions and 20 deletions.
65 changes: 64 additions & 1 deletion engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ pub fn make_test_ir(source_code: &str) -> anyhow::Result<IntermediateRepr> {
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::ir_helpers::IRHelper;
use crate::ir::{ir_helpers::IRHelper, TypeValue};

#[test]
fn test_docstrings() {
Expand Down Expand Up @@ -1205,4 +1205,67 @@ mod tests {
let walker = ir.find_test(&function, "Foo").unwrap();
assert_eq!(walker.item.1.elem.constraints.len(), 1);
}

#[test]
fn test_resolve_type_alias() {
let ir = make_test_ir(
r##"
type One = int
type Two = One
type Three = Two
class Test {
field Three
}
"##,
)
.unwrap();

let class = ir.find_class("Test").unwrap();
let alias = class.find_field("field").unwrap();

let FieldType::Alias { resolution, .. } = alias.r#type() else {
panic!("expected alias type, found {:?}", alias.r#type());
};

assert_eq!(**resolution, FieldType::Primitive(TypeValue::Int));
}

#[test]
fn test_merge_type_alias_attributes() {
let ir = make_test_ir(
r##"
type One = int @check(gt_ten, {{ this > 10 }})
type Two = One @check(lt_twenty, {{ this < 20 }})
type Three = Two @assert({{ this != 15 }})
class Test {
field Three
}
"##,
)
.unwrap();

let class = ir.find_class("Test").unwrap();
let alias = class.find_field("field").unwrap();

let FieldType::Alias { resolution, .. } = alias.r#type() else {
panic!("expected alias type, found {:?}", alias.r#type());
};

let FieldType::Constrained { base, constraints } = &**resolution else {
panic!("expected resolved constrained type, found {:?}", resolution);
};

assert_eq!(constraints.len(), 3);

assert_eq!(constraints[0].level, ConstraintLevel::Assert);
assert_eq!(constraints[0].label, None);

assert_eq!(constraints[1].level, ConstraintLevel::Check);
assert_eq!(constraints[1].label, Some("lt_twenty".to_string()));

assert_eq!(constraints[2].level, ConstraintLevel::Check);
assert_eq!(constraints[2].label, Some("gt_ten".to_string()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
type DescNotAllowed = string @description("This is not allowed")

type AliasNotAllowed = float @alias("Alias not allowed")

type SkipNotAllowed = float @skip

// error: Error validating: type aliases may only have check and assert attributes
// --> class/invalid_attrs_on_type_alias.baml:1
// |
// |
// 1 | type DescNotAllowed = string @description("This is not allowed")
// |
// error: Error validating: type aliases may only have check and assert attributes
// --> class/invalid_attrs_on_type_alias.baml:3
// |
// 2 |
// 3 | type AliasNotAllowed = float @alias("Alias not allowed")
// |
// error: Error validating: type aliases may only have check and assert attributes
// --> class/invalid_attrs_on_type_alias.baml:5
// |
// 4 |
// 5 | type SkipNotAllowed = float @skip
// |
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ type Graph = map<string, string[]>

type Combination = Primitive | List | Graph

// Alias with attrs.
type Currency = int @check(gt_ten, {{ this > 10 }})
type Amount = Currency @assert ({{ this > 0 }})

class MergeAttrs {
amount Amount @description("In USD")
}

function PrimitiveAlias(p: Primitive) -> Primitive {
client "openai/gpt-4o"
prompt r#"
Expand Down Expand Up @@ -34,3 +42,14 @@ function NestedAlias(c: Combination) -> Combination {
{{ ctx.output_format }}
"#
}

function MergeAliasAttributes(money: int) -> MergeAttrs {
client "openai/gpt-4o"
prompt r#"
Return the given integer in the specified format:

{{ money }}

{{ ctx.output_format }}
"#
}
46 changes: 44 additions & 2 deletions engine/baml-lib/parser-database/src/attributes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use internal_baml_diagnostics::Span;
use internal_baml_schema_ast::ast::{Top, TopId, TypeExpId, TypeExpressionBlock};
use internal_baml_diagnostics::{DatamodelError, Span};
use internal_baml_schema_ast::ast::{
Assignment, Top, TopId, TypeAliasId, TypeExpId, TypeExpressionBlock,
};

mod alias;
pub mod constraint;
Expand Down Expand Up @@ -79,6 +81,9 @@ pub(super) fn resolve_attributes(ctx: &mut Context<'_>) {
(TopId::Enum(enum_id), Top::Enum(ast_enum)) => {
resolve_type_exp_block_attributes(enum_id, ast_enum, ctx, SubType::Enum)
}
(TopId::TypeAlias(alias_id), Top::TypeAlias(assignment)) => {
resolve_type_alias_attributes(alias_id, assignment, ctx)
}
_ => (),
}
}
Expand Down Expand Up @@ -132,3 +137,40 @@ fn resolve_type_exp_block_attributes<'db>(
_ => (),
}
}

/// Quick hack to validate type alias attributes.
///
/// Unlike classes and enums, type aliases only support checks and asserts.
/// Everything else is reported as an error. On top of that, checks and asserts
/// must be merged when aliases point to other aliases. We do this recursively
/// when resolving the type alias to its final "virtual" type at
/// [`crate::types::resolve_type_alias`].
///
/// Then checks and asserts are collected from the virtual type and stored in
/// the IR at `engine/baml-lib/baml-core/src/ir/repr.rs`, so there's no need to
/// store them in separate classes like [`ClassAttributes`] or similar, at least
/// for now.
fn resolve_type_alias_attributes<'db>(
alias_id: TypeAliasId,
assignment: &'db Assignment,
ctx: &mut Context<'db>,
) {
ctx.assert_all_attributes_processed(alias_id.into());
let type_alias_attributes = to_string_attribute::visit(ctx, assignment.value.span(), false);
ctx.validate_visited_attributes();

// Some additional specific validation for type alias attributes.
if let Some(attrs) = &type_alias_attributes {
if attrs.dynamic_type().is_some()
|| attrs.alias().is_some()
|| attrs.skip().is_some()
|| attrs.description().is_some()
{
ctx.diagnostics
.push_error(DatamodelError::new_validation_error(
"type aliases may only have check and assert attributes",
assignment.span.clone(),
));
}
}
}
15 changes: 15 additions & 0 deletions engine/baml-lib/parser-database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,4 +600,19 @@ mod test {

Ok(())
}

#[test]
fn merged_alias_attrs() -> Result<(), Diagnostics> {
#[rustfmt::skip]
let db = parse(r#"
type One = int @assert({{ this < 5 }})
type Two = One @assert({{ this > 0 }})
"#)?;

let resolved = db.resolved_type_alias_by_name("Two").unwrap();

assert_eq!(resolved.attributes().len(), 2);

Ok(())
}
}
21 changes: 17 additions & 4 deletions engine/baml-lib/parser-database/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ fn visit_class<'db>(
pub fn resolve_type_alias(field_type: &FieldType, db: &ParserDatabase) -> FieldType {
match field_type {
// For symbols we need to check if we're dealing with aliases.
FieldType::Symbol(arity, ident, span) => {
FieldType::Symbol(arity, ident, attrs) => {
let Some(string_id) = db.interner.lookup(ident.name()) else {
unreachable!(
"Attempting to resolve alias `{ident}` that does not exist in the interner"
Expand All @@ -406,7 +406,7 @@ pub fn resolve_type_alias(field_type: &FieldType, db: &ParserDatabase) -> FieldT

match top_id {
ast::TopId::TypeAlias(alias_id) => {
let resolved = match db.types.resolved_type_aliases.get(alias_id) {
let mut resolved = match db.types.resolved_type_aliases.get(alias_id) {
// Check if we can avoid deeper recursion.
Some(already_resolved) => already_resolved.to_owned(),
// No luck, recurse.
Expand All @@ -419,11 +419,24 @@ pub fn resolve_type_alias(field_type: &FieldType, db: &ParserDatabase) -> FieldT
// type AliasTwo = AliasOne
//
// AliasTwo resolves to an "optional" type.
if resolved.is_optional() || arity.is_optional() {
//
// TODO: Add a `set_arity` function or something and avoid
// this clone.
resolved = if resolved.is_optional() || arity.is_optional() {
resolved.to_nullable()
} else {
resolved
}
};

// Merge attributes.
resolved.set_attributes({
let mut merged_attrs = Vec::from(field_type.attributes());
merged_attrs.extend(resolved.attributes().to_owned());

merged_attrs
});

resolved
}

// Class or enum. Already "resolved", pop off the stack.
Expand Down
8 changes: 8 additions & 0 deletions engine/baml-lib/schema-ast/src/ast/attribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub enum AttributeContainer {
ClassField(super::TypeExpId, super::FieldId),
Enum(super::TypeExpId),
EnumValue(super::TypeExpId, super::FieldId),
TypeAlias(super::TypeAliasId),
}

impl From<super::TypeExpId> for AttributeContainer {
Expand All @@ -79,6 +80,12 @@ impl From<(super::TypeExpId, super::FieldId)> for AttributeContainer {
}
}

impl From<super::TypeAliasId> for AttributeContainer {
fn from(v: super::TypeAliasId) -> Self {
Self::TypeAlias(v)
}
}

/// An attribute (@ or @@) node in the AST.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct AttributeId(AttributeContainer, u32);
Expand All @@ -102,6 +109,7 @@ impl Index<AttributeContainer> for super::SchemaAst {
AttributeContainer::EnumValue(enum_id, value_idx) => {
&self[enum_id][value_idx].attributes
}
AttributeContainer::TypeAlias(alias_id) => &self[alias_id].value.attributes(),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/schema-ast/src/parser/datamodel.pest
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ single_word = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_" | "-")* }
// ######################################
// Type Alias
// ######################################
type_alias = { identifier ~ identifier ~ assignment ~ field_type }
type_alias = { identifier ~ identifier ~ assignment ~ field_type_with_attr }

// ######################################
// Arguments
Expand Down
12 changes: 11 additions & 1 deletion engine/baml-lib/schema-ast/src/parser/parse_assignment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ use super::{
Rule,
};

use crate::{assert_correct_parser, ast::*, parser::parse_types::parse_field_type};
use crate::{
assert_correct_parser,
ast::*,
parser::{parse_field::parse_field_type_with_attr, parse_types::parse_field_type},
};

use internal_baml_diagnostics::{DatamodelError, Diagnostics};

Expand Down Expand Up @@ -45,8 +49,14 @@ pub(crate) fn parse_assignment(pair: Pair<'_>, diagnostics: &mut Diagnostics) ->

Rule::assignment => {} // Ok, equal sign.

// TODO: We probably only need field_type_with_attr since that's how
// the PEST syntax is defined.
Rule::field_type => field_type = parse_field_type(current, diagnostics),

Rule::field_type_with_attr => {
field_type = parse_field_type_with_attr(current, false, diagnostics)
}

_ => parsing_catch_all(current, "type_alias"),
}
}
Expand Down
30 changes: 30 additions & 0 deletions integ-tests/baml_src/test-files/functions/output/type-aliases.baml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,33 @@ function NestedAlias(c: Combination) -> Combination {
{{ ctx.output_format }}
"#
}

// Test attribute merging.
type Currency = int @check(gt_ten, {{ this > 10 }})
type Amount = Currency @assert ({{ this > 0 }})

class MergeAttrs {
amount Amount @description("In USD")
}

function MergeAliasAttributes(money: int) -> MergeAttrs {
client "openai/gpt-4o"
prompt r#"
Return the given integer in the specified format:

{{ money }}

{{ ctx.output_format }}
"#
}

function ReturnAliasWithMergedAttributes(money: Amount) -> Amount {
client "openai/gpt-4o"
prompt r#"
Return the given integer without additional context:

{{ money }}

{{ ctx.output_format }}
"#
}
Loading

0 comments on commit bad63bf

Please sign in to comment.