From b883291eff3c75d2dabc79ead12d0f5193ca3f27 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 24 Sep 2024 15:15:43 -0700 Subject: [PATCH 1/2] [Distributed] Improve sharding example --- .../developer_api_guide/tensor_parallel.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index a94d84fe05..2849892743 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -1,8 +1,9 @@ import os import torch import torch.distributed as dist +from typing import Sequence from torch.distributed import DeviceMesh -from torch.distributed._tensor import DTensor, Replicate, Shard +from torch.distributed.tensor import DTensor, Replicate, Shard, Placement from torch.utils._python_dispatch import return_and_correct_aliasing from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults @@ -101,18 +102,33 @@ def quantize(m: torch.nn.Module) -> torch.nn.Module: ) return m +def shard( + full_tensor: torch.Tensor, + device_mesh: DeviceMesh, + placements: Sequence[Placement], +) -> DTensor: + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, device_mesh, placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + local_tensor = full_tensor[slices] + return DTensor.from_local( + local_tensor, device_mesh, placements + ) + def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: """ Shard linear layer of the model in column-wise fashion """ # Column-wise is wrt to A^T, so for A it is row-wise. - # Number of rows per rank orig_weight = m.linear.weight - n_local_rows = orig_weight.size(0) // mesh.size() - rank = mesh.get_local_rank() - local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :] # Construct DTensor from local shard - dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) + dtensor = shard(orig_weight, mesh, [Shard(0)]) # Replace parameter in module m.linear.weight = torch.nn.Parameter( dtensor, requires_grad=False @@ -124,13 +140,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: Shard linear layer of the model in row-wise fashion """ # Row-wise is wrt to A^T, so for A it is column-wise. - # Number of rows per rank orig_weight = m.linear.weight - n_local_cols = orig_weight.size(1) // mesh.size() - rank = mesh.get_local_rank() - local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols] # Construct DTensor from local shard - dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) + dtensor = shard(orig_weight, mesh, [Shard(1)]) # Replace parameter in module m.linear.weight = torch.nn.Parameter( dtensor, requires_grad=False From 998b59519b2a3ef34726f8d4c78429cd27a9b1e2 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 24 Sep 2024 15:23:06 -0700 Subject: [PATCH 2/2] Add comment --- tutorials/developer_api_guide/tensor_parallel.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index 2849892743..db610a71fa 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -107,6 +107,13 @@ def shard( device_mesh: DeviceMesh, placements: Sequence[Placement], ) -> DTensor: + """ + Add a shard function to simplify both colwise_shard and rowwise_shard. The + shard function accepts a full tensor, and returns a DTensor based on + indicated placements. Goal is to move the shard function as a static method + of DTensor, e.g. + dtensor = DTensor.shard(full_tensor, device_mesh, placement) + """ from torch.distributed.tensor._utils import compute_local_shape_and_global_offset shape, offset = compute_local_shape_and_global_offset(