Skip to content

Commit

Permalink
add base3d
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 9, 2024
1 parent 97e0531 commit 371383f
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 5 deletions.
109 changes: 104 additions & 5 deletions specparam/objs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from specparam.core.utils import unlog
from specparam.core.items import OBJ_DESC
from specparam.core.errors import NoDataError
from specparam.core.io import save_model, load_json
from specparam.core.io import save_group, load_jsonlines
from specparam.core.io import (save_model, save_group, save_event,
load_json, load_jsonlines, get_files)
from specparam.core.modutils import copy_doc_func_to_method
from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT
from specparam.objs.data import BaseData, BaseData2D, BaseData2DT
from specparam.plts.event import plot_event_model
from specparam.objs.fit import BaseFit, BaseFit2D, BaseFit2DT, BaseFit3D
from specparam.objs.data import BaseData, BaseData2D, BaseData2DT, BaseData3D

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -240,7 +241,7 @@ def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True):
self._reset_data_results(True, True, True, True)
self._reset_group_results()

super().add_data(freqs, power_spectra, freq_range=None)
super().add_data(freqs, power_spectra, freq_range=freq_range)


@copy_doc_func_to_method(save_group)
Expand Down Expand Up @@ -345,3 +346,101 @@ def load(self, file_name, file_path=None, peak_org=None):
super().load(file_name, file_path=file_path)
if peak_org is not False and self.group_results:
self.convert_results(peak_org)


class BaseObject3D(BaseObject2DT, BaseFit3D, BaseData3D):
"""Define Base object for fitting models to 3D data."""

def __init__(self, aperiodic_mode=None, periodic_mode=None, debug_mode=False, verbose=True):

BaseObject2DT.__init__(self)
BaseData3D.__init__(self)
BaseFit3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode,
debug_mode=debug_mode, verbose=verbose)


def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True):
"""Add data (frequencies and spectrograms) to the current object.
Parameters
----------
freqs : 1d array
Frequency values for the power spectra, in linear space.
spectrograms : 3d array or list of 2d array
Matrix of power values, in linear space.
If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows].
If a 3d array, should have shape [n_events, n_freqs, n_time_windows].
freq_range : list of [float, float], optional
Frequency range to restrict power spectra to. If not provided, keeps the entire range.
clear_results : bool, optional, default: True
Whether to clear prior results, if any are present in the object.
This should only be set to False if data for the current results are being re-added.
Notes
-----
If called on an object with existing data and/or results these will be cleared
by this method call, unless explicitly set not to.
"""

if clear_results:
self._reset_event_results()

super().add_data(freqs, spectrograms, freq_range=freq_range)


@copy_doc_func_to_method(save_event)
def save(self, file_name, file_path=None, append=False,
save_results=False, save_settings=False, save_data=False):

save_event(self, file_name, file_path, append, save_results, save_settings, save_data)


def load(self, file_name, file_path=None, peak_org=None):
"""Load data from file(s).
Parameters
----------
file_name : str
File(s) to load data from.
file_path : str, optional
Path to directory to load from. If None, loads from current directory.
peak_org : int or Bands, optional
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
"""

files = get_files(file_path, select=file_name)
spectrograms = []
for file in files:
super().load(file, file_path, peak_org=False)
if self.group_results:
self.add_results(self.group_results, append=True)
if np.all(self.power_spectra):
spectrograms.append(self.spectrogram)
self.spectrograms = np.array(spectrograms) if spectrograms else None

self._reset_group_results()
if peak_org is not False and self.event_group_results:
self.convert_results(peak_org)


# TO CHECK - DOES THIS GO HERE?
def _reset_data_results(self, clear_freqs=False, clear_spectrum=False,
clear_results=False, clear_spectra=False):
"""Set, or reset, data & results attributes to empty.
Parameters
----------
clear_freqs : bool, optional, default: False
Whether to clear frequency attributes.
clear_spectrum : bool, optional, default: False
Whether to clear power spectrum attribute.
clear_results : bool, optional, default: False
Whether to clear model results attributes.
clear_spectra : bool, optional, default: False
Whether to clear power spectra attribute.
"""

self._reset_data(clear_freqs, clear_spectrum, clear_spectra)
self._reset_results(clear_results)
13 changes: 13 additions & 0 deletions specparam/tests/objs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,21 @@ def test_base2d():

## 2DT Base Object

def test_base2dt():

tobj2dt = BaseObject2DT()
assert isinstance(tobj2dt, CommonBase)
assert isinstance(tobj2dt, BaseObject2DT)
assert isinstance(tobj2dt, BaseFit2DT)
assert isinstance(tobj2dt, BaseObject2DT)

## 3D Base Object

def test_base3d():

tobj3d = BaseObject3D()
assert isinstance(tobj3d, CommonBase)
assert isinstance(tobj3d, BaseObject2DT)
assert isinstance(tobj3d, BaseFit2DT)
assert isinstance(tobj3d, BaseObject2DT)
assert isinstance(tobj3d, BaseObject3D)

0 comments on commit 371383f

Please sign in to comment.