Skip to content

Commit

Permalink
Sync updates with draft PR #386. \n- Added pytensor.function for bfgs…
Browse files Browse the repository at this point in the history
…_sample
  • Loading branch information
aphc14 committed Nov 7, 2024
1 parent 1fd7a11 commit ef2956f
Showing 1 changed file with 89 additions and 31 deletions.
120 changes: 89 additions & 31 deletions pymc_experimental/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ def bfgs_sample(
alpha,
beta,
gamma,
random_seed: RandomSeed | None = None,
# random_seed: RandomSeed | None = None,
rng,
):
# batch: L = 8
# alpha_l: (N,) => (L, N)
Expand All @@ -315,7 +316,7 @@ def bfgs_sample(
# logdensity: (M,) => (L, M)
# theta: (J, N)

rng = pytensor.shared(np.random.default_rng(seed=random_seed))
# rng = pytensor.shared(np.random.default_rng(seed=random_seed))

def batched(x, g, alpha, beta, gamma):
var_list = [x, g, alpha, beta, gamma]
Expand Down Expand Up @@ -380,6 +381,64 @@ def compute_logp(logp_func, arr):
return np.where(np.isnan(logP), -np.inf, logP)


_x = pt.matrix("_x", dtype="float64")
_g = pt.matrix("_g", dtype="float64")
_alpha = pt.matrix("_alpha", dtype="float64")
_beta = pt.tensor3("_beta", dtype="float64")
_gamma = pt.tensor3("_gamma", dtype="float64")
_epsilon = pt.scalar("_epsilon", dtype="float64")
_maxcor = pt.iscalar("_maxcor")
_alpha, _S, _Z, _update_mask = alpha_recover(_x, _g, epsilon=_epsilon)
_beta, _gamma = inverse_hessian_factors(_alpha, _S, _Z, _update_mask, J=_maxcor)

_num_elbo_draws = pt.iscalar("_num_elbo_draws")
_dummy_rng = pytensor.shared(np.random.default_rng(), name="_dummy_rng")
_phi, _logQ_phi = bfgs_sample(
num_samples=_num_elbo_draws,
x=_x,
g=_g,
alpha=_alpha,
beta=_beta,
gamma=_gamma,
rng=_dummy_rng,
)

_num_draws = pt.iscalar("_num_draws")
_x_lstar = pt.dvector("_x_lstar")
_g_lstar = pt.dvector("_g_lstar")
_alpha_lstar = pt.dvector("_alpha_lstar")
_beta_lstar = pt.dmatrix("_beta_lstar")
_gamma_lstar = pt.dmatrix("_gamma_lstar")


_psi, _logQ_psi = bfgs_sample(
num_samples=_num_draws,
x=_x_lstar,
g=_g_lstar,
alpha=_alpha_lstar,
beta=_beta_lstar,
gamma=_gamma_lstar,
rng=_dummy_rng,
)

alpha_recover_compiled = pytensor.function(
inputs=[_x, _g, _epsilon],
outputs=[_alpha, _S, _Z, _update_mask],
)
inverse_hessian_factors_compiled = pytensor.function(
inputs=[_alpha, _S, _Z, _update_mask, _maxcor],
outputs=[_beta, _gamma],
)
bfgs_sample_compiled = pytensor.function(
inputs=[_num_elbo_draws, _x, _g, _alpha, _beta, _gamma],
outputs=[_phi, _logQ_phi],
)
bfgs_sample_lstar_compiled = pytensor.function(
inputs=[_num_draws, _x_lstar, _g_lstar, _alpha_lstar, _beta_lstar, _gamma_lstar],
outputs=[_psi, _logQ_psi],
)


def single_pathfinder(
model,
num_draws: int,
Expand Down Expand Up @@ -423,47 +482,46 @@ def neg_dlogp_func(x):
maxls=maxls,
)

# x_full, g_full: (L+1, N)
x_full = pt.as_tensor(lbfgs_history.x, dtype="float64")
g_full = pt.as_tensor(lbfgs_history.g, dtype="float64")
# x, g: (L+1, N)
x = lbfgs_history.x
g = lbfgs_history.g
alpha, S, Z, update_mask = alpha_recover_compiled(x, g, epsilon)
beta, gamma = inverse_hessian_factors_compiled(alpha, S, Z, update_mask, maxcor)

# ignore initial point - x, g: (L, N)
x = x_full[1:]
g = g_full[1:]

alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon)
beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J=maxcor)

phi, logQ_phi = bfgs_sample(
num_samples=num_elbo_draws,
x=x,
g=g,
alpha=alpha,
beta=beta,
gamma=gamma,
random_seed=pathfinder_seed,
x = x[1:]
g = g[1:]

rng = pytensor.shared(np.random.default_rng(pathfinder_seed), borrow=True)
phi, logQ_phi = bfgs_sample_compiled.copy(swap={_dummy_rng: rng})(
num_elbo_draws,
x,
g,
alpha,
beta,
gamma,
)

# .vectorize is slower than apply_along_axis
logP_phi = compute_logp(logp_func, phi.eval())
logQ_phi = logQ_phi.eval()
logP_phi = compute_logp(logp_func, phi)
# logQ_phi = logQ_phi.eval()
elbo = (logP_phi - logQ_phi).mean(axis=-1)
lstar = np.argmax(elbo)

# BUG: elbo may all be -inf for all l in L. So np.argmax(elbo) will return 0 which is wrong. Still, this won't affect the posterior samples in the multipath Pathfinder scenario because of PSIS/PSIR step. However, the user is left unaware of a failed Pathfinder run.
# TODO: handle this case, e.g. by warning of a failed Pathfinder run and skip the following bfgs_sample step to save time.

psi, logQ_psi = bfgs_sample(
num_samples=num_draws,
x=x[lstar],
g=g[lstar],
alpha=alpha[lstar],
beta=beta[lstar],
gamma=gamma[lstar],
random_seed=sample_seed,
rng.set_value(np.random.default_rng(sample_seed), borrow=True)
psi, logQ_psi = bfgs_sample_lstar_compiled.copy(swap={_dummy_rng: rng})(
num_draws,
x[lstar],
g[lstar],
alpha[lstar],
beta[lstar],
gamma[lstar],
)
psi = psi.eval()
logQ_psi = logQ_psi.eval()
# psi = psi.eval()
# logQ_psi = logQ_psi.eval()
logP_psi = compute_logp(logp_func, psi)
# psi: (1, M, N)
# logP_psi: (1, M)
Expand Down

0 comments on commit ef2956f

Please sign in to comment.