Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device updates for TPU Pod #7243

Merged
merged 5 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def log(
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
# TODO: Find a way to make the reduction only once, so we don't need to clone.
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
if (is_dist_initialized or tpu_distributed()) and isinstance(value, torch.Tensor):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pytorch_lightning.utilities.seed import reset_seed

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.core.xla_model import rendezvous
Expand All @@ -58,7 +59,7 @@ def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = F

@property
def global_rank(self) -> int:
return self.tpu_local_core_rank
return self.tpu_global_core_rank

@property
def local_rank(self) -> int:
Expand Down Expand Up @@ -175,7 +176,8 @@ def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.device)

def barrier(self, name: Optional[str] = None) -> None:
if tpu_distributed():
# HOST_WORLD_SIZE is None outside the xmp.spawn process
if os.getenv(xenv.HOST_WORLD_SIZE, None) and tpu_distributed():
rendezvous(name)

def transfer_distrib_spawn_state_on_fit_end(self, results):
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/utilities/xla_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ def _is_device_tpu() -> bool:
Return:
A boolean value indicating if TPU devices are available
"""

# For the TPU Pod training process, for example, if we have
# TPU v3-32 with 4 VMs, the world size would be 4 and as
# we would have to use `torch_xla.distributed.xla_dist` for
# multiple VMs and TPU_CONFIG won't be available, running
# `xm.get_xla_supported_devices("TPU")` won't be possible.
if xm.xrt_world_size() > 1:
return True
return len(xm.get_xla_supported_devices("TPU")) > 0

@staticmethod
Expand Down
27 changes: 27 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,30 @@ def teardown(self, stage):

model = DebugModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_host_world_size(tmpdir):
"""Test Host World size env setup on TPU."""

class DebugModel(BoringModel):

def on_train_start(self):
assert os.environ.get("XRT_HOST_WORLD_SIZE") == str(1)

def teardown(self, stage):
assert "XRT_HOST_WORLD_SIZE" not in os.environ
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=8,
limit_train_batches=0.4,
limit_val_batches=0.4,
)

model = DebugModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)