Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(ci): skip more tests on GPU CI #4200

Merged
merged 2 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import os
import sys
import unittest
from abc import (
ABC,
abstractmethod,
Expand Down Expand Up @@ -33,6 +34,11 @@
Backend,
)

from ..utils import (
CI,
TEST_DEVICE,
)

INSTALLED_TF = Backend.get_backend("tensorflow")().is_available()
INSTALLED_PT = Backend.get_backend("pytorch")().is_available()
INSTALLED_JAX = Backend.get_backend("jax")().is_available()
Expand Down Expand Up @@ -340,6 +346,7 @@ def test_tf_self_consistent(self):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_dp_consistent_with_ref(self):
"""Test whether DP and reference are consistent."""
if self.skip_dp:
Expand All @@ -358,6 +365,7 @@ def test_dp_consistent_with_ref(self):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_dp_self_consistent(self):
"""Test whether DP is self consistent."""
if self.skip_dp:
Expand Down Expand Up @@ -447,6 +455,7 @@ def test_jax_self_consistent(self):
else:
self.assertEqual(rr1, rr2)

@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_array_api_strict_consistent_with_ref(self):
"""Test whether array_api_strict and reference are consistent."""
if self.skip_array_api_strict:
Expand All @@ -465,6 +474,7 @@ def test_array_api_strict_consistent_with_ref(self):
np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol)
assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}"

@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_array_api_strict_self_consistent(self):
"""Test whether array_api_strict is self consistent."""
if self.skip_array_api_strict:
Expand Down
13 changes: 7 additions & 6 deletions source/tests/universal/common/cases/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
GLOBAL_SEED,
)
from .....utils import (
CI,
TEST_DEVICE,
)

Expand Down Expand Up @@ -327,7 +328,7 @@ def test_zero_forward(self):
continue
np.testing.assert_allclose(rr1, rr2, atol=aprec)

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_permutation(self):
"""Test permutation."""
if getattr(self, "skip_test_permutation", False):
Expand Down Expand Up @@ -413,7 +414,7 @@ def test_permutation(self):
else:
raise RuntimeError(f"Unknown output key: {kk}")

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_trans(self):
"""Test translation."""
if getattr(self, "skip_test_trans", False):
Expand Down Expand Up @@ -482,7 +483,7 @@ def test_trans(self):
else:
raise RuntimeError(f"Unknown output key: {kk}")

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_rot(self):
"""Test rotation."""
if getattr(self, "skip_test_rot", False):
Expand Down Expand Up @@ -672,7 +673,7 @@ def test_rot(self):
else:
raise RuntimeError(f"Unknown output key: {kk}")

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_smooth(self):
"""Test smooth."""
if getattr(self, "skip_test_smooth", False):
Expand Down Expand Up @@ -779,7 +780,7 @@ def test_smooth(self):
else:
raise RuntimeError(f"Unknown output key: {kk}")

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
def test_autodiff(self):
"""Test autodiff."""
if getattr(self, "skip_test_autodiff", False):
Expand Down Expand Up @@ -919,7 +920,7 @@ def ff_cell(bb):
# not support virial by far
pass

@unittest.skipIf(TEST_DEVICE == "cpu", "Skip test on CPU.")
@unittest.skipIf(TEST_DEVICE == "cpu" and CI, "Skip test on CPU.")
def test_device_consistence(self):
"""Test forward consistency between devices."""
test_spin = getattr(self, "test_spin", False)
Expand Down
13 changes: 7 additions & 6 deletions source/tests/universal/dpmodel/atomc_model/test_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
parameterized,
)
from ....utils import (
CI,
TEST_DEVICE,
)
from ...common.cases.atomic_model.atomic_model import (
Expand Down Expand Up @@ -98,7 +99,7 @@
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestEnergyAtomicModelDP(unittest.TestCase, EnerAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -165,7 +166,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestDosAtomicModelDP(unittest.TestCase, DosAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -227,7 +228,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestDipoleAtomicModelDP(unittest.TestCase, DipoleAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -290,7 +291,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestPolarAtomicModelDP(unittest.TestCase, PolarAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -351,7 +352,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestZBLAtomicModelDP(unittest.TestCase, ZBLAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -429,7 +430,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestPropertyAtomicModelDP(unittest.TestCase, PropertyAtomicModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down
3 changes: 2 additions & 1 deletion source/tests/universal/dpmodel/descriptor/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
GLOBAL_SEED,
)
from ....utils import (
CI,
TEST_DEVICE,
)
from ...common.cases.descriptor.descriptor import (
Expand Down Expand Up @@ -519,7 +520,7 @@ def DescriptorParamHybridMixedTTebd(ntypes, rcut, rcut_smth, sel, type_map, **kw
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
) # class_param & class
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestDescriptorDP(unittest.TestCase, DescriptorTest, DPTestCase):
def setUp(self):
DescriptorTest.setUp(self)
Expand Down
3 changes: 2 additions & 1 deletion source/tests/universal/dpmodel/fitting/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GLOBAL_SEED,
)
from ....utils import (
CI,
TEST_DEVICE,
)
from ...common.cases.fitting.fitting import (
Expand Down Expand Up @@ -236,7 +237,7 @@ def FittingParamProperty(
), # class_param & class
(True, False), # mixed_types
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestFittingDP(unittest.TestCase, FittingTest, DPTestCase):
def setUp(self):
((FittingParam, Fitting), self.mixed_types) = self.param
Expand Down
5 changes: 3 additions & 2 deletions source/tests/universal/dpmodel/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
parameterized,
)
from ....utils import (
CI,
TEST_DEVICE,
)
from ...common.cases.model.model import (
Expand Down Expand Up @@ -112,7 +113,7 @@ def skip_model_tests(test_obj):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestEnergyModelDP(unittest.TestCase, EnerModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -200,7 +201,7 @@ def setUpClass(cls):
), # fitting_class_param & class
),
)
@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestSpinEnergyModelDP(unittest.TestCase, SpinEnerModelTest, DPTestCase):
@classmethod
def setUpClass(cls):
Expand Down
3 changes: 2 additions & 1 deletion source/tests/universal/dpmodel/utils/test_type_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)

from ....utils import (
CI,
TEST_DEVICE,
)
from ...common.cases.utils.type_embed import (
Expand All @@ -16,7 +17,7 @@
)


@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
class TestTypeEmbd(unittest.TestCase, TypeEmbdTest, DPTestCase):
def setUp(self):
TypeEmbdTest.setUp(self)
Expand Down
3 changes: 3 additions & 0 deletions source/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@
TEST_DEVICE = "cpu"
else:
TEST_DEVICE = "cuda"

# see https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables
CI = os.environ.get("CI") == "true"