diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index 0cb871e7..de33dbda 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any, Generic, Protocol, runtime_checkable -from .datatypes import ATTRIBUTE_TYPES, AttrCallback, DataType, T, validate_value +from .datatypes import ATTRIBUTE_TYPES, AttrCallback, DataType, T class AttrMode(Enum): @@ -126,7 +126,9 @@ def __init__( allowed_values=allowed_values, # type: ignore description=description, ) - self._value: T = datatype.dtype() if initial_value is None else initial_value + self._value: T = ( + datatype.initial_value if initial_value is None else initial_value + ) self._update_callback: AttrCallback[T] | None = None self._updater = handler @@ -134,7 +136,7 @@ def get(self) -> T: return self._value async def set(self, value: T) -> None: - self._value = self._datatype.dtype(validate_value(self._datatype, value)) + self._value = self._datatype.cast(value) if self._update_callback is not None: await self._update_callback(self._value) @@ -177,11 +179,11 @@ async def process(self, value: T) -> None: async def process_without_display_update(self, value: T) -> None: if self._process_callback is not None: - await self._process_callback(self._datatype.dtype(value)) + await self._process_callback(self._datatype.cast(value)) async def update_display_without_process(self, value: T) -> None: if self._write_display_callback is not None: - await self._write_display_callback(self._datatype.dtype(value)) + await self._write_display_callback(self._datatype.cast(value)) def set_process_callback(self, callback: AttrCallback[T] | None) -> None: self._process_callback = callback @@ -221,6 +223,6 @@ def __init__( ) async def process(self, value: T) -> None: - await self.set(validate_value(self._datatype, value)) + await self.set(value) await super().process(value) # type: ignore diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index 7ffb8157..e3fd843b 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar T = TypeVar("T", int, float, bool, str) ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore @@ -21,6 +21,18 @@ class DataType(Generic[T]): def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars pass + @abstractmethod + def cast(self, value: T) -> Any: + """Cast a value to a more primative datatype for `Attribute` push. + + Also validate it against fields in the datatype. + """ + pass + + @property + def initial_value(self) -> T: + return self.dtype() + T_Numerical = TypeVar("T_Numerical", int, float) @@ -33,6 +45,13 @@ class _Numerical(DataType[T_Numerical]): min_alarm: int | None = None max_alarm: int | None = None + def cast(self, value: T_Numerical) -> T_Numerical: + if self.min is not None and value < self.min: + raise ValueError(f"Value {value} is less than minimum {self.min}") + if self.max is not None and value > self.max: + raise ValueError(f"Value {value} is greater than maximum {self.max}") + return value + @dataclass(frozen=True) class Int(_Numerical[int]): @@ -65,6 +84,9 @@ class Bool(DataType[bool]): def dtype(self) -> type[bool]: return bool + def cast(self, value: bool) -> bool: + return value + @dataclass(frozen=True) class String(DataType[str]): @@ -74,14 +96,5 @@ class String(DataType[str]): def dtype(self) -> type[str]: return str - -def validate_value(datatype: DataType[T], value: T) -> T: - """Validate a value against a datatype.""" - - if isinstance(datatype, (Int | Float)): - assert isinstance(value, (int | float)), f"Value {value} is not a number" - if datatype.min is not None and value < datatype.min: - raise ValueError(f"Value {value} is less than minimum {datatype.min}") - if datatype.max is not None and value > datatype.max: - raise ValueError(f"Value {value} is greater than maximum {datatype.max}") - return value + def cast(self, value: str) -> str: + return value