Skip to content

Commit

Permalink
fix likelihood reg v2
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasrothfuss committed Jan 16, 2024
1 parent f004a84 commit 65f783d
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions sim_transfer/models/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,7 @@ def _neg_log_posterior(self, params: Dict, x_batch: jnp.ndarray, y_batch: jnp.nd
self.batched_model.flatten_batch(params['nn_params_stacked'])))
log_prior /= self.prior_dist.event_shape[0]
if self.likelihood_reg > 0:
likelihood_penalty = - self.likelihood_reg * \
jnp.sum(self._likelihood_prior_logprob(params['likelihood_std_raw'])**2)
likelihood_penalty = self.likelihood_reg * self._likelihood_prior_logprob(params['likelihood_std_raw'])
log_prior += likelihood_penalty
stats = OrderedDict(train_nll_loss=nll, neg_log_prior=-log_prior)
neg_log_post = nll - log_prior
Expand Down

0 comments on commit 65f783d

Please sign in to comment.