Skip to content

Commit

Permalink
Test for min torch version + fix all issues (#638)
Browse files Browse the repository at this point in the history
* Test for min torch

Co-authored-by: Sylvain Gugger <[email protected]>
  • Loading branch information
muellerzr and sgugger authored Sep 2, 2022
1 parent 3ab4651 commit 60d6807
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 26 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
pytorch-version: [
latest,
minimum
]
test-kind: [
test_prod,
test_core,
Expand Down Expand Up @@ -43,6 +47,7 @@ jobs:
if [[ ${{ matrix.test-kind }} = test_prod ]]; then pip install -e .[test_prod]; fi
if [[ ${{ matrix.test-kind }} != test_prod ]]; then pip install -e .[testing,test_trackers]; fi
if [[ ${{ matrix.test-kind }} = test_rest ]]; then pip uninstall comet_ml -y; fi
if [[ ${{ matrix.pytorch-version }} = minimum ]]; then pip install torch==1.6.0; fi
- name: Run Tests
run: |
Expand Down
6 changes: 2 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,7 @@ def __init__(
self.native_amp = False
err = "{mode} mixed precision requires {requirement}"
if self.state.mixed_precision == "fp16":
self.native_amp = is_torch_version(">=", "1.6")
if not self.native_amp:
raise ValueError(err.format(mode="fp16", requirement="PyTorch >= 1.6"))
self.native_amp = True
if not torch.cuda.is_available() and not parse_flag_from_env("USE_MPS_DEVICE"):
raise ValueError(err.format(mode="fp16", requirement="a GPU"))
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
Expand Down Expand Up @@ -314,7 +312,7 @@ def __init__(
# RNG Types
self.rng_types = rng_types
if self.rng_types is None:
self.rng_types = ["torch"] if is_torch_version("<=", "1.5.1") else ["generator"]
self.rng_types = ["generator"]

@property
def use_distributed(self):
Expand Down
11 changes: 11 additions & 0 deletions src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
load_checkpoint_in_model,
offload_state_dict,
)
from .utils.versions import is_torch_version


@contextmanager
Expand Down Expand Up @@ -59,6 +60,8 @@ def init_empty_weights(include_buffers: bool = False):
</Tip>
"""
if not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("Initializing empty weights to a meta device requires torch >= 1.9.0")
old_register_parameter = nn.Module.register_parameter
if include_buffers:
old_register_buffer = nn.Module.register_buffer
Expand Down Expand Up @@ -114,6 +117,8 @@ def cpu_offload(
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
"""
if not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("CPU offloading requires torch >= 1.9.0")
if execution_device is None:
execution_device = next(iter(model.parameters())).device
if state_dict is None:
Expand Down Expand Up @@ -157,6 +162,8 @@ def disk_offload(
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
"""
if not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("Disk offloading requires torch >= 1.9.0")
if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")):
offload_state_dict(offload_dir, model.state_dict())
if execution_device is None:
Expand Down Expand Up @@ -208,6 +215,8 @@ def dispatch_model(
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
"""
if not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("Model dispatching requires torch >= 1.9.0")
# Error early if the device map is incomplete.
check_device_map(model, device_map)

Expand Down Expand Up @@ -304,6 +313,8 @@ def load_checkpoint_and_dispatch(
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
"""
if not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("Loading and dispatching requires torch >= 1.9.0")
if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
raise ValueError(
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
Expand Down
4 changes: 2 additions & 2 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def total_dataset_length(self):
"timeout": 0,
"worker_init_fn": None,
"multiprocessing_context": None,
"generator": None,
}

# kwargs added after by version
_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {
"1.6.0": {"generator": None},
"1.7.0": {"prefetch_factor": 2, "persistent_workers": False},
}

Expand Down Expand Up @@ -412,7 +412,7 @@ def __init__(self, dataset, split_batches: bool = False, _drop_last: bool = Fals
self.split_batches = split_batches
if is_torch_version("<", "1.8.0"):
raise ImportError(
"Using `DataLoaderDispatcher` requires PyTorch 1.8.0 minimum. You have {torch.__version__}."
f"Using `DataLoaderDispatcher` requires PyTorch 1.8.0 minimum. You have {torch.__version__}."
)
if shuffle:
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
Expand Down
14 changes: 1 addition & 13 deletions src/accelerate/launchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

from .state import AcceleratorState
from .utils import PrecisionType, PrepareForLaunch, is_torch_version, patch_environment
from .utils import PrecisionType, PrepareForLaunch, patch_environment


def notebook_launcher(function, args=(), num_processes=None, use_fp16=False, mixed_precision="no", use_port="29500"):
Expand Down Expand Up @@ -90,12 +90,6 @@ def notebook_launcher(function, args=(), num_processes=None, use_fp16=False, mix

if num_processes > 1:
# Multi-GPU launch
if is_torch_version("<", "1.5.0"):
raise ImportError(
"Using `notebook_launcher` for distributed training on GPUs require torch >= 1.5.0, got "
f"{torch.__version__}."
)

from torch.multiprocessing import start_processes

if len(AcceleratorState._shared_state) > 0:
Expand Down Expand Up @@ -154,12 +148,6 @@ def debug_launcher(function, args=(), num_processes=2):
num_processes (`int`, *optional*, defaults to 2):
The number of processes to use for training.
"""
if is_torch_version("<", "1.5.0"):
raise ImportError(
"Using `debug_launcher` for distributed training on GPUs require torch >= 1.5.0, got "
f"{torch.__version__}."
)

from torch.multiprocessing import start_processes

with tempfile.NamedTemporaryFile() as tmp_file:
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
require_huggingface_suite,
require_multi_gpu,
require_single_gpu,
require_torch_min_version,
require_tpu,
skip,
slow,
Expand Down
9 changes: 4 additions & 5 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@ def rng_sync_check():
if state.distributed_type == DistributedType.MULTI_GPU:
synchronize_rng_states(["cuda"])
assert are_the_same_tensors(torch.cuda.get_rng_state()), "RNG states improperly synchronized on GPU."
if is_torch_version(">=", "1.6.0"):
generator = torch.Generator()
synchronize_rng_states(["generator"], generator=generator)
assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator."
generator = torch.Generator()
synchronize_rng_states(["generator"], generator=generator)
assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator."

if state.local_process_index == 0:
print("All rng are properly synched.")
Expand Down Expand Up @@ -339,7 +338,7 @@ def main():
if state.local_process_index == 0:
print("\n**DataLoader integration test**")
dl_preparation_check()
if state.distributed_type != DistributedType.TPU:
if state.distributed_type != DistributedType.TPU and is_torch_version(">=", "1.8.0"):
central_dl_preparation_check()

# Trainings are not exactly the same in DeepSpeed and CPU mode
Expand Down
11 changes: 11 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tempfile
import unittest
from distutils.util import strtobool
from functools import partial
from pathlib import Path
from typing import List, Union
from unittest import mock
Expand Down Expand Up @@ -132,6 +133,16 @@ def require_fsdp(test_case):
return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case)


def require_torch_min_version(test_case=None, version=None):
"""
Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an
installed torch version is less than the required one.
"""
if test_case is None:
return partial(require_torch_min_version, version=version)
return unittest.skipUnless(is_torch_version(">=", version), f"test requires torch version >= {version}")(test_case)


def require_tensorboard(test_case):
"""
Decorator marking a test that requires tensorboard installed. These tests are skipped when tensorboard isn't
Expand Down
3 changes: 2 additions & 1 deletion tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
load_checkpoint_and_dispatch,
)
from accelerate.hooks import remove_hook_from_submodules
from accelerate.test_utils import require_cuda, require_multi_gpu, slow
from accelerate.test_utils import require_cuda, require_multi_gpu, require_torch_min_version, slow
from accelerate.utils import offload_state_dict
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down Expand Up @@ -79,6 +79,7 @@ def forward(self, x):
return self.linear4(self.linear3(self.batchnorm(self.linear2(self.linear1(x)))))


@require_torch_min_version(version="1.9.0")
class BigModelingTester(unittest.TestCase):
def test_init_empty_weights(self):
# base use
Expand Down
3 changes: 2 additions & 1 deletion tests/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
remove_hook_from_module,
remove_hook_from_submodules,
)
from accelerate.test_utils import require_multi_gpu
from accelerate.test_utils import require_multi_gpu, require_torch_min_version


class ModelForTest(nn.Module):
Expand All @@ -51,6 +51,7 @@ def post_forward(self, module, output):
return output + 1


@require_torch_min_version(version="1.9.0")
class HooksModelTester(unittest.TestCase):
def test_add_and_remove_hooks(self):
test_model = ModelForTest()
Expand Down
2 changes: 2 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
require_huggingface_suite,
require_multi_gpu,
require_single_gpu,
require_torch_min_version,
)
from accelerate.utils import get_launch_prefix, patch_environment


@require_huggingface_suite
@require_torch_min_version(version="1.8.0")
class MetricTester(unittest.TestCase):
def setUp(self):
mod_file = inspect.getfile(accelerate.test_utils)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn as nn

from accelerate.test_utils import require_cuda, require_multi_gpu
from accelerate.test_utils.testing import require_torch_min_version
from accelerate.utils.modeling import (
check_device_map,
clean_device_map,
Expand All @@ -45,6 +46,7 @@ def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))


@require_torch_min_version(version="1.9.0")
class ModelingUtilsTester(unittest.TestCase):
def check_set_module_tensor_for_device(self, model, device1, device2):
self.assertEqual(model.linear1.weight.device, torch.device(device1))
Expand Down

0 comments on commit 60d6807

Please sign in to comment.