Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add back deterministic support in accelerator_connector #11999

Merged
merged 6 commits into from
Feb 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed default logger name to `lightning_logs` for consistency ([#11762](https://github.com/PyTorchLightning/pytorch-lightning/pull/11762))


- Rewrote `accelerator_connector` ([#11448](https://github.com/PyTorchLightning/pytorch-lightning/pull/11448))

### Deprecated

- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
Expand Down
22 changes: 21 additions & 1 deletion pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@
rank_zero_warn,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE
from pytorch_lightning.utilities.imports import (
_HOROVOD_AVAILABLE,
_IPU_AVAILABLE,
_TORCH_GREATER_EQUAL_1_8,
_TPU_AVAILABLE,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -141,6 +146,7 @@ def __init__(
torch.backends.cudnn.benchmark = benchmark
self.replace_sampler_ddp = replace_sampler_ddp
self.sync_batchnorm = sync_batchnorm
self._init_deterministic(deterministic)

# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
Expand Down Expand Up @@ -196,6 +202,20 @@ def __init__(
# 6. Instantiate Strategy - Part 2
self._lazy_init_strategy()

def _init_deterministic(self, deterministic: bool) -> None:
self.deterministic = deterministic
if _TORCH_GREATER_EQUAL_1_8:
torch.use_deterministic_algorithms(deterministic)
else:
torch.set_deterministic(deterministic)
if deterministic:
# fixing non-deterministic part of horovod
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
os.environ["HOROVOD_FUSION_THRESHOLD"] = "0"

# https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def _check_config_and_set_final_flags(
self,
strategy: Optional[Union[str, Strategy]],
Expand Down
9 changes: 9 additions & 0 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,12 @@ def test_passing_zero_and_empty_list_to_devices_flag():

with pytest.warns(UserWarning, match=r"switching to `cpu` accelerator"):
Trainer(accelerator="gpu", devices=[])


@pytest.mark.parametrize("deterministic", [True, False])
def test_deterministic_init(deterministic):
trainer = Trainer(accelerator="auto", deterministic=deterministic)
assert trainer._accelerator_connector.deterministic == deterministic
if deterministic:
assert os.environ.get("CUBLAS_WORKSPACE_CONFIG") == ":4096:8"
assert os.environ.get("HOROVOD_FUSION_THRESHOLD") == "0"