Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Skip some tests in Python < 3.10 (#257)
Browse files Browse the repository at this point in the history
This PR skips the `mosaicml/mpt-7b` model in the
`tests/models/test_big_models.py::test_models` test when using Python
3.8 since that model has custom code that does not work in Python 3.8.

It also skips the clean shutdown test in Python < 3.10 as that was
failing for some reason in our CI (running the file through pytest
locally passes).
  • Loading branch information
dbarbuzzi authored May 22, 2024
1 parent c57b35d commit 79bb15a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/engine/test_multiproc_workers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
# UPSTREAM SYNC
import sys
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from time import sleep
Expand Down Expand Up @@ -100,6 +102,11 @@ def execute_workers(worker_input: str) -> None:
def test_local_workers_clean_shutdown() -> None:
"""Test clean shutdown"""

# UPSTREAM SYNC
pytest.mark.skipif(sys.version_info < (3, 10),
reason="This test is inexplicably failing in CI "
"on Python < 3.10")

workers, worker_monitor = _start_workers()

assert worker_monitor.is_alive()
Expand Down
12 changes: 12 additions & 0 deletions tests/models/test_big_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
Run `pytest tests/models/test_big_models.py`.
"""
# UPSTREAM SYNC
import sys

import pytest

MODELS = [
Expand All @@ -27,6 +30,11 @@
"EleutherAI/gpt-j-6b",
]

# UPSTREAM SYNC
SKIPPED_MODELS_PY38 = [
"mosaicml/mpt-7b",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
Expand All @@ -45,6 +53,10 @@ def test_models(
if model in SKIPPED_MODELS_OOM:
pytest.skip(reason="These models cause OOM issue on the CPU"
"because it is a fp32 checkpoint.")
# UPSTREAM SYNC
if model in SKIPPED_MODELS_PY38 and sys.version_info < (3, 9):
pytest.skip(reason="This model has custom code that does not "
"support Python 3.8")

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
Expand Down

0 comments on commit 79bb15a

Please sign in to comment.