diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index df0484f00..ffe4d27c7 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -269,15 +269,21 @@ impl ArgCoercer { } FieldType::Map(k, v) => { if let BamlValue::Map(kv) = value { + let mut map = BamlMap::new(); for (key, value) in kv { - let mut key_scope = ScopeStack::new(); - let _ = - self.coerce_arg(ir, k, &BamlValue::String(key.clone()), &mut key_scope); + scope.push("".to_string()); + let k = self.coerce_arg(ir, k, &BamlValue::String(key.clone()), scope); + scope.pop(false); - let mut value_scope = ScopeStack::new(); - let _ = self.coerce_arg(ir, v, value, &mut value_scope); + if k.is_ok() { + scope.push(key.to_string()); + if let Ok(v) = self.coerce_arg(ir, v, value, scope) { + map.insert(key.clone(), v); + } + scope.pop(false); + } } - Ok(value.clone()) + Ok(BamlValue::Map(map)) } else { scope.push_error(format!("Expected map, got `{}`", value)); Err(()) diff --git a/engine/baml-lib/jinja/src/output_format/types.rs b/engine/baml-lib/jinja/src/output_format/types.rs index f0a16e515..fd0ae74cc 100644 --- a/engine/baml-lib/jinja/src/output_format/types.rs +++ b/engine/baml-lib/jinja/src/output_format/types.rs @@ -238,7 +238,7 @@ impl OutputFormatContent { FieldType::List(_) => Some("Answer with a JSON Array using this schema:\n"), FieldType::Union(_) => Some("Answer in JSON using any of these schemas:\n"), FieldType::Optional(_) => Some("Answer in JSON using this schema:\n"), - FieldType::Map(_, _) => None, + FieldType::Map(_, _) => Some("Answer in JSON using this schema:\n"), FieldType::Tuple(_) => None, }, } diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index bd6e04914..7d80c8d21 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -1020,4 +1020,9 @@ async def test_arg_exceptions(): ) with pytest.raises(errors.BamlValidationError): - await b.DummyOutputFunction("dummy input") \ No newline at end of file + await b.DummyOutputFunction("dummy input") + +@pytest.mark.asyncio +async def test_map_as_param(): + with pytest.raises(errors.BamlInvalidArgumentError): + _ = await b.TestFnNamedArgsSingleMapStringToMap({ "a" : "b"}) # intentionally passing the wrong type diff --git a/integ-tests/python/tests/test_hi.py b/integ-tests/python/tests/test_hi.py deleted file mode 100644 index d79e7926f..000000000 --- a/integ-tests/python/tests/test_hi.py +++ /dev/null @@ -1,26 +0,0 @@ -import time -from typing import List -import pytest -from assertpy import assert_that -from dotenv import load_dotenv -from .base64_test_data import image_b64, audio_b64 - -load_dotenv() -import baml_py -from ..baml_client import b -from ..baml_client.sync_client import b as sync_b - -from ..baml_client import partial_types - -from ..baml_client.tracing import trace, set_tags, flush, on_log_event -from ..baml_client.type_builder import TypeBuilder -import datetime -import concurrent.futures -import asyncio -import random - - -@pytest.mark.asyncio -async def test_accepts_subclass_of_baml_type(): - print("calling with class") - _ = await b.ExtractResume("hello")