Skip to content

Commit

Permalink
simplified code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangkaihuo committed Sep 26, 2023
1 parent 4d41af7 commit 5422fd7
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions bmtrain/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
for name, param in self._parameters.items():
if param is not None:
if isinstance(param, DistributedParameter):#and not param._in_block:
if config['save_param_to_cpu']:
if param._in_block:
destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation
else:
destination[prefix + name] = param.gather_all().detach().cpu() # sync operation
if param._in_block:
destination[prefix + name] = param.tp_gather().detach() # sync operation
else:
if param._in_block:
destination[prefix + name] = param.tp_gather().detach() # sync operation
else:
destination[prefix + name] = param.gather_all().detach() # sync operation
destination[prefix + name] = param.gather_all().detach() # sync operation
if config['save_param_to_cpu']:
destination[prefix + name] = destination[prefix + name].cpu()
else:
if config['save_param_to_cpu']:
destination[prefix + name] = param if keep_vars else param.detach().cpu()
Expand Down

0 comments on commit 5422fd7

Please sign in to comment.