diff --git a/engine/language_client_python/src/parse_py_type.rs b/engine/language_client_python/src/parse_py_type.rs index e4c0416d1..569bf1e24 100644 --- a/engine/language_client_python/src/parse_py_type.rs +++ b/engine/language_client_python/src/parse_py_type.rs @@ -5,7 +5,7 @@ use baml_types::{BamlMap, BamlMedia, BamlValue}; use pyo3::{ exceptions::{PyRuntimeError, PyTypeError}, prelude::{PyAnyMethods, PyTypeMethods}, - types::{PyBool, PyBoolMethods, PyList}, + types::{PyBool, PyBoolMethods, PyDict, PyList, PyString}, PyErr, PyObject, PyResult, Python, ToPyObject, }; @@ -249,25 +249,40 @@ pub fn parse_py_type( } }) .unwrap_or("".to_string()); - let fields = match t + let mut fields = HashMap::new(); + // Get regular fields + if let Ok(model_fields) = t .getattr("model_fields")? .extract::>() { - Ok(fields) => fields - .keys() - .filter_map(|k| { - let v = any.getattr(py, k.as_str()); - if let Ok(v) = v { - Some((k.clone(), v)) - } else { - None + for (key, _) in model_fields { + if let Ok(value) = any.getattr(py, key.as_str()) { + fields.insert(key, value.to_object(py)); + } + } + } + + // Get extra fields (like if this is a @@dynamic class) + if let Ok(extra) = any.getattr(py, "__pydantic_extra__") { + if let Ok(extra_dict) = extra.downcast::(py) { + for (key, value) in extra_dict.iter() { + if let (Ok(key), value) = (key.extract::(), value) { + fields.insert(key, value.to_object(py)); } - }) - .collect::>(), - Err(_) => { - bail!("model_fields is not a dict") + } } - }; + } + + // Log the fields + // log::info!("Fields of {}:", name); + // for (key, value) in &fields { + // let repr = py + // .import_bound("builtins")? + // .getattr("repr")? + // .call1((value,))?; + // let repr_str = repr.extract::()?; + // log::info!(" {}: {}", key, repr_str); + // } Ok(MappedPyType::Class(name, fields)) // use downcast only } else if let Ok(list) = any.downcast_bound::(py) { diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index 4a0dfd567..e8b50919a 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -10,7 +10,7 @@ from ..baml_client.globals import ( DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME, ) -from ..baml_client.types import NamedArgsSingleEnumList, NamedArgsSingleClass +from ..baml_client.types import NamedArgsSingleEnumList, NamedArgsSingleClass, DynInputOutput from ..baml_client.tracing import trace, set_tags, flush, on_log_event from ..baml_client.type_builder import TypeBuilder import datetime @@ -555,6 +555,41 @@ async def test_stream_dynamic_class_output(): assert final.hair_color == "black" +@pytest.mark.asyncio +async def test_dynamic_inputs_list2(): + tb = TypeBuilder() + tb.DynInputOutput.add_property("new_key", tb.string().optional()) + custom_class = tb.add_class("MyBlah") + custom_class.add_property("nestedKey1", tb.string()) + tb.DynInputOutput.add_property("blah", custom_class.type()) + + res = await b.DynamicListInputOutput( + [ + DynInputOutput(**{ + "new_key": "hi1", + "testKey": "myTest", + "blah": { + "nestedKey1": "nestedVal", + }, + }), + { + "new_key": "hi", + "testKey": "myTest", + "blah": { + "nestedKey1": "nestedVal", + }, + }, + ], + {"tb": tb}, + ) + assert res[0].new_key == "hi1" + assert res[0].testKey == "myTest" + assert res[0].blah["nestedKey1"] == "nestedVal" + assert res[1].new_key == "hi" + assert res[1].testKey == "myTest" + assert res[1].blah["nestedKey1"] == "nestedVal" + + @pytest.mark.asyncio async def test_dynamic_inputs_list(): tb = TypeBuilder()