Skip to content

Commit

Permalink
add _save_to_infer_model
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangkaihuo committed Sep 25, 2023
1 parent e834efd commit 01d5be0
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions bmtrain/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,6 @@ def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''):
destination = hook_result
return destination

#model: GPTModel
def _transform_state_dict_to_libcpm(model : torch.nn.Module, libcpm, destination=None, prefix=''):
config['save_param_to_cpu'] = False
#state_dict = OrderedDict()
for name, module in model._modules.items():
if isinstance(model, TransformerBlockList):
for layer in model:
_transform_state_dict_to_libcpm(layer, libcpm, destination, prefix)
layer_state_dict = _save_to_local_rank0(module, destination, prefix + name + '.')
#state_dict.update(layer_state_dict)
if config['local_rank'] == 0:
libcpm.load_layer_state_dict(layer_state_dict)
#return state_dict
print("after transform state_dict to libcpm")

def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''):
if destination is None:
Expand All @@ -69,6 +55,30 @@ def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''):
else:
model._save_to_state_dict(destination, prefix, False)
return destination

def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, prefix=''):
config['save_param_to_cpu'] = False
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
_save_to_state_dict(model, config['local_rank'], destination, prefix)
for name, module in model._modules.items():
if module is not None:
if isinstance(module, TransformerBlockList):
for local_name, local_module in module._modules.items():
local_state_dict = _save_to_local_rank0(local_module, None, prefix + name + "." + local_name + '.')
if config['local_rank'] == 0:
infer_model.load_layer_state_dict(local_state_dict)
else:
_save_to_infer_model(module, infer_model, destination, prefix + name + '.')
for hook in model._state_dict_hooks.values():
hook_result = hook(model, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result

if config['local_rank'] == 0:
infer_model.load_layer_state_dict(destination)



Expand Down

0 comments on commit 01d5be0

Please sign in to comment.