Skip to content

Commit

Permalink
WIP: assuredly incorrect genesis4 import
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Aug 20, 2024
1 parent b8f822b commit a879cbe
Showing 1 changed file with 67 additions and 49 deletions.
116 changes: 67 additions & 49 deletions pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pathlib
from typing import Any, List, Optional, Sequence, Tuple, Union

import h5py
import matplotlib
import matplotlib.axes
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -232,20 +233,7 @@ def is_odd(value: int) -> bool:
return value % 2 == 1


def _fix_fft_dimension(dim: int):
"""Get a dimension that's efficient for the FFT - and also odd for symmetry."""
while True:
next_dim = scipy.fft.next_fast_len(dim, real=False)
if next_dim is None:
raise ValueError(f"Unable to get the next valid dimension for: {dim}")
dim = next_dim
if is_odd(dim):
break
dim += 1
return dim


def _fix_grid_padding(grid: int, pad: int) -> Tuple[int, int]:
def _fix_grid_padding(grid: int, pad: int) -> int:
"""
Fix gridding and padding values for symmetry and FFT efficiency.
Expand All @@ -260,29 +248,29 @@ def _fix_grid_padding(grid: int, pad: int) -> Tuple[int, int]:
Returns
-------
int
Adjusted data gridding.
int
Adjusted data padding.
"""
# Grid must be odd for us:
if not is_odd(grid):
grid += 1

# Ensure that our FFT dimension is odd and optimal for scipy's FFT:
dim = _fix_fft_dimension(grid + 2 * pad)
assert is_odd(dim), "FFT dimension not odd?"
def is_good(dim: int) -> bool:
return is_odd(dim) and scipy.fft.next_fast_len(dim, real=False) == dim

while not is_good(grid + pad):
dim = scipy.fft.next_fast_len(grid + pad + 1, real=False)
if dim is None:
raise ValueError(
f"Unable to get the next valid FFT length for: {grid=} {pad=}"
)

# Fix padding based on our optimal dimension:
pad = (dim - grid) // 2
assert not is_odd(dim - grid), "End dimension not even as expected?"
return grid, pad
pad = dim - grid

return pad


def fix_padding(
grid: Sequence[int],
pad: Sequence[int],
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
) -> Tuple[int, ...]:
"""
Fix gridding and padding values for symmetry and FFT efficiency.
Expand All @@ -297,8 +285,6 @@ def fix_padding(
Returns
-------
tuple of ints
Adjusted data gridding.
tuple of ints
Adjusted data padding.
"""
Expand All @@ -308,21 +294,20 @@ def fix_padding(
f"Got: {len(grid)} and {len(pad)}"
)

result = [[], []]
final_padding = []
for dim, (dim_grid, dim_pad) in enumerate(zip(grid, pad)):
new_grid, new_pad = _fix_grid_padding(dim_grid, dim_pad)
logger.debug(
"Grid[%d] %d -> %d pad %d -> %d",
dim,
dim_grid,
new_grid,
dim_pad,
new_pad,
)
result[0].append(new_grid)
result[1].append(new_pad)
new_pad = _fix_grid_padding(dim_grid, dim_pad)
if new_pad != dim_pad:
logger.debug(
"Grid[dim=%d] grid=%d pad=%d -> adjusted padding=%d",
dim,
dim_grid,
dim_pad,
new_pad,
)
final_padding.append(new_pad)

return tuple(result[0]), tuple(result[1])
return tuple(final_padding)


def get_shifts(
Expand Down Expand Up @@ -584,19 +569,18 @@ def fix(self) -> WavefrontPadding:
Such that the total array size will be: `grid + 2 * pad`.
"""
grid, pad = fix_padding(self.grid, self.pad)
return WavefrontPadding(grid, pad)
return WavefrontPadding(self.grid, fix_padding(self.grid, self.pad))

def get_padded_shape(self, field_rspace: np.ndarray) -> Tuple[int, ...]:
"""Get the padded shape given a 3D field rspace array."""
nd = len(self.grid)
if field_rspace.ndim != nd:
raise ValueError(f"`field_rspace` is not an {nd}D array")

if not all(is_odd(dim) for dim in field_rspace.shape):
raise ValueError(
f"`field_rspace` dimensions are not all odd numbers: {field_rspace.shape}"
)
# if not all(is_odd(dim) for dim in field_rspace.shape):
# raise ValueError(
# f"`field_rspace` dimensions are not all odd numbers: {field_rspace.shape}"
# )

return tuple(dim + 2 * pad for dim, pad in zip(field_rspace.shape, self.pad))

Expand Down Expand Up @@ -636,6 +620,12 @@ def __init__(
) -> None:
if not pad:
pad = (40,) + (100,) * (field_rspace.ndim - 1)

if len(ranges) != field_rspace.ndim:
raise ValueError(
"'ranges' must have the same number of dimensions as `field_rspace`; "
"each should describe the cartesian range of the corresponding axis."
)
self._phasors = None
self._field_rspace = field_rspace
self._field_rspace_shape = field_rspace.shape
Expand Down Expand Up @@ -1003,7 +993,13 @@ def plot(

sum_axis = {
# TODO: when standardized, this will be xyz instead of txy
"xy": 0,
# "xy": 0,
# (1, 2): 0,
"xy": 2,
(0, 1): 2,
"xz": 1,
(0, 2): 1,
"yz": 0,
(1, 2): 0,
}[plane]

Expand Down Expand Up @@ -1055,3 +1051,25 @@ def plot(dat, title: str):
fig.savefig(save)

return fig, axs

@classmethod
def from_genesis4(
cls, h5: Union[h5py.File, pathlib.Path, str], pad: int = 100
) -> Wavefront:
from genesis.version4 import FieldFile

field = FieldFile.from_file(h5)

_nx, _ny, nz = field.dfl.shape
z_low = field.param.refposition
z_high = z_low + field.param.slicespacing * nz # TODO: off by one?
return cls(
field_rspace=field.dfl,
wavelength=field.param.wavelength,
ranges=[
(-field.param.gridsize, field.param.gridsize),
(-field.param.gridsize, field.param.gridsize),
(z_low, z_high),
],
pad=(pad, pad, pad),
)

0 comments on commit a879cbe

Please sign in to comment.