Skip to content

Commit

Permalink
binomial float fixing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
cecilia-ferrari committed Jan 10, 2024
1 parent 5d8c14b commit a89366a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flamedisx/lxe_blocks/s2_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def _compute(self, data_tensor, ptensor,
data_tensor=data_tensor, ptensor=ptensor)[:, o, o]

# s2_raw_after_loss distributed as Binom(s2_raw, p=s2_survival_probability)
s2_raw_after_loss = tf.clip_by_value(s2_raw_after_loss, 0, tf.int32.max)
s2_raw_after_loss = tf.clip_by_value(s2_raw_after_loss, 1e-15, tf.float32.max)
result = tfp.distributions.Binomial(
total_count=tf.clip_by_value(tf.cast(s2_raw, dtype=fd.int_type()),1e-15, tf.int32.max),
total_count=tf.clip_by_value(tf.cast(s2_raw, dtype=fd.int_type()),0, tf.int32.max),
probs=tf.clip_by_value(tf.cast(s2_survival_probability, dtype=fd.float_type()), 0.,1.)
).prob(s2_raw_after_loss)

Expand Down

0 comments on commit a89366a

Please sign in to comment.