diff --git a/pymdp/jax/task.py b/pymdp/jax/task.py index 5de0315e..2cf9eb0f 100644 --- a/pymdp/jax/task.py +++ b/pymdp/jax/task.py @@ -62,9 +62,6 @@ def step(self, key: PRNGKeyArray, actions: Optional[Array] = None): keys = list(jr.split(key_state, len(state_probs))) new_states = jtu.tree_map(cat_sample, keys, state_probs) - - states.append(new_states) - else: new_states = states[-1] @@ -76,4 +73,4 @@ def step(self, key: PRNGKeyArray, actions: Optional[Array] = None): keys = list(jr.split(key_obs, len(obs_probs))) new_obs = jtu.tree_map(cat_sample, keys, obs_probs) - return new_obs, tree_at(lambda x: (x.states), self, states) \ No newline at end of file + return new_obs, tree_at(lambda x: (x.states), self, [new_states]) \ No newline at end of file