Skip to content

Commit

Permalink
MAINT: rename field_[rk]space -> [rk]mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Sep 5, 2024
1 parent 214b3c1 commit fdb36a1
Show file tree
Hide file tree
Showing 3 changed files with 371 additions and 113 deletions.
340 changes: 303 additions & 37 deletions docs/examples/wavefront.ipynb

Large diffs are not rendered by default.

120 changes: 57 additions & 63 deletions pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,15 +443,13 @@ def drift_kernel_z(


def drift_propagator_z(
field_kspace: np.ndarray,
kmesh: np.ndarray,
domains_kxky: List[np.ndarray],
z: float,
wavelength: float,
):
"""Fresnel propagator in paraxial approximation to distance z [m]."""
return field_kspace * drift_kernel_z(
domains_kxky=domains_kxky, z=z, wavelength=wavelength
)
return kmesh * drift_kernel_z(domains_kxky=domains_kxky, z=z, wavelength=wavelength)


def thin_lens_kernel_xy(
Expand Down Expand Up @@ -618,18 +616,18 @@ def fix(self) -> WavefrontPadding:
"""
return WavefrontPadding(self.grid, fix_padding(self.grid, self.pad))

def get_padded_shape(self, field_rspace: np.ndarray) -> Tuple[int, ...]:
def get_padded_shape(self, rmesh: np.ndarray) -> Tuple[int, ...]:
"""Get the padded shape given a 3D field rspace array."""
nd = len(self.grid)
if field_rspace.ndim != nd:
raise ValueError(f"`field_rspace` is not an {nd}D array")
if rmesh.ndim != nd:
raise ValueError(f"`rmesh` is not an {nd}D array")

# if not all(is_odd(dim) for dim in field_rspace.shape):
# if not all(is_odd(dim) for dim in rmesh.shape):
# raise ValueError(
# f"`field_rspace` dimensions are not all odd numbers: {field_rspace.shape}"
# f"`rmesh` dimensions are not all odd numbers: {rmesh.shape}"
# )

return tuple(dim + 2 * pad for dim, pad in zip(field_rspace.shape, self.pad))
return tuple(dim + 2 * pad for dim, pad in zip(rmesh.shape, self.pad))


def get_range_for_grid_spacing(grid_spacing: float, dim: int) -> Tuple[float, float]:
Expand Down Expand Up @@ -688,7 +686,7 @@ class Wavefront:
Parameters
----------
field_rspace : np.ndarray
rmesh : np.ndarray
Cartesian space field data.
wavelength : float
Wavelength (lambda0) [m].
Expand All @@ -704,15 +702,15 @@ class Wavefront:
Padding for each of the dimensions. Defaults to 40 for the time
dimension and 100 for the remaining dimensions.
unit : pmd_unit, default=V/m
Units of `field_rspace`.
Units of `rmesh`.
See Also
--------
[OpenPMD standard](https://github.com/openPMD/openPMD-standard/blob/upcoming-2.0.0/EXT_Wavefront.md)
"""

_field_rspace: Optional[np.ndarray]
_field_kspace: Optional[np.ndarray]
_rmesh: Optional[np.ndarray]
_kmesh: Optional[np.ndarray]
_phasors: Optional[Tuple[np.ndarray, ...]]
# TODO:
# time snapshot in Z
Expand All @@ -723,7 +721,7 @@ class Wavefront:

def __init__(
self,
field_rspace: np.ndarray,
rmesh: np.ndarray,
*,
wavelength: float,
grid_spacing: Optional[Sequence[float]] = None,
Expand All @@ -734,9 +732,9 @@ def __init__(
units: Optional[pmd_unit] = None,
) -> None:
self._phasors = None
self._field_rspace = field_rspace
self._field_rspace_shape = field_rspace.shape
self._field_kspace = None
self._rmesh = rmesh
self._rmesh_shape = rmesh.shape
self._kmesh = None
self._wavelength = wavelength

self._set_metadata(
Expand Down Expand Up @@ -771,7 +769,7 @@ def _set_metadata(
f"'WavefrontMetadata', 'dict', or None to reset the metadata"
)

ndim = len(self._field_rspace_shape)
ndim = len(self._rmesh_shape)
if isinstance(pad, int):
pad = (pad,) * ndim
if pad is None:
Expand All @@ -781,7 +779,7 @@ def _set_metadata(
else:
pad = (100,) * ndim

self._pad = WavefrontPadding.from_array(self.field_rspace, pad=pad, fix=True)
self._pad = WavefrontPadding.from_array(self.rmesh, pad=pad, fix=True)
md.pads = self._pad.pad
if polarization is not None:
md.polarization = polarization
Expand All @@ -795,36 +793,32 @@ def _set_metadata(
self._metadata = md

def _check_metadata(self) -> None:
if len(self.metadata.mesh.grid_spacing) != len(self._field_rspace_shape):
if len(self.metadata.mesh.grid_spacing) != len(self._rmesh_shape):
raise ValueError(
"'grid_spacing' must have the same number of dimensions as `field_rspace`; "
"'grid_spacing' must have the same number of dimensions as `rmesh`; "
"each should describe the cartesian range of the corresponding axis."
)

if len(self.metadata.mesh.axis_labels) != len(self._field_rspace_shape):
if len(self.metadata.mesh.axis_labels) != len(self._rmesh_shape):
raise ValueError(
"'axis_labels' must have the same number of dimensions as `field_rspace`"
"'axis_labels' must have the same number of dimensions as `rmesh`"
)

def __copy__(self) -> Wavefront:
res = Wavefront.__new__(Wavefront)
res._phasors = self._phasors
res._field_rspace_shape = self._field_rspace_shape
res._field_rspace = self._field_rspace
res._field_kspace = self._field_kspace
res._rmesh_shape = self._rmesh_shape
res._rmesh = self._rmesh
res._kmesh = self._kmesh
res._wavelength = self._wavelength
res._pad = self._pad
res._metadata = copy.deepcopy(self._metadata)
return res

def __deepcopy__(self, memo) -> Wavefront:
res = self.__copy__()
res._field_rspace = (
np.copy(self._field_rspace) if self._field_rspace is not None else None
)
res._field_kspace = (
np.copy(self._field_kspace) if self._field_kspace is not None else None
)
res._rmesh = np.copy(self._rmesh) if self._rmesh is not None else None
res._kmesh = np.copy(self._kmesh) if self._kmesh is not None else None
return res

def __eq__(self, other: Any) -> bool:
Expand All @@ -833,9 +827,9 @@ def __eq__(self, other: Any) -> bool:

return all(
(
self._field_rspace_shape == other._field_rspace_shape,
np.all(self._field_rspace == other._field_rspace),
np.all(self._field_kspace == other._field_kspace),
self._rmesh_shape == other._rmesh_shape,
np.all(self._rmesh == other._rmesh),
np.all(self._kmesh == other._kmesh),
self._wavelength == other._wavelength,
self._pad == other._pad,
self.metadata == other.metadata,
Expand Down Expand Up @@ -885,7 +879,7 @@ def gaussian_pulse(
dtype=dtype,
)
return cls(
field_rspace=pulse,
rmesh=pulse,
wavelength=wavelength,
grid_spacing=grid_spacing,
pad=pad,
Expand All @@ -905,7 +899,7 @@ def _calc_phasors(self) -> Tuple[np.ndarray, ...]:
"""Calculate phasors for each dimension of the cartesian domain."""
coeffs = conversion_coeffs(
wavelength=self._wavelength,
dim=len(self._field_rspace_shape),
dim=len(self._rmesh_shape),
)
shifts = get_shifts(
ranges=self.ranges,
Expand Down Expand Up @@ -933,18 +927,18 @@ def phasors(self) -> Tuple[np.ndarray, ...]:
return self._phasors

@property
def field_rspace(self) -> np.ndarray:
def rmesh(self) -> np.ndarray:
"""Real-space (cartesian space) wavefront field data."""
if self._field_rspace is None:
self._field_rspace = self._ifft()
return self._field_rspace
if self._rmesh is None:
self._rmesh = self._ifft()
return self._rmesh

@property
def field_kspace(self) -> np.ndarray:
def kmesh(self) -> np.ndarray:
"""K-space (reciprocal space) wavefront field data."""
if self._field_kspace is None:
self._field_kspace = self._fft()
return self._field_kspace
if self._kmesh is None:
self._kmesh = self._fft()
return self._kmesh

@property
def polarization(self) -> PolarizationDirection:
Expand All @@ -965,15 +959,15 @@ def _fft(self):
"""
Calculate the FFT (rspace -> kspace) on the user data.
Requires that the `_field_rspace` data is available.
Requires that the `_rmesh` data is available.
This is intended to be handled by the `Wavefront` class itself, such
that the user does not need to pay attention to whether the real-space
or k-space wavefront data is up-to-date.
"""
assert self._field_rspace is not None
assert self._rmesh is not None
workers = get_num_fft_workers()
dfl_pad = _pad_array(self.field_rspace, self._pad.pad_shape)
dfl_pad = _pad_array(self.rmesh, self._pad.pad_shape)
return fft_phased(
dfl_pad,
axes=(0, 1, 2),
Expand All @@ -985,16 +979,16 @@ def _ifft(self):
"""
Calculate the inverse FFT (kspace -> rspace) on the user data.
Requires that the `_field_kspace` data is available.
Requires that the `_kmesh` data is available.
This is intended to be handled by the `Wavefront` class itself, such
that the user does not need to pay attention to whether the real-space
or k-space wavefront data is up-to-date.
"""
assert self._field_kspace is not None
assert self._kmesh is not None
workers = get_num_fft_workers()
return ifft_phased(
self._field_kspace,
self._kmesh,
axes=(0, 1, 2),
phasors=self.phasors,
workers=workers,
Expand Down Expand Up @@ -1026,7 +1020,7 @@ def grid_spacing(self) -> Tuple[float, ...]:
def ranges(self):
return get_ranges_for_grid_spacing(
grid_spacing=self.metadata.mesh.grid_spacing,
dims=self.field_rspace.shape,
dims=self.rmesh.shape,
)

def focus(
Expand Down Expand Up @@ -1060,15 +1054,15 @@ def focus(
wavefront = copy.copy(self)
return wavefront.focus(plane, focus, inplace=True)

self._field_rspace = self.field_rspace * thin_lens_kernel_xy(
self._rmesh = self.rmesh * thin_lens_kernel_xy(
wavelength=self.wavelength,
ranges=self.ranges,
grid=self._pad.grid,
f_lens_x=focus[0],
f_lens_y=focus[1],
)
# Invalidate the spectral data
self._field_kspace = None
self._kmesh = None
return self

def drift(
Expand Down Expand Up @@ -1103,8 +1097,8 @@ def drift(
wavefront = copy.copy(self)
return wavefront.drift(direction, distance, inplace=True)

self._field_kspace = drift_propagator_z(
field_kspace=self.field_kspace,
self._kmesh = drift_propagator_z(
kmesh=self.kmesh,
domains_kxky=domains_kxky(
grids=self._pad.grid,
pads=self._pad.pad,
Expand All @@ -1114,7 +1108,7 @@ def drift(
z=float(distance),
)
# Invalidate the real space data
self._field_rspace = None
self._rmesh = None
return self

def plot(
Expand Down Expand Up @@ -1186,9 +1180,9 @@ def plot(
list of Axes
"""
if rspace:
data = self.field_rspace
data = self.rmesh
else:
data = self.field_kspace
data = self.kmesh

if transpose:
data = data.T
Expand Down Expand Up @@ -1289,7 +1283,7 @@ def from_genesis4(

# field.param.gridsize = 2 * field.dgrid / (ngrid - 1)
wf = cls(
field_rspace=field.dfl * genesis_to_v_over_m,
rmesh=field.dfl * genesis_to_v_over_m,
wavelength=field.param.wavelength,
grid_spacing=(
field.param.gridsize,
Expand Down Expand Up @@ -1323,7 +1317,7 @@ def from_file(
# names = get_wavefront_names_from_file("something.h5")
# for name in names:
# wavefront = Wavefront.from_file("something.h5", identifier=name)
# wavefront.field_rspace # <-- single wavefront
# wavefront.rmesh # <-- single wavefront
#
# wavefront = Wavefront.from_file("something.h5", identifier=5)

Expand Down Expand Up @@ -1394,7 +1388,7 @@ def write_group(self, group: h5py.Group) -> None:
writers.write_component_data(
group,
name=self.polarization,
data=self.field_rspace,
data=self.rmesh,
unit=self.metadata.units,
attrs=self.metadata.mesh.attrs,
)
24 changes: 11 additions & 13 deletions tests/test_wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_smoke_drift_z_in_place(wavefront: Wavefront) -> None:
# Implicitly calculates the FFT:
wavefront.drift(direction="z", distance=0.0, inplace=True)
# Use the property to calculate the inverse fft:
wavefront.field_rspace
wavefront.rmesh


def test_smoke_drift_z(wavefront: Wavefront) -> None:
Expand Down Expand Up @@ -109,33 +109,31 @@ def test_padding_fix(padding: WavefrontPadding, expected: WavefrontPadding) -> N

def test_smoke_properties(wavefront: Wavefront) -> None:
assert len(wavefront.phasors) == 3
assert wavefront.field_rspace.shape == (11, 21, 21)
assert wavefront.field_kspace.shape == wavefront.pad.get_padded_shape(
wavefront.field_rspace
)
assert wavefront.rmesh.shape == (11, 21, 21)
assert wavefront.kmesh.shape == wavefront.pad.get_padded_shape(wavefront.rmesh)
assert np.isclose(wavefront.wavelength, 1.35e-8)
assert wavefront.pad.grid == (11, 21, 21)
assert wavefront.pad.pad == (44, 100, 100)


def test_copy(wavefront: Wavefront) -> None:
wavefront.field_rspace
wavefront.field_kspace
wavefront.rmesh
wavefront.kmesh
copied = copy.copy(wavefront)
assert copied == wavefront
assert copied is not wavefront
assert copied.field_rspace is wavefront.field_rspace
assert copied.field_kspace is wavefront.field_kspace
assert copied.rmesh is wavefront.rmesh
assert copied.kmesh is wavefront.kmesh


def test_deepcopy(wavefront: Wavefront) -> None:
wavefront.field_rspace
wavefront.field_kspace
wavefront.rmesh
wavefront.kmesh
copied = copy.deepcopy(wavefront)
assert copied == wavefront
assert copied is not wavefront
assert copied.field_rspace is not wavefront.field_rspace
assert copied.field_kspace is not wavefront.field_kspace
assert copied.rmesh is not wavefront.rmesh
assert copied.kmesh is not wavefront.kmesh


def test_plot_projection(wavefront: Wavefront, projection_plane: Plane) -> None:
Expand Down

0 comments on commit fdb36a1

Please sign in to comment.