Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace types ParallelAdder with ReverbParallelAdder #356

Merged
merged 1 commit into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mava/adders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
"""Adders for sending data from actors to replay buffers."""


from mava.adders.base import ParallelAdder
from mava.adders.reverb.base import ReverbParallelAdder
14 changes: 9 additions & 5 deletions mava/adders/reverb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from acme.adders.reverb.base import ReverbAdder

from mava import types as mava_types
from mava.adders.base import ParallelAdder
from mava.utils.sort_utils import sort_str_num

DEFAULT_PRIORITY_TABLE = "priority_table"
Expand Down Expand Up @@ -93,13 +94,14 @@ def get_trajectory_net_agents(
trajectory: Union[Trajectory, mava_types.Transition],
trajectory_net_keys: Dict[str, str],
) -> Tuple[List, Dict[str, List]]:
"""Returns a dictionary that maps network_keys to a list of agents using that specific
network.
"""Returns a dictionary that maps network_keys to a list of agents using that
specific network.

Args:
trajectory: Episode experience recorded by
the adders.
trajectory_net_keys: The network_keys used by each agent in the trajectory.

Returns:
agents: A sorted list of all the agent_keys.
agents_per_network: A dictionary that maps network_keys to
Expand All @@ -113,7 +115,7 @@ def get_trajectory_net_agents(
return agents, agents_per_network


class ReverbParallelAdder(ReverbAdder):
class ReverbParallelAdder(ReverbAdder, ParallelAdder):
"""Base reverb class."""

def __init__(
Expand Down Expand Up @@ -159,11 +161,13 @@ def write_experience_to_tables( # noqa
trajectory: Union[Trajectory, mava_types.Transition],
table_priorities: Dict[str, Any],
) -> None:
"""Write an episode experience (trajectory) to the reverb tables. Each
table represents experience used by each of the trainers. Therefore
"""Write an episode experience (trajectory) to the reverb tables.

Each table represents experience used by each of the trainers. Therefore
this function dynamically determines to which table(s) to write
parts of the trajectory based on what networks where used by
the agents in the episode run.

Args:
trajectory: Trajectory to be
written to the reverb tables.
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/tf/dial/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def make_executor( # type: ignore[override]
],
action_selectors: Dict[str, Any],
communication_module: BaseCommunicationModule,
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_source: Optional[core.VariableSource] = None,
trainer: Optional[training.MADQNRecurrentCommTrainer] = None,
evaluator: bool = False,
Expand All @@ -135,8 +135,8 @@ def make_executor( # type: ignore[override]
epsilon greedy.
communication_module (BaseCommunicationModule): module for enabling
communication protocols between agents.
adder (Optional[adders.ParallelAdder], optional): adder to send data to
a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder to send data
to a replay buffer. Defaults to None.
variable_source (Optional[core.VariableSource], optional): variables server.
Defaults to None.
trainer (Optional[training.MADQNRecurrentCommTrainer], optional):
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/tf/dial/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
action_selectors: Dict[str, snt.Module],
communication_module: BaseCommunicationModule,
agent_net_keys: Dict[str, str],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
store_recurrent_state: bool = True,
trainer: MADQNTrainer = None,
Expand All @@ -59,8 +59,8 @@ def __init__(
communication protocols between agents.
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
adder (Optional[adders.ParallelAdder], optional): adder which sends data
to a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder which sends
data to a replay buffer. Defaults to None.
variable_client (Optional[tf2_variable_utils.VariableClient], optional):
client to copy weights from the trainer. Defaults to None.
store_recurrent_state (bool, optional): boolean to store the recurrent
Expand Down
18 changes: 9 additions & 9 deletions mava/systems/tf/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self,
policy_networks: Dict[str, snt.Module],
agent_net_keys: Dict[str, str],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
):
"""Initialise the system executor
Expand All @@ -50,8 +50,8 @@ def __init__(
the system.
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
adder (Optional[adders.ParallelAdder], optional): adder which sends data
to a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder which sends
data to a replay buffer. Defaults to None.
variable_client (Optional[tf2_variable_utils.VariableClient], optional):
client to copy weights from the trainer. Defaults to None.
"""
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(
self,
policy_networks: Dict[str, snt.RNNCore],
agent_net_keys: Dict[str, str],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
store_recurrent_state: bool = True,
):
Expand All @@ -211,8 +211,8 @@ def __init__(
the system.
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
adder (Optional[adders.ParallelAdder], optional): adder which sends data
to a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder which sends
data to a replay buffer. Defaults to None.
variable_client (Optional[tf2_variable_utils.VariableClient], optional):
client to copy weights from the trainer. Defaults to None.
store_recurrent_state (bool, optional): boolean to store the recurrent
Expand Down Expand Up @@ -421,7 +421,7 @@ def __init__(
policy_networks: Dict[str, snt.RNNCore],
communication_module: BaseCommunicationModule,
agent_net_keys: Dict[str, str],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
store_recurrent_state: bool = True,
):
Expand All @@ -434,8 +434,8 @@ def __init__(
communication protocols between agents.
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
adder (Optional[adders.ParallelAdder], optional): adder which sends data
to a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder which sends
data to a replay buffer. Defaults to None.
variable_client (Optional[tf2_variable_utils.VariableClient], optional):
client to copy weights from the trainer. Defaults to None.
store_recurrent_state (bool, optional): boolean to store the recurrent
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/tf/mad4pg/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
network_sampling_setup: List,
net_keys_to_ids: Dict[str, int],
evaluator: bool = False,
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
counts: Optional[Dict[str, Any]] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
interval: Optional[dict] = None,
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(
network_sampling_setup: List,
net_keys_to_ids: Dict[str, int],
evaluator: bool = False,
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
counts: Optional[Dict[str, Any]] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
interval: Optional[dict] = None,
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/tf/maddpg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def make_dataset_iterator(
def make_adder(
self,
replay_client: reverb.Client,
) -> Optional[adders.ParallelAdder]:
) -> Optional[adders.ReverbParallelAdder]:
"""Create an adder which records data generated by the executor/environment.
Args:
replay_client: Reverb Client which points to the
Expand Down Expand Up @@ -431,7 +431,7 @@ def make_executor(
self,
networks: Dict[str, snt.Module],
policy_networks: Dict[str, snt.Module],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_source: Optional[MavaVariableSource] = None,
evaluator: bool = False,
) -> core.Executor:
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/tf/maddpg/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
network_sampling_setup: List,
net_keys_to_ids: Dict[str, int],
evaluator: bool = False,
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
counts: Optional[Dict[str, Any]] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
interval: Optional[dict] = None,
Expand Down Expand Up @@ -240,7 +240,7 @@ def __init__(
network_sampling_setup: List,
net_keys_to_ids: Dict[str, int],
evaluator: bool = False,
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
counts: Optional[Dict[str, Any]] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
store_recurrent_state: bool = True,
Expand Down
11 changes: 6 additions & 5 deletions mava/systems/tf/madqn/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def make_dataset_iterator(

def make_adder(
self, replay_client: reverb.Client
) -> Optional[adders.ParallelAdder]:
) -> Optional[adders.ReverbParallelAdder]:
"""Create an adder which records data generated by the executor/environment.

Args:
Expand All @@ -246,7 +246,8 @@ def make_adder(
NotImplementedError: unknown executor type.

Returns:
Optional[adders.ParallelAdder]: adder which sends data to a replay buffer.
Optional[adders.ReverbParallelAdder]: adder which sends data to a
replay buffer.
"""

# Select adder
Expand Down Expand Up @@ -281,7 +282,7 @@ def make_executor(
ConstantScheduler,
],
],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_source: Optional[core.VariableSource] = None,
trainer: Optional[training.MADQNTrainer] = None,
communication_module: Optional[BaseCommunicationModule] = None,
Expand All @@ -296,8 +297,8 @@ def make_executor(
action_selectors (Dict[str, Any]): policy action selector method, e.g.
epsilon greedy.
exploration_schedules: epsilon decay scheduler per agent.
adder (Optional[adders.ParallelAdder], optional): adder to send data to
a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder to send data
to a replay buffer. Defaults to None.
variable_source (Optional[core.VariableSource], optional): variables server.
Defaults to None.
trainer (Optional[training.MADQNRecurrentCommTrainer], optional):
Expand Down
18 changes: 9 additions & 9 deletions mava/systems/tf/madqn/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
action_selectors: Dict[str, snt.Module],
trainer: MADQNTrainer,
agent_net_keys: Dict[str, str],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
communication_module: Optional[BaseCommunicationModule] = None,
fingerprint: bool = False,
Expand All @@ -119,8 +119,8 @@ def __init__(
trainer (MADQNTrainer, optional): system trainer.
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
adder (Optional[adders.ParallelAdder], optional): adder which sends data
to a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder which sends
data to a replay buffer. Defaults to None.
variable_client (Optional[tf2_variable_utils.VariableClient], optional):
client to copy weights from the trainer. Defaults to None.
communication_module (BaseCommunicationModule): module for enabling
Expand Down Expand Up @@ -307,7 +307,7 @@ def __init__(
q_networks: Dict[str, snt.Module],
action_selectors: Dict[str, snt.Module],
agent_net_keys: Dict[str, str],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
store_recurrent_state: bool = True,
trainer: MADQNTrainer = None,
Expand All @@ -326,8 +326,8 @@ def __init__(
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
agent_net_keys (Dict[str, Any]): specifies what network each agent uses.
adder (Optional[adders.ParallelAdder], optional): adder which sends data
to a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder which sends
data to a replay buffer. Defaults to None.
variable_client (Optional[tf2_variable_utils.VariableClient], optional):
client to copy weights from the trainer. Defaults to None.
store_recurrent_state (bool, optional): boolean to store the recurrent
Expand Down Expand Up @@ -452,7 +452,7 @@ def __init__(
action_selectors: Dict[str, snt.Module],
communication_module: BaseCommunicationModule,
agent_net_keys: Dict[str, str],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
store_recurrent_state: bool = True,
trainer: MADQNTrainer = None,
Expand All @@ -471,8 +471,8 @@ def __init__(
communication protocols between agents.
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
adder (Optional[adders.ParallelAdder], optional): adder which sends data
to a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder which sends
data to a replay buffer. Defaults to None.
variable_client (Optional[tf2_variable_utils.VariableClient], optional):
client to copy weights from the trainer. Defaults to None.
store_recurrent_state (bool, optional): boolean to store the recurrent
Expand Down
7 changes: 4 additions & 3 deletions mava/systems/tf/mappo/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def make_dataset_iterator(
def make_adder(
self,
replay_client: reverb.Client,
) -> Optional[adders.ParallelAdder]:
) -> Optional[adders.ReverbParallelAdder]:
"""Create an adder which records data generated by the executor/environment.

Args:
Expand All @@ -198,7 +198,8 @@ def make_adder(
NotImplementedError: unknown executor type.

Returns:
Optional[adders.ParallelAdder]: adder which sends data to a replay buffer.
Optional[adders.ReverbParallelAdder]: adder which sends data to a
replay buffer.
"""

return reverb_adders.ParallelSequenceAdder(
Expand All @@ -211,7 +212,7 @@ def make_adder(
def make_executor(
self,
policy_networks: Dict[str, snt.Module],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_source: Optional[core.VariableSource] = None,
evaluator: bool = False,
) -> core.Executor:
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/tf/mappo/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self,
policy_networks: Dict[str, snt.Module],
agent_net_keys: Dict[str, str],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
evaluator: bool = False,
interval: Optional[dict] = None,
Expand All @@ -52,8 +52,8 @@ def __init__(
the system.
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
adder (Optional[adders.ParallelAdder], optional): adder which sends data
to a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder which sends
data to a replay buffer. Defaults to None.
variable_client (Optional[tf2_variable_utils.VariableClient], optional):
client to copy weights from the trainer. Defaults to None.
evaluator (bool, optional): whether the executor will be used for
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/tf/qmix/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
action_selectors: Dict[str, snt.Module],
trainer: MADQNTrainer,
agent_net_keys: Dict[str, str],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
communication_module: Optional[BaseCommunicationModule] = None,
fingerprint: bool = False,
Expand All @@ -54,8 +54,8 @@ def __init__(
trainer (MADQNTrainer, optional): system trainer.
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
adder (Optional[adders.ParallelAdder], optional): adder which sends data
to a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder which sends
data to a replay buffer. Defaults to None.
variable_client (Optional[tf2_variable_utils.VariableClient], optional):
client to copy weights from the trainer. Defaults to None.
communication_module (BaseCommunicationModule): module for enabling
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/tf/vdn/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
action_selectors: Dict[str, snt.Module],
trainer: MADQNTrainer,
agent_net_keys: Dict[str, str],
adder: Optional[adders.ParallelAdder] = None,
adder: Optional[adders.ReverbParallelAdder] = None,
variable_client: Optional[tf2_variable_utils.VariableClient] = None,
communication_module: Optional[BaseCommunicationModule] = None,
fingerprint: bool = False,
Expand All @@ -54,8 +54,8 @@ def __init__(
trainer (MADQNTrainer, optional): system trainer.
agent_net_keys: (dict, optional): specifies what network each agent uses.
Defaults to {}.
adder (Optional[adders.ParallelAdder], optional): adder which sends data
to a replay buffer. Defaults to None.
adder (Optional[adders.ReverbParallelAdder], optional): adder which sends
data to a replay buffer. Defaults to None.
variable_client (Optional[tf2_variable_utils.VariableClient], optional):
client to copy weights from the trainer. Defaults to None.
communication_module (BaseCommunicationModule): module for enabling
Expand Down