Skip to content

Commit

Permalink
Merge pull request festim-dev#697 from KulaginVladimir/DerQuant_list-…
Browse files Browse the repository at this point in the history
…and-Traps_list

`F.Traps` and `F.DerivedQuantities` -> subclasses of list
  • Loading branch information
RemDelaporteMathurin authored Feb 6, 2024
2 parents a484d8d + 0d68ce5 commit 49d18e8
Show file tree
Hide file tree
Showing 16 changed files with 881 additions and 196 deletions.
2 changes: 1 addition & 1 deletion festim/concentration/mobile.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def create_diffusion_form(
# add the trapping terms
F_trapping = 0
if traps is not None:
for trap in traps.traps:
for trap in traps:
for i, mat in enumerate(trap.materials):
if type(trap.k_0) is list:
k_0 = trap.k_0[i]
Expand Down
63 changes: 50 additions & 13 deletions festim/concentration/traps/traps.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,89 @@
import festim
import fenics as f
import warnings


class Traps:
def __init__(self, traps=[]) -> None:
self.traps = traps
class Traps(list):
"""
A list of festim.Trap objects
"""

def __init__(self, *args):
# checks that input is list
if not isinstance(*args, list):
raise TypeError("festim.Traps must be a list")
super().__init__(self._validate_trap(item) for item in args[0])

self.F = None
self.extrinsic_formulations = []
self.sub_expressions = []

# add ids if unspecified
for i, trap in enumerate(self.traps, 1):
for i, trap in enumerate(self, 1):
if trap.id is None:
trap.id = i

@property
def traps(self):
return self._traps
warnings.warn(
"The traps attribute will be deprecated in a future release, please use festim.Traps as a list instead",
DeprecationWarning,
)
return self

@traps.setter
def traps(self, value):
warnings.warn(
"The traps attribute will be deprecated in a future release, please use festim.Traps as a list instead",
DeprecationWarning,
)
if isinstance(value, list):
if not all(isinstance(t, festim.Trap) for t in value):
raise TypeError("traps must be a list of festim.Trap")
self._traps = value
super().__init__(value)
else:
raise TypeError("traps must be a list")

def __setitem__(self, index, item):
super().__setitem__(index, self._validate_trap(item))

def insert(self, index, item):
super().insert(index, self._validate_trap(item))

def append(self, item):
super().append(self._validate_trap(item))

def extend(self, other):
if isinstance(other, type(self)):
super().extend(other)
else:
super().extend(self._validate_trap(item) for item in other)

def _validate_trap(self, value):
if isinstance(value, festim.Trap):
return value
raise TypeError("festim.Traps must be a list of festim.Trap")

def make_traps_materials(self, materials):
for trap in self.traps:
for trap in self:
trap.make_materials(materials)

def create_forms(self, mobile, materials, T, dx, dt=None):
self.F = 0
for trap in self.traps:
for trap in self:
trap.create_form(mobile, materials, T, dx, dt=dt)
self.F += trap.F
self.sub_expressions += trap.sub_expressions

def get_trap(self, id):
for trap in self.traps:
for trap in self:
if trap.id == id:
return trap
raise ValueError("Couldn't find trap {}".format(id))

def initialise_extrinsic_traps(self, V):
"""Add functions to ExtrinsicTrapBase objects for density form"""
for trap in self.traps:
for trap in self:
if isinstance(trap, festim.ExtrinsicTrapBase):
trap.density = [f.Function(V)]
trap.density_test_function = f.TestFunction(V)
Expand All @@ -63,14 +100,14 @@ def define_variational_problem_extrinsic_traps(self, dx, dt, T):
"""
self.extrinsic_formulations = []
expressions_extrinsic = []
for trap in self.traps:
for trap in self:
if isinstance(trap, festim.ExtrinsicTrapBase):
trap.create_form_density(dx, dt, T)
self.extrinsic_formulations.append(trap.form_density)
self.sub_expressions.extend(expressions_extrinsic)

def solve_extrinsic_traps(self):
for trap in self.traps:
for trap in self:
if isinstance(trap, festim.ExtrinsicTrapBase):
du_t = f.TrialFunction(trap.density[0].function_space())
J_t = f.derivative(trap.form_density, trap.density[0], du_t)
Expand All @@ -91,6 +128,6 @@ def solve_extrinsic_traps(self):
solver.solve()

def update_extrinsic_traps_density(self):
for trap in self.traps:
for trap in self:
if isinstance(trap, festim.ExtrinsicTrapBase):
trap.density_previous_solution.assign(trap.density[0])
73 changes: 60 additions & 13 deletions festim/exports/derived_quantities/derived_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import os
import numpy as np
from typing import Union
import warnings


class DerivedQuantities:
class DerivedQuantities(list):
"""
A list of festim.DerivedQuantity objects
Args:
derived_quantities (list, optional): list of F.DerivedQuantity
object. Defaults to None.
filename (str, optional): the filename (must end with .csv).
If None, the data will not be exported. Defaults to None.
nb_iterations_between_compute (int, optional): number of
Expand All @@ -26,22 +27,68 @@ class DerivedQuantities:

def __init__(
self,
derived_quantities: list = None,
*args,
filename: str = None,
nb_iterations_between_compute: int = 1,
nb_iterations_between_exports: int = None,
) -> None:
# checks that input is list
if not isinstance(*args, list):
raise TypeError("festim.DerivedQuantities must be a list")
super().__init__(self._validate_derived_quantity(item) for item in args[0])

self.filename = filename
self.nb_iterations_between_compute = nb_iterations_between_compute
self.nb_iterations_between_exports = nb_iterations_between_exports

self.derived_quantities = derived_quantities
if derived_quantities is None:
self.derived_quantities = []

self.data = [self.make_header()]
self.t = []

@property
def derived_quantities(self):
warnings.warn(
"The derived_quantities attribute will be deprecated in a future release, please use festim.DerivedQuantities as a list instead",
DeprecationWarning,
)
return self

@derived_quantities.setter
def derived_quantities(self, value):
warnings.warn(
"The derived_quantities attribute will be deprecated in a future release, please use festim.DerivedQuantities as a list instead",
DeprecationWarning,
)
if isinstance(value, list):
if not all(isinstance(t, DerivedQuantity) for t in value):
raise TypeError(
"derived_quantities must be a list of festim.DerivedQuantity"
)
super().__init__(value)
else:
raise TypeError("derived_quantities must be a list")

def __setitem__(self, index, item):
super().__setitem__(index, self._validate_derived_quantity(item))

def insert(self, index, item):
super().insert(index, self._validate_derived_quantity(item))

def append(self, item):
super().append(self._validate_derived_quantity(item))

def extend(self, other):
if isinstance(other, type(self)):
super().extend(other)
else:
super().extend(self._validate_derived_quantity(item) for item in other)

def _validate_derived_quantity(self, value):
if isinstance(value, DerivedQuantity):
return value
raise TypeError(
"festim.DerivedQuantities must be a list of festim.DerivedQuantity"
)

@property
def filename(self):
return self._filename
Expand All @@ -57,13 +104,13 @@ def filename(self, value):

def make_header(self):
header = ["t(s)"]
for quantity in self.derived_quantities:
for quantity in self:
header.append(quantity.title)
return header

def assign_measures_to_quantities(self, dx, ds):
self.volume_markers = dx.subdomain_data()
for quantity in self.derived_quantities:
for quantity in self:
quantity.dx = dx
quantity.ds = ds
quantity.n = f.FacetNormal(dx.subdomain_data().mesh())
Expand All @@ -75,7 +122,7 @@ def assign_properties_to_quantities(self, materials):
Args:
materials (festim.Materials): the materials
"""
for quantity in self.derived_quantities:
for quantity in self:
quantity.D = materials.D
quantity.S = materials.S
quantity.thermal_cond = materials.thermal_cond
Expand All @@ -84,7 +131,7 @@ def assign_properties_to_quantities(self, materials):
def compute(self, t):
# TODO need to support for soret flag in surface flux
row = [t]
for quantity in self.derived_quantities:
for quantity in self:
if isinstance(quantity, (MaximumVolume, MinimumVolume)):
value = quantity.compute(self.volume_markers)
else:
Expand Down Expand Up @@ -180,7 +227,7 @@ def filter(
quantities = []

# iterate through derived_quantities
for quantity in self.derived_quantities:
for quantity in self:
# initialise flags to False
match_surface, match_volume, match_field, match_instance = (
False,
Expand Down
58 changes: 56 additions & 2 deletions festim/exports/exports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import festim
import fenics as f
import warnings


class Exports(list):
Expand All @@ -8,12 +9,65 @@ class Exports(list):
"""

def __init__(self, *args):
super().__init__(*args)
# checks that input is list
if not isinstance(*args, list):
raise TypeError("festim.Exports must be a list")
super().__init__(self._validate_export(item) for item in args[0])

self.t = None
self.V_DG1 = None
self.final_time = None
self.nb_iterations = 0

@property
def exports(self):
warnings.warn(
"The exports attribute will be deprecated in a future release, please use festim.Exports as a list instead",
DeprecationWarning,
)
return self

@exports.setter
def exports(self, value):
warnings.warn(
"The exports attribute will be deprecated in a future release, please use festim.Exports as a list instead",
DeprecationWarning,
)
if isinstance(value, list):
if not all(
(
isinstance(t, festim.Export)
or isinstance(t, festim.DerivedQuantities)
)
for t in value
):
raise TypeError("exports must be a list of festim.Export")
super().__init__(value)
else:
raise TypeError("exports must be a list")

def __setitem__(self, index, item):
super().__setitem__(index, self._validate_export(item))

def insert(self, index, item):
super().insert(index, self._validate_export(item))

def append(self, item):
super().append(self._validate_export(item))

def extend(self, other):
if isinstance(other, type(self)):
super().extend(other)
else:
super().extend(self._validate_export(item) for item in other)

def _validate_export(self, value):
if isinstance(value, festim.Export) or isinstance(
value, festim.DerivedQuantities
):
return value
raise TypeError("festim.Exports must be a list of festim.Export")

def write(self, label_to_function, dx):
"""writes to file
Expand All @@ -26,7 +80,7 @@ def write(self, label_to_function, dx):
# compute derived quantities
if export.is_compute(self.nb_iterations):
# check if function has to be projected
for quantity in export.derived_quantities:
for quantity in export:
if isinstance(
quantity, (festim.MaximumVolume, festim.MinimumVolume)
):
Expand Down
Loading

0 comments on commit 49d18e8

Please sign in to comment.