diff --git a/specparam/core/strings.py b/specparam/core/strings.py index 01962822..8d2e8d21 100644 --- a/specparam/core/strings.py +++ b/specparam/core/strings.py @@ -4,6 +4,7 @@ from specparam.core.errors import NoModelError from specparam.data.utils import get_periodic_labels +from specparam.utils.data import compute_presence from specparam.version import __version__ as MODULE_VERSION ################################################################################################### diff --git a/specparam/utils/data.py b/specparam/utils/data.py index 181bab0f..b7759896 100644 --- a/specparam/utils/data.py +++ b/specparam/utils/data.py @@ -108,15 +108,11 @@ def compute_presence(data, average=False, output='ratio'): assert output in ['ratio', 'percent'], 'Setting for output type not understood.' - if data.ndim == 1: - presence = sum(~np.isnan(data)) / len(data) + if data.ndim == 1 or average: + presence = np.sum(~np.isnan(data)) / data.size elif data.ndim == 2: - if average: - presence = compute_presence(data.flatten()) - else: - n_events, n_windows = data.shape - presence = np.sum(~np.isnan(data), 0) / (np.ones(n_windows) * n_events) + presence = np.sum(~np.isnan(data), 0) / (np.ones(data.shape[1]) * data.shape[0]) if output == 'percent': presence *= 100