diff --git a/cleanrl/c51_atari.py b/cleanrl/c51_atari.py index 97b790759..53f29c62c 100755 --- a/cleanrl/c51_atari.py +++ b/cleanrl/c51_atari.py @@ -294,7 +294,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): run_name=f"{run_name}-eval", Model=QNetwork, device=device, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) diff --git a/cleanrl/c51_atari_jax.py b/cleanrl/c51_atari_jax.py index 8cd46e855..f6e74f36b 100644 --- a/cleanrl/c51_atari_jax.py +++ b/cleanrl/c51_atari_jax.py @@ -334,7 +334,7 @@ def get_action(q_state, obs): eval_episodes=10, run_name=f"{run_name}-eval", Model=QNetwork, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) diff --git a/cleanrl/dqn_atari.py b/cleanrl/dqn_atari.py index a23b84391..dae131811 100644 --- a/cleanrl/dqn_atari.py +++ b/cleanrl/dqn_atari.py @@ -263,7 +263,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): run_name=f"{run_name}-eval", Model=QNetwork, device=device, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) diff --git a/cleanrl/dqn_atari_jax.py b/cleanrl/dqn_atari_jax.py index 383ceeef8..2aa563bd7 100644 --- a/cleanrl/dqn_atari_jax.py +++ b/cleanrl/dqn_atari_jax.py @@ -292,7 +292,7 @@ def mse_loss(params): eval_episodes=10, run_name=f"{run_name}-eval", Model=QNetwork, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) diff --git a/cleanrl/qdagger_dqn_atari_impalacnn.py b/cleanrl/qdagger_dqn_atari_impalacnn.py index 6cde11c99..20370707e 100644 --- a/cleanrl/qdagger_dqn_atari_impalacnn.py +++ b/cleanrl/qdagger_dqn_atari_impalacnn.py @@ -263,7 +263,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits): eval_episodes=args.teacher_eval_episodes, run_name=f"{run_name}-teacher-eval", Model=TeacherModel, - epsilon=0.05, + epsilon=args.end_e, capture_video=False, ) writer.add_scalar("charts/teacher/avg_episodic_return", np.mean(teacher_episodic_returns), 0) @@ -342,7 +342,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits): run_name=f"{run_name}-eval", Model=QNetwork, device=device, - epsilon=0.05, + epsilon=args.end_e, ) print(episodic_returns) writer.add_scalar("charts/offline/avg_episodic_return", np.mean(episodic_returns), global_step) @@ -459,7 +459,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits): run_name=f"{run_name}-eval", Model=QNetwork, device=device, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) diff --git a/cleanrl/qdagger_dqn_atari_jax_impalacnn.py b/cleanrl/qdagger_dqn_atari_jax_impalacnn.py index 7ecbb5c47..46c354263 100644 --- a/cleanrl/qdagger_dqn_atari_jax_impalacnn.py +++ b/cleanrl/qdagger_dqn_atari_jax_impalacnn.py @@ -262,7 +262,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): eval_episodes=args.teacher_eval_episodes, run_name=f"{run_name}-teacher-eval", Model=TeacherModel, - epsilon=0.05, + epsilon=args.end_e, capture_video=False, ) writer.add_scalar("charts/teacher/avg_episodic_return", np.mean(teacher_episodic_returns), 0) @@ -361,7 +361,7 @@ def loss(params, td_target, teacher_q_values, distill_coeff): eval_episodes=10, run_name=f"{run_name}-eval", Model=QNetwork, - epsilon=0.05, + epsilon=args.end_e, ) print(episodic_returns) writer.add_scalar("charts/offline/avg_episodic_return", np.mean(episodic_returns), global_step) @@ -469,7 +469,7 @@ def loss(params, td_target, teacher_q_values, distill_coeff): eval_episodes=10, run_name=f"{run_name}-eval", Model=QNetwork, - epsilon=0.05, + epsilon=args.end_e, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx)