diff --git a/specparam/plts/utils.py b/specparam/plts/utils.py index f9156282..70fd1372 100644 --- a/specparam/plts/utils.py +++ b/specparam/plts/utils.py @@ -65,7 +65,8 @@ def set_alpha(n_points): return alpha -def add_shades(ax, shades, colors='r', add_center=False, logged=False): +def add_shades(ax, shades, colors='r', alpha=0.2, + add_center=False, center_alpha=0.6, logged=False): """Add shaded regions to a plot. Parameters @@ -76,8 +77,13 @@ def add_shades(ax, shades, colors='r', add_center=False, logged=False): Shaded region(s) to add to plot, defined as [lower_bound, upper_bound]. colors : str or list of string Color(s) to plot shades. + alpha : float or list of float, optional, default: 0.2 + The alpha level to add the shade regions with. + If a list, can specify a separate alpha level per shade. add_center : boolean, default: False Whether to add a line at the center point of the shaded regions. + center_alpha : float, optional, default: 0.6 + The alpha level for the center line, if added. logged : boolean, default: False Whether the shade values should be logged before applying to plot axes. """ @@ -87,16 +93,17 @@ def add_shades(ax, shades, colors='r', add_center=False, logged=False): shades = [shades] colors = repeat(colors) if not isinstance(colors, list) else colors + alphas = repeat(alpha) if not isinstance(alpha, list) else alpha - for shade, color in zip(shades, colors): + for shade, color, alpha in zip(shades, colors, alphas): shade = np.log10(shade) if logged else shade - ax.axvspan(shade[0], shade[1], color=color, alpha=0.2, lw=0) + ax.axvspan(shade[0], shade[1], color=color, alpha=alpha, lw=0) if add_center: center = sum(shade) / 2 - ax.axvspan(center, center, color='k', alpha=0.6) + ax.axvspan(center, center, color='k', alpha=center_alpha) def recursive_plot(data, plot_function, ax, **kwargs): diff --git a/specparam/tests/plts/test_utils.py b/specparam/tests/plts/test_utils.py index edfe80d1..a88accfa 100644 --- a/specparam/tests/plts/test_utils.py +++ b/specparam/tests/plts/test_utils.py @@ -33,6 +33,10 @@ def test_add_shades(skip_if_no_mpl): add_shades(check_ax(None), [4, 8]) +@plot_test +def test_add_shades_multi(skip_if_no_mpl): + add_shades(check_ax(None), [[4, 8], [8, 12], [12, 25]], colors=['b', 'c', 'y'], alpha=0.3) + @plot_test def test_recursive_plot(skip_if_no_mpl):