diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index a94d84fe05..db610a71fa 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -1,8 +1,9 @@ import os import torch import torch.distributed as dist +from typing import Sequence from torch.distributed import DeviceMesh -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed.tensor import DTensor, Replicate, Shard, Placement from torch.utils._python_dispatch import return_and_correct_aliasing from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults @@ -101,18 +102,40 @@ def quantize(m: torch.nn.Module) -> torch.nn.Module: ) return m +def shard( + full_tensor: torch.Tensor, + device_mesh: DeviceMesh, + placements: Sequence[Placement], +) -> DTensor: + """ + Add a shard function to simplify both colwise_shard and rowwise_shard. The + shard function accepts a full tensor, and returns a DTensor based on + indicated placements. Goal is to move the shard function as a static method + of DTensor, e.g. + dtensor = DTensor.shard(full_tensor, device_mesh, placement) + """ + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, device_mesh, placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + local_tensor = full_tensor[slices] + return DTensor.from_local( + local_tensor, device_mesh, placements + ) + def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: """ Shard linear layer of the model in column-wise fashion """ # Column-wise is wrt to A^T, so for A it is row-wise. - # Number of rows per rank orig_weight = m.linear.weight - n_local_rows = orig_weight.size(0) // mesh.size() - rank = mesh.get_local_rank() - local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] # Construct DTensor from local shard - dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) + dtensor = shard(orig_weight, mesh, [Shard(0)]) # Replace parameter in module m.linear.weight = torch.nn.Parameter( dtensor, requires_grad=False @@ -124,13 +147,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: Shard linear layer of the model in row-wise fashion """ # Row-wise is wrt to A^T, so for A it is column-wise. - # Number of rows per rank orig_weight = m.linear.weight - n_local_cols = orig_weight.size(1) // mesh.size() - rank = mesh.get_local_rank() - local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] # Construct DTensor from local shard - dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) + dtensor = shard(orig_weight, mesh, [Shard(1)]) # Replace parameter in module m.linear.weight = torch.nn.Parameter( dtensor, requires_grad=False