diff --git a/pymdp/jax/control.py b/pymdp/jax/control.py index b3cc0ad2..3f42ca9f 100644 --- a/pymdp/jax/control.py +++ b/pymdp/jax/control.py @@ -204,10 +204,8 @@ def compute_info_gain(qs, qo, A, A_dependencies): """ def compute_info_gain_for_modality(qo_m, A_m, m): - H_qo = - xlogy(qo_m, qo_m).sum() - # H_qo = - (qo_m * log_stable(qo_m)).sum() - H_A_m = - xlogy(A_m, A_m).sum(0) - # H_A_m = - (A_m * log_stable(A_m)).sum(0) + H_qo = stable_entropy(qo_m) + H_A_m = - stable_xlogx(A_m).sum(0) deps = A_dependencies[m] relevant_factors = [qs[idx] for idx in deps] qs_H_A_m = factor_dot(H_A_m, relevant_factors) @@ -217,23 +215,14 @@ def compute_info_gain_for_modality(qo_m, A_m, m): return jtu.tree_reduce(lambda x,y: x+y, info_gains_per_modality) -# qs_H_A = 0 # expected entropy of the likelihood, under Q(s) -# H_qo = 0 # marginal entropy of Q(o) -# for a, o, deps in zip(A, qo, A_dependencies): -# relevant_factors = jtu.tree_map(lambda idx: qs[idx], deps) -# qs_joint_relevant = relevant_factors[0] -# for q in relevant_factors[1:]: -# qs_joint_relevant = jnp.expand_dims(qs_joint_relevant, -1) * q -# H_A_m = -(a * log_stable(a)).sum(0) -# qs_H_A += (H_A_m * qs_joint_relevant).sum() - -# H_qo -= (o * log_stable(o)).sum() - -def compute_expected_utility(qo, C): +def compute_expected_utility(t, qo, C): util = 0. for o_m, C_m in zip(qo, C): - util += (o_m * C_m).sum() + if C_m.ndim > 1: + util += (o_m * C_m[t]).sum() + else: + util += (o_m * C_m).sum() return util @@ -258,13 +247,17 @@ def calc_pA_info_gain(pA, qo, qs, A_dependencies): Surprise (about Dirichlet parameters) expected for the pair of posterior predictive distributions ``qo`` and ``qs`` """ - wA = lambda pa: spm_wnorm(pa) * (pa > 0.) - fd = lambda x, i: factor_dot(x, [s for f, s in enumerate(qs) if f in A_dependencies[i]], keep_dims=(0,))[..., None] + def infogain_per_modality(pa_m, qo_m, m): + wa_m = spm_wnorm(pa_m) * (pa_m > 0.) + fd = factor_dot(wa_m, [s for f, s in enumerate(qs) if f in A_dependencies[m]], keep_dims=(0,))[..., None] + return qo_m.dot(fd) + pA_infogain_per_modality = jtu.tree_map( - lambda pa, qo, m: qo.dot(fd( wA(pa), m)), pA, qo, list(range(len(qo))) + infogain_per_modality, pA, qo, list(range(len(qo))) ) - infogain_pA = jtu.tree_reduce(lambda x, y: x + y, pA_infogain_per_modality)[0] - return infogain_pA + + infogain_pA = jtu.tree_reduce(lambda x, y: x + y, pA_infogain_per_modality) + return infogain_pA.squeeze(-1) def calc_pB_info_gain(pB, qs_t, qs_t_minus_1, B_dependencies, u_t_minus_1): """ @@ -338,7 +331,7 @@ def scan_body(carry, t): info_gain = compute_info_gain(qs_next, qo, A, A_dependencies) if use_states_info_gain else 0. - utility = compute_expected_utility(qo, C) if use_utility else 0. + utility = compute_expected_utility(t, qo, C) if use_utility else 0. inductive_value = calc_inductive_value_t(qs_init, qs_next, I, epsilon=inductive_epsilon) if use_inductive else 0. diff --git a/pymdp/jax/maths.py b/pymdp/jax/maths.py index 213b519f..ac8173be 100644 --- a/pymdp/jax/maths.py +++ b/pymdp/jax/maths.py @@ -2,11 +2,21 @@ from functools import partial from typing import Optional, Tuple, List -from jax import tree_util, nn, jit +from jax import tree_util, nn, jit, vmap, lax +from jax.scipy.special import xlogy from opt_einsum import contract MINVAL = jnp.finfo(float).eps +def stable_xlogx(x): + return xlogy(x, jnp.clip(x, MINVAL)) + +def stable_entropy(x): + return - stable_xlogx(x).sum() + +def stable_cross_entropy(x, y): + return - xlogy(x, y).sum() + def log_stable(x): return jnp.log(jnp.clip(x, min=MINVAL)) @@ -51,15 +61,19 @@ def factor_dot_flex(M, xs, dims: List[Tuple[int]], keep_dims: Optional[Tuple[int args += [keep_dims] return contract(*args, backend='jax') -def compute_log_likelihood_single_modality(o_m, A_m, distr_obs=True): - """ Compute observation likelihood for a single modality (observation and likelihood)""" +def get_likelihood_single_modality(o_m, A_m, distr_obs=True): + """Return observation likelihood for a single observation modality m""" if distr_obs: expanded_obs = jnp.expand_dims(o_m, tuple(range(1, A_m.ndim))) likelihood = (expanded_obs * A_m).sum(axis=0) else: likelihood = A_m[o_m] - - return log_stable(likelihood) + + return likelihood + +def compute_log_likelihood_single_modality(o_m, A_m, distr_obs=True): + """Compute observation log-likelihood for a single modality""" + return log_stable(get_likelihood_single_modality(o_m, A_m, distr_obs=distr_obs)) def compute_log_likelihood(obs, A, distr_obs=True): """ Compute likelihood over hidden states across observations from different modalities """ @@ -77,7 +91,7 @@ def compute_log_likelihood_per_modality(obs, A, distr_obs=True): def compute_accuracy(qs, obs, A): """ Compute the accuracy portion of the variational free energy (expected log likelihood under the variational posterior) """ - ll = compute_log_likelihood(obs, A) + log_likelihood = compute_log_likelihood(obs, A) x = qs[0] for q in qs[1:]: @@ -98,8 +112,8 @@ def compute_free_energy(qs, prior, obs, A): vfe = 0.0 # initialize variational free energy for q, p in zip(qs, prior): - negH_qs = q.dot(log_stable(q)) - xH_qp = -q.dot(log_stable(p)) + negH_qs = - stable_entropy(q) + xH_qp = stable_cross_entropy(q, p) vfe += (negH_qs + xH_qp) vfe -= compute_accuracy(qs, obs, A)