diff --git a/bmtrain/store.py b/bmtrain/store.py index c00e734f..254213bd 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -37,20 +37,6 @@ def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''): destination = hook_result return destination -#model: GPTModel -def _transform_state_dict_to_libcpm(model : torch.nn.Module, libcpm, destination=None, prefix=''): - config['save_param_to_cpu'] = False - #state_dict = OrderedDict() - for name, module in model._modules.items(): - if isinstance(model, TransformerBlockList): - for layer in model: - _transform_state_dict_to_libcpm(layer, libcpm, destination, prefix) - layer_state_dict = _save_to_local_rank0(module, destination, prefix + name + '.') - #state_dict.update(layer_state_dict) - if config['local_rank'] == 0: - libcpm.load_layer_state_dict(layer_state_dict) - #return state_dict - print("after transform state_dict to libcpm") def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''): if destination is None: @@ -69,6 +55,30 @@ def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''): else: model._save_to_state_dict(destination, prefix, False) return destination + +def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, prefix=''): + config['save_param_to_cpu'] = False + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) + _save_to_state_dict(model, config['local_rank'], destination, prefix) + for name, module in model._modules.items(): + if module is not None: + if isinstance(module, TransformerBlockList): + for local_name, local_module in module._modules.items(): + local_state_dict = _save_to_local_rank0(local_module, None, prefix + name + "." + local_name + '.') + if config['local_rank'] == 0: + infer_model.load_layer_state_dict(local_state_dict) + else: + _save_to_infer_model(module, infer_model, destination, prefix + name + '.') + for hook in model._state_dict_hooks.values(): + hook_result = hook(model, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + + if config['local_rank'] == 0: + infer_model.load_layer_state_dict(destination)