diff --git a/flamedisx/lxe_blocks/s2_loss.py b/flamedisx/lxe_blocks/s2_loss.py index df4f2e73c..342932319 100644 --- a/flamedisx/lxe_blocks/s2_loss.py +++ b/flamedisx/lxe_blocks/s2_loss.py @@ -24,9 +24,9 @@ def _compute(self, data_tensor, ptensor, data_tensor=data_tensor, ptensor=ptensor)[:, o, o] # s2_raw_after_loss distributed as Binom(s2_raw, p=s2_survival_probability) - s2_raw_after_loss = tf.clip_by_value(s2_raw_after_loss, 0, tf.int32.max) + s2_raw_after_loss = tf.clip_by_value(s2_raw_after_loss, 1e-15, tf.float32.max) result = tfp.distributions.Binomial( - total_count=tf.clip_by_value(tf.cast(s2_raw, dtype=fd.int_type()),1e-15, tf.int32.max), + total_count=tf.clip_by_value(tf.cast(s2_raw, dtype=fd.int_type()),0, tf.int32.max), probs=tf.clip_by_value(tf.cast(s2_survival_probability, dtype=fd.float_type()), 0.,1.) ).prob(s2_raw_after_loss)