Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix trainer distributed attributes #5303

Merged
merged 4 commits into from
Dec 31, 2020
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DeprecatedDistDeviceAttributes:
@property
def on_cpu(self) -> bool:
# rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.CPU
return self._device_type is not None and self._device_type == DeviceType.CPU

@on_cpu.setter
def on_cpu(self, val: bool) -> None:
Expand All @@ -35,7 +35,7 @@ def on_cpu(self, val: bool) -> None:
@property
def on_tpu(self) -> bool:
# rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.TPU
return self._device_type is not None and self._device_type == DeviceType.TPU

@on_tpu.setter
def on_tpu(self, val: bool) -> None:
Expand All @@ -52,13 +52,12 @@ def use_tpu(self) -> bool:
@use_tpu.setter
def use_tpu(self, val: bool) -> None:
# rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
# todo add logic that it cannot be set if TPU is missing
self.on_tpu = val

@property
def on_gpu(self) -> bool:
# rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.GPU
return self._device_type is not None and self._device_type == DeviceType.GPU

@on_gpu.setter
def on_gpu(self, val: bool) -> None:
Expand All @@ -70,7 +69,7 @@ def on_gpu(self, val: bool) -> None:
@property
def use_dp(self) -> bool:
# rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._distrib_type == DistributedType.DP
return self._distrib_type is not None and self._distrib_type == DistributedType.DP

@use_dp.setter
def use_dp(self, val: bool) -> None:
Expand All @@ -81,7 +80,7 @@ def use_dp(self, val: bool) -> None:
@property
def use_ddp(self) -> bool:
# rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._distrib_type == DistributedType.DDP
return self._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)

@use_ddp.setter
def use_ddp(self, val: bool) -> None:
Expand All @@ -92,7 +91,7 @@ def use_ddp(self, val: bool) -> None:
@property
def use_ddp2(self) -> bool:
# rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._distrib_type == DistributedType.DDP2
return self._distrib_type is not None and self._distrib_type == DistributedType.DDP2

@use_ddp2.setter
def use_ddp2(self, val: bool) -> None:
Expand All @@ -105,7 +104,7 @@ def use_horovod(self) -> bool:
# rank_zero_warn(
# "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning
# )
return self._device_type and self._distrib_type == DistributedType.HOROVOD
return self._distrib_type is not None and self._distrib_type == DistributedType.HOROVOD

@use_horovod.setter
def use_horovod(self, val: bool) -> None:
Expand Down