Skip to content

Commit

Permalink
fixed numerical issues with gradient computations
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jul 12, 2024
1 parent 1958e4d commit 799581a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
41 changes: 17 additions & 24 deletions pymdp/jax/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,8 @@ def compute_info_gain(qs, qo, A, A_dependencies):
"""

def compute_info_gain_for_modality(qo_m, A_m, m):
H_qo = - xlogy(qo_m, qo_m).sum()
# H_qo = - (qo_m * log_stable(qo_m)).sum()
H_A_m = - xlogy(A_m, A_m).sum(0)
# H_A_m = - (A_m * log_stable(A_m)).sum(0)
H_qo = stable_entropy(qo_m)
H_A_m = - stable_xlogx(A_m).sum(0)
deps = A_dependencies[m]
relevant_factors = [qs[idx] for idx in deps]
qs_H_A_m = factor_dot(H_A_m, relevant_factors)
Expand All @@ -217,23 +215,14 @@ def compute_info_gain_for_modality(qo_m, A_m, m):

return jtu.tree_reduce(lambda x,y: x+y, info_gains_per_modality)

# qs_H_A = 0 # expected entropy of the likelihood, under Q(s)
# H_qo = 0 # marginal entropy of Q(o)
# for a, o, deps in zip(A, qo, A_dependencies):
# relevant_factors = jtu.tree_map(lambda idx: qs[idx], deps)
# qs_joint_relevant = relevant_factors[0]
# for q in relevant_factors[1:]:
# qs_joint_relevant = jnp.expand_dims(qs_joint_relevant, -1) * q
# H_A_m = -(a * log_stable(a)).sum(0)
# qs_H_A += (H_A_m * qs_joint_relevant).sum()

# H_qo -= (o * log_stable(o)).sum()

def compute_expected_utility(qo, C):
def compute_expected_utility(t, qo, C):

util = 0.
for o_m, C_m in zip(qo, C):
util += (o_m * C_m).sum()
if C_m.ndim > 1:
util += (o_m * C_m[t]).sum()
else:
util += (o_m * C_m).sum()

return util

Expand All @@ -258,13 +247,17 @@ def calc_pA_info_gain(pA, qo, qs, A_dependencies):
Surprise (about Dirichlet parameters) expected for the pair of posterior predictive distributions ``qo`` and ``qs``
"""

wA = lambda pa: spm_wnorm(pa) * (pa > 0.)
fd = lambda x, i: factor_dot(x, [s for f, s in enumerate(qs) if f in A_dependencies[i]], keep_dims=(0,))[..., None]
def infogain_per_modality(pa_m, qo_m, m):
wa_m = spm_wnorm(pa_m) * (pa_m > 0.)
fd = factor_dot(wa_m, [s for f, s in enumerate(qs) if f in A_dependencies[m]], keep_dims=(0,))[..., None]
return qo_m.dot(fd)

pA_infogain_per_modality = jtu.tree_map(
lambda pa, qo, m: qo.dot(fd( wA(pa), m)), pA, qo, list(range(len(qo)))
infogain_per_modality, pA, qo, list(range(len(qo)))
)
infogain_pA = jtu.tree_reduce(lambda x, y: x + y, pA_infogain_per_modality)[0]
return infogain_pA

infogain_pA = jtu.tree_reduce(lambda x, y: x + y, pA_infogain_per_modality)
return infogain_pA.squeeze(-1)

def calc_pB_info_gain(pB, qs_t, qs_t_minus_1, B_dependencies, u_t_minus_1):
"""
Expand Down Expand Up @@ -338,7 +331,7 @@ def scan_body(carry, t):

info_gain = compute_info_gain(qs_next, qo, A, A_dependencies) if use_states_info_gain else 0.

utility = compute_expected_utility(qo, C) if use_utility else 0.
utility = compute_expected_utility(t, qo, C) if use_utility else 0.

inductive_value = calc_inductive_value_t(qs_init, qs_next, I, epsilon=inductive_epsilon) if use_inductive else 0.

Expand Down
30 changes: 22 additions & 8 deletions pymdp/jax/maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,21 @@

from functools import partial
from typing import Optional, Tuple, List
from jax import tree_util, nn, jit
from jax import tree_util, nn, jit, vmap, lax
from jax.scipy.special import xlogy
from opt_einsum import contract

MINVAL = jnp.finfo(float).eps

def stable_xlogx(x):
return xlogy(x, jnp.clip(x, MINVAL))

def stable_entropy(x):
return - stable_xlogx(x).sum()

def stable_cross_entropy(x, y):
return - xlogy(x, y).sum()

def log_stable(x):
return jnp.log(jnp.clip(x, min=MINVAL))

Expand Down Expand Up @@ -51,15 +61,19 @@ def factor_dot_flex(M, xs, dims: List[Tuple[int]], keep_dims: Optional[Tuple[int
args += [keep_dims]
return contract(*args, backend='jax')

def compute_log_likelihood_single_modality(o_m, A_m, distr_obs=True):
""" Compute observation likelihood for a single modality (observation and likelihood)"""
def get_likelihood_single_modality(o_m, A_m, distr_obs=True):
"""Return observation likelihood for a single observation modality m"""
if distr_obs:
expanded_obs = jnp.expand_dims(o_m, tuple(range(1, A_m.ndim)))
likelihood = (expanded_obs * A_m).sum(axis=0)
else:
likelihood = A_m[o_m]

return log_stable(likelihood)

return likelihood

def compute_log_likelihood_single_modality(o_m, A_m, distr_obs=True):
"""Compute observation log-likelihood for a single modality"""
return log_stable(get_likelihood_single_modality(o_m, A_m, distr_obs=distr_obs))

def compute_log_likelihood(obs, A, distr_obs=True):
""" Compute likelihood over hidden states across observations from different modalities """
Expand All @@ -77,7 +91,7 @@ def compute_log_likelihood_per_modality(obs, A, distr_obs=True):
def compute_accuracy(qs, obs, A):
""" Compute the accuracy portion of the variational free energy (expected log likelihood under the variational posterior) """

ll = compute_log_likelihood(obs, A)
log_likelihood = compute_log_likelihood(obs, A)

x = qs[0]
for q in qs[1:]:
Expand All @@ -98,8 +112,8 @@ def compute_free_energy(qs, prior, obs, A):

vfe = 0.0 # initialize variational free energy
for q, p in zip(qs, prior):
negH_qs = q.dot(log_stable(q))
xH_qp = -q.dot(log_stable(p))
negH_qs = - stable_entropy(q)
xH_qp = stable_cross_entropy(q, p)
vfe += (negH_qs + xH_qp)

vfe -= compute_accuracy(qs, obs, A)
Expand Down

0 comments on commit 799581a

Please sign in to comment.