Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu): add UniZero algo. and related configs/utils/envs/models #232

Merged
merged 184 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 174 commits
Commits
Show all changes
184 commits
Select commit Hold shift + click to select a range
36f69d0
feature(pu): add init version of gpt-based muzero
puyuan1996 Nov 3, 2023
344f9f6
sync code
puyuan1996 Nov 6, 2023
f922cfb
sync code
puyuan1996 Nov 7, 2023
6e436c5
sync code
puyuan1996 Nov 7, 2023
e0c2b95
fix(pu): fix reward/value/policy kl loss
puyuan1996 Nov 8, 2023
fc913c9
fix(pu): fix kv_cache used in MCTS search method
puyuan1996 Nov 8, 2023
8779039
Merge branch 'dev-transformer' of https://github.com/opendilab/LightZ…
puyuan1996 Nov 22, 2023
a20deba
feature(pu): add muzero_gpt for atari, polish world_model.past_keys_v…
puyuan1996 Nov 23, 2023
641680c
polish(pu): polish slicer
puyuan1996 Nov 27, 2023
36adc43
poliah(pu): polish compute_slice
puyuan1996 Nov 27, 2023
e39d421
fix(pu): fix init latent state, fix past_keys_values_cache in mcts se…
puyuan1996 Nov 28, 2023
80f2239
fix(pu): fix past_keys_values_cache in mcts search
puyuan1996 Nov 28, 2023
3cad2f2
polish(pu): polish world_model for batch processing, work in progress
puyuan1996 Nov 30, 2023
12d4174
feature(pu): dd test_slicer and test_slicer_time
puyuan1996 Nov 30, 2023
03dc66f
fix(pu): fix reward/value/policy loss in muzero_gpt
puyuan1996 Nov 30, 2023
26b6255
fix(pu): fix state_action_history bug in mcts_ctree for muzero_gpt
puyuan1996 Nov 30, 2023
fd85f4d
fix(pu): fix self.past_keys_values_cache bug in mcts_ctree for muzero…
puyuan1996 Nov 30, 2023
30e7882
polish(pu): use precomputed slices
puyuan1996 Dec 1, 2023
ff6419f
feature(pu): add target_policy_entropy log
puyuan1996 Dec 1, 2023
19c673f
feature(pu): add train tokenizer related loss
puyuan1996 Dec 7, 2023
91d2e63
polish(pu): polish obs_token_loss_weight
puyuan1996 Dec 9, 2023
18a5e7b
polish(pu): use different sample data for training transformer and to…
puyuan1996 Dec 12, 2023
8a6da1d
fix(pu): use argmax to select the most likely obs_token in recurrent_…
puyuan1996 Dec 14, 2023
20f7c4c
sync code
puyuan1996 Dec 15, 2023
8a1c418
fix(pu): tokenizer adam optimizer use weight_decay=0, set lr=1e-4 for…
puyuan1996 Dec 17, 2023
7be1c32
fix(pu): calculate target value in init_inference using unrolling 5 s…
puyuan1996 Dec 18, 2023
fa2263e
fix(pu): fix init_inference in mcts root state
puyuan1996 Dec 18, 2023
1c1c745
plish(pu): polish muzero_gpt config
puyuan1996 Dec 19, 2023
a22cc88
polish(pu): tokenizer delay update
puyuan1996 Dec 20, 2023
6c21c0f
polish(pu): polish configs
puyuan1996 Dec 21, 2023
08cc4f9
feature(pu): add world_mmodel using continous embeddings
puyuan1996 Dec 21, 2023
4d6d84a
feature(pu): polish obs_loss, a neg cos similarity loss
puyuan1996 Dec 21, 2023
af15791
polish(pu): polish configs
puyuan1996 Dec 21, 2023
fe4693e
polish(pu): use EZ-style SSL obs loss
jiayilee65 Dec 26, 2023
d681fdf
polish(pu): add support for k=1
jiayilee65 Dec 26, 2023
6907276
fix(pu): fix mask in obs loss
jiayilee65 Dec 27, 2023
06ac85f
fix(pu): use muzero-type representation network, use last linear in r…
jiayilee65 Dec 28, 2023
4604be1
feature(pu): add pong stack4 support, add lunarlander muzero_gpt config
jiayilee65 Dec 28, 2023
31efa2f
feature(pu): add some utils function
Jan 10, 2024
f8d88e6
fix(pu): latent state gradient times 0.2, set grad_clip_value to 0.5,…
Jan 11, 2024
c13fea7
polish(pu): polish unused debug code
Jan 11, 2024
01c8385
fix(pu): fix self.past_keys_values_cache.popitem cuda memory bug, use…
Jan 11, 2024
a538a80
fix(pu): world_model.past_keys_values_cache.clear() per 200 env steps…
Jan 13, 2024
d5a94a1
polish(pu): polish configs for stack=1 and stack=4
Jan 14, 2024
1dcbfec
sync code
Jan 17, 2024
51b80e6
polish(pu): when agent collect env_num segments then train
Jan 17, 2024
24865ef
fix(pu): when the agent collects <env_num> segments or one episode do…
Jan 17, 2024
e9a930e
polish(pu): add latent state soft-target, para reset options
Jan 19, 2024
a004364
polish(pu): use the new mcts visit counts to replace the old in buffer
Jan 20, 2024
b30051b
fix(pu): fix initial_inference in reanalyze
Jan 21, 2024
2c238e3
feature(pu): add init version of multi-task muzero
Jan 21, 2024
ace1d83
polish(pu): polish multi-task muzero
Jan 22, 2024
2562d38
fix(pu): fix multi-task muzero and multi-task muzero-gpt
Jan 22, 2024
67465a6
fix(pu): use normalized visit counts as reanalyzed target policy
jiayilee65 Jan 24, 2024
24c975e
feature(pu): add gtrxl gating option, init_infer not reset option
jiayilee65 Jan 26, 2024
4f9b711
feature(pu): add world_model_envnum8_kvcache-latent-one-env
jiayilee65 Jan 26, 2024
655b045
feature(pu): add envnum8_kv-latent-8-1-env world model
jiayilee65 Jan 28, 2024
7a149d5
fix(pu): fix init_infer kv_cache bug when reanalyze_ratio>0
jiayilee65 Feb 4, 2024
7f80343
fix(pu): fix kv_cache when env_num>1
puyuan1996 Feb 19, 2024
0ad129e
fix(pu): fix kv_cache, use one env in recurrent_inference() in search
puyuan1996 Feb 20, 2024
6eca8f7
fix(pu): use sequential (for loop) exec transformer() to ensure corre…
puyuan1996 Feb 23, 2024
97cf1d2
fix(pu): use sequential (for loop) exec transformer() to ensure corre…
puyuan1996 Feb 23, 2024
6612e64
Merge branch 'dev-xzero' of https://github.com/opendilab/LightZero in…
puyuan1996 Feb 23, 2024
9b50bbb
sync code
puyuan1996 Feb 26, 2024
e30e25e
fix(pu): fix world_model recurrent_inference kv_batch_pad
puyuan1996 Feb 26, 2024
01737ba
polish(pu): polish world_model recurrent_inference kv_batch_pad and k…
puyuan1996 Feb 27, 2024
913cbad
fix(pu): fix keys_values_wm_list duda cache bug
puyuan1996 Feb 27, 2024
b3b48c0
sync code
puyuan1996 Feb 27, 2024
9f5b941
fix(pu): use world_model_batch_pad_min and ooptimize kv_cache save me…
puyuan1996 Feb 28, 2024
d9ed770
polish(pu): polish world_model kv_cache statistics
puyuan1996 Feb 28, 2024
2062920
fix(pu): fix collector init_inference kv_cache in root node
puyuan1996 Mar 4, 2024
a3db9fd
polish(pu): polish world_model_batch_pad_min_fixroot_v2
puyuan1996 Mar 5, 2024
84a7ca3
polish(pu): polish world_model head structure and optim configs
puyuan1996 Mar 5, 2024
0beff7c
fix(pu): fix world_model root kv_cache, now almost %100 full context
puyuan1996 Mar 5, 2024
346ff2d
polish(pu): transformer attn calculate qkv in batch
puyuan1996 Mar 6, 2024
2096256
fix(pu): fix attn_mask in pytorch scaled_dot_product_attention
puyuan1996 Mar 6, 2024
c9cc588
polish(pu): polish xzero
puyuan1996 Mar 6, 2024
32aaf2a
feature(pu): add sim_nom, policy_entropy_loss options
puyuan1996 Mar 8, 2024
a7d5d21
polish(pu): polish xzero configs
puyuan1996 Mar 9, 2024
e809cd5
feature(pu): add MemoryEnvLightZero and related tests
puyuan1996 Mar 14, 2024
5fc7a69
feature(pu): add render and save_replay options in MemoryEnvLightZero
puyuan1996 Mar 14, 2024
7ee4e1d
polish(pu): delete unused files, add requirements
puyuan1996 Mar 14, 2024
612aa19
feature(pu): add memory train and eval configs
puyuan1996 Mar 14, 2024
1c13700
polish(pu): polish world model, use latent-groupkl-loss, no latent gr…
puyuan1996 Mar 14, 2024
f333b22
Merge branch 'dev-memory' of https://github.com/opendilab/LightZero i…
puyuan1996 Mar 16, 2024
8354885
fix(pu): fix merge errors
puyuan1996 Mar 16, 2024
6b594f3
polish(pu): polish prepare_obs_stack4_for_gpt
puyuan1996 Mar 16, 2024
b4c04ca
polish(pu): delete unused debug files
puyuan1996 Mar 16, 2024
f448f6d
feature(pu): add memory_xzero_config, polish code
puyuan1996 Mar 16, 2024
f15c85a
polish(pu): polish memory configs
puyuan1996 Mar 16, 2024
670cfa3
polish(pu): use learned-act-embeddings, use latent mse loss, clear-pe…
puyuan1996 Mar 17, 2024
eb850d2
fix(pu): fix memory_lightzero_env return bug
puyuan1996 Mar 19, 2024
bf26548
fix(pu): fix memory_eval.py
puyuan1996 Mar 21, 2024
735f44d
fix(pu): fix memory env obs scale bug
puyuan1996 Mar 21, 2024
c3031b0
polish(pu): polish memory_env eval config
puyuan1996 Mar 22, 2024
950639a
ppolish(pu): polish memory_env config
puyuan1996 Mar 22, 2024
ec7fa19
polish(pu): polish xzero config
puyuan1996 Mar 23, 2024
aabaa0d
fix(pu): memory_env use rgb_img observation (3,5,5), use cnn encoder …
puyuan1996 Mar 27, 2024
37b9a76
sync code
puyuan1996 Mar 24, 2024
9ecd46a
Merge branch 'dev-xzero-memory-debug' of https://github.com/opendilab…
jiayilee65 Mar 27, 2024
29fe1e0
fix(pu): use full episode to train world model in memory_env
puyuan1996 Mar 28, 2024
2bd1f80
Merge branch 'dev-xzero-memory-debug' of https://github.com/opendilab…
jiayilee65 Mar 28, 2024
ea59840
feature(pu): add temporal gamma discount option
jiayilee65 Mar 28, 2024
cb824c0
fix(pu): use self.past_keys_values_cache_init_infer and self.past_key…
jiayilee65 Apr 1, 2024
9644bdf
polish(pu): polish world_model
jiayilee65 Apr 1, 2024
d500c2c
fix(pu): use different context length for init_infer and recurrent_infer
jiayilee65 Apr 1, 2024
906856b
fix(pu): self.past_keys_values_cache_recurrent_infer.clear() per search
jiayilee65 Apr 2, 2024
4255702
feature(pu): add depth_in_search_path utils
jiayilee65 Apr 3, 2024
f5a928e
fix(pu): always save latest kv_cache for latent_state to tackle POMDP…
jiayilee65 Apr 4, 2024
5e3bd98
fix(pu): fix pos_embedding in kv_cache
jiayilee65 Apr 5, 2024
c0ff6df
polish(pu): polish world_model
jiayilee65 Apr 5, 2024
a343458
polish(pu):polish configs
jiayilee65 Apr 8, 2024
74356a0
fix(pu): fix loss_value bug
jiayilee65 Apr 8, 2024
5fa006c
polish(pu): polish memory env settings and config
jiayilee65 Apr 11, 2024
b7cd384
fix(pu): use polished and bigeer encoder/decoder net for memory env
jiayilee65 Apr 13, 2024
cbb3d2c
feature(pu): add visualize_reconstruction utils
jiayilee65 Apr 14, 2024
2d362ff
fix(pu): use batch max_kv_size for memory env
jiayilee65 Apr 16, 2024
6f99cef
fix(pu): fix the context of max_kv_size in batch
jiayilee65 Apr 16, 2024
51af708
fix(pu): fix the context of max_kv_size in batch
jiayilee65 Apr 17, 2024
aab5d8f
fix(pu): use train_episode_length + 5 as context_length for memory env
puyuan1996 Apr 18, 2024
ccbe071
polish(pu): rename muzero_gpt to unizero, add muzero_context variant
puyuan1996 Apr 22, 2024
f05516c
feature(pu): add l2 norm of latent state, grad_norm, dormant_ratio st…
puyuan1996 Apr 22, 2024
1af5f92
feature(pu): add muzero_rnn variant
puyuan1996 Apr 23, 2024
dfc47e1
fix(pu): fix the cuda_cache of cal_dormant_ratio and FeatureAndGradie…
puyuan1996 Apr 23, 2024
b8ae3b8
feature(pu): add atari 100k muzero configs
puyuan1996 Apr 23, 2024
3509e4e
feature(pu): add unizero atari100k 26 games config
puyuan1996 Apr 24, 2024
3da1042
feature(pu): add attention_map visualize utils
puyuan1996 Apr 25, 2024
fd3f7a5
polish(pu): polish attention_map visualize utils
jiayilee65 Apr 27, 2024
3a44e8a
feature(pu): add muzero_rnn_fullobs variants
jiayilee65 Apr 30, 2024
4168d9f
fix(pu): fix muzero_rnn_fullobs variants
jiayilee65 May 1, 2024
959959e
fix(pu): fix muzero_rnn_fullobs variants
jiayilee65 May 2, 2024
8f1e079
sync code
jiayilee65 May 8, 2024
c31c43d
sync code
jiayilee65 May 11, 2024
9c5b647
sync code
jiayilee65 May 18, 2024
29ebad5
fix(pu): fix efficientzero atari (4,64,64) obs configs
jiayilee65 May 19, 2024
8e779d4
sync code
dyyoungg May 31, 2024
bd058c0
polish(pu): polish unizero world models
puyuan1996 Jun 10, 2024
431c8cc
refactor(pu): refactor __init__ and foward method of world model
puyuan1996 Jun 10, 2024
6da6785
polish(pu): polish forward_init_inference of unizero world models
puyuan1996 Jun 11, 2024
e4aa0dd
polish(pu): polish forward_recurrent_inference, kv_cache mechanisms a…
puyuan1996 Jun 11, 2024
fca75f3
polish(pu): polish utils.py, visualize_utils.py and transformer.py
puyuan1996 Jun 11, 2024
fc49cae
polish(pu): polish unizero.py and train_unizero.py
puyuan1996 Jun 11, 2024
5a67892
polish(pu): polish unizero model and tree_search
puyuan1996 Jun 11, 2024
8d89d7f
polish(pu): polish game_buffer_unizero
puyuan1996 Jun 11, 2024
e507b1b
polish(pu): polish cartpole/lunarlander/memory unizero configs
puyuan1996 Jun 11, 2024
dd60665
polish(pu): polish muzero_variants configs/model/policy
puyuan1996 Jun 11, 2024
55c1a64
Merge tag 'main' of https://github.com/opendilab/LightZero into dev-u…
puyuan1996 Jun 11, 2024
7a09c6f
polish(pu): delete some unused backup policy utils
puyuan1996 Jun 11, 2024
ca0b633
polish(pu): polish unizero config
puyuan1996 Jun 11, 2024
d79cf79
polish(pu): add pyecharts requirements
puyuan1996 Jun 12, 2024
6373b36
polish(pu): polish requirements
puyuan1996 Jun 12, 2024
e50c357
polish(pu): polish LN eps in encoder model
puyuan1996 Jun 12, 2024
17ea7ce
polish(pu): polish train_unizero.py
puyuan1996 Jun 12, 2024
e15dcea
fix(pu): fix softmax operation in inverse_scalar_transform
jiayilee65 Jun 12, 2024
3eb00e5
polish(pu): polish LN eps and init
dyyoungg Jun 13, 2024
68a16b7
polish(pu): polish visualize config and utils
dyyoungg Jun 13, 2024
7fb3cb8
feature(pu): add debug loggs
dyyoungg Jun 13, 2024
c60073e
polish(pu): polish configs
puyuan1996 Jun 28, 2024
890742a
Merge remote-tracking branch 'origin/main' into dev-unizero
puyuan1996 Jun 28, 2024
b93cf42
polish(pu): polish comments in models
puyuan1996 Jun 28, 2024
17c0cd8
fix(pu): fix ready_env_id in policy
puyuan1996 Jul 1, 2024
f3d7369
fix(pu): fix wrong modifications
puyuan1996 Jul 1, 2024
563a183
polish(pu): polish muzero variants
puyuan1996 Jul 1, 2024
4c90207
polish(pu): make muzero variants inherit from muzero
puyuan1996 Jul 1, 2024
d4eab1e
polish(pu): polish unizero comments
puyuan1996 Jul 1, 2024
6d366e8
polish(pu): polish lunarlander configs, fix muzero_collector env_id o…
puyuan1996 Jul 2, 2024
d5e958b
fix(pu): fix muzero_evaluator env_id operations, polish _reset_colle…
puyuan1996 Jul 2, 2024
428379c
polish(pu): polish configs
puyuan1996 Jul 2, 2024
7305526
polish(pu): polish reset and del method of muzero
puyuan1996 Jul 2, 2024
1dd01c7
polish(pu): polish comments and typelint
puyuan1996 Jul 2, 2024
ee905fb
polish(pu): add line_profiler requirements
puyuan1996 Jul 3, 2024
3ce2fb5
fix(pu): fix muzero target_policy_entropy bug
puyuan1996 Jul 3, 2024
d8c5649
fix(pu): fix device and use gym
dyyoungg Jul 3, 2024
f908378
Merge branch 'dev-unizero' of https://github.com/opendilab/LightZero …
dyyoungg Jul 3, 2024
9813ce4
polish(pu): move initialize_zeros_batch to reset method of unizero
puyuan1996 Jul 3, 2024
fb9594a
polish(pu): polish model_path in configs
puyuan1996 Jul 3, 2024
79a4567
polish(pu): polish code and comments
puyuan1996 Jul 3, 2024
d97128f
polish(pu): polish unizero utils
puyuan1996 Jul 3, 2024
7b08a5e
polish(pu): polish KeysValues to_device method
puyuan1996 Jul 3, 2024
9d739e8
polish(pu): delete bsuite_unizero_config.py
puyuan1996 Jul 3, 2024
c957a2d
polish(pu): polish atari_unizero_configs
puyuan1996 Jul 3, 2024
63507b1
polish(pu): polish code style and comments
puyuan1996 Jul 3, 2024
c1e0228
polish(pu): polish code style and comments
puyuan1996 Jul 3, 2024
3882296
polish(pu): polish comments in model/common.py
puyuan1996 Jul 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ log*
default*
events.*

# DI-engine special key
# LightZero special key
*default_logger.txt
*default_tb_logger
*evaluate.txt
Expand All @@ -1444,4 +1444,5 @@ events.*
!/lzero/mcts/**/lib/*.h
**/tb/*
**/mcts/ctree/tests_cpp/*
**/*tmp*
**/*tmp*
lzero/mcts/ctree/ctree_alphazero/pybind11
8 changes: 5 additions & 3 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .train_alphazero import train_alphazero
from .eval_alphazero import eval_alphazero
from .train_muzero import train_muzero
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .eval_muzero import eval_muzero
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
from .train_alphazero import train_alphazero
from .train_muzero import train_muzero
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .train_rezero import train_rezero
from .train_unizero import train_unizero
14 changes: 11 additions & 3 deletions lzero/entry/eval_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner
from lzero.worker import MuZeroEvaluator
from lzero.entry.utils import initialize_zeros_batch


def eval_muzero(
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -25,7 +26,7 @@ def eval_muzero(
) -> 'Policy': # noqa
"""
Overview:
The eval entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero.
The eval entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, StochasticMuZero, GumbelMuZero, UniZero, etc.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
Expand All @@ -38,8 +39,8 @@ def eval_muzero(
- policy (:obj:`Policy`): Converged policy.
"""
cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'"
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero', 'unizero'], \
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero', 'unizero'"

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
Expand Down Expand Up @@ -85,6 +86,13 @@ def eval_muzero(
# Learner's before_run hook.
learner.call_hook('before_run')

policy.last_batch_obs = initialize_zeros_batch(
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
cfg.policy.model.observation_shape,
len(evaluator_env_cfg),
cfg.policy.device
)
policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))]

while True:
# ==============================================================
# eval trained model
Expand Down
6 changes: 3 additions & 3 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect
from .utils import random_collect, initialize_zeros_batch


def train_muzero(
Expand Down Expand Up @@ -47,10 +47,10 @@ def train_muzero(
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'"

if create_cfg.policy.type == 'muzero':
if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn_full_obs']:
from lzero.mcts import MuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'efficientzero':
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
Expand Down
207 changes: 207 additions & 0 deletions lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import logging
import os
from functools import partial
from typing import Tuple, Optional

import torch
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect, initialize_zeros_batch


def train_unizero(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy':
"""
Overview:
The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models.
UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms,
particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""

cfg, create_cfg = input_cfg

# Ensure the specified policy type is supported
assert create_cfg.policy.type in ['unizero'], "train_unizero entry now only supports the following algo.: 'unizero'"

# Import the correct GameBuffer class based on the policy type
game_buffer_classes = {'unizero': 'UniZeroGameBuffer'}

GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
game_buffer_classes[create_cfg.policy.type])

# Set device based on CUDA availability
cfg.policy.device = cfg.policy.model.world_model.device if torch.cuda.is_available() else 'cpu'
logging.info(f'cfg.policy.device: {cfg.policy.device}')

# Compile the configuration
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)

# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# Load pretrained model if specified
if model_path is not None:
logging.info(f'Loading model from {model_path} begin...')
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
logging.info(f'Loading model from {model_path} end!')

# Create worker components: learner, collector, evaluator, replay buffer, commander
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# MCTS+RL algorithms related core code
policy_config = cfg.policy
replay_buffer = GameBuffer(policy_config)
collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name,
policy_config=policy_config)
evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode,
tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config)

# Learner's before_run hook
learner.call_hook('before_run')

# Collect random data before training
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)

policy.last_batch_obs = initialize_zeros_batch(cfg.policy.model.observation_shape, len(evaluator_env_cfg),
cfg.policy.device)
policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))]
batch_size = policy._cfg.batch_size

# TODO: for visualize
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)

while True:
# Log buffer memory usage
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)

# Set temperature for visit count distributions
collect_kwargs = {
'temperature': visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
),
'epsilon': 0.0 # Default epsilon value
}

# Configure epsilon for epsilon-greedy exploration
if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)

# Evaluate policy performance
if evaluator.should_eval(learner.train_iter):
policy.last_batch_obs = initialize_zeros_batch(
cfg.policy.model.observation_shape, len(evaluator_env_cfg), cfg.policy.device
)
policy.last_batch_action = [-1] * len(evaluator_env_cfg)
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break

# Collect new data
policy.last_batch_obs = initialize_zeros_batch(
cfg.policy.model.observation_shape, len(collector_env_cfg), cfg.policy.device
)
policy.last_batch_action = [-1] * len(collector_env_cfg)
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = cfg.policy.update_per_collect
if update_per_collect is None:
collected_transitions_num = sum(len(game_segment) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
replay_buffer.remove_oldest_data_to_fit()

# Train the policy if sufficient data is available
if collector.envstep > cfg.policy.train_start_after_envsteps:
# data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size # For 'episode' sample type
data_sufficient = replay_buffer.get_num_of_transitions() > batch_size
if not data_sufficient:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....'
)
continue

for i in range(update_per_collect):
if data_sufficient:
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
train_data = replay_buffer.sample(batch_size, policy)
if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0:
# Clear caches
for model in [policy._collect_model, policy._target_model]:
# model.world_model.precompute_pos_emb_diff_kv() # TODO
model.world_model.clear_caches()

torch.cuda.empty_cache()

train_data.append({'train_which_component': 'transformer'})
log_vars = learner.train(train_data, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

# Clear caches and precompute positional embedding matrices
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move these part to the __del__ method of world_model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为是每个train epoch都需要调用一次,因为我在unizero中新建了一个recompute_pos_emb_diff_and_clear_cache() methid哈

for model in [policy._collect_model, policy._target_model]:
model.world_model.precompute_pos_emb_diff_kv()
model.world_model.clear_caches()

torch.cuda.empty_cache()

# Check stopping criteria
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

learner.call_hook('after_run')
return policy
22 changes: 22 additions & 0 deletions lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,29 @@
import psutil
from pympler.asizeof import asizeof
from tensorboardX import SummaryWriter
from typing import Optional, Callable
import torch


def initialize_zeros_batch(observation_shape, batch_size, device):
"""
Overview:
Initialize a zeros tensor for batch observations based on the shape. This function is used to initialize the UniZero model input.
Arguments:
- observation_shape (:obj:`Union[int, List[int]]`): The shape of the observation tensor.
- batch_size (:obj:`int`): The batch size.
- device (:obj:`str`): The device to store the tensor.
Returns:
- zeros (:obj:`torch.Tensor`): The zeros tensor.
"""
if isinstance(observation_shape, list):
shape = [batch_size, *observation_shape]
elif isinstance(observation_shape, int):
shape = [batch_size, observation_shape]
else:
raise TypeError("observation_shape must be either an int or a list")

return torch.zeros(shape).to(device)

def random_collect(
policy_cfg: 'EasyDict', # noqa
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/buffer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .game_buffer_muzero import MuZeroGameBuffer
from .game_buffer_unizero import UniZeroGameBuffer
from .game_buffer_efficientzero import EfficientZeroGameBuffer
from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer
Expand Down
Loading
Loading