Skip to content

Commit

Permalink
Even simpler form to set NdArray types
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulVanSchayck committed Nov 15, 2024
1 parent 2ae9ec5 commit 329d0f7
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 78 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ from datetime import datetime, timezone
from pydantic import AwareDatetime
from covjson_pydantic.coverage import Coverage
from covjson_pydantic.domain import Domain, Axes, ValuesAxis, DomainType
from covjson_pydantic.ndarray import NdArray
from covjson_pydantic.ndarray import NdArrayFloat

c = Coverage(
domain=Domain(
Expand All @@ -51,7 +51,7 @@ c = Coverage(
)
),
ranges={
"temperature": NdArray[float](axisNames=["x", "y", "t"], shape=[1, 1, 1], values=[42.0])
"temperature": NdArrayFloat(axisNames=["x", "y", "t"], shape=[1, 1, 1], values=[42.0])
}
)

Expand Down
6 changes: 2 additions & 4 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from covjson_pydantic.domain import Domain
from covjson_pydantic.domain import DomainType
from covjson_pydantic.domain import ValuesAxis
from covjson_pydantic.ndarray import NdArray
from covjson_pydantic.ndarray import NdArrayFloat
from pydantic import AwareDatetime

c = Coverage(
Expand All @@ -18,7 +18,5 @@
t=ValuesAxis[AwareDatetime](values=[datetime.now(tz=timezone.utc)]),
),
),
ranges={"temperature": NdArray[float](axisNames=["x", "y", "t"], shape=[1, 1, 1], values=[42.0])},
ranges={"temperature": NdArrayFloat(axisNames=["x", "y", "t"], shape=[1, 1, 1], values=[42.0])},
)

print(c.model_dump_json(exclude_none=True, indent=4))
6 changes: 4 additions & 2 deletions src/covjson_pydantic/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from .base_models import CovJsonBaseModel
from .domain import Domain
from .domain import DomainType
from .ndarray import NdArray
from .ndarray import NdArrayFloat
from .ndarray import NdArrayInt
from .ndarray import NdArrayStr
from .ndarray import TiledNdArray
from .parameter import Parameter
from .parameter import ParameterGroup
Expand All @@ -22,7 +24,7 @@ class Coverage(CovJsonBaseModel, extra="allow"):
domain: Domain
parameters: Optional[Dict[str, Parameter]] = None
parameterGroups: Optional[List[ParameterGroup]] = None # noqa: N815
ranges: Dict[str, Union[NdArray[float], NdArray[int], NdArray[str], TiledNdArray, AnyUrl]]
ranges: Dict[str, Union[NdArrayFloat, NdArrayInt, NdArrayStr, TiledNdArray, AnyUrl]]


class CoverageCollection(CovJsonBaseModel, extra="allow"):
Expand Down
77 changes: 24 additions & 53 deletions src/covjson_pydantic/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,25 @@
import math
import typing
from enum import Enum
from typing import List
from typing import Literal
from typing import Optional
from typing import Union

from pydantic import model_validator
from typing_extensions import Generic
from typing_extensions import TypeVar

from .base_models import CovJsonBaseModel


class DataType(str, Enum):
float = "float"
str = "string"
int = "integer"


NdArrayTypeT = TypeVar("NdArrayTypeT")


class NdArray(CovJsonBaseModel, Generic[NdArrayTypeT], extra="allow"):
class NdArray(CovJsonBaseModel, extra="allow"):
type: Literal["NdArray"] = "NdArray"
dataType: Union[DataType, None] = None # noqa: N815
dataType: str # Kept here to ensure order of output in JSON # noqa: N815
axisNames: Optional[List[str]] = None # noqa: N815
shape: Optional[List[int]] = None
values: List[Optional[NdArrayTypeT]] = []

@model_validator(mode="before")
@classmethod
def set_data_type(cls, v):
if type(v) is not dict:
return v

if "dataType" in v:
v["dataType"] = DataType(v["dataType"])
return v

t = typing.get_args(cls.model_fields["values"].annotation)[0]
if t == typing.Optional[float]:
v["dataType"] = DataType.float
elif t == typing.Optional[int]:
v["dataType"] = DataType.int
elif t == typing.Optional[str]:
v["dataType"] = DataType.str
else:
raise ValueError(f"Unsupported NdArray type: {t}")
return v

@model_validator(mode="after")
def check_data_type(self):
t = typing.get_args(self.model_fields["values"].annotation)[0]
if t == typing.Optional[NdArrayTypeT]:
given_type = self.dataType.name if isinstance(self.dataType, DataType) else ""
raise ValueError(f"No NdArray type given, please specify as NdArray[{given_type}]")
if self.dataType == DataType.float and not t == typing.Optional[float]:
raise ValueError("dataType and NdArray type must both be float.")
if self.dataType == DataType.str and not t == typing.Optional[str]:
raise ValueError("dataType and NdArray type must both be string.")
if self.dataType == DataType.int and not t == typing.Optional[int]:
raise ValueError("dataType and NdArray type must both be integer.")

return self
def __new__(cls, *args, **kwargs):
if cls is NdArray:
raise TypeError(
"NdArray cannot be instantiated directly, please use a NdArrayFloat, NdArrayInt or NdArrayStr"
)
return super().__new__(cls)

@model_validator(mode="after")
def check_field_dependencies(self):
Expand All @@ -87,6 +43,21 @@ def check_field_dependencies(self):
return self


class NdArrayFloat(NdArray):
dataType: Literal["float"] = "float" # noqa: N815
values: List[Optional[float]]


class NdArrayInt(NdArray):
dataType: Literal["integer"] = "integer" # noqa: N815
values: List[Optional[int]]


class NdArrayStr(NdArray):
dataType: Literal["string"] = "string" # noqa: N815
values: List[Optional[str]]


class TileSet(CovJsonBaseModel):
tileShape: List[Optional[int]] # noqa: N815
urlTemplate: str # noqa: N815
Expand All @@ -95,7 +66,7 @@ class TileSet(CovJsonBaseModel):
# TODO: Validation of field dependencies
class TiledNdArray(CovJsonBaseModel, extra="allow"):
type: Literal["TiledNdArray"] = "TiledNdArray"
dataType: DataType = DataType.float # noqa: N815
dataType: Literal["float"] = "float" # noqa: N815
axisNames: List[str] # noqa: N815
shape: List[int]
tileSets: List[TileSet] # noqa: N815
31 changes: 14 additions & 17 deletions tests/test_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from covjson_pydantic.coverage import CoverageCollection
from covjson_pydantic.domain import Axes
from covjson_pydantic.domain import Domain
from covjson_pydantic.ndarray import DataType
from covjson_pydantic.ndarray import NdArray
from covjson_pydantic.ndarray import NdArrayFloat
from covjson_pydantic.ndarray import NdArrayInt
from covjson_pydantic.ndarray import NdArrayStr
from covjson_pydantic.ndarray import TiledNdArray
from covjson_pydantic.parameter import Parameter
from covjson_pydantic.parameter import ParameterGroup
Expand All @@ -34,10 +36,10 @@
("spec-domain-multipoint-series.json", Domain),
("spec-domain-multipoint.json", Domain),
("spec-domain-trajectory.json", Domain),
("ndarray-float.json", NdArray[float]),
("ndarray-string.json", NdArray[str]),
("ndarray-integer.json", NdArray[int]),
("spec-ndarray.json", NdArray[float]),
("ndarray-float.json", NdArrayFloat),
("ndarray-string.json", NdArrayStr),
("ndarray-integer.json", NdArrayInt),
("spec-ndarray.json", NdArrayFloat),
("spec-tiled-ndarray.json", TiledNdArray),
("continuous-data-parameter.json", Parameter),
("categorical-data-parameter.json", Parameter),
Expand Down Expand Up @@ -69,11 +71,10 @@ def test_happy_cases(file_name, object_type):
("point-series-domain-no-t.json", Domain, r"A 'PointSeries' must have a 't'-axis."),
("mixed-type-axes.json", Axes, r"Input should be a valid number"),
("mixed-type-axes-2.json", Axes, r"Input should be a valid string"),
("mixed-type-ndarray-1.json", NdArray[float], r"Input should be a valid number"),
("mixed-type-ndarray-1.json", NdArray, r"No NdArray type given, please specify as NdArray\[float\]"),
("mixed-type-ndarray-2.json", NdArray[str], r"dataType and NdArray type must both be float"),
("mixed-type-ndarray-3.json", NdArray[int], r"Input should be a valid integer"),
("mixed-type-ndarray-3.json", NdArray[float], r"dataType and NdArray type must both be integer"),
("mixed-type-ndarray-1.json", NdArrayFloat, r"Input should be a valid number"),
("mixed-type-ndarray-2.json", NdArrayStr, r"Input should be 'string'"),
("mixed-type-ndarray-3.json", NdArrayInt, r"Input should be a valid integer"),
("mixed-type-ndarray-3.json", NdArrayFloat, r"Input should be 'float'"),
]


Expand All @@ -89,10 +90,6 @@ def test_error_cases(file_name, object_type, error_message):
object_type.model_validate_json(json_string)


def test_ndarray_set_datetype():
nd = NdArray[float](axisNames=["x", "y", "t"], shape=[1, 1, 1], values=[42.0])
assert nd.dataType == DataType.float
nd = NdArray[int](axisNames=["x", "y", "t"], shape=[1, 1, 1], values=[42])
assert nd.dataType == DataType.int
nd = NdArray[str](axisNames=["x", "y", "t"], shape=[1, 1, 1], values=["foo"])
assert nd.dataType == DataType.str
def test_ndarray_directly():
with pytest.raises(TypeError, match="NdArray cannot be instantiated directly"):
NdArray(axisNames=["x", "y", "t"], shape=[1, 1, 1], values=[42.0])

0 comments on commit 329d0f7

Please sign in to comment.