Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/recurrent and multiple trainer MAPPO #326

Merged
merged 72 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
a535e3d
Add rough recurrent code for MAPPO.
DriesSmit Sep 27, 2021
0c6405b
Save progress.
DriesSmit Sep 27, 2021
71f4dc6
Save recurrent PPO progress.
DriesSmit Sep 27, 2021
cc42e98
Recurrent PPO is running.
DriesSmit Sep 28, 2021
1ae02b5
Small fixes.
DriesSmit Sep 28, 2021
ac46176
Recurrent MAPPO trains of the debugging environment!
DriesSmit Sep 29, 2021
1e4aaae
Small fix.
DriesSmit Sep 29, 2021
fa2b43c
Save changes.
DriesSmit Sep 30, 2021
a2bbc86
Add code to MAD4PG.
DriesSmit Oct 3, 2021
345fe1b
Ready to run 2 vs 2 xray_attention.
DriesSmit Oct 4, 2021
a25e977
Decrease queue size.
DriesSmit Oct 5, 2021
3c8a32a
Merge develop.
DriesSmit Oct 29, 2021
792ea3c
PPO seems to be training and running.
DriesSmit Oct 29, 2021
f1006c1
Add multiple trainers PPO example.
DriesSmit Oct 29, 2021
cbf9cda
Merge branch 'develop' into feature/recurrent-mappo
DriesSmit Nov 19, 2021
5b44c5b
Merge develop.
DriesSmit Dec 3, 2021
c33e32c
Fix PPO example.
DriesSmit Dec 3, 2021
ef77441
Fix embed_spec bug.
DriesSmit Dec 3, 2021
fb1c43c
Fix mypy issues.
DriesSmit Dec 3, 2021
58b68f9
Fix mypy issue.
DriesSmit Dec 3, 2021
153e0e5
Merge branch 'develop' into feature/recurrent-mappo
KaleabTessera Dec 13, 2021
e4a343e
Address some of the PR comments.
DriesSmit Dec 14, 2021
d7f8ba5
Add termination condition to MA-PPO.
DriesSmit Dec 14, 2021
05c1650
Merge branch 'develop' into feature/recurrent-mappo
DriesSmit Dec 14, 2021
473d40d
Merge branch 'develop' into feature/recurrent-mappo
DriesSmit Jan 6, 2022
f709231
Merge branch 'develop' into feature/recurrent-mappo
KaleabTessera Jan 6, 2022
aef20d2
Address PR comments.
DriesSmit Jan 12, 2022
92abb30
Add the capability for MAPPO to use continuous action spaces.
DriesSmit Jan 13, 2022
7797e91
Merge branch 'develop' into feature/recurrent-mappo
arnupretorius Jan 14, 2022
eba009c
Merge branch 'develop' into feature/recurrent-mappo
arnupretorius Jan 14, 2022
c89ce9f
fix: Small fixes.
DriesSmit Mar 9, 2022
276933c
fix: mypy.
DriesSmit Mar 9, 2022
ad2f300
Merge branch 'develop' into feature/recurrent-mappo
DriesSmit Mar 10, 2022
c5ffe70
Merge branch 'develop' into feature/recurrent-mappo
DriesSmit Mar 23, 2022
1967146
fix: Test small fix.
DriesSmit Mar 23, 2022
1fcaf0c
fix: Change writer back.
DriesSmit Mar 23, 2022
75ba274
fix: Change back.
DriesSmit Mar 23, 2022
21b376b
fix: Distributional head inside networks.
DriesSmit Mar 23, 2022
232a540
fix: Merge dev.
DriesSmit Mar 24, 2022
315a151
fix: Change static unroll function to a manual unroll function.
DriesSmit Mar 24, 2022
c0e3127
fix: Change setting in PPO sequence adder. Remove custom adder code.
DriesSmit Mar 25, 2022
9217172
fix: Small comment updates.
DriesSmit Mar 25, 2022
66757a2
bugfix: architecture type typo fix
RuanJohn Mar 28, 2022
b8ccfe0
fix: Baseline cost defualt.
DriesSmit Mar 28, 2022
acf592f
Merge branch 'feature/recurrent-mappo' of github.com:instadeepai/Mava…
DriesSmit Mar 28, 2022
f0186b6
fix: Update defualt sequence length.
DriesSmit Mar 28, 2022
091d73a
Fix entropy term in trainer for continuous action space environments.
DriesSmit Mar 29, 2022
7ce71c7
fix: Small fix to entropy loss.
DriesSmit Mar 29, 2022
64cb86a
fix: Small fix to variable spelling.
DriesSmit Mar 29, 2022
d797241
fix: Fix continuous action space PPO by creating custom clipped Gauss…
DriesSmit Mar 31, 2022
b34c707
fix: Small fixes.
DriesSmit Mar 31, 2022
ec31d19
fix: Replace clip to spec with tanh to spec.
DriesSmit Mar 31, 2022
527d8df
fix: Remove comment.
DriesSmit Mar 31, 2022
40a430e
fix: Remove comment.
DriesSmit Mar 31, 2022
86147a2
feature: Update Gaussian head's settings control the possible distrib…
DriesSmit Apr 1, 2022
91fc6a0
feature: Small update.
DriesSmit Apr 1, 2022
0adc026
Merge branch 'develop' into feature/recurrent-mappo
DriesSmit Apr 4, 2022
df45468
fix: Remove redundant statement.
DriesSmit Apr 12, 2022
49f8534
fix: Remove redundant statement.
DriesSmit Apr 12, 2022
34c538c
feat: Small improvement.
DriesSmit Apr 12, 2022
7d411dc
fix: PPO training for networks with Categorical heads.
DriesSmit Apr 13, 2022
25f3ea1
fix: Small fix to dataset shuffler.
DriesSmit Apr 13, 2022
dbb5797
fix: Remove print statement.
DriesSmit Apr 13, 2022
a257819
Small fixes to trainer variable client and Hyperparameter settings.
DriesSmit Apr 13, 2022
8e25670
feat: added multiple network fix
sash-a Apr 13, 2022
a88de00
feat: Small updates to hyperparameters. Moving system closer to devel…
DriesSmit Apr 13, 2022
072771b
merge: Merge changes.
DriesSmit Apr 13, 2022
2851235
fix: Small training fixes.
DriesSmit Apr 14, 2022
1674154
fix: Big bugfix in MAPPO trainer code setup.
DriesSmit Apr 14, 2022
b8d2738
Remove variable update inside mappo trainer _step code.
DriesSmit Apr 19, 2022
7aac022
Merge branch 'develop' into feature/recurrent-mappo
DriesSmit Apr 20, 2022
60eb054
fix: Address PR comments.
DriesSmit Apr 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/images/focus_fire.html
Original file line number Diff line number Diff line change
Expand Up @@ -428,4 +428,4 @@
MTAw
">
Your browser does not support the video tag.
</video>
</video>
2 changes: 1 addition & 1 deletion docs/images/runaway.html
Original file line number Diff line number Diff line change
Expand Up @@ -1272,4 +1272,4 @@
dAAAACWpdG9vAAAAHWRhdGEAAAABAAAAAExhdmY1OC4yOS4xMDA=
">
Your browser does not support the video tag.
</video>
</video>
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example running feedforward mappo on debug MPE environments.
NB: Using multiple trainers with non-shared weights is still in its
experimental phase of development. This feature will become faster and
more stable in future Mava updates."""

import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import sonnet as snt
from absl import app, flags

from mava.systems.tf import mappo
from mava.systems.tf.mappo import make_default_networks
from mava.utils import enums, lp_utils
from mava.utils.environments import debugging_utils
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS
flags.DEFINE_string(
"env_name",
"simple_spread",
"Debugging environment name (str).",
)
flags.DEFINE_string(
"action_space",
"discrete",
"Environment action space type (str).",
)
flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "~/mava/", "Base dir to store experiments.")


def main(_: Any) -> None:

# environment
environment_factory = functools.partial(
debugging_utils.make_environment,
env_name=FLAGS.env_name,
action_space=FLAGS.action_space,
)

# networks
network_factory = lp_utils.partial_kwargs(make_default_networks)

# Checkpointer appends "Checkpoints" to checkpoint_dir
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# distributed program
"""NB: Using multiple trainers with non-shared weights is still in its
experimental phase of development. This feature will become faster and
more stable in future Mava updates."""
program = mappo.MAPPO(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=2,
shared_weights=False,
trainer_networks=enums.Trainer.one_trainer_per_network,
network_sampling_setup=enums.NetworkSampler.fixed_agent_networks,
policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
max_gradient_norm=40.0,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
121 changes: 121 additions & 0 deletions examples/debugging/simple_spread/recurrent/state_based/run_mappo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example running MAPPO on debug MPE environments."""

import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import sonnet as snt
from absl import app, flags

from mava.components.tf import architectures
from mava.systems.tf import mappo
from mava.utils import lp_utils
from mava.utils.enums import ArchitectureType
from mava.utils.environments import debugging_utils
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS
flags.DEFINE_string(
"env_name",
"simple_spread",
"Debugging environment name (str).",
)
flags.DEFINE_string(
"action_space",
"discrete",
"Environment action space type (str).",
)

flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.")


def main(_: Any) -> None:

recurrent_test = False
recurrent_ppo = True

# Environment.
environment_factory = functools.partial(
debugging_utils.make_environment,
env_name=FLAGS.env_name,
action_space=FLAGS.action_space,
return_state_info=True,
recurrent_test=recurrent_test,
)

# Networks.
network_factory = lp_utils.partial_kwargs(
mappo.make_default_networks,
archecture_type=ArchitectureType.recurrent
if recurrent_ppo
else ArchitectureType.feedforward,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# Distributed program
program = mappo.MAPPO(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=1,
policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
max_gradient_norm=40.0,
executor_fn=mappo.MAPPORecurrentExecutor
if recurrent_ppo
else mappo.MAPPOFeedForwardExecutor,
architecture=architectures.StateBasedValueActorCritic,
trainer_fn=mappo.StateBasedMAPPOTrainer,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

# Launch.
lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
1 change: 1 addition & 0 deletions mava/components/tf/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@
StateBasedQValueActorCritic,
StateBasedQValueCritic,
StateBasedQValueSingleActionCritic,
StateBasedValueActorCritic,
)
8 changes: 5 additions & 3 deletions mava/components/tf/architectures/centralised.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ def _get_critic_specs(

for agent_type, agents in agents_by_type.items():
agent_key = agents[0]
critic_obs_shape = list(copy.copy(self._embed_specs[agent_key].shape))
net_key = self._agent_net_keys[agent_key]
critic_obs_shape = list(copy.copy(self._embed_specs[net_key].shape))
critic_obs_shape.insert(0, len(agents))
obs_specs_per_type[agent_type] = tf.TensorSpec(
obs_specs_per_type[net_key] = tf.TensorSpec(
shape=critic_obs_shape,
dtype=tf.dtypes.float32,
)
Expand Down Expand Up @@ -143,11 +144,12 @@ def _get_critic_specs(

for agent_type, agents in agents_by_type.items():
agent_key = agents[0]
net_key = self._agent_net_keys[agent_key]

# TODO (dries): Add a check to see if all
# self._embed_specs[agent_key].shape are of the same shape
DriesSmit marked this conversation as resolved.
Show resolved Hide resolved

critic_obs_shape = list(copy.copy(self._embed_specs[agent_key].shape))
critic_obs_shape = list(copy.copy(self._embed_specs[net_key].shape))
critic_obs_shape.insert(0, len(agents))
obs_specs_per_type[agent_type] = tf.TensorSpec(
shape=critic_obs_shape,
Expand Down
11 changes: 4 additions & 7 deletions mava/components/tf/architectures/decentralised.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def create_actor_variables(self) -> Dict[str, Dict[str, snt.Module]]:
emb_spec = tf2_utils.create_variables(
self._observation_networks[agent_net_key], [obs_spec]
)
self._embed_specs[agent_key] = emb_spec
self._embed_specs[agent_net_key] = emb_spec

# Create variables.
tf2_utils.create_variables(self._policy_networks[agent_net_key], [emb_spec])
Expand Down Expand Up @@ -267,7 +267,7 @@ def create_actor_variables(self) -> Dict[str, Dict[str, snt.Module]]:
emb_spec = tf2_utils.create_variables(
self._observation_networks[net_key], [obs_spec]
)
self._embed_specs[agent_key] = emb_spec
self._embed_specs[net_key] = emb_spec

# Create variables.
tf2_utils.create_variables(self._policy_networks[net_key], [emb_spec])
Expand All @@ -279,7 +279,6 @@ def create_actor_variables(self) -> Dict[str, Dict[str, snt.Module]]:
tf2_utils.create_variables(
self._target_observation_networks[net_key], [obs_spec]
)

actor_networks: Dict[str, Dict[str, snt.Module]] = {
"policies": self._policy_networks,
"observations": self._observation_networks,
Expand All @@ -295,10 +294,8 @@ def create_critic_variables(self) -> Dict[str, Dict[str, snt.Module]]:

# create critics
for net_key in self._net_keys:
agent_key = self._net_spec_keys[net_key]

# get specs
emb_spec = embed_specs[agent_key]
emb_spec = embed_specs[net_key]

# Create variables.
tf2_utils.create_variables(self._critic_networks[net_key], [emb_spec])
Expand Down Expand Up @@ -372,7 +369,7 @@ def create_critic_variables(self) -> Dict[str, Dict[str, snt.Module]]:
agent_key = self._net_spec_keys[net_key]

# get specs
emb_spec = embed_specs[agent_key]
emb_spec = embed_specs[net_key]
act_spec = act_specs[agent_key]

# Create variables.
Expand Down
Loading