diff --git a/pymdp/agent.py b/pymdp/agent.py index 2bec1a16..94d093e2 100644 --- a/pymdp/agent.py +++ b/pymdp/agent.py @@ -46,7 +46,7 @@ def __init__( control_fac_idx=None, policies=None, gamma=16.0, - alpha = 16.0, + alpha=16.0, use_utility=True, use_states_info_gain=True, use_param_info_gain=False, @@ -394,7 +394,7 @@ def get_future_qs(self): return future_qs_seq - def infer_states(self, observation): + def infer_states(self, observation, distr_obs = False): """ Update approximate posterior over hidden states by solving variational inference problem, given an observation. @@ -414,7 +414,7 @@ def infer_states(self, observation): at timepoint ``t_idx``. """ - observation = tuple(observation) + observation = tuple(observation) if not distr_obs else observation if not hasattr(self, "qs"): self.reset()