From da3716c49734ab42578336ed023bd81397e87304 Mon Sep 17 00:00:00 2001 From: conorheins Date: Wed, 2 Oct 2024 12:06:00 +0200 Subject: [PATCH] - renamed `PyMDPEnv` to `Env` throughout the repo - started version of `T-Maze` environment in new jax `Env` class Co-authored-by: Tim Verbelen --- pymdp/envs/__init__.py | 2 +- pymdp/envs/env.py | 16 ++++- pymdp/envs/rollout.py | 8 +-- pymdp/envs/tmaze.py | 135 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 154 insertions(+), 7 deletions(-) create mode 100644 pymdp/envs/tmaze.py diff --git a/pymdp/envs/__init__.py b/pymdp/envs/__init__.py index afac7577..b02c781d 100644 --- a/pymdp/envs/__init__.py +++ b/pymdp/envs/__init__.py @@ -1,2 +1,2 @@ -from .env import PyMDPEnv from .graph_worlds import GraphEnv +from .tmaze import TMaze diff --git a/pymdp/envs/env.py b/pymdp/envs/env.py index 81c3a062..a23d241a 100644 --- a/pymdp/envs/env.py +++ b/pymdp/envs/env.py @@ -24,7 +24,7 @@ def cat_sample(key, p): return jr.choice(key, a, p=p) -class PyMDPEnv(Module): +class Env(Module): params: Dict state: List[Array] dependencies: Dict = field(static=True) @@ -46,7 +46,19 @@ def reset(self, key: Optional[PRNGKeyArray] = None): keys = list(jr.split(key, len(probs))) state = jtu.tree_map(cat_sample, keys, probs) - return tree_at(lambda x: x.state, self, state) + return tree_at(lambda x: x.state, self, state) # TODO: change this to return an observation + + def render(self, mode="human"): + """ + + Returns + ---- + if mode == "human": + returns None, renders the environment using MPL inside the function + elif mode == "rgb_array": + A (H, W, 3) uint8 jax.numpy array, with values between 0 and 255 + """ + pass @vmap def step(self, rng_key: PRNGKeyArray, actions: Optional[Array] = None): diff --git a/pymdp/envs/rollout.py b/pymdp/envs/rollout.py index 3eb394c9..342f352b 100644 --- a/pymdp/envs/rollout.py +++ b/pymdp/envs/rollout.py @@ -4,10 +4,10 @@ import jax.lax from pymdp.jax.agent import Agent -from pymdp.jax.envs.env import PyMDPEnv +from pymdp.jax.envs.env import Env -def rollout(agent: Agent, env: PyMDPEnv, num_timesteps: int, rng_key: jr.PRNGKey, policy_search=None): +def rollout(agent: Agent, env: Env, num_timesteps: int, rng_key: jr.PRNGKey, policy_search=None): """ Rollout an agent in an environment for a number of timesteps. @@ -15,7 +15,7 @@ def rollout(agent: Agent, env: PyMDPEnv, num_timesteps: int, rng_key: jr.PRNGKey ---------- agent: ``Agent`` Agent to interact with the environment - env: ``PyMDPEnv` + env: ``Env` Environment to interact with num_timesteps: ``int`` Number of timesteps to rollout for @@ -32,7 +32,7 @@ def rollout(agent: Agent, env: PyMDPEnv, num_timesteps: int, rng_key: jr.PRNGKey Carry dictionary from the last timestep info: ``dict`` Dictionary containing information about the rollout, i.e. executed actions, observations, beliefs, etc. - env: ``PyMDPEnv`` + env: ``Env`` Environment state after the rollout """ # get the batch_size of the agent diff --git a/pymdp/envs/tmaze.py b/pymdp/envs/tmaze.py new file mode 100644 index 00000000..50f4f07d --- /dev/null +++ b/pymdp/envs/tmaze.py @@ -0,0 +1,135 @@ +import jax.numpy as jnp +from equinox import field + +from pympd.envs.env import Env + + +class TMaze(Env): + """ + Implementation of the 3-arm T-Maze environment. + """ + reward_probability: float = field(static=True) + + def __init__(self, batch_size=1, reward_probability=0.98, reward_condition=None): + self.reward_probability = reward_probability + + A, A_dependencies = self.generate_A() + A = [jnp.broadcast_to(a, (batch_size,) + a.shape) for a in A] + B, B_dependencies = self.generate_B() + B = [jnp.broadcast_to(b, (batch_size,) + b.shape) for b in B] + D = self.generate_D(reward_condition) + + params = { + "A": A, + "B": B, + "D": D, + } + + dependencies = { + "A": A_dependencies, + "B": B_dependencies, + } + + super().__init__(params, dependencies) + + + def generate_A(self): + """ + T-maze has 3 observation modalities: location, reward and cue, + and 2 state factors: agent location [center, left, right, cue] and reward location [left, right] + """ + A = [] + A.append(jnp.eye(4)) + A.append(jnp.zeros([2, 4, 2])) + A.append(jnp.zeros([2, 4, 2])) + + A_dependencies = [[0], [0, 1], [0, 1]] + + # 4 locations : [center, left, right, cue] + for loc in range(4): + # 2 reward conditions: [left, right] + for reward_condition in range(2): + # start location + if loc == 0: + # When in the centre location, reward observation is always 'no reward' + # or the outcome with index 0 + A[1] = A[1].at[0, loc, reward_condition].set(1.0) + + # When in the centre location, cue is totally ambiguous with respect to the reward condition + A[2] = A[2].at[:, loc, reward_condition].set(0.5) + + # The case when loc == 3, or the cue location ('bottom arm') + elif loc == 3: + + # When in the cue location, reward observation is always 'no reward' + # or the outcome with index 0 + A[1] = A[1].at[0, loc, reward_condition].set(1.0) + + # When in the cue location, the cue indicates the reward condition umambiguously + # signals where the reward is located + A[2] = A[2].at[reward_condition, loc, reward_condition].set(1.0) + + # The case when the agent is in one of the (potentially) rewarding arms + else: + + # When location is consistent with reward condition + if loc == (reward_condition + 1): + # Means highest probability is concentrated over reward outcome + high_prob_idx = 1 + # Lower probability on loss outcome + low_prob_idx = 2 + else: + # Means highest probability is concentrated over loss outcome + high_prob_idx = 2 + # Lower probability on reward outcome + low_prob_idx = 1 + + A[1] = A[1].at[high_prob_idx, loc, reward_condition].set(self.reward_probility) + A[1] = A[1].at[low_prob_idx, loc, reward_condition].set(1 - self.reward_probability) + + # Cue is ambiguous when in the reward location + A[2] = A[2].at[:, loc, reward_condition].set(0.5) + + return A, A_dependencies + + + def generate_B(self): + """ + T-maze has 2 state factors: + agent location [center, left, right, cue] and reward location [left, right] + agent can move between locations by teleporting, reward location stays fixed + """ + B = [] + + # agent can teleport to any location + B_loc = jnp.eye(4) + B_loc = B_loc.reshape(4, 4, 1) + B_loc = jnp.tile(B_loc, (1, 1, 4)) + B_loc = B_loc.transpose(1, 2, 0) + B.append(B_loc) + + # reward condition stays fixed + B_reward = jnp.eye(2).reshape(2, 2, 1) + B.append(B_reward) + + B_dependencies = [[0], [1]] + + return B, B_dependencies + + def generate_D(self, reward_condition=None): + """ + Agent starts at center + Reward condition can be set or randomly sampled + """ + D = [] + D_loc = jnp.zeros([4]) + D_loc = D_loc.at[0].set(1.0) + D.append(D_loc) + + if reward_condition is None: + D_reward = jnp.ones(2) * 0.5 + else: + D_reward = jnp.zeros(2) + D_reward = D_reward.at[reward_condition].set(1.0) + D.append(D_reward) + return D \ No newline at end of file