diff --git a/tests/mnle_test.py b/tests/mnle_test.py index 062182480..54471b0a0 100644 --- a/tests/mnle_test.py +++ b/tests/mnle_test.py @@ -243,7 +243,7 @@ def sim_wrapper(theta): # MNLE trainer = MNLE(proposal) - estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1) + estimator = trainer.append_simulations(theta, x).train() potential_fn = MixedLikelihoodBasedPotential(estimator, proposal, x_o) @@ -276,6 +276,7 @@ def sim_wrapper(theta): potential_fn=conditioned_potential_fn, theta_transform=prior_transform, proposal=prior, + **mcmc_kwargs, ) cond_samples = mcmc_posterior.sample((num_samples,), x=x_o)