Skip to content

Commit

Permalink
Support for unsharded parameters in state_dict APIs (#2023)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored Nov 19, 2024
1 parent 1814feb commit 352cf4b
Showing 1 changed file with 39 additions and 30 deletions.
69 changes: 39 additions & 30 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def load_from_full_model_state_dict(
requires_grad=sharded_meta_param.requires_grad,
)

elif not hasattr(sharded_meta_param, "device_mesh"):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
sharded_tensor = distribute_tensor(
full_tensor,
Expand All @@ -220,6 +223,30 @@ def load_from_full_model_state_dict(
return model.load_state_dict(sharded_sd, strict=strict, assign=True)


def _gather_nf4_tensor(sharded_param: nn.Parameter) -> nn.Parameter:
"""
Manually gather NF4Tensor parameter since it does not support all_gather
"""
mesh = sharded_param.device_mesh
nf4_tensor = sharded_param._local_tensor
quant_params, metadata = nf4_tensor.fsdp_pre_all_gather(mesh)
full_quant_params = []
for quant_param in quant_params:
d0, *dn = quant_param.shape
shape = (d0 * mesh.get_group().size(), *dn)
full_quant_param = torch.empty(
shape, device=quant_param.device, dtype=quant_param.dtype
)
dist.all_gather_into_tensor(
full_quant_param, quant_param, mesh.get_group(), async_op=False
)
full_quant_params.append(full_quant_param)
full_param, _ = nf4_tensor.fsdp_post_all_gather(
full_quant_params, metadata, nf4_tensor.dtype
)
return full_param


def gather_cpu_state_dict(
sharded_sd: Dict[str, DTensor], # noqa
is_rank_zero: bool,
Expand All @@ -238,39 +265,21 @@ def gather_cpu_state_dict(
Dict[str, Any]: State dict on CPU
"""
cpu_state_dict = {}
for param_name, sharded_param in sharded_sd.items():
if sharded_param.is_cpu:
for param_name, param in sharded_sd.items():
if param.is_cpu:
# Move back to device if offloaded to CPU
sharded_param = sharded_param.to(device)
if isinstance(sharded_param._local_tensor, NF4Tensor):
# NF4Tensor does not support all_gather from DTensor
# so we need to manually all_gather
mesh = sharded_param.device_mesh
nf4_tensor = sharded_param._local_tensor
quant_params, metadata = nf4_tensor.fsdp_pre_all_gather(mesh)
full_quant_params = []
for quant_param in quant_params:
d0, *dn = quant_param.shape
shape = (d0 * mesh.get_group().size(), *dn)
full_quant_param = torch.empty(
shape, device=quant_param.device, dtype=quant_param.dtype
)
dist.all_gather_into_tensor(
full_quant_param, quant_param, mesh.get_group(), async_op=False
)
full_quant_params.append(full_quant_param)
full_param, _ = nf4_tensor.fsdp_post_all_gather(
full_quant_params, metadata, nf4_tensor.dtype
)
param = param.to(device)
if hasattr(param, "_local_tensor"):
if isinstance(param._local_tensor, NF4Tensor):
param = _gather_nf4_tensor(param)
else:
# Gather DTensor
param = param.full_tensor()
if isinstance(param, NF4Tensor):
# upcasting NF4 to original dtype
full_param = full_param.to(full_param.dtype)
else:
# Gather DTensor
full_param = sharded_param.full_tensor()
param = param.to(param.dtype)
if is_rank_zero:
cpu_state_dict[param_name] = full_param.cpu()
else:
del full_param
cpu_state_dict[param_name] = param.cpu()
torch.distributed.barrier()
return cpu_state_dict

Expand Down

0 comments on commit 352cf4b

Please sign in to comment.