diff --git a/tests/__init__.py b/tests/__init__.py index 433f183896dee..fc634e6b73fec 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -19,8 +19,8 @@ _TEST_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_TEST_ROOT) _TEMP_PATH = os.path.join(_PROJECT_ROOT, 'test_temp') -DATASETS_PATH = os.path.join(_PROJECT_ROOT, 'Datasets') -LEGACY_PATH = os.path.join(_PROJECT_ROOT, 'legacy') +PATH_DATASETS = os.path.join(_PROJECT_ROOT, 'Datasets') +PATH_LEGACY = os.path.join(_PROJECT_ROOT, 'legacy') # todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages if _PROJECT_ROOT not in os.getenv('PYTHONPATH', ""): diff --git a/tests/base/model_template.py b/tests/base/model_template.py index 1ec2df7865caa..86578fef4c699 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from pytorch_lightning.core.lightning import LightningModule +from tests import PATH_DATASETS from tests.base.model_optimizers import ConfigureOptimizersPool from tests.base.model_test_dataloaders import TestDataloaderVariations from tests.base.model_test_epoch_ends import TestEpochEndVariations @@ -28,7 +29,7 @@ from tests.base.model_valid_dataloaders import ValDataloaderVariations from tests.base.model_valid_epoch_ends import ValidationEpochEndVariations from tests.base.model_valid_steps import ValidationStepVariations -from tests.helpers.datasets import PATH_DATASETS, TrialMNIST +from tests.helpers.datasets import TrialMNIST class EvalModelTemplate( diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 0b47e25c7ce23..325cc4925f4f4 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -18,9 +18,9 @@ import pytest from pytorch_lightning import Trainer -from tests import LEGACY_PATH +from tests import PATH_LEGACY -LEGACY_CHECKPOINTS_PATH = os.path.join(LEGACY_PATH, 'checkpoints') +LEGACY_CHECKPOINTS_PATH = os.path.join(PATH_LEGACY, 'checkpoints') CHECKPOINT_EXTENSION = ".ckpt" diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index e7bdad0f1538c..312cbb32b162e 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -22,10 +22,7 @@ from torch import Tensor from torch.utils.data import Dataset -from tests import _PROJECT_ROOT - -#: local path to test datasets -PATH_DATASETS = os.path.join(_PROJECT_ROOT, 'Datasets') +from tests import PATH_DATASETS class MNIST(Dataset):