In this project, you will train a JEPA world model on a set of pre-collected trajectories from a toy environment involving an agent in two rooms.
Joint embedding prediction architecture (JEPA) is an energy based architecture for self supervised learning first proposed by LeCun (2022). Essentially, it works by asking the model to predict its own representations of future observations.
More formally, in the context of this problem, given an agent trajectory
Where
The architecture may also be teacher-forcing (non-recurrent):
The JEPA training objective would be to minimize the energy for the observation-action sequence
Where the Target Encoder
Here's a diagram illustrating a recurrent JEPA for 4 timesteps:
The dataset consists of random trajectories collected from a toy environment consisting of an agent (dot) in two rooms separated by a wall. There's a door in a wall. The agent cannot travel through the border wall or middle wall (except through the door). Different trajectories may have different wall and door positions. Thus your JEPA model needs to be able to perceive and distinguish environment layouts. Two training trajectories with different layouts are depicted below.
Your task is to implement and train a JEPA architecture on a dataset of 2.5M frames of exploratory trajectories (see images above). Then, your model will be evaluated based on how well the predicted representations will capture the true
Here are the constraints:
- It has to be a JEPA architecture - namely you have to train it by minimizing the distance between predictions and targets in the representation space, while preventing collapse.
- You can try various methods of preventing collapse, except image reconstruction. That is - you cannot reconstruct target images as a part of your objective, such as in the case of MAE.
- You have to rely only on the provided data in folder
/scratch/DL24FA/train
. However you are allowed to apply image augmentation.
Failing to meet the above constraints will result in deducted points or even zero points
How do we evaluate the quality of our encoded and predicted representations?
One way to do it is through probing - we can see how well we can extract certain ground truth informations from the learned representations. In this particular setting, we will unroll the JEPA world model recurrently
The smaller the MSE loss on the probing validation dataset, the better our learned representations are at capturing the particular information we care about - in this case the agent location. (We can also probe for other things such as wall or door locations, but we only focus on agent location here).
The evaluation code is already implemented, so you just need to plug in your trained model to run it.
The evaluation script will train the prober on 170k frames of agent trajectories loaded from folder /scratch/DL24FA/probe_normal/train
, and evaluate it on validation sets to report the mean-squared error between probed and true global agent coordinates. There will be two known validation sets loaded from folders /scratch/DL24FA/probe_normal/val
and /scratch/DL24FA/probe_wall/val
. The first validation set contains similar trajectories from the training set, while the second consists of trajectories with agent running straight towards the wall and sometimes door, this tests how well your model is able to learn the dynamics of stopping at the wall.
There are two other validation sets that are not released but will be used to test how good your model is for long-horizon predictions, and how well your model generalize to unseen novel layouts (detail: during training we exclude the wall from showing up at certain range of x-axes, we want to see how well your model performs when the wall is placed at those x-axes).
Each team will be evaluated on
- MSE error on
probe_normal
. Weight 1 - MSE error on
probe_wall
. Weight 1 - MSE error on long horizon probing test. Weight 1
- MSE error on out of domain wall probing test. Weight 1
- Parameter count of your model (less parameters --> more points). Weight 0.25
The teams are first scorded according to each criteria
- Follow instruction to set up HPC and singularity
- Clone repo,
cd
into repo pip install -r requirements.txt
The training data can be found in /scratch/DL24FA/train/states.npy
and /scratch/DL24FA/train/actions.npy
. States have shape (num_trajectories, trajectory_length, 2, 64, 64). The observation is a two-channel image. 1st channel representing agent, and 2nd channel representing border and walls.
Actions have shape (num_trajectories, trajectory_length-1, 2), each action is a (delta x, delta y) vector specifying position shift from previous global position of agent.
Probing train dataset can be found in /scratch/DL24FA/probe_normal/train
.
Probing val datasets can be found in /scratch/DL24FA/probe_normal/val
and /scratch/DL24FA/probe_wall/val
Please implement your own training script and model architecture as a part of this existing codebase.
The probing evaluation is already implemented for you. It's inside main.py
. You just need to add change some code marked by #TODOs, namely initialize, load your model. You can also change how your model handle forward pass marked by #TODOs inside evaluator.py
. DO NOT change any other parts of main.py
and evaluator.py
.
Just run python main.py
to evaluate your model.
Create the zipped version of following folder for submission:
DL_Final_Proj/
├── main.py
├── evaluator.py
├── ... (other files including your new ones)
├── model_weights.pth
├── requirements.txt
├── metrics.txt
└── team_name.txt
Make sure main.py
is runnable with your trained model, including python command load your model weights.
If your code requires additional packages, add them to requirements.txt
. No need to include torch
unless your code requires a specific version.
metrics.txt contains 5 lines -
probe_normal val loss: {your loss}
probe_wall val loss: {your loss}
probe_wall_other val loss: {your loss}
probe_expert val loss: {your loss}
# number of trainable parameters of your model
# training command to replicate your submitted model
team_name.txt contains the name of your team and members' NETIDs.
Either email the zipped file directly (if size is permitted) or upload it to any cloud storage and email the download link to TA [email protected] with subject line "DL Final Project Submission. {Team Name}".
The TA should be able to execute the following commands to validate your submission:
- Download the file directly or
wget {link}
- Unzip the folder
- Install dependencies:
pip install -r requirements.txt
(if applicable) - Run the evaluation script:
python main.py
- Run the training script given in
metrics.txt
If the evaluation script or training script fails to execute correctly on the first try, your team's rank will be reduced by 3 positions.
Every subsequent failure after follow-up with the team will result in an additonal reduction of 3 positions. With maximum of 3 attempts. After that team gets placed on the bottom rank.
It is recommended to test your submission on a clean environment before submitting to ensure all required dependencies are included and that your scripts run as expected.
Submission deadline is 12/16 9:00AM. Winners will be picked and asked to present their work on last class day 12/18.