From 682c4a29c4f169c9f0761fe97d36ae1e93209dbf Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 26 Mar 2024 16:24:40 -0400 Subject: [PATCH] enforce consistent xlims on time & event param plots --- specparam/objs/time.py | 7 +++++++ specparam/plts/event.py | 8 +++++--- specparam/plts/templates.py | 5 ++++- specparam/plts/time.py | 8 +++++--- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/specparam/objs/time.py b/specparam/objs/time.py index 98becc19..dddb3583 100644 --- a/specparam/objs/time.py +++ b/specparam/objs/time.py @@ -96,6 +96,13 @@ def n_peaks_(self): if self.has_model else None + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrogram.shape[1] if self.has_data else 0 + + def _reset_time_results(self): """Set, or reset, time results to be empty.""" diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 3e37c8ca..3800add1 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -54,12 +54,14 @@ def plot_event_model(event_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 4 * n_bands])) axes = cycle(axes) + xlim = [0, time_model.n_time_windows] + # 01: aperiodic params alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] for alabel in alabels: plot_param_over_time_yshade(\ None, event_model.event_time_results[alabel], - label=alabel, drop_xticks=True, add_xlabel=False, + label=alabel, drop_xticks=True, add_xlabel=False, xlim=xlim, title='Aperiodic Parameters' if alabel == 'offset' else None, color=PARAM_COLORS[alabel], ax=next(axes)) next(axes).axis('off') @@ -69,7 +71,7 @@ def plot_event_model(event_model, **plot_kwargs): for plabel in ['cf', 'pw', 'bw']: plot_param_over_time_yshade(\ None, event_model.event_time_results[pe_labels[plabel][band_ind]], - label=plabel.upper(), drop_xticks=True, add_xlabel=False, + label=plabel.upper(), drop_xticks=True, add_xlabel=False, xlim=xlim, title='Periodic Parameters - ' + band_labels[band_ind] if plabel == 'cf' else None, color=PARAM_COLORS[plabel], ax=next(axes)) next(axes).axis('off') @@ -81,4 +83,4 @@ def plot_event_model(event_model, **plot_kwargs): drop_xticks=False if glabel == 'r_squared' else True, add_xlabel=True if glabel == 'r_squared' else False, title='Goodness of Fit' if glabel == 'error' else None, - color=PARAM_COLORS[glabel], ax=next(axes)) + color=PARAM_COLORS[glabel], xlim=xlim, ax=next(axes)) diff --git a/specparam/plts/templates.py b/specparam/plts/templates.py index 80024871..1c850280 100644 --- a/specparam/plts/templates.py +++ b/specparam/plts/templates.py @@ -190,7 +190,7 @@ def plot_yshade(x_vals, y_vals, average='mean', shade='std', scale=1., color=Non @check_dependency(plt, 'matplotlib') def plot_param_over_time(times, param, label=None, title=None, add_legend=True, add_xlabel=True, - drop_xticks=False, ax=None, **plot_kwargs): + xlim=None, drop_xticks=False, ax=None, **plot_kwargs): """Plot a parameter over time. Parameters @@ -228,6 +228,9 @@ def plot_param_over_time(times, param, label=None, title=None, add_legend=True, if drop_xticks: ax.set_xticks([], []) + if xlim: + ax.set_xlim(xlim) + if label and add_legend: ax.legend(loc='upper left', framealpha=plot_kwargs.pop('legend_framealpha', 0.9)) diff --git a/specparam/plts/time.py b/specparam/plts/time.py index d507f421..84523d45 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -52,6 +52,8 @@ def plot_time_model(time_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) axes = cycle(axes) + xlim = [0, time_model.n_time_windows] + # 01: aperiodic parameters ap_params = [time_model.time_results['offset'], time_model.time_results['exponent']] @@ -63,7 +65,7 @@ def plot_time_model(time_model, **plot_kwargs): ap_labels.insert(1, 'Knee') ap_colors.insert(1, PARAM_COLORS['knee']) - plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, + plot_params_over_time(None, ap_params, labels=ap_labels, add_xlabel=False, xlim=xlim, colors=ap_colors, title='Aperiodic Parameters', ax=next(axes)) # 02: periodic parameters @@ -73,7 +75,7 @@ def plot_time_model(time_model, **plot_kwargs): [time_model.time_results[pe_labels['cf'][band_ind]], time_model.time_results[pe_labels['pw'][band_ind]], time_model.time_results[pe_labels['bw'][band_ind]]], - labels=['CF', 'PW', 'BW'], add_xlabel=False, + labels=['CF', 'PW', 'BW'], add_xlabel=False, xlim=xlim, colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], title='Periodic Parameters - ' + band_labels[band_ind], ax=next(axes)) @@ -81,6 +83,6 @@ def plot_time_model(time_model, **plot_kwargs): plot_params_over_time(None, [time_model.time_results['error'], time_model.time_results['r_squared']], - labels=['Error', 'R-squared'], + labels=['Error', 'R-squared'], xlim=xlim, colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], title='Goodness of Fit', ax=next(axes))