From c1d1a536b9e833bbef6e6b11a1ae48cb2a873bc6 Mon Sep 17 00:00:00 2001 From: Pete Walsh Date: Thu, 30 Jan 2025 14:25:41 -0800 Subject: [PATCH] update DTensor imports to use public module (#153) --- src/olmo_core/distributed/parallel/tensor_parallel.py | 3 +-- src/olmo_core/distributed/utils.py | 2 +- src/olmo_core/nn/attention.py | 3 +-- src/olmo_core/nn/lm_head.py | 2 +- src/olmo_core/nn/transformer/block.py | 3 +-- src/olmo_core/nn/transformer/model.py | 2 +- src/test/distributed/checkpoint/filesystem_test.py | 2 +- src/test/distributed/checkpoint_test.py | 2 +- src/test/nn/transformer/model_test.py | 2 +- 9 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/olmo_core/distributed/parallel/tensor_parallel.py b/src/olmo_core/distributed/parallel/tensor_parallel.py index 6fc58fc2..14e80f12 100644 --- a/src/olmo_core/distributed/parallel/tensor_parallel.py +++ b/src/olmo_core/distributed/parallel/tensor_parallel.py @@ -6,9 +6,8 @@ import torch import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed._tensor import Shard, distribute_module +from torch.distributed.tensor import Placement, Shard, distribute_module from torch.distributed.tensor.parallel import SequenceParallel as _SequenceParallel -from torch.distributed.tensor.placement_types import Placement from olmo_core.config import Config diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index b70b0953..22801e56 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -10,8 +10,8 @@ import torch import torch.distributed as dist -from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor from ..exceptions import OLMoEnvironmentError from ..utils import logging_configured, move_to_device, set_env_var diff --git a/src/olmo_core/nn/attention.py b/src/olmo_core/nn/attention.py index d7520ddf..1e302940 100644 --- a/src/olmo_core/nn/attention.py +++ b/src/olmo_core/nn/attention.py @@ -6,9 +6,8 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed import DeviceMesh -from torch.distributed._tensor import Shard +from torch.distributed.tensor import Placement, Shard from torch.distributed.tensor.parallel import parallelize_module -from torch.distributed.tensor.placement_types import Placement from ..config import Config, DType, StrEnum from ..distributed.parallel.tensor_parallel import SequenceParallel diff --git a/src/olmo_core/nn/lm_head.py b/src/olmo_core/nn/lm_head.py index 49559b4b..2da8e01e 100644 --- a/src/olmo_core/nn/lm_head.py +++ b/src/olmo_core/nn/lm_head.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor import Replicate, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, ParallelStyle, diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index db60e677..725d8eaf 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -6,9 +6,8 @@ import torch import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor import Placement, Replicate, Shard from torch.distributed.tensor.parallel import parallelize_module -from torch.distributed.tensor.placement_types import Placement from olmo_core.config import Config, StrEnum from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 929a1cb1..49fbfe4d 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -241,7 +241,7 @@ def apply_tp( :param loss_parallel: Set to ``True`` if parallelizing the loss function as well. :param float8_enabled: Set this to ``True`` if training with float8 linear layers. """ - from torch.distributed._tensor import Replicate + from torch.distributed.tensor import Replicate from torch.distributed.tensor.parallel import ( PrepareModuleInput, RowwiseParallel, diff --git a/src/test/distributed/checkpoint/filesystem_test.py b/src/test/distributed/checkpoint/filesystem_test.py index c89c627e..65660390 100644 --- a/src/test/distributed/checkpoint/filesystem_test.py +++ b/src/test/distributed/checkpoint/filesystem_test.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist import torch.distributed.checkpoint as distcp -from torch.distributed._tensor import Shard, distribute_tensor, init_device_mesh +from torch.distributed.tensor import Shard, distribute_tensor, init_device_mesh from olmo_core.distributed.checkpoint.filesystem import ( RemoteFileSystemReader, diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 8efb3d67..bcf93050 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -4,7 +4,7 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -from torch.distributed._tensor import init_device_mesh +from torch.distributed.tensor import init_device_mesh from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, diff --git a/src/test/nn/transformer/model_test.py b/src/test/nn/transformer/model_test.py index 5f2266b9..ab601784 100644 --- a/src/test/nn/transformer/model_test.py +++ b/src/test/nn/transformer/model_test.py @@ -3,7 +3,7 @@ import pytest import torch import torch.nn as nn -from torch.distributed._tensor import DTensor, init_device_mesh +from torch.distributed.tensor import DTensor, init_device_mesh from olmo_core.distributed.checkpoint import ( load_model_and_optim_state,