Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] Improve sharding example #937

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions tutorials/developer_api_guide/tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import torch
import torch.distributed as dist
from typing import Sequence
from torch.distributed import DeviceMesh
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed.tensor import DTensor, Replicate, Shard, Placement
from torch.utils._python_dispatch import return_and_correct_aliasing
from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults

Expand Down Expand Up @@ -101,18 +102,40 @@ def quantize(m: torch.nn.Module) -> torch.nn.Module:
)
return m

def shard(
full_tensor: torch.Tensor,
device_mesh: DeviceMesh,
placements: Sequence[Placement],
) -> DTensor:
"""
Add a shard function to simplify both colwise_shard and rowwise_shard. The
shard function accepts a full tensor, and returns a DTensor based on
indicated placements. Goal is to move the shard function as a static method
of DTensor, e.g.
dtensor = DTensor.shard(full_tensor, device_mesh, placement)
"""
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset

shape, offset = compute_local_shape_and_global_offset(
full_tensor.shape, device_mesh, placements
)
slices = [
slice(cur_offset, cur_offset + cur_shape)
for cur_shape, cur_offset in zip(shape, offset)
]
local_tensor = full_tensor[slices]
return DTensor.from_local(
local_tensor, device_mesh, placements
)

def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in column-wise fashion
"""
# Column-wise is wrt to A^T, so for A it is row-wise.
# Number of rows per rank
orig_weight = m.linear.weight
n_local_rows = orig_weight.size(0) // mesh.size()
rank = mesh.get_local_rank()
local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
dtensor = shard(orig_weight, mesh, [Shard(0)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
Expand All @@ -124,13 +147,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
Shard linear layer of the model in row-wise fashion
"""
# Row-wise is wrt to A^T, so for A it is column-wise.
# Number of rows per rank
orig_weight = m.linear.weight
n_local_cols = orig_weight.size(1) // mesh.size()
rank = mesh.get_local_rank()
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
dtensor = shard(orig_weight, mesh, [Shard(1)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
Expand Down
Loading