Skip to content

Commit

Permalink
Fix codegen for Python booleans & Add literals integ tests (#1099)
Browse files Browse the repository at this point in the history
Fixes #1094 & #1095
<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Fixes Python boolean code generation and adds integration tests for
literals, updating handling in `generate_types.rs` and adding tests for
various literal scenarios.
> 
>   - **Behavior**:
> - Added `to_python_literal()` in `generate_types.rs` for Python
boolean literals.
> - Updated `FieldType::Literal` handling in `generate_types.rs` and
`mod.rs` to use `to_python_literal()`.
>   - **Tests**:
> - Added integration tests for literals in
`integ-tests/baml_src/test-files/functions/input/named-args/single/` and
`integ-tests/baml_src/test-files/functions/output/`.
>     - Tests cover literal booleans, integers, strings, and unions.
>   - **Misc**:
> - Updated `async_client.py`, `sync_client.py`, and `inlinedbaml.py`
for new literal handling.
> - Added models `LiteralClassHello`, `LiteralClassOne`,
`LiteralClassTwo` in `types.py` and `partial_types.py`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 5c13e9b. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
antoniosarosi authored Oct 26, 2024
1 parent 11efa5e commit 6359762
Show file tree
Hide file tree
Showing 33 changed files with 3,177 additions and 167 deletions.
38 changes: 32 additions & 6 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
},
};
use anyhow::Result;
use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, FieldType, TypeValue};
use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, FieldType, LiteralValue, TypeValue};
pub use to_baml_arg::ArgCoercer;

use super::repr;
Expand Down Expand Up @@ -206,15 +206,32 @@ impl IRHelper for IntermediateRepr {
{
Ok(BamlValueWithMeta::String(s, field_type))
}
BamlValue::String(_) => anyhow::bail!("Could not unify String with {:?}", field_type),

BamlValue::String(s) => {
if let FieldType::Literal(LiteralValue::String(l)) = &field_type {
if s == *l {
return Ok(BamlValueWithMeta::String(s, field_type));
}
}

anyhow::bail!("Could not unify String with {:?}", field_type)
}

BamlValue::Int(i)
if FieldType::Primitive(TypeValue::Int).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Int(i, field_type))
}

BamlValue::Int(_) => anyhow::bail!("Could not unify Int with {:?}", field_type),
BamlValue::Int(i) => {
if let FieldType::Literal(LiteralValue::Int(l)) = &field_type {
if i == *l {
return Ok(BamlValueWithMeta::Int(i, field_type));
}
}

anyhow::bail!("Could not unify Int with {:?}", field_type)
}

BamlValue::Float(f)
if FieldType::Primitive(TypeValue::Float).is_subtype_of(&field_type) =>
Expand All @@ -228,7 +245,16 @@ impl IRHelper for IntermediateRepr {
{
Ok(BamlValueWithMeta::Bool(b, field_type))
}
BamlValue::Bool(_) => anyhow::bail!("Could not unify Bool with {:?}", field_type),

BamlValue::Bool(b) => {
if let FieldType::Literal(LiteralValue::Bool(l)) = &field_type {
if b == *l {
return Ok(BamlValueWithMeta::Bool(b, field_type));
}
}

anyhow::bail!("Could not unify Bool with {:?}", field_type)
}

BamlValue::Null if FieldType::Primitive(TypeValue::Null).is_subtype_of(&field_type) => {
Ok(BamlValueWithMeta::Null(field_type))
Expand Down Expand Up @@ -313,13 +339,13 @@ impl IRHelper for IntermediateRepr {
if FieldType::Enum(name.clone()).is_subtype_of(&field_type) {
Ok(BamlValueWithMeta::Enum(name, val, field_type))
} else {
anyhow::bail!("Could not unify Enum {name} with {:?}", field_type)
anyhow::bail!("Could not unify Enum {} with {:?}", name, field_type)
}
}

BamlValue::Class(name, fields) => {
if !FieldType::Class(name.clone()).is_subtype_of(&field_type) {
anyhow::bail!("Could not unify Class {name} with {:?}", field_type);
anyhow::bail!("Could not unify Class {} with {:?}", name, field_type);
} else {
let class_type = &self.find_class(&name)?.item.elem;
let class_fields: BamlMap<String, FieldType> = class_type
Expand Down
66 changes: 44 additions & 22 deletions engine/language_client_codegen/src/python/generate_types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use anyhow::Result;
use itertools::{Itertools, join};
use baml_types::LiteralValue;
use itertools::Itertools;
use std::borrow::Cow;

use crate::{field_type_attributes, type_check_attributes, TypeCheckAttributes};
Expand All @@ -13,7 +14,7 @@ use internal_baml_core::ir::{
#[template(path = "types.py.j2", escape = "none")]
pub(crate) struct PythonTypes<'ir> {
enums: Vec<PythonEnum<'ir>>,
classes: Vec<PythonClass<'ir>>
classes: Vec<PythonClass<'ir>>,
}

#[derive(askama::Template)]
Expand Down Expand Up @@ -70,8 +71,7 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for TypeBui
fn try_from(
(ir, _): (&'ir IntermediateRepr, &'_ crate::GeneratorArgs),
) -> Result<TypeBuilder<'ir>> {
let checks_classes =
type_check_attributes(ir)
let checks_classes = type_check_attributes(ir)
.into_iter()
.map(|checks| type_def_for_checks(checks))
.collect::<Vec<_>>();
Expand Down Expand Up @@ -169,18 +169,44 @@ pub fn add_default_value(node: &FieldType, type_str: &String) -> String {
}

pub fn type_name_for_checks(checks: &TypeCheckAttributes) -> String {
let check_names = checks.0.iter().map(|check| format!("\"{check}\"")).sorted().join(", ");
let check_names = checks
.0
.iter()
.map(|check| format!("\"{check}\""))
.sorted()
.join(", ");

format!["Literal[{check_names}]"]
}

fn type_def_for_checks(checks: TypeCheckAttributes) -> PythonClass<'static> {
PythonClass {
name: Cow::Owned(type_name_for_checks(&checks)),
fields: checks.0.into_iter().map(|check_name| (Cow::Owned(check_name), "baml_py.Check".to_string())).collect(),
dynamic: false
fields: checks
.0
.into_iter()
.map(|check_name| (Cow::Owned(check_name), "baml_py.Check".to_string()))
.collect(),
dynamic: false,
}
}

/// Returns the Python `Literal` representation of `self`.
pub fn to_python_literal(literal: &LiteralValue) -> String {
// Python bools are a little special...
let value = match literal {
LiteralValue::Bool(bool) => String::from(match *bool {
true => "True",
false => "False",
}),

// Rest of types match the fmt::Display impl.
other => other.to_string(),
};

format!("Literal[{value}]")
}

trait ToTypeReferenceInTypeDefinition {
fn to_type_ref(&self, ir: &IntermediateRepr) -> String;
fn to_partial_type_ref(&self, ir: &IntermediateRepr, wrapped: bool) -> String;
Expand All @@ -200,7 +226,7 @@ impl ToTypeReferenceInTypeDefinition for FieldType {
format!("\"{name}\"")
}
}
FieldType::Literal(value) => format!("Literal[{}]", value),
FieldType::Literal(value) => to_python_literal(value),
FieldType::Class(name) => format!("\"{name}\""),
FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir)),
FieldType::Map(key, value) => {
Expand All @@ -224,17 +250,13 @@ impl ToTypeReferenceInTypeDefinition for FieldType {
.join(", ")
),
FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir)),
FieldType::Constrained{base, ..} => {
match field_type_attributes(self) {
Some(checks) => {
let base_type_ref = base.to_type_ref(ir);
let checks_type_ref = type_name_for_checks(&checks);
format!("baml_py.Checked[{base_type_ref},{checks_type_ref}]")
}
None => {
base.to_type_ref(ir)
}
FieldType::Constrained { base, .. } => match field_type_attributes(self) {
Some(checks) => {
let base_type_ref = base.to_type_ref(ir);
let checks_type_ref = type_name_for_checks(&checks);
format!("baml_py.Checked[{base_type_ref},{checks_type_ref}]")
}
None => base.to_type_ref(ir),
},
}
}
Expand All @@ -259,7 +281,7 @@ impl ToTypeReferenceInTypeDefinition for FieldType {
format!("Optional[types.{name}]")
}
}
FieldType::Literal(value) => format!("Literal[{}]", value),
FieldType::Literal(value) => to_python_literal(value),
FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir, true)),
FieldType::Map(key, value) => {
format!(
Expand All @@ -286,17 +308,17 @@ impl ToTypeReferenceInTypeDefinition for FieldType {
.join(", ")
),
FieldType::Optional(inner) => inner.to_partial_type_ref(ir, false),
FieldType::Constrained{base,..} => {
FieldType::Constrained { base, .. } => {
let base_type_ref = base.to_partial_type_ref(ir, false);
match field_type_attributes(self) {
Some(checks) => {
let base_type_ref = base.to_partial_type_ref(ir, false);
let checks_type_ref = type_name_for_checks(&checks);
format!("baml_py.Checked[{base_type_ref},{checks_type_ref}]")
}
None => base_type_ref
None => base_type_ref,
}
},
}
}
}
}
56 changes: 29 additions & 27 deletions engine/language_client_codegen/src/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod python_language_features;
use std::path::PathBuf;

use anyhow::Result;
use generate_types::type_name_for_checks;
use generate_types::{to_python_literal, type_name_for_checks};
use indexmap::IndexMap;
use internal_baml_core::{
configuration::GeneratorDefaultClientMode,
Expand Down Expand Up @@ -163,7 +163,9 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonClient
args: f
.inputs()
.iter()
.map(|(name, r#type)| (name.to_string(), r#type.to_type_ref(ir, false)))
.map(|(name, r#type)| {
(name.to_string(), r#type.to_type_ref(ir, false))
})
.collect(),
})
})
Expand Down Expand Up @@ -198,11 +200,15 @@ impl ToTypeReferenceInClientDefinition for FieldType {
format!("types.{name}")
}
}
FieldType::Literal(value) => format!("Literal[{}]", value),
FieldType::Literal(value) => to_python_literal(value),
FieldType::Class(name) => format!("types.{name}"),
FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir, with_checked)),
FieldType::Map(key, value) => {
format!("Dict[{}, {}]", key.to_type_ref(ir, with_checked), value.to_type_ref(ir, with_checked))
format!(
"Dict[{}, {}]",
key.to_type_ref(ir, with_checked),
value.to_type_ref(ir, with_checked)
)
}
FieldType::Primitive(r#type) => r#type.to_python(),
FieldType::Union(inner) => format!(
Expand All @@ -221,18 +227,16 @@ impl ToTypeReferenceInClientDefinition for FieldType {
.collect::<Vec<_>>()
.join(", ")
),
FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir, with_checked)),
FieldType::Constrained{base, ..} => {
match field_type_attributes(self) {
Some(checks) => {
let base_type_ref = base.to_type_ref(ir, with_checked);
let checks_type_ref = type_name_for_checks(&checks);
format!("baml_py.Checked[{base_type_ref},types.{checks_type_ref}]")
}
None => {
base.to_type_ref(ir, with_checked)
}
FieldType::Optional(inner) => {
format!("Optional[{}]", inner.to_type_ref(ir, with_checked))
}
FieldType::Constrained { base, .. } => match field_type_attributes(self) {
Some(checks) => {
let base_type_ref = base.to_type_ref(ir, with_checked);
let checks_type_ref = type_name_for_checks(&checks);
format!("baml_py.Checked[{base_type_ref},types.{checks_type_ref}]")
}
None => base.to_type_ref(ir, with_checked),
},
}
}
Expand All @@ -251,8 +255,10 @@ impl ToTypeReferenceInClientDefinition for FieldType {
}
}
FieldType::Class(name) => format!("partial_types.{name}"),
FieldType::Literal(value) => format!("Literal[{}]", value),
FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir, with_checked)),
FieldType::Literal(value) => to_python_literal(value),
FieldType::List(inner) => {
format!("List[{}]", inner.to_partial_type_ref(ir, with_checked))
}
FieldType::Map(key, value) => {
format!(
"Dict[{}, {}]",
Expand All @@ -278,17 +284,13 @@ impl ToTypeReferenceInClientDefinition for FieldType {
.join(", ")
),
FieldType::Optional(inner) => inner.to_partial_type_ref(ir, with_checked),
FieldType::Constrained{base, ..} => {
match field_type_attributes(self) {
Some(checks) => {
let base_type_ref = base.to_partial_type_ref(ir, with_checked);
let checks_type_ref = type_name_for_checks(&checks);
format!("baml_py.Checked[{base_type_ref},types.{checks_type_ref}]")
}
None => {
base.to_partial_type_ref(ir, with_checked)
}
FieldType::Constrained { base, .. } => match field_type_attributes(self) {
Some(checks) => {
let base_type_ref = base.to_partial_type_ref(ir, with_checked);
let checks_type_ref = type_name_for_checks(&checks);
format!("baml_py.Checked[{base_type_ref},types.{checks_type_ref}]")
}
None => base.to_partial_type_ref(ir, with_checked),
},
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class LiteralClassHello {
prop "hello"
}

function FnLiteralClassInputOutput(input: LiteralClassHello) -> LiteralClassHello {
client GPT4
prompt #"
Return the same object you were given.
{{ ctx.output_format }}
"#
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class LiteralClassOne {
prop "one"
}

class LiteralClassTwo {
prop "two"
}

function FnLiteralUnionClassInputOutput(input: LiteralClassOne | LiteralClassTwo) -> LiteralClassOne | LiteralClassTwo {
client GPT4
prompt #"
Return the same object you were given.
{{ ctx.output_format }}
"#
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function TestNamedArgsLiteralBool(myBool: true) -> string {
client GPT35
prompt #"
Return this value back to me: {{myBool}}
"#
}

test TestFnNamedArgsLiteralBool {
functions [TestNamedArgsLiteralBool]
args {
myBool true
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function TestNamedArgsLiteralInt(myInt: 1) -> string {
client GPT35
prompt #"
Return this value back to me: {{myInt}}
"#
}

test TestFnNamedArgsLiteralInt {
functions [TestNamedArgsLiteralInt]
args {
myInt 1
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function TestNamedArgsLiteralString(myString: "My String") -> string {
client GPT35
prompt #"
Return this value back to me: {{myString}}
"#
}

test TestFnNamedArgsLiteralString {
functions [TestNamedArgsLiteralString]
args {
myString "My String"
}
}
Loading

0 comments on commit 6359762

Please sign in to comment.