A reinforcement learning algorithm is characterised by the trajectories it generates during training.
We are interested in "algorithm distillation" - whether trajectories can be modelled by transformers, as studied in the original deepmind algorithm distillation paper.
A particular focus of this repo is to extend prior work to the case where:
- the trajectories have been generated by the TRLx library during RLHF training of language models
- the transformer modelling the trajectories is itself a standard language model
A trajectory is typically defined as a list of (state, action, reward)
triples. For training purposes, it is sometimes useful to augment this to include logprobs
, which is, for each triple (s, a, r)
, the probability of taking action
We therefore define an RL Format Trajectory as a sequence of (state, action, reward, logprobs)
tuples.
The typical way to learn to model these trajectories with a transformer is to seperately map the final hidden state using 3 different heads. That is, for a given triple (s,a,r,l)
a transformer
In this repo, this is done via the models in /models/rl
.
We are also interested in the ability of standard language models (with language modeling heads) to learn trajectories. To this end we define a Language Format Trajectory as a trajectory serialised into a string. There are many possible ways to do this, and the optimal one requires investigation. For example, for trajectories generated using TRLx when finetuning a language model on positive sentiment, we can format the trajectory as the string:
prompt: Dan went down to the shops.
completion: He smiled as he walked - the sun was shining.
reward: 0.9975
###
It's less obvious how to do this when the task is not a language task, such as moonlander. Enumerating the states as coordinates might work, but requires experimentation.
Trajectories in Language format are learnt by models in /models/lm
.
/models
contains the "algorithm distillation models", transformers that are trained in a supervised fashion to learn RL trajectories. We distinguish between models that operate on RL Format trajectories and Language format trajectories.
/tasks
contains code to produce the RL trajectories that the models learn. It can store this data however it likes, but each task should expose a torch.utils.data.Dataset
that can return trajectory data in either RL Format or Language format.
I am using my own fork of TRLx that has rollout logging.
Today: [X] Set up repo structure (just for your language stuff, @H can add in his) [X] Add train script for models/lm/casuallm [ ] Clone H's stuff and merge with @H stuff (/models/rl) and (/tasks/rl) (25 mins) [ ] Proper PR for TRLx (25 mins) [ ] Post guide and project tasks on discord
Future:
[ ] Add more elegant meta class switching between ...LanguageTrajectories and ...RlTrajectories
[ ] Add online evaluation script for models/lm/casuallm
[ ] Improve train script to include reward accuracy
[ ] Run some preliminary experiments
[ ] Add main file with click CLI interface for running experiments