Skip to content

Commit

Permalink
fix reset
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Verbelen committed Oct 2, 2024
1 parent 18f6903 commit a958270
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pymdp/envs/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def cat_sample(key, p):
if p.ndim > 1:
choice = lambda key, p: jr.choice(key, a, p=p)
keys = jr.split(key, len(p))
print(keys.shape)
return vmap(choice)(keys, p)

return jr.choice(key, a, p=p)
Expand All @@ -42,7 +43,7 @@ def __init__(self, params: Dict, dependencies: Dict, init_state: List[Array] = N

@vmap
def reset(self, key: Optional[PRNGKeyArray], state: Optional[List[Array]] = None):
if state is None:
if state is not None:
state = self.state
else:
probs = self.params["D"]
Expand Down

0 comments on commit a958270

Please sign in to comment.