diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst index 472f42114..fa2bd78b3 100644 --- a/docs/guide/callbacks.rst +++ b/docs/guide/callbacks.rst @@ -143,6 +143,7 @@ Stable Baselines provides you with a set of common callbacks for: - evaluating the model periodically and saving the best one (:ref:`EvalCallback`) - chaining callbacks (:ref:`CallbackList`) - triggering callback on events (:ref:`EventCallback`, :ref:`EveryNTimesteps`) +- logging data every N timesteps (:ref:`LogEveryNTimesteps`) - stopping the training early based on a reward threshold (:ref:`StopTrainingOnRewardThreshold `) @@ -313,7 +314,7 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t .. note:: - Because of the way ``PPO1`` and ``TRPO`` work (they rely on MPI), ``n_steps`` is a lower bound between two events. + Because of the way ``VecEnv`` work, ``n_steps`` is a lower bound between two events when using multiple environments. .. code-block:: python @@ -330,7 +331,30 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t model = PPO("MlpPolicy", "Pendulum-v1", verbose=1) - model.learn(int(2e4), callback=event_callback) + model.learn(20_000, callback=event_callback) + +.. _LogEveryNTimesteps: + +LogEveryNTimesteps +^^^^^^^^^^^^^^^^^^ + +A callback derived from :ref:`EveryNTimesteps` that will dump the logged data every ``n_steps`` timesteps. + + +.. code-block:: python + + import gymnasium as gym + + from stable_baselines3 import PPO + from stable_baselines3.common.callbacks import LogEveryNTimesteps + + event_callback = LogEveryNTimesteps(n_steps=1_000) + + model = PPO("MlpPolicy", "Pendulum-v1", verbose=1) + + # Disable auto-logging by passing `log_interval=None` + model.learn(10_000, callback=event_callback, log_interval=None) + .. _StopTrainingOnMaxEpisodes: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 0a7b21db5..cd2bc23f7 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.6.0a0 (WIP) +Release 2.6.0a1 (WIP) -------------------------- @@ -13,6 +13,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists +- Added ``LogEveryNTimesteps`` callback to dump logs every N timesteps (note: you need to pass ``log_interval=None`` to avoid any interference) Bug Fixes: ^^^^^^^^^^ @@ -29,15 +30,17 @@ Bug Fixes: Deprecations: ^^^^^^^^^^^^^ +- ``algo._dump_logs()`` is deprecated in favor of ``algo.dump_logs()`` and will be removed in SB3 v2.7.0 Others: ^^^^^^^ - Updated black from v24 to v25 +- Improved error messages when checking Box space equality (loading ``VecNormalize``) Documentation: ^^^^^^^^^^^^^^ - Clarify the use of Gym wrappers with ``make_vec_env`` in the section on Vectorized Environments (@pstahlhofen) - +- Updated callback doc for ``EveryNTimesteps`` Release 2.5.0 (2025-01-27) -------------------------- diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 412f9dda2..434898bf0 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -865,3 +865,13 @@ def save( params_to_save = self.get_parameters() save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables) + + def dump_logs(self) -> None: + """ + Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm) + """ + raise NotImplementedError() + + def _dump_logs(self, *args) -> None: + warnings.warn("algo._dump_logs() is deprecated in favor of algo.dump_logs(). It will be removed in SB3 v2.7.0") + self.dump_logs(*args) diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 0e7387911..3beb17b72 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -591,6 +591,21 @@ def _on_step(self) -> bool: return True +class LogEveryNTimesteps(EveryNTimesteps): + """ + Log data every ``n_steps`` timesteps + + :param n_steps: Number of timesteps between two trigger. + """ + + def __init__(self, n_steps: int): + super().__init__(n_steps, callback=ConvertCallback(self._log_data)) + + def _log_data(self, _locals: dict[str, Any], _globals: dict[str, Any]) -> bool: + self.model.dump_logs() + return True + + class StopTrainingOnMaxEpisodes(BaseCallback): """ Stop the training once a maximum number of episodes are played. diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index c3e1c6662..b778480d4 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -406,9 +406,9 @@ def _sample_action( action = buffer_action return action, buffer_action - def _dump_logs(self) -> None: + def dump_logs(self) -> None: """ - Write log. + Write log data. """ assert self.ep_info_buffer is not None assert self.ep_success_buffer is not None @@ -594,7 +594,7 @@ def collect_rollouts( # Log training infos if log_interval is not None and self._episode_num % log_interval == 0: - self._dump_logs() + self.dump_logs() callback.on_rollout_end() return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index ac4c0970c..0db5ce5d5 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -274,7 +274,7 @@ def train(self) -> None: """ raise NotImplementedError - def _dump_logs(self, iteration: int) -> None: + def dump_logs(self, iteration: int = 0) -> None: """ Write log. @@ -285,7 +285,8 @@ def _dump_logs(self, iteration: int) -> None: time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) - self.logger.record("time/iterations", iteration, exclude="tensorboard") + if iteration > 0: + self.logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) @@ -331,7 +332,7 @@ def learn( # Display training infos if log_interval is not None and iteration % log_interval == 0: assert self.ep_info_buffer is not None - self._dump_logs(iteration) + self.dump_logs(iteration) self.train() diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 7caef0501..3790fe356 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -244,12 +244,14 @@ def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None: :param space2: Other space """ if isinstance(space1, spaces.Dict): - assert isinstance(space2, spaces.Dict), "spaces must be of the same type" - assert space1.spaces.keys() == space2.spaces.keys(), "spaces must have the same keys" + assert isinstance(space2, spaces.Dict), f"spaces must be of the same type: {type(space1)} != {type(space2)}" + assert ( + space1.spaces.keys() == space2.spaces.keys() + ), f"spaces must have the same keys: {list(space1.spaces.keys())} != {list(space2.spaces.keys())}" for key in space1.spaces.keys(): check_shape_equal(space1.spaces[key], space2.spaces[key]) elif isinstance(space1, spaces.Box): - assert space1.shape == space2.shape, "spaces must have the same shape" + assert space1.shape == space2.shape, f"spaces must have the same shape: {space1.shape} != {space2.shape}" def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 3d87ca93f..5809eab2a 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.6.0a0 +2.6.0a1 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index ffc37320f..81d84acbc 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -13,6 +13,7 @@ CheckpointCallback, EvalCallback, EveryNTimesteps, + LogEveryNTimesteps, StopTrainingOnMaxEpisodes, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold, @@ -62,11 +63,12 @@ def test_callbacks(tmp_path, model_class): checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event") event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event) + log_callback = LogEveryNTimesteps(n_steps=250) # Stop training if max number of episodes is reached callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1) - callback = CallbackList([checkpoint_callback, eval_callback, event_callback, callback_max_episodes]) + callback = CallbackList([checkpoint_callback, eval_callback, event_callback, log_callback, callback_max_episodes]) model.learn(500, callback=callback) # Check access to local variables