Skip to content
This repository was archived by the owner on Nov 8, 2024. It is now read-only.

Commit

Permalink
Fix some type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jl-wynen committed Oct 23, 2024
1 parent 6e98775 commit 8e9ba48
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 37 deletions.
8 changes: 5 additions & 3 deletions src/sqomega/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ def create(self) -> Generator[Sqw, None, None]:
descriptor = block_descriptors[name]
match descriptor.block_type:
case SqwDataBlockType.regular:
sqw_io.write_raw(buffer)
# Type guaranteed by _serialize_data_blocks
sqw_io.write_raw(buffer) # type: ignore[arg-type]
case SqwDataBlockType.pix:
self._pix_placeholder.write(sqw_io)
# Type guaranteed by _serialize_data_blocks
self._pix_placeholder.write(sqw_io) # type: ignore[union-attr]
sqw_io.seek(descriptor.position + descriptor.size)
case _:
raise NotImplementedError(
Expand Down Expand Up @@ -175,7 +177,7 @@ def _serialize_data_blocks(
dict[DataBlockName, SqwDataBlockDescriptor],
]:
data_blocks = self._prepare_data_blocks()
buffers = {}
buffers: dict[DataBlockName, memoryview | None] = {}
descriptors = {}
for name, data_block in data_blocks.items():
buffer = BytesIO()
Expand Down
9 changes: 6 additions & 3 deletions src/sqomega/_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class ObjectArray:
@dataclass(kw_only=True)
class CellArray:
shape: tuple[int, ...]
data: list[ObjectArray] # nested object array to encode types of each item
# nested object array to encode types of each item
data: list[ObjectArray | CellArray]
ty: ClassVar[TypeTag] = TypeTag.cell


Expand Down Expand Up @@ -119,7 +120,7 @@ class Datetime:

class Serializable(ABC):
@abstractmethod
def _serialize_to_dict(self) -> dict[str, Object]: ...
def _serialize_to_dict(self) -> dict[str, Object | ObjectArray | CellArray]: ...

def serialize_to_ir(self) -> Struct:
fields = self._serialize_to_dict()
Expand All @@ -135,7 +136,9 @@ def prepare_for_serialization(self: _T, filename: str, filepath: str) -> _T: #
return self


def _serialize_field(field: Object) -> ObjectArray:
def _serialize_field(
field: Object | ObjectArray | CellArray,
) -> ObjectArray | CellArray:
if isinstance(field, ObjectArray | CellArray):
return field
if isinstance(field, Datetime):
Expand Down
62 changes: 44 additions & 18 deletions src/sqomega/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import enum
from dataclasses import dataclass, field, replace
from datetime import datetime, timezone
from typing import ClassVar
from typing import ClassVar, TypeAlias

import numpy as np
import scipp as sc

from . import _ir as ir

DataBlockName = tuple[str, str]
DataBlockName: TypeAlias = tuple[str, str]


class SqwFileType(enum.Enum):
Expand Down Expand Up @@ -54,7 +54,9 @@ class SqwMainHeader(ir.Serializable):
serial_name: ClassVar[str] = "main_header_cl"
version: ClassVar[float] = 2.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
return {
"serial_name": ir.String(self.serial_name),
"version": ir.F64(self.version),
Expand Down Expand Up @@ -86,7 +88,9 @@ class SqwLineAxes(ir.Serializable):
serial_name: ClassVar[str] = "line_axes"
version: ClassVar[float] = 7.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
units = ['1/angstrom'] * 3 + ['meV'] # depends on SqwLineProj.type

return {
Expand Down Expand Up @@ -127,7 +131,9 @@ class SqwLineProj(ir.Serializable):
serial_name: ClassVar[str] = "line_proj"
version: ClassVar[float] = 7.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
if self.type != "aaa":
raise NotImplementedError(f"Projection type not supported: {self.type}")
units = ['1/angstrom'] * 3 + ['meV'] # depends on SqwLineProj.type
Expand Down Expand Up @@ -164,7 +170,9 @@ class SqwDndMetadata(ir.Serializable):
serial_name: ClassVar[str] = "dnd_metadata"
version: ClassVar[float] = 1.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
axes = self.axes.serialize_to_ir()
proj = self.proj.serialize_to_ir()

Expand Down Expand Up @@ -201,7 +209,9 @@ class SqwPixelMetadata(ir.Serializable):
serial_name: ClassVar[str] = "pix_metadata"
version: ClassVar[float] = 1.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
return {
"serial_name": ir.String(self.serial_name),
"version": ir.F64(self.version),
Expand All @@ -218,7 +228,9 @@ class SqwPixWrap(ir.Serializable):
n_rows: int = 9
n_pixels: int

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
return {
"n_rows": ir.U32(self.n_rows),
"n_pixels": ir.U64(self.n_pixels),
Expand All @@ -234,7 +246,9 @@ class SqwIXSource(ir.Serializable):
serial_name: ClassVar[str] = "IX_source"
version: ClassVar[float] = 2.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
return {
"serial_name": ir.String(self.serial_name),
"version": ir.F64(self.version),
Expand All @@ -252,7 +266,9 @@ class SqwIXNullInstrument(ir.Serializable):
serial_name: ClassVar[str] = "IX_null_inst"
version: ClassVar[float] = 2.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
return {
"serial_name": ir.String(self.serial_name),
"version": ir.F64(self.version),
Expand All @@ -275,7 +291,9 @@ class SqwIXSample(ir.Serializable):
serial_name: ClassVar[str] = "IX_samp"
version: ClassVar[float] = 0.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
return {
"serial_name": ir.String(self.serial_name),
"version": ir.F64(self.version),
Expand Down Expand Up @@ -307,7 +325,9 @@ class SqwIXExperiment(ir.Serializable):
serial_name: ClassVar[str] = "IX_experiment"
version: ClassVar[float] = 3.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
en = (
self.en.to(unit='meV', dtype='float64', copy=False)
.broadcast(sizes={'_': 1, 'energy_transfer': self.en.shape[0]})
Expand Down Expand Up @@ -342,7 +362,9 @@ class SqwMultiIXExperiment(ir.Serializable):
serial_name: ClassVar[str] = "IX_experiment"
version: ClassVar[float] = 3.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
return {
"serial_name": ir.String(self.serial_name),
"version": ir.F64(self.version),
Expand All @@ -362,7 +384,9 @@ class UniqueRefContainer(ir.Serializable):
serial_name: ClassVar[str] = "unique_references_container"
version: ClassVar[float] = 1.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
return {
"serial_name": ir.String(self.serial_name),
"version": ir.F64(self.version),
Expand All @@ -381,7 +405,9 @@ class UniqueObjContainer(ir.Serializable):
serial_name: ClassVar[str] = "unique_objects_container"
version: ClassVar[float] = 1.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
def _serialize_to_dict(
self,
) -> dict[str, ir.Object | ir.ObjectArray | ir.CellArray]:
return {
"serial_name": ir.String(self.serial_name),
"version": ir.F64(self.version),
Expand All @@ -400,7 +426,7 @@ def _serialize_to_dict(self) -> dict[str, ir.Object]:


def _angle_value(x: sc.Variable) -> float:
return x.to(unit='rad', dtype='float64', copy=False).value
return x.to(unit='rad', dtype='float64', copy=False).value # type: ignore[no-any-return]


def _serialize_str_array(strings: list[str]) -> ir.CellArray:
Expand All @@ -414,10 +440,10 @@ def _serialize_str_array(strings: list[str]) -> ir.CellArray:


def _serialize_multi_unit_array(data: list[sc.Variable], units: list[str]) -> ir.Array:
data = np.stack(
stacked = np.stack(
[d.to(unit=u, dtype='float64').values for d, u in zip(data, units, strict=True)]
)
return ir.Array(data, ty=ir.TypeTag.f64)
return ir.Array(stacked, ty=ir.TypeTag.f64)


def _variable_to_float_array(var: sc.Variable, unit: str | None) -> ir.Array:
Expand Down
8 changes: 5 additions & 3 deletions src/sqomega/_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Implementations of readers and writers for SQW object types."""

from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import Any, Generic, TypeVar

import numpy as np
Expand All @@ -14,7 +14,9 @@

_Shape = tuple[int, ...]
_T = TypeVar("_T")
_AnyObjectList = list[ir.Object] | list[ir.ObjectArray] | npt.NDArray[Any]
_AnyObjectList = (
Sequence[ir.Object] | list[ir.ObjectArray | ir.CellArray] | npt.NDArray[Any]
)
_ObjectReader = Callable[[LowLevelSqw, _Shape], _AnyObjectList]
_ObjectWriter = Callable[[LowLevelSqw, _AnyObjectList], None]

Expand Down Expand Up @@ -105,7 +107,7 @@ def _write_cell(sqw_io: LowLevelSqw, objects: _AnyObjectList) -> None:
# Arrays of struct are encoded with both the shape of the object array and the shape of
# the child cell array. Note the check of the shape.
@_READERS.add(ir.TypeTag.struct)
def _read_struct(sqw_io: LowLevelSqw, shape: _Shape) -> list[ir.Object]:
def _read_struct(sqw_io: LowLevelSqw, shape: _Shape) -> Sequence[ir.Object]:
position = sqw_io.position
if not shape:
return []
Expand Down
16 changes: 8 additions & 8 deletions src/sqomega/_sqw.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ def _normalize_data_block_name(
name: DataBlockName | str, level2_name: str | None
) -> DataBlockName:
match (name, level2_name):
case (_, _) as n, None:
return n
case (str(n1), str(n2)), None:
return DataBlockName((n1, n2)) # type: ignore[no-any-return, operator]
case str(n1), str(n2):
return DataBlockName((n1, n2))
return DataBlockName((n1, n2)) # type: ignore[no-any-return, operator]
raise TypeError(
"Data block name must be given either as a tuple of two strings or two"
f"separate strings. Got {name!r} and {level2_name!r}."
Expand Down Expand Up @@ -251,16 +251,16 @@ def _get_scalar_struct_field(struct: ir.Struct, name: str) -> Any:
shape = field.shape[1:] if field.ty == ir.TypeTag.char else field.shape
if shape not in ((1,), ()):
raise AbortParse(f"Field '{name}' has non-scalar shape: {shape}")
if isinstance(field.data[0], ir.Struct):
raise AbortParse(f"Field '{name}' contains a nested struct")
if isinstance(field.data[0], ir.Struct | ir.ObjectArray | ir.CellArray):
raise AbortParse(f"Field '{name}' contains a nested struct or cell array")
return field.data[0].value


def _unpack_cell_array(cell_array: ir.CellArray) -> list[Any]:
# This does not support general cell arrays.
# It is specialized for 1d cell arrays of strings.
data = (obj.data for obj in cell_array.data)
return [d[0].value if len(d) == 1 else [x.value for x in d] for d in data]
return [d[0].value if len(d) == 1 else [x.value for x in d] for d in data] # type: ignore[union-attr]


def _parse_main_header_cl_2_0(struct: ir.Struct) -> SqwMainHeader:
Expand All @@ -275,8 +275,8 @@ def _parse_main_header_cl_2_0(struct: ir.Struct) -> SqwMainHeader:
def _parse_dnd_metadata_1_0(struct: ir.Struct) -> SqwDndMetadata:
(axes_struct,) = _get_struct_field(struct, "axes").data
(proj_struct,) = _get_struct_field(struct, "proj").data
proj, units = _parse_line_proj_7_0(proj_struct)
axes = _parse_line_axes_7_0(axes_struct, units)
proj, units = _parse_line_proj_7_0(proj_struct) # type: ignore[arg-type]
axes = _parse_line_axes_7_0(axes_struct, units) # type: ignore[arg-type]
return SqwDndMetadata(
axes=axes,
proj=proj,
Expand Down
4 changes: 2 additions & 2 deletions tests/write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_create_writes_main_header(

with Sqw.open(buffer) as sqw:
main_header = sqw.read_data_block(("", "main_header"))
assert main_header.full_filename == "" # because we use a buffer
assert main_header.full_filename == "in_memory" # because we use a buffer
assert main_header.title == "my title"
assert main_header.nfiles == 0
assert (main_header.creation_date - datetime.now(tz=timezone.utc)) < timedelta(
Expand All @@ -123,7 +123,7 @@ def test_register_pixel_data_writes_pix_metadata(

with Sqw.open(buffer) as sqw:
pix_metadata = sqw.read_data_block(("pix", "metadata"))
assert pix_metadata.full_filename == "" # because we use a buffer
assert pix_metadata.full_filename == "in_memory" # because we use a buffer
assert pix_metadata.npix == 13
assert pix_metadata.data_range.shape == (9, 2)

Expand Down

0 comments on commit 8e9ba48

Please sign in to comment.