Skip to content

Commit

Permalink
add data3d
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 9, 2024
1 parent 371383f commit f9f1553
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
61 changes: 61 additions & 0 deletions specparam/objs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,64 @@ def add_data(self, freqs, spectrogram, freq_range=None):
if np.any(self.freqs):
self._reset_time_results()
super().add_data(freqs, spectrogram, freq_range)


class BaseData3D(BaseData2DT):
"""Base object for managing data for spectral parameterization - for 3D data."""

def __init__(self):

BaseData2DT.__init__(self)

self.spectrograms = None


@property
def has_data(self):
"""Redefine has_data marker to reflect the spectrograms attribute."""

return bool(np.any(self.spectrograms))


@property
def n_time_windows(self):
"""How many time windows are included in the model object."""

return self.spectrograms[0].shape[1] if self.has_data else 0


@property
def n_events(self):
"""How many events are included in the model object."""

return len(self.spectrograms)


def add_data(self, freqs, spectrograms, freq_range=None):
"""Add data (frequencies and spectrograms) to the current object.
Parameters
----------
freqs : 1d array
Frequency values for the power spectra, in linear space.
spectrograms : 3d array or list of 2d array
Matrix of power values, in linear space.
If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows].
If a 3d array, should have shape [n_events, n_freqs, n_time_windows].
freq_range : list of [float, float], optional
Frequency range to restrict power spectra to. If not provided, keeps the entire range.
"""

# If given a list of spectrograms, convert to 3d array
if isinstance(spectrograms, list):
spectrograms = np.array(spectrograms)

# If is a 3d array, add to object as spectrograms
if spectrograms.ndim == 3:

self.freqs, self.spectrograms, self.freq_range, self.freq_res = \
self._prepare_data(freqs, spectrograms, freq_range, 3)

# Otherwise, pass through 2d array to underlying object method
else:
super().add_data(freqs, spectrograms, freq_range)
20 changes: 20 additions & 0 deletions specparam/tests/objs/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,23 @@ def test_base_data2dt_add_data():
assert tdata2dt.has_data
assert np.all(tdata2dt.spectrogram)
assert tdata2dt.n_time_windows

## 3D Data Object

def test_base_data3d():

tdata3d = BaseData3D()
assert tdata3d
assert isinstance(tdata3d, BaseData)
assert isinstance(tdata3d, BaseData2D)
assert isinstance(tdata3d, BaseData2DT)
assert isinstance(tdata3d, BaseData3D)

def test_base_data3d_add_data():

tdata3d = BaseData3D()
freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]]).T
tdata3d.add_data(freqs, np.array([pows, pows]))
assert tdata3d.has_data
assert np.all(tdata3d.spectrograms)
assert tdata3d.n_events

0 comments on commit f9f1553

Please sign in to comment.