From 9c0544557f92966802d5e97c4e13530ca92decca Mon Sep 17 00:00:00 2001 From: conorheins Date: Thu, 8 Dec 2022 13:14:55 +0100 Subject: [PATCH 1/2] Added two new environments to visual_foraging.py environment module: scene construction and random dot motion --- pymdp/envs/visual_foraging.py | 328 +++++++++++++++++++++++++++++++++- 1 file changed, 327 insertions(+), 1 deletion(-) diff --git a/pymdp/envs/visual_foraging.py b/pymdp/envs/visual_foraging.py index b3cf8996..4b670371 100644 --- a/pymdp/envs/visual_foraging.py +++ b/pymdp/envs/visual_foraging.py @@ -8,12 +8,13 @@ """ from pymdp.envs import Env +from pymdp import utils, maths import numpy as np +from itertools import permutations, product LOCATION_ID = 0 SCENE_ID = 1 - class VisualForagingEnv(Env): """ Implementation of the visual foraging environment used for scene construction simulations """ @@ -146,3 +147,328 @@ def state(self): @property def true_scene(self): return self._true_scene + +scene_names = ["UP_RIGHT", "RIGHT_DOWN", "DOWN_LEFT", "LEFT_UP"] # possible scenes +quadrant_names = ['1','2','3','4'] +choice_names = ['choose_UP_RIGHT','choose_RIGHT_DOWN','choose_DOWN_LEFT', 'choose_LEFT_UP'] # possible choices +config_names = list(permutations([1,2,3,4], 2)) +all_scenes_all_configs = list(product(scene_names, config_names)) + +motion_dir = ['null','UP','RIGHT','DOWN','LEFT'] +n_states = len(motion_dir) +sampling_states = ['sample', 'break'] + +class SceneConstruction(Env): + + def __init__(self, starting_loc = 'start', scene_name = 'UP_RIGHT', config = "1_2"): + + pos1, pos2 = config.split("_") + config_tuple = (int(pos1), int(pos2)) + + assert scene_name in scene_names, f"{scene_name} is not a possible scene! please choose from {scene_names[0]}, {scene_names[1]}, {scene_names[2]}, or {scene_names[3]}\n" + assert config_tuple in config_names, f"{config} is not a possible spatial configuration! Please choose an appropriate 2x2 spatial configuration\n" + + self.current_location = starting_loc + self.scene_name = scene_name + self.config = config + self._create_visual_array() + + print(f'Starting location is {self.current_location}, Scene is {self.scene_name}, Configuration is {self.config}\n') + + def step(self,action_label): + + location = self.current_location + + if action_label == 'start': + + new_location = 'start' + what_obs = 'null' + + elif action_label in quadrant_names: + + what_obs = self.vis_array_flattened[int(action_label)-1] + new_location = action_label + + elif action_label in choice_names: + new_location = action_label + + chosen_scene_name = new_location.split('_')[1] + '_' + new_location.split('_')[2] + + if chosen_scene_name== self.scene_name: + what_obs = 'correct!' + else: + what_obs = 'incorrect!' + + self.current_location = new_location # store the new grid location + + return what_obs, self.current_location + + def reset(self): + self.current_location = "start" + print(f'Re-initialized location to Start location') + what_obs = 'null' + + return what_obs, self.current_location + + def _create_visual_array(self): + """ Create scene array """ + + vis_array_flattened = np.array(['null', 'null', 'null', 'null'],dtype="