diff --git a/test/float8/test_base.py b/test/float8/test_base.py index db66a206ef..8fb3921f67 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -231,15 +231,6 @@ def test_linear( linear_dtype: torch.dtype, linear_bias: bool, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) @@ -287,16 +278,6 @@ def test_autocast_outputs( emulate: bool, linear_dtype: torch.dtype, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) config = Float8LinearConfig( cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), @@ -334,10 +315,6 @@ def test_autocast_outputs( @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): - emulate = ( - not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0) - ) - m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) config = Float8LinearConfig(emulate=emulate) m = Float8Linear.from_float(copy.deepcopy(m), config) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 8a0458bec3..bae62bf77d 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -224,7 +224,8 @@ def forward(self, x): return x_hp return x_fp8 - @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available") + # TODO(future): figure out why the test below fails on CUDA capability 8.9 + @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with capability 9.0 or greater not available") def test_float8_with_graph_break_in_the_middle(self): """Test that having Float8Tensor object at the boundary of a subgraph""" cnts = CompileCounterWithBackend("inductor")