Skip to content

Commit

Permalink
WIP: longitudinal drift propagation direction and more
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Sep 11, 2024
1 parent c546e6f commit 9a0ef40
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 372 deletions.
350 changes: 48 additions & 302 deletions docs/examples/wavefront.ipynb

Large diffs are not rendered by default.

138 changes: 68 additions & 70 deletions pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .metadata import PolarizationDirection, WavefrontMetadata
from . import writers
from .units import pmd_unit, known_unit
from .units import known_unit

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -404,52 +404,28 @@ def domains_omega_thx_thy(
)


def domains_kxky(
grids: Sequence[int],
pads: Sequence[int],
deltas: Sequence[float],
):
"""
Transverse Fourier space domain x and y.
In units of the scipy FFT.
Parameters
----------
grids : tuple of ints
Data gridding.
pads : tuple of ints
Number of padding points for each axis
deltas : tuple of floats
Grid delta steps in cartesian space.
"""
assert len(grids) == len(pads) == len(deltas)
return nd_kspace_mesh(
coeffs=(1.0,) * (len(grids) - 1),
sizes=grids[1:],
pads=pads[1:],
steps=deltas[1:],
)


def drift_kernel_z(
domains_kxky: List[np.ndarray],
def drift_kernel_paraxial(
transverse_kspace_grid: List[np.ndarray],
z: float,
wavelength: float,
):
"""Drift transfer function in Z [m] paraxial approximation."""
kx, ky = domains_kxky
kx, ky = transverse_kspace_grid
return np.exp(-1j * z * np.pi * wavelength * (kx**2 + ky**2))


def drift_propagator_z(
def drift_propagator_paraxial(
kmesh: np.ndarray,
domains_kxky: List[np.ndarray],
transverse_kspace_grid: List[np.ndarray],
z: float,
wavelength: float,
):
"""Fresnel propagator in paraxial approximation to distance z [m]."""
return kmesh * drift_kernel_z(domains_kxky=domains_kxky, z=z, wavelength=wavelength)
return kmesh * drift_kernel_paraxial(
transverse_kspace_grid=transverse_kspace_grid,
z=z,
wavelength=wavelength,
)


def thin_lens_kernel_xy(
Expand Down Expand Up @@ -540,7 +516,7 @@ def create_gaussian_pulse_3d_with_q(
return pulse.astype(dtype)


def max_divergence_padding_factor(
def transverse_divergence_padding_factor(
theta_max: float,
drift_distance: float,
beam_size: float,
Expand All @@ -551,7 +527,7 @@ def max_divergence_padding_factor(
Parameters
----------
theta_max : float
Maximum divergence [rad]
Maximum transverse divergence [rad]
drift_distance : float
Drift propagation distance [m]
beam_size : float
Expand All @@ -562,8 +538,7 @@ def max_divergence_padding_factor(
float
Factor to increase the initial number of grid points, per dimension.
"""
# TODO: balticfish
return (theta_max * drift_distance) / beam_size
return 2.0 * (theta_max * drift_distance) / beam_size


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -687,22 +662,21 @@ class Wavefront:
Parameters
----------
rmesh : np.ndarray
Cartesian space field data.
Cartesian space field data. [V/m]
wavelength : float
Wavelength (lambda0) [m].
Wavelength. [m]
grid_spacing : sequence of float
Grid spacing for the corresponding dimensions.
Grid spacing for the corresponding dimensions. [m]
polarization : {"x", "y", "z"}, default="x"
Direction of polarization. The default assumes a planar undulator
with electric field polarization in the X direction.
Circular or generalized polarization is not currently supported.
metadata : Metadata, optional
OpenPMD-specific metadata.
pad : int or tuple of int, optional
Padding for each of the dimensions. Defaults to 40 for the time
dimension and 100 for the remaining dimensions.
unit : pmd_unit, default=V/m
Units of `rmesh`.
pad_theta_max : float, default=5e-5
pad_drift_distance : float, default=1.0
pad_beam_size : float, default=1e-4
See Also
--------
Expand All @@ -729,23 +703,39 @@ def __init__(
polarization: Optional[PolarizationDirection] = None,
axis_labels: Optional[Sequence[str]] = None,
metadata: Optional[Union[WavefrontMetadata, dict]] = None,
units: Optional[pmd_unit] = None,
pad_theta_max: float = 5e-5,
pad_drift_distance: float = 1.0,
pad_beam_size: float = 1e-4,
longitudinal_axis: Optional[str] = None,
) -> None:
self._phasors = None
self._rmesh = rmesh
self._rmesh_shape = rmesh.shape
self._kmesh = None
self._wavelength = wavelength

if pad is None:
pad_factor = transverse_divergence_padding_factor(
theta_max=pad_theta_max,
drift_distance=pad_drift_distance,
beam_size=pad_beam_size,
)
# TODO: current factor is just for transverse, we need to calculate
# for longitudinal as well
if pad_factor < 1:
pad = (0, 0, 0)
else:
pad = tuple((dim * pad_factor) - dim for dim in rmesh.shape)

self._set_metadata(
metadata,
pad=pad,
polarization=polarization,
axis_labels=axis_labels,
grid_spacing=grid_spacing,
units=units,
)
self._check_metadata()
self._longitudinal_axis = longitudinal_axis or self.axis_labels[-1]

def _set_metadata(
self,
Expand All @@ -754,7 +744,7 @@ def _set_metadata(
pad: Optional[Union[int, Sequence[int]]] = None,
polarization: Optional[PolarizationDirection] = None,
axis_labels: Optional[Sequence[str]] = None,
units: Optional[pmd_unit] = None,
# units: Optional[pmd_unit] = None,
) -> None:
if metadata is None:
metadata = WavefrontMetadata()
Expand All @@ -773,7 +763,6 @@ def _set_metadata(
if isinstance(pad, int):
pad = (pad,) * ndim
if pad is None:
# TODO: investigate padding values
if md.pads:
pad = md.pads
else:
Expand All @@ -787,8 +776,8 @@ def _set_metadata(
md.mesh.axis_labels = tuple(axis_labels)
if grid_spacing is not None:
md.mesh.grid_spacing = tuple(grid_spacing)
if units is not None:
md.units = units
# if units is not None:
md.units = known_unit["V/m"]

self._metadata = md

Expand All @@ -813,6 +802,8 @@ def __copy__(self) -> Wavefront:
res._wavelength = self._wavelength
res._pad = self._pad
res._metadata = copy.deepcopy(self._metadata)
# TODO there are more fields now
res._longitudinal_axis = self._longitudinal_axis
return res

def __deepcopy__(self, memo) -> Wavefront:
Expand Down Expand Up @@ -883,6 +874,8 @@ def gaussian_pulse(
wavelength=wavelength,
grid_spacing=grid_spacing,
pad=pad,
axis_labels="zxy",
longitudinal_axis="z",
)

@property
Expand Down Expand Up @@ -1002,8 +995,7 @@ def wavelength(self) -> float:
@property
def photon_energy(self) -> float:
"""Photon energy [eV]."""
# TODO check
h = scipy.constants.value("Planck constant in eV/Hz")
h = scipy.constants.value("Planck constant in eV/Hz") / (2 * np.pi)
freq = scipy.constants.speed_of_light / self.wavelength
return h * freq

Expand All @@ -1023,6 +1015,10 @@ def ranges(self):
dims=self.rmesh.shape,
)

@property
def axis_labels(self):
return self.metadata.mesh.axis_labels

def focus(
self,
plane: Plane,
Expand Down Expand Up @@ -1067,19 +1063,16 @@ def focus(

def drift(
self,
direction: Union[str, int],
distance: float,
*,
inplace: bool = False,
) -> Wavefront:
"""
Drift this Wavefront along `direction` in meters.
Drift this Wavefront along the longitudinal direction in meters.
Parameters
----------
direction : str or (int, int)
Propagation direction dimension name (e.g., "z") or dimension index (e.g., `2`)
z_prop : float
distance : float
Distance in meters.
inplace : bool, default=False
Perform the operation in-place on this wavefront object.
Expand All @@ -1090,20 +1083,25 @@ def drift(
This object if `inplace=True` or a new copy if `inplace=False`.
"""

if direction not in {"z", 2}:
raise NotImplementedError(f"Unsupported propagation direction: {direction}")

if not inplace:
wavefront = copy.copy(self)
return wavefront.drift(direction, distance, inplace=True)
return wavefront.drift(distance, inplace=True)

self._kmesh = drift_propagator_z(
indices = [
idx
for idx, label in enumerate(self.axis_labels)
if label != self._longitudinal_axis
]
transverse_kspace_grid = nd_kspace_mesh(
coeffs=(1.0,) * len(self._rmesh_shape),
sizes=[self._pad.grid[idx] for idx in indices],
pads=[self._pad.pad[idx] for idx in indices],
steps=[self.grid_spacing[idx] for idx in indices],
)

self._kmesh = drift_propagator_paraxial(
kmesh=self.kmesh,
domains_kxky=domains_kxky(
grids=self._pad.grid,
pads=self._pad.pad,
deltas=self.grid_spacing,
),
transverse_kspace_grid=transverse_kspace_grid,
wavelength=self._wavelength,
z=float(distance),
)
Expand Down Expand Up @@ -1293,7 +1291,7 @@ def from_genesis4(
pad=pad,
polarization="x",
axis_labels="xyz",
units=known_unit["V/m"],
longitudinal_axis="z",
)
wf.metadata.mesh.grid_global_offset = (0.0, 0.0, field.param.refposition)
return wf
Expand Down

0 comments on commit 9a0ef40

Please sign in to comment.