From a958270ac8807ae03a0f3dfdc8d48979a19e3c02 Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Wed, 2 Oct 2024 20:01:24 +0200 Subject: [PATCH] fix reset --- pymdp/envs/env.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymdp/envs/env.py b/pymdp/envs/env.py index 2814a2a8..d1cd769e 100644 --- a/pymdp/envs/env.py +++ b/pymdp/envs/env.py @@ -19,6 +19,7 @@ def cat_sample(key, p): if p.ndim > 1: choice = lambda key, p: jr.choice(key, a, p=p) keys = jr.split(key, len(p)) + print(keys.shape) return vmap(choice)(keys, p) return jr.choice(key, a, p=p) @@ -42,7 +43,7 @@ def __init__(self, params: Dict, dependencies: Dict, init_state: List[Array] = N @vmap def reset(self, key: Optional[PRNGKeyArray], state: Optional[List[Array]] = None): - if state is None: + if state is not None: state = self.state else: probs = self.params["D"]