Skip to content

Commit

Permalink
ENH: single plane plotting framework
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Aug 15, 2024
1 parent e8c66ef commit 93658ca
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 1 deletion.
130 changes: 129 additions & 1 deletion pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import copy
import dataclasses
import logging
import pathlib
from typing import Any, List, Optional, Sequence, Tuple, Union

import matplotlib
import matplotlib.axes
import matplotlib.pyplot as plt
import numpy as np
import scipy.constants
import scipy.fft
Expand All @@ -16,6 +20,8 @@

_fft_workers = -1
Ranges = Sequence[Tuple[float, float]]
AnyPath = Union[str, pathlib.Path]
Plane = Union[str, Tuple[int, int]]


def get_num_fft_workers() -> int:
Expand Down Expand Up @@ -842,7 +848,7 @@ def ranges(self):

def focus(
self,
plane: Union[str, Tuple[int, int]],
plane: Plane,
focus: Tuple[float, float],
*,
inplace: bool = False,
Expand Down Expand Up @@ -927,3 +933,125 @@ def propagate(
# Invalidate the real space data
self._field_rspace = None
return self

def plot(
self,
plane: Plane,
*,
rspace: bool = True,
show_real: bool = True,
show_imaginary: bool = True,
show_abs: bool = True,
show_phase: bool = True,
axs: Optional[List[matplotlib.axes.Axes]] = None,
cmap: str = "viridis",
figsize: Optional[Tuple[int, int]] = None,
nrows: int = 2,
ncols: int = 2,
xlim: Optional[Tuple[int, int]] = None,
ylim: Optional[Tuple[int, int]] = None,
tight_layout: bool = True,
save: Optional[AnyPath] = None,
):
"""
Plot the projection onto the given plane.
Parameters
----------
plane : str or (int, int)
Plane to plot, e.g., "xy" or (1, 2).
rspace : bool, default=True
Plot the real/cartesian space data.
show_real : bool
Show the projection of the real portion of the data.
show_imaginary : bool
Show the projection of the imaginary portion of the data.
show_abs : bool
Show the projection of the absolute value of the data.
show_phase : bool
Show the projection of the phase of the data.
figsize : (int, int), optional
Figure size for the axes.
Defaults to Matplotlib's `rcParams["figure.figsize"]``.
axs : List[matplotlib.axes.Axes], optional
Plot the data in the provided matplotlib Axes.
Creates a new figure and Axes if not specified.
cmap : str, default="viridis"
Color map to use.
nrows : int, default=2
Number of rows for the plot.
ncols : int, default=2
Number of columns for the plot.
save : pathlib.Path or str, optional
Save the plot to the given filename.
xlim : (float, float), optional
X axis limits.
ylim : (float, float), optional
Y axis limits.
tight_layout : bool, default=True
Set a tight layout.
Returns
-------
Figure
list of Axes
"""
if rspace:
data = self.field_rspace
else:
data = self.field_kspace

sum_axis = {
# TODO: when standardized, this will be xyz instead of txy
"xy": 0,
(1, 2): 0,
}[plane]

if axs is None:
fig, gs = plt.subplots(
nrows=nrows,
ncols=ncols,
sharex=True,
sharey=True,
squeeze=False,
figsize=figsize,
)
axs = list(gs.flatten())
fig.suptitle(f"{plane}")
else:
fig = axs[0].get_figure()
assert fig is not None

remaining_axes = list(axs)

def plot(dat, title: str):
ax = remaining_axes.pop(0)
ax.imshow(np.sum(dat, axis=sum_axis), cmap=cmap)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
if not ax.get_title():
ax.set_title(title)

if show_real:
plot(np.real(data), title="Real")

if show_imaginary:
plot(np.imag(data), title="Imaginary")

if show_abs:
plot(np.abs(data), f"|{plane}|")

if show_phase:
plot(np.angle(data), title="Phase")

if fig is not None:
if tight_layout:
fig.tight_layout()

if save:
logger.info(f"Saving plot to {save!r}")
fig.savefig(save)

return fig, axs
35 changes: 35 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pathlib
import re

import matplotlib
import matplotlib.pyplot as plt
import pytest

matplotlib.use("Agg")

test_root = pathlib.Path(__file__).parent.resolve()
test_artifacts = test_root / "artifacts"


@pytest.fixture(autouse=True, scope="function")
def _plot_show_to_savefig(
request: pytest.FixtureRequest,
monkeypatch: pytest.MonkeyPatch,
):
index = 0

def savefig():
nonlocal index
test_artifacts.mkdir(exist_ok=True)
for fignum in plt.get_fignums():
plt.figure(fignum)
name = re.sub(r"[/\\]", "_", request.node.name)
filename = test_artifacts / f"{name}_{index}.png"
print(f"Saving figure (_plot_show_to_savefig fixture) to {filename}")
plt.savefig(filename)
index += 1
plt.close("all")

monkeypatch.setattr(plt, "show", savefig)
yield
plt.show()
16 changes: 16 additions & 0 deletions tests/test_wavefront.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import copy

import matplotlib.pyplot as plt
import numpy as np
import pytest

from pmd_beamphysics import Wavefront
from pmd_beamphysics.wavefront import (
Plane,
WavefrontPadding,
get_num_fft_workers,
set_num_fft_workers,
Expand All @@ -30,6 +32,13 @@ def wavefront() -> Wavefront:
)


@pytest.fixture(
params=["xy"],
)
def projection_plane(request: pytest.FixtureRequest) -> str:
return request.param


def test_smoke_propagate_z_in_place(wavefront: Wavefront) -> None:
# Implicitly calculates the FFT:
wavefront.propagate(direction="z", distance=0.0, inplace=True)
Expand Down Expand Up @@ -99,3 +108,10 @@ def test_deepcopy(wavefront: Wavefront) -> None:
assert copied is not wavefront
assert copied.field_rspace is not wavefront.field_rspace
assert copied.field_kspace is not wavefront.field_kspace


def test_plot_projection(wavefront: Wavefront, projection_plane: Plane) -> None:
wavefront.plot(projection_plane, rspace=True)
plt.suptitle(f"rspace - {projection_plane}")
wavefront.plot(projection_plane, rspace=False)
plt.suptitle(f"kspace - {projection_plane}")

0 comments on commit 93658ca

Please sign in to comment.