Skip to content

Commit

Permalink
Merge pull request #326 from instadeepai/feature/recurrent-mappo
Browse files Browse the repository at this point in the history
Feature/recurrent and multiple trainer MAPPO
  • Loading branch information
DriesSmit authored Apr 21, 2022
2 parents fe5389d + 60eb054 commit 5877138
Show file tree
Hide file tree
Showing 29 changed files with 1,662 additions and 535 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def main(_: Any) -> None:
logger_factory=logger_factory,
num_executors=1,
checkpoint_subpath=checkpoint_dir,
num_epochs=15,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example running feedforward mappo on debug MPE environments.
NB: Using multiple trainers with non-shared weights is still in its
experimental phase of development. This feature will become faster and
more stable in future Mava updates."""

import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import sonnet as snt
from absl import app, flags

from mava.systems.tf import mappo
from mava.systems.tf.mappo import make_default_networks
from mava.utils import enums, lp_utils
from mava.utils.environments import debugging_utils
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS
flags.DEFINE_string(
"env_name",
"simple_spread",
"Debugging environment name (str).",
)
flags.DEFINE_string(
"action_space",
"discrete",
"Environment action space type (str).",
)
flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "~/mava/", "Base dir to store experiments.")


def main(_: Any) -> None:

# environment
environment_factory = functools.partial(
debugging_utils.make_environment,
env_name=FLAGS.env_name,
action_space=FLAGS.action_space,
)

# networks
network_factory = lp_utils.partial_kwargs(make_default_networks)

# Checkpointer appends "Checkpoints" to checkpoint_dir
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# distributed program
"""NB: Using multiple trainers with non-shared weights is still in its
experimental phase of development. This feature will become faster and
more stable in future Mava updates."""
program = mappo.MAPPO(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=2,
shared_weights=False,
trainer_networks=enums.Trainer.one_trainer_per_network,
network_sampling_setup=enums.NetworkSampler.fixed_agent_networks,
policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
max_gradient_norm=40.0,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
113 changes: 113 additions & 0 deletions examples/tf/debugging/simple_spread/recurrent/state_based/run_mappo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example running MAPPO on debug MPE environments."""

import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import sonnet as snt
from absl import app, flags

from mava.components.tf import architectures
from mava.systems.tf import mappo
from mava.utils import lp_utils
from mava.utils.enums import ArchitectureType
from mava.utils.environments import debugging_utils
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS
flags.DEFINE_string(
"env_name",
"simple_spread",
"Debugging environment name (str).",
)
flags.DEFINE_string(
"action_space",
"discrete",
"Environment action space type (str).",
)

flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.")


def main(_: Any) -> None:
# Environment.
environment_factory = functools.partial(
debugging_utils.make_environment,
env_name=FLAGS.env_name,
action_space=FLAGS.action_space,
return_state_info=True,
recurrent_test=True,
)

# Networks.
network_factory = lp_utils.partial_kwargs(
mappo.make_default_networks,
architecture_type=ArchitectureType.recurrent,
)

# Checkpointer appends "Checkpoints" to checkpoint_dir
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# Distributed program
program = mappo.MAPPO(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=1,
policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
max_gradient_norm=40.0,
executor_fn=mappo.MAPPORecurrentExecutor,
architecture=architectures.StateBasedValueActorCritic,
trainer_fn=mappo.StateBasedMAPPOTrainer,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

# Launch.
lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
104 changes: 104 additions & 0 deletions examples/tf/sisl/multiwalker/feedforward/decentralised/run_mappo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# python3
# Copyright 2021 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example running continous MAPPO on pettinzoo SISL environments."""

import functools
from datetime import datetime
from typing import Any

import launchpad as lp
import sonnet as snt
from absl import app, flags

from mava.systems.tf import mappo
from mava.utils import lp_utils
from mava.utils.environments import pettingzoo_utils
from mava.utils.loggers import logger_utils

FLAGS = flags.FLAGS

flags.DEFINE_string(
"env_class",
"sisl",
"Pettingzoo environment class, e.g. atari (str).",
)

flags.DEFINE_string(
"env_name",
"multiwalker_v7",
"Pettingzoo environment name, e.g. pong (str).",
)
flags.DEFINE_string(
"mava_id",
str(datetime.now()),
"Experiment identifier that can be used to continue experiments.",
)
flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.")


def main(_: Any) -> None:
# Environment.
environment_factory = functools.partial(
pettingzoo_utils.make_environment,
env_class=FLAGS.env_class,
env_name=FLAGS.env_name,
remove_on_fall=False,
)

# Networks.
network_factory = lp_utils.partial_kwargs(mappo.make_default_networks)

# Checkpointer appends "Checkpoints" to checkpoint_dir.
checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}"

# Log every [log_every] seconds.
log_every = 10
logger_factory = functools.partial(
logger_utils.make_logger,
directory=FLAGS.base_dir,
to_terminal=True,
to_tensorboard=True,
time_stamp=FLAGS.mava_id,
time_delta=log_every,
)

# Distributed program.
program = mappo.MAPPO(
environment_factory=environment_factory,
network_factory=network_factory,
logger_factory=logger_factory,
num_executors=1,
optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
max_gradient_norm=40.0,
).build()

# Ensure only trainer runs on gpu, while other processes run on cpu.
local_resources = lp_utils.to_device(
program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"]
)

# Launch.
lp.launch(
program,
lp.LaunchType.LOCAL_MULTI_PROCESSING,
terminal="current_terminal",
local_resources=local_resources,
)


if __name__ == "__main__":
app.run(main)
2 changes: 0 additions & 2 deletions mava/adders/reverb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,6 @@ def add_first(

if self._use_next_extras:
add_dict["extras"] = extras

self._writer.append(
add_dict,
partial_step=True,
Expand Down Expand Up @@ -421,7 +420,6 @@ def add(

if self._use_next_extras:
next_step["extras"] = next_extras

self._writer.append(
next_step,
partial_step=True,
Expand Down
1 change: 1 addition & 0 deletions mava/components/tf/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@
StateBasedQValueActorCritic,
StateBasedQValueCritic,
StateBasedQValueSingleActionCritic,
StateBasedValueActorCritic,
)
13 changes: 5 additions & 8 deletions mava/components/tf/architectures/centralised.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ def _get_critic_specs(

for agent_type, agents in agents_by_type.items():
agent_key = agents[0]
critic_obs_shape = list(copy.copy(self._embed_specs[agent_key].shape))
net_key = self._agent_net_keys[agent_key]
critic_obs_shape = list(copy.copy(self._embed_specs[net_key].shape))
critic_obs_shape.insert(0, len(agents))
obs_specs_per_type[agent_type] = tf.TensorSpec(
obs_specs_per_type[net_key] = tf.TensorSpec(
shape=critic_obs_shape,
dtype=tf.dtypes.float32,
)
Expand Down Expand Up @@ -143,11 +144,9 @@ def _get_critic_specs(

for agent_type, agents in agents_by_type.items():
agent_key = agents[0]
net_key = self._agent_net_keys[agent_key]

# TODO (dries): Add a check to see if all
# self._embed_specs[agent_key].shape are of the same shape

critic_obs_shape = list(copy.copy(self._embed_specs[agent_key].shape))
critic_obs_shape = list(copy.copy(self._embed_specs[net_key].shape))
critic_obs_shape.insert(0, len(agents))
obs_specs_per_type[agent_type] = tf.TensorSpec(
shape=critic_obs_shape,
Expand All @@ -157,8 +156,6 @@ def _get_critic_specs(
critic_act_shape = list(
copy.copy(self._agent_specs[agents[0]].actions.shape)
)
# TODO (dries): Add a check to see if all
# self._agent_specs[agents[0]].actions.shape are of the same shape

critic_act_shape.insert(0, len(agents))
action_specs_per_type[agent_type] = tf.TensorSpec(
Expand Down
Loading

0 comments on commit 5877138

Please sign in to comment.