diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index edd67d6f8..715e2df51 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -239,7 +239,9 @@ def train_unizero( if cfg.policy.use_wandb: policy.set_train_iter_env_step(learner.train_iter, collector.envstep) - train_data.append({'train_which_component': 'transformer'}) + # train_data.append({'train_which_component': 'transformer'}) + train_data.append(learner.train_iter) + log_vars = learner.train(train_data, collector.envstep) if cfg.policy.use_priority: replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) diff --git a/lzero/model/common.py b/lzero/model/common.py index 53b61e185..dc8e44160 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -323,8 +323,8 @@ def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: i self.sim_norm = SimNorm(simnorm_dim=group_size) - # def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: - def forward(self, x: torch.Tensor, no_grad: bool = False) -> torch.Tensor: # TODO ====== + def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: # TODO: train projection ====== + # def forward(self, x: torch.Tensor, no_grad: bool = False) -> torch.Tensor: # TODO: train encoder ====== """ 前向传播,获取输入序列的语言表示。 diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 90af62ec0..88028e856 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -220,6 +220,8 @@ class UniZeroPolicy(MuZeroPolicy): policy_loss_weight=1, # (float) The weight of ssl (self-supervised learning) loss. ssl_loss_weight=0, + # (bool) Whether to use the cosine learning rate decay. + cos_lr_scheduler=False, # (bool) Whether to use piecewise constant learning rate decay. # i.e. lr: 0.2 -> 0.02 -> 0.002 piecewise_decay_lr_scheduler=False, @@ -300,6 +302,11 @@ def _init_learn(self) -> None: betas=(0.9, 0.95), ) + if self._cfg.cos_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR + # TODO: check the total training steps + self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) # Ensure that the installed torch version is greater than or equal to 2.0 @@ -335,6 +342,9 @@ def _init_learn(self) -> None: # TODO: add the model to wandb wandb.watch(self._learn_model.representation_network, log="all") + # TODO: ======== + self.accumulation_steps = 4 # 累积的步数 + # @profile def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ @@ -352,7 +362,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in self._learn_model.train() self._target_model.train() - current_batch, target_batch, _ = data + # current_batch, target_batch, _ = data + current_batch, target_batch, train_iter = data + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time = current_batch target_reward, target_value, target_policy = target_batch @@ -386,14 +398,14 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # print(f'transformed_target_value:{transformed_target_value}') # print("self.value_support:", self.value_support) - try: - target_value_categorical = phi_transform(self.value_support, transformed_target_value) - except Exception as e: - print('='*20) - print(e) - # print(f'transformed_target_value:{transformed_target_value}') - # print("self.value_support:", self.value_support) - print('='*20) + # try: + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + # except Exception as e: + # print('='*20) + # print(e) + # # print(f'transformed_target_value:{transformed_target_value}') + # # print("self.value_support:", self.value_support) + # print('='*20) # target_value_categorical = phi_transform(self.value_support, transformed_target_value) @@ -455,7 +467,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" # Core learn model update step - self._optimizer_world_model.zero_grad() + if train_iter % self.accumulation_steps == 0: # 每 accumulation_steps 步更新一次参数 + # print(f'train_iter:{train_iter}') + self._optimizer_world_model.zero_grad() + + weighted_total_loss = weighted_total_loss / self.accumulation_steps # 累积梯度时对 loss 进行缩放 + weighted_total_loss.backward() # ========== for debugging ========== @@ -471,15 +488,23 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value) - if self._cfg.multi_gpu: - self.sync_gradients(self._learn_model) - self._optimizer_world_model.step() - if self._cfg.piecewise_decay_lr_scheduler: - self.lr_scheduler.step() + if train_iter % self.accumulation_steps == 0: # 每 accumulation_steps 步更新一次参数 + # print(f'pos 2 train_iter:{train_iter}') + + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + + self._optimizer_world_model.step() + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Core target model update step + self._target_model.update(self._learn_model.state_dict()) - # Core target model update step - self._target_model.update(self._learn_model.state_dict()) + if self.accumulation_steps>1: + torch.cuda.empty_cache() if torch.cuda.is_available(): torch.cuda.synchronize() diff --git a/zoo/jericho/configs/jericho_ppo_config.py b/zoo/jericho/configs/jericho_ppo_config.py new file mode 100644 index 000000000..cdbd619b0 --- /dev/null +++ b/zoo/jericho/configs/jericho_ppo_config.py @@ -0,0 +1,102 @@ +from easydict import EasyDict +import torch.nn as nn + +action_space_size = 10 +max_steps = 50 +model_name = 'BAAI/bge-base-en-v1.5' +env_id = 'detective.z5' +evaluator_env_num = 2 + +# proj train +collector_env_num = 4 +batch_size = 32 + +# all train +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 2 +# batch_size = 4 +# num_unroll_steps = 5 +# infer_context_length = 2 +jericho_ppo_config = dict( + # exp_name=f"data_ppo_detective/jericho_ppo_projtrain_bs{batch_size}_seed0", + exp_name=f"data_ppo_detective/jericho_add-loc-inv_ppo_projtrain_bs{batch_size}_seed0", + env=dict( + remove_stuck_actions=False, + # remove_stuck_actions=True, + # add_location_and_inventory=True, + add_location_and_inventory=False, + + stop_value=int(1e6), + observation_shape=512, + max_steps=max_steps, + max_action_num=action_space_size, + tokenizer_path=model_name, + # tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594", + max_seq_len=512, + # game_path="z-machine-games-master/jericho-game-suite/" + env_id, + game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/" + env_id, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ) + ), + policy=dict( + cuda=True, + multi_agent=True, + action_space='discrete', + model=dict( + obs_shape=(26, 5, 4), # 没有起作用 + action_shape=action_space_size, + action_space='discrete', + encoder_hidden_size_list = [512], # encoder_hidden_size_list[-1]是head的输入维度 + actor_head_hidden_size= 512, + critic_head_hidden_size = 512, + ), + learn=dict( + epoch_per_collect=4, + batch_size=batch_size, + learning_rate=0.0005, + # entropy_weight=0.01, + entropy_weight=0.05, + value_norm=True, + grad_clip_value=10, + ), + collect=dict( + # n_sample=1024, + n_sample=320, # TODO: DEBUG + discount_factor=0.99, + gae_lambda=0.95, + ), + eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=5000, )), + ), +) +jericho_ppo_config = EasyDict(jericho_ppo_config) +main_config = jericho_ppo_config +cartpole_ppo_create_config = dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + # env_manager=dict(type='subprocess'), + env_manager=dict(type='base'), + policy=dict(type='ppo'), +) +cartpole_ppo_create_config = EasyDict(cartpole_ppo_create_config) +create_config = cartpole_ppo_create_config + + +if __name__ == "__main__": + from ding.entry import serial_pipeline_onpolicy + from ding.model.template import VAC + m = main_config.policy.model + import os + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + from lzero.model.common import HFLanguageRepresentationNetwork + encoder = HFLanguageRepresentationNetwork(url=model_name, embedding_size=512) + + model = VAC(obs_shape=m.obs_shape, action_shape=m.action_shape, action_space=m.action_space, encoder_hidden_size_list=m.encoder_hidden_size_list, + actor_head_hidden_size=m.actor_head_hidden_size, + critic_head_hidden_size =m.critic_head_hidden_size, encoder=encoder) + serial_pipeline_onpolicy([main_config, create_config], seed=0, model=model) diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index a7dca5678..651afdd53 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -18,18 +18,20 @@ def main(env_id='detective.z5', seed=0): num_simulations = 50 max_env_step = int(10e6) - # collector_env_num = 4 - # n_episode = 4 - # batch_size = 16 # proj train - # num_unroll_steps = 10 - # infer_context_length = 4 - - collector_env_num = 2 - n_episode = 2 - evaluator_env_num = 2 - batch_size = 4 # all train - num_unroll_steps = 5 - infer_context_length = 2 + # proj train + collector_env_num = 4 + n_episode = 4 + batch_size = 16 + num_unroll_steps = 10 + infer_context_length = 4 + + # all train + # collector_env_num = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # batch_size = 4 + # num_unroll_steps = 5 + # infer_context_length = 2 # batch_size = 16 # num_unroll_steps = 5 @@ -52,19 +54,14 @@ def main(env_id='detective.z5', seed=0): # model_name = 'google-bert/bert-base-uncased' # =========== TODO: only for debug =========== - # collector_env_num = 2 - # num_segments = 2 - # game_segment_length = 20 - # evaluator_env_num = 2 # max_env_step = int(5e5) # batch_size = 10 - # num_simulations = 5 + # num_simulations = 2 # num_unroll_steps = 5 # infer_context_length = 2 # max_steps = 10 # num_layers = 1 # replay_ratio = 0.05 - # embed_dim = 768 # TODO: MCTS内部的action_space受限于root节点的legal action # ============================================================== @@ -72,7 +69,9 @@ def main(env_id='detective.z5', seed=0): # ============================================================== jericho_unizero_config = dict( env=dict( - remove_stuck_actions=False, + # remove_stuck_actions=False, + remove_stuck_actions=True, + stop_value=int(1e6), observation_shape=512, max_steps=max_steps, @@ -168,7 +167,10 @@ def main(env_id='detective.z5', seed=0): main_config = jericho_unizero_config create_config = jericho_unizero_create_config - main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_all-train_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}-remove-novalid_proj-train-accstep4_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + # main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_proj-train-accstep4_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + + # main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_all-train_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' # main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_remove-novalid-action_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' from lzero.entry import train_unizero train_unizero([main_config, create_config], seed=seed, diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py index a17dd90d4..8266e9520 100644 --- a/zoo/jericho/envs/jericho_env.py +++ b/zoo/jericho/envs/jericho_env.py @@ -37,6 +37,7 @@ def __init__(self, cfg): # 新增:是否启用移除无效动作的功能 self.remove_stuck_actions = cfg.get('remove_stuck_actions', False) + self.add_location_and_inventory = cfg.get('add_location_and_inventory', False) if JerichoEnv.tokenizer is None: # 只让 rank 0 下载模型 @@ -48,7 +49,7 @@ def __init__(self, cfg): if self.rank != 0: # 非 rank 0 的进程从本地缓存加载 JerichoEnv.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path) - self._env = FrotzEnv(self.game_path) + self._env = FrotzEnv(self.game_path, 0) self._action_list = None self.finished = False self._init_flag = False @@ -69,12 +70,22 @@ def prepare_obs(self, obs, return_str: bool = False): # 根据是否启用移除无效动作的功能,调整可用动作列表 if self.remove_stuck_actions: available_actions = [a for a in self._action_list if a not in self.blocked_actions] + if len(available_actions) < 1 and len(self._action_list)>0: + # TODO========= + # import ipdb;ipdb.set_trace() + # if self._action_list is None + available_actions = [self._action_list[0]] self._action_list = available_actions else: available_actions = self._action_list - - full_obs = obs + "\nValid actions: " + str(available_actions) + # import ipdb;ipdb.set_trace() + if self.add_location_and_inventory: + look = self._env.get_player_location() + inv = self._env.get_inventory() + full_obs = "Location: " + str(look) + "\nInventory: " + str(inv) + obs + "\nValid actions: " + str(available_actions) + else: + full_obs = obs + "\nValid actions: " + str(available_actions) if not return_str: full_obs = JerichoEnv.tokenizer( @@ -90,10 +101,13 @@ def prepare_obs(self, obs, return_str: bool = False): action_mask = np.array(action_mask, dtype=np.int8) - if return_str: - return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1} + if return_str: # TODO=============== + # return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1} + return {'observation': full_obs, 'action_mask': action_mask} else: - return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1} + # return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1} + return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask} + def reset(self, return_str: bool = False): initial_observation, info = self._env.reset() @@ -102,6 +116,7 @@ def reset(self, return_str: bool = False): self._action_list = None self.episode_return = 0 self.env_step = 0 + self.timestep = 0 # 设置初始的 last_observation if self.remove_stuck_actions: @@ -129,19 +144,28 @@ def __repr__(self) -> str: return "LightZero Jericho Env" def step(self, action: int, return_str: bool = False): + # import ipdb;ipdb.set_trace() + # print(action) self.blocked_actions = set() if isinstance(action, str): action_str = action else: + if isinstance(action, np.ndarray): + action = int(action) try: action_str = self._action_list[action] except Exception as e: # TODO: 为什么会有非法动作 print('='*20) print(e, f'rank {self.rank}, action {action} is illegal now we randomly choose a legal action from {self._action_list}!') - action = np.random.choice(len(self._action_list)) - action_str = self._action_list[action] + + if len(self._action_list) > 0: + action = np.random.choice(len(self._action_list)) + action_str = self._action_list[action] + else: + action_str = 'go' + print(f'rank {self.rank}, len(self._action_list) == 0, self._env.get_valid_actions():{self._env.get_valid_actions()}') # 记录上一次的观察 if self.remove_stuck_actions and self.last_observation is not None: @@ -151,6 +175,13 @@ def step(self, action: int, return_str: bool = False): # 执行动作 observation, reward, done, info = self._env.step(action_str) + + self.timestep += 1 + # print(f'step: {self.timestep}, [OBS]:{observation} self._action_list:{self._action_list}') + + # TODO: for PPO + reward = np.array([float(reward)]) + self.env_step += 1 self.episode_return += reward self._action_list = None @@ -160,7 +191,7 @@ def step(self, action: int, return_str: bool = False): if observation == previous_obs: # 动作无效,移除该动作 self.blocked_actions.add(action_str) - print(f'[Removing action] "{action_str}" as it did not change the observation.') + # print(f'[Removing action] "{action_str}" as it did not change the observation.') # 更新上一次的观察 if self.remove_stuck_actions: @@ -211,7 +242,8 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]: max_env_step=100, tokenizer_path="google-bert/bert-base-uncased", max_seq_len=512, - remove_stuck_actions=True # 启用移除无效动作的功能 + remove_stuck_actions=True, # 启用移除无效动作的功能 + add_location_and_inventory=True ) ) env = JerichoEnv(env_cfg)