Skip to content

Commit

Permalink
one fix for NaNs in gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
David Jones committed Jan 28, 2024
1 parent 62393cc commit 4485ad5
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions saltshaker/training/optimizers/gradientdescent.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ def rpropiter(self,X, Xprev, prevloss,prevsign, learningrates,*args,**kwargs):
"""
lossval,grad= self.lossfunction(X,*args,**kwargs, diff='valueandgrad')

sign=jnp.sign(grad)
# if gradient is NaN, jax had some trouble...
sign=jnp.nan_to_num(jnp.sign(grad))
indicatorvector= prevsign *sign

greater= indicatorvector >0
Expand All @@ -478,8 +479,8 @@ def rpropiter(self,X, Xprev, prevloss,prevsign, learningrates,*args,**kwargs):
Xnew=jnp.select( [less,greatereq], [ lax.cond(lossval>prevloss, lambda x,y:x , lambda x,y: y, Xprev, X),
X-(sign *learningrates)
])
if len(Xnew[Xnew != Xnew]):
import pdb; pdb.set_trace()
#if len(Xnew[Xnew != Xnew]):
# import pdb; pdb.set_trace()
#Set sign to 0 after a previous change
sign= (sign * greatereq)
return jnp.clip(Xnew,*self.Xbounds), lossval, sign, grad, learningrates

0 comments on commit 4485ad5

Please sign in to comment.