diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 3365526b3..831586201 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -20,7 +20,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size -from blackjax.util import incremental_value_update, pytree_size +from blackjax.util import generate_unit_vector, incremental_value_update, pytree_size class MCLMCAdaptationState(NamedTuple): @@ -147,6 +147,8 @@ def predictor(previous_state, params, adaptive_state, rng_key): time, x_average, step_size_max = adaptive_state + rng_key, nan_key = jax.random.split(rng_key) + # dynamics next_state, info = kernel(params.sqrt_diag_cov)( rng_key=rng_key, @@ -162,6 +164,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): params.step_size, step_size_max, info.energy_change, + nan_key, ) # Warning: var = 0 if there were nans, but we will give it a very small weight @@ -203,7 +206,7 @@ def step(iteration_state, weight_and_key): streaming_avg = incremental_value_update( expectation=jnp.array([x, jnp.square(x)]), incremental_val=streaming_avg, - weight=(1 - mask) * success * params.step_size, + weight=mask * success * params.step_size, ) return (state, params, adaptive_state, streaming_avg), None @@ -233,7 +236,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): ) # we use the last num_steps2 to compute the diagonal preconditioner - mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) # run the steps state, params, _, (_, average) = run_steps( @@ -298,7 +301,9 @@ def step(state, key): return adaptation_L -def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): +def handle_nans( + previous_state, next_state, step_size, step_size_max, kinetic_change, key +): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" @@ -311,4 +316,13 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch (next_state, step_size_max, kinetic_change), (previous_state, step_size * reduced_step_size, 0.0), ) + + state = jax.lax.cond( + jnp.isnan(next_state.logdensity), + lambda: state._replace( + momentum=generate_unit_vector(key, previous_state.position) + ), + lambda: state, + ) + return nonans, state, step_size, kinetic_change