git clone [email protected]:porthok/RL_Simple_Shapes.git & cd RL_Simple_Shapes
pip install -e .
pip install git+https://github.com/porthok/RL_Simple_Shapes.git
This environment was implemented in the context of Zero-shot cross-modal transfer of Reinforcement Learning policies through a Global Workspace paper presented at the Reinforcement Learning Conference 2024.
The environment is composed of a a shape contain in a 32x32 image. The images contain one shape (among a triangle, an "egg", and a "diamond") possessing different attributes:
- a size
- a location (x, y)
- a rotation
- a color
The goal of the agent is to place the shape at the center of the image pointing to the top.
The shape can only translate or rotate at each time step. This leads to 6 different possible actions:
- Goes up by one pixel
- Goes down by one pixel
- Goes right by one pixel
- Goes left by one pixel
- Rotate right by an angle of
$\frac{\pi}{32}$ - Rotate left by an angle of
$\frac{\pi}{32}$
This environment is made to be multimodal. By default, the it returns a dictionary composed of two modalities: {'attr': attributes, 'v': image}
. But can also return only one of the two if specified.
spaces.Dict({ 'attr': spaces.Box(low=np.array([0, -32, -32, 7, -1, -1, 0, 0, 0]), high=np.array([2, 32, 32, 14, 1, 1, 255, 255, 255])), 'v': spaces.Box(low=0, high=255, shape=(32,32,3))})
The attributes represent the state of the environment. It is a vector composed of 9 float for the different attributes of the shape.
- shape's class: 0 for diamond, 1 for egg, 2 for triangle
- x position
$\in[0,32]$ - y position
$\in[0,32]$ - size
$\in[0,14]$ - cosinus of the rotation -1 and 1
$\in[-1,1]$ - sinus of the rotation -1 and 1
$\in[-1,1]$ - the RGB color
$\in[0,255]^{3}$
spaces.Box(low=np.array([0, -32, -32, 7, -1, -1, 0, 0, 0]), high=np.array([2, 32, 32, 14, 1, 1, 255, 255, 255]))
The visual part is a 32x32 RGB image between 0 and 255 like the following ones.
spaces.Box(low=0, high=255, shape=(32,32,3))
This environment was created to evaluate the model introduced by the paper "Semi-supervised multimodal representation learning through a global workspace".
This model is inspired from the Global Workspace introduced by Baars in the Cognitive Science field. It has been adapted to link multimodal data in a central space. This implementation posses 2 modalities, vision and attributes. The raw data are encoded to unimodal latent vectors (
- Semi-supervised multimodal representation learning through a global workspace, B. Devillers, L. Maytie, R. VanRullen, IEEE Transactions on Neural Networks and Learning Systems, 2024
or
- Zero-shot cross-modal transfer of Reinforcement Learning policies through a Global Workspace, L. Maytie, B. Devillers, A. Arnold, R. VanRullen, Reinforcement Learning Journal 2024
The goal was to test if zero-shot cross-modal policy transfer was possible using this model. Different models have been used in the paper trained with different amount of data and different losses. The checkpoints of these trainings are available at the Simple_Shapes_RL/checkpoints/
folder.
The different checkpoints used in Zero-shot cross-modal transfer of Reinforcement Learning policies through a Global Workspace paper are available in this folder Simple_Shapes_RL/checkpoints/
-
VAE_ckpt/
: contains the checkpoint of the VAE used to convert raw images to unimodal latent vector$o^v$ -
GW_5000_ckpt/
: contains the checkpoint of the GW trained with all the losses$(L_{tr}+L_{cont}+L_{dcy}+L_{cy})$ and with 5000 paired images-attributes -
GW_ckpt/
: contains the checkpoint of the GW trained with all the losses$(L_{tr}+L_{cont}+L_{dcy}+L_{cy})$ and with 500 000 paired images-attributes -
GW_cont_5000_ckpt/
: contains the checkpoint of the GW trained with the contrastive loss only$(L_{cont})$ and with 5000 paired images-attributes -
GW_cont_ckpt/
: contains the checkpoint of the GW trained with the contrastive loss only$(L_{cont})$ and with 500 000 paired images-attributes -
GW_trad_cont_5000_ckpt/
: contains the checkpoint of the GW trained with the contrastive and translation losses only$(L_{tr}+L_{cont})$ and with 5000 paired images-attributes -
GW_trad_cont_ckpt/
: contains the checkpoint of the GW trained with the contrastive and translation losses only$(L_{tr}+L_{cont})$ and with 500 000 paired images-attributes
Different scripts are already written to play with this environment.
train.py
: To train a RL algorithm using PPO stable baselines 3 and the Global Workspacetest.py
: To test a modelvisualisation.py
: To visualize the policy learnt by PPO
Please ensure proper citations when incorporating this work into your projects.
@article{maytie2024zero,
title={Zero-shot cross-modal transfer of Reinforcement Learning policies through a Global Workspace},
author={Mayti{\'{e}}, L{\'{e}}opold and Devillers, Benjamin and Arnold, Alexandre and VanRullen, Rufin},
journal={Reinforcement Learning Journal},
volume={3},
pages={1410--1426},
year={2024}
}