Skip to content

Commit

Permalink
build: use pydantic-compat library (#124)
Browse files Browse the repository at this point in the history
* build: use pydantic-compat

* more usage

* more changes

* use config support

* add dep

* fix typing
  • Loading branch information
tlambert03 authored Aug 1, 2023
1 parent 91fb059 commit 6fa9f31
Show file tree
Hide file tree
Showing 15 changed files with 78 additions and 230 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ repos:
additional_dependencies:
- types-PyYAML
- pydantic>=2
- pydantic-compat
files: "^src/"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = ["pydantic >=1.7", "numpy"]
dependencies = ["pydantic >=1.7", "numpy", "pydantic-compat >=0.0.1"]

# extras
# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
Expand Down
6 changes: 2 additions & 4 deletions src/useq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@
"ZTopBottom",
]

from useq._pydantic_compat import model_rebuild

model_rebuild(MDAEvent, MDASequence=MDASequence)
model_rebuild(Position, MDASequence=MDASequence)
del model_rebuild
MDAEvent.model_rebuild(MDASequence=MDASequence)
Position.model_rebuild(MDASequence=MDASequence)
81 changes: 16 additions & 65 deletions src/useq/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

import numpy as np
from pydantic import BaseModel

from useq._pydantic_compat import PYDANTIC2, model_dump, model_fields
from pydantic_compat import PydanticCompatMixin

if TYPE_CHECKING:
from pydantic import ConfigDict

ReprArgs = Sequence[Tuple[Optional[str], Any]]
IncEx = set[int] | set[str] | dict[int, Any] | dict[str, Any] | None

Expand All @@ -31,21 +32,13 @@
_Y = TypeVar("_Y", bound="UseqModel")


class FrozenModel(BaseModel):
if PYDANTIC2:
model_config = {
"populate_by_name": True,
"extra": "ignore",
"frozen": True,
}

else:

class Config:
allow_population_by_field_name = True
extra = "ignore"
frozen = True
json_encoders: ClassVar[dict] = {MappingProxyType: dict}
class FrozenModel(PydanticCompatMixin, BaseModel):
model_config: ClassVar[ConfigDict] = {
"populate_by_name": True,
"extra": "ignore",
"frozen": True,
"json_encoders": {MappingProxyType: dict},
}

def replace(self: _T, **kwargs: Any) -> _T:
"""Return a new instance replacing specified kwargs with new values.
Expand All @@ -58,63 +51,22 @@ def replace(self: _T, **kwargs: Any) -> _T:
will perform validation and casting on the new values, whereas `copy` assumes
that all objects are valid and will not perform any validation or casting.
"""
state = model_dump(self, exclude={"uid"})
state = self.model_dump(exclude={"uid"})
return type(self)(**{**state, **kwargs})

if PYDANTIC2:
# retain pydantic1's json method
def json(
self,
*,
indent: int | None = None, # type: ignore
include: IncEx = None,
exclude: IncEx = None, # type: ignore
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False, # type: ignore
round_trip: bool = False,
warnings: bool = True,
) -> str:
return super().model_dump_json(
indent=indent,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
)

# we let this one be deprecated
# def dict()

elif not TYPE_CHECKING:
# Backport pydantic2 methods so that useq-0.1.0 can be used with pydantic1

def model_dump_json(self, **kwargs: Any) -> str:
"""Backport of pydantic2's model_dump_json method."""
return self.json(**kwargs)

def model_dump(self, **kwargs: Any) -> dict[str, Any]:
"""Backport of pydantic2's model_dump_json method."""
return self.dict(**kwargs)


class UseqModel(FrozenModel):
def __repr_args__(self) -> ReprArgs:
"""Only show fields that are not None or equal to their default value."""
return [
(k, val)
for k, val in super().__repr_args__()
if k in model_fields(self)
if k in self.model_fields
and val
!= (
factory()
if (factory := model_fields(self)[k].default_factory) is not None
else model_fields(self)[k].default
if (factory := self.model_fields[k].default_factory) is not None
else self.model_fields[k].default
)
]

Expand All @@ -133,7 +85,7 @@ def from_file(cls: Type[_Y], path: Union[str, Path]) -> _Y:
else: # pragma: no cover
raise ValueError(f"Unknown file type: {path.suffix}")

return cls.model_validate(obj) if PYDANTIC2 else cls.parse_obj(obj)
return cls.model_validate(obj)

@classmethod
def parse_file(cls: Type[_Y], path: Union[str, Path], **kwargs: Any) -> _Y:
Expand Down Expand Up @@ -180,8 +132,7 @@ def yaml(
np.floating, lambda dumper, d: dumper.represent_float(float(d))
)

data = model_dump(
self,
data = self.model_dump(
include=include,
exclude=exclude,
by_alias=by_alias,
Expand Down
25 changes: 15 additions & 10 deletions src/useq/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,24 @@
import math
from enum import Enum
from functools import partial
from typing import Any, Callable, Iterator, NamedTuple, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
ClassVar,
Iterator,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)

import numpy as np
from pydantic import Field
from pydantic import ConfigDict, Field
from pydantic_compat import field_validator

from useq._base_model import FrozenModel
from useq._pydantic_compat import FROZEN, PYDANTIC2, field_validator
from useq._pydantic_compat import FROZEN


class RelativeTo(Enum):
Expand Down Expand Up @@ -123,13 +134,7 @@ class _GridPlan(FrozenModel):
"""

# Overriding FrozenModel to make fov_width and fov_height mutable.
if PYDANTIC2:
model_config = {"validate_assignment": True, "frozen": False}
else:

class Config:
validate_assignment = True
frozen = False
model_config: ClassVar[ConfigDict] = {"validate_assignment": True, "frozen": False}

overlap: Tuple[float, float] = Field((0.0, 0.0), **FROZEN) # type: ignore
mode: OrderMode = Field(OrderMode.row_wise_snake, **FROZEN) # type: ignore
Expand Down
3 changes: 1 addition & 2 deletions src/useq/_hardware_autofocus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from useq._actions import HardwareAutofocus
from useq._base_model import FrozenModel
from useq._mda_event import MDAEvent
from useq._pydantic_compat import model_copy


class AutoFocusPlan(FrozenModel):
Expand Down Expand Up @@ -46,7 +45,7 @@ def event(self, event: MDAEvent) -> Optional[MDAEvent]:
if zplan and zplan.is_relative and "z" in event.index:
updates["z_pos"] = event.z_pos - list(zplan)[event.index["z"]]

return model_copy(event, update=updates)
return event.model_copy(update=updates)

def should_autofocus(self, event: MDAEvent) -> bool:
"""Method that must be implemented by a subclass.
Expand Down
13 changes: 6 additions & 7 deletions src/useq/_iter_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
)
from useq._mda_event import Channel as EventChannel
from useq._mda_event import MDAEvent
from useq._pydantic_compat import model_construct, model_copy
from useq._utils import AXES, Axis, _has_axes
from useq._z import AnyZPlan # noqa: TCH001 # noqa: TCH001

Expand Down Expand Up @@ -100,7 +99,7 @@ def iter_sequence(sequence: MDASequence) -> Iterator[MDAEvent]:
for axis, idx in this_e.index.items()
if idx != next_e.index[axis]
):
this_e = model_copy(this_e, update={"keep_shutter_open": True})
this_e = this_e.model_copy(update={"keep_shutter_open": True})
yield this_e
this_e = next_e
yield this_e
Expand Down Expand Up @@ -169,8 +168,8 @@ def _iter_sequence(
if position and position.name:
event_kwargs["pos_name"] = position.name
if channel:
event_kwargs["channel"] = model_construct(
EventChannel, config=channel.config, group=channel.group
event_kwargs["channel"] = EventChannel.model_construct(
config=channel.config, group=channel.group
)
if channel.exposure is not None:
event_kwargs["exposure"] = channel.exposure
Expand Down Expand Up @@ -205,8 +204,8 @@ def _iter_sequence(
# if the sub-sequence doe not have an autofocus plan, we override it
# with the parent sequence's autofocus plan
if not sub_seq.autofocus_plan:
sub_seq = model_copy(
sub_seq, update={"autofocus_plan": autofocus_plan}
sub_seq = sub_seq.model_copy(
update={"autofocus_plan": autofocus_plan}
)

# recurse into the sub-sequence
Expand All @@ -223,7 +222,7 @@ def _iter_sequence(
elif position.sequence is not None and position.sequence.autofocus_plan:
autofocus_plan = position.sequence.autofocus_plan

event = model_construct(MDAEvent, **event_kwargs)
event = MDAEvent.model_construct(**event_kwargs)
if autofocus_plan:
af_event = autofocus_plan.event(event)
if af_event:
Expand Down
8 changes: 6 additions & 2 deletions src/useq/_mda_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

from useq._actions import AcquireImage, AnyAction
from useq._base_model import UseqModel
from useq._pydantic_compat import PYDANTIC2, field_serializer

try:
from pydantic import field_serializer
except ImportError:
field_serializer = None # type: ignore

if TYPE_CHECKING:
from useq._mda_sequence import MDASequence
Expand Down Expand Up @@ -166,7 +170,7 @@ def to_pycromanager(self) -> "PycroManagerEvent":

return to_pycromanager(self)

if PYDANTIC2:
if field_serializer is not None:
_si = field_serializer("index", mode="plain")(lambda v: dict(v))
_sx = field_serializer("x_pos", mode="plain")(_float_or_none)
_sy = field_serializer("y_pos", mode="plain")(_float_or_none)
Expand Down
35 changes: 12 additions & 23 deletions src/useq/_mda_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,14 @@

import numpy as np
from pydantic import Field, PrivateAttr
from pydantic_compat import field_validator, model_validator

from useq._base_model import UseqModel
from useq._channel import Channel
from useq._grid import AnyGridPlan, GridPosition # noqa: TCH001
from useq._hardware_autofocus import AnyAutofocusPlan, AxesBasedAF
from useq._iter_sequence import iter_sequence
from useq._position import Position
from useq._pydantic_compat import (
field_validator,
model_construct,
model_dump,
model_validator,
pydantic_1_style_root_dict,
)
from useq._time import AnyTimePlan # noqa: TCH001
from useq._utils import AXES, Axis, TimeEstimate, estimate_sequence_duration
from useq._z import AnyZPlan # noqa: TCH001
Expand Down Expand Up @@ -176,7 +170,7 @@ def _validate_channels(cls, value: Any) -> Tuple[Channel, ...]:
if isinstance(v, Channel):
channels.append(v)
elif isinstance(v, str):
channels.append(model_construct(Channel, config=v))
channels.append(Channel.model_construct(config=v))
elif isinstance(v, dict):
channels.append(Channel(**v))
else: # pragma: no cover
Expand Down Expand Up @@ -230,22 +224,17 @@ def _validate_axis_order(cls, v: Any) -> str:
@model_validator(mode="after")
@classmethod
def _validate_mda(cls, values: Any) -> Any:
# this strange bit here is to deal with the fact that in pydantic1
# root_validator after returned a dict of {field_name -> validated_value}
# but in pydantic2 it returns the complete validated model instance
_values = pydantic_1_style_root_dict(cls, values)

if "axis_order" in _values:
if values.axis_order:
cls._check_order(
_values["axis_order"],
z_plan=_values.get("z_plan"),
stage_positions=_values.get("stage_positions", ()),
channels=_values.get("channels", ()),
grid_plan=_values.get("grid_plan"),
autofocus_plan=_values.get("autofocus_plan"),
values.axis_order,
z_plan=values.z_plan,
stage_positions=values.stage_positions,
channels=values.channels,
grid_plan=values.grid_plan,
autofocus_plan=values.autofocus_plan,
)
if "stage_positions" in _values:
for p in _values["stage_positions"]:
if values.stage_positions:
for p in values.stage_positions:
if hasattr(p, "sequence") and getattr(
p.sequence, "keep_shutter_open_across", None
): # pragma: no cover
Expand All @@ -259,7 +248,7 @@ def __eq__(self, other: Any) -> bool:
"""Return `True` if two `MDASequences` are equal (uid is excluded)."""
if isinstance(other, MDASequence):
return bool(
model_dump(self, exclude={"uid"}) == model_dump(other, exclude={"uid"})
self.model_dump(exclude={"uid"}) == other.model_dump(exclude={"uid"})
)
else:
return False
Expand Down
Loading

0 comments on commit 6fa9f31

Please sign in to comment.