Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LogEveryNTimesteps callback #2083

Merged
merged 3 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions docs/guide/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ Stable Baselines provides you with a set of common callbacks for:
- evaluating the model periodically and saving the best one (:ref:`EvalCallback`)
- chaining callbacks (:ref:`CallbackList`)
- triggering callback on events (:ref:`EventCallback`, :ref:`EveryNTimesteps`)
- logging data every N timesteps (:ref:`LogEveryNTimesteps`)
- stopping the training early based on a reward threshold (:ref:`StopTrainingOnRewardThreshold <StopTrainingCallback>`)


Expand Down Expand Up @@ -313,7 +314,7 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t

.. note::

Because of the way ``PPO1`` and ``TRPO`` work (they rely on MPI), ``n_steps`` is a lower bound between two events.
Because of the way ``VecEnv`` work, ``n_steps`` is a lower bound between two events when using multiple environments.


.. code-block:: python
Expand All @@ -330,7 +331,30 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t

model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)

model.learn(int(2e4), callback=event_callback)
model.learn(20_000, callback=event_callback)

.. _LogEveryNTimesteps:

LogEveryNTimesteps
^^^^^^^^^^^^^^^^^^

A callback derived from :ref:`EveryNTimesteps` that will dump the logged data every ``n_steps`` timesteps.


.. code-block:: python

import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import LogEveryNTimesteps

event_callback = LogEveryNTimesteps(n_steps=1_000)

model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)

# Disable auto-logging by passing `log_interval=None`
model.learn(10_000, callback=event_callback, log_interval=None)



.. _StopTrainingOnMaxEpisodes:
Expand Down
7 changes: 5 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.6.0a0 (WIP)
Release 2.6.0a1 (WIP)
--------------------------


Expand All @@ -13,6 +13,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists
- Added ``LogEveryNTimesteps`` callback to dump logs every N timesteps (note: you need to pass ``log_interval=None`` to avoid any interference)

Bug Fixes:
^^^^^^^^^^
Expand All @@ -29,15 +30,17 @@ Bug Fixes:

Deprecations:
^^^^^^^^^^^^^
- ``algo._dump_logs()`` is deprecated in favor of ``algo.dump_logs()`` and will be removed in SB3 v2.7.0

Others:
^^^^^^^
- Updated black from v24 to v25
- Improved error messages when checking Box space equality (loading ``VecNormalize``)

Documentation:
^^^^^^^^^^^^^^
- Clarify the use of Gym wrappers with ``make_vec_env`` in the section on Vectorized Environments (@pstahlhofen)

- Updated callback doc for ``EveryNTimesteps``

Release 2.5.0 (2025-01-27)
--------------------------
Expand Down
10 changes: 10 additions & 0 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,3 +865,13 @@ def save(
params_to_save = self.get_parameters()

save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)

def dump_logs(self) -> None:
"""
Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm)
"""
raise NotImplementedError()

def _dump_logs(self, *args) -> None:
warnings.warn("algo._dump_logs() is deprecated in favor of algo.dump_logs(). It will be removed in SB3 v2.7.0")
self.dump_logs(*args)
15 changes: 15 additions & 0 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,21 @@ def _on_step(self) -> bool:
return True


class LogEveryNTimesteps(EveryNTimesteps):
"""
Log data every ``n_steps`` timesteps

:param n_steps: Number of timesteps between two trigger.
"""

def __init__(self, n_steps: int):
super().__init__(n_steps, callback=ConvertCallback(self._log_data))

def _log_data(self, _locals: dict[str, Any], _globals: dict[str, Any]) -> bool:
self.model.dump_logs()
return True


class StopTrainingOnMaxEpisodes(BaseCallback):
"""
Stop the training once a maximum number of episodes are played.
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,9 @@ def _sample_action(
action = buffer_action
return action, buffer_action

def _dump_logs(self) -> None:
def dump_logs(self) -> None:
"""
Write log.
Write log data.
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None
Expand Down Expand Up @@ -594,7 +594,7 @@ def collect_rollouts(

# Log training infos
if log_interval is not None and self._episode_num % log_interval == 0:
self._dump_logs()
self.dump_logs()
callback.on_rollout_end()

return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
7 changes: 4 additions & 3 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def train(self) -> None:
"""
raise NotImplementedError

def _dump_logs(self, iteration: int) -> None:
def dump_logs(self, iteration: int = 0) -> None:
"""
Write log.

Expand All @@ -285,7 +285,8 @@ def _dump_logs(self, iteration: int) -> None:

time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if iteration > 0:
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
Expand Down Expand Up @@ -331,7 +332,7 @@ def learn(
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
self._dump_logs(iteration)
self.dump_logs(iteration)

self.train()

Expand Down
8 changes: 5 additions & 3 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,14 @@ def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None:
:param space2: Other space
"""
if isinstance(space1, spaces.Dict):
assert isinstance(space2, spaces.Dict), "spaces must be of the same type"
assert space1.spaces.keys() == space2.spaces.keys(), "spaces must have the same keys"
assert isinstance(space2, spaces.Dict), f"spaces must be of the same type: {type(space1)} != {type(space2)}"
assert (
space1.spaces.keys() == space2.spaces.keys()
), f"spaces must have the same keys: {list(space1.spaces.keys())} != {list(space2.spaces.keys())}"
for key in space1.spaces.keys():
check_shape_equal(space1.spaces[key], space2.spaces[key])
elif isinstance(space1, spaces.Box):
assert space1.shape == space2.shape, "spaces must have the same shape"
assert space1.shape == space2.shape, f"spaces must have the same shape: {space1.shape} != {space2.shape}"


def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.6.0a0
2.6.0a1
4 changes: 3 additions & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
CheckpointCallback,
EvalCallback,
EveryNTimesteps,
LogEveryNTimesteps,
StopTrainingOnMaxEpisodes,
StopTrainingOnNoModelImprovement,
StopTrainingOnRewardThreshold,
Expand Down Expand Up @@ -62,11 +63,12 @@ def test_callbacks(tmp_path, model_class):
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event")

event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
log_callback = LogEveryNTimesteps(n_steps=250)

# Stop training if max number of episodes is reached
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1)

callback = CallbackList([checkpoint_callback, eval_callback, event_callback, callback_max_episodes])
callback = CallbackList([checkpoint_callback, eval_callback, event_callback, log_callback, callback_max_episodes])
model.learn(500, callback=callback)

# Check access to local variables
Expand Down