Skip to content

Commit

Permalink
[RLlib] Add support for multi-agent off-policy algorithms in the new …
Browse files Browse the repository at this point in the history
…API stack. (#45182)
  • Loading branch information
simonsays1980 authored May 24, 2024
1 parent 4851df7 commit 7fb0ce1
Show file tree
Hide file tree
Showing 18 changed files with 299 additions and 208 deletions.
1 change: 1 addition & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ test:ci --flaky_test_attempts=3
test:ci --nocache_test_results
test:ci --spawn_strategy=local
test:ci --test_output=errors
test:ci --experimental_ui_max_stdouterr_bytes=-1
test:ci --test_verbose_timeout_warnings
test:ci-debug -c dbg
test:ci-debug --copt="-g"
Expand Down
19 changes: 19 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,25 @@ py_test(
args = ["--dir=tuned_examples/sac"]
)

# TODO (simon): These tests are not learning, yet.
# py_test(
# name = "learning_tests_multi_agent_pendulum_sac",
# main = "tuned_examples/sac/multi_agent_pendulum_sac.py",
# tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_pendulum", "learning_tests_continuous"],
# size = "large",
# srcs = ["tuned_examples/sac/multi_agent_pendulum_sac.py"],
# args = ["--enable-new-api-stack", "--num-agents=2"]
# )

# py_test(
# name = "learning_tests_multi_agent_pendulum_sac_multi_gpu",
# main = "tuned_examples/sac/multi_agent_pendulum_sac.py",
# tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_pendulum", "learning_tests_continuous", "multi_gpu"],
# size = "large",
# srcs = ["tuned_examples/sac/multi_agent_pendulum_sac.py"],
# args = ["--enable-new-api-stack", "--num-agents=2", "--num-gpus=2"]
# )

# --------------------------------------------------------------------
# Algorithms (Compilation, Losses, simple functionality tests)
# rllib/algorithms/
Expand Down
5 changes: 3 additions & 2 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.dqn.dqn_rainbow_learner import TD_ERROR_KEY
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.core.learner import Learner
Expand Down Expand Up @@ -64,6 +63,7 @@
REPLAY_BUFFER_UPDATE_PRIOS_TIMER,
SAMPLE_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
TD_ERROR_KEY,
TIMERS,
)
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
Expand Down Expand Up @@ -662,7 +662,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
num_items=self.config.train_batch_size,
n_step=self.config.n_step,
gamma=self.config.gamma,
beta=self.config.replay_buffer_config["beta"],
beta=self.config.replay_buffer_config.get("beta"),
)

# Perform an update on the buffer-sampled train batch.
Expand Down Expand Up @@ -700,6 +700,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
},
reduce="sum",
)

# TODO (sven): Uncomment this once agent steps are available in the
# Learner stats.
# self.metrics.log_dict(self.metrics.peek(
Expand Down
6 changes: 4 additions & 2 deletions rllib/algorithms/dqn/dqn_rainbow_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
override,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.metrics import LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_TARGET_UPDATES,
)
from ray.rllib.utils.typing import ModuleID

if TYPE_CHECKING:
Expand All @@ -32,7 +35,6 @@
QF_TARGET_NEXT_PROBS = "qf_target_next_probs"
QF_PREDS = "qf_preds"
QF_PROBS = "qf_probs"
TD_ERROR_KEY = "td_error"
TD_ERROR_MEAN_KEY = "td_error_mean"


Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
QF_TARGET_NEXT_PROBS,
QF_PREDS,
QF_PROBS,
TD_ERROR_KEY,
TD_ERROR_MEAN_KEY,
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import TD_ERROR_KEY
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.typing import ModuleID, TensorType

Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def validate(self) -> None:
] not in [
"EpisodeReplayBuffer",
"PrioritizedEpisodeReplayBuffer",
"MultiAgentEpisodeReplayBuffer",
]:
raise ValueError(
"When using the new `EnvRunner API` the replay buffer must be of type "
Expand Down
1 change: 0 additions & 1 deletion rllib/algorithms/sac/sac_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
QF_TWIN_LOSS_KEY = "qf_twin_loss"
QF_TWIN_PREDS = "qf_twin_preds"
TD_ERROR_MEAN_KEY = "td_error_mean"
TD_ERROR_KEY = "td_error"


class SACLearner(DQNRainbowLearner):
Expand Down
9 changes: 3 additions & 6 deletions rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@
QF_TWIN_LOSS_KEY,
QF_TWIN_PREDS,
TD_ERROR_MEAN_KEY,
TD_ERROR_KEY,
SACLearner,
)
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import (
POLICY_LOSS_KEY,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import ALL_MODULES
from ray.rllib.utils.metrics import ALL_MODULES, TD_ERROR_KEY
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType

Expand Down Expand Up @@ -221,8 +219,6 @@ def compute_loss_for_module(
# Note further, we use here the Huber loss instead of the mean squared error
# as it improves training performance.
critic_loss = torch.mean(
# TODO (simon): Introduce priority weights when episode buffer is ready.
# batch[PRIO_WEIGHTS] *
batch["weights"]
* torch.nn.HuberLoss(reduction="none", delta=1.0)(
q_selected, q_selected_target
Expand Down Expand Up @@ -303,6 +299,7 @@ def compute_loss_for_module(
def compute_gradients(
self, loss_per_module: Dict[str, TensorType], **kwargs
) -> ParamDict:
# Set all grads to `None`.
for optim in self._optimizer_parameters:
optim.zero_grad(set_to_none=True)

Expand All @@ -317,7 +314,7 @@ def compute_gradients(
for component in (
["qf", "policy", "alpha"] + ["qf_twin"] if config.twin_q else []
):
self.metrics.peek(DEFAULT_MODULE_ID, component + "_loss").backward(
self.metrics.peek(module_id, component + "_loss").backward(
retain_graph=True
)
grads.update(
Expand Down
3 changes: 0 additions & 3 deletions rllib/connectors/common/agent_to_module_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,6 @@ def __call__(
shared_data: Optional[dict] = None,
**kwargs,
) -> Any:
# This Connector should only be used in a multi-agent setting.
assert not episodes or isinstance(episodes[0], MultiAgentEpisode)

# Current agent to module mapping function.
# agent_to_module_mapping_fn = shared_data.get("agent_to_module_mapping_fn")
# Store in shared data, which module IDs map to which episode/agent, such
Expand Down
5 changes: 4 additions & 1 deletion rllib/connectors/common/batch_individual_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def __call__(
# to a batch structure of:
# [module_id] -> [col0] -> [list of items]
if is_marl_module and column in rl_module:
assert is_multi_agent
# assert is_multi_agent
# TODO (simon, sven): Check, if we need for other cases this check.
# If MA Off-Policy and independent sampling we need to overcome
# this check.
module_data = column_data
for col, col_data in module_data.copy().items():
if isinstance(col_data, list) and col != Columns.INFOS:
Expand Down
9 changes: 5 additions & 4 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from collections import defaultdict
from functools import partial
import numpy as np
from typing import DefaultDict, Dict, List, Optional

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
Expand Down Expand Up @@ -623,9 +622,11 @@ def get_metrics(self) -> ResultDict:
module_episode_returns,
)

# If no episodes at all, log NaN stats.
if len(self._done_episodes_for_metrics) == 0:
self._log_episode_metrics(np.nan, np.nan, np.nan)
# TODO (simon): This results in hundreds of warnings in the logs
# b/c reducing over NaNs is not supported.
# # If no episodes at all, log NaN stats.
# if len(self._done_episodes_for_metrics) == 0:
# self._log_episode_metrics(np.nan, np.nan, np.nan)

# Log num episodes counter for this iteration.
self.metrics.log_value(
Expand Down
24 changes: 24 additions & 0 deletions rllib/env/multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,30 @@ class MultiAgentEpisode:
up to here, b/c there is nothing to learn from these "premature" rewards.
"""

__slots__ = (
"id_",
"agent_to_module_mapping_fn",
"_agent_to_module_mapping",
"observation_space",
"action_space",
"env_t_started",
"env_t",
"agent_t_started",
"env_t_to_agent_t",
"_hanging_actions_end",
"_hanging_extra_model_outputs_end",
"_hanging_rewards_end",
"_hanging_actions_begin",
"_hanging_extra_model_outputs_begin",
"_hanging_rewards_begin",
"is_terminated",
"is_truncated",
"agent_episodes",
"_temporary_timestep_data",
"_start_time",
"_last_step_time",
)

SKIP_ENV_TS_TAG = "S"

def __init__(
Expand Down
79 changes: 79 additions & 0 deletions rllib/tuned_examples/sac/multi_agent_pendulum_sac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from ray.rllib.algorithms.sac import SACConfig
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
)
from ray.tune.registry import register_env

from ray.rllib.utils.test_utils import add_rllib_example_script_args

parser = add_rllib_example_script_args()
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values to set up `config` below.
args = parser.parse_args()

register_env(
"multi_agent_pendulum",
lambda _: MultiAgentPendulum({"num_agents": args.num_agents or 2}),
)

config = (
SACConfig()
.environment(env="multi_agent_pendulum")
.rl_module(
model_config_dict={
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"post_fcnet_weights_initializer": "orthogonal_",
"post_fcnet_weights_initializer_config": {"gain": 0.01},
}
)
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.env_runners(
rollout_fragment_length=1,
num_env_runners=2,
num_envs_per_env_runner=1,
)
.training(
initial_alpha=1.001,
lr=3e-4,
target_entropy="auto",
n_step=1,
tau=0.005,
train_batch_size_per_learner=256,
target_network_update_freq=1,
replay_buffer_config={
"type": "MultiAgentEpisodeReplayBuffer",
"capacity": 100000,
},
num_steps_sampled_before_learning_starts=256,
)
.reporting(
metrics_num_episodes_for_smoothing=5,
min_sample_timesteps_per_iteration=1000,
)
)

if args.num_agents:
config.multi_agent(
policy_mapping_fn=lambda aid, *arg, **kw: f"p{aid}",
policies={f"p{i}" for i in range(args.num_agents)},
)

stop = {
NUM_ENV_STEPS_SAMPLED_LIFETIME: 500000,
# `episode_return_mean` is the sum of all agents/policies' returns.
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -400.0 * (args.num_agents or 2),
}

if __name__ == "__main__":
from ray.rllib.utils.test_utils import run_rllib_example_script_experiment

run_rllib_example_script_experiment(config, args, stop=stop)
1 change: 1 addition & 0 deletions rllib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,4 @@
# Learner.
LEARNER_STATS_KEY = "learner_stats"
ALL_MODULES = "__all_modules__"
TD_ERROR_KEY = "td_error"
Loading

0 comments on commit 7fb0ce1

Please sign in to comment.