forked from sandialabs/optimism
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request sandialabs#69 from sandialabs/ralberd/path_depende…
…nt_adjoint ralberd/path dependent adjoint
- Loading branch information
Showing
7 changed files
with
1,000 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from optimism import FunctionSpace | ||
from optimism import Interpolants | ||
from optimism import Mesh | ||
from optimism.FunctionSpace import compute_element_volumes | ||
from optimism.FunctionSpace import compute_element_volumes_axisymmetric | ||
from optimism.FunctionSpace import map_element_shape_grads | ||
from jax import vmap | ||
|
||
def construct_function_space_for_adjoint(coords, mesh, quadratureRule, mode2D='cartesian'): | ||
|
||
shapeOnRef = Interpolants.compute_shapes(mesh.parentElement, quadratureRule.xigauss) | ||
|
||
shapes = vmap(lambda elConns, elShape: elShape, (0, None))(mesh.conns, shapeOnRef.values) | ||
|
||
shapeGrads = vmap(map_element_shape_grads, (None, 0, None, None))(coords, mesh.conns, mesh.parentElement, shapeOnRef.gradients) | ||
|
||
if mode2D == 'cartesian': | ||
el_vols = compute_element_volumes | ||
isAxisymmetric = False | ||
elif mode2D == 'axisymmetric': | ||
el_vols = compute_element_volumes_axisymmetric | ||
isAxisymmetric = True | ||
vols = vmap(el_vols, (None, 0, None, 0, None))(coords, mesh.conns, mesh.parentElement, shapes, quadratureRule.wgauss) | ||
|
||
# unpack mesh and remake a mesh to make sure we get all the AD | ||
mesh = Mesh.Mesh(coords=coords, conns=mesh.conns, simplexNodesOrdinals=mesh.simplexNodesOrdinals, | ||
parentElement=mesh.parentElement, parentElement1d=mesh.parentElement1d, blocks=mesh.blocks, | ||
nodeSets=mesh.nodeSets, sideSets=mesh.sideSets) | ||
|
||
return FunctionSpace.FunctionSpace(shapes, vols, shapeGrads, mesh, quadratureRule, isAxisymmetric) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from collections import namedtuple | ||
|
||
from optimism.JaxConfig import * | ||
from optimism import FunctionSpace | ||
from optimism import Interpolants | ||
from optimism.TensorMath import tensor_2D_to_3D | ||
|
||
IvsUpdateInverseFunctions = namedtuple('IvsUpdateInverseFunctions', | ||
['ivs_update_jac_ivs_prev', | ||
'ivs_update_jac_disp_vjp', | ||
'ivs_update_jac_coords_vjp']) | ||
|
||
PathDependentResidualInverseFunctions = namedtuple('PathDependentResidualInverseFunctions', | ||
['residual_jac_ivs_prev_vjp', | ||
'residual_jac_coords_vjp']) | ||
|
||
ResidualInverseFunctions = namedtuple('ResidualInverseFunctions', | ||
['residual_jac_coords_vjp']) | ||
|
||
def _compute_element_field_gradient(U, elemShapeGrads, elemConnectivity, modify_element_gradient): | ||
elemNodalDisps = U[elemConnectivity] | ||
elemGrads = vmap(FunctionSpace.compute_quadrature_point_field_gradient, (None, 0))(elemNodalDisps, elemShapeGrads) | ||
elemGrads = modify_element_gradient(elemGrads) | ||
return elemGrads | ||
|
||
def _compute_field_gradient(shapeGrads, conns, nodalField, modify_element_gradient): | ||
return vmap(_compute_element_field_gradient, (None,0,0,None))(nodalField, shapeGrads, conns, modify_element_gradient) | ||
|
||
def _compute_updated_internal_variables_gradient(dispGrads, states, dt, compute_state_new, output_shape): | ||
dgQuadPointRavel = dispGrads.reshape(dispGrads.shape[0]*dispGrads.shape[1],*dispGrads.shape[2:]) | ||
stQuadPointRavel = states.reshape(states.shape[0]*states.shape[1],*states.shape[2:]) | ||
statesNew = vmap(compute_state_new, (0, 0, None))(dgQuadPointRavel, stQuadPointRavel, dt) | ||
return statesNew.reshape(output_shape) | ||
|
||
|
||
def create_ivs_update_inverse_functions(functionSpace, mode2D, materialModel, pressureProjectionDegree=None): | ||
fs = functionSpace | ||
shapeOnRef = Interpolants.compute_shapes(fs.mesh.parentElement, fs.quadratureRule.xigauss) | ||
|
||
if mode2D == 'plane strain': | ||
grad_2D_to_3D = vmap(tensor_2D_to_3D) | ||
elif mode2D == 'axisymmetric': | ||
raise NotImplementedError | ||
|
||
modify_element_gradient = grad_2D_to_3D | ||
if pressureProjectionDegree is not None: | ||
raise NotImplementedError | ||
|
||
def compute_partial_ivs_update_partial_ivs_prev(U, stateVariables, dt=0.0): | ||
dispGrads = _compute_field_gradient(fs.shapeGrads, fs.mesh.conns, U, modify_element_gradient) | ||
update_gradient = jacfwd(materialModel.compute_state_new, argnums=1) | ||
grad_shape = stateVariables.shape + (stateVariables.shape[2],) | ||
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\ | ||
update_gradient, grad_shape) | ||
|
||
def compute_ivs_update_parameterized(U, stateVariables, coords, dt=0.0): | ||
shapeGrads = vmap(FunctionSpace.map_element_shape_grads, (None, 0, None, None))(coords, fs.mesh.conns, fs.mesh.parentElement, shapeOnRef.gradients) | ||
dispGrads = _compute_field_gradient(shapeGrads, fs.mesh.conns, U, modify_element_gradient) | ||
update_func = materialModel.compute_state_new | ||
output_shape = stateVariables.shape | ||
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\ | ||
update_func, output_shape) | ||
|
||
compute_partial_ivs_update_partial_coords = jit(lambda u, ivs, x, av, dt=0.0: | ||
vjp(lambda z: compute_ivs_update_parameterized(u, ivs, z, dt), x)[1](av)[0]) | ||
|
||
def compute_ivs_update(U, stateVariables, dt=0.0): | ||
dispGrads = _compute_field_gradient(fs.shapeGrads, fs.mesh.conns, U, modify_element_gradient) | ||
update_func = materialModel.compute_state_new | ||
output_shape = stateVariables.shape | ||
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\ | ||
update_func, output_shape) | ||
|
||
compute_partial_ivs_update_partial_disp = jit(lambda x, ivs, av, dt=0.0: | ||
vjp(lambda z: compute_ivs_update(z, ivs, dt), x)[1](av)[0]) | ||
|
||
return IvsUpdateInverseFunctions(jit(compute_partial_ivs_update_partial_ivs_prev), | ||
compute_partial_ivs_update_partial_disp, | ||
compute_partial_ivs_update_partial_coords | ||
) | ||
|
||
def create_path_dependent_residual_inverse_functions(energyFunction): | ||
|
||
compute_partial_residual_partial_ivs_prev = jit(lambda u, q, iv, x, vx: | ||
vjp(lambda z: grad(energyFunction, 0)(u, q, z, x), iv)[1](vx)[0]) | ||
|
||
compute_partial_residual_partial_coords = jit(lambda u, q, iv, x, vx: | ||
vjp(lambda z: grad(energyFunction, 0)(u, q, iv, z), x)[1](vx)[0]) | ||
|
||
return PathDependentResidualInverseFunctions(compute_partial_residual_partial_ivs_prev, | ||
compute_partial_residual_partial_coords | ||
) | ||
|
||
def create_residual_inverse_functions(energyFunction): | ||
|
||
compute_partial_residual_partial_coords = jit(lambda u, q, x, vx: | ||
vjp(lambda z: grad(energyFunction, 0)(u, q, z), x)[1](vx)[0]) | ||
|
||
return ResidualInverseFunctions(compute_partial_residual_partial_coords) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from optimism.test.MeshFixture import MeshFixture | ||
from collections import namedtuple | ||
import numpy as onp | ||
|
||
class FiniteDifferenceFixture(MeshFixture): | ||
def assertFiniteDifferenceCheckHasVShape(self, errors, tolerance=1e-6): | ||
minError = min(errors) | ||
self.assertLess(minError, tolerance, "Smallest finite difference error not less than tolerance.") | ||
self.assertLess(minError, errors[0], "Finite difference error does not decrease from initial step size.") | ||
self.assertLess(minError, errors[-1], "Finite difference error does not increase after reaching minimum. Try more finite difference steps.") | ||
|
||
def build_direction_vector(self, numDesignVars, seed=123): | ||
|
||
onp.random.seed(seed) | ||
directionVector = onp.random.uniform(-1.0, 1.0, numDesignVars) | ||
normVector = directionVector / onp.linalg.norm(directionVector) | ||
|
||
return onp.array(normVector) | ||
|
||
def compute_finite_difference_error(self, stepSize, initialParameters): | ||
storedState = self.forward_solve(initialParameters) | ||
originalObjective = self.compute_objective_function(storedState, initialParameters) | ||
gradient = self.compute_gradient(storedState, initialParameters) | ||
|
||
directionVector = self.build_direction_vector(initialParameters.shape[0]) | ||
directionalDerivative = onp.tensordot(directionVector, gradient, axes=1) | ||
|
||
perturbedParameters = initialParameters + stepSize * directionVector | ||
storedState = self.forward_solve(perturbedParameters) | ||
perturbedObjective = self.compute_objective_function(storedState, perturbedParameters) | ||
|
||
fd_value = (perturbedObjective - originalObjective) / stepSize | ||
error = abs(directionalDerivative - fd_value) | ||
|
||
return error | ||
|
||
def compute_finite_difference_errors(self, stepSize, steps, initialParameters, printOutput=True): | ||
storedState = self.forward_solve(initialParameters) | ||
originalObjective = self.compute_objective_function(storedState, initialParameters) | ||
gradient = self.compute_gradient(storedState, initialParameters) | ||
|
||
directionVector = self.build_direction_vector(initialParameters.shape[0]) | ||
directionalDerivative = onp.tensordot(directionVector, gradient, axes=1) | ||
|
||
fd_values = [] | ||
errors = [] | ||
for i in range(0, steps): | ||
perturbedParameters = initialParameters + stepSize * directionVector | ||
storedState = self.forward_solve(perturbedParameters) | ||
perturbedObjective = self.compute_objective_function(storedState, perturbedParameters) | ||
|
||
fd_value = (perturbedObjective - originalObjective) / stepSize | ||
fd_values.append(fd_value) | ||
|
||
error = abs(directionalDerivative - fd_value) | ||
errors.append(error) | ||
|
||
stepSize *= 1e-1 | ||
|
||
if printOutput: | ||
print("\n grad'*dir | FD approx | abs error") | ||
print("--------------------------------------------------------------------------------") | ||
for i in range(0, steps): | ||
print(f" {directionalDerivative} | {fd_values[i]} | {errors[i]}") | ||
|
||
return errors | ||
|
Oops, something went wrong.