Skip to content

Commit

Permalink
first pass at reasonable plotting for polarized wfs
Browse files Browse the repository at this point in the history
  • Loading branch information
kvangorkom committed Feb 12, 2025
1 parent 7decee4 commit aacbb77
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 88 deletions.
257 changes: 185 additions & 72 deletions notebooks/Polarization_Demo.ipynb

Large diffs are not rendered by default.

33 changes: 27 additions & 6 deletions poppy/polarized_wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,34 @@ def stokes_parameters(self):
if self.input_stokes_vector is None:
raise ValueError('Stokes parameters cannot be computed unless input_stokes_vector is supplied!')
return jones_to_stokes(self.wavefront, self.input_stokes_vector)

def display_stokes():
"""TO DO: only one of display_stokes and display_vector is valid, depending on self.pol_type"""
raise NotImplementedError()

def display_vector():
raise NotImplementedError()
def display_tensor(self, *args, **kwargs):
""" Display the vector or tensor field """

if self.pol_type == 'vector':
nrows = 2
indices = [0,1]
else: # tensor
nrows = 4
indices = [(0,0), (0,1), (1,0), (1,1)]

axes = []
for n, idx in enumerate(indices):
ax = super(BasePolarizedWavefront, self).display(
*args,
nrows=nrows,
row=n+1,
tensor_idx=idx,
**kwargs)
title = ax.title
title_text = title.get_text()
title.set_text(str(idx))
axes.append(ax)
fig = ax.get_figure()
fig.suptitle(title_text)
return axes



class PolarizedWavefront(BasePolarizedWavefront, Wavefront):
'''
Expand Down
86 changes: 78 additions & 8 deletions poppy/poppy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,19 +344,20 @@ def display(self, what='intensity', nrows=1, row=1, showpadding=False,
imagecrop=None, pupilcrop=None,
colorbar=False, crosshairs=False, ax=None, title=None, vmin=None,
vmax=None, vmax_wfe=None, scale=None, use_angular_coordinates=None,
angular_coordinate_unit=u.arcsec):
angular_coordinate_unit=u.arcsec, tensor_idx=None):
"""Display wavefront on screen
Parameters
----------
what : string
What to display. Must be one of {intensity, phase, wfe, best, 'both'}.
What to display. Must be one of {'intensity', 'phase', 'wfe', 'best', 'both', or 'stokes'}.
'intensity' shows the wavefront intensity, 'wfe' shows the wavefront
error in meters or microns, 'phase' is similar to 'wfe' but shows wavefront
phase in radians at the given wavelength.
'Best' implies to display the phase if there is nonzero OPD,
or else display the intensity for a perfect pupil.
'both' will show two panels, for the wavefront intensity and wavefront error.
'stokes' will show four panels, for each of the Stokes parameters.
nrows : int
Number of rows to display in current figure (used for
showing steps in a calculation)
Expand Down Expand Up @@ -408,6 +409,14 @@ def display(self, what='intensity', nrows=1, row=1, showpadding=False,
(Default: None, infer coordinates from planetype)
angular_coordinate_unit : astropy unit
Unit to use for angular coordinates display; default is arcsecond.
tensor_idx : int or tuple, optional
Index of the tensor element to display with Polarized Wavefronts.
If not provided and plotting anything other than intensity or
Stokes parameters, defaults to 0 or (0,0) in the case of vector
and tensor fields, respectively. If plotting intensity and tensor_idx=None,
will plot the total vector intensity or I Stokes parameter; otherwise,
plots the intensity of the specified vector/tensor element.
Ignored if what='stokes'.
Returns
-------
Expand All @@ -419,15 +428,44 @@ def display(self, what='intensity', nrows=1, row=1, showpadding=False,

if row is None:
row = self.current_plane_index

intens = self.intensity.copy()

# handle polarized wavefronts
from poppy.polarized_wavefront import BasePolarizedWavefront
is_polarized = isinstance(self, BasePolarizedWavefront)
has_stokes = is_polarized and (self.input_stokes_vector is not None)
if is_polarized:
# only intensity is well-defined when tensor_idx is not provided.
# for all other cases, fall back to 0 or (0,0) field element.
if (what in ['phase', 'wfe', 'both', 'best']) and (tensor_idx is None):
if has_stokes:
tensor_idx = (0,0)
else:
tensor_idx = 0
_log.warning(f'tensor_idx not provided! Plotting {tensor_idx} element of vector/tensor field.')

# if plotting intensity and tensor_idx is not supplied, then
# plot the I Stokes parameter or total vector intensity
if (what == 'intensity') and (tensor_idx is None):
intens = self.intensity.copy()
else: # not intensity, or tensor_idx is supplied
intens = xp.abs(self.wavefront)[tensor_idx]**2
amp = self.amplitude[tensor_idx].copy()
phase = self.phase[tensor_idx].copy()

# grab the stokes vector if available
if has_stokes:
stokes = self.stokes_parameters.copy()
else:
# non-polarized case
intens = self.intensity.copy()
phase = self.phase.copy()
amp = self.amplitude.copy()

# make a version of the phase where we try to mask out
# areas with particularly low intensity
phase = self.phase.copy()
mean_intens = np.mean(intens[intens != 0])
phase[intens < mean_intens / 100] = np.nan
amp = self.amplitude
if what in ['phase', 'wfe', 'both', 'best']: # only compute this if it might be used
phase[intens < mean_intens / 100] = np.nan

y, x = self.coordinates()
# GPU arrays don't work in matplotlib
Expand All @@ -438,6 +476,8 @@ def display(self, what='intensity', nrows=1, row=1, showpadding=False,
intens = utils.remove_padding(intens, self.oversample)
phase = utils.remove_padding(phase, self.oversample)
amp = utils.remove_padding(amp, self.oversample)
if has_stokes:
stokes = utils.remove_padding(stokes, self.oversample)
y = utils.remove_padding(y, self.oversample)
x = utils.remove_padding(x, self.oversample)

Expand Down Expand Up @@ -475,6 +515,9 @@ def display(self, what='intensity', nrows=1, row=1, showpadding=False,
# what = 'intensity' # show intensity for coronagraphic downstream propagation.
else:
what = 'phase' # for aberrated pupils
# for partially polarized wavefronts, best always shows the stokes parameters
if has_stokes:
what = 'stokes'

# compute plot parameters for the subplot grid
nc = int(np.ceil(np.sqrt(nrows)))
Expand Down Expand Up @@ -509,6 +552,11 @@ def display(self, what='intensity', nrows=1, row=1, showpadding=False,
vmx = np.clip(max(vmax, np.abs(vmin)), -np.pi, np.pi)
norm_phase = matplotlib.colors.Normalize(vmin=-vmx, vmax=vmx)

# norm and colormap for stokes
if has_stokes:
norm_stokes = matplotlib.colors.Normalize(vmin=stokes.min(), vmax=stokes.max())
cmap_stokes = copy.copy(getattr(matplotlib.cm, conf.cmap_diverging))

def wrap_lines_title(title):
# Helper fn to add line breaks in plot titles,
# tweaked to put in particular places for aesthetics
Expand Down Expand Up @@ -633,9 +681,31 @@ def wrap_lines_title(title):
plt.colorbar(ax.images[0], ax=ax, orientation='vertical', shrink=0.8)
plot_axes = [ax]
to_return = ax
elif what == 'stokes':
nstokes = 4
stokes_names = ['I','Q','U','V']

ax = plt.subplot(nrows, 1, row)
if title is None:
title = wrap_lines_title("Stokes " + self.location)
ax.set_title(title)
ax.set_frame_on(False)
ax.axis('off')

plot_axes = to_return = []
for n in range(nstokes):
ax = plt.subplot(nrows, 4, 4*(row - 1) + n + 1)
plt.imshow(stokes[n], extent=extent, cmap=cmap_stokes, norm=norm_stokes, origin='lower')
ax.set_title(stokes_names[n])

ax.set_ylabel(unit_label)
ax.set_xlabel(unit_label)
if colorbar:
plt.colorbar(orientation='vertical', ax=ax, shrink=0.8)
plot_axes.append(ax)
else:
raise ValueError("Invalid value for what to display; must be: "
"'intensity', 'amplitude', 'phase', or 'both'.")
"'intensity', 'amplitude', 'phase', 'stokes', or 'both'.")

# now apply axes cropping and/or overplots, if requested.
for ax in plot_axes:
Expand Down
4 changes: 2 additions & 2 deletions poppy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,12 +1250,12 @@ def pad_or_crop_to_shape(array, target_shape):

def remove_padding(array, oversample):
""" Remove zeros around the edge of an array, assuming some integer oversampling padding factor """
npix = array.shape[0] / oversample
npix = array.shape[-1] / oversample
n0 = float(npix) * (oversample - 1) / 2
n1 = n0 + npix
n0 = int(round(n0))
n1 = int(round(n1))
return array[n0:n1, n0:n1].copy()
return array[..., n0:n1, n0:n1].copy()


# Back compatibility alias:
Expand Down

0 comments on commit aacbb77

Please sign in to comment.