Skip to content

Commit

Permalink
fixed bug where constraint was added to log likelihod for each batch
Browse files Browse the repository at this point in the history
  • Loading branch information
josh0-jrg committed Jan 30, 2024
1 parent 0e42675 commit 7c712b8
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions flamedisx/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,10 +521,14 @@ def _log_likelihood(self,
0.)
if dsetname == self.dsetnames[0]:
if constraint_extra_args is None:
ll += self.log_constraint(**params_unstacked)
ll += tf.where( tf.equal(i_batch, tf.constant(0, dtype=fd.int_type())),
self.log_constraint(**params_unstacked),
0.)
else:
kwargs = {**params_unstacked, **constraint_extra_args}
ll += self.log_constraint(**kwargs)
ll += tf.where( tf.equal(i_batch, tf.constant(0, dtype=fd.int_type())),
self.log_constraint(**kwargs),
0.)

# Autodifferentiation. This is why we use tensorflow:
grad = tf.gradients(ll, grad_par_stack)[0]
Expand Down

0 comments on commit 7c712b8

Please sign in to comment.