Skip to content

Commit

Permalink
Fix tests when PyTorch built with MPS support (#1188)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Mar 22, 2023
1 parent 0851e6d commit e9113e3
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 19 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ filterwarnings = [
"ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:pretrainedmodels.datasets.utils",
# https://github.com/pytorch/vision/pull/5898
"ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:torchvision.transforms.functional_pil",
"ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:torchvision.transforms._functional_pil",
# https://github.com/rwightman/pytorch-image-models/pull/1256
"ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:timm.data",
# https://github.com/pytorch/pytorch/issues/72906
Expand Down Expand Up @@ -99,8 +100,11 @@ filterwarnings = [
# Expected warnings
# Lightning warns us about using num_workers=0, but it's faster on macOS
"ignore:The dataloader, .*, does not have many workers which may be a bottleneck:UserWarning",
# Lightning warns us about using the CPU when a GPU is available
# Lightning warns us about using the CPU when GPU/MPS is available
"ignore:GPU available but not used.:UserWarning",
"ignore:MPS available but not used.:UserWarning",
# Lightning warns us if TensorBoard is not installed
"ignore:Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package:UserWarning",
# https://github.com/kornia/kornia/pull/1611
"ignore:`ColorJitter` is now following Torchvision implementation.:DeprecationWarning:kornia.augmentation._2d.intensity.color_jitter",
# https://github.com/kornia/kornia/pull/1663
Expand Down
4 changes: 2 additions & 2 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class TestGeoDataModule:
@pytest.fixture(params=[SamplerGeoDataModule, BatchSamplerGeoDataModule])
def datamodule(self, request: SubRequest) -> CustomGeoDataModule:
dm: CustomGeoDataModule = request.param()
dm.trainer = Trainer(max_epochs=1)
dm.trainer = Trainer(accelerator="cpu", max_epochs=1)
return dm

@pytest.mark.parametrize("stage", ["fit", "validate", "test"])
Expand Down Expand Up @@ -145,7 +145,7 @@ class TestNonGeoDataModule:
@pytest.fixture
def datamodule(self) -> CustomNonGeoDataModule:
dm = CustomNonGeoDataModule()
dm.trainer = Trainer(max_epochs=1)
dm.trainer = Trainer(accelerator="cpu", max_epochs=1)
return dm

@pytest.mark.parametrize("stage", ["fit", "validate", "test", "predict"])
Expand Down
2 changes: 1 addition & 1 deletion tests/datamodules/test_oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def datamodule(self, request: SubRequest) -> OSCDDataModule:
num_workers=0,
)
dm.prepare_data()
dm.trainer = Trainer(max_epochs=1)
dm.trainer = Trainer(accelerator="cpu", max_epochs=1)
return dm

def test_train_dataloader(self, datamodule: OSCDDataModule) -> None:
Expand Down
7 changes: 6 additions & 1 deletion tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ def test_trainer(
model.backbone = SegmentationTestModel(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)

@pytest.fixture
Expand Down
42 changes: 36 additions & 6 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,12 @@ def test_trainer(
model = ClassificationTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down Expand Up @@ -208,15 +213,25 @@ def test_no_rgb(
root="tests/data/eurosat", batch_size=1, num_workers=0
)
model = ClassificationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None:
datamodule = PredictClassificationDataModule(
root="tests/data/eurosat", batch_size=1, num_workers=0
)
model = ClassificationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)


Expand Down Expand Up @@ -250,7 +265,12 @@ def test_trainer(
model = MultiLabelClassificationTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down Expand Up @@ -285,13 +305,23 @@ def test_no_rgb(
root="tests/data/bigearthnet", batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None:
datamodule = PredictMultiLabelClassificationDataModule(
root="tests/data/bigearthnet", batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)
21 changes: 18 additions & 3 deletions tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ def test_trainer(
model = ObjectDetectionTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down Expand Up @@ -131,13 +136,23 @@ def test_no_rgb(
root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0
)
model = ObjectDetectionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None:
datamodule = PredictObjectDetectionDataModule(
root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0
)
model = ObjectDetectionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)
21 changes: 18 additions & 3 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ def test_trainer(
model.model = RegressionTestModel()

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down Expand Up @@ -165,13 +170,23 @@ def test_no_rgb(
root="tests/data/cyclone", batch_size=1, num_workers=0
)
model = RegressionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None:
datamodule = PredictRegressionDataModule(
root="tests/data/cyclone", batch_size=1, num_workers=0
)
model = RegressionTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)
14 changes: 12 additions & 2 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,12 @@ def test_trainer(
model = SemanticSegmentationTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
try:
trainer.test(model=model, datamodule=datamodule)
Expand Down Expand Up @@ -161,5 +166,10 @@ def test_no_rgb(
root="tests/data/sen12ms", batch_size=1, num_workers=0
)
model = SemanticSegmentationTask(**model_kwargs)
trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

0 comments on commit e9113e3

Please sign in to comment.