From 4485ad5d811600fa32f9d8a45d1a2e0fdc90624c Mon Sep 17 00:00:00 2001 From: David Jones Date: Sat, 27 Jan 2024 21:33:59 -0600 Subject: [PATCH] one fix for NaNs in gradient --- saltshaker/training/optimizers/gradientdescent.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/saltshaker/training/optimizers/gradientdescent.py b/saltshaker/training/optimizers/gradientdescent.py index 1dffb7c3..ec7284d2 100644 --- a/saltshaker/training/optimizers/gradientdescent.py +++ b/saltshaker/training/optimizers/gradientdescent.py @@ -465,7 +465,8 @@ def rpropiter(self,X, Xprev, prevloss,prevsign, learningrates,*args,**kwargs): """ lossval,grad= self.lossfunction(X,*args,**kwargs, diff='valueandgrad') - sign=jnp.sign(grad) + # if gradient is NaN, jax had some trouble... + sign=jnp.nan_to_num(jnp.sign(grad)) indicatorvector= prevsign *sign greater= indicatorvector >0 @@ -478,8 +479,8 @@ def rpropiter(self,X, Xprev, prevloss,prevsign, learningrates,*args,**kwargs): Xnew=jnp.select( [less,greatereq], [ lax.cond(lossval>prevloss, lambda x,y:x , lambda x,y: y, Xprev, X), X-(sign *learningrates) ]) - if len(Xnew[Xnew != Xnew]): - import pdb; pdb.set_trace() + #if len(Xnew[Xnew != Xnew]): + # import pdb; pdb.set_trace() #Set sign to 0 after a previous change sign= (sign * greatereq) return jnp.clip(Xnew,*self.Xbounds), lossval, sign, grad, learningrates