Skip to content

Commit

Permalink
RandomAnnotation support for model_dump(mode="json") (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Jan 22, 2025
1 parent e18c399 commit 1c5cf12
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
10 changes: 6 additions & 4 deletions src/aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,12 @@ def val_func(state: list) -> random.Random:
)
}
return cs.json_or_python_schema(
python_schema=cs.union_schema([
cs.is_instance_schema(source),
plain_val_schema,
]),
python_schema=cs.union_schema(
choices=[cs.is_instance_schema(source), plain_val_schema],
serialization=cs.plain_serializer_function_ser_schema(
lambda inst: inst.getstate(), when_used="json"
),
),
json_schema=plain_val_schema_json,
)

Expand Down
34 changes: 22 additions & 12 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,32 @@ class SomeModel(BaseModel):

model = SomeModel(rng=random.Random(5))
assert isinstance(model.rng, random.Random)
serialized = model.model_dump_json()

# 1. Manually check serialized RNG is expected
deserialized_manual = json.loads(serialized)
rng_serialized = deserialized_manual.pop("rng")
assert not deserialized_manual, "Expected only one key in the serialized model"
version, internal_state, gauss_next = rng_serialized
assert isinstance(version, int)
assert isinstance(internal_state, list)
assert isinstance(gauss_next, float | None)
for deserialized in (
json.loads(model.model_dump_json()), # JSON str
model.model_dump(mode="json"), # JSON dict
):
rng_serialized = deserialized.pop("rng")
assert not deserialized, "Expected only one key in the serialized model"
version, internal_state, gauss_next = rng_serialized
assert isinstance(version, int)
assert isinstance(internal_state, list)
assert isinstance(gauss_next, float | None)

# 2. Check deserialized RNG behaves as original RNG
deserialized_model = SomeModel.model_validate_json(serialized)
sampled_original = model.rng.sample(list(range(10)), k=6)
sampled_deserialized = deserialized_model.rng.sample(list(range(10)), k=6)
assert sampled_original == sampled_deserialized, "Deserialization seeding failed"
for i, deserialized_model in enumerate((
SomeModel.model_validate_json(model.model_dump_json()), # JSON str
SomeModel.model_validate(model.model_dump(mode="json")), # JSON dict
)):
if i == 0:
# Sample original model once so RNG aligns for both deserialized
# models in the `for` loop
sampled_original = model.rng.sample(list(range(10)), k=6)
sampled_deserialized = deserialized_model.rng.sample(list(range(10)), k=6)
assert sampled_original == sampled_deserialized, (
"Deserialization seeding failed"
)


class TestLitQAEvaluation:
Expand Down

0 comments on commit 1c5cf12

Please sign in to comment.