Skip to content

Commit

Permalink
adding fixes to tfp_files to avoid nan-gradients/hesses
Browse files Browse the repository at this point in the history
  • Loading branch information
josh0-jrg committed Oct 25, 2023
1 parent c40a8c1 commit 6b820f6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
11 changes: 6 additions & 5 deletions flamedisx/tfp_files/skew_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,10 @@ def owensT1(h, a, terms):
exp_hs = tf.math.exp(hs)

ci = -1 + exp_hs
val = tf.math.atan(a) * tf.ones_like(hs)

for i in range(terms):
val = tf.math.atan(a) * tf.ones_like(hs)+ ci * a
ci= -ci + hs / tf.exp(tf.math.lgamma(tf.cast(2,'float32'))) * exp_hs
#for a->0. tf.math.pow(a,0.) can cause divergent gradients. Would be more efficeint not to calculate this.
for i in range(1,terms):
val += ci * tf.math.pow(a,2*tf.cast(i,'float32')+1) / (2*tf.cast(i,'float32')+1)
ci = -ci + tf.math.pow(hs,tf.cast(i+1,'float32')) / tf.exp(tf.math.lgamma(tf.cast(i+2,'float32'))) * exp_hs

Expand All @@ -170,9 +171,9 @@ def _cdf(self, x):
a = tf.cast(skewness,'float32')

owens_t_eval = 0.5 * normal.Normal(loc=0.,scale=1.).cdf(h) + 0.5 * normal.Normal(loc=0.,scale=1.).cdf(a*h) - normal.Normal(loc=0.,scale=1.).cdf(h) * normal.Normal(loc=0.,scale=1.).cdf(a*h)

#if tensorflow calculates a value, like 1/a., it will do the gradients as well.
return 0.5 * (1. + tf.math.erf(1./(np.sqrt(2.)*scale) * (x - self.loc))) - \
tf.cast(tf.where(a > tf.ones_like(a), 2. * (owens_t_eval - self.owensT1(a*h,1./a,self.owens_t_terms)), 2. * self.owensT1(h,a,self.owens_t_terms)),'float32')
tf.cast(tf.where(a > tf.ones_like(a), 2. * (owens_t_eval - self.owensT1(a*h,tf.math.divide_no_nan(1.,a),self.owens_t_terms)), 2. * self.owensT1(h,a,self.owens_t_terms)),'float32')

def _parameter_control_dependencies(self, is_init):
assertions = []
Expand Down
6 changes: 3 additions & 3 deletions flamedisx/tfp_files/truncated_skew_gaussian_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,15 @@ def _log_prob(self, x):
cdf_lower = skew_gauss.cdf(x-0.5)

minus_inf = dtype_util.as_numpy_dtype(x.dtype)(-np.inf)

#_log_prob is used by exp(ln(x)), gradients are 1/x *e^ln(x). diverge at 0.
bounded_log_prob = tf.where((x > limit),
minus_inf,
tf.math.log(cdf_upper - cdf_lower))
tf.math.log(cdf_upper - cdf_lower+1e-11))
bounded_log_prob = tf.where(tf.math.is_nan(bounded_log_prob),
minus_inf,
bounded_log_prob)
dumping_log_prob = tf.where((x == limit),
tf.math.log(1 - cdf_lower),
tf.math.log(1 - cdf_lower+1e-11),
bounded_log_prob)

return dumping_log_prob
Expand Down

0 comments on commit 6b820f6

Please sign in to comment.