diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 50cac80d094..843351a2591 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -224,7 +224,6 @@ def _download_next_item(self) -> None: job.job_started = get_iso_timestamp() self._do_download(job) self._signal_job_complete(job) - except (OSError, HTTPError) as excp: job.error_type = excp.__class__.__name__ + f"({str(excp)})" job.error = traceback.format_exc() diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 737f62a0649..4f2cdaed8e8 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -28,6 +28,7 @@ class InstallStatus(str, Enum): WAITING = "waiting" # waiting to be dequeued DOWNLOADING = "downloading" # downloading of model files in process + DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run RUNNING = "running" # being processed COMPLETED = "completed" # finished running ERROR = "error" # terminated with an error message @@ -229,6 +230,11 @@ def downloading(self) -> bool: """Return true if job is downloading.""" return self.status == InstallStatus.DOWNLOADING + @property + def downloads_done(self) -> bool: + """Return true if job's downloads ae done.""" + return self.status == InstallStatus.DOWNLOADS_DONE + @property def running(self) -> bool: """Return true if job is running.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index a48cf92b994..93287a40c64 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -28,7 +28,6 @@ ModelRepoVariant, ModelType, ) -from invokeai.backend.model_manager.hash import FastModelHash from invokeai.backend.model_manager.metadata import ( AnyModelRepoMetadata, CivitaiMetadataFetch, @@ -153,7 +152,6 @@ def install_path( config["source"] = model_path.resolve().as_posix() info: AnyModelConfig = self._probe_model(Path(model_path), config) - old_hash = info.current_hash if preferred_name := config.get("name"): preferred_name = Path(preferred_name).with_suffix(model_path.suffix) @@ -167,8 +165,6 @@ def install_path( raise DuplicateModelException( f"A model named {model_path.name} is already installed at {dest_path.as_posix()}" ) from excp - new_hash = FastModelHash.hash(new_path) - assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted." return self._register( new_path, @@ -284,7 +280,7 @@ def sync_to_config(self) -> None: def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()} callback = self._scan_install if install else self._scan_register - search = ModelSearch(on_model_found=callback) + search = ModelSearch(on_model_found=callback, config=self._app_config) self._models_installed.clear() search.search(scan_dir) return list(self._models_installed) @@ -370,7 +366,7 @@ def _install_next_item(self) -> None: self._signal_job_errored(job) elif ( - job.waiting or job.downloading + job.waiting or job.downloads_done ): # local jobs will be in waiting state, remote jobs will be downloading state job.total_bytes = self._stat_size(job.local_path) job.bytes = job.total_bytes @@ -448,7 +444,7 @@ def _scan_models_directory(self) -> None: installed.update(self.scan_directory(models_dir)) self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered") - def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig: + def _sync_model_path(self, key: str) -> AnyModelConfig: """ Move model into the location indicated by its basetype, type and name. @@ -469,14 +465,7 @@ def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyMod new_path = models_dir / model.base.value / model.type.value / model.name self._logger.info(f"Moving {model.name} to {new_path}.") new_path = self._move_model(old_path, new_path) - new_hash = FastModelHash.hash(new_path) model.path = new_path.relative_to(models_dir).as_posix() - if model.current_hash != new_hash: - assert ( - ignore_hash_change - ), f"{model.name}: Model hash changed during installation, model is possibly corrupted" - model.current_hash = new_hash - self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}") self.record_store.update_model(key, model) return model @@ -749,8 +738,8 @@ def _download_complete_callback(self, download_job: DownloadJob) -> None: self._download_cache.pop(download_job.source, None) # are there any more active jobs left in this task? - if all(x.complete for x in install_job.download_parts): - # now enqueue job for actual installation into the models directory + if install_job.downloading and all(x.complete for x in install_job.download_parts): + install_job.status = InstallStatus.DOWNLOADS_DONE self._install_queue.put(install_job) # Let other threads know that the number of downloads has changed diff --git a/pyproject.toml b/pyproject.toml index f4608063542..26db5a63c7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,7 @@ dependencies = [ "pre-commit", "pytest>6.0.0", "pytest-cov", + "pytest-timeout", "pytest-datadir", "requests_testadapter", "httpx", @@ -186,9 +187,10 @@ version = { attr = "invokeai.version.__version__" } #=== Begin: PyTest and Coverage [tool.pytest.ini_options] -addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\"" +addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers --timeout 60 -m \"not slow\"" markers = [ "slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".", + "timeout: Marks the timeout override." ] [tool.coverage.run] branch = true diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py index 5cb8cf1c37b..c0da3ec51ca 100644 --- a/tests/app/routers/test_images.py +++ b/tests/app/routers/test_images.py @@ -1,6 +1,8 @@ +import os from pathlib import Path from typing import Any +import pytest from fastapi import BackgroundTasks from fastapi.testclient import TestClient @@ -9,7 +11,11 @@ from invokeai.app.services.board_records.board_records_common import BoardRecord from invokeai.app.services.invoker import Invoker -client = TestClient(app) + +@pytest.fixture(autouse=True, scope="module") +def client(invokeai_root_dir: Path) -> TestClient: + os.environ["INVOKEAI_ROOT"] = invokeai_root_dir.as_posix() + return TestClient(app) class MockApiDependencies(ApiDependencies): @@ -19,7 +25,7 @@ def __init__(self, invoker) -> None: self.invoker = invoker -def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None: +def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None: prepare_download_images_test(monkeypatch, mock_invoker) response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]}) @@ -28,7 +34,9 @@ def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> N assert json_response["bulk_download_item_name"] == "test.zip" -def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None: +def test_download_images_from_board_id_empty_image_name_list( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: expected_board_name = "test" def mock_get(*args, **kwargs): @@ -56,7 +64,9 @@ def mock_add_task(*args, **kwargs): monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task) -def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any, mock_invoker: Invoker) -> None: +def test_download_images_with_empty_image_list_and_no_board_id( + monkeypatch: Any, mock_invoker: Invoker, client: TestClient +) -> None: prepare_download_images_test(monkeypatch, mock_invoker) response = client.post("/api/v1/images/download", json={"image_names": []}) @@ -64,7 +74,7 @@ def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any, assert response.status_code == 400 -def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker) -> None: +def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None: mock_file: Path = tmp_path / "test.zip" mock_file.write_text("contents") @@ -82,7 +92,7 @@ def mock_add_task(*args, **kwargs): assert response.content == b"contents" -def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker) -> None: +def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None: monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker)) def mock_add_task(*args, **kwargs): @@ -96,7 +106,7 @@ def mock_add_task(*args, **kwargs): def test_get_bulk_download_image_image_deleted_after_response( - monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path + monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path, client: TestClient ) -> None: mock_file: Path = tmp_path / "test.zip" mock_file.write_text("contents") diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 9c1826170e6..543703d713a 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -167,6 +167,7 @@ def broken_callback(job: DownloadJob) -> None: queue.stop() +@pytest.mark.timeout(timeout=15, method="thread") def test_cancel(tmp_path: Path, session: Session) -> None: event_bus = TestEventService() @@ -182,6 +183,9 @@ def cancelled_callback(job: DownloadJob) -> None: nonlocal cancelled cancelled = True + def handler(signum, frame): + raise TimeoutError("Join took too long to return") + job = queue.download( source=AnyHttpUrl("http://www.civitai.com/models/12345"), dest=tmp_path, diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 80b106c5cb2..7e51e8deb3d 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -196,6 +196,7 @@ def test_delete_register( store.get_model(key) +@pytest.mark.timeout(timeout=20, method="thread") def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) @@ -221,6 +222,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"] +@pytest.mark.timeout(timeout=20, method="thread") def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index df54e2f9267..fce72cb04d7 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -2,6 +2,7 @@ import os import shutil +import time from pathlib import Path from typing import Any, Dict, List @@ -149,6 +150,7 @@ def mm2_installer( def stop_installer() -> None: installer.stop() + time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message request.addfinalizer(stop_installer) return installer diff --git a/tests/conftest.py b/tests/conftest.py index a483b7529a1..06d29b05bed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,8 @@ # We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not # play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. import logging +import shutil +from pathlib import Path import pytest @@ -58,3 +60,11 @@ def mock_services() -> InvocationServices: @pytest.fixture() def mock_invoker(mock_services: InvocationServices) -> Invoker: return Invoker(services=mock_services) + + +@pytest.fixture(scope="module") +def invokeai_root_dir(tmp_path_factory) -> Path: + root_template = Path(__file__).parent.resolve() / "backend/model_manager/data/invokeai_root" + temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root" + shutil.copytree(root_template, temp_dir) + return temp_dir