From f6946b61796d24a0e006df1cc24c552f9b45d646 Mon Sep 17 00:00:00 2001 From: Chahak Mehta Date: Sat, 15 Jul 2023 13:22:31 -0700 Subject: [PATCH 1/4] Remove extra example files --- examples/mpm-nodal-forces.toml | 73 --------------------- examples/mpm-uniaxial-stress.toml | 61 ------------------ examples/particles-2d-nodal-force.json | 74 ---------------------- examples/particles-2d-uniaxial-stress.json | 26 -------- 4 files changed, 234 deletions(-) delete mode 100644 examples/mpm-nodal-forces.toml delete mode 100644 examples/mpm-uniaxial-stress.toml delete mode 100644 examples/particles-2d-nodal-force.json delete mode 100644 examples/particles-2d-uniaxial-stress.json diff --git a/examples/mpm-nodal-forces.toml b/examples/mpm-nodal-forces.toml deleted file mode 100644 index cf01f1e..0000000 --- a/examples/mpm-nodal-forces.toml +++ /dev/null @@ -1,73 +0,0 @@ -# The `meta` group contains top level attributes that govern the -# behaviour of the MPM Solver. -# -# Attributes: -# title: The title of the experiment. This is just for the user's -# reference. -# type: The type of simulation to be used. Allowed values are -# {"MPMExplicit"} -# scheme: The MPM Scheme used for simulation. Allowed values are -# {"usl", "usf"} -# dt: Timestep used in the simulation. -# nsteps: Number of steps to run the simulation for. -[meta] -title = "uniaxial-nodal-traction" -type = "MPMExplicit" -dimension = 2 -scheme = "usf" -dt = 0.001 -nsteps = 301 -velocity_update = true - -[output] -type = "hdf5" -file = "results/example_2d_out.hdf5" -step_frequency = 5 - -[mesh] -# type = "file" -# file = "mesh-1d.txt" -# boundary_nodes = "boundary-1d.txt" -# particle_element_ids = "particles-elements.txt" -type = "generator" -nelements = [3, 1] -element_length = [0.1, 0.1] -particle_element_ids = [0] -element = "Quadrilateral4Node" - -[[mesh.constraints]] -node_ids = [0, 4] -dir = 0 -velocity = 0.0 - -[[materials]] -id = 0 -density = 1000 -poisson_ratio = 0 -youngs_modulus = 1000000 -type = "LinearElastic" - -[[particles]] -file = "examples/particles-2d-nodal-force.json" -material_id = 0 -init_velocity = 0.0 - -[external_loading] -gravity = [0, 0] - -[[external_loading.concentrated_nodal_forces]] -node_ids = [3, 7] -math_function_id = 0 -dir = 0 -force = 0.05 - -[[external_loading.particle_surface_traction]] -pset = [1] -dir = 1 -math_function_id = 0 -traction = 10.5 - -[[math_functions]] -type = "Linear" -xvalues = [0.0, 0.5, 1.0] -fxvalues = [0.0, 1.0, 1.0] diff --git a/examples/mpm-uniaxial-stress.toml b/examples/mpm-uniaxial-stress.toml deleted file mode 100644 index 4f8065e..0000000 --- a/examples/mpm-uniaxial-stress.toml +++ /dev/null @@ -1,61 +0,0 @@ -# The `meta` group contains top level attributes that govern the -# behaviour of the MPM Solver. -# -# Attributes: -# title: The title of the experiment. This is just for the user's -# reference. -# type: The type of simulation to be used. Allowed values are -# {"MPMExplicit"} -# scheme: The MPM Scheme used for simulation. Allowed values are -# {"usl", "usf"} -# dt: Timestep used in the simulation. -# nsteps: Number of steps to run the simulation for. -[meta] -title = "uniaxial-stress" -type = "MPMExplicit" -dimension = 2 -scheme = "usf" -dt = 0.01 -nsteps = 10 -velocity_update = false - -[output] -format = "npz" -folder = "results/" -step_frequency = 5 - -[mesh] -# type = "file" -# file = "mesh-1d.txt" -# boundary_nodes = "boundary-1d.txt" -# particle_element_ids = "particles-elements.txt" -type = "generator" -nelements = [1, 1] -element_length = [1, 1] -particle_element_ids = [0] -element = "Quadrilateral4Node" - -[[mesh.constraints]] -node_ids = [0, 1] -dir = 1 -velocity = 0.0 - -[[mesh.constraints]] -node_ids = [2, 3] -dir = 1 -velocity = -0.01 - -[[materials]] -id = 0 -density = 1 -poisson_ratio = 0 -youngs_modulus = 1000 -type = "LinearElastic" - -[[particles]] -file = "examples/particles-2d-uniaxial-stress.json" -material_id = 0 -init_velocity = [1.0, 0.0] - -[external_loading] -gravity = [0, 0] diff --git a/examples/particles-2d-nodal-force.json b/examples/particles-2d-nodal-force.json deleted file mode 100644 index f0143b8..0000000 --- a/examples/particles-2d-nodal-force.json +++ /dev/null @@ -1,74 +0,0 @@ -[ - [ - [ - 0.025, - 0.025 - ] - ], - [ - [ - 0.075, - 0.025 - ] - ], - [ - [ - 0.125, - 0.025 - ] - ], - [ - [ - 0.175, - 0.025 - ] - ], - [ - [ - 0.225, - 0.025 - ] - ], - [ - [ - 0.275, - 0.025 - ] - ], - [ - [ - 0.025, - 0.075 - ] - ], - [ - [ - 0.075, - 0.075 - ] - ], - [ - [ - 0.125, - 0.075 - ] - ], - [ - [ - 0.175, - 0.075 - ] - ], - [ - [ - 0.225, - 0.075 - ] - ], - [ - [ - 0.275, - 0.075 - ] - ] -] \ No newline at end of file diff --git a/examples/particles-2d-uniaxial-stress.json b/examples/particles-2d-uniaxial-stress.json deleted file mode 100644 index 3b22d51..0000000 --- a/examples/particles-2d-uniaxial-stress.json +++ /dev/null @@ -1,26 +0,0 @@ -[ - [ - [ - 0.25, - 0.25 - ] - ], - [ - [ - 0.75, - 0.25 - ] - ], - [ - [ - 0.75, - 0.75 - ] - ], - [ - [ - 0.25, - 0.75 - ] - ] -] \ No newline at end of file From efd78437d5216e624cb4fcef27021f03a2b5f144 Mon Sep 17 00:00:00 2001 From: Chahak Mehta Date: Sat, 15 Jul 2023 14:54:19 -0700 Subject: [PATCH 2/4] Restructure code and add materials subdir To make the imports easier from the materials subdir, also restructured other files. This moves `MPM` to a separate file so as to remove circular imports for materials module. --- .../uniaxial_nodal_forces/test_benchmark.py | 2 +- .../test_benchmark.py | 2 +- .../2d/uniaxial_stress/test_benchmark.py | 2 +- diffmpm/__init__.py | 44 +------------ diffmpm/cli/mpm.py | 2 +- diffmpm/io.py | 3 +- diffmpm/materials/__init__.py | 3 + diffmpm/materials/_base.py | 48 ++++++++++++++ .../linear_elastic.py} | 66 +------------------ diffmpm/materials/newtonian.py | 1 + diffmpm/materials/simple.py | 18 +++++ diffmpm/mpm.py | 42 ++++++++++++ diffmpm/particle.py | 6 +- tests/test_element.py | 2 +- tests/test_material.py | 3 +- tests/test_particle.py | 2 +- 16 files changed, 126 insertions(+), 120 deletions(-) create mode 100644 diffmpm/materials/__init__.py create mode 100644 diffmpm/materials/_base.py rename diffmpm/{material.py => materials/linear_elastic.py} (55%) create mode 100644 diffmpm/materials/newtonian.py create mode 100644 diffmpm/materials/simple.py create mode 100644 diffmpm/mpm.py diff --git a/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py b/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py index ae72923..4dcb077 100644 --- a/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py +++ b/benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py @@ -3,7 +3,7 @@ import jax.numpy as jnp -from diffmpm import MPM +from diffmpm.mpm import MPM def test_benchmarks(): diff --git a/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py b/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py index 356d0a3..995ca16 100644 --- a/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py +++ b/benchmarks/2d/uniaxial_particle_traction/test_benchmark.py @@ -3,7 +3,7 @@ import jax.numpy as jnp -from diffmpm import MPM +from diffmpm.mpm import MPM def test_benchmarks(): diff --git a/benchmarks/2d/uniaxial_stress/test_benchmark.py b/benchmarks/2d/uniaxial_stress/test_benchmark.py index f04e820..0dd6af8 100644 --- a/benchmarks/2d/uniaxial_stress/test_benchmark.py +++ b/benchmarks/2d/uniaxial_stress/test_benchmark.py @@ -3,7 +3,7 @@ import jax.numpy as jnp -from diffmpm import MPM +from diffmpm.mpm import MPM def test_benchmarks(): diff --git a/diffmpm/__init__.py b/diffmpm/__init__.py index faa8316..a138300 100644 --- a/diffmpm/__init__.py +++ b/diffmpm/__init__.py @@ -1,47 +1,5 @@ from importlib.metadata import version -from pathlib import Path -import diffmpm.writers as writers -from diffmpm.io import Config -from diffmpm.solver import MPMExplicit - -__all__ = ["MPM", "__version__"] +__all__ = ["__version__"] __version__ = version("diffmpm") - - -class MPM: - def __init__(self, filepath): - self._config = Config(filepath) - mesh = self._config.parse() - out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath( - self._config.parsed_config["meta"]["title"], - ) - - write_format = self._config.parsed_config["output"].get("format", None) - if write_format is None or write_format.lower() == "none": - writer_func = None - elif write_format == "npz": - writer_func = writers.NPZWriter().write - else: - raise ValueError(f"Specified output format not supported: {write_format}") - - if self._config.parsed_config["meta"]["type"] == "MPMExplicit": - self.solver = MPMExplicit( - mesh, - self._config.parsed_config["meta"]["dt"], - velocity_update=self._config.parsed_config["meta"]["velocity_update"], - sim_steps=self._config.parsed_config["meta"]["nsteps"], - out_steps=self._config.parsed_config["output"]["step_frequency"], - out_dir=out_dir, - writer_func=writer_func, - ) - else: - raise ValueError("Wrong type of solver specified.") - - def solve(self): - """Solve the MPM simulation using JIT solver.""" - arrays = self.solver.solve_jit( - self._config.parsed_config["external_loading"]["gravity"], - ) - return arrays diff --git a/diffmpm/cli/mpm.py b/diffmpm/cli/mpm.py index aebc4ba..0b4b9d7 100644 --- a/diffmpm/cli/mpm.py +++ b/diffmpm/cli/mpm.py @@ -1,6 +1,6 @@ import click -from diffmpm import MPM +from diffmpm.mpm import MPM @click.command() # type: ignore diff --git a/diffmpm/io.py b/diffmpm/io.py index d6e4573..66447f5 100644 --- a/diffmpm/io.py +++ b/diffmpm/io.py @@ -1,11 +1,10 @@ import json import tomllib as tl -from collections import namedtuple import jax.numpy as jnp from diffmpm import element as mpel -from diffmpm import material as mpmat +from diffmpm import materials as mpmat from diffmpm import mesh as mpmesh from diffmpm.constraint import Constraint from diffmpm.forces import NodalForce, ParticleTraction diff --git a/diffmpm/materials/__init__.py b/diffmpm/materials/__init__.py new file mode 100644 index 0000000..ce35083 --- /dev/null +++ b/diffmpm/materials/__init__.py @@ -0,0 +1,3 @@ +from diffmpm.materials._base import _Material +from diffmpm.materials.simple import SimpleMaterial +from diffmpm.materials.linear_elastic import LinearElastic diff --git a/diffmpm/materials/_base.py b/diffmpm/materials/_base.py new file mode 100644 index 0000000..d30b15b --- /dev/null +++ b/diffmpm/materials/_base.py @@ -0,0 +1,48 @@ +import abc +from typing import Tuple + + +class _Material(abc.ABC): + """Base material class.""" + + _props: Tuple[str, ...] + + def __init__(self, material_properties): + """Initialize material properties. + + Parameters + ---------- + material_properties: dict + A key-value map for various material properties. + """ + self.properties = material_properties + + # @abc.abstractmethod + def tree_flatten(self): + """Flatten this class as PyTree Node.""" + return (tuple(), self.properties) + + # @abc.abstractmethod + @classmethod + def tree_unflatten(cls, aux_data, children): + """Unflatten this class as PyTree Node.""" + del children + return cls(aux_data) + + @abc.abstractmethod + def __repr__(self): + """Repr for Material class.""" + ... + + @abc.abstractmethod + def compute_stress(self): + """Compute stress for the material.""" + ... + + def validate_props(self, material_properties): + for key in self._props: + if key not in material_properties: + raise KeyError( + f"'{key}' should be present in `material_properties` " + f"for {self.__class__.__name__} materials." + ) diff --git a/diffmpm/material.py b/diffmpm/materials/linear_elastic.py similarity index 55% rename from diffmpm/material.py rename to diffmpm/materials/linear_elastic.py index 09230d4..098c10e 100644 --- a/diffmpm/material.py +++ b/diffmpm/materials/linear_elastic.py @@ -1,58 +1,11 @@ -import abc -from typing import Tuple - import jax.numpy as jnp from jax.tree_util import register_pytree_node_class - -class Material(abc.ABC): - """Base material class.""" - - _props: Tuple[str, ...] - - def __init__(self, material_properties): - """Initialize material properties. - - Parameters - ---------- - material_properties: dict - A key-value map for various material properties. - """ - self.properties = material_properties - - # @abc.abstractmethod - def tree_flatten(self): - """Flatten this class as PyTree Node.""" - return (tuple(), self.properties) - - # @abc.abstractmethod - @classmethod - def tree_unflatten(cls, aux_data, children): - """Unflatten this class as PyTree Node.""" - del children - return cls(aux_data) - - @abc.abstractmethod - def __repr__(self): - """Repr for Material class.""" - ... - - @abc.abstractmethod - def compute_stress(self): - """Compute stress for the material.""" - ... - - def validate_props(self, material_properties): - for key in self._props: - if key not in material_properties: - raise KeyError( - f"'{key}' should be present in `material_properties` " - f"for {self.__class__.__name__} materials." - ) +from ._base import _Material @register_pytree_node_class -class LinearElastic(Material): +class LinearElastic(_Material): """Linear Elastic Material.""" _props = ("density", "youngs_modulus", "poisson_ratio") @@ -114,18 +67,3 @@ def compute_stress(self, dstrain): """Compute material stress.""" dstress = self.de @ dstrain return dstress - - -@register_pytree_node_class -class SimpleMaterial(Material): - _props = ("E", "density") - - def __init__(self, material_properties): - self.validate_props(material_properties) - self.properties = material_properties - - def __repr__(self): - return f"SimpleMaterial(props={self.properties})" - - def compute_stress(self, dstrain): - return dstrain * self.properties["E"] diff --git a/diffmpm/materials/newtonian.py b/diffmpm/materials/newtonian.py new file mode 100644 index 0000000..e5a0d9b --- /dev/null +++ b/diffmpm/materials/newtonian.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/diffmpm/materials/simple.py b/diffmpm/materials/simple.py new file mode 100644 index 0000000..d9cf15d --- /dev/null +++ b/diffmpm/materials/simple.py @@ -0,0 +1,18 @@ +from jax.tree_util import register_pytree_node_class + +from ._base import _Material + + +@register_pytree_node_class +class SimpleMaterial(_Material): + _props = ("E", "density") + + def __init__(self, material_properties): + self.validate_props(material_properties) + self.properties = material_properties + + def __repr__(self): + return f"SimpleMaterial(props={self.properties})" + + def compute_stress(self, dstrain): + return dstrain * self.properties["E"] diff --git a/diffmpm/mpm.py b/diffmpm/mpm.py new file mode 100644 index 0000000..b06eaee --- /dev/null +++ b/diffmpm/mpm.py @@ -0,0 +1,42 @@ +from pathlib import Path + +import diffmpm.writers as writers +from diffmpm.io import Config +from diffmpm.solver import MPMExplicit + + +class MPM: + def __init__(self, filepath): + self._config = Config(filepath) + mesh = self._config.parse() + out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath( + self._config.parsed_config["meta"]["title"], + ) + + write_format = self._config.parsed_config["output"].get("format", None) + if write_format is None or write_format.lower() == "none": + writer_func = None + elif write_format == "npz": + writer_func = writers.NPZWriter().write + else: + raise ValueError(f"Specified output format not supported: {write_format}") + + if self._config.parsed_config["meta"]["type"] == "MPMExplicit": + self.solver = MPMExplicit( + mesh, + self._config.parsed_config["meta"]["dt"], + velocity_update=self._config.parsed_config["meta"]["velocity_update"], + sim_steps=self._config.parsed_config["meta"]["nsteps"], + out_steps=self._config.parsed_config["output"]["step_frequency"], + out_dir=out_dir, + writer_func=writer_func, + ) + else: + raise ValueError("Wrong type of solver specified.") + + def solve(self): + """Solve the MPM simulation using JIT solver.""" + arrays = self.solver.solve_jit( + self._config.parsed_config["external_loading"]["gravity"], + ) + return arrays diff --git a/diffmpm/particle.py b/diffmpm/particle.py index 1bb3d70..586fec0 100644 --- a/diffmpm/particle.py +++ b/diffmpm/particle.py @@ -6,7 +6,7 @@ from jax.typing import ArrayLike from diffmpm.element import _Element -from diffmpm.material import Material +from diffmpm.materials import _Material @register_pytree_node_class @@ -16,7 +16,7 @@ class Particles(Sized): def __init__( self, loc: ArrayLike, - material: Material, + material: _Material, element_ids: ArrayLike, initialized: Optional[bool] = None, data: Optional[Tuple[ArrayLike, ...]] = None, @@ -27,7 +27,7 @@ def __init__( ---------- loc: ArrayLike Location of the particles. Expected shape (nparticles, 1, ndim) - material: diffmpm.material.Material + material: diffmpm.materials._Material Type of material for the set of particles. element_ids: ArrayLike The element ids that the particles belong to. This contains diff --git a/tests/test_element.py b/tests/test_element.py index 50881d9..ff8d92e 100644 --- a/tests/test_element.py +++ b/tests/test_element.py @@ -5,7 +5,7 @@ from diffmpm.element import Quadrilateral4Node from diffmpm.forces import NodalForce from diffmpm.functions import Unit -from diffmpm.material import SimpleMaterial +from diffmpm.materials import SimpleMaterial from diffmpm.particle import Particles diff --git a/tests/test_material.py b/tests/test_material.py index 2e041d7..66cb4dc 100644 --- a/tests/test_material.py +++ b/tests/test_material.py @@ -1,7 +1,6 @@ import jax.numpy as jnp import pytest - -from diffmpm.material import LinearElastic, SimpleMaterial +from diffmpm.materials import LinearElastic, SimpleMaterial material_dstrain_stress_targets = [ ( diff --git a/tests/test_particle.py b/tests/test_particle.py index d7dedaa..d67bc2f 100644 --- a/tests/test_particle.py +++ b/tests/test_particle.py @@ -2,7 +2,7 @@ import pytest from diffmpm.element import Quadrilateral4Node -from diffmpm.material import SimpleMaterial +from diffmpm.materials import SimpleMaterial from diffmpm.particle import Particles From 1ec115649905d10d328b97c0633a4d0fa2dda3c2 Mon Sep 17 00:00:00 2001 From: Chahak Mehta Date: Sat, 15 Jul 2023 16:58:52 -0700 Subject: [PATCH 3/4] Pass `particles` as arg to compute stress --- diffmpm/materials/_base.py | 3 ++- diffmpm/materials/linear_elastic.py | 5 ++-- diffmpm/materials/simple.py | 5 ++-- diffmpm/particle.py | 9 ++++++- tests/test_material.py | 38 +++++++++++++++++++++++------ 5 files changed, 46 insertions(+), 14 deletions(-) diff --git a/diffmpm/materials/_base.py b/diffmpm/materials/_base.py index d30b15b..896b206 100644 --- a/diffmpm/materials/_base.py +++ b/diffmpm/materials/_base.py @@ -6,6 +6,7 @@ class _Material(abc.ABC): """Base material class.""" _props: Tuple[str, ...] + properties: dict def __init__(self, material_properties): """Initialize material properties. @@ -35,7 +36,7 @@ def __repr__(self): ... @abc.abstractmethod - def compute_stress(self): + def compute_stress(self, particles): """Compute stress for the material.""" ... diff --git a/diffmpm/materials/linear_elastic.py b/diffmpm/materials/linear_elastic.py index 098c10e..5a008d4 100644 --- a/diffmpm/materials/linear_elastic.py +++ b/diffmpm/materials/linear_elastic.py @@ -9,6 +9,7 @@ class LinearElastic(_Material): """Linear Elastic Material.""" _props = ("density", "youngs_modulus", "poisson_ratio") + state_vars = () def __init__(self, material_properties): """Create a Linear Elastic material. @@ -63,7 +64,7 @@ def _compute_elastic_tensor(self): ] ) - def compute_stress(self, dstrain): + def compute_stress(self, particles): """Compute material stress.""" - dstress = self.de @ dstrain + dstress = self.de @ particles.dstrain return dstress diff --git a/diffmpm/materials/simple.py b/diffmpm/materials/simple.py index d9cf15d..77b57ca 100644 --- a/diffmpm/materials/simple.py +++ b/diffmpm/materials/simple.py @@ -6,6 +6,7 @@ @register_pytree_node_class class SimpleMaterial(_Material): _props = ("E", "density") + state_vars = () def __init__(self, material_properties): self.validate_props(material_properties) @@ -14,5 +15,5 @@ def __init__(self, material_properties): def __repr__(self): return f"SimpleMaterial(props={self.properties})" - def compute_stress(self, dstrain): - return dstrain * self.properties["E"] + def compute_stress(self, particles): + return particles.dstrain * self.properties["E"] diff --git a/diffmpm/particle.py b/diffmpm/particle.py index 586fec0..04f2581 100644 --- a/diffmpm/particle.py +++ b/diffmpm/particle.py @@ -69,6 +69,11 @@ def __init__( self.reference_loc = jnp.zeros_like(self.loc) self.dvolumetric_strain = jnp.zeros((self.loc.shape[0], 1)) self.volumetric_strain_centroid = jnp.zeros((self.loc.shape[0], 1)) + self.state_vars = {} + if self.material.state_vars: + self.state_vars = self.material.initialize_state_variables( + self.loc.shape[0] + ) else: ( self.mass, @@ -87,6 +92,7 @@ def __init__( self.reference_loc, self.dvolumetric_strain, self.volumetric_strain_centroid, + self.state_vars, ) = data # type: ignore self.initialized = True @@ -112,6 +118,7 @@ def tree_flatten(self): self.reference_loc, self.dvolumetric_strain, self.volumetric_strain_centroid, + self.state_vars, ) aux_data = (self.material,) return (children, aux_data) @@ -319,7 +326,7 @@ def compute_stress(self, *args): particles. The stress calculated by the material is then added to the particles current stress values. """ - self.stress = self.stress.at[:].add(self.material.compute_stress(self.dstrain)) + self.stress = self.stress.at[:].add(self.material.compute_stress(self)) def update_volume(self, *args): """Update volume based on central strain rate.""" diff --git a/tests/test_material.py b/tests/test_material.py index 66cb4dc..f81cfb0 100644 --- a/tests/test_material.py +++ b/tests/test_material.py @@ -1,27 +1,48 @@ import jax.numpy as jnp import pytest from diffmpm.materials import LinearElastic, SimpleMaterial +from diffmpm.particle import Particles -material_dstrain_stress_targets = [ +particles_dstrain_stress_targets = [ ( - SimpleMaterial({"E": 10, "density": 1}), + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + SimpleMaterial({"E": 10, "density": 1}), + jnp.array([0]), + ), jnp.ones((1, 6, 1)), jnp.ones((1, 6, 1)) * 10, ), ( - LinearElastic({"density": 1, "youngs_modulus": 10, "poisson_ratio": 1}), + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + LinearElastic({"density": 1, "youngs_modulus": 10, "poisson_ratio": 1}), + jnp.array([0]), + ), jnp.ones((1, 6, 1)), jnp.array([-10, -10, -10, 2.5, 2.5, 2.5]).reshape(1, 6, 1), ), ( - LinearElastic({"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}), + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + LinearElastic( + {"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3} + ), + jnp.array([0]), + ), jnp.array([0.001, 0.0005, 0, 0, 0, 0]).reshape(1, 6, 1), jnp.array([1.63461538461538e4, 12500, 0.86538461538462e4, 0, 0, 0]).reshape( 1, 6, 1 ), ), ( - LinearElastic({"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3}), + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + LinearElastic( + {"density": 1000, "youngs_modulus": 1e7, "poisson_ratio": 0.3} + ), + jnp.array([0]), + ), jnp.array([0.001, 0.0005, 0, 0.00001, 0, 0]).reshape(1, 6, 1), jnp.array( [1.63461538461538e4, 12500, 0.86538461538462e4, 3.84615384615385e01, 0, 0] @@ -30,7 +51,8 @@ ] -@pytest.mark.parametrize("material, dstrain, target", material_dstrain_stress_targets) -def test_compute_stress(material, dstrain, target): - stress = material.compute_stress(dstrain) +@pytest.mark.parametrize("particles, dstrain, target", particles_dstrain_stress_targets) +def test_compute_stress(particles, dstrain, target): + particles.dstrain = dstrain + stress = particles.material.compute_stress(particles) assert jnp.allclose(stress, target) From ef2655a07ee73065177b4e400ecfef3d32caf0be Mon Sep 17 00:00:00 2001 From: Chahak Mehta Date: Sat, 15 Jul 2023 16:59:12 -0700 Subject: [PATCH 4/4] Add Newtonian material --- diffmpm/materials/__init__.py | 1 + diffmpm/materials/newtonian.py | 115 ++++++++++++++++++++++++++++++++- tests/test_newtonian.py | 90 ++++++++++++++++++++++++++ 3 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 tests/test_newtonian.py diff --git a/diffmpm/materials/__init__.py b/diffmpm/materials/__init__.py index ce35083..028e715 100644 --- a/diffmpm/materials/__init__.py +++ b/diffmpm/materials/__init__.py @@ -1,3 +1,4 @@ from diffmpm.materials._base import _Material from diffmpm.materials.simple import SimpleMaterial from diffmpm.materials.linear_elastic import LinearElastic +from diffmpm.materials.newtonian import Newtonian diff --git a/diffmpm/materials/newtonian.py b/diffmpm/materials/newtonian.py index e5a0d9b..558832f 100644 --- a/diffmpm/materials/newtonian.py +++ b/diffmpm/materials/newtonian.py @@ -1 +1,114 @@ -#!/usr/bin/env python3 +import jax.numpy as jnp +from jax import Array, lax +from jax.typing import ArrayLike + +from ._base import _Material + + +class Newtonian(_Material): + """Newtonian fluid material model.""" + + _props = ("density", "bulk_modulus", "dynamic_viscosity") + state_vars = ("pressure",) + + def __init__(self, material_properties: dict): + """Create a Newtonian material. + + Parameters + ---------- + material_properties: dict + Dictionary with material properties. For newtonian + materials, `density`, `bulk_modulus` and `dynamic_viscosity` + are required keys. + """ + self.validate_props(material_properties) + compressibility = 1 + + if material_properties.get("incompressible", False): + compressibility = 0 + + self.properties = { + **material_properties, + "compressibility": compressibility, + } + + def __repr__(self): + return f"Newtonian(props={self.properties})" + + def initialize_state_variables(self, nparticles: int) -> dict: + """Return initial state variables dictionary. + + Parameters + ---------- + nparticles : int + Number of particles being simulated with this material. + + Returns + ------- + dict + Dictionary of state variables initialized with values + decided by material type. + """ + state_vars_dict = {var: jnp.zeros((nparticles, 1)) for var in self.state_vars} + return state_vars_dict + + def _thermodynamic_pressure(self, volumetric_strain: ArrayLike) -> Array: + return -self.properties["bulk_modulus"] * volumetric_strain + + def compute_stress(self, particles): + """Compute material stress.""" + ndim = particles.loc.shape[-1] + if ndim not in {2, 3}: + raise ValueError(f"Cannot compute stress for {ndim}-d Newotonian material.") + volumetric_strain_rate = ( + particles.strain_rate[:, 0] + particles.strain_rate[:, 1] + ) + particles.state_vars["pressure"] = ( + particles.state_vars["pressure"] + .at[:] + .add( + self.properties["compressibility"] + * self._thermodynamic_pressure(particles.dvolumetric_strain) + ) + ) + + volumetric_stress_component = self.properties["compressibility"] * ( + -particles.state_vars["pressure"] + - (2 * self.properties["dynamic_viscosity"] * volumetric_strain_rate / 3) + ) + + stress = jnp.zeros_like(particles.stress) + stress = stress.at[:, 0].set( + volumetric_stress_component + + 2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 0] + ) + stress = stress.at[:, 1].set( + volumetric_stress_component + + 2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 1] + ) + + extra_component_2 = lax.select( + ndim == 3, + 2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 2], + jnp.zeros_like(particles.strain_rate[:, 2]), + ) + stress = stress.at[:, 2].set(volumetric_stress_component + extra_component_2) + + stress = stress.at[:, 3].set( + self.properties["dynamic_viscosity"] * particles.strain_rate[:, 3] + ) + + component_4 = lax.select( + ndim == 3, + self.properties["dynamic_viscosity"] * particles.strain_rate[:, 4], + jnp.zeros_like(particles.strain_rate[:, 4]), + ) + stress = stress.at[:, 4].set(component_4) + component_5 = lax.select( + ndim == 3, + self.properties["dynamic_viscosity"] * particles.strain_rate[:, 5], + jnp.zeros_like(particles.strain_rate[:, 5]), + ) + stress = stress.at[:, 5].set(component_5) + + return stress diff --git a/tests/test_newtonian.py b/tests/test_newtonian.py new file mode 100644 index 0000000..518a246 --- /dev/null +++ b/tests/test_newtonian.py @@ -0,0 +1,90 @@ +import jax.numpy as jnp +import pytest +from diffmpm.constraint import Constraint +from diffmpm.element import Quadrilateral4Node +from diffmpm.materials import Newtonian +from diffmpm.node import Nodes +from diffmpm.particle import Particles + +particles_element_targets = [ + ( + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + Newtonian( + { + "density": 1000, + "bulk_modulus": 8333333.333333333, + "dynamic_viscosity": 8.9e-4, + } + ), + jnp.array([0]), + ), + Quadrilateral4Node( + (1, 1), + 1, + (4.0, 4.0), + [(0, Constraint(0, 0.02)), (0, Constraint(1, 0.03))], + Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))), + ), + jnp.array( + [ + -52083.3333338896, + -52083.3333355583, + -52083.3333305521, + -0.0000041719, + 0, + 0, + ] + ).reshape(1, 6, 1), + ), + ( + Particles( + jnp.array([[0.5, 0.5]]).reshape(1, 1, 2), + Newtonian( + { + "density": 1000, + "bulk_modulus": 8333333.333333333, + "dynamic_viscosity": 8.9e-4, + "incompressible": True, + } + ), + jnp.array([0]), + ), + Quadrilateral4Node( + (1, 1), + 1, + (4.0, 4.0), + [(0, Constraint(0, 0.02)), (0, Constraint(1, 0.03))], + Nodes(4, jnp.array([-2, -2, 2, -2, -2, 2, 2, 2]).reshape((4, 1, 2))), + ), + jnp.array( + [ + -0.0000033375, + -0.00000500625, + 0, + -0.0000041719, + 0, + 0, + ] + ).reshape(1, 6, 1), + ), +] + + +@pytest.mark.parametrize( + "particles, element, target", + particles_element_targets, +) +def test_compute_stress(particles, element, target): + dt = 1 + particles.update_natural_coords(element) + if element.constraints: + element.apply_boundary_constraints() + particles.compute_strain(element, dt) + stress = particles.material.compute_stress(particles) + assert jnp.allclose(stress, target) + + +def test_init(): + with pytest.raises(KeyError): + Newtonian({"dynamic_viscosity": 1, "density": 1})