Skip to content

Commit

Permalink
Remove duplicated test classes (#14122)
Browse files Browse the repository at this point in the history
Remove duplicated classes
  • Loading branch information
carmocca authored Aug 10, 2022
1 parent 4e87a44 commit 9b61b1c
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from tests_pytorch.helpers.datasets import RandomIterableDataset
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf


Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import StochasticWeightAveraging
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from pytorch_lightning.strategies import DDPSpawnStrategy, Strategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datasets import RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf


Expand Down
39 changes: 1 addition & 38 deletions tests/tests_pytorch/helpers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Optional, Sequence, Tuple

import torch
from torch.utils.data import Dataset, IterableDataset
from torch.utils.data import Dataset


class MNIST(Dataset):
Expand Down Expand Up @@ -212,40 +212,3 @@ def __getitem__(self, idx):

def __len__(self):
return len(self.y)


class RandomDictDataset(Dataset):
def __init__(self, size: int, length: int):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
a = self.data[index]
b = a + 2
return {"a": a, "b": b}

def __len__(self):
return self.len


class RandomIterableDataset(IterableDataset):
def __init__(self, size: int, count: int):
self.count = count
self.size = size

def __iter__(self):
for _ in range(self.count):
yield torch.randn(self.size)


class RandomIterableDatasetWithLen(IterableDataset):
def __init__(self, size: int, count: int):
self.count = count
self.size = size

def __iter__(self):
for _ in range(len(self)):
yield torch.randn(self.size)

def __len__(self):
return self.count
3 changes: 1 addition & 2 deletions tests/tests_pytorch/strategies/test_deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@

from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin
from pytorch_lightning.strategies import DeepSpeedStrategy
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.datasets import RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf

if _DEEPSPEED_AVAILABLE:
Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/trainer/flags/test_val_check_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
import pytest
from torch.utils.data import DataLoader

from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datasets import RandomIterableDataset


@pytest.mark.parametrize("max_epochs", [1, 2, 3])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@
from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomDictDataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datasets import RandomDictDataset
from tests_pytorch.helpers.runif import RunIf


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.demos.boring_classes import BoringModel, RandomIterableDataset
from pytorch_lightning.strategies.ipu import IPUStrategy
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datasets import RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf


Expand Down
8 changes: 6 additions & 2 deletions tests/tests_pytorch/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@

from pytorch_lightning import Callback, seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.demos.boring_classes import (
BoringModel,
RandomDataset,
RandomIterableDataset,
RandomIterableDatasetWithLen,
)
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset, has_len_all_ranks
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen
from tests_pytorch.helpers.runif import RunIf


Expand Down
8 changes: 6 additions & 2 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@
from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.demos.boring_classes import (
BoringModel,
RandomDataset,
RandomIterableDataset,
RandomIterableDatasetWithLen,
)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from pytorch_lightning.strategies import (
Expand All @@ -60,7 +65,6 @@
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.seed import seed_everything
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.datasets import RandomIterableDataset, RandomIterableDatasetWithLen
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel

Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.data import (
Expand All @@ -23,7 +23,6 @@
warning_cache,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datasets import RandomIterableDataset
from tests_pytorch.helpers.utils import no_warning_call


Expand Down

0 comments on commit 9b61b1c

Please sign in to comment.