diff --git a/pyproject.toml b/pyproject.toml index c3077e4e..5b29541f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "pydantic~=2.1", "bioutils", "requests", + "canonicaljson", ] [project.optional-dependencies] diff --git a/src/ga4gh/core/entity_models.py b/src/ga4gh/core/entity_models.py index 78146b49..f40e9c9d 100644 --- a/src/ga4gh/core/entity_models.py +++ b/src/ga4gh/core/entity_models.py @@ -14,7 +14,7 @@ from typing import Any, Dict, Annotated, Optional, Union, List from enum import Enum -from pydantic import BaseModel, Field, RootModel, StringConstraints, model_serializer, ConfigDict +from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict from ga4gh.core import GA4GH_IR_REGEXP @@ -78,7 +78,6 @@ class IRI(RootModel): def __hash__(self): return self.root.__hash__() - @model_serializer(when_used='json') def ga4gh_serialize(self): m = GA4GH_IR_REGEXP.match(self.root) if m is not None: diff --git a/src/ga4gh/core/identifiers.py b/src/ga4gh/core/identifiers.py index 85ea214a..7b026243 100644 --- a/src/ga4gh/core/identifiers.py +++ b/src/ga4gh/core/identifiers.py @@ -14,7 +14,7 @@ For that reason, they are implemented here in one file. """ - +from canonicaljson import encode_canonical_json import contextvars import re from contextlib import ContextDecorator @@ -194,6 +194,6 @@ def ga4gh_serialize(obj: BaseModel, as_version: PrevVrsVersion | None = None) -> PrevVrsVersion.validate(as_version) if as_version is None: - return obj.model_dump_json().encode("utf-8") + return encode_canonical_json(obj.ga4gh_serialize()) else: return obj.ga4gh_serialize_as_version(as_version) diff --git a/src/ga4gh/vrs/models.py b/src/ga4gh/vrs/models.py index e787694b..cd5848ad 100644 --- a/src/ga4gh/vrs/models.py +++ b/src/ga4gh/vrs/models.py @@ -26,7 +26,8 @@ ) from ga4gh.core.pydantic import get_pydantic_root -from pydantic import BaseModel, Field, RootModel, StringConstraints, model_serializer, ConfigDict +from canonicaljson import encode_canonical_json +from pydantic import BaseModel, Field, RootModel, StringConstraints, ConfigDict from ga4gh.core.pydantic import ( getattr_in @@ -178,7 +179,7 @@ def _recurse_ga4gh_serialize(obj): elif isinstance(obj, _ValueObject): return obj.ga4gh_serialize() elif isinstance(obj, RootModel): - return _recurse_ga4gh_serialize(obj.model_dump(mode='json')) + return _recurse_ga4gh_serialize(obj.model_dump()) elif isinstance(obj, str): return obj elif isinstance(obj, list): @@ -193,9 +194,8 @@ class _ValueObject(DomainEntity, ABC): """ def __hash__(self): - return self.model_dump_json().__hash__() + return encode_canonical_json(self.ga4gh_serialize()).decode("utf-8").__hash__() - @model_serializer(when_used='json') def ga4gh_serialize(self) -> Dict: out = OrderedDict() for k in self.ga4gh.keys: @@ -242,7 +242,7 @@ def compute_digest(self, store=True, as_version: PrevVrsVersion | None = None) - returned following the conventions of the VRS version indicated by ``as_version_``. """ if as_version is None: - digest = sha512t24u(self.model_dump_json().encode("utf-8")) + digest = sha512t24u(encode_canonical_json(self.ga4gh_serialize())) if store: self.digest = digest else: @@ -580,7 +580,6 @@ class CisPhasedBlock(_VariationBase): ) sequenceReference: Optional[SequenceReference] = Field(None, description="An optional Sequence Reference on which all of the in-cis Alleles are found. When defined, this may be used to implicitly define the `sequenceReference` attribute for each of the CisPhasedBlock member Alleles.") - @model_serializer(when_used="json") def ga4gh_serialize(self) -> Dict: out = _ValueObject.ga4gh_serialize(self) out["members"] = sorted(out["members"]) diff --git a/tests/validation/test_models.py b/tests/validation/test_models.py index f09f5db8..63679362 100644 --- a/tests/validation/test_models.py +++ b/tests/validation/test_models.py @@ -24,7 +24,7 @@ def ga4gh_1_3_serialize(*args, **kwargs): return ga4gh_serialize(*args, **kwargs) fxs = { - "ga4gh_serialize": lambda o: ga4gh_serialize(o).decode() if ga4gh_serialize(o) else None, + "ga4gh_serialize": ga4gh_serialize, "ga4gh_digest": ga4gh_digest, "ga4gh_identify": ga4gh_identify, "ga4gh_1_3_digest": ga4gh_1_3_digest, @@ -60,6 +60,8 @@ def flatten_tests(vts): def test_validation(cls, data, fn, exp): o = getattr(models, cls)(**data) fx = fxs[fn] + if fn == "ga4gh_serialize": + exp = exp.encode("utf-8") assert fx(o) == exp