diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 71f32c2c5..1262e0671 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -22,9 +22,11 @@ New Features: ^^^^^^^^^^^^^ - Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined) - Improved error message when mixing Gym API with VecEnv API (see GH#1694) +- Added capability to log on a step-based interval in OffPolicyAlgorithm (@tobiabir) - Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss) - Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to on-policy algorithms (A2C and PPO) +>>>>>>> master Bug Fixes: ^^^^^^^^^^ @@ -1501,3 +1503,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger +@tobiabir diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 5e8759990..f5b972bfe 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -523,7 +523,7 @@ def learn( :param total_timesteps: The total number of samples (env steps) to train on :param callback: callback(s) called at every step with state of the algorithm. - :param log_interval: The number of episodes before logging. + :param log_interval: The number of rounds (environment interactions + agent updates) between logging. :param tb_log_name: the name of the run for TensorBoard logging :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) :param progress_bar: Display a progress bar using tqdm and rich. diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index c460d0236..4710562b1 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -311,6 +311,8 @@ def learn( reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfOffPolicyAlgorithm: + iteration = 0 + total_timesteps, callback = self._setup_learn( total_timesteps, callback, @@ -332,7 +334,6 @@ def learn( callback=callback, learning_starts=self.learning_starts, replay_buffer=self.replay_buffer, - log_interval=log_interval, ) if not rollout.continue_training: @@ -346,6 +347,11 @@ def learn( if gradient_steps > 0: self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) + iteration += 1 + + if log_interval is not None and iteration % log_interval == 0: + self._dump_logs() + callback.on_training_end() return self @@ -511,7 +517,6 @@ def collect_rollouts( replay_buffer: ReplayBuffer, action_noise: Optional[ActionNoise] = None, learning_starts: int = 0, - log_interval: Optional[int] = None, ) -> RolloutReturn: """ Collect experiences and store them into a ``ReplayBuffer``. @@ -529,7 +534,6 @@ def collect_rollouts( in addition to the stochastic policy for SAC. :param learning_starts: Number of steps before learning for the warm-up phase. :param replay_buffer: - :param log_interval: Log data every ``log_interval`` episodes :return: """ # Switch to eval mode (this affects batch norm / dropout) @@ -592,9 +596,6 @@ def collect_rollouts( kwargs = dict(indices=[idx]) if env.num_envs > 1 else {} action_noise.reset(**kwargs) - # Log training infos - if log_interval is not None and self._episode_num % log_interval == 0: - self._dump_logs() callback.on_rollout_end() return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index ddd0f8de2..e60d2e53c 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -279,9 +279,12 @@ def learn( if not continue_training: break - iteration += 1 self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + self.train() + + iteration += 1 + # Display training infos if log_interval is not None and iteration % log_interval == 0: assert self.ep_info_buffer is not None @@ -296,8 +299,6 @@ def learn( self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.dump(step=self.num_timesteps) - self.train() - callback.on_training_end() return self diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index c311b2357..7c666e876 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -115,7 +115,7 @@ def learn( self: SelfDDPG, total_timesteps: int, callback: MaybeCallback = None, - log_interval: int = 4, + log_interval: int = 1000, tb_log_name: str = "DDPG", reset_num_timesteps: bool = True, progress_bar: bool = False, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 42e3d0df0..d4ed512b1 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -259,7 +259,7 @@ def learn( self: SelfDQN, total_timesteps: int, callback: MaybeCallback = None, - log_interval: int = 4, + log_interval: int = 1000, tb_log_name: str = "DQN", reset_num_timesteps: bool = True, progress_bar: bool = False, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index bf0fa5028..ec3d63181 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -299,7 +299,7 @@ def learn( self: SelfSAC, total_timesteps: int, callback: MaybeCallback = None, - log_interval: int = 4, + log_interval: int = 1000, tb_log_name: str = "SAC", reset_num_timesteps: bool = True, progress_bar: bool = False, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a06ce67e0..e15f18e60 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -214,7 +214,7 @@ def learn( self: SelfTD3, total_timesteps: int, callback: MaybeCallback = None, - log_interval: int = 4, + log_interval: int = 1000, tb_log_name: str = "TD3", reset_num_timesteps: bool = True, progress_bar: bool = False, diff --git a/tests/test_run.py b/tests/test_run.py index 31c7b956e..8bc43dce1 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -229,7 +229,7 @@ def test_ppo_warnings(): # in that case with pytest.warns(UserWarning, match="there will be a truncated mini-batch of size 1"): model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=63, verbose=1) - model.learn(64) + model.learn(64, log_interval=2) loss = model.logger.name_to_value["train/loss"] assert loss > 0