Skip to content

Commit

Permalink
ENH: initial round-tripping of pmd wavefront format
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Dec 16, 2024
1 parent 6eacca7 commit fb43d19
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 21 deletions.
69 changes: 59 additions & 10 deletions pmd_beamphysics/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,41 @@
import getpass
import platform
from collections.abc import Sequence
from typing import TypeVar

from typing_extensions import Literal

import h5py
import numpy as np

from . import tools
from .types import Dataclass
from .units import pmd_unit, known_unit

PolarizationDirection = Literal["x", "y", "z"]


def get_pmd_metadata_dict(
obj: Dataclass,
attrs: Sequence[str],
) -> dict[str, str | float | None]:
def python_attrs_to_pmd_keys(obj: Dataclass | type[Dataclass]) -> dict[str, str]:
assert dataclasses.is_dataclass(obj)
attr_to_field = {
return {
fld.name: fld.metadata.get("pmd_key", fld.name)
for fld in dataclasses.fields(obj)
}


def hdf_to_python_attrs(obj: Dataclass | type[Dataclass]) -> dict[str, str]:
assert dataclasses.is_dataclass(obj)
return {
fld.metadata.get("pmd_key", fld.name): fld.name
for fld in dataclasses.fields(obj)
}


def get_pmd_metadata_dict(
obj: Dataclass,
attrs: Sequence[str],
) -> dict[str, str | float | None]:
attr_to_field = python_attrs_to_pmd_keys(obj)
return {
attr_to_field[attr]: getattr(obj, attr)
for attr in attrs
Expand All @@ -35,6 +51,27 @@ def _key(pmd_key: str):
return {"pmd_key": pmd_key}


_T = TypeVar("_T", bound=Dataclass)


def _dataclass_from_hdf5(cls: type[_T], h5: h5py.Group) -> _T:
hdf_key_to_attr = hdf_to_python_attrs(cls)

def maybe_decode(value):
if isinstance(value, bytes):
return value.decode()
if isinstance(value, np.ndarray):
return tuple(value.tolist())
return value

values = {
attr: maybe_decode(h5.attrs[hdf_key])
for hdf_key, attr in hdf_key_to_attr.items()
if hdf_key in h5.attrs
}
return cls(**values)


@dataclasses.dataclass
class BaseMetadata:
"""Base metadata for OpenPMD spec files."""
Expand Down Expand Up @@ -170,7 +207,6 @@ class WavefrontMetadata:
base: BaseMetadata = dataclasses.field(default_factory=BaseMetadata)
iteration: IterationMetadata = dataclasses.field(default_factory=IterationMetadata)
mesh: MeshMetadata = dataclasses.field(default_factory=MeshMetadata)

polarization: PolarizationDirection = dataclasses.field(default="x")
beamline: str = dataclasses.field(default="")
radius_of_curvature_x: float | None = dataclasses.field(
Expand All @@ -190,9 +226,9 @@ class WavefrontMetadata:
metadata=_key("deltaRadiusOfCurvatureY"),
)
z_coordinate: float = dataclasses.field(default=0.0, metadata=_key("zCoordinate"))
pads: tuple[int, ...] = dataclasses.field(
default_factory=tuple, metadata=_key("pads")
)
# pads: tuple[int, ...] = dataclasses.field(
# default_factory=tuple, metadata=_key("pads")
# )

@property
def attrs(self) -> dict[str, str | float | None]:
Expand All @@ -205,7 +241,7 @@ def attrs(self) -> dict[str, str | float | None]:
"radius_of_curvature_y",
"delta_radius_of_curvature_x",
"delta_radius_of_curvature_y",
"pads",
# "pads",
],
)

Expand Down Expand Up @@ -235,3 +271,16 @@ def from_dict(cls, md: dict) -> WavefrontMetadata:
iteration=IterationMetadata(**iteration_md),
**md,
)

@classmethod
def from_hdf5(cls, base_h5: h5py.Group, field_h5: h5py.Group, rmesh_h5: h5py.Group):
md = _dataclass_from_hdf5(cls, field_h5)
md.base = _dataclass_from_hdf5(BaseMetadata, base_h5.require_group("/"))
md.iteration = _dataclass_from_hdf5(IterationMetadata, base_h5)
md.mesh = _dataclass_from_hdf5(MeshMetadata, rmesh_h5)
if isinstance(md.base.date, str):
try:
md.base.date = datetime.datetime.fromisoformat(md.base.date)
except Exception:
md.base.date = tools.current_date_with_tzinfo()
return md
69 changes: 58 additions & 11 deletions pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import scipy.fft
from mpl_toolkits.axes_grid1 import make_axes_locatable

from . import writers
from . import readers, writers
from .metadata import PolarizationDirection, WavefrontMetadata
from .units import known_unit, nice_array

Expand Down Expand Up @@ -731,6 +731,12 @@ def calculate_phasors(
)


def wavelength_to_photon_energy(wavelength: float) -> float:
h = scipy.constants.value("Planck constant in eV/Hz") / (2 * np.pi)
freq = scipy.constants.speed_of_light / wavelength
return h * freq


class Wavefront:
"""
Particle field wavefront.
Expand Down Expand Up @@ -1202,9 +1208,7 @@ def k0(self) -> float:
@property
def photon_energy(self) -> float:
"""Photon energy [eV]."""
h = scipy.constants.value("Planck constant in eV/Hz") / (2 * np.pi)
freq = scipy.constants.speed_of_light / self.wavelength
return h * freq
return wavelength_to_photon_energy(self.wavelength)

@property
def grid_spacing(self) -> tuple[float, ...]:
Expand Down Expand Up @@ -1937,8 +1941,51 @@ def write_genesis4(self, h5: h5py.File | pathlib.Path | str) -> None:
field_file.write_genesis4(h5)

@classmethod
def _from_h5_file(cls, h5: h5py.File) -> Wavefront:
return cls()
def _from_h5_file(cls, h5: h5py.File, identifier: int) -> Wavefront:
base_path = h5.attrs["basePath"].decode()
# data_type = h5.attrs["dataType"]
openpmd_extension = h5.attrs["openPMDextension"].decode()
if "Wavefront" not in openpmd_extension:
raise ValueError(
f"Wavefront extension not enabled in file."
f"Extensions configured: {openpmd_extension}"
)

iteration_path = base_path.replace("%T", str(identifier))
wavefront_field_path = h5.attrs["wavefrontFieldPath"].decode()

group = h5[iteration_path][wavefront_field_path]
if not isinstance(group, h5py.Group):
raise ValueError(
f"Key {group} expected to be a group, but is a {type(group)}"
)

efield = group["electricField"]

photon_energy = efield.attrs["photonEnergy"]
assert isinstance(photon_energy, float)
# efield["photonEnergyUnitSI"]
# efield["photonEnergyUnitDimension"]
# efield["temporalDomain"]
# efield["spatialDomain"]
for polarization in "xyz":
try:
rmesh_group = efield[polarization]
except KeyError:
pass
else:
break
else:
raise ValueError("No supported polarization direction group found")

metadata = WavefrontMetadata.from_hdf5(h5, efield, rmesh_group)

rmesh = readers.component_data(rmesh_group)
return cls(
rmesh=rmesh,
wavelength=wavelength_to_photon_energy(photon_energy),
metadata=metadata,
)

@classmethod
def from_file(
Expand All @@ -1948,9 +1995,9 @@ def from_file(
) -> Wavefront:
"""Load a Wavefront from a file in the OpenPMD format."""
if isinstance(h5, h5py.File):
return cls._from_h5_file(h5)
return cls._from_h5_file(h5, identifier=identifier)
with h5py.File(h5) as h5p:
return cls._from_h5_file(h5p)
return cls._from_h5_file(h5p, identifier=identifier)

# names = get_wavefront_names_from_file("something.h5")
# for name in names:
Expand All @@ -1963,10 +2010,10 @@ def _write_file(self, h5: h5py.File):
md = self.metadata
base_path_template = "/data/%T/"
if md.index is not None:
wavefront_base_path_template = "/wavefront/%T/"
wavefront_base_path_template = "wavefront/%T/"
else:
# For us, at least, second %T doesn't make much sense:
wavefront_base_path_template = "/wavefront"
wavefront_base_path_template = "wavefront"
wavefront_base_path = wavefront_base_path_template.replace("%T", str(md.index))
base_path = base_path_template.replace("%T", str(md.iteration.iteration))

Expand Down Expand Up @@ -1996,7 +2043,7 @@ def _write_file(self, h5: h5py.File):
base_group = h5.create_group(base_path)
writers.write_attrs(base_group, md.iteration.attrs)

electric_field_path = wavefront_path + "electricField/"
electric_field_path = wavefront_path + "/electricField/"
efield_group = h5.create_group(electric_field_path)
self.write_group(efield_group)

Expand Down
30 changes: 30 additions & 0 deletions tests/test_wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,33 @@ def test_write_and_read_genesis4(
# assert np.allclose(wavefront.rmesh.imag, loaded.rmesh.imag)
# loaded._rmesh = wavefront.rmesh
# assert wavefront == loaded


def test_write_and_read_openpmd(
wavefront: Wavefront,
tmp_path: pathlib.Path,
request: pytest.FixtureRequest,
):
fn = tmp_path / f"{request.node.name}.h5"
wavefront.metadata.mesh.grid_global_offset = (0.0, 0.0, 0.0)

wavefront.write(fn)
loaded = wavefront.from_file(fn).with_padding(wavefront.pad)

# check these individually before testing full equality, so we don't get just a final failure
assert wavefront.grid == loaded.grid
assert np.all(wavefront._rmesh == loaded._rmesh)
assert np.all(wavefront._kmesh == loaded._kmesh)
assert wavefront.wavelength == loaded.wavelength
# TODO we don't store padding
assert wavefront.pad == loaded.pad

# TODO: we don't store microseconds
loaded.metadata.base.date = loaded.metadata.base.date.replace(
microsecond=wavefront.metadata.base.date.microsecond
)
assert wavefront.metadata.base == loaded.metadata.base
assert wavefront.metadata.iteration == loaded.metadata.iteration
assert wavefront.metadata.mesh == loaded.metadata.mesh
assert wavefront.metadata == loaded.metadata
assert wavefront == loaded

0 comments on commit fb43d19

Please sign in to comment.