Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Newtonian model #25

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/2d/uniaxial_nodal_forces/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax.numpy as jnp

from diffmpm import MPM
from diffmpm.mpm import MPM


def test_benchmarks():
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/2d/uniaxial_particle_traction/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax.numpy as jnp

from diffmpm import MPM
from diffmpm.mpm import MPM


def test_benchmarks():
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/2d/uniaxial_stress/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax.numpy as jnp

from diffmpm import MPM
from diffmpm.mpm import MPM


def test_benchmarks():
Expand Down
44 changes: 1 addition & 43 deletions diffmpm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,5 @@
from importlib.metadata import version
from pathlib import Path

import diffmpm.writers as writers
from diffmpm.io import Config
from diffmpm.solver import MPMExplicit

__all__ = ["MPM", "__version__"]
__all__ = ["__version__"]

__version__ = version("diffmpm")


class MPM:
def __init__(self, filepath):
self._config = Config(filepath)
mesh = self._config.parse()
out_dir = Path(self._config.parsed_config["output"]["folder"]).joinpath(
self._config.parsed_config["meta"]["title"],
)

write_format = self._config.parsed_config["output"].get("format", None)
if write_format is None or write_format.lower() == "none":
writer_func = None
elif write_format == "npz":
writer_func = writers.NPZWriter().write
else:
raise ValueError(f"Specified output format not supported: {write_format}")

if self._config.parsed_config["meta"]["type"] == "MPMExplicit":
self.solver = MPMExplicit(
mesh,
self._config.parsed_config["meta"]["dt"],
velocity_update=self._config.parsed_config["meta"]["velocity_update"],
sim_steps=self._config.parsed_config["meta"]["nsteps"],
out_steps=self._config.parsed_config["output"]["step_frequency"],
out_dir=out_dir,
writer_func=writer_func,
)
else:
raise ValueError("Wrong type of solver specified.")

def solve(self):
"""Solve the MPM simulation using JIT solver."""
arrays = self.solver.solve_jit(
self._config.parsed_config["external_loading"]["gravity"],
)
return arrays
2 changes: 1 addition & 1 deletion diffmpm/cli/mpm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import click

from diffmpm import MPM
from diffmpm.mpm import MPM


@click.command() # type: ignore
Expand Down
3 changes: 1 addition & 2 deletions diffmpm/io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import tomllib as tl
from collections import namedtuple

import jax.numpy as jnp

from diffmpm import element as mpel
from diffmpm import material as mpmat
from diffmpm import materials as mpmat
from diffmpm import mesh as mpmesh
from diffmpm.constraint import Constraint
from diffmpm.forces import NodalForce, ParticleTraction
Expand Down
4 changes: 4 additions & 0 deletions diffmpm/materials/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from diffmpm.materials._base import _Material
from diffmpm.materials.simple import SimpleMaterial
from diffmpm.materials.linear_elastic import LinearElastic
from diffmpm.materials.newtonian import Newtonian
49 changes: 49 additions & 0 deletions diffmpm/materials/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import abc
from typing import Tuple


class _Material(abc.ABC):
"""Base material class."""

_props: Tuple[str, ...]
properties: dict

def __init__(self, material_properties):
"""Initialize material properties.

Parameters
----------
material_properties: dict
A key-value map for various material properties.
"""
self.properties = material_properties

# @abc.abstractmethod
def tree_flatten(self):
"""Flatten this class as PyTree Node."""
return (tuple(), self.properties)

# @abc.abstractmethod
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Unflatten this class as PyTree Node."""
del children
return cls(aux_data)

@abc.abstractmethod
def __repr__(self):
"""Repr for Material class."""
...

@abc.abstractmethod
def compute_stress(self, particles):
"""Compute stress for the material."""
...

def validate_props(self, material_properties):
for key in self._props:
if key not in material_properties:
raise KeyError(
f"'{key}' should be present in `material_properties` "
f"for {self.__class__.__name__} materials."
)
71 changes: 5 additions & 66 deletions diffmpm/material.py → diffmpm/materials/linear_elastic.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,15 @@
import abc
from typing import Tuple

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class


class Material(abc.ABC):
"""Base material class."""

_props: Tuple[str, ...]

def __init__(self, material_properties):
"""Initialize material properties.

Parameters
----------
material_properties: dict
A key-value map for various material properties.
"""
self.properties = material_properties

# @abc.abstractmethod
def tree_flatten(self):
"""Flatten this class as PyTree Node."""
return (tuple(), self.properties)

# @abc.abstractmethod
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Unflatten this class as PyTree Node."""
del children
return cls(aux_data)

@abc.abstractmethod
def __repr__(self):
"""Repr for Material class."""
...

@abc.abstractmethod
def compute_stress(self):
"""Compute stress for the material."""
...

def validate_props(self, material_properties):
for key in self._props:
if key not in material_properties:
raise KeyError(
f"'{key}' should be present in `material_properties` "
f"for {self.__class__.__name__} materials."
)
from ._base import _Material


@register_pytree_node_class
class LinearElastic(Material):
class LinearElastic(_Material):
"""Linear Elastic Material."""

_props = ("density", "youngs_modulus", "poisson_ratio")
state_vars = ()

def __init__(self, material_properties):
"""Create a Linear Elastic material.
Expand Down Expand Up @@ -110,22 +64,7 @@ def _compute_elastic_tensor(self):
]
)

def compute_stress(self, dstrain):
def compute_stress(self, particles):
"""Compute material stress."""
dstress = self.de @ dstrain
dstress = self.de @ particles.dstrain
return dstress


@register_pytree_node_class
class SimpleMaterial(Material):
_props = ("E", "density")

def __init__(self, material_properties):
self.validate_props(material_properties)
self.properties = material_properties

def __repr__(self):
return f"SimpleMaterial(props={self.properties})"

def compute_stress(self, dstrain):
return dstrain * self.properties["E"]
114 changes: 114 additions & 0 deletions diffmpm/materials/newtonian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import jax.numpy as jnp
from jax import Array, lax
from jax.typing import ArrayLike

from ._base import _Material


class Newtonian(_Material):
"""Newtonian fluid material model."""

_props = ("density", "bulk_modulus", "dynamic_viscosity")
state_vars = ("pressure",)

def __init__(self, material_properties: dict):
"""Create a Newtonian material.

Parameters
----------
material_properties: dict
Dictionary with material properties. For newtonian
materials, `density`, `bulk_modulus` and `dynamic_viscosity`
are required keys.
"""
self.validate_props(material_properties)
compressibility = 1

if material_properties.get("incompressible", False):
compressibility = 0

self.properties = {
**material_properties,
"compressibility": compressibility,
}

def __repr__(self):
return f"Newtonian(props={self.properties})"

def initialize_state_variables(self, nparticles: int) -> dict:
"""Return initial state variables dictionary.

Parameters
----------
nparticles : int
Number of particles being simulated with this material.

Returns
-------
dict
Dictionary of state variables initialized with values
decided by material type.
"""
state_vars_dict = {var: jnp.zeros((nparticles, 1)) for var in self.state_vars}
return state_vars_dict

def _thermodynamic_pressure(self, volumetric_strain: ArrayLike) -> Array:
return -self.properties["bulk_modulus"] * volumetric_strain

def compute_stress(self, particles):
"""Compute material stress."""
ndim = particles.loc.shape[-1]
if ndim not in {2, 3}:
raise ValueError(f"Cannot compute stress for {ndim}-d Newotonian material.")
volumetric_strain_rate = (
particles.strain_rate[:, 0] + particles.strain_rate[:, 1]
)
particles.state_vars["pressure"] = (
particles.state_vars["pressure"]
.at[:]
.add(
self.properties["compressibility"]
* self._thermodynamic_pressure(particles.dvolumetric_strain)
)
)

volumetric_stress_component = self.properties["compressibility"] * (
-particles.state_vars["pressure"]
- (2 * self.properties["dynamic_viscosity"] * volumetric_strain_rate / 3)
)

stress = jnp.zeros_like(particles.stress)
stress = stress.at[:, 0].set(
volumetric_stress_component
+ 2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 0]
)
stress = stress.at[:, 1].set(
volumetric_stress_component
+ 2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 1]
)

extra_component_2 = lax.select(
ndim == 3,
2 * self.properties["dynamic_viscosity"] * particles.strain_rate[:, 2],
jnp.zeros_like(particles.strain_rate[:, 2]),
)
stress = stress.at[:, 2].set(volumetric_stress_component + extra_component_2)

stress = stress.at[:, 3].set(
self.properties["dynamic_viscosity"] * particles.strain_rate[:, 3]
)

component_4 = lax.select(
ndim == 3,
self.properties["dynamic_viscosity"] * particles.strain_rate[:, 4],
jnp.zeros_like(particles.strain_rate[:, 4]),
)
stress = stress.at[:, 4].set(component_4)
component_5 = lax.select(
ndim == 3,
self.properties["dynamic_viscosity"] * particles.strain_rate[:, 5],
jnp.zeros_like(particles.strain_rate[:, 5]),
)
stress = stress.at[:, 5].set(component_5)

return stress
19 changes: 19 additions & 0 deletions diffmpm/materials/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from jax.tree_util import register_pytree_node_class

from ._base import _Material


@register_pytree_node_class
class SimpleMaterial(_Material):
_props = ("E", "density")
state_vars = ()

def __init__(self, material_properties):
self.validate_props(material_properties)
self.properties = material_properties

def __repr__(self):
return f"SimpleMaterial(props={self.properties})"

def compute_stress(self, particles):
return particles.dstrain * self.properties["E"]
Loading