Skip to content

Commit

Permalink
update data checks for 3D properly
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 8, 2024
1 parent 83ae693 commit 97e0531
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions specparam/objs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,16 @@ def _regenerate_freqs(self):
self.freqs = gen_freqs(self.freq_range, self.freq_res)


def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1):
"""Prepare input data for adding to current object.
Parameters
----------
freqs : 1d array
Frequency values for the power_spectrum, in linear space.
power_spectrum : 1d or 2d array
Frequency values for `powers`, in linear space.
powers : 1d or 2d or 3d array
Power values, which must be input in linear space.
1d vector, or 2d as [n_power_spectra, n_freqs].
1d vector, or 2d as [n_spectra, n_freqs], or 3d as [n_events, n_spectra, n_freqs].
freq_range : list of [float, float]
Frequency range to restrict power spectrum to.
If None, keeps the entire range.
Expand All @@ -170,10 +170,10 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
Returns
-------
freqs : 1d array
Frequency values for the power_spectrum, in linear space.
power_spectrum : 1d or 2d array
Frequency values for `powers`, in linear space.
powers : 1d or 2d or 3d array
Power spectrum values, in log10 scale.
1d vector, or 2d as [n_power_specta, n_freqs].
1d vector, or 2d as [n_spectra, n_freqs], or 3d as [n_events, n_spectra, n_freqs].
freq_range : list of [float, float]
Minimum and maximum values of the frequency vector.
freq_res : float
Expand All @@ -188,38 +188,39 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
"""

# Check that data are the right types
if not isinstance(freqs, np.ndarray) or not isinstance(power_spectrum, np.ndarray):
if not isinstance(freqs, np.ndarray) or not isinstance(powers, np.ndarray):
raise DataError("Input data must be numpy arrays.")

# Check that data have the right dimensionality
if freqs.ndim != 1 or (power_spectrum.ndim != spectra_dim):
if freqs.ndim != 1 or (powers.ndim != spectra_dim):
raise DataError("Inputs are not the right dimensions.")

# Check that data sizes are compatible
if freqs.shape[-1] != power_spectrum.shape[-1]:
if (spectra_dim < 3 and freqs.shape[-1] != powers.shape[-1]) or \
spectra_dim == 3 and freqs.shape[-1] != powers.shape[1]:
raise InconsistentDataError("The input frequencies and power spectra "
"are not consistent size.")

# Check if power values are complex
if np.iscomplexobj(power_spectrum):
if np.iscomplexobj(powers):
raise DataError("Input power spectra are complex values. "
"Model fitting does not currently support complex inputs.")

# Force data to be dtype of float64
# If they end up as float32, or less, scipy curve_fit fails (sometimes implicitly)
if freqs.dtype != 'float64':
freqs = freqs.astype('float64')
if power_spectrum.dtype != 'float64':
power_spectrum = power_spectrum.astype('float64')
if powers.dtype != 'float64':
powers = powers.astype('float64')

# Check frequency range, trim the power_spectrum range if requested
# Check frequency range, trim the power values range if requested
if freq_range:
freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, freq_range)
freqs, powers = trim_spectrum(freqs, powers, freq_range)

# Check if freqs start at 0 and move up one value if so
# Aperiodic fit gets an inf if freq of 0 is included, which leads to an error
if freqs[0] == 0.0:
freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()])
freqs, powers = trim_spectrum(freqs, powers, [freqs[1], freqs.max()])
if self.verbose:
print("\nFITTING WARNING: Skipping frequency == 0, "
"as this causes a problem with fitting.")
Expand All @@ -229,7 +230,7 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
freq_res = freqs[1] - freqs[0]

# Log power values
power_spectrum = np.log10(power_spectrum)
powers = np.log10(powers)

## Data checks - run checks on inputs based on check modes

Expand All @@ -241,14 +242,14 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
"The model expects equidistant frequency values in linear space.")
if self._check_data:
# Check if there are any infs / nans, and raise an error if so
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
if np.any(np.isinf(powers)) or np.any(np.isnan(powers)):
error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. "
"This will cause the fitting to fail. "
"One reason this can happen is if inputs are already logged. "
"Input data should be in linear spacing, not log.")
raise DataError(error_msg)

return freqs, power_spectrum, freq_range, freq_res
return freqs, powers, freq_range, freq_res


class BaseData2D(BaseData):
Expand Down

0 comments on commit 97e0531

Please sign in to comment.