Skip to content

Commit

Permalink
Fix subtype bug with aliases and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniosarosi committed Nov 28, 2024
1 parent cde729a commit b9de7ed
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 1,829 deletions.
8 changes: 4 additions & 4 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,17 @@ impl IRHelper for IntermediateRepr {

if !map_type.is_subtype_of(&field_type) {
anyhow::bail!("Could not unify {:?} with {:?}", map_type, field_type);
} else {
let mapped_fields: BamlMap<String, BamlValueWithMeta<FieldType>> =
}

let mapped_fields: BamlMap<String, BamlValueWithMeta<FieldType>> =
pairs
.into_iter()
.map(|(key, val)| {
let sub_value = self.distribute_type(val, item_type.clone())?;
Ok((key, sub_value))
})
.collect::<anyhow::Result<BamlMap<String,BamlValueWithMeta<FieldType>>>>()?;
Ok(BamlValueWithMeta::Map(mapped_fields, field_type))
}
Ok(BamlValueWithMeta::Map(mapped_fields, field_type))
}
None => Ok(BamlValueWithMeta::Map(BamlMap::new(), field_type)),
}
Expand Down
3 changes: 2 additions & 1 deletion engine/baml-lib/baml-types/src/field_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ impl FieldType {
}

match (self, other) {
(FieldType::Alias { resolution, .. }, _) => resolution.is_subtype_of(other),
(_, FieldType::Alias { resolution, .. }) => self.is_subtype_of(resolution),
(FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true,
(FieldType::Optional(self_item), FieldType::Optional(other_item)) => {
self_item.is_subtype_of(other_item)
Expand Down Expand Up @@ -245,7 +247,6 @@ impl FieldType {
.zip(other_items)
.all(|(self_item, other_item)| self_item.is_subtype_of(other_item))
}
(FieldType::Alias { resolution, .. }, _) => resolution.is_subtype_of(other),
(FieldType::Tuple(_), _) => false,
(FieldType::Primitive(_), _) => false,
(FieldType::Enum(_), _) => false,
Expand Down
18 changes: 18 additions & 0 deletions integ-tests/python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,24 @@ async def test_single_literal_string_key_in_map(self):
res = await b.InOutSingleLiteralStringMapKey({"key": "1"})
assert res["key"] == "1"

@pytest.mark.asyncio
async def test_primitive_union_alias(self):
res = await b.PrimitiveAlias("test")
assert res == "test"

@pytest.mark.asyncio
async def test_map_alias(self):
res = await b.MapAlias({"A": ["B", "C"], "B": [], "C": []})
assert res == {"A": ["B", "C"], "B": [], "C": []}

@pytest.mark.asyncio
async def test_alias_union(self):
res = await b.NestedAlias("test")
assert res == "test"

res = await b.NestedAlias({"A": ["B", "C"], "B": [], "C": []})
assert res == {"A": ["B", "C"], "B": [], "C": []}


class MyCustomClass(NamedArgsSingleClass):
date: datetime.datetime
Expand Down
12 changes: 12 additions & 0 deletions integ-tests/ruby/test_functions.rb
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@

res = b.InOutSingleLiteralStringMapKey(m: {"key" => "1"})
assert_equal res['key'], "1"

res = b.PrimitiveAlias(p: "test")
assert_equal res, "test"

res = b.MapAlias(m: {"A" => ["B", "C"], "B" => [], "C" => []})
assert_equal res, {"A" => ["B", "C"], "B" => [], "C" => []}

res = b.NestedAlias(c: "test")
assert_equal res, "test"

res = b.NestedAlias(c: {"A" => ["B", "C"], "B" => [], "C" => []})
assert_equal res, {"A" => ["B", "C"], "B" => [], "C" => []}
end

it "accepts subclass of baml type" do
Expand Down
1,881 changes: 57 additions & 1,824 deletions integ-tests/typescript/test-report.html

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions integ-tests/typescript/tests/integ-tests.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,24 @@ describe('Integ tests', () => {
const res = await b.InOutSingleLiteralStringMapKey({ key: '1' })
expect(res).toHaveProperty('key', '1')
})

it('primitive union alias', async () => {
const res = await b.PrimitiveAlias('test')
expect(res).toEqual('test')
})

it('map alias', async () => {
const res = await b.MapAlias({ A: ['B', 'C'], B: [], C: [] })
expect(res).toEqual({ A: ['B', 'C'], B: [], C: [] })
})

it('alias union', async () => {
let res = await b.NestedAlias('test')
expect(res).toEqual('test')

res = await b.NestedAlias({ A: ['B', 'C'], B: [], C: [] })
expect(res).toEqual({ A: ['B', 'C'], B: [], C: [] })
})
})

it('should work for all outputs', async () => {
Expand Down

0 comments on commit b9de7ed

Please sign in to comment.