Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add support for multi-agent off-policy algorithms in the new API stack. #45182

Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
baa1398
wip
sven1977 Apr 29, 2024
a1eb1f9
wip
sven1977 Apr 29, 2024
6538b58
fixes
sven1977 Apr 29, 2024
683f515
Merge branch 'master' of https://github.com/ray-project/ray into chan…
sven1977 Apr 30, 2024
366a4b9
wip
sven1977 Apr 30, 2024
a8b2d0c
wip
sven1977 Apr 30, 2024
f76628a
merge
sven1977 May 3, 2024
81421d9
Fixed a bug with 'TERMINATEDS/TRUNCATEDS' in replay buffer sampling t…
simonsays1980 May 3, 2024
bd54d5a
LINTER.
simonsays1980 May 3, 2024
6ee006f
Added docs to new 'sample' method and removed old sample methods.
simonsays1980 May 6, 2024
a345d09
Merge branch 'master' into change_episode_buffers_to_return_episode_l…
simonsays1980 May 6, 2024
b77fd5a
Replaced 'td_error' by 'TD_ERROR_KEY'.
simonsays1980 May 6, 2024
6e11ff6
Needed to define 'TD_ERROR_KEY' in 'replay_buffer.utils' b/c import e…
simonsays1980 May 6, 2024
b39b9a8
Fixed a small bug in test code.
simonsays1980 May 7, 2024
e6cf4f7
Merge branch 'master' into change_episode_buffers_to_return_episode_l…
simonsays1980 May 7, 2024
eebc04d
Interchanged 'new_obs' with our constant 'Columns.NEXT_OBS' for bette…
simonsays1980 May 7, 2024
d12f16f
Added new sampling method in 'MultiAgentEpisodeReplayBuffer' for 'ind…
simonsays1980 May 7, 2024
2247c02
Changed 'truncated/terminated' logic in 'MultiEnv' and 'MultiAgentEpi…
simonsays1980 May 8, 2024
827adda
Switched back to 'pid'.
simonsays1980 May 10, 2024
1e67ccf
Commented out NaN metrics b/c they produced hindreds of warnings.
simonsays1980 May 10, 2024
c748df8
Changed comment.
simonsays1980 May 10, 2024
fc35faa
Little changes here and there and to clean-up sample logic and multi-…
simonsays1980 May 10, 2024
c336ac8
Added suggestions from @sven1977's review.
simonsays1980 May 10, 2024
6409007
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 13, 2024
81c3893
Merged master
simonsays1980 May 13, 2024
c522597
Modified multi-agent buffer tests to correspond to the changes in '_s…
simonsays1980 May 13, 2024
b8fbe19
CHanged 'MultiAGentEpisode' and 'MultiEnv' back to master.
simonsays1980 May 13, 2024
feafb6b
Apply suggestions from code review
sven1977 May 14, 2024
d2f9030
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 14, 2024
2fd7717
Added slots to 'MultiAgentEpisode' which should help reducing memory …
simonsays1980 May 15, 2024
a3416a8
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 15, 2024
2296cfc
Changed multi-agent SAC example such that at a minimum 2 agents are u…
simonsays1980 May 16, 2024
8582ad9
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 16, 2024
c8d72fa
Merge branch 'master' into change_ma_buffer_to_use_list_of_episodes
simonsays1980 May 16, 2024
ffbf3de
Multiple performance tunings that bring the multi-agent buffer into d…
simonsays1980 May 16, 2024
47888a4
LINTER.
simonsays1980 May 16, 2024
7d6497e
Merge branch 'change_ma_buffer_to_use_list_of_episodes' of github.com…
simonsays1980 May 16, 2024
cccd48d
Merge branch 'master' of https://github.com/ray-project/ray into chan…
sven1977 May 17, 2024
e96b9ce
test BAZEL printout
sven1977 May 17, 2024
9d409dd
Commented out off-policy multi-agent examples that were not learning.
simonsays1980 May 17, 2024
41d0b18
Merge branch 'change_ma_buffer_to_use_list_of_episodes' of github.com…
simonsays1980 May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.execution.train_ops import (
train_one_step,
multi_gpu_train_one_step,
Expand Down Expand Up @@ -656,24 +656,20 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
self.learner_group.foreach_learner(lambda lrnr: lrnr._reset_noise())
# Run multiple sample-from-buffer and update iterations.
for _ in range(sample_and_train_weight):
# Sample training batch from replay_buffer.
# TODO (simon): Use sample_with_keys() here.
# Sample a list of episodes used for learning from the replay buffer.
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_SAMPLE_TIMER)):
train_dict = self.local_replay_buffer.sample(
episodes = self.local_replay_buffer.sample(
num_items=self.config.train_batch_size,
n_step=self.config.n_step,
gamma=self.config.gamma,
beta=self.config.replay_buffer_config["beta"],
beta=self.config.replay_buffer_config.get("beta"),
)
train_batch = SampleBatch(train_dict)
# Convert to multi-agent batch as `LearnerGroup` depends on it.
# TODO (sven, simon): Remove this conversion once the `LearnerGroup`
# supports dict.
train_batch = train_batch.as_multi_agent()

# Perform an update on the buffer-sampled train batch.
with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
learner_results = self.learner_group.update_from_batch(train_batch)
learner_results = self.learner_group.update_from_episodes(
episodes=episodes,
)
# Isolate TD-errors from result dicts (we should not log these to
# disk or WandB, they might be very large).
td_errors = defaultdict(list)
Expand Down Expand Up @@ -704,6 +700,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
},
reduce="sum",
)

# TODO (sven): Uncomment this once agent steps are available in the
# Learner stats.
# self.metrics.log_dict(self.metrics.peek(
Expand All @@ -713,10 +710,8 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
# Update replay buffer priorities.
with self.metrics.log_time((TIMERS, REPLAY_BUFFER_UPDATE_PRIOS_TIMER)):
update_priorities_in_episode_replay_buffer(
self.local_replay_buffer,
self.config,
train_batch,
td_errors,
replay_buffer=self.local_replay_buffer,
td_errors=td_errors,
)

# Update the target networks, if necessary.
Expand Down
24 changes: 23 additions & 1 deletion rllib/algorithms/dqn/dqn_rainbow_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
from typing import TYPE_CHECKING

from ray.rllib.core.learner.learner import Learner
from ray.rllib.utils.annotations import override
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.metrics import LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES
from ray.rllib.utils.typing import ModuleID

Expand All @@ -28,6 +37,19 @@


class DQNRainbowLearner(Learner):
@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(Learner)
def build(self) -> None:
super().build()
# Prepend a NEXT_OBS from episodes to train batch connector piece (right
# after the observation default piece).

if self.config.add_default_connectors_to_learner_pipeline:
self._learner_connector.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)

@override(Learner)
def additional_update_for_module(
self, *, module_id: ModuleID, config: "DQNConfig", timestep: int, **kwargs
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def compute_loss_for_module(
r_tau = torch.clamp(
batch[Columns.REWARDS].unsqueeze(dim=-1)
+ (
config.gamma ** batch["n_steps"]
config.gamma ** batch["n_step"]
* (1.0 - batch[Columns.TERMINATEDS].float())
).unsqueeze(dim=-1)
* z,
Expand Down Expand Up @@ -171,7 +171,7 @@ def compute_loss_for_module(
# backpropagate through the target network when optimizing the Q loss.
q_selected_target = (
batch[Columns.REWARDS]
+ (config.gamma ** batch["n_steps"]) * q_next_best_masked
+ (config.gamma ** batch["n_step"]) * q_next_best_masked
).detach()

# Choose the requested loss function. Note, in case of the Huber loss
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/ppo/ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _update_from_batch_or_episodes(
# episodes).
if self.config.enable_env_runner_and_connector_v2:
batch, episodes = self._compute_gae_from_episodes(episodes=episodes)

# Now that GAE (advantages and value targets) have been added to the train
# batch, we can proceed normally (calling super method) with the update step.
return super()._update_from_batch_or_episodes(
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def validate(self) -> None:
] not in [
"EpisodeReplayBuffer",
"PrioritizedEpisodeReplayBuffer",
"MultiAgentEpisodeReplayBuffer",
]:
raise ValueError(
"When using the new `EnvRunner API` the replay buffer must be of type "
Expand Down
9 changes: 4 additions & 5 deletions rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
TD_ERROR_KEY,
SACLearner,
)
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import (
POLICY_LOSS_KEY,
Expand Down Expand Up @@ -204,7 +203,7 @@ def compute_loss_for_module(
# Detach this node from the computation graph as we do not want to
# backpropagate through the target network when optimizing the Q loss.
q_selected_target = (
batch[Columns.REWARDS] + (config.gamma ** batch["n_steps"]) * q_next_masked
batch[Columns.REWARDS] + (config.gamma ** batch["n_step"]) * q_next_masked
).detach()

# Calculate the TD-error. Note, this is needed for the priority weights in
Expand Down Expand Up @@ -317,13 +316,13 @@ def compute_gradients(
for component in (
["qf", "policy", "alpha"] + ["qf_twin"] if config.twin_q else []
):
self.metrics.peek(DEFAULT_MODULE_ID, component + "_loss").backward(
self.metrics.peek(module_id, component + "_loss").backward(
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
retain_graph=True
)
grads.update(
{
pid: p.grad
for pid, p in self.filter_param_dict_for_optimizer(
mid: p.grad
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
for mid, p in self.filter_param_dict_for_optimizer(
self._params, self.get_optimizer(module_id, component)
).items()
}
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/common/agent_to_module_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __call__(
**kwargs,
) -> Any:
# This Connector should only be used in a multi-agent setting.
assert not episodes or isinstance(episodes[0], MultiAgentEpisode)
# assert not episodes or isinstance(episodes[0], MultiAgentEpisode)
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved

# Current agent to module mapping function.
# agent_to_module_mapping_fn = shared_data.get("agent_to_module_mapping_fn")
Expand Down
5 changes: 4 additions & 1 deletion rllib/connectors/common/batch_individual_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def __call__(
# to a batch structure of:
# [module_id] -> [col0] -> [list of items]
if is_marl_module and column in rl_module:
assert is_multi_agent
# assert is_multi_agent
# TODO (simon, sven): Check, if we need for other cases this check.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. There are still some "weird" assumptions left in some connectors' logic.
We should comb these out and make the logic when to go into what loop with SA- or MAEps more clear.

Some of this stuff has to do with the fact that EnvRunners can either have a SingleAgentRLModule OR a MultiAgentRLModule, but Learners always(!) have a MultiAgentModule. Maybe we should have Learners that operate on SingleAgentRLModules for simplicity and more transparency. It shouldn't be too hard to fix that on the Learner side.

# If MA Off-Policy and independent sampling we need to overcome
# this check.
module_data = column_data
for col, col_data in module_data.copy().items():
if isinstance(col_data, list) and col != Columns.INFOS:
Expand Down
4 changes: 4 additions & 0 deletions rllib/connectors/learner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
from ray.rllib.connectors.learner.add_columns_from_episodes_to_train_batch import (
AddColumnsFromEpisodesToTrainBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.connectors.learner.learner_connector_pipeline import (
LearnerConnectorPipeline,
)

__all__ = [
"AddColumnsFromEpisodesToTrainBatch",
"AddNextObservationsFromEpisodesToTrainBatch",
"AddObservationsFromEpisodesToBatch",
"AddStatesFromEpisodesToBatch",
"AgentToModuleMapping",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Any, List, Optional

import gymnasium as gym

from ray.rllib.core.columns import Columns
from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import EpisodeType


class AddNextObservationsFromEpisodesToTrainBatch(ConnectorV2):
"""Adds the NEXT_OBS column with the correct episode observations to train batch.

- Operates on a list of Episode objects.
- Gets all observation(s) from all the given episodes (except the very first ones)
and adds them to the batch under construction in the NEXT_OBS column (as a list of
individual observations).
- Does NOT alter any observations (or other data) in the given episodes.
- Can be used in Learner connector pipelines.

.. testcode::

import gymnasium as gym
import numpy as np

from ray.rllib.connectors.learner import (
AddNextObservationsFromEpisodesToTrainBatch
)
from ray.rllib.core.columns import Columns
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.test_utils import check

# Create two dummy SingleAgentEpisodes, each containing 3 observations,
# 2 actions and 2 rewards (both episodes are length=2).
obs_space = gym.spaces.Box(-1.0, 1.0, (2,), np.float32)
act_space = gym.spaces.Discrete(2)

episodes = [SingleAgentEpisode(
observations=[obs_space.sample(), obs_space.sample(), obs_space.sample()],
actions=[act_space.sample(), act_space.sample()],
rewards=[1.0, 2.0],
len_lookback_buffer=0,
) for _ in range(2)]
eps_1_next_obses = episodes[0].get_observations([1, 2])
eps_2_next_obses = episodes[1].get_observations([1, 2])
print(f"1st Episode's next obses are {eps_1_next_obses}")
print(f"2nd Episode's next obses are {eps_2_next_obses}")

# Create an instance of this class.
connector = AddNextObservationsFromEpisodesToTrainBatch()

# Call the connector with the two created episodes.
# Note that this particular connector works without an RLModule, so we
# simplify here for the sake of this example.
output_data = connector(
rl_module=None,
data={},
episodes=episodes,
explore=True,
shared_data={},
)
# The output data should now contain the last observations of both episodes,
# in a "per-episode organized" fashion.
check(
output_data,
{
Columns.NEXT_OBS: {
(episodes[0].id_,): eps_1_next_obses,
(episodes[1].id_,): eps_2_next_obses,
},
},
)
"""

def __init__(
self,
input_observation_space: Optional[gym.Space] = None,
input_action_space: Optional[gym.Space] = None,
**kwargs,
):
"""Initializes a AddNextObservationsFromEpisodesToTrainBatch instance."""
super().__init__(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
**kwargs,
)

@override(ConnectorV2)
def __call__(
self,
*,
rl_module: RLModule,
data: Optional[Any],
episodes: List[EpisodeType],
explore: Optional[bool] = None,
shared_data: Optional[dict] = None,
**kwargs,
) -> Any:
# If "obs" already in data, early out.
if Columns.NEXT_OBS in data:
return data

for sa_episode in self.single_agent_episode_iterator(
# This is a Learner-only connector -> Get all episodes (for train batch).
episodes,
agents_that_stepped_only=False,
):
self.add_n_batch_items(
data,
Columns.NEXT_OBS,
items_to_add=sa_episode.get_observations(slice(1, len(sa_episode) + 1)),
num_items=len(sa_episode),
single_agent_episode=sa_episode,
)
return data
2 changes: 1 addition & 1 deletion rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ def _update_from_batch_or_episodes(
# Call the learner connector pipeline.
batch = self._learner_connector(
rl_module=self.module,
data=batch,
data=batch if batch is not None else {},
episodes=episodes,
shared_data={},
)
Expand Down
Loading
Loading