Skip to content

Commit

Permalink
Merge pull request #44 from ChristopherMayes/plot_limits
Browse files Browse the repository at this point in the history
Plot limits
  • Loading branch information
ChristopherMayes authored Aug 8, 2023
2 parents 0b589b7 + 8bd6eed commit 5759911
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 200 deletions.
40 changes: 20 additions & 20 deletions docs/examples/particle_examples.ipynb

Large diffs are not rendered by default.

168 changes: 27 additions & 141 deletions docs/examples/plot_examples.ipynb

Large diffs are not rendered by default.

80 changes: 75 additions & 5 deletions pmd_beamphysics/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,17 +821,81 @@ def write(self, h5, name=None):

# Plotting
# --------
# TODO: more general plotting
def plot(self, key1='x', key2=None, bins=None, return_figure=False,
tex=True, **kwargs):
def plot(self, key1='x', key2=None,
bins=None,
*,
xlim=None,
ylim=None,
return_figure=False,
tex=True, nice=True,
**kwargs):
"""
1d or 2d density plot.
If one key is given, this will plot the density of that key.
Example:
.plot('x')
If two keys arg given, this will plot a 2d marginal plot.
Example:
.plot('x', 'px')
Parameters
----------
particle_group: ParticleGroup
The object to plot
key1: str, default = 't'
Key to bin on the x-axis
key2: str, default = None
Key to bin on the y-axis.
bins: int, default = None
Number of bins. If None, this will use a heuristic: bins = sqrt(n_particle/4)
xlim: tuple, default = None
Manual setting of the x-axis limits. Note that these are in raw, unscaled units.
ylim: tuple, default = None
Manual setting of the y-axis limits. Note that these are in raw, unscaled units.
tex: bool, default = True
Use TEX for labels
nice: bool, default = True
Scale to nice units
return_figure: bool, default = False
If true, return a matplotlib.figure.Figure object
**kwargs
Any additional kwargs to send to the the plot in: plt.subplots(**kwargs)
Returns
-------
None or fig: matplotlib.figure.Figure
This only returns a figure object if return_figure=T, otherwise returns None
"""

if not key2:
fig = density_plot(self, key=key1, bins=bins, tex=tex, **kwargs)
fig = density_plot(self, key=key1,
bins=bins,
xlim=xlim,
tex=tex,
nice=nice,
**kwargs)
else:
fig = marginal_plot(self, key1=key1, key2=key2, bins=bins, tex=tex, **kwargs)
fig = marginal_plot(self, key1=key1, key2=key2,
bins=bins,
xlim=xlim,
ylim=ylim,
tex=tex,
nice=nice,
**kwargs)

if return_figure:
return fig
Expand All @@ -840,7 +904,10 @@ def slice_plot(self, key='sigma_x',
n_slice=100,
slice_key=None,
tex=True,
nice=True,
return_figure=False,
xlim=None,
ylim=None,
**kwargs):
"""
Slice statistics plot.
Expand All @@ -857,6 +924,9 @@ def slice_plot(self, key='sigma_x',
n_slice=n_slice,
slice_key=slice_key,
tex=tex,
nice=nice,
xlim=xlim,
ylim=ylim,
**kwargs)

if return_figure:
Expand Down
133 changes: 113 additions & 20 deletions pmd_beamphysics/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from pmd_beamphysics.units import nice_array, nice_scale_prefix
from pmd_beamphysics.units import nice_array, plottable_array, nice_scale_prefix
from pmd_beamphysics.labels import mathlabel


Expand Down Expand Up @@ -41,10 +41,39 @@ def slice_plot(particle_group,
stat_key='sigma_x',
n_slice=40,
slice_key='z',
xlim=None,
ylim=None,
tex=True,
nice=True,
**kwargs):
"""
Complete slice plotting routine. Will plot the density of the slice key on the right axis.
Parameters
----------
particle_group: ParticleGroup
The object to plot
stat_key: str, default = 'sigma_x'
Key to calculate the statistics
n_slice: int, default = 40
Number of slices
slice_key: str, default = 'z'
Should be 'z' or 't'
ylim: tuple, default = None
Manual setting of the y-axis limits.
tex: bool, defaul = True
Use TEX for labels
Returns
-------
fig: matplotlib.figure.Figure
"""

x_key = 'mean_'+slice_key
Expand All @@ -58,9 +87,10 @@ def slice_plot(particle_group,
fig, ax = plt.subplots(**kwargs)

# Get nice arrays
x, _, prex = nice_array(slice_dat[x_key])
y, _, prey = nice_array(slice_dat[y_key])
y2, _, prey2 = nice_array(slice_dat[y2_key])
x, f1, prex, xmin, xmax = plottable_array(slice_dat[x_key], nice=nice, lim=xlim)
y, f2, prey, ymin, ymax = plottable_array(slice_dat[y_key], nice=nice, lim=ylim)
# Density on r.h.s
y2, _, prey2, _, _ = plottable_array(slice_dat[y2_key], nice=nice, lim=None)

x_units = f'{prex}{particle_group.units(x_key)}'
y_units = f'{prey}{particle_group.units(y_key)}'
Expand All @@ -72,8 +102,6 @@ def slice_plot(particle_group,
y2_units = prey2+y2_units

# Labels


labelx = mathlabel(slice_key, units=x_units, tex=tex)
labely = mathlabel(y_key, units=y_units, tex=tex)
labely2 = mathlabel(y2_key, units=y2_units, tex=tex)
Expand All @@ -83,18 +111,30 @@ def slice_plot(particle_group,

# Main plot
ax.plot(x, y, color = 'black')

#ax.set_ylim(0, 1.1*ymax )

# rhs plot
ax2 = ax.twinx()
ax2.set_ylabel(labely2)
ax2.fill_between(x, 0, y2, color='black', alpha = 0.2)
ax2.set_ylim(0, None)

# Actual plot limits, considering scaling
if xlim:
ax.set_xlim( xmin/f1, xmax/f1)
if ylim:
ax.set_ylim( ymin/f2, ymax/f2)

return fig



def density_plot(particle_group, key='x', bins=None, tex=True, **kwargs):
def density_plot(particle_group, key='x',
bins=None,
*,
xlim=None,
tex=True,
nice=True,
**kwargs):
"""
1D density plot. Also see: marginal_plot
Expand All @@ -109,7 +149,7 @@ def density_plot(particle_group, key='x', bins=None, tex=True, **kwargs):
bins = int(n/100)

# Scale to nice units and get the factor, unit prefix
x, f1, p1 = nice_array(particle_group[key])
x, f1, p1, xmin, xmax = plottable_array(particle_group[key], nice=nice, lim=xlim)
w = particle_group['weight']
u1 = particle_group.units(key).unitSymbol
ux = p1+u1
Expand All @@ -134,26 +174,72 @@ def density_plot(particle_group, key='x', bins=None, tex=True, **kwargs):

ax.set_xlabel(labelx)

# Limits
if xlim:
ax.set_xlim(xmin/f1, xmax/f1)

return fig

def marginal_plot(particle_group, key1='t', key2='p', bins=None, tex=True, **kwargs):
def marginal_plot(particle_group, key1='t', key2='p',
bins=None,
*,
xlim=None,
ylim=None,
tex=True,
nice=True,
**kwargs):
"""
Density plot and projections
Example:
marginal_plot(P, 't', 'energy', bins=200)
Parameters
----------
particle_group: ParticleGroup
The object to plot
"""
key1: str, default = 't'
Key to bin on the x-axis
key2: str, default = 'p'
Key to bin on the y-axis
bins: int, default = None
Number of bins. If None, this will use a heuristic: bins = sqrt(n_particle/4)
xlim: tuple, default = None
Manual setting of the x-axis limits.
ylim: tuple, default = None
Manual setting of the y-axis limits.
tex: bool, default = True
Use TEX for labels
nice: bool, default = True
Returns
-------
fig: matplotlib.figure.Figure
"""
if not bins:
n = len(particle_group)
bins = int(np.sqrt(n/4) )

# Scale to nice units and get the factor, unit prefix
x, f1, p1 = nice_array(particle_group[key1])
y, f2, p2 = nice_array(particle_group[key2])

x = particle_group[key1]
y = particle_group[key2]

# Form nice arrays
x, f1, p1, xmin, xmax = plottable_array(x, nice=nice, lim=xlim)
y, f2, p2, ymin, ymax = plottable_array(y, nice=nice, lim=ylim)

w = particle_group['weight']

u1 = particle_group.units(key1).unitSymbol
Expand Down Expand Up @@ -183,8 +269,6 @@ def marginal_plot(particle_group, key1='t', key2='p', bins=None, tex=True, **kwa
#extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
#ax_joint.imshow(H.T, cmap=cmap, vmin=1e-16, origin='lower', extent=extent, aspect='auto')



# Top histogram
# Old method:
#dx = x.ptp()/bins
Expand Down Expand Up @@ -220,8 +304,17 @@ def marginal_plot(particle_group, key1='t', key2='p', bins=None, tex=True, **kwa
# Set labels on joint
ax_joint.set_xlabel(labelx)
ax_joint.set_ylabel(labely)

# Actual plot limits, considering scaling
if xlim:
ax_joint.set_xlim( xmin/f1, xmax/f1)
ax_marg_x.set_xlim(xmin/f1, xmax/f1)

if ylim:
ax_joint.set_ylim( ymin/f2, ymax/f2)
ax_marg_y.set_ylim(ymin/f2, ymax/f2)

return fig
return fig


def density_and_slice_plot(particle_group, key1='t', key2='p', stat_keys=['norm_emit_x', 'norm_emit_y'], bins=100, n_slice=30, tex=True):
Expand All @@ -235,8 +328,8 @@ def density_and_slice_plot(particle_group, key1='t', key2='p', stat_keys=['norm_
"""

# Scale to nice units and get the factor, unit prefix
x, f1, p1 = nice_array(particle_group[key1])
y, f2, p2 = nice_array(particle_group[key2])
x, f1, p1, xmin, xmax = plottable_array(particle_group[key1])
y, f2, p2, ymin, ymax = plottable_array(particle_group[key2])
w = particle_group['weight']

u1 = particle_group.units(key1).unitSymbol
Expand Down
Loading

0 comments on commit 5759911

Please sign in to comment.