Skip to content

Commit

Permalink
add _save_to_infer_model (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
zkh2016 authored Sep 27, 2023
1 parent 95417b5 commit 25e3671
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 7 deletions.
1 change: 1 addition & 0 deletions bmtrain/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def init_distributed(
config["zero_rank"] = config['topology'].get_group_rank("zero")
config["tp_rank"] = config['topology'].get_group_rank("tp")
config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero")
config["save_param_to_cpu"] = True
cpus_this_worker = None

all_available_cpus = sorted(list(os.sched_getaffinity(0)))
Expand Down
12 changes: 9 additions & 3 deletions bmtrain/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,17 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
if param is not None:
if isinstance(param, DistributedParameter):#and not param._in_block:
if param._in_block:
destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation
destination[prefix + name] = param.tp_gather().detach() # sync operation
else:
destination[prefix + name] = param.gather_all().detach().cpu() # sync operation
destination[prefix + name] = param.gather_all().detach() # sync operation
if config['save_param_to_cpu']:
destination[prefix + name] = destination[prefix + name].cpu()
else:
destination[prefix + name] = param if keep_vars else param.detach().cpu()
if config['save_param_to_cpu']:
destination[prefix + name] = param if keep_vars else param.detach().cpu()
else:
destination[prefix + name] = param if keep_vars else param.detach()

for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
Expand Down
49 changes: 45 additions & 4 deletions bmtrain/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,48 @@
import torch

from .pipe_layer import PipelineTransformerBlockList
from .block_layer import TransformerBlockList
from .global_var import config
from .block_layer import Block
from . import nccl
import io, pickle
from typing import Mapping

def _save_to_state_dict(model : torch.nn.Module, destination, prefix):
def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix):
if isinstance(model, Block):
if config['rank'] != 0:
if rank != 0:
destination = OrderedDict() # creates an temporary ordered dict
destination._metadata = OrderedDict()
model.state_dict(destination=destination, prefix=prefix, keep_vars=False)
else:
if config['rank'] != 0:
if rank != 0:
destination = OrderedDict() # creates an temporary ordered dict
destination._metadata = OrderedDict()
model._save_to_state_dict(destination, prefix, False)

def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''):
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
_save_to_state_dict(model, config['local_rank'], destination, prefix)
for name, module in model._modules.items():
if module is not None:
_save_to_local_rank0(module, destination, prefix + name + '.')
for hook in model._state_dict_hooks.values():
hook_result = hook(model, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination


def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''):
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
if not isinstance(model, PipelineTransformerBlockList):
_save_to_state_dict(model, destination, prefix)
_save_to_state_dict(model, config['rank'], destination, prefix)
for name, module in model._modules.items():
if module is not None:
_save_to_rank0(module, destination, prefix + name + '.')
Expand All @@ -38,6 +55,30 @@ def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''):
else:
model._save_to_state_dict(destination, prefix, False)
return destination

def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, prefix=''):
config['save_param_to_cpu'] = False
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
_save_to_state_dict(model, config['local_rank'], destination, prefix)
for name, module in model._modules.items():
if module is not None:
if isinstance(module, TransformerBlockList):
for local_name, local_module in module._modules.items():
local_state_dict = _save_to_local_rank0(local_module, None, prefix + name + "." + local_name + '.')
if config['local_rank'] == 0:
infer_model.load_layer_state_dict(local_state_dict)
else:
_save_to_infer_model(module, infer_model, destination, prefix + name + '.')
for hook in model._state_dict_hooks.values():
hook_result = hook(model, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result

if config['local_rank'] == 0:
infer_model.load_layer_state_dict(destination)



Expand Down

0 comments on commit 25e3671

Please sign in to comment.