Constrained Policy Optimization is a safe reinforcement learning algorithm that solves constrained Markov decision processes to ensure safety. Our implementation is a port of the original OpenAI implementation to JAX.
First, make sure to have a python 3.10.12 installed.
poetry install
Check out the additional (optional) installation groups in pyproject.toml
for additional functionality.
You have two options, cloning the repository (for example, for local development and hacking) or just install it as it is, directly from github.
- Clone:
git clone https://github.com/lasgroup/jax-cpo.git
, thencd jax-cpo
andpip install -e .
; or pip install git+https://[email protected]/lasgroup/jax-cpo
Via Trainer
class
This is the easier entry point for running experiments. A usage example here.
If you just want to use our implementation with a different training/evaluation setup, you can directly use the CPO
class. The only required interface is via the __call__(observation: np.ndarray, train: bool) -> np.array
function. The function implements the following:
- Observes the state (provided by the environment), put it in an episodic buffer for the next policy update.
- At each timestep use the current policy to return an action.
- Whenever the
train
flag is true, and the buffer is full, a policy update is triggered.
Consult configs.yaml
for hyper-parameters.