From ccacb4fbb9700638b569f06dabfeb762f3ea0fd6 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Sun, 19 Nov 2023 06:12:39 -0800 Subject: [PATCH 01/17] Remove gneral matrix routines from tensor math library and put in own module --- optimism/LinAlg.py | 184 ++++++++++++++++++++ optimism/TensorMath.py | 277 +++++++------------------------ optimism/test/test_LinAlg.py | 138 +++++++++++++++ optimism/test/test_TensorMath.py | 131 +-------------- 4 files changed, 386 insertions(+), 344 deletions(-) create mode 100644 optimism/LinAlg.py create mode 100644 optimism/test/test_LinAlg.py diff --git a/optimism/LinAlg.py b/optimism/LinAlg.py new file mode 100644 index 00000000..2e532f86 --- /dev/null +++ b/optimism/LinAlg.py @@ -0,0 +1,184 @@ +import jax +import jax.numpy as np + +from optimism.JaxConfig import if_then_else +from optimism.QuadratureRule import create_padded_quadrature_rule_1D + +@jax.custom_jvp +def sqrtm(A): + sqrtA,_ = sqrtm_dbp(A) + return sqrtA + + +@sqrtm.defjvp +def jvp_sqrtm(primals, tangents): + A, = primals + H, = tangents + sqrtA = sqrtm(A) + dim = A.shape[0] + # TODO(brandon): Use a stable algorithm for solving a Sylvester equation. + # See https://en.wikipedia.org/wiki/Bartels%E2%80%93Stewart_algorithm + # The following will only reliably work for small matrices. + I = np.identity(dim) + M = np.kron(sqrtA.T, I) + np.kron(I, sqrtA) + Hvec = H.T.ravel() + return sqrtA, (np.linalg.solve(M, Hvec)).reshape((dim,dim)).T + + +def sqrtm_dbp(A): + """ Matrix square root by product form of Denman-Beavers iteration. + + Translated from the Matrix Function Toolbox + http://www.ma.man.ac.uk/~higham/mftoolbox + Nicholas J. Higham, Functions of Matrices: Theory and Computation, + SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7, + """ + dim = A.shape[0] + tol = 0.5 * np.sqrt(dim) * np.finfo(np.dtype("float64")).eps + maxIters = 32 + scaleTol = 0.01 + + def scaling(M): + d = np.abs(np.linalg.det(M))**(1.0/(2.0*dim)) + g = 1.0 / d + return g + + def cond_f(loopData): + _,_,error,k,_ = loopData + p = np.array([k < maxIters, error > tol], dtype=bool) + return np.all(p) + + def body_f(loopData): + X, M, error, k, diff = loopData + g = np.where(diff >= scaleTol, + scaling(M), + 1.0) + + X *= g + M *= g * g + + Y = X + N = np.linalg.inv(M) + I = np.identity(dim) + X = 0.5 * X @ (I + N) + M = 0.5 * (I + 0.5 * (M + N)) + error = np.linalg.norm(M - I, 'fro') + diff = np.linalg.norm(X - Y, 'fro') / np.linalg.norm(X, 'fro') + k += 1 + return (X, M, error, k, diff) + + X0 = A + M0 = A + error0 = np.finfo(np.dtype("float64")).max + k0 = 0 + diff0 = 2.0*scaleTol # want to force scaling on first iteration + loopData0 = (X0, M0, error0, k0, diff0) + + X,_,_,k,_ = jax.lax.while_loop(cond_f, body_f, loopData0) + + return X,k + + +@jax.custom_jvp +def logm_iss(A): + X,k,m = _logm_iss(A) + return (1 << k) * log_pade_pf(X - np.identity(A.shape[0]), m) + + +@logm_iss.defjvp +def logm_jvp(primals, tangents): + A, = primals + H, = tangents + logA = logm_iss(A) + DexpLogA = jax.jacfwd(jax.scipy.linalg.expm)(logA) + dim = A.shape[0] + JVP = np.linalg.solve(DexpLogA.reshape(dim*dim,-1), H.ravel()) + return logA, JVP.reshape(dim,dim) + + +def _logm_iss(A): + """Logarithmic map by inverse scaling and squaring and Padé approximants + + Translated from the Matrix Function Toolbox + http://www.ma.man.ac.uk/~higham/mftoolbox + Nicholas J. Higham, Functions of Matrices: Theory and Computation, + SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7, + """ + dim = A.shape[0] + c15 = log_pade_coefficients[15] + + def cond_f(loopData): + _,_,k,_,_,converged = loopData + conditions = np.array([~converged, k < 16], dtype = bool) + return conditions.all() + + def compute_pade_degree(diff, j, itk): + j += 1 + # Manually force the return type of searchsorted to be 64-bit int, because it + # returns 32-bit ints, ignoring the global `jax_enable_x64` flag. This looks + # like a bug. I filed an issue (#11375) with Jax to correct this. + # If they fix it, the conversions on p and q can be removed. + p = np.searchsorted(log_pade_coefficients[2:16], diff, side='right').astype(np.int64) + p += 2 + q = np.searchsorted(log_pade_coefficients[2:16], diff/2.0, side='right').astype(np.int64) + q += 2 + m,j,converged = if_then_else((2 * (p - q) // 3 < itk) | (j == 2), + (p+1,j,True), (0,j,False)) + return m,j,converged + + def body_f(loopData): + X,j,k,m,itk,converged = loopData + diff = np.linalg.norm(X - np.identity(dim), ord=1) + m,j,converged = if_then_else(diff < c15, + compute_pade_degree(diff, j, itk), + (m, j, converged)) + X,itk = sqrtm_dbp(X) + k += 1 + return X,j,k,m,itk,converged + + X = A + j = 0 + k = 0 + m = 0 + itk = 5 + converged = False + X,j,k,m,itk,converged = jax.lax.while_loop(cond_f, body_f, (X,j,k,m,itk,converged)) + return X,k,m + + +def log_pade_pf(A, n): + """Logarithmic map by Padé approximant and partial fractions + """ + I = np.identity(A.shape[0]) + X = np.zeros_like(A) + quadPrec = 2*n - 1 + xs,ws = create_padded_quadrature_rule_1D(quadPrec) + + def get_log_inc(A, x, w): + B = I + x*A + dXT = w*np.linalg.solve(B.T, A.T) + return dXT + + dXsTransposed = jax.vmap(get_log_inc, (None, 0, 0))(A, xs, ws) + X = np.sum(dXsTransposed, axis=0).T + + return X + + +log_pade_coefficients = np.array([ + 1.100343044625278e-05, 1.818617533662554e-03, 1.620628479501567e-02, 5.387353263138127e-02, + 1.135280226762866e-01, 1.866286061354130e-01, 2.642960831111435e-01, 3.402172331985299e-01, + 4.108235000556820e-01, 4.745521256007768e-01, 5.310667521178455e-01, 5.806887133441684e-01, + 6.240414344012918e-01, 6.618482563071411e-01, 6.948266172489354e-01, 7.236382701437292e-01, + 7.488702930926310e-01, 7.710320825151814e-01, 7.905600074925671e-01, 8.078252198050853e-01, + 8.231422814010787e-01, 8.367774696147783e-01, 8.489562661576765e-01, 8.598698723737197e-01, + 8.696807597657327e-01, 8.785273397512191e-01, 8.865278635527148e-01, 8.937836659824918e-01, + 9.003818585631236e-01, 9.063975647545747e-01, 9.118957765024351e-01, 9.169328985287867e-01, + 9.215580354375991e-01, 9.258140669835052e-01, 9.297385486977516e-01, 9.333644683151422e-01, + 9.367208829050256e-01, 9.398334570841484e-01, 9.427249190039424e-01, 9.454154478075423e-01, + 9.479230038146050e-01, 9.502636107090112e-01, 9.524515973891873e-01, 9.544998058228285e-01, + 9.564197701703862e-01, 9.582218715590143e-01, 9.599154721638511e-01, 9.615090316568806e-01, + 9.630102085912245e-01, 9.644259488813590e-01, 9.657625632018019e-01, 9.670257948457799e-01, + 9.682208793510226e-01, 9.693525970039069e-01, 9.704253191689650e-01, 9.714430492527785e-01, + 9.724094589950460e-01, 9.733279206814576e-01, 9.742015357899175e-01, 9.750331605111618e-01, + 9.758254285248543e-01, 9.765807713611383e-01, 9.773014366339591e-01, 9.779895043950849e-01 ]) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index ebb2f284..cf6a4594 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -1,61 +1,78 @@ -from jax import custom_jvp -from jax.lax import while_loop -from jax.scipy import linalg +from functools import partial +import jax +import jax.numpy as np -from optimism.JaxConfig import * from optimism import Math -from optimism.QuadratureRule import create_padded_quadrature_rule_1D - - -def compute_deviatoric_tensor(strain): - dil = np.trace(strain) - return strain - (dil/3.)*np.identity(3) +def trace(A): + return A[0, 0] + A[1, 1] + A[2, 2] + +def det(A): + return A[0, 0]*A[1, 1]*A[2, 2] + A[0, 1]*A[1, 2]*A[2, 0] + A[0, 2]*A[1, 0]*A[2, 1] \ + - A[0, 0]*A[1, 2]*A[2, 1] - A[0, 1]*A[1, 0]*A[2, 2] - A[0, 2]*A[1, 1]*A[2, 0] + +def inv(A): + invA00 = A[1, 1]*A[2, 2] - A[1, 2]*A[2, 1] + invA01 = A[0, 2]*A[2, 1] - A[0, 1]*A[2, 2] + invA02 = A[0, 1]*A[1, 2] - A[0, 2]*A[1, 1] + invA10 = A[1, 2]*A[2, 0] - A[1, 0]*A[2, 2] + invA11 = A[0, 0]*A[2, 2] - A[0, 2]*A[2, 0] + invA12 = A[0, 2]*A[1, 0] - A[0, 0]*A[1, 2] + invA20 = A[1, 0]*A[2, 1] - A[1, 1]*A[2, 0] + invA21 = A[0, 1]*A[2, 0] - A[0, 0]*A[2, 1] + invA22 = A[0, 0]*A[1, 1] - A[0, 1]*A[1, 0] + invA = (1.0/det(A)) * np.array([[invA00, invA01, invA02], + [invA10, invA11, invA12], + [invA20, invA21, invA22]]) + return invA + +def compute_deviatoric_tensor(A): + dil = trace(A) + return A - (dil/3.)*np.identity(3) def dev(strain): return compute_deviatoric_tensor(strain) +def sym(A): + return 0.5*(A + A.T) -def tensor_norm(tensor): - return np.linalg.norm( tensor, ord='fro' ) +def skw(A): + return 0.5*(A - A.T) +def norm(A): + return Math.safe_sqrt(A.ravel() @ A.ravel()) def norm_of_deviator_squared(tensor): dev = compute_deviatoric_tensor(tensor) return np.tensordot(dev,dev) - def norm_of_deviator(tensor): - return tensor_norm( compute_deviatoric_tensor(tensor) ) - + return norm( compute_deviatoric_tensor(tensor) ) def mises_equivalent_stress(stress): return np.sqrt(1.5)*norm_of_deviator(stress) - def triaxiality(A): - mean_normal = np.trace(A)/3.0 + mean_normal = trace(A)/3.0 mises_norm = mises_equivalent_stress(A) # avoid division by zero in case of spherical tensor mises_norm += np.finfo(np.dtype("float64")).eps return mean_normal/mises_norm - -def sym(A): - return 0.5*(A + A.T) - - -def logh(A): - d,V = linalg.eigh(A) - return logh_from_eigen(d,V) - - -def logh_from_eigen(eVals, eVecs): - return eVecs@np.diag(np.log(eVals))@eVecs.T - - def tensor_2D_to_3D(H): return np.zeros((3,3)).at[ 0:H.shape[0], 0:H.shape[1] ].set(H) +def gradient_2D_to_axisymmetric(dudX_2D, u, X): + dudX = tensor_2D_to_3D(dudX_2D) + dudX = dudX.at[2, 2].set(u[0]/X[0]) + return dudX + +# BT 11/2023 +# We should probably replace this with np.where and avoid duplication +def if_then_else(cond, val1, val2): + return jax.lax.cond(cond, + lambda x: val1, + lambda x: val2, + None) # Compute eigen values and vectors of a symmetric 3x3 tensor # Note, returned eigen vectors may not be unit length @@ -284,7 +301,7 @@ def cos_of_acos_divided_by_3(x): return numer/denom -@custom_jvp +@jax.custom_jvp def mtk_log_sqrt(A): lam,V = eigen_sym33_unit(A) return V @ np.diag(0.5*np.log(lam)) @ V.T @@ -349,7 +366,7 @@ def mtk_log_sqrt_jvp(Cpack, Hpack): return logSqrtC, sol -@partial(custom_jvp, nondiff_argnums=(1,)) +@partial(jax.custom_jvp, nondiff_argnums=(1,)) def mtk_pow(A,m): lam,V = eigen_sym33_unit(A) return V @ np.diag(np.power(lam,m)) @ V.T @@ -440,20 +457,30 @@ def relative_log_difference(lam1, lam2): relative_log_difference_no_tolerance_check(lam1, lamFake), relative_log_difference_taylor(lam1, lam2)) +# +# We should consider deprecating the following functions and use the mtk ones exclusively +# + +def logh(A): + d,V = np.linalg.eigh(A) + return logh_from_eigen(d,V) + + +def logh_from_eigen(eVals, eVecs): + return eVecs@np.diag(np.log(eVals))@eVecs.T # C must be symmetric! -@custom_jvp +@jax.custom_jvp def log_sqrt(C): return 0.5*logh(C) - @log_sqrt.defjvp def log_jvp(Cpack, Hpack): C, = Cpack H, = Hpack logSqrtC = log_sqrt(C) - lam,V = linalg.eigh(C) + lam,V = np.linalg.eigh(C) lam1 = lam[0] lam2 = lam[1] @@ -506,181 +533,3 @@ def log_jvp(Cpack, Hpack): return logSqrtC, sol -@custom_jvp -def sqrtm(A): - sqrtA,_ = sqrtm_dbp(A) - return sqrtA - - -@sqrtm.defjvp -def jvp_sqrtm(primals, tangents): - A, = primals - H, = tangents - sqrtA = sqrtm(A) - dim = A.shape[0] - # TODO(brandon): Use a stable algorithm for solving a Sylvester equation. - # See https://en.wikipedia.org/wiki/Bartels%E2%80%93Stewart_algorithm - # The following will only reliably work for small matrices. - I = np.identity(dim) - M = np.kron(sqrtA.T, I) + np.kron(I, sqrtA) - Hvec = H.T.ravel() - return sqrtA, (linalg.solve(M, Hvec)).reshape((dim,dim)).T - - -def sqrtm_dbp(A): - """ Matrix square root by product form of Denman-Beavers iteration. - - Translated from the Matrix Function Toolbox - http://www.ma.man.ac.uk/~higham/mftoolbox - Nicholas J. Higham, Functions of Matrices: Theory and Computation, - SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7, - """ - dim = A.shape[0] - tol = 0.5 * np.sqrt(dim) * np.finfo(np.dtype("float64")).eps - maxIters = 32 - scaleTol = 0.01 - - def scaling(M): - d = np.abs(linalg.det(M))**(1.0/(2.0*dim)) - g = 1.0 / d - return g - - def cond_f(loopData): - _,_,error,k,_ = loopData - p = np.array([k < maxIters, error > tol], dtype=bool) - return np.all(p) - - def body_f(loopData): - X, M, error, k, diff = loopData - g = np.where(diff >= scaleTol, - scaling(M), - 1.0) - - X *= g - M *= g * g - - Y = X - N = linalg.inv(M) - I = np.identity(dim) - X = 0.5 * X @ (I + N) - M = 0.5 * (I + 0.5 * (M + N)) - error = np.linalg.norm(M - I, 'fro') - diff = np.linalg.norm(X - Y, 'fro') / np.linalg.norm(X, 'fro') - k += 1 - return (X, M, error, k, diff) - - X0 = A - M0 = A - error0 = np.finfo(np.dtype("float64")).max - k0 = 0 - diff0 = 2.0*scaleTol # want to force scaling on first iteration - loopData0 = (X0, M0, error0, k0, diff0) - - X,_,_,k,_ = while_loop(cond_f, body_f, loopData0) - - return X,k - - -@custom_jvp -def logm_iss(A): - X,k,m = _logm_iss(A) - return (1 << k) * log_pade_pf(X - np.identity(A.shape[0]), m) - - -@logm_iss.defjvp -def logm_jvp(primals, tangents): - A, = primals - H, = tangents - logA = logm_iss(A) - DexpLogA = jacfwd(linalg.expm)(logA) - dim = A.shape[0] - JVP = linalg.solve(DexpLogA.reshape(dim*dim,-1), H.ravel()) - return logA, JVP.reshape(dim,dim) - - -def _logm_iss(A): - """Logarithmic map by inverse scaling and squaring and Padé approximants - - Translated from the Matrix Function Toolbox - http://www.ma.man.ac.uk/~higham/mftoolbox - Nicholas J. Higham, Functions of Matrices: Theory and Computation, - SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7, - """ - dim = A.shape[0] - c15 = log_pade_coefficients[15] - - def cond_f(loopData): - _,_,k,_,_,converged = loopData - conditions = np.array([~converged, k < 16], dtype = bool) - return conditions.all() - - def compute_pade_degree(diff, j, itk): - j += 1 - # Manually force the return type of searchsorted to be 64-bit int, because it - # returns 32-bit ints, ignoring the global `jax_enable_x64` flag. This looks - # like a bug. I filed an issue (#11375) with Jax to correct this. - # If they fix it, the conversions on p and q can be removed. - p = np.searchsorted(log_pade_coefficients[2:16], diff, side='right').astype(np.int64) - p += 2 - q = np.searchsorted(log_pade_coefficients[2:16], diff/2.0, side='right').astype(np.int64) - q += 2 - m,j,converged = if_then_else((2 * (p - q) // 3 < itk) | (j == 2), - (p+1,j,True), (0,j,False)) - return m,j,converged - - def body_f(loopData): - X,j,k,m,itk,converged = loopData - diff = np.linalg.norm(X - np.identity(dim), ord=1) - m,j,converged = if_then_else(diff < c15, - compute_pade_degree(diff, j, itk), - (m, j, converged)) - X,itk = sqrtm_dbp(X) - k += 1 - return X,j,k,m,itk,converged - - X = A - j = 0 - k = 0 - m = 0 - itk = 5 - converged = False - X,j,k,m,itk,converged = while_loop(cond_f, body_f, (X,j,k,m,itk,converged)) - return X,k,m - - -def log_pade_pf(A, n): - """Logarithmic map by Padé approximant and partial fractions - """ - I = np.identity(A.shape[0]) - X = np.zeros_like(A) - quadPrec = 2*n - 1 - xs,ws = create_padded_quadrature_rule_1D(quadPrec) - - def get_log_inc(A, x, w): - B = I + x*A - dXT = w*linalg.solve(B.T, A.T) - return dXT - - dXsTransposed = vmap(get_log_inc, (None, 0, 0))(A, xs, ws) - X = np.sum(dXsTransposed, axis=0).T - - return X - - -log_pade_coefficients = np.array([ - 1.100343044625278e-05, 1.818617533662554e-03, 1.620628479501567e-02, 5.387353263138127e-02, - 1.135280226762866e-01, 1.866286061354130e-01, 2.642960831111435e-01, 3.402172331985299e-01, - 4.108235000556820e-01, 4.745521256007768e-01, 5.310667521178455e-01, 5.806887133441684e-01, - 6.240414344012918e-01, 6.618482563071411e-01, 6.948266172489354e-01, 7.236382701437292e-01, - 7.488702930926310e-01, 7.710320825151814e-01, 7.905600074925671e-01, 8.078252198050853e-01, - 8.231422814010787e-01, 8.367774696147783e-01, 8.489562661576765e-01, 8.598698723737197e-01, - 8.696807597657327e-01, 8.785273397512191e-01, 8.865278635527148e-01, 8.937836659824918e-01, - 9.003818585631236e-01, 9.063975647545747e-01, 9.118957765024351e-01, 9.169328985287867e-01, - 9.215580354375991e-01, 9.258140669835052e-01, 9.297385486977516e-01, 9.333644683151422e-01, - 9.367208829050256e-01, 9.398334570841484e-01, 9.427249190039424e-01, 9.454154478075423e-01, - 9.479230038146050e-01, 9.502636107090112e-01, 9.524515973891873e-01, 9.544998058228285e-01, - 9.564197701703862e-01, 9.582218715590143e-01, 9.599154721638511e-01, 9.615090316568806e-01, - 9.630102085912245e-01, 9.644259488813590e-01, 9.657625632018019e-01, 9.670257948457799e-01, - 9.682208793510226e-01, 9.693525970039069e-01, 9.704253191689650e-01, 9.714430492527785e-01, - 9.724094589950460e-01, 9.733279206814576e-01, 9.742015357899175e-01, 9.750331605111618e-01, - 9.758254285248543e-01, 9.765807713611383e-01, 9.773014366339591e-01, 9.779895043950849e-01 ]) diff --git a/optimism/test/test_LinAlg.py b/optimism/test/test_LinAlg.py new file mode 100644 index 00000000..0e44a39b --- /dev/null +++ b/optimism/test/test_LinAlg.py @@ -0,0 +1,138 @@ +import jax +from jax import numpy as np +from jax.test_util import check_grads +from scipy.spatial.transform import Rotation +import unittest + +from optimism import LinAlg +from optimism.test.TestFixture import TestFixture + +def generate_n_random_symmetric_matrices(n, minval=0.0, maxval=1.0): + key = jax.random.PRNGKey(0) + As = jax.random.uniform(key, (n,3,3), minval=minval, maxval=maxval) + return jax.vmap(lambda A: np.dot(A.T,A), (0,))(As) + +sqrtm_jit = jax.jit(LinAlg.sqrtm) +logm_iss_jit = jax.jit(LinAlg.logm_iss) + +class TestLinAlg(TestFixture): + def setUp(self): + self.sym_mat = generate_n_random_symmetric_matrices(1)[0] + # make a matrix with 2 identical eigenvalues + R = Rotation.random(random_state=41).as_matrix() + eigvals = np.array([2., 0.5, 2.]) + self.sym_mat_double_degeneracy = R@np.diag(eigvals)@R.T + + ### sqrtm ### + + def test_sqrtm_jit(self): + sqrtC = sqrtm_jit(self.sym_mat) + self.assertTrue(not np.isnan(sqrtC).any()) + + + def test_sqrtm(self): + mats = generate_n_random_symmetric_matrices(100) + sqrtMats = jax.vmap(sqrtm_jit, (0,))(mats) + shouldBeMats = jax.vmap(lambda A: np.dot(A,A), (0,))(sqrtMats) + self.assertArrayNear(shouldBeMats, mats, 10) + + + def test_sqrtm_fwd_mode_derivative(self): + check_grads(LinAlg.sqrtm, (self.sym_mat,), order=2, modes=["fwd"]) + + + def test_sqrtm_rev_mode_derivative(self): + check_grads(LinAlg.sqrtm, (self.sym_mat,), order=2, modes=["rev"]) + + + def test_sqrtm_on_degenerate_eigenvalues(self): + C = self.sym_mat_double_degeneracy + sqrtC = LinAlg.sqrtm(C) + shouldBeC = np.dot(sqrtC, sqrtC) + self.assertArrayNear(shouldBeC, C, 12) + check_grads(LinAlg.sqrtm, (C,), order=2, modes=["rev"]) + + + def test_sqrtm_on_10x10(self): + key = jax.random.PRNGKey(0) + F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) + C = F.T@F + sqrtC = LinAlg.sqrtm(C) + shouldBeC = np.dot(sqrtC,sqrtC) + self.assertArrayNear(shouldBeC, C, 11) + + + def test_sqrtm_derivatives_on_10x10(self): + key = jax.random.PRNGKey(0) + F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) + C = F.T@F + check_grads(LinAlg.sqrtm, (C,), order=1, modes=["fwd", "rev"]) + + ### sqrtm ### + + def test_logm_iss_on_matrix_near_identity(self): + key = jax.random.PRNGKey(0) + id_perturbation = 1.0 + jax.random.uniform(key, (3,), minval=1e-8, maxval=0.01) + A = np.diag(id_perturbation) + logA = LinAlg.logm_iss(A) + self.assertArrayNear(logA, np.diag(np.log(id_perturbation)), 12) + + + def test_logm_iss_on_double_degenerate_eigenvalues(self): + C = self.sym_mat_double_degeneracy + logC = LinAlg.logm_iss(C) + explogC = jax.scipy.linalg.expm(logC) + self.assertArrayNear(C, explogC, 8) + + + def test_logm_iss_on_triple_degenerate_eigvalues(self): + A = 4.0*np.identity(3) + logA = LinAlg.logm_iss(A) + self.assertArrayNear(logA, np.log(4.0)*np.identity(3), 12) + + + def test_logm_iss_jit(self): + C = generate_n_random_symmetric_matrices(1)[0] + logC = logm_iss_jit(C) + self.assertFalse(np.isnan(logC).any()) + + + def test_logm_iss_on_full_3x3s(self): + mats = generate_n_random_symmetric_matrices(1000) + logMats = jax.vmap(logm_iss_jit, (0,))(mats) + shouldBeMats = jax.vmap(lambda A: jax.scipy.linalg.expm(A), (0,))(logMats) + self.assertArrayNear(shouldBeMats, mats, 7) + + + def test_logm_iss_fwd_mode_derivative(self): + check_grads(logm_iss_jit, (self.sym_mat,), order=1, modes=['fwd']) + + + def test_logm_iss_rev_mode_derivative(self): + check_grads(logm_iss_jit, (self.sym_mat,), order=1, modes=['rev']) + + + def test_logm_iss_hessian_on_double_degenerate_eigenvalues(self): + C = self.sym_mat_double_degeneracy + check_grads(jax.jacrev(LinAlg.logm_iss), (C,), order=1, modes=['fwd'], rtol=1e-9, atol=1e-9, eps=1e-5) + + + def test_logm_iss_derivatives_on_double_degenerate_eigenvalues(self): + C = self.sym_mat_double_degeneracy + check_grads(LinAlg.logm_iss, (C,), order=1, modes=['fwd']) + check_grads(LinAlg.logm_iss, (C,), order=1, modes=['rev']) + + + def test_logm_iss_derivatives_on_triple_degenerate_eigenvalues(self): + A = 4.0*np.identity(3) + check_grads(LinAlg.logm_iss, (A,), order=1, modes=['fwd']) + check_grads(LinAlg.logm_iss, (A,), order=1, modes=['rev']) + + + def test_logm_iss_on_10x10(self): + key = jax.random.PRNGKey(0) + F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) + C = F.T@F + logC = LinAlg.logm_iss(C) + explogC = jax.scipy.linalg.expm(logC) + self.assertArrayNear(explogC, C, 8) \ No newline at end of file diff --git a/optimism/test/test_TensorMath.py b/optimism/test/test_TensorMath.py index ad302803..d2e56293 100644 --- a/optimism/test/test_TensorMath.py +++ b/optimism/test/test_TensorMath.py @@ -4,7 +4,6 @@ import jax from jax import numpy as np from jax.test_util import check_grads -from jax.scipy import linalg from optimism.test.TestFixture import TestFixture from optimism import TensorMath @@ -26,18 +25,10 @@ def lam(A): return lam -def generate_n_random_symmetric_matrices(n, minval=0.0, maxval=1.0): - key = jax.random.PRNGKey(0) - As = jax.random.uniform(key, (n,3,3), minval=minval, maxval=maxval) - return jax.vmap(lambda A: np.dot(A.T,A), (0,))(As) - - class TensorMathFixture(TestFixture): def setUp(self): self.log_squared = lambda A: np.tensordot(TensorMath.log_sqrt(A), TensorMath.log_sqrt(A)) - self.sqrtm_jit = jax.jit(TensorMath.sqrtm) - self.logm_iss_jit = jax.jit(TensorMath.logm_iss) def test_log_sqrt_tensor_jvp_0(self): @@ -205,127 +196,7 @@ def pow_squared(A): check_grads(pow_squared, (C,), order=1) - ### sqrtm ### - - - def test_sqrtm_jit(self): - C = generate_n_random_symmetric_matrices(1)[0] - sqrtC = self.sqrtm_jit(C) - self.assertFalse(np.isnan(sqrtC).any()) - - - def test_sqrtm(self): - mats = generate_n_random_symmetric_matrices(100) - sqrtMats = jax.vmap(self.sqrtm_jit, (0,))(mats) - shouldBeMats = jax.vmap(lambda A: np.dot(A,A), (0,))(sqrtMats) - self.assertArrayNear(shouldBeMats, mats, 10) - - - def test_sqrtm_fwd_mode_derivative(self): - C = generate_n_random_symmetric_matrices(1)[0] - check_grads(TensorMath.sqrtm, (C,), order=2, modes=["fwd"]) - - - def test_sqrtm_rev_mode_derivative(self): - C = generate_n_random_symmetric_matrices(1)[0] - check_grads(TensorMath.sqrtm, (C,), order=2, modes=["rev"]) - - - def test_sqrtm_on_degenerate_eigenvalues(self): - C = R@np.diag(np.array([2., 0.5, 2]))@R.T - sqrtC = TensorMath.sqrtm(C) - shouldBeC = np.dot(sqrtC, sqrtC) - self.assertArrayNear(shouldBeC, C, 12) - check_grads(TensorMath.sqrtm, (C,), order=2, modes=["rev"]) - - - def test_sqrtm_on_10x10(self): - key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) - C = F.T@F - sqrtC = TensorMath.sqrtm(C) - shouldBeC = np.dot(sqrtC,sqrtC) - self.assertArrayNear(shouldBeC, C, 11) - - - def test_sqrtm_derivatives_on_10x10(self): - key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) - C = F.T@F - check_grads(TensorMath.sqrtm, (C,), order=1, modes=["fwd", "rev"]) - - - def test_logm_iss_on_matrix_near_identity(self): - key = jax.random.PRNGKey(0) - id_perturbation = 1.0 + jax.random.uniform(key, (3,), minval=1e-8, maxval=0.01) - A = np.diag(id_perturbation) - logA = TensorMath.logm_iss(A) - self.assertArrayNear(logA, np.diag(np.log(id_perturbation)), 12) - - - def test_logm_iss_on_double_degenerate_eigenvalues(self): - eigvals = np.array([2., 0.5, 2.]) - C = R@np.diag(eigvals)@R.T - logC = TensorMath.logm_iss(C) - logCSpectral = R@np.diag(np.log(eigvals))@R.T - self.assertArrayNear(logC, logCSpectral, 12) - - - def test_logm_iss_on_triple_degenerate_eigvalues(self): - A = 4.0*np.identity(3) - logA = TensorMath.logm_iss(A) - self.assertArrayNear(logA, np.log(4.0)*np.identity(3), 12) - - - def test_logm_iss_jit(self): - C = generate_n_random_symmetric_matrices(1)[0] - logC = self.logm_iss_jit(C) - self.assertFalse(np.isnan(logC).any()) - - - def test_logm_iss_on_full_3x3s(self): - mats = generate_n_random_symmetric_matrices(1000) - logMats = jax.vmap(self.logm_iss_jit, (0,))(mats) - shouldBeMats = jax.vmap(lambda A: linalg.expm(A), (0,))(logMats) - self.assertArrayNear(shouldBeMats, mats, 7) - - - def test_logm_iss_fwd_mode_derivative(self): - C = generate_n_random_symmetric_matrices(1)[0] - check_grads(self.logm_iss_jit, (C,), order=1, modes=['fwd']) - - - def test_logm_iss_rev_mode_derivative(self): - C = generate_n_random_symmetric_matrices(1)[0] - check_grads(self.logm_iss_jit, (C,), order=1, modes=['rev']) - - - def test_logm_iss_hessian_on_double_degenerate_eigenvalues(self): - eigvals = np.array([2., 0.5, 2.]) - C = R@np.diag(eigvals)@R.T - check_grads(jax.jacrev(TensorMath.logm_iss), (C,), order=1, modes=['fwd'], rtol=1e-9, atol=1e-9, eps=1e-5) - - - def test_logm_iss_derivatives_on_double_degenerate_eigenvalues(self): - eigvals = np.array([2., 0.5, 2.]) - C = R@np.diag(eigvals)@R.T - check_grads(TensorMath.logm_iss, (C,), order=1, modes=['fwd']) - check_grads(TensorMath.logm_iss, (C,), order=1, modes=['rev']) - - - def test_logm_iss_derivatives_on_triple_degenerate_eigenvalues(self): - A = 4.0*np.identity(3) - check_grads(TensorMath.logm_iss, (A,), order=1, modes=['fwd']) - check_grads(TensorMath.logm_iss, (A,), order=1, modes=['rev']) - - - def test_logm_iss_on_10x10(self): - key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (10,10), minval=1e-8, maxval=10.0) - C = F.T@F - logC = TensorMath.logm_iss(C) - logCSpectral = TensorMath.logh(C) - self.assertArrayNear(logC, logCSpectral, 12) + if __name__ == '__main__': From 046b8213e6630e47a461ff05a769b484f908c4b0 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Sun, 19 Nov 2023 07:18:35 -0800 Subject: [PATCH 02/17] Clean up format and give a few operations better names --- optimism/TensorMath.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index cf6a4594..7b8bf028 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -1,7 +1,10 @@ +"""Provide differentiable operations on 3x3 tensors.""" + from functools import partial import jax import jax.numpy as np +from optimism.JaxConfig import if_then_else from optimism import Math def trace(A): @@ -48,12 +51,12 @@ def norm_of_deviator_squared(tensor): def norm_of_deviator(tensor): return norm( compute_deviatoric_tensor(tensor) ) -def mises_equivalent_stress(stress): +def mises_invariant(stress): return np.sqrt(1.5)*norm_of_deviator(stress) def triaxiality(A): mean_normal = trace(A)/3.0 - mises_norm = mises_equivalent_stress(A) + mises_norm = mises_invariant(A) # avoid division by zero in case of spherical tensor mises_norm += np.finfo(np.dtype("float64")).eps return mean_normal/mises_norm @@ -66,23 +69,15 @@ def gradient_2D_to_axisymmetric(dudX_2D, u, X): dudX = dudX.at[2, 2].set(u[0]/X[0]) return dudX -# BT 11/2023 -# We should probably replace this with np.where and avoid duplication -def if_then_else(cond, val1, val2): - return jax.lax.cond(cond, - lambda x: val1, - lambda x: val2, - None) - -# Compute eigen values and vectors of a symmetric 3x3 tensor -# Note, returned eigen vectors may not be unit length -# -# Note, this routine involves high powers of the input tensor (~M^8). -# Thus results can start to denormalize when the infinity norm of the input -# tensor falls outside the range 1.0e-40 to 1.0e+40. -# -# Outside this range use eigen_sym33_unit def eigen_sym33_non_unit(tensor): + """Compute eigen values and vectors of a symmetric 3x3 tensor. + + Note, returned eigen vectors may not be unit length + Note, this routine involves high powers of the input tensor (~M^8). + Thus results can start to denormalize when the infinity norm of the input + tensor falls outside the range 1.0e-40 to 1.0e+40. + Outside this range use eigen_sym33_unit + """ cxx = tensor[0,0] cyy = tensor[1,1] czz = tensor[2,2] From 144b310cf4655a4fc1b12fa8ab93e4c0e161e2fa Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Sun, 19 Nov 2023 07:19:56 -0800 Subject: [PATCH 03/17] Give another tensor operator a more succinct and common name --- optimism/TensorMath.py | 8 ++++---- optimism/phasefield/PhaseFieldClassic.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index 7b8bf028..111848ed 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -29,11 +29,11 @@ def inv(A): [invA20, invA21, invA22]]) return invA -def compute_deviatoric_tensor(A): +def deviator(A): dil = trace(A) return A - (dil/3.)*np.identity(3) -def dev(strain): return compute_deviatoric_tensor(strain) +def dev(strain): return deviator(strain) def sym(A): return 0.5*(A + A.T) @@ -45,11 +45,11 @@ def norm(A): return Math.safe_sqrt(A.ravel() @ A.ravel()) def norm_of_deviator_squared(tensor): - dev = compute_deviatoric_tensor(tensor) + dev = deviator(tensor) return np.tensordot(dev,dev) def norm_of_deviator(tensor): - return norm( compute_deviatoric_tensor(tensor) ) + return norm( deviator(tensor) ) def mises_invariant(stress): return np.sqrt(1.5)*norm_of_deviator(stress) diff --git a/optimism/phasefield/PhaseFieldClassic.py b/optimism/phasefield/PhaseFieldClassic.py index fe1ac579..cbf0edc8 100644 --- a/optimism/phasefield/PhaseFieldClassic.py +++ b/optimism/phasefield/PhaseFieldClassic.py @@ -24,7 +24,7 @@ def degradation(phase): def intact_strain_energy_density(props, strain): dil = np.trace(strain) - dev = compute_deviatoric_tensor(strain).ravel() + dev = deviator(strain).ravel() return 0.5*props['kappa'] * dil**2 + props['mu'] * np.dot(dev,dev) From 48169d5ab2e09fb7d874b15cc8adb4c38c19ca19 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Sun, 19 Nov 2023 08:05:31 -0800 Subject: [PATCH 04/17] Add high precision determinant plus identity function and determinant tests --- optimism/TensorMath.py | 10 ++++++++- optimism/test/test_TensorMath.py | 37 +++++++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index 111848ed..da6d0f9e 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -10,10 +10,18 @@ def trace(A): return A[0, 0] + A[1, 1] + A[2, 2] +def I2(A): + trA = np.trace(A) + return 0.5*(trA*trA - A.ravel()@A.T.ravel()) + def det(A): return A[0, 0]*A[1, 1]*A[2, 2] + A[0, 1]*A[1, 2]*A[2, 0] + A[0, 2]*A[1, 0]*A[2, 1] \ - A[0, 0]*A[1, 2]*A[2, 1] - A[0, 1]*A[1, 0]*A[2, 2] - A[0, 2]*A[1, 1]*A[2, 0] +def detpIm1(A): + """Compute det(A + I) - 1 while preserving precision when A is small compared to the identity.""" + return trace(A) + I2(A) + det(A) + def inv(A): invA00 = A[1, 1]*A[2, 2] - A[1, 2]*A[2, 1] invA01 = A[0, 2]*A[2, 1] - A[0, 1]*A[2, 2] @@ -31,7 +39,7 @@ def inv(A): def deviator(A): dil = trace(A) - return A - (dil/3.)*np.identity(3) + return A - (dil/3)*np.identity(3) def dev(strain): return deviator(strain) diff --git a/optimism/test/test_TensorMath.py b/optimism/test/test_TensorMath.py index d2e56293..e90f4e1f 100644 --- a/optimism/test/test_TensorMath.py +++ b/optimism/test/test_TensorMath.py @@ -183,7 +183,6 @@ def pow_squared(A): return np.tensordot(lg, lg) check_grads(pow_squared, (C,), order=1) - def test_pow_squared_grad_rand(self): key = jax.random.PRNGKey(0) F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) @@ -194,10 +193,38 @@ def pow_squared(A): lg = TensorMath.mtk_pow(A, m) return np.tensordot(lg, lg) check_grads(pow_squared, (C,), order=1) - - - - + def test_determinant(self): + A = np.array([[5/9, 4/7, 2/11], + [7/9, 4/9, 1/5], + [1/3, 3/7, 17/18]]) + self.assertEqual(TensorMath.det(A), -45583/280665) + + def test_detpIm1(self): + A = np.array([[-8.7644781692191447986e-7, -0.00060943437636452272438, 0.0006160110345770283824], + [0.00059197095431573693372, -0.00032421698142571543644, -0.00075031460538177354586], + [-0.00057095032376313107833, 0.00042675236045286923589, -0.00029239794707394684004]]) + exact = -0.00061636368316760725654 # computed with exact arithmetic in Mathematica and truncated + val = TensorMath.detpIm1(A) + self.assertAlmostEqual(exact, val, 15) + + def test_determinant_precision(self): + eps = 1e-8 + A = np.diag(np.array([eps, eps, eps])) + # det(A + I) - 1 + exact = eps**3 + 3*eps**2 + 3*eps + + # straightforward approach loses precision + Jm1 = TensorMath.det(A + np.identity(3)) - 1 + error = np.abs((Jm1 - exact)/exact) + self.assertGreater(error, 1e-9) + + # special function retains precision + Jm1 = TensorMath.detpIm1(A) + error = np.abs((Jm1 - exact)/exact) + self.assertEqual(error, 0) + + + if __name__ == '__main__': unittest.main() From da454deffa72008384b6e166f4286457ca6d5550 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Mon, 20 Nov 2023 06:38:43 -0800 Subject: [PATCH 05/17] Add polar decomposition and matrix square root --- optimism/TensorMath.py | 22 +++++++++++++++++++++- optimism/test/test_TensorMath.py | 18 ++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index da6d0f9e..af7f363a 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -69,6 +69,23 @@ def triaxiality(A): mises_norm += np.finfo(np.dtype("float64")).eps return mean_normal/mises_norm +def right_polar_decomposition(F): + """Compute the right polar decomposition of a tensor. + + Computes the factors R and U such that R@U = F, where R is an + orthogonal matrix and U is symmetric positive semi-definite. + + Parameters: F : 3x3 matrix + + Returns: a tuple of the following arrays + R : orthogonal matrix + U : right stretch matrix + """ + C = F.T@F + U = mtk_sqrt(C) + R = F@inv(U) + return R, U + def tensor_2D_to_3D(H): return np.zeros((3,3)).at[ 0:H.shape[0], 0:H.shape[1] ].set(H) @@ -535,4 +552,7 @@ def log_jvp(Cpack, Hpack): return logSqrtC, sol - +def mtk_sqrt(A): + """Square root of a symmetric positive semi-definite tensor.""" + lam, V = eigen_sym33_unit(A) + return V @ np.diag(Math.safe_sqrt(lam)) @ V.T diff --git a/optimism/test/test_TensorMath.py b/optimism/test/test_TensorMath.py index e90f4e1f..ab8ff474 100644 --- a/optimism/test/test_TensorMath.py +++ b/optimism/test/test_TensorMath.py @@ -224,6 +224,24 @@ def test_determinant_precision(self): error = np.abs((Jm1 - exact)/exact) self.assertEqual(error, 0) + def test_right_polar_decomp(self): + key = jax.random.PRNGKey(0) + F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) + R, U = TensorMath.right_polar_decomposition(F) + # R is orthogonal + self.assertArrayNear(R@R.T, np.identity(3), 14) + self.assertArrayNear(R.T@R, np.identity(3), 14) + # U is symmetric + self.assertArrayNear(U, TensorMath.sym(U), 14) + # RU = F + self.assertArrayNear(R@U, F, 14) + + def test_tensor_sqrt(self): + eigvals = np.array([2., 0.5, 2.]) + C = R@np.diag(eigvals)@R.T + U = TensorMath.mtk_sqrt(C) + self.assertArrayNear(U, TensorMath.sym(U), 14) + self.assertArrayNear(U@U, C, 14) if __name__ == '__main__': From e257f497587689dc2bcf5209fa090daa37941e67 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Wed, 22 Nov 2023 16:14:05 -0800 Subject: [PATCH 06/17] Add way to make a tensor function out of any scalar function, implment matrix log, exp, and sqrt --- optimism/TensorMath.py | 146 +++++++++++++------------- optimism/test/test_TensorMath.py | 174 ++++++++++++++++++++++++------- 2 files changed, 210 insertions(+), 110 deletions(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index af7f363a..dbb9c313 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -82,7 +82,7 @@ def right_polar_decomposition(F): U : right stretch matrix """ C = F.T@F - U = mtk_sqrt(C) + U = sqrt_symm(C) R = F@inv(U) return R, U @@ -385,13 +385,17 @@ def mtk_log_sqrt_jvp(Cpack, Hpack): return logSqrtC, sol - @partial(jax.custom_jvp, nondiff_argnums=(1,)) def mtk_pow(A,m): lam,V = eigen_sym33_unit(A) return V @ np.diag(np.power(lam,m)) @ V.T +# BT 11/22/2023 +# This implementation is wrong - it's reusing the relative_log_difference +# function where it should be using one particular to the power function. +# I don't know how to compute that while avoiding catastrophic +# cancellation errors. Someone should fix this if they know how. @mtk_pow.defjvp def mtk_pow_jvp(m, Cpack, Hpack): C, = Cpack @@ -477,82 +481,84 @@ def relative_log_difference(lam1, lam2): relative_log_difference_no_tolerance_check(lam1, lamFake), relative_log_difference_taylor(lam1, lam2)) -# -# We should consider deprecating the following functions and use the mtk ones exclusively -# -def logh(A): - d,V = np.linalg.eigh(A) - return logh_from_eigen(d,V) +def symmetric_matrix_function(A, func): + """Create a function on symmetric matrices from a scalar function.""" + lam, V = eigen_sym33_unit(A) + return V@np.diag(func(lam))@V.T + +def _symmetric_matrix_function_jvp_helper(func, relative_difference, primals, tangents): + C, = primals + Cdot, = tangents + + lam, V = eigen_sym33_unit(C) + primal_out = V@np.diag(func(lam))@V.T + + df = jax.jacfwd(func) + h_diag = jax.vmap(df)(lam) + def rd(x1, x2): + x2_safe = np.where(x2 == x1, x1 + 1.0, x2) + return np.where(x2 == x1, df(x1), relative_difference(x1, x2_safe)) + h12 = rd(lam[0], lam[1]) + h23 = rd(lam[1], lam[2]) + h31 = rd(lam[2], lam[0]) + h = np.array([[h_diag[0], h12, h31], + [h12, h_diag[1], h23], + [h31, h23, h_diag[2]]]) + W = V.T@sym(Cdot)@V + h = h*W + + t00 = V[0].T@h@V[0] + t11 = V[1].T@h@V[1] + t22 = V[2].T@h@V[2] + t01 = V[0].T@h@V[1] + t12 = V[1].T@h@V[2] + t20 = V[2].T@h@V[0] + sol = np.array([ [t00, t01, t20], + [t01, t11, t12], + [t20, t12, t22] ]) -def logh_from_eigen(eVals, eVecs): - return eVecs@np.diag(np.log(eVals))@eVecs.T + return primal_out, sol -# C must be symmetric! @jax.custom_jvp -def log_sqrt(C): - return 0.5*logh(C) +def sqrt_symm(A): + """Square root of a symmetric positive semi-definite tensor.""" + return symmetric_matrix_function(A, Math.safe_sqrt) -@log_sqrt.defjvp -def log_jvp(Cpack, Hpack): - C, = Cpack - H, = Hpack +def _sqrt_relative_difference(lam1, lam2): + return 1/(np.sqrt(lam1) + np.sqrt(lam2)) - logSqrtC = log_sqrt(C) - lam,V = np.linalg.eigh(C) - - lam1 = lam[0] - lam2 = lam[1] - lam3 = lam[2] - - e1 = V[:,0] - e2 = V[:,1] - e3 = V[:,2] - - hHat = 0.5 * (V.T @ H @ V) +@sqrt_symm.defjvp +def sqrt_symm_jvp(primals, tangents): + return _symmetric_matrix_function_jvp_helper(Math.safe_sqrt, _sqrt_relative_difference, primals, tangents) - l1111 = hHat[0,0] / lam1 - l2222 = hHat[1,1] / lam2 - l3333 = hHat[2,2] / lam3 - - l1212 = 0.5*(hHat[0,1]+hHat[1,0]) * relative_log_difference(lam1, lam2) - l2323 = 0.5*(hHat[1,2]+hHat[2,1]) * relative_log_difference(lam2, lam3) - l3131 = 0.5*(hHat[2,0]+hHat[0,2]) * relative_log_difference(lam3, lam1) +@jax.custom_jvp +def exp_symm(A): + """Compute the matrix exponential of a symmetric matrix.""" + return symmetric_matrix_function(A, np.exp) - t00 = l1111 * e1[0] * e1[0] + l2222 * e2[0] * e2[0] + l3333 * e3[0] * e3[0] + \ - 2 * l1212 * e1[0] * e2[0] + \ - 2 * l2323 * e2[0] * e3[0] + \ - 2 * l3131 * e3[0] * e1[0] - t11 = l1111 * e1[1] * e1[1] + l2222 * e2[1] * e2[1] + l3333 * e3[1] * e3[1] + \ - 2 * l1212 * e1[1] * e2[1] + \ - 2 * l2323 * e2[1] * e3[1] + \ - 2 * l3131 * e3[1] * e1[1] - t22 = l1111 * e1[2] * e1[2] + l2222 * e2[2] * e2[2] + l3333 * e3[2] * e3[2] + \ - 2 * l1212 * e1[2] * e2[2] + \ - 2 * l2323 * e2[2] * e3[2] + \ - 2 * l3131 * e3[2] * e1[2] +def _exp_relative_difference(lam1, lam2): + arg = lam1 - lam2 + return np.exp(lam2)*np.expm1(arg)/arg - t01 = l1111 * e1[0] * e1[1] + l2222 * e2[0] * e2[1] + l3333 * e3[0] * e3[1] + \ - l1212 * (e1[0] * e2[1] + e2[0] * e1[1]) + \ - l2323 * (e2[0] * e3[1] + e3[0] * e2[1]) + \ - l3131 * (e3[0] * e1[1] + e1[0] * e3[1]) - t12 = l1111 * e1[1] * e1[2] + l2222 * e2[1] * e2[2] + l3333 * e3[1] * e3[2] + \ - l1212 * (e1[1] * e2[2] + e2[1] * e1[2]) + \ - l2323 * (e2[1] * e3[2] + e3[1] * e2[2]) + \ - l3131 * (e3[1] * e1[2] + e1[1] * e3[2]) - t20 = l1111 * e1[2] * e1[0] + l2222 * e2[2] * e2[0] + l3333 * e3[2] * e3[0] + \ - l1212 * (e1[2] * e2[0] + e2[2] * e1[0]) + \ - l2323 * (e2[2] * e3[0] + e3[2] * e2[0]) + \ - l3131 * (e3[2] * e1[0] + e1[2] * e3[0]) - - sol = np.array([ [t00, t01, t20], - [t01, t11, t12], - [t20, t12, t22] ]) - - return logSqrtC, sol +@exp_symm.defjvp +def exp_symm_jvp(primals, tangents): + return _symmetric_matrix_function_jvp_helper(np.exp, _exp_relative_difference, primals, tangents) -def mtk_sqrt(A): - """Square root of a symmetric positive semi-definite tensor.""" - lam, V = eigen_sym33_unit(A) - return V @ np.diag(Math.safe_sqrt(lam)) @ V.T +@jax.custom_jvp +def log_symm(A): + """Compute the matrix logarithm of a symmetric positive definite matrix.""" + return symmetric_matrix_function(A, np.log) + +def _log_relative_difference(lam1, lam2): + arg = lam1/lam2 - 1 + return (np.log1p(arg)/arg)/lam2 + +@log_symm.defjvp +def log_symm_jvp(primals, tangents): + return _symmetric_matrix_function_jvp_helper(np.log, _log_relative_difference, primals, tangents) + +def log_sqrt_symm(A): + """Compute matrix logarithm of the square root of a symmetric positive definite matrix.""" + return 0.5*log_symm(A) \ No newline at end of file diff --git a/optimism/test/test_TensorMath.py b/optimism/test/test_TensorMath.py index ab8ff474..1e7e4421 100644 --- a/optimism/test/test_TensorMath.py +++ b/optimism/test/test_TensorMath.py @@ -28,40 +28,12 @@ def lam(A): class TensorMathFixture(TestFixture): def setUp(self): + key = jax.random.PRNGKey(1) + self.R = jax.random.orthogonal(key, 3) + self.assertGreater(np.linalg.det(self.R), 0) # make sure this is a rotation and not a reflection self.log_squared = lambda A: np.tensordot(TensorMath.log_sqrt(A), TensorMath.log_sqrt(A)) - def test_log_sqrt_tensor_jvp_0(self): - A = np.array([ [2.0, 0.0, 0.0], - [0.0, 1.2, 0.0], - [0.0, 0.0, 2.0] ]) - - check_grads(self.log_squared, (A,), order=1) - - - def test_log_sqrt_tensor_jvp_1(self): - A = np.array([ [2.0, 0.0, 0.0], - [0.0, 1.2, 0.0], - [0.0, 0.0, 3.0] ]) - - check_grads(self.log_squared, (A,), order=1) - - - def test_log_sqrt_tensor_jvp_2(self): - A = np.array([ [2.0, 0.0, 0.2], - [0.0, 1.2, 0.1], - [0.2, 0.1, 3.0] ]) - - check_grads(self.log_squared, (A,), order=1) - - - @unittest.expectedFailure - def test_log_sqrt_hessian_on_double_degenerate_eigenvalues(self): - eigvals = np.array([2., 0.5, 2.]) - C = R@np.diag(eigvals)@R.T - check_grads(jax.jacrev(TensorMath.log_sqrt), (C,), order=1, modes=['fwd'], rtol=1e-9, atol=1e-9, eps=1e-5) - - def test_eigen_sym33_non_unit(self): key = jax.random.PRNGKey(0) F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) @@ -132,8 +104,137 @@ def log_squared(A): lg = TensorMath.mtk_log_sqrt(A) return np.tensordot(lg, lg) check_grads(log_squared, (C,), order=1) - - + + # log_symm tests + + def test_log_symm_scaled_identity(self): + val = 1.2 + C = np.diag(np.array([val, val, val])) + logVal = np.log(val) + self.assertArrayNear(TensorMath.log_symm(C), np.diag(np.array([logVal, logVal, logVal])), 12) + + def test_log_symm_double_eigs(self): + val1 = 2.0 + val2 = 0.5 + C = self.R@np.diag(np.array([val1, val2, val1]))@self.R.T + + log1 = np.log(val1) + log2 = np.log(val2) + diagLog = np.diag(np.array([log1, log2, log1])) + + logCExpected = self.R@diagLog@self.R.T + self.assertArrayNear(TensorMath.log_symm(C), logCExpected, 12) + + def test_log_symm_gradient_scaled_identity(self): + val = 1.2 + C = np.diag(np.array([val, val, val])) + check_grads(TensorMath.log_symm, (C,), order=1) + + def test_log_symm_gradient_double_eigs(self): + val1 = 2.0 + val2 = 0.5 + C = self.R@np.diag(np.array([val1, val2, val1]))@self.R.T + check_grads(TensorMath.log_symm, (C,), order=1) + + def test_log_symm_gradient_distinct_eigenvalues(self): + key = jax.random.PRNGKey(0) + F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) + C = F.T@F + check_grads(TensorMath.log_symm, (C,), order=1) + + def test_log_symm_gradient_almost_double_degenerate(self): + C = self.R@np.diag(np.array([2.1, 2.1 + 1e-8, 3.0]))@self.R.T + check_grads(TensorMath.log_symm, (C,), order=1, eps=1e-10) + + # sqrt_symm_tests + + def test_sqrt_symm(self): + key = jax.random.PRNGKey(0) + F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) + C = F.T@F + U = TensorMath.sqrt_symm(C) + self.assertArrayNear(U@U, C, 12) + + def test_sqrt_symm_scaled_identity(self): + val = 1.2 + C = np.diag(np.array([val, val, val])) + sqrtVal = np.sqrt(val) + self.assertArrayNear(TensorMath.sqrt_symm(C), np.diag(np.array([sqrtVal, sqrtVal, sqrtVal])), 12) + + def test_sqrt_symm_double_eigs(self): + val1 = 2.0 + val2 = 0.5 + C = self.R@np.diag(np.array([val1, val2, val1]))@self.R.T + sqrt1 = np.sqrt(val1) + sqrt = np.sqrt(val2) + diagSqrt = np.diag(np.array([sqrt1, sqrt, sqrt1])) + + sqrtCExpected = self.R@diagSqrt@self.R.T + self.assertArrayNear(TensorMath.sqrt_symm(C), sqrtCExpected, 12) + + def test_sqrt_symm_gradient_scaled_identity(self): + val = 1.2 + C = np.diag(np.array([val, val, val])) + check_grads(TensorMath.sqrt_symm, (C,), order=1) + + def test_sqrt_symm_gradient_double_eigs(self): + val1 = 2.0 + val2 = 0.5 + C = self.R@np.diag(np.array([val1, val2, val1]))@self.R.T + check_grads(TensorMath.sqrt_symm, (C,), order=1) + + def test_sqrt_symm_gradient_distinct_eigenvalues(self): + key = jax.random.PRNGKey(0) + F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) + C = F.T@F + check_grads(TensorMath.sqrt_symm, (C,), order=1) + + def test_sqrt_symm_gradient_almost_double_degenerate(self): + C = self.R@np.diag(np.array([2.1, 2.1 + 1e-8, 3.0]))@self.R.T + check_grads(TensorMath.sqrt_symm, (C,), order=1, eps=1e-10) + + ### exp_symm tests + def test_exp_symm_at_identity(self): + I = TensorMath.exp_symm(np.zeros((3, 3))) + self.assertArrayNear(I, np.identity(3), 12) + + def test_exp_symm_scaled_identity(self): + val = 1.2 + C = np.diag(np.array([val, val, val])) + expVal = np.exp(val) + self.assertArrayNear(TensorMath.exp_symm(C), np.diag(np.array([expVal, expVal, expVal])), 12) + + def test_exp_symm_double_eigs(self): + val1 = 2.0 + val2 = 0.5 + C = self.R@np.diag(np.array([val1, val2, val1]))@self.R.T + exp1 = np.exp(val1) + exp2 = np.exp(val2) + diagExp = np.diag(np.array([exp1, exp2, exp1])) + expCExpected = self.R@diagExp@self.R.T + self.assertArrayNear(TensorMath.exp_symm(C), expCExpected, 12) + + def test_exp_symm_gradient_scaled_identity(self): + val = 1.2 + C = np.diag(np.array([val, val, val])) + check_grads(TensorMath.exp_symm, (C,), order=1) + + def test_exp_symm_gradient_double_eigs(self): + val1 = 2.0 + val2 = 0.5 + C = self.R@np.diag(np.array([val1, val2, val1]))@self.R.T + check_grads(TensorMath.exp_symm, (C,), order=1) + + def test_exp_symm_gradient_distinct_eigenvalues(self): + key = jax.random.PRNGKey(0) + F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) + C = F.T@F + check_grads(TensorMath.exp_symm, (C,), order=1) + + def test_sqrt_symm_gradient_almost_double_degenerate(self): + C = self.R@np.diag(np.array([2.1, 2.1 + 1e-8, 3.0]))@self.R.T + check_grads(TensorMath.exp_symm, (C,), order=1, eps=1e-10) + ### mtk_pow tests ### @@ -236,13 +337,6 @@ def test_right_polar_decomp(self): # RU = F self.assertArrayNear(R@U, F, 14) - def test_tensor_sqrt(self): - eigvals = np.array([2., 0.5, 2.]) - C = R@np.diag(eigvals)@R.T - U = TensorMath.mtk_sqrt(C) - self.assertArrayNear(U, TensorMath.sym(U), 14) - self.assertArrayNear(U@U, C, 14) - if __name__ == '__main__': unittest.main() From 50e8b7fc23f92a84f1c88a2f18e14ca0072b881c Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Wed, 22 Nov 2023 17:55:46 -0800 Subject: [PATCH 07/17] Write more comments --- optimism/TensorMath.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index dbb9c313..bd49a357 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -487,6 +487,12 @@ def symmetric_matrix_function(A, func): lam, V = eigen_sym33_unit(A) return V@np.diag(func(lam))@V.T +# Helper function to define the JVP for any matrix function created from a +# scalar function func. +# To use, you must provide a function +# relative_difference: lam1, lam2 -> (func(lam1) - func(lam2))/(lam1 - lam2) +# Ideally, this should be formulated such that it does not suffer from cancellation +# error as lam1 -> lam2. def _symmetric_matrix_function_jvp_helper(func, relative_difference, primals, tangents): C, = primals Cdot, = tangents From f8273741e0e392e9432d3c128e4e4138a0286929 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Tue, 28 Nov 2023 11:37:09 -0800 Subject: [PATCH 08/17] Sort eigenvalues in log relative difference functino so that it is exactly symmetric --- optimism/TensorMath.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index bd49a357..daef1da9 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -489,7 +489,7 @@ def symmetric_matrix_function(A, func): # Helper function to define the JVP for any matrix function created from a # scalar function func. -# To use, you must provide a function +# To use, you must provide the function # relative_difference: lam1, lam2 -> (func(lam1) - func(lam2))/(lam1 - lam2) # Ideally, this should be formulated such that it does not suffer from cancellation # error as lam1 -> lam2. @@ -503,7 +503,7 @@ def _symmetric_matrix_function_jvp_helper(func, relative_difference, primals, ta df = jax.jacfwd(func) h_diag = jax.vmap(df)(lam) def rd(x1, x2): - x2_safe = np.where(x2 == x1, x1 + 1.0, x2) + x2_safe = np.where(x2 == x1, x2 + 1.0, x2) return np.where(x2 == x1, df(x1), relative_difference(x1, x2_safe)) h12 = rd(lam[0], lam[1]) h23 = rd(lam[1], lam[2]) @@ -512,14 +512,14 @@ def rd(x1, x2): [h12, h_diag[1], h23], [h31, h23, h_diag[2]]]) W = V.T@sym(Cdot)@V - h = h*W + h *= W t00 = V[0].T@h@V[0] t11 = V[1].T@h@V[1] t22 = V[2].T@h@V[2] - t01 = V[0].T@h@V[1] - t12 = V[1].T@h@V[2] - t20 = V[2].T@h@V[0] + t01 = 0.5*(V[0].T@h@V[1] + V[1].T@h@V[0]) + t12 = 0.5*(V[1].T@h@V[2] + V[2].T@h@V[1]) + t20 = 0.5*(V[2].T@h@V[0] + V[0].T@h@V[2]) sol = np.array([ [t00, t01, t20], [t01, t11, t12], @@ -558,8 +558,10 @@ def log_symm(A): return symmetric_matrix_function(A, np.log) def _log_relative_difference(lam1, lam2): - arg = lam1/lam2 - 1 - return (np.log1p(arg)/arg)/lam2 + lams = np.array([lam1, lam2]) + i = np.argsort(np.abs(lams)) + arg = lams[i[0]]/lams[i[1]] - 1 + return (np.log1p(arg)/arg)/lams[i[1]] @log_symm.defjvp def log_symm_jvp(primals, tangents): From a0aa600779826b47de4a2bab6edaa5b901b81c64 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Tue, 28 Nov 2023 16:18:49 -0800 Subject: [PATCH 09/17] Fix higher order derivatives of matrix functions Before: the primal output was computed in-line in the custom jvp function. The advantage is that this avoids a second call to the eigendecomposition. The downside is that this in-line computation doesn't itself have a custom jvp, so its derivative can be wrong. After: I re-compute the primal value through the base function (e.g., log_symm), which has the custom jvp defined on it. The eigendecomposition is repeated. We can refactor to eliminate this later if profiling reveals it to be a performance bottleneck. --- optimism/TensorMath.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index daef1da9..1fd18ce3 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -497,8 +497,12 @@ def _symmetric_matrix_function_jvp_helper(func, relative_difference, primals, ta C, = primals Cdot, = tangents + # it is tempting to compute the primal output here as + # V@np.diag(func(lam))@V.T + # and avoid the cost of doing the eigendecomp twice. + # Hoever, this will not attach the custom jvp to the primal output + # computation, making higher order derivatives wrong! lam, V = eigen_sym33_unit(C) - primal_out = V@np.diag(func(lam))@V.T df = jax.jacfwd(func) h_diag = jax.vmap(df)(lam) @@ -525,7 +529,7 @@ def rd(x1, x2): [t01, t11, t12], [t20, t12, t22] ]) - return primal_out, sol + return sol @jax.custom_jvp def sqrt_symm(A): @@ -537,7 +541,8 @@ def _sqrt_relative_difference(lam1, lam2): @sqrt_symm.defjvp def sqrt_symm_jvp(primals, tangents): - return _symmetric_matrix_function_jvp_helper(Math.safe_sqrt, _sqrt_relative_difference, primals, tangents) + primal_out = sqrt_symm(*primals) + return primal_out, _symmetric_matrix_function_jvp_helper(Math.safe_sqrt, _sqrt_relative_difference, primals, tangents) @jax.custom_jvp def exp_symm(A): @@ -550,7 +555,8 @@ def _exp_relative_difference(lam1, lam2): @exp_symm.defjvp def exp_symm_jvp(primals, tangents): - return _symmetric_matrix_function_jvp_helper(np.exp, _exp_relative_difference, primals, tangents) + primal_out = exp_symm(*primals) + return primal_out, _symmetric_matrix_function_jvp_helper(np.exp, _exp_relative_difference, primals, tangents) @jax.custom_jvp def log_symm(A): @@ -565,7 +571,8 @@ def _log_relative_difference(lam1, lam2): @log_symm.defjvp def log_symm_jvp(primals, tangents): - return _symmetric_matrix_function_jvp_helper(np.log, _log_relative_difference, primals, tangents) + primal_out = log_symm(*primals) + return primal_out, _symmetric_matrix_function_jvp_helper(np.log, _log_relative_difference, primals, tangents) def log_sqrt_symm(A): """Compute matrix logarithm of the square root of a symmetric positive definite matrix.""" From 70530e97d136e3c94aba85af6476f8a1c1e208fd Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Wed, 29 Nov 2023 06:46:04 -0800 Subject: [PATCH 10/17] Add test that shows jvp of matrix power function is wrong --- optimism/test/test_TensorMath.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/optimism/test/test_TensorMath.py b/optimism/test/test_TensorMath.py index 1e7e4421..ab7b78a8 100644 --- a/optimism/test/test_TensorMath.py +++ b/optimism/test/test_TensorMath.py @@ -237,6 +237,11 @@ def test_sqrt_symm_gradient_almost_double_degenerate(self): ### mtk_pow tests ### + def test_pow_symm_gradient_distinct_eigenvalues(self): + key = jax.random.PRNGKey(0) + F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) + C = F.T@F + check_grads(TensorMath.mtk_pow, (C, 0.25), order=1) def test_pow_scaled_identity(self): m = 0.25 From 165262799fb4dc4858a51349ca78cb6cf72b66ad Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Thu, 30 Nov 2023 06:31:06 -0800 Subject: [PATCH 11/17] Add a new matrix power function with a correct gradient --- optimism/TensorMath.py | 37 +++++++++++++- optimism/test/test_TensorMath.py | 82 +++++++++++++------------------- 2 files changed, 68 insertions(+), 51 deletions(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index 1fd18ce3..f6fa4869 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -576,4 +576,39 @@ def log_symm_jvp(primals, tangents): def log_sqrt_symm(A): """Compute matrix logarithm of the square root of a symmetric positive definite matrix.""" - return 0.5*log_symm(A) \ No newline at end of file + return 0.5*log_symm(A) + +@jax.custom_jvp +def pow_symm(A, m): + """Raise a symmetric matrix to a power. + + Compute m-fold iterated matrix multiplication of A. Works correctly with + negative powers, but recall that the matrix must be invertible + (or else a matrix of inf or nan will result). This function is not + differentiable in the `m` argument. + + Arguments: + A: (array) symmetric matrix to raise to a power + m: (float or int) power + + .. note:: The derivative of this function is inaccurate on matrices + with nearly degenerate eigenvalues. We lack a high-quality implementation + of the relative difference of the eigenvalue function. (The derivative of + matrices with exactly equal eigenvalues is computed correctly). + """ + return symmetric_matrix_function(A, lambda x: np.power(x, m)) + +# This function loses precision when lam1 -> lam2. +# Please replace with a numerically stable implmentation if you know how! +def _pow_relative_difference(lam1, lam2, m): + lams = np.array([lam1, lam2]) + i = np.argsort(np.abs(lams)) + lam_small, lam_big = lams[i] + arg = lam_small/lam_big + return lam_big**(m-1)*(arg**m - 1)/(arg - 1) + +@pow_symm.defjvp +def pow_symm_jvp(primals, tangents): + A, m = primals + dA, dm = tangents + return pow_symm(A, m), _symmetric_matrix_function_jvp_helper(lambda x: np.power(x, m), lambda l1, l2: _pow_relative_difference(l1, l2, m), (A,), (dA,)) diff --git a/optimism/test/test_TensorMath.py b/optimism/test/test_TensorMath.py index ab7b78a8..169fb740 100644 --- a/optimism/test/test_TensorMath.py +++ b/optimism/test/test_TensorMath.py @@ -144,7 +144,7 @@ def test_log_symm_gradient_distinct_eigenvalues(self): def test_log_symm_gradient_almost_double_degenerate(self): C = self.R@np.diag(np.array([2.1, 2.1 + 1e-8, 3.0]))@self.R.T - check_grads(TensorMath.log_symm, (C,), order=1, eps=1e-10) + check_grads(TensorMath.log_symm, (C,), order=1, atol=1e-16, eps=1e-10) # sqrt_symm_tests @@ -234,71 +234,53 @@ def test_exp_symm_gradient_distinct_eigenvalues(self): def test_sqrt_symm_gradient_almost_double_degenerate(self): C = self.R@np.diag(np.array([2.1, 2.1 + 1e-8, 3.0]))@self.R.T check_grads(TensorMath.exp_symm, (C,), order=1, eps=1e-10) - - ### mtk_pow tests ### - - def test_pow_symm_gradient_distinct_eigenvalues(self): - key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) - C = F.T@F - check_grads(TensorMath.mtk_pow, (C, 0.25), order=1) - def test_pow_scaled_identity(self): - m = 0.25 - val = 1.2 - C = np.diag(np.array([val, val, val])) + # pow_symm tests + def test_pow_symm_scaled_identity(self): + val = 1.2 + C = val*np.identity(3) + m = 3 powVal = np.power(val, m) - self.assertArrayNear(TensorMath.mtk_pow(C,m), np.diag(np.array([powVal, powVal, powVal])), 12) - + self.assertArrayNear(TensorMath.pow_symm(C, m), np.diag(np.array([powVal, powVal, powVal])), 12) - def test_pow_double_eigs(self): + def test_pow_symm_double_eigs(self): + val1 = 2.0 + val2 = 0.5 + C = self.R@np.diag(np.array([val1, val2, val1]))@self.R.T m = 0.25 - val1 = 2.1 - val2 = 0.6 - C = R@np.diag(np.array([val1, val2, val1]))@R.T + pow1 = np.power(val1, m) + pow2 = np.power(val2, m) + diagPow = np.diag(np.array([pow1, pow2, pow1])) + powCExpected = self.R@diagPow@self.R.T + self.assertArrayNear(TensorMath.pow_symm(C, m), powCExpected, 12) - powVal1 = np.power(val1, m) - powVal2 = np.power(val2, m) - diagLogSqrt = np.diag(np.array([powVal1, powVal2, powVal1])) - - logSqrtCExpected = R@diagLogSqrt@R.T - - self.assertArrayNear(TensorMath.mtk_pow(C,m), logSqrtCExpected, 12) - - - def test_pow_squared_grad_scaled_identity(self): + def test_pow_symm_gradient_scaled_identity(self): val = 1.2 C = np.diag(np.array([val, val, val])) + m = 3 + check_grads(lambda A: TensorMath.pow_symm(A, m), (C,), order=1) - def pow_squared(A): - m = 0.25 - lg = TensorMath.mtk_pow(A, m) - return np.tensordot(lg, lg) - check_grads(pow_squared, (C,), order=1) - - - def test_pow_squared_grad_double_eigs(self): + def test_pow_symm_gradient_double_eigs(self): val1 = 2.0 val2 = 0.5 - C = R@np.diag(np.array([val1, val2, val1]))@R.T - - def pow_squared(A): - m=0.25 - lg = TensorMath.mtk_pow(A, m) - return np.tensordot(lg, lg) - check_grads(pow_squared, (C,), order=1) + C = self.R@np.diag(np.array([val1, val2, val1]))@self.R.T + m = 3 + check_grads(lambda A: TensorMath.pow_symm(A, m), (C,), order=1) - def test_pow_squared_grad_rand(self): + def test_pow_symm_gradient_distinct_eigenvalues(self): key = jax.random.PRNGKey(0) F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) C = F.T@F + m = 0.25 + check_grads(lambda A: TensorMath.pow_symm(C, m), (C,), order=1) + + @unittest.expectedFailure + def test_pow_symm_gradient_almost_double_degenerate(self): + C = self.R@np.diag(np.array([2.1, 2.1 + 1e-8, 3.0]))@self.R.T + m = 0.25 + check_grads(lambda A: TensorMath.pow_symm(A, 0.25), (C,), order=1, atol=1e-16, eps=1e-10) - def pow_squared(A): - m=0.25 - lg = TensorMath.mtk_pow(A, m) - return np.tensordot(lg, lg) - check_grads(pow_squared, (C,), order=1) def test_determinant(self): A = np.array([[5/9, 4/7, 2/11], From 46bef827cf948e301a94edf4b04dac2e6483adc0 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Thu, 30 Nov 2023 09:12:26 -0800 Subject: [PATCH 12/17] Remove the broken mtk_power function --- optimism/TensorMath.py | 73 +--------------------------------- optimism/material/J2Plastic.py | 2 +- 2 files changed, 3 insertions(+), 72 deletions(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index f6fa4869..84cb41bf 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -307,8 +307,8 @@ def eigen_sym33_unit(tensor): # 6 and 5 indicate the polynomial order in the numerator and denominator. def cos_of_acos_divided_by_3(x): - x2 = x*x; - x4 = x2*x2; + x2 = x*x + x4 = x2*x2 numer = 0.866025403784438713 + 2.12714890259493060 * x + \ ( ( 1.89202064815951569 + 0.739603278343401613 * x ) * x2 + \ @@ -385,75 +385,6 @@ def mtk_log_sqrt_jvp(Cpack, Hpack): return logSqrtC, sol -@partial(jax.custom_jvp, nondiff_argnums=(1,)) -def mtk_pow(A,m): - lam,V = eigen_sym33_unit(A) - return V @ np.diag(np.power(lam,m)) @ V.T - - -# BT 11/22/2023 -# This implementation is wrong - it's reusing the relative_log_difference -# function where it should be using one particular to the power function. -# I don't know how to compute that while avoiding catastrophic -# cancellation errors. Someone should fix this if they know how. -@mtk_pow.defjvp -def mtk_pow_jvp(m, Cpack, Hpack): - C, = Cpack - H, = Hpack - - powC = mtk_pow(C,m) - lam,V = eigen_sym33_unit(C) - - lam1 = lam[0] - lam2 = lam[1] - lam3 = lam[2] - - e1 = V[:,0] - e2 = V[:,1] - e3 = V[:,2] - - hHat = m * (V.T @ H @ V) - - l1111 = hHat[0,0] * np.power(lam1, m-1) - l2222 = hHat[1,1] * np.power(lam2, m-1) - l3333 = hHat[2,2] * np.power(lam3, m-1) - - l1212 = 0.5*(hHat[0,1]+hHat[1,0]) * relative_log_difference(lam1, lam2) - l2323 = 0.5*(hHat[1,2]+hHat[2,1]) * relative_log_difference(lam2, lam3) - l3131 = 0.5*(hHat[2,0]+hHat[0,2]) * relative_log_difference(lam3, lam1) - - t00 = l1111 * e1[0] * e1[0] + l2222 * e2[0] * e2[0] + l3333 * e3[0] * e3[0] + \ - 2 * l1212 * e1[0] * e2[0] + \ - 2 * l2323 * e2[0] * e3[0] + \ - 2 * l3131 * e3[0] * e1[0] - t11 = l1111 * e1[1] * e1[1] + l2222 * e2[1] * e2[1] + l3333 * e3[1] * e3[1] + \ - 2 * l1212 * e1[1] * e2[1] + \ - 2 * l2323 * e2[1] * e3[1] + \ - 2 * l3131 * e3[1] * e1[1] - t22 = l1111 * e1[2] * e1[2] + l2222 * e2[2] * e2[2] + l3333 * e3[2] * e3[2] + \ - 2 * l1212 * e1[2] * e2[2] + \ - 2 * l2323 * e2[2] * e3[2] + \ - 2 * l3131 * e3[2] * e1[2] - - t01 = l1111 * e1[0] * e1[1] + l2222 * e2[0] * e2[1] + l3333 * e3[0] * e3[1] + \ - l1212 * (e1[0] * e2[1] + e2[0] * e1[1]) + \ - l2323 * (e2[0] * e3[1] + e3[0] * e2[1]) + \ - l3131 * (e3[0] * e1[1] + e1[0] * e3[1]) - t12 = l1111 * e1[1] * e1[2] + l2222 * e2[1] * e2[2] + l3333 * e3[1] * e3[2] + \ - l1212 * (e1[1] * e2[2] + e2[1] * e1[2]) + \ - l2323 * (e2[1] * e3[2] + e3[1] * e2[2]) + \ - l3131 * (e3[1] * e1[2] + e1[1] * e3[2]) - t20 = l1111 * e1[2] * e1[0] + l2222 * e2[2] * e2[0] + l3333 * e3[2] * e3[0] + \ - l1212 * (e1[2] * e2[0] + e2[2] * e1[0]) + \ - l2323 * (e2[2] * e3[0] + e3[2] * e2[0]) + \ - l3131 * (e3[2] * e1[0] + e1[2] * e3[0]) - - sol = np.array([ [t00, t01, t20], - [t01, t11, t12], - [t20, t12, t22] ]) - - return powC, sol - def relative_log_difference_taylor(lam1, lam2): # Compute a more accurate (mtk::log(lam1) - log(lam2)) / (lam1-lam2) as lam1 -> lam2 diff --git a/optimism/material/J2Plastic.py b/optimism/material/J2Plastic.py index a7458cfe..d76a015c 100644 --- a/optimism/material/J2Plastic.py +++ b/optimism/material/J2Plastic.py @@ -237,6 +237,6 @@ def compute_elastic_linear_strain(dispGrad, state): def compute_elastic_seth_hill_strain(dispGrad, state): m=0.25 C = dispGrad.T@dispGrad - strain = (TensorMath.mtk_pow(C,m) - np.identity(3)) / (2*m) + strain = (TensorMath.pow_symm(C,m) - np.identity(3)) / (2*m) plasticStrain = state[PLASTIC_STRAIN].reshape((3,3)) return strain - plasticStrain From fa2f3e193402837bdc3de0ba0a282033c9cda5c7 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Thu, 30 Nov 2023 09:17:50 -0800 Subject: [PATCH 13/17] Hide internal helper functions Put a leading underscore on functions menat for internal use. Most Python tools will ignore these when reporting contents of a module. --- optimism/TensorMath.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index 84cb41bf..566344c4 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -328,7 +328,7 @@ def mtk_log_sqrt(A): @mtk_log_sqrt.defjvp -def mtk_log_sqrt_jvp(Cpack, Hpack): +def _mtk_log_sqrt_jvp(Cpack, Hpack): C, = Cpack H, = Hpack @@ -349,9 +349,9 @@ def mtk_log_sqrt_jvp(Cpack, Hpack): l2222 = hHat[1,1] / lam2 l3333 = hHat[2,2] / lam3 - l1212 = 0.5*(hHat[0,1]+hHat[1,0]) * relative_log_difference(lam1, lam2) - l2323 = 0.5*(hHat[1,2]+hHat[2,1]) * relative_log_difference(lam2, lam3) - l3131 = 0.5*(hHat[2,0]+hHat[0,2]) * relative_log_difference(lam3, lam1) + l1212 = 0.5*(hHat[0,1]+hHat[1,0]) * _relative_log_difference(lam1, lam2) + l2323 = 0.5*(hHat[1,2]+hHat[2,1]) * _relative_log_difference(lam2, lam3) + l3131 = 0.5*(hHat[2,0]+hHat[0,2]) * _relative_log_difference(lam3, lam1) t00 = l1111 * e1[0] * e1[0] + l2222 * e2[0] * e2[0] + l3333 * e3[0] * e3[0] + \ 2 * l1212 * e1[0] * e2[0] + \ @@ -386,7 +386,7 @@ def mtk_log_sqrt_jvp(Cpack, Hpack): return logSqrtC, sol -def relative_log_difference_taylor(lam1, lam2): +def _relative_log_difference_taylor(lam1, lam2): # Compute a more accurate (mtk::log(lam1) - log(lam2)) / (lam1-lam2) as lam1 -> lam2 third2 = 2.0 / 3.0 fifth2 = 2.0 / 5.0 @@ -401,16 +401,16 @@ def relative_log_difference_taylor(lam1, lam2): return (2.0 + third2 * frac2 + fifth2 * frac4 + seventh2 * frac4 * frac2 + ninth2 * frac4 * frac4) / (lam1 + lam2) -def relative_log_difference_no_tolerance_check(lam1, lam2): +def _relative_log_difference_no_tolerance_check(lam1, lam2): return np.log(lam1 / lam2) / (lam1 - lam2) -def relative_log_difference(lam1, lam2): +def _relative_log_difference(lam1, lam2): haveLargeDiff = np.abs(lam1 - lam2) > 0.05 * np.minimum(lam1, lam2) lamFake = np.where(haveLargeDiff, lam2, 2.0*lam2) return np.where(haveLargeDiff, - relative_log_difference_no_tolerance_check(lam1, lamFake), - relative_log_difference_taylor(lam1, lam2)) + _relative_log_difference_no_tolerance_check(lam1, lamFake), + _relative_log_difference_taylor(lam1, lam2)) def symmetric_matrix_function(A, func): @@ -471,7 +471,7 @@ def _sqrt_relative_difference(lam1, lam2): return 1/(np.sqrt(lam1) + np.sqrt(lam2)) @sqrt_symm.defjvp -def sqrt_symm_jvp(primals, tangents): +def _sqrt_symm_jvp(primals, tangents): primal_out = sqrt_symm(*primals) return primal_out, _symmetric_matrix_function_jvp_helper(Math.safe_sqrt, _sqrt_relative_difference, primals, tangents) @@ -485,7 +485,7 @@ def _exp_relative_difference(lam1, lam2): return np.exp(lam2)*np.expm1(arg)/arg @exp_symm.defjvp -def exp_symm_jvp(primals, tangents): +def _exp_symm_jvp(primals, tangents): primal_out = exp_symm(*primals) return primal_out, _symmetric_matrix_function_jvp_helper(np.exp, _exp_relative_difference, primals, tangents) @@ -501,7 +501,7 @@ def _log_relative_difference(lam1, lam2): return (np.log1p(arg)/arg)/lams[i[1]] @log_symm.defjvp -def log_symm_jvp(primals, tangents): +def _log_symm_jvp(primals, tangents): primal_out = log_symm(*primals) return primal_out, _symmetric_matrix_function_jvp_helper(np.log, _log_relative_difference, primals, tangents) @@ -539,7 +539,7 @@ def _pow_relative_difference(lam1, lam2, m): return lam_big**(m-1)*(arg**m - 1)/(arg - 1) @pow_symm.defjvp -def pow_symm_jvp(primals, tangents): +def _pow_symm_jvp(primals, tangents): A, m = primals dA, dm = tangents return pow_symm(A, m), _symmetric_matrix_function_jvp_helper(lambda x: np.power(x, m), lambda l1, l2: _pow_relative_difference(l1, l2, m), (A,), (dA,)) From 0c84e5522b6cc0be02c9f7d219d9862821b6ab52 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Thu, 30 Nov 2023 15:37:33 -0800 Subject: [PATCH 14/17] Remove mtk_log in favor of new implementation Replace all calls except in the new viscoelastic model. Changes are about to merge there and I want to handle the conflicts separately. --- examples/adjoint_with_ivs/parameterized_j2.py | 2 +- .../parameterized_linear_elastic.py | 2 +- optimism/TensorMath.py | 66 +------------------ optimism/material/J2Plastic.py | 21 +++--- optimism/material/LinearElastic.py | 12 ++-- .../phasefield/PhaseFieldLorentzPlastic.py | 19 ++---- optimism/phasefield/PhaseFieldThreshold.py | 10 +-- .../phasefield/PhaseFieldThresholdPlastic.py | 1 - optimism/test/test_TensorMath.py | 56 ---------------- 9 files changed, 32 insertions(+), 157 deletions(-) diff --git a/examples/adjoint_with_ivs/parameterized_j2.py b/examples/adjoint_with_ivs/parameterized_j2.py index ff1fb778..0c71cb66 100644 --- a/examples/adjoint_with_ivs/parameterized_j2.py +++ b/examples/adjoint_with_ivs/parameterized_j2.py @@ -220,7 +220,7 @@ def compute_elastic_logarithmic_strain(dispGrad, state): Je = np.linalg.det(FeT) # = J since this model is isochoric plasticity traceEe = np.log(Je) CeIso = Je**(-2./3.)*FeT@FeT.T - EeDev = TensorMath.mtk_log_sqrt(CeIso) + EeDev = TensorMath.log_sqrt_symm(CeIso) return EeDev + traceEe/3.0*np.identity(3) diff --git a/examples/adjoint_with_ivs/parameterized_linear_elastic.py b/examples/adjoint_with_ivs/parameterized_linear_elastic.py index 11c022f1..3e10ef46 100644 --- a/examples/adjoint_with_ivs/parameterized_linear_elastic.py +++ b/examples/adjoint_with_ivs/parameterized_linear_elastic.py @@ -67,5 +67,5 @@ def log_strain(dispGrad): J = np.linalg.det(F) traceStrain = np.log(J) CIso = J**(-2.0/3.0)*F.T@F - devStrain = TensorMath.mtk_log_sqrt(CIso) + devStrain = TensorMath.log_sqrt_symm(CIso) return devStrain + traceStrain/3.0*np.identity(3) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index 566344c4..12037db9 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -321,70 +321,8 @@ def cos_of_acos_divided_by_3(x): return numer/denom -@jax.custom_jvp -def mtk_log_sqrt(A): - lam,V = eigen_sym33_unit(A) - return V @ np.diag(0.5*np.log(lam)) @ V.T - - -@mtk_log_sqrt.defjvp -def _mtk_log_sqrt_jvp(Cpack, Hpack): - C, = Cpack - H, = Hpack - - logSqrtC = mtk_log_sqrt(C) - lam,V = eigen_sym33_unit(C) - - lam1 = lam[0] - lam2 = lam[1] - lam3 = lam[2] - - e1 = V[:,0] - e2 = V[:,1] - e3 = V[:,2] - - hHat = 0.5 * (V.T @ H @ V) - - l1111 = hHat[0,0] / lam1 - l2222 = hHat[1,1] / lam2 - l3333 = hHat[2,2] / lam3 - - l1212 = 0.5*(hHat[0,1]+hHat[1,0]) * _relative_log_difference(lam1, lam2) - l2323 = 0.5*(hHat[1,2]+hHat[2,1]) * _relative_log_difference(lam2, lam3) - l3131 = 0.5*(hHat[2,0]+hHat[0,2]) * _relative_log_difference(lam3, lam1) - - t00 = l1111 * e1[0] * e1[0] + l2222 * e2[0] * e2[0] + l3333 * e3[0] * e3[0] + \ - 2 * l1212 * e1[0] * e2[0] + \ - 2 * l2323 * e2[0] * e3[0] + \ - 2 * l3131 * e3[0] * e1[0] - t11 = l1111 * e1[1] * e1[1] + l2222 * e2[1] * e2[1] + l3333 * e3[1] * e3[1] + \ - 2 * l1212 * e1[1] * e2[1] + \ - 2 * l2323 * e2[1] * e3[1] + \ - 2 * l3131 * e3[1] * e1[1] - t22 = l1111 * e1[2] * e1[2] + l2222 * e2[2] * e2[2] + l3333 * e3[2] * e3[2] + \ - 2 * l1212 * e1[2] * e2[2] + \ - 2 * l2323 * e2[2] * e3[2] + \ - 2 * l3131 * e3[2] * e1[2] - - t01 = l1111 * e1[0] * e1[1] + l2222 * e2[0] * e2[1] + l3333 * e3[0] * e3[1] + \ - l1212 * (e1[0] * e2[1] + e2[0] * e1[1]) + \ - l2323 * (e2[0] * e3[1] + e3[0] * e2[1]) + \ - l3131 * (e3[0] * e1[1] + e1[0] * e3[1]) - t12 = l1111 * e1[1] * e1[2] + l2222 * e2[1] * e2[2] + l3333 * e3[1] * e3[2] + \ - l1212 * (e1[1] * e2[2] + e2[1] * e1[2]) + \ - l2323 * (e2[1] * e3[2] + e3[1] * e2[2]) + \ - l3131 * (e3[1] * e1[2] + e1[1] * e3[2]) - t20 = l1111 * e1[2] * e1[0] + l2222 * e2[2] * e2[0] + l3333 * e3[2] * e3[0] + \ - l1212 * (e1[2] * e2[0] + e2[2] * e1[0]) + \ - l2323 * (e2[2] * e3[0] + e3[2] * e2[0]) + \ - l3131 * (e3[2] * e1[0] + e1[2] * e3[0]) - - sol = np.array([ [t00, t01, t20], - [t01, t11, t12], - [t20, t12, t22] ]) - - return logSqrtC, sol - +# This is an alternative way of computing the relative difference in the log of the eigenvalues. +# We're using a different method, but let's keep these functions for reference. def _relative_log_difference_taylor(lam1, lam2): # Compute a more accurate (mtk::log(lam1) - log(lam2)) / (lam1-lam2) as lam1 -> lam2 diff --git a/optimism/material/J2Plastic.py b/optimism/material/J2Plastic.py index d76a015c..58dad33d 100644 --- a/optimism/material/J2Plastic.py +++ b/optimism/material/J2Plastic.py @@ -1,6 +1,5 @@ import jax import jax.numpy as np -from jax.scipy import linalg from optimism.material.MaterialModel import MaterialModel from optimism.material import Hardening @@ -107,7 +106,7 @@ def compute_state_new_finite_deformations(dispGrad, stateOld, dt, props, hardeni stateInc = compute_state_increment(elasticTrialStrain, stateOld, dt, props, hardening_model) eqpsNew = stateOld[EQPS] + stateInc[EQPS] FpOld = np.reshape(stateOld[PLASTIC_DISTORTION], (3,3)) - FpNew = linalg.expm(stateInc[PLASTIC_DISTORTION].reshape((3,3)))@FpOld + FpNew = TensorMath.exp_symm(stateInc[PLASTIC_DISTORTION].reshape((3,3)))@FpOld return np.hstack((eqpsNew, FpNew.ravel())) @@ -212,19 +211,17 @@ def update_state(elasticTrialStrain, stateOld, dt, props, hardening_model): def compute_elastic_logarithmic_strain(dispGrad, state): - F = dispGrad + np.eye(3) - Fp = state[PLASTIC_DISTORTION].reshape((3,3)) - FeT = linalg.solve(Fp.T, F.T) - # Compute the deviatoric and spherical parts separately # to preserve the sign of J. Want to let solver sense and # deal with inverted elements. - - Je = np.linalg.det(FeT) # = J since this model is isochoric plasticity - traceEe = np.log(Je) - CeIso = Je**(-2./3.)*FeT@FeT.T - EeDev = TensorMath.mtk_log_sqrt(CeIso) - return EeDev + traceEe/3.0*np.identity(3) + Je_minus_1 = TensorMath.detpIm1(dispGrad) # J = Je since this model is isochoric plasticity + traceEe = np.log1p(Je_minus_1) + F = dispGrad + np.eye(3) + Fp = state[PLASTIC_DISTORTION].reshape((3,3)) + Fe = F@TensorMath.inv(Fp) + Ce = Fe.T@Fe + Ee = TensorMath.log_sqrt_symm(Ce) + return TensorMath.dev(Ee) + traceEe/3.0*np.identity(3) def compute_elastic_linear_strain(dispGrad, state): diff --git a/optimism/material/LinearElastic.py b/optimism/material/LinearElastic.py index 9e2f6c0d..290cbb7d 100644 --- a/optimism/material/LinearElastic.py +++ b/optimism/material/LinearElastic.py @@ -76,9 +76,11 @@ def linear_strain(dispGrad): def log_strain(dispGrad): + # Compute the deviatoric and spherical parts separately + # to preserve the sign of J. + Jm1 = TensorMath.detpIm1(dispGrad) + traceStrain = np.log1p(Jm1) F = dispGrad + np.eye(3) - J = np.linalg.det(F) - traceStrain = np.log(J) - CIso = J**(-2.0/3.0)*F.T@F - devStrain = TensorMath.mtk_log_sqrt(CIso) - return devStrain + traceStrain/3.0*np.identity(3) + C = F.T@F + strain = TensorMath.log_sqrt_symm(C) + return TensorMath.dev(strain) + traceStrain/3.0*np.identity(3) diff --git a/optimism/phasefield/PhaseFieldLorentzPlastic.py b/optimism/phasefield/PhaseFieldLorentzPlastic.py index 2a83b12f..139d32ed 100644 --- a/optimism/phasefield/PhaseFieldLorentzPlastic.py +++ b/optimism/phasefield/PhaseFieldLorentzPlastic.py @@ -1,6 +1,3 @@ -from sys import float_info -from jax.scipy.linalg import solve, expm - from optimism.JaxConfig import * from optimism.phasefield.PhaseFieldMaterialModel import MaterialModel from optimism.material.J2Plastic import compute_flow_direction @@ -136,7 +133,7 @@ def compute_state_new_finite_deformations(dispGrad, phase, phaseGrad, stateOld, stateInc = compute_state_increment(elasticTrialStrain, phase, stateOld, dt, props, hardeningModel) eqpsNew = stateOld[STATE_EQPS] + stateInc[STATE_EQPS] FpOld = np.reshape(stateOld[STATE_PLASTIC_STRAIN], (3,3)) - FpNew = expm(stateInc[STATE_PLASTIC_STRAIN].reshape((3,3)))@FpOld + FpNew = TensorMath.exp_symm(stateInc[STATE_PLASTIC_STRAIN].reshape((3,3)))@FpOld elasticStrainNew = elasticTrialStrain - stateInc[STATE_PLASTIC_STRAIN].reshape((3,3)) return np.hstack((eqpsNew, FpNew.ravel())) @@ -277,17 +274,15 @@ def compute_elastic_linear_strain(dispGrad, plasticStrain): def compute_elastic_logarithmic_strain(dispGrad, Fp): - F = dispGrad + np.eye(3) - FeT = solve(Fp.T, F.T) - # Compute the deviatoric and spherical parts separately # to preserve the sign of J. Want to let solver sense and # deal with inverted elements. - - Je = np.linalg.det(FeT) # = J since this model is isochoric plasticity - traceEe = np.log(Je) - CeIso = Je**(-2./3.)*FeT@FeT.T - EeDev = TensorMath.mtk_log_sqrt(CeIso) + Je_minus_1 = TensorMath.detpIm1(dispGrad) # J = Je since this model is isochoric plasticity + traceEe = np.log1p(Je_minus_1) + F = dispGrad + np.eye(3) + Fe = F@TensorMath.inv(Fp) + Ce = Fe.T@Fe + EeDev = TensorMath.dev(TensorMath.log_sqrt_symm(Ce)) return EeDev + traceEe/3.0*np.identity(3) diff --git a/optimism/phasefield/PhaseFieldThreshold.py b/optimism/phasefield/PhaseFieldThreshold.py index aafcec3a..1482c15a 100644 --- a/optimism/phasefield/PhaseFieldThreshold.py +++ b/optimism/phasefield/PhaseFieldThreshold.py @@ -140,8 +140,8 @@ def compute_linear_strain(dispGrad): def compute_logarithmic_strain(dispGrad): F = dispGrad + np.identity(3) - J = np.linalg.det(F) - traceE = np.log(J) - CIso = J**(-2.0/3.0)*F.T@F - devE = TensorMath.mtk_log_sqrt(CIso) - return devE + traceE/3.0*np.identity(3) + Jm1 = TensorMath.detpIm1(dispGrad) + traceE = np.log1p(Jm1) + C = F.T@F + E = TensorMath.log_sqrt_symm(C) + return TensorMath.dev(E) + traceE/3.0*np.identity(3) diff --git a/optimism/phasefield/PhaseFieldThresholdPlastic.py b/optimism/phasefield/PhaseFieldThresholdPlastic.py index f44e00a1..9c6b2a7a 100644 --- a/optimism/phasefield/PhaseFieldThresholdPlastic.py +++ b/optimism/phasefield/PhaseFieldThresholdPlastic.py @@ -1,4 +1,3 @@ -from sys import float_info from jax.lax import while_loop from jax import custom_jvp diff --git a/optimism/test/test_TensorMath.py b/optimism/test/test_TensorMath.py index 169fb740..67033dc9 100644 --- a/optimism/test/test_TensorMath.py +++ b/optimism/test/test_TensorMath.py @@ -49,62 +49,6 @@ def test_eigen_sym33_non_unit_degenerate_case(self): self.assertArrayNear(C, vecs@np.diag(d)@vecs.T, 13) self.assertArrayNear(vecs@vecs.T, np.identity(3), 13) - - ### mtk_log_sqrt tests ### - - - def test_log_sqrt_scaled_identity(self): - val = 1.2 - C = np.diag(np.array([val, val, val])) - - logSqrtVal = np.log(np.sqrt(val)) - self.assertArrayNear(TensorMath.mtk_log_sqrt(C), np.diag(np.array([logSqrtVal, logSqrtVal, logSqrtVal])), 12) - - - def test_log_sqrt_double_eigs(self): - val1 = 2.0 - val2 = 0.5 - C = R@np.diag(np.array([val1, val2, val1]))@R.T - - logSqrt1 = np.log(np.sqrt(val1)) - logSqrt2 = np.log(np.sqrt(val2)) - diagLogSqrt = np.diag(np.array([logSqrt1, logSqrt2, logSqrt1])) - - logSqrtCExpected = R@diagLogSqrt@R.T - self.assertArrayNear(TensorMath.mtk_log_sqrt(C), logSqrtCExpected, 12) - - - def test_log_sqrt_squared_grad_scaled_identity(self): - val = 1.2 - C = np.diag(np.array([val, val, val])) - - def log_squared(A): - lg = TensorMath.mtk_log_sqrt(A) - return np.tensordot(lg, lg) - check_grads(log_squared, (C,), order=1) - - - def test_log_sqrt_squared_grad_double_eigs(self): - val1 = 2.0 - val2 = 0.5 - C = R@np.diag(np.array([val1, val2, val1]))@R.T - - def log_squared(A): - lg = TensorMath.mtk_log_sqrt(A) - return np.tensordot(lg, lg) - check_grads(log_squared, (C,), order=1) - - - def test_log_sqrt_squared_grad_rand(self): - key = jax.random.PRNGKey(0) - F = jax.random.uniform(key, (3,3), minval=1e-8, maxval=10.0) - C = F.T@F - - def log_squared(A): - lg = TensorMath.mtk_log_sqrt(A) - return np.tensordot(lg, lg) - check_grads(log_squared, (C,), order=1) - # log_symm tests def test_log_symm_scaled_identity(self): From d81daeea21ac5fcfa5e7bc1bce52f29aec47b624 Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Fri, 1 Dec 2023 15:07:21 -0800 Subject: [PATCH 15/17] Comments and formatting --- optimism/TensorMath.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/optimism/TensorMath.py b/optimism/TensorMath.py index 12037db9..928e8bd1 100644 --- a/optimism/TensorMath.py +++ b/optimism/TensorMath.py @@ -1,6 +1,5 @@ """Provide differentiable operations on 3x3 tensors.""" -from functools import partial import jax import jax.numpy as np @@ -59,8 +58,8 @@ def norm_of_deviator_squared(tensor): def norm_of_deviator(tensor): return norm( deviator(tensor) ) -def mises_invariant(stress): - return np.sqrt(1.5)*norm_of_deviator(stress) +def mises_invariant(S): + return np.sqrt(1.5)*norm_of_deviator(S) def triaxiality(A): mean_normal = trace(A)/3.0 @@ -413,6 +412,7 @@ def _sqrt_symm_jvp(primals, tangents): primal_out = sqrt_symm(*primals) return primal_out, _symmetric_matrix_function_jvp_helper(Math.safe_sqrt, _sqrt_relative_difference, primals, tangents) + @jax.custom_jvp def exp_symm(A): """Compute the matrix exponential of a symmetric matrix.""" @@ -427,6 +427,7 @@ def _exp_symm_jvp(primals, tangents): primal_out = exp_symm(*primals) return primal_out, _symmetric_matrix_function_jvp_helper(np.exp, _exp_relative_difference, primals, tangents) + @jax.custom_jvp def log_symm(A): """Compute the matrix logarithm of a symmetric positive definite matrix.""" @@ -443,23 +444,26 @@ def _log_symm_jvp(primals, tangents): primal_out = log_symm(*primals) return primal_out, _symmetric_matrix_function_jvp_helper(np.log, _log_relative_difference, primals, tangents) + def log_sqrt_symm(A): """Compute matrix logarithm of the square root of a symmetric positive definite matrix.""" return 0.5*log_symm(A) + @jax.custom_jvp def pow_symm(A, m): """Raise a symmetric matrix to a power. Compute m-fold iterated matrix multiplication of A. Works correctly with negative powers, but recall that the matrix must be invertible - (or else a matrix of inf or nan will result). This function is not - differentiable in the `m` argument. + (or else a matrix of inf or nan will result). Arguments: A: (array) symmetric matrix to raise to a power m: (float or int) power + .. note:: This function is not differentiable in the `m` argument. + .. note:: The derivative of this function is inaccurate on matrices with nearly degenerate eigenvalues. We lack a high-quality implementation of the relative difference of the eigenvalue function. (The derivative of From 3d0cee7e34b08534b33b5d4b0bc9ae10cb0d251c Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Wed, 17 Jan 2024 06:03:31 -0800 Subject: [PATCH 16/17] Replace calls to tensor log sqrt with new function name --- optimism/J2PlasticPhaseField.py | 2 +- optimism/material/HyperViscoelastic.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optimism/J2PlasticPhaseField.py b/optimism/J2PlasticPhaseField.py index 248ccc6b..cb90e4fa 100644 --- a/optimism/J2PlasticPhaseField.py +++ b/optimism/J2PlasticPhaseField.py @@ -84,7 +84,7 @@ def compute_logarithmic_elastic_strain(dispGrad, state): Fp = state[PLASTIC_STRAIN].reshape((3,3)) FeT = solve(Fp.T, F.T) Ce = FeT@FeT.T - return TensorMath.mtk_log_sqrt(Ce) + return TensorMath.log_sqrt_symm(Ce) def compute_state_increment(elasticTrialStrain, stateOld, props): diff --git a/optimism/material/HyperViscoelastic.py b/optimism/material/HyperViscoelastic.py index 38c6eea9..47334e87 100644 --- a/optimism/material/HyperViscoelastic.py +++ b/optimism/material/HyperViscoelastic.py @@ -105,4 +105,4 @@ def _compute_elastic_logarithmic_strain(dispGrad, stateOld): Fv_old = stateOld.reshape((3, 3)) Fe_trial = F @ np.linalg.inv(Fv_old) - return TensorMath.mtk_log_sqrt(Fe_trial.T @ Fe_trial) \ No newline at end of file + return TensorMath.log_sqrt_symm(Fe_trial.T @ Fe_trial) \ No newline at end of file From ceaea061d81fde23a375a73a244f710b5f0597ce Mon Sep 17 00:00:00 2001 From: Brandon Talamini Date: Wed, 17 Jan 2024 10:33:19 -0800 Subject: [PATCH 17/17] Fix another call to a function that I renamed --- optimism/material/HyperViscoelastic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimism/material/HyperViscoelastic.py b/optimism/material/HyperViscoelastic.py index 47334e87..9380da97 100644 --- a/optimism/material/HyperViscoelastic.py +++ b/optimism/material/HyperViscoelastic.py @@ -97,7 +97,7 @@ def _compute_state_increment(elasticStrain, dt, props): tau = props[PROPS_TAU] integration_factor = 1. / (1. + dt / tau) - Ee_dev = TensorMath.compute_deviatoric_tensor(elasticStrain) + Ee_dev = TensorMath.dev(elasticStrain) return dt * integration_factor * Ee_dev / tau # dt * D def _compute_elastic_logarithmic_strain(dispGrad, stateOld):