diff --git a/docs/whats_new.rst b/docs/whats_new.rst index 68bbbaa0..bc2b790b 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -20,6 +20,7 @@ Documentation Performance ~~~~~~~~~~~ +- Improved performace of :py:class:`pysd.py_backend.output.DataFrameHandler` by creating the dataframe at the end of the run (:issue:`374`). (`@easyas314159 `_ and `@enekomartinmartinez `_) Internal Changes ~~~~~~~~~~~~~~~~ diff --git a/pysd/py_backend/output.py b/pysd/py_backend/output.py index abfd8451..18ed23b1 100644 --- a/pysd/py_backend/output.py +++ b/pysd/py_backend/output.py @@ -10,6 +10,7 @@ from csv import QUOTE_NONE from pathlib import Path +from collections import defaultdict import regex as re @@ -46,8 +47,7 @@ def __init__(self, out_file=None): self.handler = DataFrameHandler(DatasetHandler(None)).handle(out_file) def set_capture_elements(self, capture_elements): - self.handler.capture_elements_step = capture_elements["step"] + \ - ["time"] + self.handler.capture_elements_step = capture_elements["step"] self.handler.capture_elements_run = capture_elements["run"] def initialize(self, model): @@ -214,23 +214,9 @@ def __init__(self, next): super().__init__(next) self.out_file = None self.ds = None - self._step = 0 + self.__step = 0 self.nc = __import__("netCDF4") - @property - def step(self): - """ - Used as time index for the output Dataset. Increases by one - at each iteration. - """ - return self._step - - def __update_step(self): - """ - Increases the _step attribute by 1 at each model iteration. - """ - self._step = self.step + 1 - def process_output(self, out_file): """ If out_file can be handled by this concrete handler, it returns @@ -267,6 +253,7 @@ def initialize(self, model): None """ + self.__step = 0 self.ds = self.nc.Dataset(self.out_file, "w") # defining global attributes @@ -295,7 +282,7 @@ def initialize(self, model): # creating the time dimension as unlimited self.ds.createDimension("time", None) # creating variables - self.__create_ds_vars(model, self.capture_elements_step) + self.__create_ds_vars(model, self.capture_elements_step + ['time']) def update(self, model): """ @@ -312,19 +299,20 @@ def update(self, model): None """ + self.ds['time'][self.__step] = model.time.round() for key in self.capture_elements_step: comp = model[key] if isinstance(comp, xr.DataArray): - self.ds[key][self.step, :] = comp.values + self.ds[key][self.__step, :] = comp.values else: - self.ds[key][self.step] = comp + self.ds[key][self.__step] = comp - self.__update_step() + self.__step += 1 def __update_run_elements(self, model): """ - Writes values of cache run elements from the cature_elements set - in the netCDF4 Dataset. + Writes values of cache run elements from the capture_elements + set in the netCDF4 Dataset. Cache run elements do not have the time dimension. Parameters @@ -429,6 +417,7 @@ def __init__(self, next): super().__init__(next) self.ds = None self.out_file = None + self.__step = 0 def process_output(self, out_file): """ @@ -457,7 +446,7 @@ def process_output(self, out_file): def initialize(self, model): """ - Creates a pandas DataFrame and adds model variables as columns. + Creates an empty dictionary to save the outputs. Parameters ---------- @@ -469,12 +458,12 @@ def initialize(self, model): None """ - self.ds = pd.DataFrame(columns=self.capture_elements_step) + self.ds = defaultdict(list) + self.__step = 0 def update(self, model): """ - Add a row to the results pandas DataFrame with the values of the - variables listed in capture_elements. + Add new values to the data dictionary. Parameters ---------- @@ -486,13 +475,15 @@ def update(self, model): None """ - self.ds.loc[model.time.round()] = [ - getattr(model.components, key)() - for key in self.capture_elements_step] + self.ds['time'].append(model.time.round()) + for key in self.capture_elements_step: + self.ds[key].append(getattr(model.components, key)()) + + self.__step += 1 def postprocess(self, **kwargs): """ - Delete time column from the pandas DataFrame and flatten + Convert the output dictionary to a pandas DataFrame and flatten xarrays if required. Returns @@ -501,15 +492,15 @@ def postprocess(self, **kwargs): Simulation results stored as a pandas DataFrame. """ - # delete time column as it was created only for avoiding errors - # of appending data. See previous TODO. - del self.ds["time"] + # create the dataframe + df = pd.DataFrame.from_dict(self.ds) + df.set_index('time', inplace=True) # enforce flattening if df is to be saved to csv or tab file flatten = True if self.out_file else kwargs.get("flatten", None) df = DataFrameHandler.make_flat_df( - self.ds, kwargs["return_addresses"], flatten + df, kwargs["return_addresses"], flatten ) if self.out_file: NCFile.df_to_text_file(df, self.out_file) @@ -518,7 +509,7 @@ def postprocess(self, **kwargs): def add_run_elements(self, model): """ - Adds constant elements to a dataframe. + Adds constant elements to the output data dictionary. Parameters ---------- @@ -530,9 +521,8 @@ def add_run_elements(self, model): None """ - nx = len(self.ds.index) - for element in self.capture_elements_run: - self.ds[element] = [getattr(model.components, element)()] * nx + for key in self.capture_elements_run: + self.ds[key] = [getattr(model.components, key)()] * self.__step @staticmethod def make_flat_df(df, return_addresses, flatten=False): diff --git a/tests/pytest_pysd/pytest_output.py b/tests/pytest_pysd/pytest_output.py index 341524bf..36c4fce8 100644 --- a/tests/pytest_pysd/pytest_output.py +++ b/tests/pytest_pysd/pytest_output.py @@ -312,15 +312,6 @@ def test_dataset_handler_step_setter(self, tmp_path, model): output.set_capture_elements(capture_elements) output.initialize(model) - # Dataset handler step cannot be modified from the outside - with pytest.raises(AttributeError): - output.handler.step = 5 - - with pytest.raises(AttributeError): - output.handler.__update_step() - - assert output.handler.step == 0 - def test_make_flat_df(self): df = pd.DataFrame(index=[1], columns=['elem1']) diff --git a/tests/pytest_pysd/pytest_pysd.py b/tests/pytest_pysd/pytest_pysd.py index 3dfd434c..d0996f97 100644 --- a/tests/pytest_pysd/pytest_pysd.py +++ b/tests/pytest_pysd/pytest_pysd.py @@ -1228,9 +1228,9 @@ def test__integrate(self, tmp_path, model): model.output = output model._integrate() res = model.output.handler.ds - assert isinstance(res, pd.DataFrame) + assert isinstance(res, dict) assert 'teacup_temperature' in res - assert all(res.index.values == list(range(0, 5, 2))) + assert np.array_equal(res['time'], list(range(0, 5, 2))) model.reload() model.time.add_return_timestamps(list(range(0, 5, 2)))