Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: sebulba ff_ippo #1088

Merged
merged 150 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
adc2114
feat: gym wrapper
Louay-Ben-nessir Jun 10, 2024
ce86d09
chore : pre-commit hooks
Louay-Ben-nessir Jun 10, 2024
d5edf45
fix: merged the observations and action mask
Louay-Ben-nessir Jun 14, 2024
f891be5
fix: Create the gym wrappers directly
Louay-Ben-nessir Jun 14, 2024
15f4867
chore: pre-commit
Louay-Ben-nessir Jun 14, 2024
82ea827
fix: fixed the async env creation
Louay-Ben-nessir Jun 14, 2024
4e94df5
fix: gymV26 compatability wrapper
Louay-Ben-nessir Jun 14, 2024
8a86be9
fix: various minor fixes
Louay-Ben-nessir Jun 15, 2024
1da5c15
fix: handling rware reset function
Louay-Ben-nessir Jun 15, 2024
4466044
feat: async env wrapper , changed the gym wrapper to rware wrapper
Louay-Ben-nessir Jun 16, 2024
24d8aae
fix: fixed the async env wrapper
Louay-Ben-nessir Jun 16, 2024
a6deae2
fix: info only contains the action_mask and reformated (n_agents, n_e…
Louay-Ben-nessir Jun 18, 2024
1475bd0
chore: removed async gym wrapper
Louay-Ben-nessir Jun 22, 2024
9fce9c6
feat: gym metric tracker wrapper
Louay-Ben-nessir Jun 22, 2024
055a326
feat: init sebulba ippo
Louay-Ben-nessir Jun 10, 2024
a435a0a
feat: initial learner / training loop
Louay-Ben-nessir Jun 13, 2024
7e80d7b
fix: changes the env creation
Louay-Ben-nessir Jun 14, 2024
b961336
fix: fixed function calls
Louay-Ben-nessir Jun 15, 2024
502730d
fix: fixed the training and added training logger
Louay-Ben-nessir Jun 22, 2024
1985729
fix: changed the anakin ppo type import
Louay-Ben-nessir Jun 22, 2024
89ed246
feat: fulll sebulba functional
Louay-Ben-nessir Jun 25, 2024
7f43a33
fix: logging and added LBF
Louay-Ben-nessir Jul 2, 2024
8a87258
fix: batch size calc for multiple devices
Louay-Ben-nessir Jul 4, 2024
7f0acd9
fix: num_updates and code refactoring
Louay-Ben-nessir Jul 5, 2024
3e352cf
chore : code cleanup + comments + added checkpoint save
Louay-Ben-nessir Jul 8, 2024
bcdaa38
feat: mappo + removed sebulba specifique types and made the rware wra…
Louay-Ben-nessir Jul 8, 2024
7044fbe
fix: removed the sebulba spesifique types
Louay-Ben-nessir Jul 8, 2024
9433f2e
feat: ff_mappo and rec_ippo in sebulba
Louay-Ben-nessir Jul 10, 2024
627215d
fix: removed the lbf import/wrapper
Louay-Ben-nessir Jul 10, 2024
c3b405d
chore: clean up & updated the code to match the sebulba-ff-ippo branch
Louay-Ben-nessir Jul 10, 2024
e40c5d4
chore : pre-commits and some comments
Louay-Ben-nessir Jul 10, 2024
4b17c15
chore: removed unused config file
Louay-Ben-nessir Jul 10, 2024
9ec6b16
feat: sebulba ff_ippo
Louay-Ben-nessir Jul 10, 2024
e5dd71b
chore: pre-commits
Louay-Ben-nessir Jul 10, 2024
af24082
fix: fix the num_updates_in_eval in the last eval
Louay-Ben-nessir Jul 13, 2024
32ac389
fix: fixed the num evals cacls
Louay-Ben-nessir Jul 16, 2024
45ca587
chore : pre commit
Louay-Ben-nessir Jul 16, 2024
d694498
chore: created the anakin and sebulba folders
Louay-Ben-nessir Jul 16, 2024
cb8111f
fix: imports and config paths in systems
Louay-Ben-nessir Jul 16, 2024
d842375
fix: allow for reproducibility
Louay-Ben-nessir Jul 16, 2024
0a1ffd0
chore: pre-commits
Louay-Ben-nessir Jul 16, 2024
f1adc31
chore: pre-commits
Louay-Ben-nessir Jul 16, 2024
3850591
feat: LBF and reproducibility
Louay-Ben-nessir Jul 16, 2024
0a2ee08
feat : lbf
Louay-Ben-nessir Jul 16, 2024
dc92065
fix: sync neptune logging for sebulba to avoid stalling
Louay-Ben-nessir Jul 17, 2024
133a250
fix: added missing lbf import
Louay-Ben-nessir Jul 17, 2024
b938c83
fix: seeds need to python arrays not np arrays
Louay-Ben-nessir Jul 17, 2024
a368476
fix: config and imports for anakin q_learning and sac
Louay-Ben-nessir Jul 17, 2024
32433ff
chore: arch_name for anakin
Louay-Ben-nessir Jul 17, 2024
a68c8e9
fix: sum the rewards when using a shared reward
Louay-Ben-nessir Jul 17, 2024
8cee7ac
fix: configs revamp
Louay-Ben-nessir Jul 17, 2024
e199f3a
chore: pre-commits
Louay-Ben-nessir Jul 17, 2024
2b71d3b
fix: more config changes
Louay-Ben-nessir Jul 17, 2024
e87ad28
chore: pre-commits
Louay-Ben-nessir Jul 17, 2024
2b587c0
chore: renamed arch_name to architecture_name
Louay-Ben-nessir Jul 18, 2024
5ad4d2f
chore: config files rename
Louay-Ben-nessir Jul 18, 2024
432071e
fix; moved from gym to gymnasium
Louay-Ben-nessir Jul 18, 2024
77e6e12
feat: generic gym wrapper
Louay-Ben-nessir Jul 18, 2024
43511fd
feat: using gymnasium async worker
Louay-Ben-nessir Jul 18, 2024
eaf9a1c
chore: pre-commits and annotaions
Louay-Ben-nessir Jul 18, 2024
16c0ac3
fix: config file fixes
Louay-Ben-nessir Jul 18, 2024
18b928d
fix: rware import
Louay-Ben-nessir Jul 18, 2024
19a7765
fix: better agent ids wrapper?
Louay-Ben-nessir Jul 18, 2024
c4a05d6
chore: bunch of minor changes
Louay-Ben-nessir Jul 18, 2024
5595818
chore : annotation
Louay-Ben-nessir Jul 18, 2024
29b1303
chore: comments
Louay-Ben-nessir Jul 19, 2024
669dfbd
feat: restructured the folders
Louay-Ben-nessir Jul 19, 2024
d1f8364
update the gym wrappers
Louay-Ben-nessir Jul 19, 2024
dc641c6
folder re-structuring
Louay-Ben-nessir Jul 19, 2024
0881d2f
fix: removed deprecated jax call
Louay-Ben-nessir Jul 19, 2024
b60cefe
fix: env wrappers fix
Louay-Ben-nessir Jul 19, 2024
21aafbf
fix: config changes
Louay-Ben-nessir Jul 19, 2024
e09fd60
chore: pre-commits
Louay-Ben-nessir Jul 19, 2024
2a6452d
fix: config file fixes
Louay-Ben-nessir Jul 19, 2024
e2f36f9
fix: LBF import
Louay-Ben-nessir Jul 19, 2024
29396c9
fix: Async worker auto-resetting
Louay-Ben-nessir Jul 19, 2024
6de0b1e
chore: minor changes
Louay-Ben-nessir Jul 19, 2024
7584ce5
fixed: annotations and add agent id spaces
Louay-Ben-nessir Jul 22, 2024
e638e9f
fix: fixed the logging deadlock for sebulba
Louay-Ben-nessir Jul 22, 2024
81b0a89
Merge pull request #4 from Louay-Ben-nessir/feat-sebulba-gym-wrapper
Louay-Ben-nessir Jul 22, 2024
0860518
Merge pull request #1090 from Louay-Ben-nessir/chore--anakin-and-sebu…
sash-a Jul 23, 2024
4c0acdc
Merge remote-tracking branch 'upstream/develop' into chore--sebulba-a…
Louay-Ben-nessir Jul 23, 2024
a85aa2f
chore: pre-commits
Louay-Ben-nessir Jul 23, 2024
e504b47
pre-commit
Louay-Ben-nessir Jul 23, 2024
6a1fad4
Merge pull request #1094 from Louay-Ben-nessir/chore--sebulba-arch-up…
OmaymaMahjoub Jul 23, 2024
0cae539
Merge remote-tracking branch 'upstream/feat/sebulba_arch' into seb-ff…
Louay-Ben-nessir Jul 23, 2024
a19056b
feat : major code restructer, non-blocking evalutors
Louay-Ben-nessir Jul 25, 2024
fc80b91
chore: code cleanup and sps calcs and learner threads
Louay-Ben-nessir Jul 26, 2024
18ec08f
feat: shared time steps checker
Louay-Ben-nessir Jul 29, 2024
38e7229
chore: removed unused eval type
Louay-Ben-nessir Jul 29, 2024
5a5e542
chore: config file changes
Louay-Ben-nessir Jul 29, 2024
dcff2a1
fix: fixed stalling at the end of training
Louay-Ben-nessir Jul 29, 2024
d926c54
chore: code cleanup
Louay-Ben-nessir Jul 29, 2024
7e4698a
chore : various changes
Louay-Ben-nessir Jul 29, 2024
6dac8c3
fix: prevent the pipeline from stalling and a lot of cleanup
Louay-Ben-nessir Jul 30, 2024
23b582c
chore : better error messeages
Louay-Ben-nessir Jul 30, 2024
c71dad8
fix: changed the timestep discount
Louay-Ben-nessir Jul 30, 2024
bfea3aa
chore: very nitpicky clean ups
sash-a Jul 30, 2024
de92f5a
feat: pass timestep instead of obs and done and fix potential race co…
sash-a Jul 30, 2024
1465133
fix: deadlock in pipeline
sash-a Jul 30, 2024
6689c49
fix: wasting samples
Louay-Ben-nessir Aug 11, 2024
c506da3
chore: loss unpacking
Louay-Ben-nessir Aug 11, 2024
b24ac34
fix: updated to work with the latest gymnasium
Louay-Ben-nessir Oct 10, 2024
1dfb241
fix: jumanji
Louay-Ben-nessir Oct 10, 2024
fd8aece
fix: removed depricated gymnasium import
Louay-Ben-nessir Oct 10, 2024
ae53415
feat: minor refactor to sebulba utils
sash-a Oct 10, 2024
724d2dc
chore: a few minor changes to code style
sash-a Oct 10, 2024
fa8a996
Merge branch 'develop' into feat/sebulba_arch
sash-a Oct 11, 2024
0a36fdf
Merge branch 'feat/sebulba_arch' into seb-ff-ippo-only
sash-a Oct 11, 2024
47b8e03
fix: update configs to match latest mava
sash-a Oct 11, 2024
8be8037
fix: reshape with multiple learners and system name
sash-a Oct 11, 2024
4748636
fix: safer pipeline.clear()
sash-a Oct 11, 2024
5593bde
feat: avoid unecessary host-device transfers
sash-a Oct 14, 2024
133ea1a
chore: remove some more device transfers
sash-a Oct 14, 2024
9260e9b
chore: better graceful exit
sash-a Oct 14, 2024
d61dcfb
fix: create envs in main thread to avoid deadlocks
sash-a Oct 15, 2024
105d796
chore: use orginal rware and lbf
Louay-Ben-nessir Oct 15, 2024
f292bf3
fix: possible off by one fix
sash-a Oct 16, 2024
d42d732
fix: change to using gym.make to create envs and fix StepType
sash-a Oct 16, 2024
d4359c1
feat: learner env accumulation
Louay-Ben-nessir Oct 17, 2024
7c78478
feat: jit evaluation on cpu
sash-a Oct 17, 2024
aa49c6f
Merge branch 'seb-ff-ippo-only' of github.com:Louay-Ben-nessir/Mava i…
sash-a Oct 17, 2024
c252ffe
fix: timestep calculation with accumulation
Louay-Ben-nessir Oct 17, 2024
fd7a025
feat: shardmap almost working
sash-a Oct 17, 2024
4013a22
feat: shard_map working
sash-a Oct 18, 2024
0e559d9
fix: key use in actor loss
sash-a Oct 19, 2024
0a6bd49
fix: align gym config with other configs
sash-a Oct 19, 2024
641a548
feat: better env creation and safer sharding
sash-a Oct 19, 2024
c0c88bc
chore: minor env typing fixes
sash-a Oct 19, 2024
354159a
Merge branch 'develop' into seb-ff-ippo-only
sash-a Oct 19, 2024
6b2d01c
fix: start actors simultaneously to avoid deadlocks
Louay-Ben-nessir Oct 21, 2024
a13ab65
feat: support for smac
Louay-Ben-nessir Oct 23, 2024
bc55375
chore: pre-commits
Louay-Ben-nessir Oct 23, 2024
c6d460f
fix: random segfault
Louay-Ben-nessir Oct 27, 2024
659a837
fix: give each learner a unique random key
Louay-Ben-nessir Nov 4, 2024
7deb75b
chore: bunch of minor changes and fixes
Louay-Ben-nessir Nov 5, 2024
c024b71
chore: removed learner accumulation
Louay-Ben-nessir Nov 6, 2024
db378b9
fix: Metric tracking more aligned with Jumanji
Louay-Ben-nessir Nov 7, 2024
3d3cec8
fix: removed axis swaping & wrapper rename
Louay-Ben-nessir Nov 8, 2024
a7665f9
chore: pre-commits
Louay-Ben-nessir Nov 8, 2024
0c4e83b
chore: bunch of minor changes
Louay-Ben-nessir Nov 8, 2024
245aecc
fix: smaclite win rate tracking
Louay-Ben-nessir Nov 12, 2024
649b93b
Squashed commit of the following:
Louay-Ben-nessir Nov 12, 2024
a723392
fix: sebulba compatiable get_action_head
Louay-Ben-nessir Nov 13, 2024
a75b2a2
chore: pre-commits
Louay-Ben-nessir Nov 13, 2024
3fce221
fix: action_head parameters for all systems
Louay-Ben-nessir Nov 13, 2024
b6712e2
Merge branch 'develop' into seb-ff-ippo-only
Louay-Ben-nessir Nov 13, 2024
acf1830
chore: pre-commits
Louay-Ben-nessir Nov 13, 2024
7da5968
fix: rec_qmix import
Louay-Ben-nessir Nov 13, 2024
097df80
Merge branch 'develop' into seb-ff-ippo-only
sash-a Nov 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mava/advanced_usage/ff_ippo_store_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
1 change: 1 addition & 0 deletions mava/configs/arch/anakin.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# --- Anakin config ---
architecture_name: anakin

# --- Training ---
num_envs: 16 # Number of vectorised environments per device.
Expand Down
25 changes: 25 additions & 0 deletions mava/configs/arch/sebulba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# --- Sebulba config ---
architecture_name: sebulba

# --- Training ---
num_envs: 32 # number of environments per thread.

# --- Evaluation ---
evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select
# an action which corresponds to the greatest logit. If false, the policy will sample
# from the logits.
num_eval_episodes: 32 # Number of episodes to evaluate per evaluation.
num_evaluation: 100 # Number of evenly spaced evaluations to perform during training.
num_absolute_metric_eval_episodes: 320 # Number of episodes to evaluate the absolute metric (the final evaluation).
absolute_metric: True # Whether the absolute metric should be computed. For more details
# on the absolute metric please see: https://arxiv.org/abs/2209.10485

# --- Sebulba devices config ---
n_threads_per_executor: 2 # num of different threads/env batches per actor
actor_device_ids: [0] # ids of actor devices
learner_device_ids: [0] # ids of learner devices
rollout_queue_size : 5
# The size of the pipeline queue determines the extent of off-policy training allowed. A larger value permits more off-policy training.
# Too large of a value with too many actors will lead to all of the updates getting wasted in old episodes
# Too small of a value and the utility of having multiple actors is lost.
# A value of 1 with a single actor leads to almost strictly on-policy training.
11 changes: 11 additions & 0 deletions mava/configs/default/ff_ippo_sebulba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- logger: logger
- arch: sebulba
- system: ppo/ff_ippo
- network: mlp # [mlp, continuous_mlp, cnn]
- env: lbf_gym # [rware_gym, lbf_gym, smaclite_gym]
- _self_

hydra:
searchpath:
- file://mava/configs
25 changes: 25 additions & 0 deletions mava/configs/env/lbf_gym.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ---Environment Configs---
defaults:
- _self_

env_name: LevelBasedForaging # Used for logging purposes.
scenario:
name: lbforaging
task_name: Foraging-8x8-2p-1f-v3

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used.
# This should not be changed.
implicit_agent_id: False
# Whether or not to log the winrate of this environment. This should not be changed as not all
# environments have a winrate metric.
log_win_rate: False

# Weather or not to sum the returned rewards over all of the agents.
use_shared_rewards: True

kwargs:
max_episode_steps: 100
25 changes: 25 additions & 0 deletions mava/configs/env/rware_gym.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ---Environment Configs---
defaults:
- _self_

env_name: RobotWarehouse # Used for logging purposes.
scenario:
name: rware
task_name: rware-tiny-2ag-v2 # [rware-tiny-2ag-v2, rware-tiny-4ag-v2, rware-tiny-4ag-easy-v2, rware-small-4ag-v2]

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used.
# This should not be changed.
implicit_agent_id: False
# Whether or not to log the winrate of this environment. This should not be changed as not all
# environments have a winrate metric.
log_win_rate: False

# Weather or not to sum the returned rewards over all of the agents.
use_shared_rewards: True

kwargs:
max_episode_steps: 500
25 changes: 25 additions & 0 deletions mava/configs/env/smaclite_gym.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ---Environment Configs---
defaults:
- _self_

env_name: SMACLite # Used for logging purposes.
scenario:
name: smaclite
task_name: smaclite/2s3z-v0 # smaclite/ + ['10m_vs_11m-v0', '27m_vs_30m-v0', '3s5z_vs_3s6z-v0', '2s3z-v0', '3s5z-v0', '2c_vs_64zg-v0', '2s_vs_1sc-v0', '3s_vs_5z-v0']

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used.
# This should not be changed.
implicit_agent_id: False
# Whether or not to log the winrate of this environment. This should not be changed as not all
# environments have a winrate metric.
log_win_rate: True

# Weather or not to sum the returned rewards over all of the agents.
use_shared_rewards: True

kwargs:
max_episode_steps: 500
115 changes: 115 additions & 0 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import jax
import jax.numpy as jnp
import numpy as np
from chex import Array, PRNGKey
from flax.core.frozen_dict import FrozenDict
from jax import tree
Expand All @@ -36,6 +37,7 @@
RecActorApply,
State,
)
from mava.wrappers.gym import GymToJumanji

# Optional extras that are passed out of the actor and then into the actor in the next step
ActorState: TypeAlias = Dict[str, Any]
Expand Down Expand Up @@ -207,3 +209,116 @@ def eval_act_fn(
return action.squeeze(0), {_hidden_state: hidden_state}

return eval_act_fn


def get_sebulba_eval_fn(
env_maker: Callable[[int, int], GymToJumanji],
act_fn: EvalActFn,
config: DictConfig,
np_rng: np.random.Generator,
absolute_metric: bool,
) -> Tuple[EvalFn, Any]:
"""Creates a function that can be used to evaluate agents on a given environment.

Args:
----
env_maker: A function to create the environment instances.
act_fn: A function that takes in params, timestep, key and optionally a state
and returns actions and optionally a state (see `EvalActFn`).
config: The system config.
np_rng: Random number generator for seeding environment.
absolute_metric: Whether or not this evaluator calculates the absolute_metric.
This determines how many evaluation episodes it does.
"""
n_devices = jax.device_count()
eval_episodes = (
config.arch.num_absolute_metric_eval_episodes
if absolute_metric
else config.arch.num_eval_episodes
)

n_parallel_envs = min(eval_episodes, config.arch.num_envs)
episode_loops = math.ceil(eval_episodes / n_parallel_envs)
env = env_maker(config, n_parallel_envs)

act_fn = jax.jit(
act_fn, device=jax.local_devices()[config.arch.actor_device_ids[0]]
) # Evaluate using the first actor device

# Warnings if num eval episodes is not divisible by num parallel envs.
if eval_episodes % n_parallel_envs != 0:
warnings.warn(
f"Number of evaluation episodes ({eval_episodes}) is not divisible by `num_envs` * "
f"`num_devices` ({n_parallel_envs} * {n_devices}). Some extra evaluations will be "
f"executed. New number of evaluation episodes = {episode_loops * n_parallel_envs}",
stacklevel=2,
)

def eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics:
"""Evaluates the given params on an environment and returns relevent metrics.

Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length,
also win rate for environments that support it.

Returns: Dict[str, Array] - dictionary of metric name to metric values for each episode.
"""

def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
"""Simulates `num_envs` episodes."""

# Generate a list of random seeds within the 32-bit integer range, using a seeded RNG.
seeds = np_rng.integers(np.iinfo(np.int32).max, size=n_parallel_envs).tolist()
ts = env.reset(seed=seeds)

timesteps_array = [ts]

actor_state = init_act_state
finished_eps = ts.last()

while not finished_eps.all():
key, act_key = jax.random.split(key)
action, actor_state = act_fn(params, ts, act_key, actor_state)
cpu_action = jax.device_get(action)
ts = env.step(cpu_action)
timesteps_array.append(ts)

finished_eps = np.logical_or(finished_eps, ts.last())

timesteps = jax.tree.map(lambda *x: np.stack(x), *timesteps_array)

metrics = timesteps.extras["episode_metrics"]
if config.env.log_win_rate:
metrics["won_episode"] = timesteps.extras["won_episode"]

# find the first instance of done to get the metrics at that timestep, we don't
# care about subsequent steps because we only the results from the first episode
done_idx = np.argmax(timesteps.last(), axis=0)
metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
del metrics["is_terminal_step"] # uneeded for logging

return key, metrics

# This loop is important because we don't want too many parallel envs.
# So in evaluation we have num_envs parallel envs and loop enough times
# so that we do at least `eval_episodes` number of episodes.
metrics_array = []
for _ in range(episode_loops):
key, metric = _episode(key)
metrics_array.append(metric)

# flatten metrics
metrics: Metrics = jax.tree_map(lambda *x: np.array(x).reshape(-1), *metrics_array)
return metrics

def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics:
"""Wrapper around eval function to time it and add in steps per second metric."""
start_time = time.time()

metrics = eval_fn(params, key, init_act_state)

end_time = time.time()
total_timesteps = jnp.sum(metrics["episode_length"])
metrics["steps_per_second"] = total_timesteps / (end_time - start_time)
return metrics

return timed_eval_fn, env
13 changes: 0 additions & 13 deletions mava/systems/__init__.py

This file was deleted.

4 changes: 2 additions & 2 deletions mava/systems/mat/anakin/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import (
merge_leading_dims,
unreplicate_batch_dim,
unreplicate_n_dims,
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -355,7 +355,7 @@ def learner_setup(
init_x = env.observation_spec().generate_value()
init_x = tree.map(lambda x: x[None, ...], init_x)

_, action_space_type = get_action_head(env)
_, action_space_type = get_action_head(env.action_spec())

if action_space_type == "discrete":
init_action = jnp.zeros((1, config.system.num_agents), dtype=jnp.int32)
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import (
merge_leading_dims,
unreplicate_batch_dim,
unreplicate_n_dims,
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -362,7 +362,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -346,7 +346,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -457,7 +457,7 @@ def learner_setup(
# Define network and optimisers.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics

Expand Down Expand Up @@ -452,7 +452,7 @@ def learner_setup(
# Define network and optimiser.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env)
action_head, _ = get_action_head(env.action_spec())
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down
Loading