diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 4006d1ab..7823542f 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -379,3 +379,64 @@ def add_data(self, freqs, spectrogram, freq_range=None): if np.any(self.freqs): self._reset_time_results() super().add_data(freqs, spectrogram, freq_range) + + +class BaseData3D(BaseData2DT): + """Base object for managing data for spectral parameterization - for 3D data.""" + + def __init__(self): + + BaseData2DT.__init__(self) + + self.spectrograms = None + + + @property + def has_data(self): + """Redefine has_data marker to reflect the spectrograms attribute.""" + + return bool(np.any(self.spectrograms)) + + + @property + def n_time_windows(self): + """How many time windows are included in the model object.""" + + return self.spectrograms[0].shape[1] if self.has_data else 0 + + + @property + def n_events(self): + """How many events are included in the model object.""" + + return len(self.spectrograms) + + + def add_data(self, freqs, spectrograms, freq_range=None): + """Add data (frequencies and spectrograms) to the current object. + + Parameters + ---------- + freqs : 1d array + Frequency values for the power spectra, in linear space. + spectrograms : 3d array or list of 2d array + Matrix of power values, in linear space. + If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + freq_range : list of [float, float], optional + Frequency range to restrict power spectra to. If not provided, keeps the entire range. + """ + + # If given a list of spectrograms, convert to 3d array + if isinstance(spectrograms, list): + spectrograms = np.array(spectrograms) + + # If is a 3d array, add to object as spectrograms + if spectrograms.ndim == 3: + + self.freqs, self.spectrograms, self.freq_range, self.freq_res = \ + self._prepare_data(freqs, spectrograms, freq_range, 3) + + # Otherwise, pass through 2d array to underlying object method + else: + super().add_data(freqs, spectrograms, freq_range) diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index 58adc205..63e887f6 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -93,3 +93,23 @@ def test_base_data2dt_add_data(): assert tdata2dt.has_data assert np.all(tdata2dt.spectrogram) assert tdata2dt.n_time_windows + +## 3D Data Object + +def test_base_data3d(): + + tdata3d = BaseData3D() + assert tdata3d + assert isinstance(tdata3d, BaseData) + assert isinstance(tdata3d, BaseData2D) + assert isinstance(tdata3d, BaseData2DT) + assert isinstance(tdata3d, BaseData3D) + +def test_base_data3d_add_data(): + + tdata3d = BaseData3D() + freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]).T + tdata3d.add_data(freqs, np.array([pows, pows])) + assert tdata3d.has_data + assert np.all(tdata3d.spectrograms) + assert tdata3d.n_events