-
Notifications
You must be signed in to change notification settings - Fork 94
/
Copy pathtransf_qmix.py
1062 lines (912 loc) · 44 KB
/
transf_qmix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
TODO: refactor this code to use the new qlearning scripts, see qmix_rnn.py
End-to-End JAX Implementation of TransfQMix.
The implementation closely follows the original one https://github.com/mttga/pymarl_transformers with some additional features:
- The embeddings can be normalized with batch norm in order to stabilize the self-attention gradients.
- It's added the possibility to perform $n$ training updates of the network at each update step.
Currently supports only MPE_spread and SMAX. Remember that to use the transformers in your environment you need
to reshape the observations and states to matrices. See: jaxmarl.wrappers.transformers
"""
import os
import copy
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import NamedTuple, Dict, Union
import chex
import optax
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
import flashbax as fbx
import wandb
import hydra
from omegaconf import OmegaConf
from safetensors.flax import save_file
from flax.traverse_util import flatten_dict
from jaxmarl import make
from jaxmarl.wrappers.baselines import LogWrapper, MPELogWrapper, SMAXLogWrapper
from jaxmarl.wrappers.transformers import TransformersCTRolloutManager
from jaxmarl.environments.smax import map_name_to_scenario
from typing import Any
class EncoderBlock(nn.Module):
hidden_dim : int # Input dimension is needed here since it is equal to the output dimension (residual connection)
num_heads : int
dim_feedforward : int
init_scale: float
use_fast_attention: bool
dropout_prob : float = 0.
def setup(self):
# Attention layer
if self.use_fast_attention:
from utils.fast_attention import make_fast_generalized_attention
raw_attention_fn = make_fast_generalized_attention(
self.hidden_dim // self.num_heads,
renormalize_attention=True,
nb_features=self.hidden_dim,
unidirectional=False
)
self.self_attn = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
dropout_rate=self.dropout_prob,
attention_fn=raw_attention_fn,
kernel_init=nn.initializers.xavier_uniform(),
use_bias=False,
)
else:
self.self_attn = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
dropout_rate=self.dropout_prob,
kernel_init=nn.initializers.xavier_uniform(),
use_bias=False,
)
# Two-layer MLP
self.linear = [
nn.Dense(self.dim_feedforward, kernel_init=nn.initializers.xavier_uniform(), bias_init=constant(0.0)),
nn.Dense(self.hidden_dim, kernel_init=nn.initializers.xavier_uniform(), bias_init=constant(0.0))
]
# Layers to apply in between the main layers
self.norm1 = nn.LayerNorm()
self.norm2 = nn.LayerNorm()
self.dropout = nn.Dropout(self.dropout_prob)
def __call__(self, x, mask=None, deterministic=True):
# Attention part
if mask is not None and not self.use_fast_attention: # masking is not compatible with fast self attention
mask = jnp.repeat(nn.make_attention_mask(mask, mask), self.num_heads, axis=-3)
attended = self.self_attn(inputs_q=x, inputs_kv=x, mask=mask, deterministic=deterministic)
x = self.norm1(attended + x)
x = x + self.dropout(x, deterministic=deterministic)
# MLP part
feedforward = self.linear[0](x)
feedforward = nn.relu(feedforward)
feedforward = self.linear[1](feedforward)
x = self.norm2(feedforward+x)
x = x + self.dropout(x, deterministic=deterministic)
return x
class Embedder(nn.Module):
hidden_dim: int
init_scale: float
scale_inputs: bool = True
activation: bool = False
@nn.compact
def __call__(self, x, train:bool):
if self.scale_inputs:
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.Dense(self.hidden_dim, kernel_init=orthogonal(self.init_scale), bias_init=constant(0.0))(x)
if self.activation:
x = nn.relu(x)
x = nn.BatchNorm(use_running_average=not train)(x)
return x
class ScannedTransformer(nn.Module):
hidden_dim: int
init_scale: float
transf_num_layers: int
transf_num_heads: int
transf_dim_feedforward: int
transf_dropout_prob: float = 0
deterministic: bool = True
return_embeddings: bool = False
use_fast_attention: bool = False
def setup(self):
self.encoders = [
EncoderBlock(
self.hidden_dim,
self.transf_num_heads,
self.transf_dim_feedforward,
self.init_scale,
self.use_fast_attention,
self.transf_dropout_prob,
) for _ in range(self.transf_num_layers)
]
@partial(
nn.scan,
variable_broadcast="params",
in_axes=0,
out_axes=0,
split_rngs={"params": False},
)
def __call__(self, carry, x):
hs = carry
embeddings, mask, done = x
hs = jnp.where(
done[:, np.newaxis, np.newaxis],
self.initialize_carry(self.hidden_dim, *done.shape, 1),
hs
)
embeddings = jnp.concatenate((
hs,
embeddings,
), axis=-2)
for layer in self.encoders:
embeddings = layer(embeddings, mask=mask, deterministic=self.deterministic)
hs = embeddings[..., 0:1, :]
# as y return the entire embeddings if required (i.e. transformer mixer), otherwise only agents' hs embeddings
if self.return_embeddings:
return hs, embeddings
else:
return hs, hs
@staticmethod
def initialize_carry(hidden_size, *batch_size):
return jnp.zeros((*batch_size, hidden_size))
class TransformerAgent(nn.Module):
action_dim: int
hidden_dim: int
init_scale_emb: float
init_scale_transf: float
init_scale_q: float
transf_num_layers: int
transf_num_heads: int
transf_dim_feedforward: int
transf_dropout_prob: float
deterministic: bool
use_fast_attention: bool = True
scale_inputs: bool = True
relu_emb: bool = True
@nn.compact
def __call__(self, hs, x, train=True, return_all_hs=False):
ins, resets = x
embeddings = Embedder(
self.hidden_dim,
init_scale=self.init_scale_emb,
scale_inputs=self.scale_inputs,
activation=self.relu_emb,
)(ins, train)
last_hs, hidden_states = ScannedTransformer(
hidden_dim=self.hidden_dim,
init_scale=self.init_scale_transf,
transf_num_layers=self.transf_num_layers,
transf_num_heads=self.transf_num_heads,
transf_dim_feedforward=self.transf_dim_feedforward,
use_fast_attention=self.use_fast_attention,
deterministic=True,
return_embeddings=False,
)(hs, (embeddings, None, resets))
q_vals = nn.Dense(self.action_dim, kernel_init=orthogonal(self.init_scale_q), bias_init=constant(0.0))(hidden_states)
if return_all_hs:
return last_hs, (hidden_states, q_vals)
else:
return last_hs, q_vals
class TransformerAgentSmax(nn.Module):
# variation of transformer agent which uses policy decomposition to
# compute the q-values of attacking an enemy from the embedding of that enemy
action_dim: int
hidden_dim: int
init_scale_emb: float
init_scale_transf: float
init_scale_q: float
transf_num_layers: int
transf_num_heads: int
transf_dim_feedforward: int
transf_dropout_prob: float
deterministic: bool = True
num_movement_actions: int = 5
use_fast_attention: bool = False
scale_inputs: bool = True
relu_emb: bool = True
@nn.compact
def __call__(self, hs, x, train=True, return_all_hs=False):
ins, resets = x
# mask for the death/invisible agents, which are assumed to have obs==0
mask = jnp.all(ins==0, axis=-1).astype(bool)
mask = jnp.concatenate((jnp.zeros((*mask.shape[:-1], 1)),mask), axis=-1) # add a positive mask for the agent internal hidden state that will be added later
embeddings = Embedder(
self.hidden_dim,
init_scale=self.init_scale_emb,
scale_inputs=self.scale_inputs,
activation=self.relu_emb,
)(ins, train)
last_hs, embeddings = ScannedTransformer(
hidden_dim=self.hidden_dim,
init_scale=self.init_scale_transf,
transf_num_layers=self.transf_num_layers,
transf_num_heads=self.transf_num_heads,
transf_dim_feedforward=self.transf_dim_feedforward,
use_fast_attention=self.use_fast_attention,
deterministic=True,
return_embeddings=True,
)(hs, (embeddings, mask, resets))
# q_vals for the movement actions are computed from agents hidden states
hidden_states = embeddings[..., 0:1, :]
q_mov = nn.Dense(
self.num_movement_actions,
kernel_init=orthogonal(self.init_scale_q),
bias_init=constant(0.0),
)(hidden_states) # time_step, batch_size, 1, 5
# q_vals for attacking an enemy is computed from attacking that enemy
n_enemies = self.action_dim-self.num_movement_actions
enemy_embeddings = embeddings[..., -n_enemies-1:-1, :] # last embedding is 'self', just before are the enemies
q_attack = nn.Dense(
1,
kernel_init=orthogonal(self.init_scale_q),
bias_init=constant(0.0)
)(enemy_embeddings) # time_step, batch_size, n_enemies, 1
q_vals = jnp.concatenate((q_mov,jnp.swapaxes(q_attack, -1, -2)), axis=-1)
if return_all_hs:
return last_hs, (hidden_states, q_vals)
else:
return last_hs, q_vals
class TransformerMixer(nn.Module):
hidden_dim: int
init_scale: float
transf_num_layers: int
transf_num_heads: int
transf_dim_feedforward: int
scale_inputs: bool = True
use_fast_attention: bool = True
relu_emb: bool = True
@nn.compact
def __call__(self, q_vals, hs_agents, states, done, train=True):
n_agents, time_steps, batch_size = q_vals.shape
q_vals = jnp.transpose(q_vals, (1, 2, 0)) # (time_steps, batch_size, n_agents)
# the embeddings consist in the state-matrix embeddings and the hidden state of the agents
hs_agents = hs_agents.reshape(time_steps, batch_size, n_agents, self.hidden_dim)
mixer_embs = Embedder(
self.hidden_dim,
init_scale=self.init_scale,
scale_inputs=self.scale_inputs,
activation=self.relu_emb,
)(states, train)
mixer_embs = jnp.concatenate((
mixer_embs,
hs_agents,
), axis=-2)
hs_mixer = ScannedTransformer.initialize_carry(self.hidden_dim, batch_size, 1)
_, hyp_emb = ScannedTransformer(
hidden_dim=self.hidden_dim,
init_scale=self.init_scale,
transf_num_layers=self.transf_num_layers,
transf_num_heads=self.transf_num_heads,
transf_dim_feedforward=self.transf_dim_feedforward,
deterministic=True,
return_embeddings=True,
use_fast_attention=self.use_fast_attention,
)(hs_mixer, (mixer_embs, None, done)) # for now the mixer doesn't mask the embeddings
# monotonicity and reshaping
main_emb = hyp_emb[..., 0:1, :] # main embedding is the hs of the mixer
w_1 = jnp.abs(hyp_emb[..., -n_agents:, :].reshape(time_steps, batch_size, n_agents, self.hidden_dim)) # w1 is a transformation of the agents' hs
b_1 = main_emb.reshape(time_steps, batch_size, 1, self.hidden_dim)
w_2 = jnp.abs(main_emb.reshape(time_steps, batch_size, self.hidden_dim, 1))
b_2 = nn.Dense(1, kernel_init=orthogonal(self.init_scale), bias_init=constant(0.))(nn.relu(main_emb))
b_2 = b_2.reshape(time_steps, batch_size, 1, 1)
# mix
hidden = nn.elu(jnp.matmul(q_vals[:, :, None, :], w_1) + b_1)
q_tot = jnp.matmul(hidden, w_2) + b_2
return q_tot.squeeze() # (time_steps, batch_size)
class EpsilonGreedy:
"""Epsilon Greedy action selection"""
def __init__(self, start_e: float, end_e: float, duration: int):
self.start_e = start_e
self.end_e = end_e
self.duration = duration
self.slope = (end_e - start_e) / duration
@partial(jax.jit, static_argnums=0)
def get_epsilon(self, t: int):
e = self.slope*t + self.start_e
return jnp.clip(e, self.end_e)
@partial(jax.jit, static_argnums=0)
def choose_actions(self, q_vals: dict, t: int, rng: chex.PRNGKey):
def explore(q, eps, key):
key_a, key_e = jax.random.split(key, 2) # a key for sampling random actions and one for picking
greedy_actions = jnp.argmax(q, axis=-1) # get the greedy actions
random_actions = jax.random.randint(key_a, shape=greedy_actions.shape, minval=0, maxval=q.shape[-1]) # sample random actions
pick_random = jax.random.uniform(key_e, greedy_actions.shape)<eps # pick which actions should be random
chosen_actions = jnp.where(pick_random, random_actions, greedy_actions)
return chosen_actions
eps = self.get_epsilon(t)
keys = dict(zip(q_vals.keys(), jax.random.split(rng, len(q_vals)))) # get a key for each agent
chosen_actions = jax.tree.map(lambda q, k: explore(q, eps, k), q_vals, keys)
return chosen_actions
class Transition(NamedTuple):
obs: dict
actions: dict
rewards: dict
dones: dict
infos: dict
def tree_mean(tree):
return jnp.array(
jax.tree_leaves(jax.tree.map(lambda x: x.mean(), tree))
).mean()
def make_train(config, env):
config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
def train(rng):
# INIT ENV
rng, _rng = jax.random.split(rng)
wrapped_env = TransformersCTRolloutManager(env, batch_size=config["NUM_ENVS"])
test_env = TransformersCTRolloutManager(env, batch_size=config["NUM_TEST_EPISODES"]) # batched env for testing (has different batch size)
init_obs, env_state = wrapped_env.batch_reset(_rng)
init_dones = {agent:jnp.zeros((config["NUM_ENVS"]), dtype=bool) for agent in env.agents+['__all__']}
# INIT BUFFER
# to initalize the buffer is necessary to sample a trajectory to know its strucutre
def _env_sample_step(env_state, unused):
rng, key_a, key_s = jax.random.split(jax.random.PRNGKey(0), 3) # use a dummy rng here
key_a = jax.random.split(key_a, env.num_agents)
actions = {agent: wrapped_env.batch_sample(key_a[i], agent) for i, agent in enumerate(env.agents)}
obs, env_state, rewards, dones, infos = wrapped_env.batch_step(key_s, env_state, actions)
transition = Transition(obs, actions, rewards, dones, infos)
return env_state, transition
_, sample_traj = jax.lax.scan(
_env_sample_step, env_state, None, config["NUM_STEPS"]
)
sample_traj_unbatched = jax.tree.map(lambda x: x[:, 0], sample_traj) # remove the NUM_ENV dim
buffer = fbx.make_trajectory_buffer(
max_length_time_axis=config['BUFFER_SIZE']//config['NUM_ENVS'],
min_length_time_axis=config['BUFFER_BATCH_SIZE'],
sample_batch_size=config['BUFFER_BATCH_SIZE'],
add_batch_size=config['NUM_ENVS'],
sample_sequence_length=1,
period=1,
)
buffer_state = buffer.init(sample_traj_unbatched)
# INIT NETWORK
# init agent
if 'smax' in env.name.lower(): # smax agent
agent_class = TransformerAgentSmax
n_entities = wrapped_env._env.num_allies+wrapped_env._env.num_enemies # must be explicit for the n_entities if using policy decoupling
init_x = (
jnp.zeros((1, 1, n_entities, sample_traj.obs[env.agents[0]].shape[-1])), # (time_step, batch_size, n_entities, obs_size)
jnp.zeros((1, 1)) # (time_step, batch size)
)
else:
agent_class = TransformerAgent
init_x = (
jnp.zeros((1, 1, 1, sample_traj.obs[env.agents[0]].shape[-1])), # (time_step, batch_size, n_entities, obs_size)
jnp.zeros((1, 1)) # (time_step, batch size)
)
agent = agent_class(
action_dim=wrapped_env.max_action_space,
hidden_dim=config['AGENT_HIDDEN_DIM'],
init_scale_emb=config['AGENT_INIT_SCALE'],
init_scale_transf=config['AGENT_INIT_SCALE'],
init_scale_q=config['AGENT_INIT_SCALE'],
transf_num_layers=config['AGENT_TRANSF_NUM_LAYERS'],
transf_num_heads=config['AGENT_TRANSF_NUM_HEADS'],
transf_dim_feedforward=config['AGENT_TRANSF_DIM_FF'],
use_fast_attention=config['USE_FAST_ATTENTION'],
scale_inputs=config['SCALE_INPUTS'],
relu_emb=config['EMBEDDER_USE_RELU'],
transf_dropout_prob=0.,
deterministic=True,
)
rng, _rng = jax.random.split(rng)
init_hs = ScannedTransformer.initialize_carry(config['AGENT_HIDDEN_DIM'], 1, 1) # (batch_size, hidden_dim)
agent_params = agent.init(_rng, init_hs, init_x, train=False)
# init mixer
rng, _rng = jax.random.split(rng)
state_size = sample_traj.obs['__all__'].shape[-1] # get the state shape from the buffer
init_x = (
jnp.zeros((len(env.agents), 1, 1)), # q_vals: n_agents, time_steps, batch_size
ScannedTransformer.initialize_carry(config['AGENT_HIDDEN_DIM'], len(env.agents), 1), # hs_agents: time_step, n_agents*batch_size, hidden_dim
jnp.zeros((1, 1, 1, state_size)), # states: time_step, batch_size, n_entities, state_size
jnp.zeros((1, 1)), # done: (time_step, batch size)
False, # train
)
mixer = TransformerMixer(
hidden_dim=config['AGENT_HIDDEN_DIM'],
init_scale=config['MIXER_INIT_SCALE'],
transf_num_layers=config['MIXER_TRANSF_NUM_LAYERS'],
transf_num_heads=config['MIXER_TRANSF_NUM_HEADS'],
transf_dim_feedforward=config['MIXER_TRANSF_DIM_FF'],
scale_inputs=config['SCALE_INPUTS'],
relu_emb=config['EMBEDDER_USE_RELU'],
use_fast_attention=config['USE_FAST_ATTENTION'],
)
mixer_params = mixer.init(_rng, *init_x)
# init optimizer
network_params = {'agent':agent_params['params'],'mixer':mixer_params['params']}
network_stats = {'agent':agent_params['batch_stats'],'mixer':mixer_params['batch_stats']}
# print number of params
agent_params = sum(x.size for x in jax.tree_leaves(network_params['agent']))
mixer_params = sum(x.size for x in jax.tree_leaves(network_params['mixer']))
jax.debug.print("Number of agent params: {x}", x=agent_params)
jax.debug.print("Number of mixer params: {x}", x=mixer_params)
# INIT TRAIN STATE AND OPTIMIZER
def linear_schedule(count):
frac = 1.0 - (count / config["NUM_UPDATES"]*config['N_MINI_UPDATES'])
return config["LR"] * frac
def exponential_schedule(count):
return config["LR"] * (1-config['LR_EXP_DECAY_RATE'])**count
decay_type = config.get('LR_DECAY_TYPE', False)
if decay_type == 'cos':
lr = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=config["LR"],
warmup_steps=config['LR_WARMUP'],
decay_steps=config["NUM_UPDATES"],
end_value=0.0
)
elif decay_type == 'exp':
lr = exponential_schedule
elif 'linear':
lr = linear_schedule
else:
lr = config['LR']
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(learning_rate=lr, eps=config['EPS_ADAM']),
)
# to include the batch normalization stats
class TrainState_(TrainState):
batch_stats: Any
train_state = TrainState_.create(
apply_fn=agent.apply,
params=network_params,
batch_stats=network_stats,
tx=tx,
)
# target network params
copy_tree = lambda tree: jax.tree.map(lambda x: jnp.copy(x), tree)
target_network_state = {'params':copy_tree(train_state.params), 'batch_stats':copy_tree(train_state.batch_stats)}
# INIT EXPLORATION STRATEGY
explorer = EpsilonGreedy(
start_e=config["EPSILON_START"],
end_e=config["EPSILON_FINISH"],
duration=config["EPSILON_ANNEAL_TIME"]
)
def homogeneous_pass(params, batch_stats, hidden_state, obs, dones, return_all_hs=False, train=True):
# concatenate agents and parallel envs to process them in one batch
agents, flatten_agents_obs = zip(*obs.items())
original_shape = flatten_agents_obs[0].shape # assumes obs shape is the same for all agents
batched_input = (
jnp.concatenate(flatten_agents_obs, axis=1), # (time_step, n_agents*n_envs, n_entities, obs_size)
jnp.concatenate([dones[agent] for agent in agents], axis=1), # ensure to not pass other keys (like __all__)
)
# if train, the outs contain the update of the batch norm
if train:
outs, batch_norm_update = agent.apply(
{'params':params,'batch_stats':batch_stats},
hidden_state,
batched_input,
return_all_hs=return_all_hs,
train=True,
mutable=['batch_stats']
)
else:
batch_norm_update = None
outs = agent.apply(
{'params':params,'batch_stats':batch_stats},
hidden_state,
batched_input,
return_all_hs=return_all_hs,
train=False
)
# if return all hs, the outs contain all the hidden states of the agents per each time-step
if return_all_hs:
hidden_state, (h_states, q_vals) = outs
else:
hidden_state, q_vals = outs
q_vals = q_vals.reshape(original_shape[0], len(agents), *original_shape[1:-2], -1) # (time_steps, n_agents, n_envs, action_dim)
q_vals = {a:q_vals[:,i] for i,a in enumerate(agents)}
if return_all_hs:
return batch_norm_update, hidden_state, h_states, q_vals
else:
return batch_norm_update, hidden_state, q_vals
# TRAINING LOOP
def _update_step(runner_state, unused):
train_state, target_network_state, env_state, buffer_state, time_state, init_obs, init_dones, test_metrics, rng = runner_state
# EPISODE STEP
env_params = train_state.params['agent']
env_batch_norm = train_state.batch_stats['agent']
def _env_step(step_state, unused):
env_state, last_obs, last_dones, hstate, rng, t = step_state
# prepare rngs for actions and step
rng, key_a, key_s = jax.random.split(rng, 3)
# SELECT ACTION
# add a dummy time_step dimension to the agent input
obs_ = {a:last_obs[a] for a in env.agents} # ensure to not pass the global state (obs["__all__"]) to the network
obs_ = jax.tree.map(lambda x: x[np.newaxis, :], obs_)
dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones)
# get the q_values from the agent netwoek
_, hstate, q_vals = homogeneous_pass(env_params, env_batch_norm, hstate, obs_, dones_, train=False)
# remove the dummy time_step dimension and index qs by the valid actions of each agent
valid_q_vals = jax.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions)
# explore with epsilon greedy_exploration
actions = explorer.choose_actions(valid_q_vals, t, key_a)
# STEP ENV
obs, env_state, rewards, dones, infos = wrapped_env.batch_step(key_s, env_state, actions)
# reward scaling
rewards = jax.tree.map(lambda x:config.get("REW_SCALE", 1)*x, rewards)
transition = Transition(last_obs, actions, rewards, dones, infos)
step_state = (env_state, obs, dones, hstate, rng, t+1)
return step_state, transition
# prepare the step state and collect the episode trajectory
rng, _rng = jax.random.split(rng)
hstate = ScannedTransformer.initialize_carry(config['AGENT_HIDDEN_DIM'], len(env.agents)*config["NUM_ENVS"], 1) # (n_agents*n_envs, hs_size)
step_state = (
env_state,
init_obs,
init_dones,
hstate,
_rng,
time_state['timesteps'] # t is needed to compute epsilon
)
step_state, traj_batch = jax.lax.scan(
_env_step, step_state, None, config["NUM_STEPS"]
)
# BUFFER UPDATE: save the collected trajectory in the buffer
buffer_traj_batch = jax.tree.map(
lambda x:jnp.swapaxes(x, 0, 1)[:, np.newaxis], # put the batch dim first and add a dummy sequence dim
traj_batch
) # (num_envs, 1, time_steps, ...)
buffer_state = buffer.add(buffer_state, buffer_traj_batch)
# LEARN PHASE
def q_of_action(q, u):
"""index the q_values with action indices"""
q_u = jnp.take_along_axis(q, jnp.expand_dims(u, axis=-1), axis=-1)
return jnp.squeeze(q_u, axis=-1)
def _network_update(carry, unused):
train_state, rng = carry
# sample a batched trajectory from the buffer and set the time step dim in first axis
rng, _rng = jax.random.split(rng)
learn_traj = buffer.sample(buffer_state, _rng).experience # (batch_size, 1, max_time_steps, ...)
learn_traj = jax.tree.map(
lambda x: jnp.swapaxes(x[:, 0], 0, 1), # remove the dummy sequence dim (1) and swap batch and temporal dims
learn_traj
) # (max_time_steps, batch_size, ...)
init_hs = ScannedTransformer.initialize_carry(config['AGENT_HIDDEN_DIM'], len(env.agents)*config["BUFFER_BATCH_SIZE"], 1) # (n_agents*batch_size, hs_size)
def _loss_fn(params, init_hs, learn_traj):
obs_ = {a:learn_traj.obs[a] for a in env.agents} # ensure to not pass the global state (obs["__all__"]) to the network
updates_agent, _, hs_agents, q_vals = homogeneous_pass(
params['agent'],
train_state.batch_stats['agent'],
init_hs,
obs_,
learn_traj.dones,
return_all_hs=True,
train=True
)
_, _, hs_target_agents, target_q_vals = homogeneous_pass(
target_network_state['params']['agent'],
train_state.batch_stats['agent'],
init_hs,
obs_,
learn_traj.dones,
return_all_hs=True,
train=False
)
# stop the gradient from passing with the hidden states between agents and mixer
hs_agents = jax.lax.stop_gradient(hs_agents)
hs_target_agents = jax.lax.stop_gradient(hs_target_agents)
# get the q_vals of the taken actions (with exploration) for each agent
chosen_action_qvals = jax.tree.map(
lambda q, u: q_of_action(q, u)[:-1], # avoid last timestep
q_vals,
learn_traj.actions
)
# get the target q value of the greedy actions for each agent
valid_q_vals = jax.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions)
target_max_qvals = jax.tree.map(
lambda t_q, q: q_of_action(t_q, jnp.argmax(q, axis=-1))[1:], # avoid first timestep
target_q_vals,
jax.lax.stop_gradient(valid_q_vals)
)
# compute q_tot with the mixer network
chosen_action_qvals_mix, updates_mixer = mixer.apply(
{'params':params['mixer'],'batch_stats':train_state.batch_stats['mixer']},
jnp.stack(list(chosen_action_qvals.values())),
hs_agents[:-1], # hs of agents, avoiding last timestep
learn_traj.obs['__all__'][:-1], # avoid last timestep
learn_traj.dones['__all__'][:-1], # avoid last timestep
train=True,
mutable=['batch_stats'],
)
target_max_qvals_mix = mixer.apply(
{'params':target_network_state['params']['mixer'],'batch_stats':train_state.batch_stats['mixer']},
jnp.stack(list(target_max_qvals.values())),
hs_target_agents[1:], # hs of target agents, avoiding first timestep
learn_traj.obs['__all__'][1:], # avoid first timestep
learn_traj.dones['__all__'][1:], # avoid last timestep
train=False,
)
# compute target
if config.get('TD_LAMBDA_LOSS', True):
# time difference loss
def _td_lambda_target(ret, values):
reward, done, target_qs = values
ret = jnp.where(
done,
target_qs,
ret*config['TD_LAMBDA']*config['GAMMA']
+ reward
+ (1-config['TD_LAMBDA'])*config['GAMMA']*(1-done)*target_qs
)
return ret, ret
ret = target_max_qvals_mix[-1] * (1-learn_traj.dones['__all__'][-1])
ret, td_targets = jax.lax.scan(
_td_lambda_target,
ret,
(learn_traj.rewards['__all__'][-2::-1], learn_traj.dones['__all__'][-2::-1], target_max_qvals_mix[-1::-1])
)
targets = td_targets[::-1]
loss = jnp.mean(0.5*((chosen_action_qvals_mix - jax.lax.stop_gradient(targets))**2))
else:
# standard DQN loss
targets = (
learn_traj.rewards['__all__'][:-1]
+ config['GAMMA']*(1-learn_traj.dones['__all__'][:-1])*target_max_qvals_mix
)
loss = jnp.mean((chosen_action_qvals_mix - jax.lax.stop_gradient(targets))**2)
batch_norm_update = {'agent':updates_agent['batch_stats'], 'mixer':updates_mixer['batch_stats']}
return loss, (targets, batch_norm_update)
# compute loss and optimize grad
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(loss, (targets, batch_norm_update)), grads = grad_fn(train_state.params, init_hs, learn_traj)
train_state = train_state.apply_gradients(grads=grads)
train_state = train_state.replace(batch_stats=batch_norm_update)
update_info = {'loss':loss, 'targets':targets.mean(), 'grad':tree_mean(grads)}
return (train_state, rng), update_info
# perform n updates over the network
rng, _rng = jax.random.split(rng)
update_info_zero = dict(zip(['loss', 'targets', 'grad'], [jnp.zeros(config['N_MINI_UPDATES'])]*3)) # default update info when cannot sample
(train_state, rng), update_info = jax.lax.cond(
buffer.can_sample(buffer_state),
lambda train_state, rng: jax.lax.scan(_network_update, (train_state, rng), None, config['N_MINI_UPDATES']),
lambda train_state, rng: ((train_state, rng), update_info_zero), # do nothing
train_state,
_rng
)
# UPDATE THE VARIABLES AND RETURN
# reset the environment
rng, _rng = jax.random.split(rng)
init_obs, env_state = wrapped_env.batch_reset(_rng)
init_dones = {agent:jnp.zeros((config["NUM_ENVS"]), dtype=bool) for agent in env.agents+['__all__']}
# update the states
time_state['timesteps'] = step_state[-1]
time_state['updates'] = time_state['updates'] + 1
# update the target network if necessary
target_network_state = jax.lax.cond(
time_state['updates'] % config['TARGET_UPDATE_INTERVAL'] == 0,
lambda _: {'params':copy_tree(train_state.params), 'batch_stats':copy_tree(train_state.batch_stats)},
lambda _: target_network_state,
operand=None
)
# update the greedy rewards
rng, _rng = jax.random.split(rng)
test_metrics = jax.lax.cond(
time_state['updates'] % (config["TEST_INTERVAL"] // config["NUM_STEPS"] // config["NUM_ENVS"]) == 0,
lambda _: get_greedy_metrics(_rng, train_state, time_state),
lambda _: test_metrics,
operand=None
)
# update the returning metrics
metrics = {
'running_metrics':{
'env_step': time_state['timesteps']*config['NUM_ENVS'],
'num_updates' : time_state['updates'],
'loss': update_info['loss'].mean(),
'returns': traj_batch.rewards['__all__'].sum(axis=0).mean(), # mean of sum accross timesteps
'targets_mean': update_info['targets'].mean(),
'grad_mean':update_info['grad'].mean(),
'params_agent_mean':tree_mean(train_state.params['agent']),
'params_mixer_mean':tree_mean(train_state.params['mixer']),
},
'test_metrics': test_metrics
}
if config.get('WANDB_ONLINE_REPORT', False):
def callback(metrics, infos):
info_metrics = {
k:v[...,0][infos["returned_episode"][..., 0]].mean()
for k,v in infos.items() if k!="returned_episode"
}
wandb.log(
{
**metrics['running_metrics'],
**info_metrics,
**{k:v.mean() for k, v in metrics['test_metrics'].items()}
}
)
jax.debug.callback(callback, metrics, traj_batch.infos)
runner_state = (
train_state,
target_network_state,
env_state,
buffer_state,
time_state,
init_obs,
init_dones,
test_metrics,
rng
)
if config.get('WANDB_ONLINE_REPORT', False):
return runner_state, None # don't return metrics if you're using wandb to save memory
else:
return runner_state, metrics
def get_greedy_metrics(rng, train_state, time_state):
"""Help function to test greedy policy during training"""
env_params = train_state.params['agent']
env_batch_norm = train_state.batch_stats['agent']
def _greedy_env_step(step_state, unused):
env_state, last_obs, last_dones, hstate, rng = step_state
rng, key_s = jax.random.split(rng)
obs_ = {a:last_obs[a] for a in env.agents}
obs_ = jax.tree.map(lambda x: x[np.newaxis, :], obs_)
dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones)
_, hstate, q_vals = homogeneous_pass(env_params, env_batch_norm, hstate, obs_, dones_, train=False)
actions = jax.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions)
obs, env_state, rewards, dones, infos = test_env.batch_step(key_s, env_state, actions)
step_state = (env_state, obs, dones, hstate, rng)
return step_state, (rewards, dones, infos)
rng, _rng = jax.random.split(rng)
init_obs, env_state = test_env.batch_reset(_rng)
init_dones = {agent:jnp.zeros((config["NUM_TEST_EPISODES"]), dtype=bool) for agent in env.agents+['__all__']}
rng, _rng = jax.random.split(rng)
hstate = ScannedTransformer.initialize_carry(config['AGENT_HIDDEN_DIM'], len(env.agents)*config["NUM_TEST_EPISODES"], 1) # (n_agents*n_envs, hs_size)
step_state = (
env_state,
init_obs,
init_dones,
hstate,
_rng,
)
step_state, (rewards, dones, infos) = jax.lax.scan(
_greedy_env_step, step_state, None, config["NUM_STEPS"]
)
# compute the metrics of the first episode that is done for each parallel env
def first_episode_returns(rewards, dones):
first_done = jax.lax.select(jnp.argmax(dones)==0., dones.size, jnp.argmax(dones))
first_episode_mask = jnp.where(jnp.arange(dones.size) <= first_done, True, False)
return jnp.where(first_episode_mask, rewards, 0.).sum()
all_dones = dones['__all__']
first_returns = jax.tree.map(lambda r: jax.vmap(first_episode_returns, in_axes=1)(r, all_dones), rewards)
first_infos = jax.tree.map(lambda i: jax.vmap(first_episode_returns, in_axes=1)(i[..., 0], all_dones), infos)
metrics = {
'test_returns': first_returns['__all__'],# episode returns
**{'test_'+k:v for k,v in first_infos.items()}
}
if config.get('VERBOSE', False):
def callback(timestep, val):
print(f"Timestep: {timestep}, return: {val}")
jax.debug.callback(callback, time_state['timesteps']*config['NUM_ENVS'], first_returns['__all__'].mean())
return metrics
time_state = {
'timesteps':jnp.array(0),
'updates': jnp.array(0)
}
rng, _rng = jax.random.split(rng)
test_metrics = get_greedy_metrics(_rng, train_state, time_state) # initial greedy metrics
# train
rng, _rng = jax.random.split(rng)
runner_state = (
train_state,
target_network_state,
env_state,
buffer_state,
time_state,
init_obs,
init_dones,
test_metrics,
_rng
)
runner_state, metrics = jax.lax.scan(
_update_step, runner_state, None, config["NUM_UPDATES"]
)
return {'runner_state':runner_state, 'metrics':metrics}
return train
def env_from_config(config):
env_name = config["ENV_NAME"]
# smax init neeeds a scenario
if "smax" in env_name.lower():
config["ENV_KWARGS"]["scenario"] = map_name_to_scenario(config["MAP_NAME"])
env_name = f"{config['ENV_NAME']}_{config['MAP_NAME']}"
env = make(config["ENV_NAME"], **config["ENV_KWARGS"])
env = SMAXLogWrapper(env)
elif "mpe" in env_name.lower():
env = make(config["ENV_NAME"], **config["ENV_KWARGS"])
env = MPELogWrapper(env)
else:
raise NotImplementedError(f"Environment {env_name} not implemented.")
return env, env_name
def single_run(config):
config = {**config, **config["alg"]} # merge the alg config with the main config
print("Config:\n", OmegaConf.to_yaml(config))
alg_name = config.get("ALG_NAME", "transf_qmix")
env, env_name = env_from_config(copy.deepcopy(config))
wandb.init(
entity=config["ENTITY"],
project=config["PROJECT"],
tags=[
alg_name.upper(),
env_name.upper(),
f"jax_{jax.__version__}",
],
name=f"{alg_name}_{env_name}",
config=config,
mode=config["WANDB_MODE"],
)
rng = jax.random.PRNGKey(config["SEED"])
rngs = jax.random.split(rng, config["NUM_SEEDS"])
train_vjit = jax.jit(jax.vmap(make_train(config, env)))
outs = jax.block_until_ready(train_vjit(rngs))
# save params
if config.get("SAVE_PATH", None) is not None:
from jaxmarl.wrappers.baselines import save_params
model_state = outs["runner_state"][0]
save_dir = os.path.join(config["SAVE_PATH"], env_name)
os.makedirs(save_dir, exist_ok=True)
OmegaConf.save(
config,
os.path.join(
save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_config.yaml'
),
)
for i, rng in enumerate(rngs):
params = jax.tree.map(lambda x: x[i], model_state.params)
save_path = os.path.join(
save_dir,
f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors',
)
save_params(params, save_path)
def tune(default_config):