diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index cce7347f684..6a9e66d6a5a 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -13,8 +13,14 @@ # limitations under the License. import asyncio +import os +import shutil import sys +import tempfile import unittest +from pathlib import Path +from typing import List, Union +from unittest import mock import torch @@ -22,6 +28,69 @@ from ..utils import gather, is_tensorflow_available +class TempDirTestCase(unittest.TestCase): + """ + A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its + data at the start of a test, and then destroyes it at the end of the TestCase. + + Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases + + The temporary directory location will be stored in `self.tmpdir` + """ + + @classmethod + def setUpClass(cls): + "Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`" + cls.tmpdir = tempfile.mkdtemp() + + @classmethod + def tearDownClass(cls): + "Remove `cls.tmpdir` after test suite has finished" + if os.path.exists(cls.tmpdir): + shutil.rmtree(cls.tmpdir) + + def setUp(self): + "Destroy all contents in `self.tmpdir`, but not `self.tmpdir`" + for path in Path(self.tmpdir).glob("**/*"): + if path.is_file(): + path.unlink() + elif path.is_dir(): + shutil.rmtree(path) + + +class MockingTestCase(unittest.TestCase): + """ + A TestCase class designed to dynamically add various mockers that should be used in every test, mimicking the + behavior of a class-wide mock when defining one normally will not do. + + Useful when a mock requires specific information available only initialized after `TestCase.setUpClass`, such as + setting an environment variable with that information. + + The `add_mocks` function should be ran at the end of a `TestCase`'s `setUp` function, after a call to + `super().setUp()` such as: + ```python + def setUp(self): + super().setUp() + mocks = mock.patch.dict(os.environ, {"SOME_ENV_VAR", "SOME_VALUE"}) + self.add_mocks(mocks) + ``` + """ + + def add_mocks(self, mocks: Union[mock.Mock, List[mock.Mock]]): + """ + Add custom mocks for tests that should be repeated on each test. Should be called during + `MockingTestCase.setUp`, after `super().setUp()`. + + Args: + mocks (`mock.Mock` or list of `mock.Mock`): + Mocks that should be added to the `TestCase` after `TestCase.setUpClass` has been run + """ + self.mocks = mocks if isinstance(mocks, (tuple, list)) else [mocks] + for m in self.mocks: + m.start() + self.addCleanup(m.stop) + + def are_the_same_tensors(tensor): state = AcceleratorState() tensor = tensor[None].clone().to(state.device) diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 8d5ebc6bc08..b1f66b1ae99 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -15,7 +15,6 @@ import logging import os import re -import shutil import tempfile import unittest from pathlib import Path @@ -23,7 +22,7 @@ # We use TF to parse the logs from accelerate import Accelerator -from accelerate.test_utils.testing import require_tensorflow +from accelerate.test_utils.testing import MockingTestCase, TempDirTestCase, require_tensorflow from accelerate.utils import is_tensorflow_available @@ -43,13 +42,11 @@ def test_init_trackers(self): hps = None project_name = "test_project_with_config" with tempfile.TemporaryDirectory() as dirpath: - oldpwd = os.getcwd() - os.chdir(dirpath) - accelerator = Accelerator(log_with="tensorboard") + accelerator = Accelerator(log_with="tensorboard", logging_dir=dirpath) config = {"num_iterations": 12, "learning_rate": 1e-2, "some_boolean": False, "some_string": "some_value"} accelerator.init_trackers(project_name, config) accelerator.end_training() - for child in Path(project_name).glob("*/**"): + for child in Path(f"{dirpath}/{project_name}").glob("*/**"): log = list(filter(lambda x: x.is_file(), child.iterdir()))[0] # The config log is stored one layer deeper in the logged directory # And names are randomly generated each time @@ -61,7 +58,6 @@ def test_init_trackers(self): plugin_data = plugin_data_pb2.HParamsPluginData.FromString(proto_bytes) if plugin_data.HasField("session_start_info"): hps = dict(plugin_data.session_start_info.hparams) - os.chdir(oldpwd) self.assertTrue(isinstance(hps, dict)) keys = list(hps.keys()) @@ -77,16 +73,14 @@ def test_log(self): step = None project_name = "test_project_with_log" with tempfile.TemporaryDirectory() as dirpath: - oldpwd = os.getcwd() - os.chdir(dirpath) - accelerator = Accelerator(log_with="tensorboard") + accelerator = Accelerator(log_with="tensorboard", logging_dir=dirpath) accelerator.init_trackers(project_name) values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"} accelerator.log(values, step=0) accelerator.end_training() # Logged values are stored in the outermost-tfevents file and can be read in as a TFRecord # Names are randomly generated each time - log = list(filter(lambda x: x.is_file(), Path(project_name).iterdir()))[0] + log = list(filter(lambda x: x.is_file(), Path(f"{dirpath}/{project_name}").iterdir()))[0] serialized_examples = tf.data.TFRecordDataset(log) for e in serialized_examples: event = event_pb2.Event.FromString(e.numpy()) @@ -99,14 +93,18 @@ def test_log(self): iteration = value.simple_value elif value.tag == "my_text/text_summary": # Append /text_summary to the key my_text = value.tensor.string_val[0].decode() - os.chdir(oldpwd) self.assertAlmostEqual(total_loss, values["total_loss"]) self.assertEqual(iteration, values["iteration"]) self.assertEqual(my_text, values["my_text"]) @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"}) -class WandBTrackingTest(unittest.TestCase): +class WandBTrackingTest(TempDirTestCase, MockingTestCase): + def setUp(self): + super().setUp() + # wandb let's us override where logs are stored to via the WANDB_DIR env var + self.add_mocks(mock.patch.dict(os.environ, {"WANDB_DIR": self.tmpdir})) + @staticmethod def get_value_from_log(key: str, log: str, key_occurance: int = 0): """ @@ -126,7 +124,7 @@ def test_init_trackers(self): accelerator.init_trackers(project_name, config) accelerator.end_training() # The latest offline log is stored at wandb/latest-run/*.wandb - for child in Path("wandb/latest-run").glob("*"): + for child in Path(f"{self.tmpdir}/wandb/latest-run").glob("*"): logger.info(child) if child.is_file() and child.suffix == ".wandb": with open(child, "rb") as f: @@ -148,7 +146,7 @@ def test_log(self): accelerator.log(values, step=0) accelerator.end_training() # The latest offline log is stored at wandb/latest-run/*.wandb - for child in Path("wandb/latest-run").glob("*"): + for child in Path(f"{self.tmpdir}/wandb/latest-run").glob("*"): if child.is_file() and child.suffix == ".wandb": with open(child, "rb") as f: content = f.read() @@ -159,18 +157,3 @@ def test_log(self): self.assertEqual(self.get_value_from_log("iteration", cleaned_log), "1") self.assertEqual(self.get_value_from_log("my_text", cleaned_log), "some_value") self.assertEqual(self.get_value_from_log("_step", cleaned_log), "0") - - def setUp(self): - os.mkdir(".wandb_tests") - os.chdir(".wandb_tests") - - def tearDown(self): - if os.getcwd().endswith(".wandb_tests"): - os.chdir("..") - if os.path.exists(".wandb_tests"): - shutil.rmtree(".wandb_tests") - - @classmethod - def setUpClass(cls): - if os.path.exists(".wandb_tests"): - shutil.rmtree(".wandb_tests")