Skip to content

Commit

Permalink
running cuda related func inside test function instead calling at the…
Browse files Browse the repository at this point in the history
… top level
  • Loading branch information
llmpros committed Jul 6, 2024
1 parent e011cde commit d9de3ea
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/distributed/test_distributed_gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
import pytest

from tests.models.test_gptq_marlin import MODELS, run_test
from tests.quantization.utils import is_quant_method_supported
from tests.quantization.utils import (cuda_device_count_stateless,
is_quant_method_supported)


@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="gptq_marlin is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
Expand All @@ -28,6 +27,13 @@ def test_models(vllm_runner, example_prompts, model, dtype: str,
max_tokens: int, num_logprobs: int,
tensor_parallel_size: int) -> None:

if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip("gptq_marlin is not supported on this GPU type.")

if not is_quant_method_supported("gptq_marlin"):
pytest.skip(
f"Need at least {tensor_parallel_size} GPUs to run the test.")

distributed_executor_backend = os.getenv("DISTRIBUTED_EXECUTOR_BACKEND")
run_test(vllm_runner,
example_prompts,
Expand Down

0 comments on commit d9de3ea

Please sign in to comment.