From 20b6dc255872496cd8aa30bee7cb6d8f56e2d1d2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Wed, 10 Apr 2024 09:48:51 -0400 Subject: [PATCH] add load funcs & io updates --- specparam/core/io.py | 119 +++++++++++++++++++++++++++++ specparam/tests/core/test_io.py | 38 ++++++--- specparam/tests/objs/test_model.py | 8 +- specparam/tests/utils/test_io.py | 2 +- 4 files changed, 153 insertions(+), 14 deletions(-) diff --git a/specparam/core/io.py b/specparam/core/io.py index ebd06045..b13bd07d 100644 --- a/specparam/core/io.py +++ b/specparam/core/io.py @@ -236,6 +236,125 @@ def save_event(event, file_name, file_path=None, append=False, save_settings=save_settings, save_data=save_data) +def load_model(file_name, file_path=None, regenerate=True, model=None): + """Load a SpectralModel object. + + Parameters + ---------- + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + regenerate : bool, optional, default: True + Whether to regenerate the model fit from the loaded data, if data is available. + model : SpectralModel + xx + + Returns + ------- + model : SpectralModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not model: + from specparam.objs import SpectralModel + model = SpectralModel() + + model.load(file_name, file_path, regenerate) + + return model + + +def load_group(file_name, file_path=None, group=None): + """Load a SpectralGroupModel object. + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + group : SpectralGroupModel + xx + + Returns + ------- + group : SpectralGroupModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not group: + from specparam.objs import SpectralGroupModel + group = SpectralGroupModel() + + group.load(file_name, file_path) + + return group + + +def load_time(file_name, file_path=None, peak_org=None, time=None): + """Load a SpectralTimeModel object. + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + time : SpectralTimeModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not time: + from specparam.objs import SpectralTimeModel + time = SpectralTimeModel() + + time.load(file_name, file_path, peak_org) + + return time + +def load_event(file_name, file_path=None, peak_org=None, event=None): + """Load a SpectralTimeEventModel object. + + Parameters + ---------- + file_name : str + File(s) to load data from. + file_path : str, optional + Path to directory to load from. If None, loads from current directory. + peak_org : int or Bands, optional + How to organize peaks. + If int, extracts the first n peaks. + If Bands, extracts peaks based on band definitions. + + Returns + ------- + event : SpectralTimeEventModel + Loaded model object with data from file. + """ + + # Check for model object, import (avoid circular) and initialize if not + if not event: + from specparam.objs import SpectralTimeEventModel + event = SpectralTimeEventModel() + + event.load(file_name, file_path, peak_org) + + return event + + def load_json(file_name, file_path): """Load json file. diff --git a/specparam/tests/core/test_io.py b/specparam/tests/core/test_io.py index 3a3798e8..e597f9ff 100644 --- a/specparam/tests/core/test_io.py +++ b/specparam/tests/core/test_io.py @@ -41,9 +41,9 @@ def test_save_model_str(tfm): """Check saving model object data, with file specifiers as strings.""" # Test saving out each set of save elements - file_name_res = 'test_res' - file_name_set = 'test_set' - file_name_dat = 'test_dat' + file_name_res = 'test_model_res' + file_name_set = 'test_model_set' + file_name_dat = 'test_model_dat' save_model(tfm, file_name_res, TEST_DATA_PATH, False, True, False, False) save_model(tfm, file_name_set, TEST_DATA_PATH, False, False, True, False) @@ -54,14 +54,14 @@ def test_save_model_str(tfm): assert os.path.exists(TEST_DATA_PATH / (file_name_dat + '.json')) # Test saving out all save elements - file_name_all = 'test_all' + file_name_all = 'test_model_all' save_model(tfm, file_name_all, TEST_DATA_PATH, False, True, True, True) assert os.path.exists(TEST_DATA_PATH / (file_name_all + '.json')) def test_save_model_append(tfm): """Check saving fm data, appending to a file.""" - file_name = 'test_append' + file_name = 'test_model_append' save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) save_model(tfm, file_name, TEST_DATA_PATH, True, True, True, True) @@ -71,7 +71,7 @@ def test_save_model_append(tfm): def test_save_model_fobj(tfm): """Check saving fm data, with file object file specifier.""" - file_name = 'test_fileobj' + file_name = 'test_model_fileobj' # Save, using file-object: three successive lines with three possible save settings with open(TEST_DATA_PATH / (file_name + '.json'), 'w') as f_obj: @@ -163,12 +163,32 @@ def test_save_event(tfe): for ind in range(len(tfe)): assert os.path.exists(TEST_DATA_PATH / (file_name_all + '_' + str(ind) + '.json')) +def test_load_model(): + + tmodel = load_model('test_model_all', TEST_DATA_PATH) + assert tmodel + +def test_load_group(): + + tgroup = load_group('test_group_all', TEST_DATA_PATH) + assert tgroup + +def test_load_time(): + + ttime = load_time('test_time_all', TEST_DATA_PATH) + assert ttime + +def test_load_event(): + + tevent = load_event('test_event_all', TEST_DATA_PATH) + assert tevent + def test_load_json_str(): """Test loading JSON file, with str file specifier. Loads files from test_save_model_str. """ - file_name = 'test_all' + file_name = 'test_model_all' data = load_json(file_name, TEST_DATA_PATH) @@ -179,7 +199,7 @@ def test_load_json_fobj(): Loads files from test_save_model_str. """ - file_name = 'test_all' + file_name = 'test_model_all' with open(TEST_DATA_PATH / (file_name + '.json'), 'r') as f_obj: data = load_json(f_obj, '') @@ -201,7 +221,7 @@ def test_load_file_contents(): Note that is this test fails, it likely stems from an issue from saving. """ - file_name = 'test_all' + file_name = 'test_model_all' loaded_data = load_json(file_name, TEST_DATA_PATH) # Check settings diff --git a/specparam/tests/objs/test_model.py b/specparam/tests/objs/test_model.py index a6757382..eb9f9a58 100644 --- a/specparam/tests/objs/test_model.py +++ b/specparam/tests/objs/test_model.py @@ -182,7 +182,7 @@ def test_load(): # Test loading just results tfm = SpectralModel(verbose=False) - file_name_res = 'test_res' + file_name_res = 'test_model_res' tfm.load(file_name_res, TEST_DATA_PATH) # Check that result attributes get filled for result in OBJ_DESC['results']: @@ -196,7 +196,7 @@ def test_load(): # Test loading just settings tfm = SpectralModel(verbose=False) - file_name_set = 'test_set' + file_name_set = 'test_model_set' tfm.load(file_name_set, TEST_DATA_PATH) for setting in OBJ_DESC['settings']: assert getattr(tfm, setting) is not None @@ -207,7 +207,7 @@ def test_load(): # Test loading just data tfm = SpectralModel(verbose=False) - file_name_dat = 'test_dat' + file_name_dat = 'test_model_dat' tfm.load(file_name_dat, TEST_DATA_PATH) assert tfm.power_spectrum is not None # Test that settings and results are None @@ -218,7 +218,7 @@ def test_load(): # Test loading all elements tfm = SpectralModel(verbose=False) - file_name_all = 'test_all' + file_name_all = 'test_model_all' tfm.load(file_name_all, TEST_DATA_PATH) for result in OBJ_DESC['results']: assert not np.all(np.isnan(getattr(tfm, result))) diff --git a/specparam/tests/utils/test_io.py b/specparam/tests/utils/test_io.py index 36f1c9a6..1b73798f 100644 --- a/specparam/tests/utils/test_io.py +++ b/specparam/tests/utils/test_io.py @@ -15,7 +15,7 @@ def test_load_model(): - file_name = 'test_all' + file_name = 'test_model_all' tfm = load_model(file_name, TEST_DATA_PATH)