-
Hi I would like to send a PR to blackjax with this example from numpyro without using numpyro's import jax
jax.config.update('jax_platform_name', 'cpu')
from collections import namedtuple
import jax.numpy as jnp
from jax.experimental.ode import odeint
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
params = namedtuple('params', ['uinit', 'vinit', 'alpha', 'betta',
'gamma', 'delta', 'u_sigma', 'v_sigma'])
uinit_prior = tfd.LogNormal(jnp.log(10.0), 1.0)
vinit_prior = tfd.LogNormal(jnp.log(10.0), 1.0)
alpha_prior = tfd.TruncatedNormal(1.00, 0.50, 0.0, jnp.inf)
betta_prior = tfd.TruncatedNormal(0.05, 0.05, 0.0, jnp.inf)
gamma_prior = tfd.TruncatedNormal(1.00, 0.50, 0.0, jnp.inf)
delta_prior = tfd.TruncatedNormal(0.05, 0.05, 0.0, jnp.inf)
u_sigma_prior = tfd.LogNormal(-1, 1)
v_sigma_prior = tfd.LogNormal(-1, 1)
def lotka_volterra_ODE(z, t, alpha, betta, gamma, delta):
u, v = z[0], z[1]
du_dt = (+alpha - betta * v) * u
dv_dt = (-gamma + delta * u) * v
return jnp.stack([du_dt, dv_dt])
def Model(ODE, rtol=1e-6, atol=1e-5, max_steps=1000):
def apply(z_init, year_args, alpha, betta, gamma, delta):
z = odeint(ODE, z_init, year_args, alpha, betta, gamma, delta,
rtol=rtol, atol=atol, mxstep=max_steps)
return z
return apply
def target_log_prob(model, data):
year_args = jnp.arange(len(data), dtype=jnp.float32)
def apply(params):
uinit_log_prob = uinit_prior.log_prob(params.uinit)
vinit_log_prob = vinit_prior.log_prob(params.vinit)
alpha_log_prob = alpha_prior.log_prob(params.alpha)
betta_log_prob = betta_prior.log_prob(params.betta)
gamma_log_prob = gamma_prior.log_prob(params.gamma)
delta_log_prob = delta_prior.log_prob(params.delta)
u_sigma_log_prob = u_sigma_prior.log_prob(params.u_sigma)
v_sigma_log_prob = v_sigma_prior.log_prob(params.v_sigma)
zinit = jnp.array([params.uinit, params.vinit])
args = (params.alpha, params.betta, params.gamma, params.delta)
z = model(zinit, year_args, *args)
sigmas = jnp.array([params.u_sigma, params.v_sigma])
log_likelihood = tfd.LogNormal(jnp.log(z), sigmas).log_prob(data)
return (uinit_log_prob + vinit_log_prob + alpha_log_prob +
betta_log_prob + gamma_log_prob + delta_log_prob +
u_sigma_log_prob + v_sigma_log_prob + log_likelihood.sum())
return apply
def sample(key):
keys = jax.random.split(key, 8)
uinit = uinit_prior.sample(seed=keys[0])
vinit = vinit_prior.sample(seed=keys[1])
alpha = alpha_prior.sample(seed=keys[2])
betta = betta_prior.sample(seed=keys[3])
gamma = gamma_prior.sample(seed=keys[4])
delta = delta_prior.sample(seed=keys[5])
u_sigma = u_sigma_prior.sample(seed=keys[6])
v_sigma = v_sigma_prior.sample(seed=keys[7])
return params(uinit, vinit, alpha, betta, gamma, delta, u_sigma, v_sigma)
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, info = kernel(rng_key, state)
return state, (state, info)
keys = jax.random.split(rng_key, num_samples)
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
return states, infos
if __name__ == "__main__":
from numpyro.examples.datasets import LYNXHARE, load_dataset
import blackjax
fetch = load_dataset(LYNXHARE, shuffle=False)[1]
year, data = fetch()
key = jax.random.PRNGKey(1)
key_init, key_warm, key_loop, key_pred = jax.random.split(key, 4)
warmup_steps = 1000
num_samples = 1000
model = Model(lotka_volterra_ODE)
log_prob = target_log_prob(model, data)
position = sample(key_init)
adapt = blackjax.window_adaptation(
algorithm=blackjax.nuts,
logprob_fn=log_prob,
is_mass_matrix_diagonal=False,
initial_step_size=1.0,
progress_bar=True)
state, kernel, kernel_params = adapt.run(key_warm, position, warmup_steps)
print('Kernel params', kernel_params)
states, infos = inference_loop(key_loop, kernel, state, num_samples) |
Beta Was this translation helpful? Give feedback.
Replies: 6 comments 2 replies
-
Thank you for your interest in contributing. Before trying to run the example locally I have a few questions:
It would be interesting to understand what’s happening, but then we probably should reparametrize the model anyway. |
Beta Was this translation helpful? Give feedback.
-
Thank you for the fast response! I found an error that I fixed while posting this. The error was a RNG_KEY being used twice. I thought that this would not be too problematic, but now I don't seem to get any close to zero
|
Beta Was this translation helpful? Give feedback.
-
Ah, yes that could explain it. Which one? (Re-read the code and can’t find it).
Yeah it seems we can’t do it with the high-level interface. That’s a problem and we should open an issue for that.
Posterior predictive? Yeah you’d need to implement the forward sampling version of your model and |
Beta Was this translation helpful? Give feedback.
-
It was this line |
Beta Was this translation helpful? Give feedback.
-
Np. Looking forward to the PR :) |
Beta Was this translation helpful? Give feedback.
-
Hi! I found this example quite useful and was able to adapt it to construct a loglikelihood for a related problem (ODE fitting with generalised least squares). Thanks for showing how to use closures to obtain a logdensity function! If I'm not mistaken, this example requires a change in inference loop since #501, after which |
Beta Was this translation helpful? Give feedback.
It was this line
v_sigma = v_sigma_prior.sample(seed=keys[7])
. I usedkeys[6]
repeating it foru_sigma
. But I fixed it while posting this. Yes the posterior predictive, I will update this issue in case I encounter another problem! Thank you.