From 1c5cf1211e7eea26b02967b25f415ff36549b693 Mon Sep 17 00:00:00 2001 From: James Braza Date: Wed, 22 Jan 2025 15:00:00 -0800 Subject: [PATCH] `RandomAnnotation` support for `model_dump(mode="json")` (#179) --- src/aviary/utils.py | 10 ++++++---- tests/test_utils.py | 34 ++++++++++++++++++++++------------ 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/aviary/utils.py b/src/aviary/utils.py index 34ba7fb9..199985f9 100644 --- a/src/aviary/utils.py +++ b/src/aviary/utils.py @@ -235,10 +235,12 @@ def val_func(state: list) -> random.Random: ) } return cs.json_or_python_schema( - python_schema=cs.union_schema([ - cs.is_instance_schema(source), - plain_val_schema, - ]), + python_schema=cs.union_schema( + choices=[cs.is_instance_schema(source), plain_val_schema], + serialization=cs.plain_serializer_function_ser_schema( + lambda inst: inst.getstate(), when_used="json" + ), + ), json_schema=plain_val_schema_json, ) diff --git a/tests/test_utils.py b/tests/test_utils.py index c8dc1378..682bd4f3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -126,22 +126,32 @@ class SomeModel(BaseModel): model = SomeModel(rng=random.Random(5)) assert isinstance(model.rng, random.Random) - serialized = model.model_dump_json() # 1. Manually check serialized RNG is expected - deserialized_manual = json.loads(serialized) - rng_serialized = deserialized_manual.pop("rng") - assert not deserialized_manual, "Expected only one key in the serialized model" - version, internal_state, gauss_next = rng_serialized - assert isinstance(version, int) - assert isinstance(internal_state, list) - assert isinstance(gauss_next, float | None) + for deserialized in ( + json.loads(model.model_dump_json()), # JSON str + model.model_dump(mode="json"), # JSON dict + ): + rng_serialized = deserialized.pop("rng") + assert not deserialized, "Expected only one key in the serialized model" + version, internal_state, gauss_next = rng_serialized + assert isinstance(version, int) + assert isinstance(internal_state, list) + assert isinstance(gauss_next, float | None) # 2. Check deserialized RNG behaves as original RNG - deserialized_model = SomeModel.model_validate_json(serialized) - sampled_original = model.rng.sample(list(range(10)), k=6) - sampled_deserialized = deserialized_model.rng.sample(list(range(10)), k=6) - assert sampled_original == sampled_deserialized, "Deserialization seeding failed" + for i, deserialized_model in enumerate(( + SomeModel.model_validate_json(model.model_dump_json()), # JSON str + SomeModel.model_validate(model.model_dump(mode="json")), # JSON dict + )): + if i == 0: + # Sample original model once so RNG aligns for both deserialized + # models in the `for` loop + sampled_original = model.rng.sample(list(range(10)), k=6) + sampled_deserialized = deserialized_model.rng.sample(list(range(10)), k=6) + assert sampled_original == sampled_deserialized, ( + "Deserialization seeding failed" + ) class TestLitQAEvaluation: