From 72648b20ec4a795334adcab3f751b08a7af24ede Mon Sep 17 00:00:00 2001 From: KulaginVladimir Date: Mon, 7 Oct 2024 02:40:08 +0300 Subject: [PATCH] refactor filtering --- festim/exports/exports.py | 4 +- festim/exports/txt_export.py | 125 ++++++++++++---------- festim/generic_simulation.py | 58 +++++----- test/unit/test_exports/test_txt_export.py | 42 ++++---- 4 files changed, 125 insertions(+), 104 deletions(-) diff --git a/festim/exports/exports.py b/festim/exports/exports.py index 915f44199..c3d61480a 100644 --- a/festim/exports/exports.py +++ b/festim/exports/exports.py @@ -71,7 +71,7 @@ def _validate_export(self, value): return value raise TypeError("festim.Exports must be a list of festim.Export") - def write(self, label_to_function, dx, materials, chemical_pot): + def write(self, label_to_function, dx): """writes to file Args: @@ -129,7 +129,7 @@ def write(self, label_to_function, dx, materials, chemical_pot): label_to_function[export.field], self.V_DG1 ) export.function = label_to_function[export.field] - export.write(self.t, self.final_time, materials, chemical_pot) + export.write(self.t, self.final_time) self.nb_iterations += 1 def initialise_derived_quantities(self, dx, ds, materials): diff --git a/festim/exports/txt_export.py b/festim/exports/txt_export.py index 9144c2b97..7884c69b4 100644 --- a/festim/exports/txt_export.py +++ b/festim/exports/txt_export.py @@ -13,11 +13,20 @@ class TXTExport(festim.Export): field (str): the exported field ("solute", "1", "retention", "T"...) filename (str): the filename (must end with .txt). + write_at_last (bool): if True, the data will be exported at + the last export time. Otherwise, the data will be exported + at each export time. Defaults to False. times (list, optional): if provided, the field will be exported at these timesteps. Otherwise exports at all timesteps. Defaults to None. header_format (str, optional): the format of column headers. Defautls to ".2e". + + Attributes: + data (np.array): the data array of the exported field. The first column + is the mesh vertices. Each next column is the field profile at the specific + export time. + header (str): the header of the exported file. """ def __init__( @@ -31,10 +40,11 @@ def __init__( self.filename = filename self.write_at_last = write_at_last self.header_format = header_format - self._first_time = True self.data = None self.header = None + self._unique_indices = None + self._V = None @property def filename(self): @@ -71,80 +81,81 @@ def is_last(self, current_time, final_time): return True return False - def filter_duplicates(self, data, materials): - x = data[:, 0] - - # Collect all borders - borders = [] - for material in materials: - for border in material.borders: - borders.append(border) - borders = np.unique(borders) - - # Find indices of the closest duplicates to interfaces - border_indx = [] - for border in borders: - closest_indx = np.abs(x - border).argmin() - closest_x = x[closest_indx] - for ind in np.where(x == closest_x)[0]: - border_indx.append(ind) + def initialise_TXTExport(self, mesh, project_to_DG=False, materials=None): - # Find indices of first elements in duplicated pairs and mesh borders - _, unique_indx = np.unique(x, return_index=True) - - # Combine both arrays of indices - combined_indx = np.concatenate([border_indx, unique_indx]) + if project_to_DG: + self._V = f.FunctionSpace(mesh, "DG", 1) + else: + self._V = f.FunctionSpace(mesh, "CG", 1) + + x = f.interpolate(f.Expression("x[0]", degree=1), self._V) + x_column = np.transpose([x.vector()[:]]) + + # if chemical_pot is True or trap_element_type is DG, get indices of duplicates near interfaces + # and indices of the first elements from a pair of duplicates otherwise + if project_to_DG: + # Collect all borders + borders = [] + for material in materials: + if material.borders: + for border in material.borders: + borders.append(border) + borders = np.unique(borders) + + # Find indices of the closest duplicates to interfaces + border_indices = [] + for border in borders: + closest_indx = np.abs(x_column - border).argmin() + closest_x = x_column[closest_indx] + for ind in np.where(x_column == closest_x)[0]: + border_indices.append(ind) + + # Find indices of first elements in duplicated pairs and mesh borders + _, mesh_indices = np.unique(x_column, return_index=True) + + # Get unique indices from both arrays preserving the order in unsorted x-array + unique_indices = [] + for indx in np.argsort(x_column, axis=0)[:, 0]: + if (indx in mesh_indices) or (indx in border_indices): + unique_indices.append(indx) + + self._unique_indices = np.array(unique_indices) - # Sort unique indices to return a slice - combined_indx = sorted(np.unique(combined_indx)) + else: + # Get list of unique indices as integers + self._unique_indices = np.argsort(x_column, axis=0)[:, 0] - return data[combined_indx, :] + self.data = x_column[self._unique_indices] + self.header = "x" - def write(self, current_time, final_time, materials, chemical_pot): - # create a DG1 functionspace if chemical_pot is True - # else create a CG1 functionspace - if chemical_pot: - V = f.FunctionSpace(self.function.function_space().mesh(), "DG", 1) - else: - V = f.FunctionSpace(self.function.function_space().mesh(), "CG", 1) + def write(self, current_time, final_time): - solution = f.project(self.function, V) - solution_column = np.transpose(solution.vector()[:]) if self.is_it_time_to_export(current_time): + solution = f.project(self.function, self._V) + solution_column = np.transpose(solution.vector()[:]) + # if the directory doesn't exist # create it dirname = os.path.dirname(self.filename) if not os.path.exists(dirname): os.makedirs(dirname, exist_ok=True) - # create header if steady or it is the first time to export - # else append new column to the existing file - if final_time is None or self._first_time: - if final_time is None: - self.header = "x,t=steady" - else: - self.header = f"x,t={format(current_time, self.header_format)}s" - - x = f.interpolate(f.Expression("x[0]", degree=1), V) - x_column = np.transpose([x.vector()[:]]) - self.data = np.column_stack([x_column, solution_column]) - self._first_time = False + # if steady, add the corresponding label + # else append new export time to the header + steady = final_time is None + if steady: + self.header += ",t=steady" else: - # Update the header self.header += f",t={format(current_time, self.header_format)}s" - # Add new column - self.data = np.column_stack([self.data, solution_column]) + + # Add new column of filtered and sorted data + self.data = np.column_stack( + [self.data, solution_column[self._unique_indices]] + ) if ( self.write_at_last and self.is_last(current_time, final_time) ) or not self.write_at_last: - if self.is_last(current_time, final_time): - # Sort data by the x-column before the last export time - self.data = self.data[self.data[:, 0].argsort()] - - # Filter duplicates if chemical_pot is True - if chemical_pot: - self.data = self.filter_duplicates(self.data, materials) # Write data np.savetxt( diff --git a/festim/generic_simulation.py b/festim/generic_simulation.py index 96fa71148..2db764373 100644 --- a/festim/generic_simulation.py +++ b/festim/generic_simulation.py @@ -357,11 +357,11 @@ def initialise(self): self.h_transport_problem.initialise(self.mesh, self.materials, self.dt) - # raise warning if the derived quantities don't match the type of mesh - # eg. SurfaceFlux is used with cylindrical mesh for export in self.exports: if isinstance(export, festim.DerivedQuantities): for q in export: + # raise warning if the derived quantities don't match the type of mesh + # eg. SurfaceFlux is used with cylindrical mesh if self.mesh.type not in q.allowed_meshes: warnings.warn( f"{type(q)} may not work as intended for {self.mesh.type} meshes" @@ -381,31 +381,41 @@ def initialise(self): f"SurfaceKinetics boundary condition must be defined on surface {q.surface} to export data with festim.AdsorbedHydrogen" ) - self.exports.initialise_derived_quantities( - self.mesh.dx, self.mesh.ds, self.materials - ) - - # needed to ensure that data is actually exported at TXTExport.times - # see issue 675 - for export in self.exports: - if isinstance(export, festim.TXTExport) and export.times: - if not self.dt.milestones: - self.dt.milestones = [] - for time in export.times: - if time not in self.dt.milestones: - msg = "To ensure that TXTExport exports data at the desired times " - msg += "TXTExport.times are added to milestones" - warnings.warn(msg) - self.dt.milestones.append(time) - self.dt.milestones.sort() - - # set Soret to True for SurfaceFlux quantities - if isinstance(export, festim.DerivedQuantities): - for q in export: + # set Soret to True for SurfaceFlux quantities if isinstance(q, festim.SurfaceFlux): q.soret = self.settings.soret q.T = self.T.T + if isinstance(export, festim.TXTExport): + # pre-process data depending on the chemical potential flag, trap element type, + # and material borders + project_to_DG = ( + self.settings.chemical_pot + or self.settings.traps_element_type == "DG" + ) + export.initialise_TXTExport( + self.mesh.mesh, + project_to_DG, + self.materials, + ) + + # needed to ensure that data is actually exported at TXTExport.times + # see issue 675 + if export.times: + if not self.dt.milestones: + self.dt.milestones = [] + for time in export.times: + if time not in self.dt.milestones: + msg = "To ensure that TXTExport exports data at the desired times " + msg += "TXTExport.times are added to milestones" + warnings.warn(msg) + self.dt.milestones.append(time) + self.dt.milestones.sort() + + self.exports.initialise_derived_quantities( + self.mesh.dx, self.mesh.ds, self.materials + ) + def run(self, completion_tone=False): """Runs the model. @@ -506,8 +516,6 @@ def run_post_processing(self): self.exports.write( self.label_to_function, self.mesh.dx, - self.materials, - self.settings.chemical_pot, ) def update_post_processing_solutions(self): diff --git a/test/unit/test_exports/test_txt_export.py b/test/unit/test_exports/test_txt_export.py index 134d88255..5210bf3f3 100644 --- a/test/unit/test_exports/test_txt_export.py +++ b/test/unit/test_exports/test_txt_export.py @@ -1,4 +1,4 @@ -from festim import TXTExport, Stepsize, Material +from festim import TXTExport, Material import fenics as f import os import pytest @@ -7,6 +7,12 @@ class TestWrite: + @pytest.fixture + def mesh(self): + mesh = f.UnitIntervalMesh(10) + + return mesh + @pytest.fixture def function(self): mesh = f.UnitIntervalMesh(10) @@ -34,31 +40,29 @@ def my_export(self, tmpdir): return my_export - def test_file_exists(self, my_export, function): + def test_file_exists(self, my_export, function, mesh): current_time = 1 my_export.function = function + my_export.initialise_TXTExport(mesh) my_export.write( current_time=current_time, final_time=None, - materials=None, - chemical_pot=False, ) assert os.path.exists(my_export.filename) - def test_file_doesnt_exist(self, my_export, function): + def test_file_doesnt_exist(self, my_export, function, mesh): current_time = 10 my_export.function = function + my_export.initialise_TXTExport(mesh) my_export.write( current_time=current_time, final_time=None, - materials=None, - chemical_pot=False, ) assert not os.path.exists(my_export.filename) - def test_create_folder(self, my_export, function): + def test_create_folder(self, my_export, function, mesh): """Checks that write() creates the folder if it doesn't exist""" current_time = 1 my_export.function = function @@ -68,23 +72,21 @@ def test_create_folder(self, my_export, function): + "/folder2" + my_export.filename[slash_indx:] ) + my_export.initialise_TXTExport(mesh) my_export.write( current_time=current_time, final_time=None, - materials=None, - chemical_pot=False, ) assert os.path.exists(my_export.filename) - def test_subspace(self, my_export, function_subspace): + def test_subspace(self, my_export, function_subspace, mesh): current_time = 1 my_export.function = function_subspace + my_export.initialise_TXTExport(mesh) my_export.write( current_time=current_time, final_time=None, - materials=None, - chemical_pot=False, ) assert os.path.exists(my_export.filename) @@ -97,20 +99,19 @@ def test_error_filename_not_a_str(self, my_export): with pytest.raises(TypeError, match="filename must be a string"): my_export.filename = 2 - def test_sorted_by_x(self, my_export, function): + def test_sorted_by_x(self, my_export, function, mesh): """Checks that the exported data is sorted by x""" current_time = 1 my_export.function = function + my_export.initialise_TXTExport(mesh) my_export.write( current_time=current_time, final_time=None, - materials=None, - chemical_pot=False, ) assert (np.diff(my_export.data[:, 0]) >= 0).all() @pytest.mark.parametrize( - "materials,chemical_pot,export_len", + "materials,project_to_DG,export_len", [ (None, False, 11), ( @@ -123,18 +124,19 @@ def test_sorted_by_x(self, my_export, function): ), ], ) - def test_duplicates(self, materials, chemical_pot, export_len, my_export, function): + def test_duplicates( + self, materials, project_to_DG, export_len, my_export, function, mesh + ): """ Checks that the exported data does not contain duplicates except those near interfaces """ current_time = 1 my_export.function = function + my_export.initialise_TXTExport(mesh, project_to_DG, materials) my_export.write( current_time=current_time, final_time=None, - materials=materials, - chemical_pot=chemical_pot, ) assert len(my_export.data) == export_len