From 4fa006fa61d2526e205d5043c88efe3a11d0f52b Mon Sep 17 00:00:00 2001 From: Dimitrije Markovic <5038100+dimarkov@users.noreply.github.com> Date: Wed, 3 Jul 2024 15:13:13 +0200 Subject: [PATCH] keep only last state in the environment --- pymdp/jax/task.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pymdp/jax/task.py b/pymdp/jax/task.py index 5de0315e..2cf9eb0f 100644 --- a/pymdp/jax/task.py +++ b/pymdp/jax/task.py @@ -62,9 +62,6 @@ def step(self, key: PRNGKeyArray, actions: Optional[Array] = None): keys = list(jr.split(key_state, len(state_probs))) new_states = jtu.tree_map(cat_sample, keys, state_probs) - - states.append(new_states) - else: new_states = states[-1] @@ -76,4 +73,4 @@ def step(self, key: PRNGKeyArray, actions: Optional[Array] = None): keys = list(jr.split(key_obs, len(obs_probs))) new_obs = jtu.tree_map(cat_sample, keys, obs_probs) - return new_obs, tree_at(lambda x: (x.states), self, states) \ No newline at end of file + return new_obs, tree_at(lambda x: (x.states), self, [new_states]) \ No newline at end of file