Skip to content

Commit

Permalink
ctx.output_format for map is wrong and fix type validation for map pa…
Browse files Browse the repository at this point in the history
…rams

Fixes #861
  • Loading branch information
hellovai committed Sep 3, 2024
1 parent 1c6c9d3 commit 12bf00c
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 34 deletions.
18 changes: 12 additions & 6 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,21 @@ impl ArgCoercer {
}
FieldType::Map(k, v) => {
if let BamlValue::Map(kv) = value {
let mut map = BamlMap::new();
for (key, value) in kv {
let mut key_scope = ScopeStack::new();
let _ =
self.coerce_arg(ir, k, &BamlValue::String(key.clone()), &mut key_scope);
scope.push("<key>".to_string());
let k = self.coerce_arg(ir, k, &BamlValue::String(key.clone()), scope);
scope.pop(false);

let mut value_scope = ScopeStack::new();
let _ = self.coerce_arg(ir, v, value, &mut value_scope);
if k.is_ok() {
scope.push(key.to_string());
if let Ok(v) = self.coerce_arg(ir, v, value, scope) {
map.insert(key.clone(), v);
}
scope.pop(false);
}
}
Ok(value.clone())
Ok(BamlValue::Map(map))
} else {
scope.push_error(format!("Expected map, got `{}`", value));
Err(())
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/jinja/src/output_format/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ impl OutputFormatContent {
FieldType::List(_) => Some("Answer with a JSON Array using this schema:\n"),
FieldType::Union(_) => Some("Answer in JSON using any of these schemas:\n"),
FieldType::Optional(_) => Some("Answer in JSON using this schema:\n"),
FieldType::Map(_, _) => None,
FieldType::Map(_, _) => Some("Answer in JSON using this schema:\n"),
FieldType::Tuple(_) => None,
},
}
Expand Down
7 changes: 6 additions & 1 deletion integ-tests/python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,4 +1020,9 @@ async def test_arg_exceptions():
)

with pytest.raises(errors.BamlValidationError):
await b.DummyOutputFunction("dummy input")
await b.DummyOutputFunction("dummy input")

@pytest.mark.asyncio
async def test_map_as_param():
with pytest.raises(errors.BamlInvalidArgumentError):
_ = await b.TestFnNamedArgsSingleMapStringToMap({ "a" : "b"}) # intentionally passing the wrong type
26 changes: 0 additions & 26 deletions integ-tests/python/tests/test_hi.py

This file was deleted.

0 comments on commit 12bf00c

Please sign in to comment.