Skip to content

Commit

Permalink
update DTensor imports to use public module (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Jan 30, 2025
1 parent 4594231 commit c1d1a53
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 12 deletions.
3 changes: 1 addition & 2 deletions src/olmo_core/distributed/parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed._tensor import Shard, distribute_module
from torch.distributed.tensor import Placement, Shard, distribute_module
from torch.distributed.tensor.parallel import SequenceParallel as _SequenceParallel
from torch.distributed.tensor.placement_types import Placement

from olmo_core.config import Config

Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor

from ..exceptions import OLMoEnvironmentError
from ..utils import logging_configured, move_to_device, set_env_var
Expand Down
3 changes: 1 addition & 2 deletions src/olmo_core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import DeviceMesh
from torch.distributed._tensor import Shard
from torch.distributed.tensor import Placement, Shard
from torch.distributed.tensor.parallel import parallelize_module
from torch.distributed.tensor.placement_types import Placement

from ..config import Config, DType, StrEnum
from ..distributed.parallel.tensor_parallel import SequenceParallel
Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/nn/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
ParallelStyle,
Expand Down
3 changes: 1 addition & 2 deletions src/olmo_core/nn/transformer/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor import Placement, Replicate, Shard
from torch.distributed.tensor.parallel import parallelize_module
from torch.distributed.tensor.placement_types import Placement

from olmo_core.config import Config, StrEnum
from olmo_core.distributed.parallel.tensor_parallel import SequenceParallel
Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def apply_tp(
:param loss_parallel: Set to ``True`` if parallelizing the loss function as well.
:param float8_enabled: Set this to ``True`` if training with float8 linear layers.
"""
from torch.distributed._tensor import Replicate
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
PrepareModuleInput,
RowwiseParallel,
Expand Down
2 changes: 1 addition & 1 deletion src/test/distributed/checkpoint/filesystem_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as distcp
from torch.distributed._tensor import Shard, distribute_tensor, init_device_mesh
from torch.distributed.tensor import Shard, distribute_tensor, init_device_mesh

from olmo_core.distributed.checkpoint.filesystem import (
RemoteFileSystemReader,
Expand Down
2 changes: 1 addition & 1 deletion src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor import init_device_mesh
from torch.distributed.tensor import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
Expand Down
2 changes: 1 addition & 1 deletion src/test/nn/transformer/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch
import torch.nn as nn
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed.tensor import DTensor, init_device_mesh

from olmo_core.distributed.checkpoint import (
load_model_and_optim_state,
Expand Down

0 comments on commit c1d1a53

Please sign in to comment.