Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAC jax #300

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
08cdd3c
Remove unused params
araffin Oct 23, 2022
9c8e642
Add SAC Jax version
araffin Oct 23, 2022
9c5da79
Upgrade SB3
araffin Oct 23, 2022
d0bf6a8
Revert SB3 upgrade (gym incompatible)
araffin Oct 23, 2022
2ace8b5
Add progress bar
araffin Oct 23, 2022
1d7e36d
Fix import
araffin Oct 23, 2022
799eadb
Remove unused code
araffin Oct 23, 2022
61be9db
Update requirements
araffin Oct 23, 2022
8fce3a8
Update lock file
araffin Oct 23, 2022
631d90a
Add test for jax
araffin Oct 23, 2022
b2481a4
Display FPS only when needed
araffin Oct 23, 2022
1048689
update lock files
vwxyzjn Oct 24, 2022
39a20d0
fix test cases, use the same naming convention
vwxyzjn Oct 24, 2022
ecd66c8
Add constant ent coef support and improve types
araffin Oct 24, 2022
e6958a2
Add deterministic evaluation
araffin Oct 24, 2022
101089c
Use deterministic eval
araffin Oct 24, 2022
fe2295f
Rescale actions
araffin Oct 24, 2022
a82de8d
Use `RLTrainState`
vwxyzjn Nov 20, 2022
e285f99
remove `ReplayBufferSamplesNp`
vwxyzjn Nov 20, 2022
cac42cc
format docstring
vwxyzjn Nov 20, 2022
5e8cd88
reorganize code
vwxyzjn Nov 20, 2022
385a7ac
(docs) the reparameterization trick
vwxyzjn Nov 20, 2022
30258c5
Merge branch 'master' into feat/sac-jax
vwxyzjn Nov 21, 2022
668ea1d
update reference
vwxyzjn Nov 21, 2022
0cf0e9e
remove tensorflow_probability
vwxyzjn Nov 21, 2022
9082f57
add benchmark script
vwxyzjn Nov 21, 2022
9f2fd81
properly implement log probability
vwxyzjn Nov 22, 2022
be38473
add unit test and fix log prob calc
vwxyzjn Nov 22, 2022
15c30c8
Fix critic loss
araffin Nov 28, 2022
623da0f
Merge branch 'master' into feat/sac-jax
araffin Jun 15, 2023
f0ee601
Revert poetry lock to master
araffin Jun 15, 2023
bbec22d
Re-add tf proba
araffin Jun 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion benchmark/sac.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,13 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
--command "poetry run python cleanrl/sac_continuous_action.py --track --capture-video" \
--num-seeds 3 \
--workers 3
--workers 3

poetry install --with mujoco,pybullet,jax
poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python -c "import mujoco_py"
poetry run python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
--command "poetry run python cleanrl/sac_continuous_action_jax.py --track" \
--num-seeds 3 \
--workers 1
16 changes: 8 additions & 8 deletions cleanrl/ddpg_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def __call__(self, x):
return x


class TrainState(TrainState):
target_params: flax.core.FrozenDict
class RLTrainState(TrainState):
target_params: flax.core.FrozenDict = None


if __name__ == "__main__":
Expand Down Expand Up @@ -163,13 +163,13 @@ class TrainState(TrainState):
action_bias=action_bias,
)
qf1 = QNetwork()
actor_state = TrainState.create(
actor_state = RLTrainState.create(
apply_fn=actor.apply,
params=actor.init(actor_key, obs),
target_params=actor.init(actor_key, obs),
tx=optax.adam(learning_rate=args.learning_rate),
)
qf1_state = TrainState.create(
qf1_state = RLTrainState.create(
apply_fn=qf1.apply,
params=qf1.init(qf1_key, obs, envs.action_space.sample()),
target_params=qf1.init(qf1_key, obs, envs.action_space.sample()),
Expand All @@ -180,8 +180,8 @@ class TrainState(TrainState):

@jax.jit
def update_critic(
actor_state: TrainState,
qf1_state: TrainState,
actor_state: RLTrainState,
qf1_state: RLTrainState,
observations: np.ndarray,
actions: np.ndarray,
next_observations: np.ndarray,
Expand All @@ -202,8 +202,8 @@ def mse_loss(params):

@jax.jit
def update_actor(
actor_state: TrainState,
qf1_state: TrainState,
actor_state: RLTrainState,
qf1_state: RLTrainState,
observations: np.ndarray,
):
def actor_loss(params):
Expand Down
6 changes: 3 additions & 3 deletions cleanrl/dqn_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def __call__(self, x):
return x


class TrainState(TrainState):
target_params: flax.core.FrozenDict
class RLTrainState(TrainState):
target_params: flax.core.FrozenDict = None


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
Expand Down Expand Up @@ -163,7 +163,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

q_network = QNetwork(action_dim=envs.single_action_space.n)

q_state = TrainState.create(
q_state = RLTrainState.create(
apply_fn=q_network.apply,
params=q_network.init(q_key, obs),
target_params=q_network.init(q_key, obs),
Expand Down
6 changes: 3 additions & 3 deletions cleanrl/dqn_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def __call__(self, x: jnp.ndarray):
return x


class TrainState(TrainState):
target_params: flax.core.FrozenDict
class RLTrainState(TrainState):
target_params: flax.core.FrozenDict = None


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
Expand Down Expand Up @@ -136,7 +136,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

q_network = QNetwork(action_dim=envs.single_action_space.n)

q_state = TrainState.create(
q_state = RLTrainState.create(
apply_fn=q_network.apply,
params=q_network.init(q_key, obs),
target_params=q_network.init(q_key, obs),
Expand Down
4 changes: 0 additions & 4 deletions cleanrl/sac_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def parse_args():
help="target smoothing coefficient (default: 0.005)")
parser.add_argument("--batch-size", type=int, default=256,
help="the batch size of sample from the reply memory")
parser.add_argument("--exploration-noise", type=float, default=0.1,
help="the scale of exploration noise")
parser.add_argument("--learning-starts", type=int, default=5e3,
help="timestep to start learning")
parser.add_argument("--policy-lr", type=float, default=3e-4,
Expand All @@ -61,8 +59,6 @@ def parse_args():
help="the frequency of training policy (delayed)")
parser.add_argument("--target-network-frequency", type=int, default=1, # Denis Yarats' implementation delays this by 2.
help="the frequency of updates for the target nerworks")
parser.add_argument("--noise-clip", type=float, default=0.5,
help="noise clip parameter of the Target Policy Smoothing Regularization")
parser.add_argument("--alpha", type=float, default=0.2,
help="Entropy regularization coefficient.")
parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True,
Expand Down
Loading