diff --git a/pmd_beamphysics/wavefront.py b/pmd_beamphysics/wavefront.py index f3055d6..d2222a4 100644 --- a/pmd_beamphysics/wavefront.py +++ b/pmd_beamphysics/wavefront.py @@ -3,7 +3,8 @@ import copy import logging import pathlib -from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union +from typing import Any, NamedTuple, Union +from collections.abc import Sequence import h5py import matplotlib @@ -22,9 +23,9 @@ _fft_workers = -1 -Ranges = Sequence[Tuple[float, float]] +Ranges = Sequence[tuple[float, float]] AnyPath = Union[str, pathlib.Path] -Plane = Union[str, Tuple[int, int]] +Plane = Union[str, tuple[int, int]] Z0 = np.pi * 119.9169832 # V^2/W exactly HBAR_EV_M = scipy.constants.hbar / scipy.constants.e * scipy.constants.c # eV-m kspace_labels = { @@ -46,7 +47,7 @@ def set_num_fft_workers(workers: int): logger.info(f"Set number of FFT workers to: {workers}") -def get_axis_index(axis_labels: Sequence[str], axis: Union[str, int]): +def get_axis_index(axis_labels: Sequence[str], axis: str | int): if isinstance(axis, int): if axis >= len(axis_labels): raise ValueError(f"Axis out of bounds: {axis} ({len(axis_labels)}D array)") @@ -55,42 +56,42 @@ def get_axis_index(axis_labels: Sequence[str], axis: Union[str, int]): def get_axis_indices( - axis_labels: Tuple[str, ...], - axes: Union[Sequence[str], Sequence[int]], + axis_labels: tuple[str, ...], + axes: Sequence[str] | Sequence[int], ): return tuple(get_axis_index(axis_labels, axis) for axis in axes) def get_rspace_label( - axis_labels: Tuple[str, ...], - axis: Union[int, str], + axis_labels: tuple[str, ...], + axis: int | str, ): return axis_labels[get_axis_index(axis_labels, axis)] def get_rspace_labels( - axis_labels: Tuple[str, ...], - axes: Union[Sequence[str], Sequence[int]], + axis_labels: tuple[str, ...], + axes: Sequence[str] | Sequence[int], ): return tuple(get_rspace_label(axis_labels, axis) for axis in axes) def get_kspace_label( - axis_labels: Tuple[str, ...], - axis: Union[int, str], + axis_labels: tuple[str, ...], + axis: int | str, ): rspace_label = get_rspace_label(axis_labels, axis) return kspace_labels.get(rspace_label, rspace_label) def get_kspace_labels( - axis_labels: Tuple[str, ...], - axes: Union[Sequence[str], Sequence[int]], + axis_labels: tuple[str, ...], + axes: Sequence[str] | Sequence[int], ): return tuple(get_kspace_label(axis_labels, axis) for axis in axes) -def pad_array(wavefront: np.ndarray, pads: Union[Sequence[int], int]): +def pad_array(wavefront: np.ndarray, pads: Sequence[int] | int): """ Pad an array with complex zero elements. @@ -202,7 +203,7 @@ def nd_kspace_domains( pads: Sequence[int], steps: Sequence[float], shifted: bool = True, -) -> List[np.ndarray]: +) -> list[np.ndarray]: """ Generate reciprocal space domains for given grid sizes and steps. @@ -276,7 +277,7 @@ def nd_kspace_mesh( def nd_space_mesh( - ranges: Sequence[Tuple[float, float]], + ranges: Sequence[tuple[float, float]], sizes: Sequence[int], ): """ @@ -349,7 +350,7 @@ def is_good(dim: int) -> bool: def fix_padding( grid: Sequence[int], pad: Sequence[int], -) -> Tuple[int, ...]: +) -> tuple[int, ...]: """ Fix gridding and padding values for symmetry and FFT efficiency. @@ -398,7 +399,7 @@ def get_shifts( ranges: Ranges, pads: Sequence[int], deltas: Sequence[float], -) -> Tuple[float, ...]: +) -> tuple[float, ...]: """ Effective half sizes with padding in all dimensions. @@ -440,7 +441,7 @@ def calculate_k0(wavelength: float) -> float: return 2.0 * np.pi / wavelength -def conversion_coeffs(wavelength: float, dim: int) -> Tuple[float, ...]: +def conversion_coeffs(wavelength: float, dim: int) -> tuple[float, ...]: """ Conversion coefficients to (radians, radians, eV). @@ -483,7 +484,7 @@ def domains_omega_thx_thy( def drift_kernel_paraxial( - transverse_kspace_grid: List[np.ndarray], + transverse_kspace_grid: list[np.ndarray], z: float, wavelength: float, ): @@ -494,7 +495,7 @@ def drift_kernel_paraxial( def drift_propagator_paraxial( kmesh: np.ndarray, - transverse_kspace_grid: List[np.ndarray], + transverse_kspace_grid: list[np.ndarray], z: float, wavelength: float, ): @@ -620,12 +621,12 @@ def transverse_divergence_padding_factor( return 2.0 * (theta_max * drift_distance) / beam_size -PadShape = Tuple[Tuple[int, int], ...] +PadShape = tuple[tuple[int, int], ...] def get_padded_shape( rmesh_shape: Sequence[int], padding: Sequence[int] -) -> Tuple[int, ...]: +) -> tuple[int, ...]: """Get the padded shape given a 3D field rspace array.""" rmesh_shape = tuple(rmesh_shape) padding = tuple(padding) @@ -638,7 +639,7 @@ def get_padded_shape( return tuple(dim + 2 * pad for dim, pad in zip(rmesh_shape, padding)) -def get_range_for_grid_spacing(grid_spacing: float, dim: int) -> Tuple[float, float]: +def get_range_for_grid_spacing(grid_spacing: float, dim: int) -> tuple[float, float]: """ Given a grid spacing and array dimension, get the range the entire array represents. @@ -669,7 +670,7 @@ def get_range_for_grid_spacing(grid_spacing: float, dim: int) -> Tuple[float, fl def get_ranges_for_grid_spacing( grid_spacing: Sequence[float], dims: Sequence[int], -) -> Sequence[Tuple[float, float]]: +) -> Sequence[tuple[float, float]]: """ Given a grid spacing and array dimension, get the range the entire array represents. @@ -698,7 +699,7 @@ def calculate_phasors( rmesh_grid_spacing: tuple[float, ...], rmesh_grid: tuple[int, ...], pad: tuple[int, ...], -) -> Tuple[np.ndarray, ...]: +) -> tuple[np.ndarray, ...]: """Calculate phasors for each dimension of the cartesian domain.""" ranges = get_ranges_for_grid_spacing( grid_spacing=rmesh_grid_spacing, @@ -752,12 +753,12 @@ class Wavefront: """ # Saved into OpenPMD-compatible file: - _rmesh: Optional[np.ndarray] + _rmesh: np.ndarray | None wavelength: float _metadata: WavefrontMetadata # Internal state for Wavefront: - _kmesh: Optional[np.ndarray] + _kmesh: np.ndarray | None _grid: tuple[int, ...] # rmesh shape _padding: tuple[int, ...] # TODO: @@ -769,12 +770,12 @@ def __init__( rmesh: np.ndarray, *, wavelength: float, - grid_spacing: Optional[Sequence[float]] = None, - polarization: Optional[PolarizationDirection] = None, - axis_labels: Optional[Sequence[str]] = None, - metadata: Optional[Union[WavefrontMetadata, dict]] = None, - longitudinal_axis: Optional[str] = None, - pad: Optional[Union[Sequence[int], int]] = None, + grid_spacing: Sequence[float] | None = None, + polarization: PolarizationDirection | None = None, + axis_labels: Sequence[str] | None = None, + metadata: WavefrontMetadata | dict | None = None, + longitudinal_axis: str | None = None, + pad: Sequence[int] | int | None = None, fix_pad: bool = True, ) -> None: self._rmesh = rmesh @@ -803,9 +804,9 @@ def __init__( def _set_metadata( self, metadata: Any, - grid_spacing: Optional[Sequence[float]] = None, - polarization: Optional[PolarizationDirection] = None, - axis_labels: Optional[Sequence[str]] = None, + grid_spacing: Sequence[float] | None = None, + polarization: PolarizationDirection | None = None, + axis_labels: Sequence[str] | None = None, # units: Optional[pmd_unit] = None, ) -> None: if metadata is None: @@ -880,13 +881,13 @@ def __eq__(self, other: Any) -> bool: @classmethod def gaussian_pulse( cls, - dims: Tuple[int, int, int], + dims: tuple[int, int, int], wavelength: float, nphotons: float, zR: float, sigma_z: float, grid_spacing: Sequence[float], - pad: Optional[Sequence[int]] = None, + pad: Sequence[int] | None = None, dtype=np.complex64, ): """ @@ -942,7 +943,7 @@ def with_rmesh(self, rmesh: np.ndarray) -> Wavefront: def with_padding( self, - pad: Union[int, Sequence[int]], + pad: int | Sequence[int], fix: bool = True, ) -> Wavefront: ndim = len(self._grid) @@ -988,17 +989,17 @@ def with_padding_divergence( return self.with_padding(pad, fix=fix) @property - def grid(self) -> Tuple[int, ...]: + def grid(self) -> tuple[int, ...]: """The rmesh shape, without padding.""" return self._grid @property - def rmesh_shape(self) -> Tuple[int, ...]: + def rmesh_shape(self) -> tuple[int, ...]: """The rmesh shape, without padding.""" return self._grid @property - def pad(self) -> Tuple[int, ...]: + def pad(self) -> tuple[int, ...]: """ Per-dimension padding used in the FFT. @@ -1067,7 +1068,7 @@ def _nice_kspace_domain(self) -> _NiceXYZ: ) @property - def _k_center_indices(self) -> Tuple[int, ...]: + def _k_center_indices(self) -> tuple[int, ...]: return tuple(grid // 2 + pad for grid, pad in zip(self.rmesh.shape, self.pad)) @property @@ -1184,7 +1185,7 @@ def photon_energy(self) -> float: return h * freq @property - def grid_spacing(self) -> Tuple[float, ...]: + def grid_spacing(self) -> tuple[float, ...]: return self.metadata.mesh.grid_spacing @property @@ -1201,7 +1202,7 @@ def axis_labels(self): def focus( self, plane: Plane, - focus: Tuple[float, float], + focus: tuple[float, float], ) -> Wavefront: """ Apply thin lens focusing. @@ -1335,15 +1336,15 @@ def plot( show_power_density: bool = True, show_phase: bool = True, isophase_contour: bool = False, - axs: Optional[List[matplotlib.axes.Axes]] = None, + axs: list[matplotlib.axes.Axes] | None = None, cmap: str = "viridis", - figsize: Optional[Tuple[float, float]] = None, + figsize: tuple[float, float] | None = None, nrows: int = 1, ncols: int = 2, - xlim: Optional[Tuple[float, float]] = None, - ylim: Optional[Tuple[float, float]] = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, tight_layout: bool = True, - save: Optional[AnyPath] = None, + save: AnyPath | None = None, transpose: bool = False, colorbar: bool = True, # contour: bool = True, @@ -1511,12 +1512,12 @@ def plot(dat, title: str): def plot_1d_far_field_spectral_density( self, *, - ax: Optional[matplotlib.axes.Axes] = None, - figsize: Optional[Tuple[float, float]] = None, - xlim: Optional[Tuple[float, float]] = None, - ylim: Optional[Tuple[float, float]] = None, + ax: matplotlib.axes.Axes | None = None, + figsize: tuple[float, float] | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, tight_layout: bool = True, - save: Optional[AnyPath] = None, + save: AnyPath | None = None, ): if ax is not None: fig = ax.get_figure() @@ -1553,12 +1554,12 @@ def plot_1d_far_field_spectral_density( def plot_1d_kmesh_projections( self, *, - axs: Optional[List[matplotlib.axes.Axes]] = None, - figsize: Optional[Tuple[float, float]] = None, - xlim: Optional[Tuple[float, float]] = None, - ylim: Optional[Tuple[float, float]] = None, + axs: list[matplotlib.axes.Axes] | None = None, + figsize: tuple[float, float] | None = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, tight_layout: bool = True, - save: Optional[AnyPath] = None, + save: AnyPath | None = None, ): if axs: (ax1, ax2) = axs @@ -1603,8 +1604,8 @@ def _plot_reciprocal_thy_vs_thx( self, ax: matplotlib.axes.Axes, cmap: str = "viridis", - xlim: Optional[Tuple[float, float]] = None, - ylim: Optional[Tuple[float, float]] = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, ): kdomain = self._nice_kspace_domain @@ -1651,8 +1652,8 @@ def _plot_reciprocal_energy_vs_thetax( self, ax: matplotlib.axes.Axes, cmap: str = "viridis", - xlim: Optional[Tuple[float, float]] = None, - ylim: Optional[Tuple[float, float]] = None, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, ): kdomain = self._nice_kspace_domain extent = ( @@ -1699,15 +1700,15 @@ def _plot_reciprocal_energy_vs_thetax( def plot_reciprocal( self, *, - axs: Optional[List[matplotlib.axes.Axes]] = None, + axs: list[matplotlib.axes.Axes] | None = None, cmap: str = "viridis", - figsize: Optional[Tuple[float, float]] = None, - xlim_theta: Optional[Tuple[float, float]] = None, - ylim_theta: Optional[Tuple[float, float]] = None, - xlim_theta_w: Optional[Tuple[float, float]] = None, - ylim_theta_w: Optional[Tuple[float, float]] = None, + figsize: tuple[float, float] | None = None, + xlim_theta: tuple[float, float] | None = None, + ylim_theta: tuple[float, float] | None = None, + xlim_theta_w: tuple[float, float] | None = None, + ylim_theta_w: tuple[float, float] | None = None, tight_layout: bool = True, - save: Optional[AnyPath] = None, + save: AnyPath | None = None, colorbar: bool = True, ): """ @@ -1783,14 +1784,14 @@ def plot_reciprocal( def from_kmesh( cls, kmesh: np.ndarray, - padding: Optional[Sequence[int]], + padding: Sequence[int] | None, *, wavelength: float, - grid_spacing: Optional[Sequence[float]] = None, - polarization: Optional[PolarizationDirection] = None, - axis_labels: Optional[Sequence[str]] = None, - metadata: Optional[Union[WavefrontMetadata, dict]] = None, - longitudinal_axis: Optional[str] = None, + grid_spacing: Sequence[float] | None = None, + polarization: PolarizationDirection | None = None, + axis_labels: Sequence[str] | None = None, + metadata: WavefrontMetadata | dict | None = None, + longitudinal_axis: str | None = None, ) -> Wavefront: if padding is None: padding = (0,) * kmesh.ndim @@ -1823,8 +1824,8 @@ def from_kmesh( @classmethod def from_genesis4( cls, - h5: Union[h5py.File, pathlib.Path, str], - pad: Union[int, Tuple[int, int, int]] = 100, + h5: h5py.File | pathlib.Path | str, + pad: int | tuple[int, int, int] = 100, ) -> Wavefront: """ Load a Genesis4-format field file as a `Wavefront`. @@ -1882,7 +1883,7 @@ def _from_h5_file(cls, h5: h5py.File) -> Wavefront: @classmethod def from_file( cls, - h5: Union[h5py.File, pathlib.Path, str], + h5: h5py.File | pathlib.Path | str, identifier: int = 0, ) -> Wavefront: """Load a Wavefront from a file in the OpenPMD format.""" @@ -1939,7 +1940,7 @@ def _write_file(self, h5: h5py.File): efield_group = h5.create_group(electric_field_path) self.write_group(efield_group) - def write(self, h5: Union[h5py.File, pathlib.Path, str]) -> None: + def write(self, h5: h5py.File | pathlib.Path | str) -> None: """Write the Wavefront in OpenPMD format.""" if isinstance(h5, h5py.File): return self._write_file(h5)