diff --git a/bmtrain/init.py b/bmtrain/init.py index cd304f89..b7224f94 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -82,6 +82,7 @@ def init_distributed( config["zero_rank"] = config['topology'].get_group_rank("zero") config["tp_rank"] = config['topology'].get_group_rank("tp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") + config["save_param_to_cpu"] = True cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) diff --git a/bmtrain/layer.py b/bmtrain/layer.py index cf46814b..e071e01b 100644 --- a/bmtrain/layer.py +++ b/bmtrain/layer.py @@ -34,11 +34,17 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): if param is not None: if isinstance(param, DistributedParameter):#and not param._in_block: if param._in_block: - destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation + destination[prefix + name] = param.tp_gather().detach() # sync operation else: - destination[prefix + name] = param.gather_all().detach().cpu() # sync operation + destination[prefix + name] = param.gather_all().detach() # sync operation + if config['save_param_to_cpu']: + destination[prefix + name] = destination[prefix + name].cpu() else: - destination[prefix + name] = param if keep_vars else param.detach().cpu() + if config['save_param_to_cpu']: + destination[prefix + name] = param if keep_vars else param.detach().cpu() + else: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: destination[prefix + name] = buf if keep_vars else buf.detach() diff --git a/bmtrain/store.py b/bmtrain/store.py index 88ed7305..254213bd 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -3,31 +3,48 @@ import torch from .pipe_layer import PipelineTransformerBlockList +from .block_layer import TransformerBlockList from .global_var import config from .block_layer import Block from . import nccl import io, pickle from typing import Mapping -def _save_to_state_dict(model : torch.nn.Module, destination, prefix): +def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix): if isinstance(model, Block): - if config['rank'] != 0: + if rank != 0: destination = OrderedDict() # creates an temporary ordered dict destination._metadata = OrderedDict() model.state_dict(destination=destination, prefix=prefix, keep_vars=False) else: - if config['rank'] != 0: + if rank != 0: destination = OrderedDict() # creates an temporary ordered dict destination._metadata = OrderedDict() model._save_to_state_dict(destination, prefix, False) +def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''): + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) + _save_to_state_dict(model, config['local_rank'], destination, prefix) + for name, module in model._modules.items(): + if module is not None: + _save_to_local_rank0(module, destination, prefix + name + '.') + for hook in model._state_dict_hooks.values(): + hook_result = hook(model, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''): if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) if not isinstance(model, PipelineTransformerBlockList): - _save_to_state_dict(model, destination, prefix) + _save_to_state_dict(model, config['rank'], destination, prefix) for name, module in model._modules.items(): if module is not None: _save_to_rank0(module, destination, prefix + name + '.') @@ -38,6 +55,30 @@ def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''): else: model._save_to_state_dict(destination, prefix, False) return destination + +def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, prefix=''): + config['save_param_to_cpu'] = False + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) + _save_to_state_dict(model, config['local_rank'], destination, prefix) + for name, module in model._modules.items(): + if module is not None: + if isinstance(module, TransformerBlockList): + for local_name, local_module in module._modules.items(): + local_state_dict = _save_to_local_rank0(local_module, None, prefix + name + "." + local_name + '.') + if config['local_rank'] == 0: + infer_model.load_layer_state_dict(local_state_dict) + else: + _save_to_infer_model(module, infer_model, destination, prefix + name + '.') + for hook in model._state_dict_hooks.values(): + hook_result = hook(model, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + + if config['local_rank'] == 0: + infer_model.load_layer_state_dict(destination)