Skip to content

Commit

Permalink
Merge pull request #124 from MaciejSkrabski/update-tests
Browse files Browse the repository at this point in the history
skip multi-gpu test if not multi-gpu host
  • Loading branch information
WenjieDu authored May 21, 2023
2 parents a49676f + ca47365 commit fe5a358
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tests/test_training_on_multi_gpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import numpy as np
import pytest

import torch

from pypots.classification import BRITS, GRUD, Raindrop
from pypots.clustering import VaDER, CRLI
from pypots.forecasting import BTTF
Expand All @@ -33,7 +35,12 @@
)

EPOCHS = 5
DEVICES = ["cuda:0", "cuda:1"]

DEVICES = [torch.device(i) for i in range(torch.cuda.device_count())]
LESS_THAN_TWO_DEVICES = len(DEVICES) < 2

# global skip test if less than two cuda-enabled devices
pytestmark = pytest.mark.skipif(LESS_THAN_TWO_DEVICES, reason="not enough cuda devices")


TRAIN_SET = {"X": DATA["train_X"], "y": DATA["train_y"]}
Expand Down

0 comments on commit fe5a358

Please sign in to comment.