diff --git a/.github/workflows/make.yml b/.github/workflows/make.yml index 68693c627..0ff99ec37 100644 --- a/.github/workflows/make.yml +++ b/.github/workflows/make.yml @@ -16,7 +16,7 @@ on: jobs: test-build-release: # commit the word "build" to the commit message to enable this job - if: contains(${{ github.event.head_commit.message }}, 'build') + if: contains(github.event.head_commit.message, 'build') runs-on: ${{ matrix.os }} strategy: diff --git a/.github/workflows/test-backend.yml b/.github/workflows/test-backend.yml index d432d0e28..dc601592c 100644 --- a/.github/workflows/test-backend.yml +++ b/.github/workflows/test-backend.yml @@ -49,7 +49,6 @@ jobs: - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - - run: npx ts-node ./scripts/install-required-deps.ts - run: python ./backend/src/run.py --no-run env: TYPE_CHECK_LEVEL: 'error' diff --git a/backend/src/api.py b/backend/src/api.py index e6bea2f41..c657ecaa4 100644 --- a/backend/src/api.py +++ b/backend/src/api.py @@ -19,6 +19,10 @@ typeValidateSchema, ) +KB = 1024**1 +MB = 1024**2 +GB = 1024**3 + def _process_inputs(base_inputs: Iterable[Union[BaseInput, NestedGroup]]): inputs: List[BaseInput] = [] @@ -174,11 +178,34 @@ def toDict(self): } +@dataclass +class Dependency: + display_name: str + pypi_name: str + version: str + size_estimate: int | float + auto_update: bool = False + extra_index_url: str | None = None + + import_name: str | None = None + + def toDict(self): + return { + "displayName": self.display_name, + "pypiName": self.pypi_name, + "version": self.version, + "sizeEstimate": int(self.size_estimate), + "autoUpdate": self.auto_update, + "findLink": self.extra_index_url, + } + + @dataclass class Package: where: str name: str - dependencies: List[str] = field(default_factory=list) + description: str + dependencies: List[Dependency] = field(default_factory=list) categories: List[Category] = field(default_factory=list) def add_category( @@ -200,6 +227,12 @@ def add_category( self.categories.append(result) return result + def add_dependency( + self, + dependency: Dependency, + ): + self.dependencies.append(dependency) + def _iter_py_files(directory: str): for root, _, files in os.walk(directory): @@ -271,5 +304,7 @@ def _refresh_nodes(self): registry = PackageRegistry() -def add_package(where: str, name: str, dependencies: List[str]) -> Package: - return registry.add(Package(where, name, dependencies)) +def add_package( + where: str, name: str, description: str, dependencies: List[Dependency] +) -> Package: + return registry.add(Package(where, name, description, dependencies)) diff --git a/backend/src/dependencies/__init__.py b/backend/src/dependencies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/src/dependencies/install_core_deps.py b/backend/src/dependencies/install_core_deps.py new file mode 100644 index 000000000..2477889de --- /dev/null +++ b/backend/src/dependencies/install_core_deps.py @@ -0,0 +1,54 @@ +from typing import List + +from system import is_arm_mac, is_windows + +from .store import DependencyInfo +from .versioned_dependency_helpers import install_version_checked_dependencies + +# I'm leaving this here in case I can use the Dependency class in the future, so I don't lose the extra info + +# dependencies=[ +# Dependency("OpenCV", "opencv-python", "4.7.0.68", 30 * MB, import_name="cv2"), +# Dependency("NumPy", "numpy", "1.23.2", 15 * MB), +# Dependency("Pillow (PIL)", "Pillow", "9.2.0", 3 * MB, import_name="PIL"), +# Dependency("appdirs", "appdirs", "1.4.4", 13.5 * KB), +# Dependency("FFMPEG", "ffmpeg-python", "0.2.0", 25 * KB, import_name="ffmpeg"), +# Dependency("Requests", "requests", "2.28.2", 452 * KB), +# Dependency("re2", "google-re2", "1.0", 275 * KB, import_name="re2"), +# Dependency("scipy", "scipy", "1.9.3", 42 * MB), +# ], + +# if is_arm_mac: +# package.add_dependency(Dependency("Pasteboard", "pasteboard", "0.3.3", 19 * KB)) +# elif is_windows: +# package.add_dependency( +# Dependency("Pywin32", "pywin32", "304", 12 * MB, import_name="win32clipboard") +# ) + +deps: List[DependencyInfo] = [ + { + "package_name": "appdirs", + "version": "1.4.4", + }, + { + "package_name": "ffmpeg-python", + "version": "0.2.0", + }, + { + "package_name": "requests", + "version": "2.28.2", + }, + { + "package_name": "scipy", + "version": "1.9.3", + }, + {"package_name": "pynvml", "version": "11.5.0"}, + {"package_name": "typing_extensions", "version": "4.6.2"}, +] + +if is_arm_mac: + deps.append({"package_name": "pasteboard", "version": "0.3.3"}) +elif is_windows: + deps.append({"package_name": "pywin32", "version": None}) + +install_version_checked_dependencies(deps) diff --git a/backend/src/dependencies/install_essential_deps.py b/backend/src/dependencies/install_essential_deps.py new file mode 100644 index 000000000..b92c7cdb8 --- /dev/null +++ b/backend/src/dependencies/install_essential_deps.py @@ -0,0 +1,15 @@ +from typing import List + +from .store import DependencyInfo, install_dependencies + +# These are the dependencies we _absolutely need_ to install before we can do anything else +deps: List[DependencyInfo] = [ + { + "package_name": "semver", + "version": "3.0.0", + }, +] + + +# Note: We can't be sure we have semver yet so we can't use it to compare versions +install_dependencies(deps) diff --git a/backend/src/dependencies/install_server_deps.py b/backend/src/dependencies/install_server_deps.py new file mode 100644 index 000000000..e330063b8 --- /dev/null +++ b/backend/src/dependencies/install_server_deps.py @@ -0,0 +1,43 @@ +import subprocess +from json import loads as json_parse +from typing import List + +from .store import DependencyInfo, installed_packages, python_path +from .versioned_dependency_helpers import install_version_checked_dependencies + +# Get the list of installed packages +# We can't rely on using the package's __version__ attribute because not all packages actually have it +try: + pip_list = subprocess.check_output( + [python_path, "-m", "pip", "list", "--format=json"] + ) + for p in json_parse(pip_list): + installed_packages[p["name"]] = p["version"] +except Exception as e: + print(f"Failed to get installed packages: {e}") + + +deps: List[DependencyInfo] = [ + { + "package_name": "sanic", + "version": "23.3.0", + }, + { + "package_name": "Sanic-Cors", + "version": "2.2.0", + }, + { + "package_name": "numpy", + "version": "1.23.2", + }, + { + "package_name": "opencv-python", + "version": "4.7.0.68", + }, + { + "package_name": "Pillow", + "version": "9.2.0", + }, +] + +install_version_checked_dependencies(deps) diff --git a/backend/src/dependencies/store.py b/backend/src/dependencies/store.py new file mode 100644 index 000000000..97e2f9466 --- /dev/null +++ b/backend/src/dependencies/store.py @@ -0,0 +1,47 @@ +import subprocess +import sys +from typing import Dict, List, TypedDict, Union + +python_path = sys.executable + +installed_packages: Dict[str, Union[str, None]] = {} + + +class DependencyInfo(TypedDict): + package_name: str + version: Union[str, None] + + +def pin(package_name: str, version: Union[str, None]) -> str: + if version is None: + return package_name + return f"{package_name}=={version}" + + +def install_dependencies(dependency_info_array: List[DependencyInfo]): + subprocess.check_call( + [ + python_path, + "-m", + "pip", + "install", + *[ + pin(dep_info["package_name"], dep_info["version"]) + for dep_info in dependency_info_array + ], + "--disable-pip-version-check", + "--no-warn-script-location", + ] + ) + for dep_info in dependency_info_array: + package_name = dep_info["package_name"] + version = dep_info["version"] + installed_packages[package_name] = version + + +__all__ = [ + "DependencyInfo", + "python_path", + "install_dependencies", + "installed_packages", +] diff --git a/backend/src/dependencies/versioned_dependency_helpers.py b/backend/src/dependencies/versioned_dependency_helpers.py new file mode 100644 index 000000000..a480122fb --- /dev/null +++ b/backend/src/dependencies/versioned_dependency_helpers.py @@ -0,0 +1,36 @@ +import re +from typing import List + +from semver.version import Version + +from .store import DependencyInfo, install_dependencies, installed_packages + + +def coerce_semver(version: str) -> Version: + try: + return Version.parse(version, True) + except Exception: + regex = r"(\d+)\.(\d+)\.(\d+)" + match = re.search(regex, version) + if match: + return Version( + int(match.group(1)), + int(match.group(2)), + int(match.group(3)), + ) + return Version(0, 0, 0) + + +def install_version_checked_dependencies(dependencies: List[DependencyInfo]): + dependencies_to_install = [] + for dependency in dependencies: + version = installed_packages.get(dependency["package_name"], None) + if dependency["version"] and version: + installed_version = coerce_semver(version) + dep_version = coerce_semver(dependency["version"]) + if installed_version < dep_version: + dependencies_to_install.append(dependency) + elif not version: + dependencies_to_install.append(dependency) + if len(dependencies_to_install) > 0: + install_dependencies(dependencies_to_install) diff --git a/backend/src/events.py b/backend/src/events.py index 40afaa531..becefe752 100644 --- a/backend/src/events.py +++ b/backend/src/events.py @@ -47,6 +47,11 @@ class IteratorProgressUpdateData(TypedDict): running: Optional[List[NodeId]] +class BackendStatusData(TypedDict): + message: str + percent: float + + class FinishEvent(TypedDict): event: Literal["finish"] data: FinishData @@ -67,11 +72,23 @@ class IteratorProgressUpdateEvent(TypedDict): data: IteratorProgressUpdateData +class BackendStatusEvent(TypedDict): + event: Literal["backend-status"] + data: BackendStatusData + + +class BackendStateEvent(TypedDict): + event: Union[Literal["backend-ready"], Literal["backend-started"]] + data: None + + Event = Union[ FinishEvent, ExecutionErrorEvent, NodeFinishEvent, IteratorProgressUpdateEvent, + BackendStatusEvent, + BackendStateEvent, ] @@ -84,3 +101,14 @@ async def get(self) -> Event: async def put(self, event: Event) -> None: await self.queue.put(event) + + async def wait_until_empty(self, timeout: float) -> None: + while timeout > 0: + if self.queue.empty(): + return + await asyncio.sleep(0.01) + timeout -= 0.01 + + async def put_and_wait(self, event: Event, timeout: float = float("inf")) -> None: + await self.queue.put(event) + await self.wait_until_empty(timeout) diff --git a/backend/src/gpu.py b/backend/src/gpu.py new file mode 100644 index 000000000..aa2929770 --- /dev/null +++ b/backend/src/gpu.py @@ -0,0 +1,49 @@ +import pynvml as nv +from sanic.log import logger + +nvidia_is_available = False + +try: + nv.nvmlInit() + nvidia_is_available = True + nv.nvmlShutdown() +except nv.NVMLError as e: + logger.info("No Nvidia GPU found, or invalid driver installed.") +except Exception as e: + logger.info(f"Unknown error occurred when trying to initialize Nvidia GPU: {e}") + + +class NvidiaHelper: + def __init__(self): + nv.nvmlInit() + + self.__num_gpus = nv.nvmlDeviceGetCount() + + self.__gpus = [] + for i in range(self.__num_gpus): + handle = nv.nvmlDeviceGetHandleByIndex(i) + self.__gpus.append( + { + "name": nv.nvmlDeviceGetName(handle), + "uuid": nv.nvmlDeviceGetUUID(handle), + "index": i, + "handle": handle, + } + ) + + def __del__(self): + nv.nvmlShutdown() + + def list_gpus(self): + return self.__gpus + + def get_current_vram_usage(self, gpu_index=0): + info = nv.nvmlDeviceGetMemoryInfo(self.__gpus[gpu_index]["handle"]) + + return info.total, info.used, info.free + + +__all__ = [ + "nvidia_is_available", + "NvidiaHelper", +] diff --git a/backend/src/nodes/impl/pytorch/architecture/SPSR.py b/backend/src/nodes/impl/pytorch/architecture/SPSR.py index 6f5ac458c..c3cefff19 100644 --- a/backend/src/nodes/impl/pytorch/architecture/SPSR.py +++ b/backend/src/nodes/impl/pytorch/architecture/SPSR.py @@ -60,7 +60,6 @@ def __init__( self.out_nc: int = self.state["f_HR_conv1.0.bias"].shape[0] self.scale = self.get_scale(4) - print(self.scale) self.num_filters: int = self.state["model.0.weight"].shape[0] self.supports_fp16 = True diff --git a/backend/src/packages/chaiNNer_external/__init__.py b/backend/src/packages/chaiNNer_external/__init__.py index ef77d6839..55bb0caf6 100644 --- a/backend/src/packages/chaiNNer_external/__init__.py +++ b/backend/src/packages/chaiNNer_external/__init__.py @@ -2,7 +2,12 @@ from api import add_package -package = add_package(__file__, name="chaiNNer_external", dependencies=[]) +package = add_package( + __file__, + name="External", + description="Interact with an external Stable Diffusion API", + dependencies=[], +) external_stable_diffusion_category = package.add_category( name="Stable Diffusion (External)", diff --git a/backend/src/packages/chaiNNer_ncnn/__init__.py b/backend/src/packages/chaiNNer_ncnn/__init__.py index cd2e39d67..f8fd75e18 100644 --- a/backend/src/packages/chaiNNer_ncnn/__init__.py +++ b/backend/src/packages/chaiNNer_ncnn/__init__.py @@ -1,8 +1,23 @@ from sanic.log import logger -from api import add_package +from api import MB, Dependency, add_package +from system import is_mac -package = add_package(__file__, name="chaiNNer_ncnn", dependencies=[]) +package = add_package( + __file__, + name="NCNN", + description="NCNN uses .bin/.param models to upscale images. NCNN uses Vulkan for GPU acceleration, meaning it supports any modern GPU. Models can be converted from PyTorch to NCNN.", + dependencies=[ + Dependency( + display_name="NCNN", + pypi_name="ncnn-vulkan", + version="2022.9.12", + size_estimate=7 * MB if is_mac else 4 * MB, + auto_update=True, + import_name="ncnn_vulkan", + ), + ], +) ncnn_category = package.add_category( name="NCNN", diff --git a/backend/src/packages/chaiNNer_onnx/__init__.py b/backend/src/packages/chaiNNer_onnx/__init__.py index b60bd1a78..1b2d94b7b 100644 --- a/backend/src/packages/chaiNNer_onnx/__init__.py +++ b/backend/src/packages/chaiNNer_onnx/__init__.py @@ -1,8 +1,84 @@ from sanic.log import logger -from api import add_package +from api import KB, MB, Dependency, add_package +from gpu import nvidia_is_available +from system import is_arm_mac -package = add_package(__file__, name="chaiNNer_onnx", dependencies=[]) + +def get_onnx_runtime(): + if is_arm_mac: + return Dependency( + display_name="ONNX Runtime (Silicon)", + pypi_name="onnxruntime-silicon", + version="1.13.1", + size_estimate=6 * MB, + import_name="onnxruntime", + ) + elif nvidia_is_available: + return Dependency( + display_name="ONNX Runtime (GPU)", + pypi_name="onnxruntime-gpu", + version="1.13.1", + size_estimate=110 * MB, + import_name="onnxruntime", + ) + else: + return Dependency( + display_name="ONNX Runtime", + pypi_name="onnxruntime", + version="1.13.1", + size_estimate=5 * MB, + ) + + +def get_onnx_optimizer(): + if is_arm_mac: + return [] + + return [ + Dependency( + display_name="ONNX Optimizer", + pypi_name="onnxoptimizer", + version="0.3.6", + size_estimate=300 * KB, + ) + ] + + +package = add_package( + __file__, + name="ONNX", + description="ONNX uses .onnx models to upscale images. It also helps to convert between PyTorch and NCNN. It is fastest when CUDA is supported. If TensorRT is installed on the system, it can also be configured to use that.", + dependencies=[ + Dependency( + display_name="ONNX", + pypi_name="onnx", + version="1.13.0", + size_estimate=12 * MB, + ), + *get_onnx_optimizer(), + get_onnx_runtime(), + Dependency( + display_name="Protobuf", + pypi_name="protobuf", + version="3.20.2", + size_estimate=500 * KB, + ), + Dependency( + display_name="Numba", + pypi_name="numba", + version="0.56.3", + size_estimate=2.5 * MB, + ), + Dependency( + display_name="re2", + pypi_name="google-re2", + version="1.0", + size_estimate=275 * KB, + import_name="re2", + ), + ], +) onnx_category = package.add_category( name="ONNX", diff --git a/backend/src/packages/chaiNNer_pytorch/__init__.py b/backend/src/packages/chaiNNer_pytorch/__init__.py index 316f5dc52..9127f4472 100644 --- a/backend/src/packages/chaiNNer_pytorch/__init__.py +++ b/backend/src/packages/chaiNNer_pytorch/__init__.py @@ -1,8 +1,80 @@ +import sys + from sanic.log import logger -from api import add_package +from api import GB, KB, MB, Dependency, add_package +from gpu import nvidia_is_available + +python_version = sys.version_info + + +def get_pytorch(): + if python_version.minor < 10: + # <= 3.9 + return [ + Dependency( + display_name="PyTorch", + pypi_name="torch", + version="1.10.2+cu113" if nvidia_is_available else "1.10.2", + size_estimate=2 * GB if nvidia_is_available else 140 * MB, + extra_index_url="https://download.pytorch.org/whl/cu113" + if nvidia_is_available + else None, + ), + Dependency( + display_name="TorchVision", + pypi_name="torchvision", + version="0.11.3+cu113" if nvidia_is_available else "0.11.3", + size_estimate=2 * MB if nvidia_is_available else 800 * KB, + extra_index_url="https://download.pytorch.org/whl/cu113" + if nvidia_is_available + else None, + ), + ] + else: + # >= 3.10 + return [ + Dependency( + display_name="PyTorch", + pypi_name="torch", + version="1.12.1+cu116" if nvidia_is_available else "1.12.1", + size_estimate=2 * GB if nvidia_is_available else 140 * MB, + extra_index_url="https://download.pytorch.org/whl/cu116" + if nvidia_is_available + else None, + ), + Dependency( + display_name="TorchVision", + pypi_name="torchvision", + version="0.13.1+cu116" if nvidia_is_available else "0.13.1", + size_estimate=2 * MB if nvidia_is_available else 800 * KB, + extra_index_url="https://download.pytorch.org/whl/cu116" + if nvidia_is_available + else None, + ), + ] -package = add_package(__file__, name="chaiNNer_pytorch", dependencies=[]) + +package = add_package( + __file__, + name="PyTorch", + description="PyTorch uses .pth models to upscale images, and is fastest when CUDA is supported (Nvidia GPU). If CUDA is unsupported, it will install with CPU support (which is very slow).", + dependencies=[ + *get_pytorch(), + Dependency( + display_name="FaceXLib", + pypi_name="facexlib", + version="0.2.5", + size_estimate=1.1 * MB, + ), + Dependency( + display_name="Einops", + pypi_name="einops", + version="0.5.0", + size_estimate=36.5 * KB, + ), + ], +) pytorch_category = package.add_category( name="PyTorch", diff --git a/backend/src/packages/chaiNNer_standard/__init__.py b/backend/src/packages/chaiNNer_standard/__init__.py index 33f0584b3..ddc003e00 100644 --- a/backend/src/packages/chaiNNer_standard/__init__.py +++ b/backend/src/packages/chaiNNer_standard/__init__.py @@ -2,7 +2,12 @@ from api import add_package -package = add_package(__file__, name="chaiNNer_standard", dependencies=[]) +package = add_package( + __file__, + name="chaiNNer_standard", + description="The standard set of nodes for chaiNNer.", + dependencies=[], +) image_category = package.add_category( name="Image", diff --git a/backend/src/run.py b/backend/src/run.py index 43fd4fb27..91086f760 100644 --- a/backend/src/run.py +++ b/backend/src/run.py @@ -1,397 +1,11 @@ -import asyncio -import functools -import gc import importlib -import logging -import sys -import traceback -from concurrent.futures import ThreadPoolExecutor -from json import dumps as stringify -from typing import Dict, List, Optional, TypedDict -# pylint: disable-next=unused-import -import cv2 # type: ignore -from sanic import Sanic -from sanic.log import access_logger, logger -from sanic.request import Request -from sanic.response import json -from sanic_cors import CORS +# Install absolutely required dependencies -- aka anything we need to install other dependencies properly (e.g. semver) +importlib.import_module("dependencies.install_essential_deps") -import api -from base_types import NodeId -from chain.cache import OutputCache -from chain.json import JsonNode, parse_json -from chain.optimize import optimize -from events import EventQueue, ExecutionErrorData -from nodes.group import Group -from nodes.utils.exec_options import ( - JsonExecutionOptions, - parse_execution_options, - set_execution_options, -) -from process import ( - Executor, - NodeExecutionError, - Output, - compute_broadcast, - run_node, - timed_supplier, -) -from progress_controller import Aborted -from response import ( - alreadyRunningResponse, - errorResponse, - noExecutorResponse, - successResponse, -) +# Install server dependencies (now we can use semver for version checking) +importlib.import_module("dependencies.install_server_deps") - -class AppContext: - def __init__(self): - self.executor: Optional[Executor] = None - self.cache: Dict[NodeId, Output] = dict() - # This will be initialized by setup_queue. - # This is necessary because we don't know Sanic's event loop yet. - self.queue: EventQueue = None # type: ignore - self.pool = ThreadPoolExecutor(max_workers=4) - - @staticmethod - def get(app_instance: Sanic) -> "AppContext": - assert isinstance(app_instance.ctx, AppContext) - return app_instance.ctx - - -app = Sanic("chaiNNer", ctx=AppContext()) -app.config.REQUEST_TIMEOUT = sys.maxsize -app.config.RESPONSE_TIMEOUT = sys.maxsize -CORS(app) - -# Manually import built-in packages to get ordering correct -# Using importlib here so we don't have to ignore that it isn't used -importlib.import_module("packages.chaiNNer_standard") -importlib.import_module("packages.chaiNNer_pytorch") -importlib.import_module("packages.chaiNNer_ncnn") -importlib.import_module("packages.chaiNNer_onnx") -importlib.import_module("packages.chaiNNer_external") - -# in the future, for external packages dir, scan & import -# for package in os.listdir(packages_dir): -# # logger.info(package) -# importlib.import_module(package) - -api.registry.load_nodes(__file__) - - -class SSEFilter(logging.Filter): - def filter(self, record): - return not (record.request.endswith("/sse") and record.status == 200) # type: ignore - - -class ZeroCounter: - def __init__(self) -> None: - self.count = 0 - - async def wait_zero(self) -> None: - while self.count != 0: - await asyncio.sleep(0.01) - - def __enter__(self): - self.count += 1 - - def __exit__(self, _exc_type, _exc_value, _exc_traceback): - self.count -= 1 - - -runIndividualCounter = ZeroCounter() - - -access_logger.addFilter(SSEFilter()) - - -@app.route("/nodes") -async def nodes(_): - """Gets a list of all nodes as well as the node information""" - logger.debug(api.registry.categories) - - node_list = [] - for node, sub in api.registry.nodes.values(): - node_dict = { - "schemaId": node.schema_id, - "name": node.name, - "category": sub.category.name, - "inputs": [x.toDict() for x in node.inputs], - "outputs": [x.toDict() for x in node.outputs], - "groupLayout": [ - g.toDict() if isinstance(g, Group) else g for g in node.group_layout - ], - "description": node.description, - "icon": node.icon, - "subcategory": sub.name, - "nodeType": node.type, - "hasSideEffects": node.side_effects, - "deprecated": node.deprecated, - "defaultNodes": node.default_nodes, - } - node_list.append(node_dict) - - return json( - { - "nodes": node_list, - "categories": [x.toDict() for x in api.registry.categories], - "categoriesMissingNodes": [], - } - ) - - -class RunRequest(TypedDict): - data: List[JsonNode] - options: JsonExecutionOptions - sendBroadcastData: bool - - -@app.route("/run", methods=["POST"]) -async def run(request: Request): - """Runs the provided nodes""" - ctx = AppContext.get(request.app) - - if ctx.executor: - message = "Cannot run another executor while the first one is still running." - logger.warning(message) - return json(alreadyRunningResponse(message), status=500) - - try: - # wait until all previews are done - await runIndividualCounter.wait_zero() - - full_data: RunRequest = dict(request.json) # type: ignore - logger.debug(full_data) - chain, inputs = parse_json(full_data["data"]) - optimize(chain) - - logger.info("Running new executor...") - exec_opts = parse_execution_options(full_data["options"]) - set_execution_options(exec_opts) - logger.debug(f"Using device: {exec_opts.full_device}") - executor = Executor( - chain, - inputs, - full_data["sendBroadcastData"], - app.loop, - ctx.queue, - ctx.pool, - parent_cache=OutputCache(static_data=ctx.cache.copy()), - ) - try: - ctx.executor = executor - await executor.run() - except Aborted: - pass - finally: - ctx.executor = None - gc.collect() - - await ctx.queue.put( - {"event": "finish", "data": {"message": "Successfully ran nodes!"}} - ) - return json(successResponse("Successfully ran nodes!"), status=200) - except Exception as exception: - logger.error(exception, exc_info=True) - logger.error(traceback.format_exc()) - - error: ExecutionErrorData = { - "message": "Error running nodes!", - "source": None, - "exception": str(exception), - } - if isinstance(exception, NodeExecutionError): - error["source"] = { - "nodeId": exception.node_id, - "schemaId": exception.node_data.schema_id, - "inputs": exception.inputs, - } - - await ctx.queue.put({"event": "execution-error", "data": error}) - return json(errorResponse("Error running nodes!", exception), status=500) - - -class RunIndividualRequest(TypedDict): - id: NodeId - inputs: List[object] - schemaId: str - options: JsonExecutionOptions - - -@app.route("/run/individual", methods=["POST"]) -async def run_individual(request: Request): - """Runs a single node""" - ctx = AppContext.get(request.app) - try: - full_data: RunIndividualRequest = dict(request.json) # type: ignore - node_id = full_data["id"] - if ctx.cache.get(node_id, None) is not None: - del ctx.cache[node_id] - logger.debug(full_data) - exec_opts = parse_execution_options(full_data["options"]) - set_execution_options(exec_opts) - logger.debug(f"Using device: {exec_opts.full_device}") - # Create node based on given category/name information - node_instance = api.registry.get_node(full_data["schemaId"]) - - with runIndividualCounter: - # Run the node and pass in inputs as args - output, execution_time = await app.loop.run_in_executor( - None, - timed_supplier( - functools.partial( - run_node, node_instance, full_data["inputs"], node_id - ) - ), - ) - # Cache the output of the node - ctx.cache[node_id] = output - - # Broadcast the output from the individual run - node_outputs = node_instance.outputs - if len(node_outputs) > 0: - data, types = compute_broadcast(output, node_outputs) - await ctx.queue.put( - { - "event": "node-finish", - "data": { - "finished": [], - "nodeId": node_id, - "executionTime": execution_time, - "data": data, - "types": types, - "progressPercent": None, - }, - } - ) - gc.collect() - return json({"success": True, "data": None}) - except Exception as exception: - logger.error(exception, exc_info=True) - return json({"success": False, "error": str(exception)}) - - -@app.route("/clearcache/individual", methods=["POST"]) -async def clear_cache_individual(request: Request): - ctx = AppContext.get(request.app) - try: - full_data = dict(request.json) # type: ignore - if ctx.cache.get(full_data["id"], None) is not None: - del ctx.cache[full_data["id"]] - return json({"success": True, "data": None}) - except Exception as exception: - logger.error(exception, exc_info=True) - return json({"success": False, "error": str(exception)}) - - -@app.get("/sse") -async def sse(request: Request): - ctx = AppContext.get(request.app) - headers = {"Cache-Control": "no-cache"} - response = await request.respond(headers=headers, content_type="text/event-stream") - while True: - message = await ctx.queue.get() - if response is not None: - await response.send(f"event: {message['event']}\n") - await response.send(f"data: {stringify(message['data'])}\n\n") - - -@app.after_server_start -async def setup_queue(sanic_app: Sanic, _): - AppContext.get(sanic_app).queue = EventQueue() - - -@app.route("/pause", methods=["POST"]) -async def pause(request: Request): - """Pauses the current execution""" - ctx = AppContext.get(request.app) - - if not ctx.executor: - message = "No executor to pause" - logger.warning(message) - return json(noExecutorResponse(message), status=400) - - try: - logger.info("Executor found. Attempting to pause...") - ctx.executor.pause() - return json(successResponse("Successfully paused execution!"), status=200) - except Exception as exception: - logger.log(2, exception, exc_info=True) - return json(errorResponse("Error pausing execution!", exception), status=500) - - -@app.route("/resume", methods=["POST"]) -async def resume(request: Request): - """Pauses the current execution""" - ctx = AppContext.get(request.app) - - if not ctx.executor: - message = "No executor to resume" - logger.warning(message) - return json(noExecutorResponse(message), status=400) - - try: - logger.info("Executor found. Attempting to resume...") - ctx.executor.resume() - return json(successResponse("Successfully resumed execution!"), status=200) - except Exception as exception: - logger.log(2, exception, exc_info=True) - return json(errorResponse("Error resuming execution!", exception), status=500) - - -@app.route("/kill", methods=["POST"]) -async def kill(request: Request): - """Kills the current execution""" - ctx = AppContext.get(request.app) - - if not ctx.executor: - message = "No executor to kill" - logger.warning("No executor to kill") - return json(noExecutorResponse(message), status=400) - - try: - logger.info("Executor found. Attempting to kill...") - ctx.executor.kill() - while ctx.executor: - await asyncio.sleep(0.0001) - return json(successResponse("Successfully killed execution!"), status=200) - except Exception as exception: - logger.log(2, exception, exc_info=True) - return json(errorResponse("Error killing execution!", exception), status=500) - - -@app.route("/listgpus/ncnn", methods=["GET"]) -async def list_ncnn_gpus(_request: Request): - """Lists the available GPUs for NCNN""" - try: - # pylint: disable=import-outside-toplevel - from ncnn_vulkan import ncnn - - result = [] - for i in range(ncnn.get_gpu_count()): - result.append(ncnn.get_gpu_info(i).device_name()) - return json(result) - except Exception as exception: - logger.error(exception, exc_info=True) - return json([]) - - -@app.route("/python-info", methods=["GET"]) -async def python_info(_request: Request): - version = ( - f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" - ) - return json({"python": sys.executable, "version": version}) - - -if __name__ == "__main__": - try: - port = int(sys.argv[1]) or 8000 - except: - port = 8000 - - if sys.argv[1] != "--no-run": - app.run(port=port, single_process=True) +# Now we can start the server since we have sanic installed +main_server = importlib.import_module("server") +main_server.main() diff --git a/backend/src/server.py b/backend/src/server.py new file mode 100644 index 000000000..677b97ca9 --- /dev/null +++ b/backend/src/server.py @@ -0,0 +1,544 @@ +import asyncio +import functools +import gc +import importlib +import logging +import sys +import traceback +from concurrent.futures import ThreadPoolExecutor +from json import dumps as stringify +from typing import Dict, List, Optional, TypedDict + +from sanic import Sanic +from sanic.log import access_logger, logger +from sanic.request import Request +from sanic.response import json +from sanic_cors import CORS + +import api +from base_types import NodeId +from chain.cache import OutputCache +from chain.json import JsonNode, parse_json +from chain.optimize import optimize +from dependencies.store import DependencyInfo, installed_packages +from dependencies.versioned_dependency_helpers import ( + install_version_checked_dependencies, +) +from events import EventQueue, ExecutionErrorData +from nodes.group import Group +from nodes.utils.exec_options import ( + JsonExecutionOptions, + parse_execution_options, + set_execution_options, +) +from process import ( + Executor, + NodeExecutionError, + Output, + compute_broadcast, + run_node, + timed_supplier, +) +from progress_controller import Aborted +from response import ( + alreadyRunningResponse, + errorResponse, + noExecutorResponse, + successResponse, +) + + +class AppContext: + def __init__(self): + self.executor: Optional[Executor] = None + self.cache: Dict[NodeId, Output] = dict() + # This will be initialized by after_server_start. + # This is necessary because we don't know Sanic's event loop yet. + self.queue: EventQueue = None # type: ignore + self.setup_queue: EventQueue = None # type: ignore + self.pool = ThreadPoolExecutor(max_workers=4) + self.registry = api.registry + + @staticmethod + def get(app_instance: Sanic) -> "AppContext": + assert isinstance(app_instance.ctx, AppContext) + return app_instance.ctx + + +app = Sanic("chaiNNer", ctx=AppContext()) +app.config.REQUEST_TIMEOUT = sys.maxsize +app.config.RESPONSE_TIMEOUT = sys.maxsize +CORS(app) + + +class SSEFilter(logging.Filter): + def filter(self, record): + return not ((record.request.endswith("/sse") or record.request.endswith("/setup-sse")) and record.status == 200) # type: ignore + + +class ZeroCounter: + def __init__(self) -> None: + self.count = 0 + + async def wait_zero(self) -> None: + while self.count != 0: + await asyncio.sleep(0.01) + + def __enter__(self): + self.count += 1 + + def __exit__(self, _exc_type, _exc_value, _exc_traceback): + self.count -= 1 + + +runIndividualCounter = ZeroCounter() + +setup_task = None + + +async def nodes_available(): + if setup_task is not None: + await setup_task + + +access_logger.addFilter(SSEFilter()) + + +@app.route("/nodes") +async def nodes(_request: Request): + """Gets a list of all nodes as well as the node information""" + await nodes_available() + logger.debug(api.registry.categories) + + node_list = [] + for node, sub in api.registry.nodes.values(): + node_dict = { + "schemaId": node.schema_id, + "name": node.name, + "category": sub.category.name, + "inputs": [x.toDict() for x in node.inputs], + "outputs": [x.toDict() for x in node.outputs], + "groupLayout": [ + g.toDict() if isinstance(g, Group) else g for g in node.group_layout + ], + "description": node.description, + "icon": node.icon, + "subcategory": sub.name, + "nodeType": node.type, + "hasSideEffects": node.side_effects, + "deprecated": node.deprecated, + "defaultNodes": node.default_nodes, + } + node_list.append(node_dict) + + return json( + { + "nodes": node_list, + "categories": [x.toDict() for x in api.registry.categories], + "categoriesMissingNodes": [], + } + ) + + +class RunRequest(TypedDict): + data: List[JsonNode] + options: JsonExecutionOptions + sendBroadcastData: bool + + +@app.route("/run", methods=["POST"]) +async def run(request: Request): + """Runs the provided nodes""" + await nodes_available() + ctx = AppContext.get(request.app) + + if ctx.executor: + message = "Cannot run another executor while the first one is still running." + logger.warning(message) + return json(alreadyRunningResponse(message), status=500) + + try: + # wait until all previews are done + await runIndividualCounter.wait_zero() + + full_data: RunRequest = dict(request.json) # type: ignore + logger.debug(full_data) + chain, inputs = parse_json(full_data["data"]) + optimize(chain) + + logger.info("Running new executor...") + exec_opts = parse_execution_options(full_data["options"]) + set_execution_options(exec_opts) + logger.debug(f"Using device: {exec_opts.full_device}") + executor = Executor( + chain, + inputs, + full_data["sendBroadcastData"], + app.loop, + ctx.queue, + ctx.pool, + parent_cache=OutputCache(static_data=ctx.cache.copy()), + ) + try: + ctx.executor = executor + await executor.run() + except Aborted: + pass + finally: + ctx.executor = None + gc.collect() + + await ctx.queue.put( + {"event": "finish", "data": {"message": "Successfully ran nodes!"}} + ) + return json(successResponse("Successfully ran nodes!"), status=200) + except Exception as exception: + logger.error(exception, exc_info=True) + logger.error(traceback.format_exc()) + + error: ExecutionErrorData = { + "message": "Error running nodes!", + "source": None, + "exception": str(exception), + } + if isinstance(exception, NodeExecutionError): + error["source"] = { + "nodeId": exception.node_id, + "schemaId": exception.node_data.schema_id, + "inputs": exception.inputs, + } + + await ctx.queue.put({"event": "execution-error", "data": error}) + return json(errorResponse("Error running nodes!", exception), status=500) + + +class RunIndividualRequest(TypedDict): + id: NodeId + inputs: List[object] + schemaId: str + options: JsonExecutionOptions + + +@app.route("/run/individual", methods=["POST"]) +async def run_individual(request: Request): + """Runs a single node""" + await nodes_available() + ctx = AppContext.get(request.app) + try: + full_data: RunIndividualRequest = dict(request.json) # type: ignore + node_id = full_data["id"] + if ctx.cache.get(node_id, None) is not None: + del ctx.cache[node_id] + logger.debug(full_data) + exec_opts = parse_execution_options(full_data["options"]) + set_execution_options(exec_opts) + logger.debug(f"Using device: {exec_opts.full_device}") + # Create node based on given category/name information + node_instance = api.registry.get_node(full_data["schemaId"]) + + with runIndividualCounter: + # Run the node and pass in inputs as args + output, execution_time = await app.loop.run_in_executor( + None, + timed_supplier( + functools.partial( + run_node, node_instance, full_data["inputs"], node_id + ) + ), + ) + # Cache the output of the node + ctx.cache[node_id] = output + + # Broadcast the output from the individual run + node_outputs = node_instance.outputs + if len(node_outputs) > 0: + data, types = compute_broadcast(output, node_outputs) + await ctx.queue.put( + { + "event": "node-finish", + "data": { + "finished": [], + "nodeId": node_id, + "executionTime": execution_time, + "data": data, + "types": types, + "progressPercent": None, + }, + } + ) + gc.collect() + return json({"success": True, "data": None}) + except Exception as exception: + logger.error(exception, exc_info=True) + return json({"success": False, "error": str(exception)}) + + +@app.route("/clearcache/individual", methods=["POST"]) +async def clear_cache_individual(request: Request): + await nodes_available() + ctx = AppContext.get(request.app) + try: + full_data = dict(request.json) # type: ignore + if ctx.cache.get(full_data["id"], None) is not None: + del ctx.cache[full_data["id"]] + return json({"success": True, "data": None}) + except Exception as exception: + logger.error(exception, exc_info=True) + return json({"success": False, "error": str(exception)}) + + +@app.route("/pause", methods=["POST"]) +async def pause(request: Request): + """Pauses the current execution""" + await nodes_available() + ctx = AppContext.get(request.app) + + if not ctx.executor: + message = "No executor to pause" + logger.warning(message) + return json(noExecutorResponse(message), status=400) + + try: + logger.info("Executor found. Attempting to pause...") + ctx.executor.pause() + return json(successResponse("Successfully paused execution!"), status=200) + except Exception as exception: + logger.log(2, exception, exc_info=True) + return json(errorResponse("Error pausing execution!", exception), status=500) + + +@app.route("/resume", methods=["POST"]) +async def resume(request: Request): + """Pauses the current execution""" + await nodes_available() + ctx = AppContext.get(request.app) + + if not ctx.executor: + message = "No executor to resume" + logger.warning(message) + return json(noExecutorResponse(message), status=400) + + try: + logger.info("Executor found. Attempting to resume...") + ctx.executor.resume() + return json(successResponse("Successfully resumed execution!"), status=200) + except Exception as exception: + logger.log(2, exception, exc_info=True) + return json(errorResponse("Error resuming execution!", exception), status=500) + + +@app.route("/kill", methods=["POST"]) +async def kill(request: Request): + """Kills the current execution""" + await nodes_available() + ctx = AppContext.get(request.app) + + if not ctx.executor: + message = "No executor to kill" + logger.warning("No executor to kill") + return json(noExecutorResponse(message), status=400) + + try: + logger.info("Executor found. Attempting to kill...") + ctx.executor.kill() + while ctx.executor: + await asyncio.sleep(0.0001) + return json(successResponse("Successfully killed execution!"), status=200) + except Exception as exception: + logger.log(2, exception, exc_info=True) + return json(errorResponse("Error killing execution!", exception), status=500) + + +@app.route("/listgpus/ncnn", methods=["GET"]) +async def list_ncnn_gpus(_request: Request): + """Lists the available GPUs for NCNN""" + await nodes_available() + try: + # pylint: disable=import-outside-toplevel + from ncnn_vulkan import ncnn + + result = [] + for i in range(ncnn.get_gpu_count()): + result.append(ncnn.get_gpu_info(i).device_name()) + return json(result) + except Exception as exception: + logger.error(exception, exc_info=True) + return json([]) + + +@app.route("/python-info", methods=["GET"]) +async def python_info(_request: Request): + version = ( + f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ) + return json({"python": sys.executable, "version": version}) + + +@app.route("/dependencies", methods=["GET"]) +async def get_dependencies(_request: Request): + await nodes_available() + all_dependencies = [] + for package in api.registry.packages.values(): + pkg_dependencies = [x.toDict() for x in package.dependencies] + if package.name == "chaiNNer_standard": + continue + else: + all_dependencies.append( + { + "name": package.name, + "dependencies": pkg_dependencies, + "description": package.description, + } + ) + return json(all_dependencies) + + +def import_packages(): + def install_deps(dependencies: List[api.Dependency]): + try: + dep_info: List[DependencyInfo] = [ + { + "package_name": dep.pypi_name, + "version": dep.version, + } + for dep in dependencies + ] + install_version_checked_dependencies(dep_info) + except Exception as ex: + logger.error(f"Error installing dependencies: {ex}") + + # Manually import built-in packages to get ordering correct + # Using importlib here so we don't have to ignore that it isn't used + importlib.import_module("packages.chaiNNer_standard") + importlib.import_module("packages.chaiNNer_pytorch") + importlib.import_module("packages.chaiNNer_ncnn") + importlib.import_module("packages.chaiNNer_onnx") + importlib.import_module("packages.chaiNNer_external") + + logger.info("Checking dependencies...") + + # For these, do the same as the above, but only if auto_update is true + for package in api.registry.packages.values(): + logger.info(f"Checking dependencies for {package.name}...") + if package.name == "chaiNNer_standard": + continue + auto_update_deps = [ + dep + for dep in package.dependencies + if dep.auto_update + and installed_packages.get(dep.pypi_name, None) is not None + ] + if len(auto_update_deps) > 0: + install_deps(auto_update_deps) + + logger.info("Done checking dependencies...") + + # TODO: in the future, for external packages dir, scan & import + # for package in os.listdir(packages_dir): + # importlib.import_module(package) + + api.registry.load_nodes(__file__) + + +@app.get("/sse") +async def sse(request: Request): + ctx = AppContext.get(request.app) + headers = {"Cache-Control": "no-cache"} + response = await request.respond(headers=headers, content_type="text/event-stream") + while True: + message = await ctx.queue.get() + if response is not None: + await response.send(f"event: {message['event']}\n") + await response.send(f"data: {stringify(message['data'])}\n\n") + + +@app.get("/setup-sse") +async def setup_sse(request: Request): + ctx = AppContext.get(request.app) + headers = {"Cache-Control": "no-cache"} + response = await request.respond(headers=headers, content_type="text/event-stream") + while True: + message = await ctx.setup_queue.get() + if response is not None: + await response.send(f"event: {message['event']}\n") + await response.send(f"data: {stringify(message['data'])}\n\n") + + +async def setup(sanic_app: Sanic): + logger.info("Starting setup...") + await AppContext.get(sanic_app).setup_queue.put_and_wait( + { + "event": "backend-started", + "data": None, + }, + timeout=1, + ) + + await AppContext.get(sanic_app).setup_queue.put_and_wait( + { + "event": "backend-status", + "data": {"message": "Installing dependencies...", "percent": 0.0}, + }, + timeout=1, + ) + + # Now we can install the other dependencies + importlib.import_module("dependencies.install_core_deps") + + await AppContext.get(sanic_app).setup_queue.put_and_wait( + { + "event": "backend-status", + "data": {"message": "Loading Nodes...", "percent": 0.75}, + }, + timeout=1, + ) + + logger.info("Loading nodes...") + + # Now we can load all the nodes + # TODO: Pass in a callback func for updating progress + import_packages() + + logger.info("Sending backend ready...") + + await AppContext.get(sanic_app).setup_queue.put_and_wait( + { + "event": "backend-status", + "data": {"message": "Loading Nodes...", "percent": 1}, + }, + timeout=1, + ) + + await AppContext.get(sanic_app).setup_queue.put_and_wait( + { + "event": "backend-ready", + "data": None, + }, + timeout=1, + ) + + logger.info("Done.") + + +@app.after_server_start +async def after_server_start(sanic_app: Sanic, loop: asyncio.AbstractEventLoop): + # pylint: disable=global-statement + global setup_task + AppContext.get(sanic_app).queue = EventQueue() + AppContext.get(sanic_app).setup_queue = EventQueue() + setup_task = loop.create_task(setup(sanic_app)) + + +def main(): + try: + port = int(sys.argv[1]) or 8000 + except: + port = 8000 + print(sys.argv) + if len(sys.argv) > 1 and sys.argv[1] == "--no-run": + sys.exit() + app.run(port=port, single_process=True) + + +if __name__ == "__main__": + main() diff --git a/backend/src/system.py b/backend/src/system.py new file mode 100644 index 000000000..a0ee609a9 --- /dev/null +++ b/backend/src/system.py @@ -0,0 +1,7 @@ +import platform +import sys + +is_mac = sys.platform == "darwin" +is_arm_mac = is_mac and platform.machine() == "arm64" +is_windows = sys.platform == "win32" +is_linux = sys.platform == "linux" diff --git a/requirements.txt b/requirements.txt index b5c7fcc94..ea2c51597 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,8 @@ pywin32 ; sys_platform=="win32" requests sanic==23.3.0 sanic_cors==2.2.0 +semver==3.0.0 scipy==1.9.3 +pynvml==11.5.0 # torch # torchvision diff --git a/scripts/install-required-deps.ts b/scripts/install-required-deps.ts deleted file mode 100644 index 97c28dc00..000000000 --- a/scripts/install-required-deps.ts +++ /dev/null @@ -1,41 +0,0 @@ -/* eslint-disable no-console */ -import { spawn } from 'child_process'; -import { requiredDependencies } from '../src/common/dependencies'; -import { sanitizedEnv } from '../src/common/env'; - -try { - const command = spawn( - 'python', - [ - '-m', - 'pip', - 'install', - ...requiredDependencies - .map((d) => d.packages.map((p) => `${p.packageName}==${p.version}`)) - .flat(), - '--disable-pip-version-check', - ], - { - env: sanitizedEnv, - } - ); - - command.stdout.on('data', (data: unknown) => { - console.log(String(data)); - }); - - command.stderr.on('data', (data: unknown) => { - console.error(String(data)); - }); - - command.on('error', (error) => { - console.error(error); - process.exit(1); - }); - - command.on('close', (code) => { - process.exit(code ?? 1); - }); -} catch (error) { - console.error(String(error)); -} diff --git a/src/common/Backend.ts b/src/common/Backend.ts index 184b35bbe..fee806215 100644 --- a/src/common/Backend.ts +++ b/src/common/Backend.ts @@ -9,6 +9,7 @@ import { PythonInfo, SchemaId, } from './common-types'; +import { Package } from './dependencies'; export interface BackendSuccessResponse { type: 'success'; @@ -173,6 +174,10 @@ export class Backend { pythonInfo(): Promise { return this.fetchJson('/python-info', 'GET'); } + + dependencies(): Promise { + return this.fetchJson('/dependencies', 'GET'); + } } const backendCache = new Map(); @@ -225,4 +230,9 @@ export interface BackendEventMap { iteratorId: string; running?: string[] | null; }; + 'backend-status': { + message: string; + percent: number; + }; + 'backend-ready': null; } diff --git a/src/common/dependencies.ts b/src/common/dependencies.ts index 5dd247e5c..311630bf7 100644 --- a/src/common/dependencies.ts +++ b/src/common/dependencies.ts @@ -1,213 +1,18 @@ -import semver from 'semver'; import { Version } from './common-types'; -import { isArmMac, isMac, isWindows } from './env'; - -const KB = 1024 ** 1; -const MB = 1024 ** 2; -const GB = 1024 ** 3; export interface PyPiPackage { - packageName: string; + displayName: string; + pypiName: string; version: Version; - findLink?: string; + findLink?: string | null; /** * A size estimate (in bytes) for the whl file to download. */ sizeEstimate: number; - autoUpdate?: boolean; + autoUpdate: boolean; } -export interface Dependency { +export interface Package { name: string; - packages: PyPiPackage[]; + dependencies: PyPiPackage[]; description?: string; } - -const getOnnxRuntime = (canCuda: boolean): PyPiPackage => { - if (isArmMac) { - return { - packageName: 'onnxruntime-silicon', - sizeEstimate: 6 * MB, - version: '1.13.1', - }; - } - if (canCuda) { - return { - packageName: 'onnxruntime-gpu', - sizeEstimate: 110 * MB, - version: '1.13.1', - }; - } - return { - packageName: 'onnxruntime', - sizeEstimate: 5 * MB, - version: '1.13.1', - }; -}; - -const getPytorch = (canCuda: boolean, pythonVersion: Version): PyPiPackage[] => { - if (semver.gte(pythonVersion, '3.8.0') && semver.lt(pythonVersion, '3.10.0')) { - return [ - { - packageName: 'torch', - version: `1.10.2${canCuda ? '+cu113' : ''}`, - findLink: canCuda ? 'https://download.pytorch.org/whl/cu113' : undefined, - sizeEstimate: canCuda ? 2 * GB : 140 * MB, - }, - { - packageName: 'torchvision', - version: `0.11.3${canCuda ? '+cu113' : ''}`, - findLink: canCuda ? 'https://download.pytorch.org/whl/cu113' : undefined, - sizeEstimate: canCuda ? 2 * MB : 800 * KB, - }, - ]; - } - if (semver.gte(pythonVersion, '3.10.0')) { - return [ - { - packageName: 'torch', - version: `1.12.1${canCuda ? '+cu116' : ''}`, - findLink: canCuda ? 'https://download.pytorch.org/whl/cu116' : undefined, - sizeEstimate: canCuda ? 2 * GB : 140 * MB, - }, - { - packageName: 'torchvision', - version: `0.13.1${canCuda ? '+cu116' : ''}`, - findLink: canCuda ? 'https://download.pytorch.org/whl/cu116' : undefined, - sizeEstimate: canCuda ? 2 * MB : 800 * KB, - }, - ]; - } - throw new Error('Unsupported Python version'); -}; - -export const getOptionalDependencies = ( - isNvidiaAvailable: boolean, - pythonVersion: Version -): Dependency[] => { - const canCuda = isNvidiaAvailable && !isMac; - return [ - { - name: 'PyTorch', - packages: [ - ...getPytorch(canCuda, pythonVersion), - { - packageName: 'facexlib', - version: '0.2.5', - sizeEstimate: 1.1 * MB, - }, - { - packageName: 'einops', - version: '0.5.0', - sizeEstimate: 36.5 * KB, - }, - ], - description: - 'PyTorch uses .pth models to upscale images, and is fastest when CUDA is supported (Nvidia GPU). If CUDA is unsupported, it will install with CPU support (which is very slow).', - }, - { - name: 'NCNN', - packages: [ - { - packageName: 'ncnn-vulkan', - version: '2022.9.12', - sizeEstimate: isMac ? 7 * MB : 4 * MB, - autoUpdate: true, - }, - ], - description: - 'NCNN uses .bin/.param models to upscale images. NCNN uses Vulkan for GPU acceleration, meaning it supports any modern GPU. Models can be converted from PyTorch to NCNN.', - }, - { - name: 'ONNX', - packages: [ - { - packageName: 'onnx', - version: '1.13.0', - sizeEstimate: 12 * MB, - }, - ...(!isArmMac - ? ([ - { - packageName: 'onnxoptimizer', - version: '0.3.6', - sizeEstimate: 300 * KB, - }, - ] as PyPiPackage[]) - : []), - getOnnxRuntime(canCuda), - { - packageName: 'protobuf', - version: '3.20.2', - sizeEstimate: 500 * KB, - }, - { - packageName: 'scipy', - version: '1.9.3', - sizeEstimate: 42 * MB, - }, - { - packageName: 'numba', - version: '0.56.3', - sizeEstimate: 2.5 * MB, - }, - ], - description: - 'ONNX uses .onnx models to upscale images. It also helps to convert between PyTorch and NCNN. It is fastest when CUDA is supported. If TensorRT is installed on the system, it can also be configured to use that.', - }, - ]; -}; - -export const requiredDependencies: Dependency[] = [ - { - name: 'Sanic', - packages: [{ packageName: 'sanic', version: '23.3.0', sizeEstimate: 200 * KB }], - }, - { - name: 'Sanic Cors', - packages: [{ packageName: 'Sanic-Cors', version: '2.2.0', sizeEstimate: 17 * KB }], - }, - { - name: 'OpenCV', - packages: [{ packageName: 'opencv-python', version: '4.7.0.68', sizeEstimate: 30 * MB }], - }, - { - name: 'NumPy', - packages: [{ packageName: 'numpy', version: '1.23.2', sizeEstimate: 15 * MB }], - }, - { - name: 'Pillow (PIL)', - packages: [{ packageName: 'Pillow', version: '9.2.0', sizeEstimate: 3 * MB }], - }, - { - name: 'appdirs', - packages: [{ packageName: 'appdirs', version: '1.4.4', sizeEstimate: 13.5 * KB }], - }, - { - name: 'FFMPEG', - packages: [{ packageName: 'ffmpeg-python', version: '0.2.0', sizeEstimate: 25 * KB }], - }, - { - name: 'Requests', - packages: [{ packageName: 'requests', version: '2.28.2', sizeEstimate: 452 * KB }], - }, - { - name: 're2', - packages: [{ packageName: 'google-re2', version: '1.0.0', sizeEstimate: 275 * KB }], - }, - { - name: 'scipy', - packages: [{ packageName: 'scipy', version: '1.9.3', sizeEstimate: 42 * MB }], - }, -]; - -if (isMac && !isArmMac) { - requiredDependencies.push({ - name: 'Pasteboard', - packages: [{ packageName: 'pasteboard', version: '0.3.3', sizeEstimate: 19 * KB }], - }); -} else if (isWindows) { - requiredDependencies.push({ - name: 'Pywin32', - packages: [{ packageName: 'pywin32', version: '304' as Version, sizeEstimate: 12 * MB }], - }); -} diff --git a/src/common/locales/en/translation.json b/src/common/locales/en/translation.json index 477fde186..ccf03f620 100644 --- a/src/common/locales/en/translation.json +++ b/src/common/locales/en/translation.json @@ -57,7 +57,6 @@ } }, "splash": { - "checkingDeps": "Checking dependencies...", "checkingFfmpeg": "Checking system environment for Ffmpeg...", "checkingPort": "Checking for available port...", "checkingPython": "Checking system environment for valid Python...", @@ -65,11 +64,9 @@ "downloadingPython": "Downloading Integrated Python...", "extractingFfmpeg": "Extracting downloaded files...", "extractingPython": "Extracting downloaded files...", - "installingDeps": "Installing required dependencies...", "loading": "Loading...", "loadingApp": "Loading main application...", - "startingBackend": "Starting up backend process...", - "updatingDeps": "Updating dependencies..." + "startingBackend": "Starting up backend process..." }, "typeTags": { "optional": "optional" diff --git a/src/common/pip.ts b/src/common/pip.ts index 35ff723bf..39f9ca78e 100644 --- a/src/common/pip.ts +++ b/src/common/pip.ts @@ -1,7 +1,7 @@ /* eslint-disable no-param-reassign */ import { spawn } from 'child_process'; import { PythonInfo, Version } from './common-types'; -import { Dependency } from './dependencies'; +import { Package } from './dependencies'; import { sanitizedEnv } from './env'; import { log } from './log'; import { pipInstallWithProgress } from './pipInstallWithProgress'; @@ -72,10 +72,10 @@ export const runPipList = async (info: PythonInfo, onStdio?: OnStdio): Promise

[e.name, e.version])); }; -const getFindLinks = (dependencies: readonly Dependency[]): string[] => { +const getFindLinks = (dependencies: readonly Package[]): string[] => { const links = new Set(); for (const d of dependencies) { - for (const p of d.packages) { + for (const p of d.dependencies) { if (p.findLink) { links.add(p.findLink); } @@ -86,7 +86,7 @@ const getFindLinks = (dependencies: readonly Dependency[]): string[] => { export const runPipInstall = async ( info: PythonInfo, - dependencies: readonly Dependency[], + dependencies: readonly Package[], onProgress?: (percentage: number) => void, onStdio?: OnStdio ): Promise => { @@ -94,7 +94,7 @@ export const runPipInstall = async ( if (onProgress === undefined) { // TODO: implement progress via this method (if possible) const deps = dependencies - .map((d) => d.packages.map((p) => `${p.packageName}==${p.version}`)) + .map((d) => d.dependencies.map((p) => `${p.pypiName}==${p.version}`)) .flat(); const findLinks = getFindLinks(dependencies).flatMap((l) => ['--extra-index-url', l]); @@ -102,7 +102,7 @@ export const runPipInstall = async ( } else { const { python } = info; for (const dep of dependencies) { - for (const pkg of dep.packages) { + for (const pkg of dep.dependencies) { // eslint-disable-next-line no-await-in-loop await pipInstallWithProgress(python, pkg, onProgress, onStdio); } @@ -113,12 +113,12 @@ export const runPipInstall = async ( export const runPipUninstall = async ( info: PythonInfo, - dependencies: readonly Dependency[], + dependencies: readonly Package[], onProgress?: (percentage: number) => void, onStdio?: OnStdio ): Promise => { onProgress?.(10); - const deps = dependencies.map((d) => d.packages.map((p) => p.packageName)).flat(); + const deps = dependencies.map((d) => d.dependencies.map((p) => p.pypiName)).flat(); onProgress?.(25); await runPip(info, ['uninstall', '-y', ...deps], onStdio); onProgress?.(100); diff --git a/src/common/pipInstallWithProgress.ts b/src/common/pipInstallWithProgress.ts index 792c7f20f..3093cfbc4 100644 --- a/src/common/pipInstallWithProgress.ts +++ b/src/common/pipInstallWithProgress.ts @@ -91,7 +91,7 @@ export const pipInstallWithProgress = async ( let args = [ 'install', '--upgrade', - `${pkg.packageName}==${pkg.version}`, + `${pkg.pypiName}==${pkg.version}`, '--disable-pip-version-check', ]; if (pkg.findLink) { @@ -123,44 +123,38 @@ export const pipInstallWithProgress = async ( resolve(); }); } else { - const req = https.get( - `https://pypi.org/pypi/${pkg.packageName}/json`, - (res) => { - let output = ''; + const req = https.get(`https://pypi.org/pypi/${pkg.pypiName}/json`, (res) => { + let output = ''; - res.on('data', (d) => { - output += String(d); - }); + res.on('data', (d) => { + output += String(d); + }); - res.on('close', () => { - if (output) { - const releaseData = JSON.parse(output) as { - releases: Record< - string, - { filename: string; url: string }[] - >; - }; - const releases = Array.from(releaseData.releases[pkg.version]); - const find = releases.find( - (file) => file.filename === wheelFileName + res.on('close', () => { + if (output) { + const releaseData = JSON.parse(output) as { + releases: Record; + }; + const releases = Array.from(releaseData.releases[pkg.version]); + const find = releases.find( + (file) => file.filename === wheelFileName + ); + if (!find) + throw new Error( + `Unable for find correct file for ${pkg.pypiName}==${pkg.version}` ); - if (!find) - throw new Error( - `Unable for find correct file for ${pkg.packageName}==${pkg.version}` - ); - const { url } = find; - onStdout(`Downloading package from PyPi at: ${url}\n`); - downloadWheelAndInstall( - python, - url, - wheelFileName, - onProgress, - onStdio - ).then(() => resolve(), reject); - } - }); - } - ); + const { url } = find; + onStdout(`Downloading package from PyPi at: ${url}\n`); + downloadWheelAndInstall( + python, + url, + wheelFileName, + onProgress, + onStdio + ).then(() => resolve(), reject); + } + }); + }); req.on('error', (error) => { onStderr(String(error)); diff --git a/src/common/safeIpc.ts b/src/common/safeIpc.ts index 264caca85..cad07198f 100644 --- a/src/common/safeIpc.ts +++ b/src/common/safeIpc.ts @@ -56,6 +56,7 @@ export interface InvokeChannels { export interface SendChannels { 'splash-setup-progress': SendChannelInfo<[progress: Progress]>; 'backend-ready': SendChannelInfo; + 'backend-started': SendChannelInfo; 'file-new': SendChannelInfo; 'file-open': SendChannelInfo<[FileOpenResult]>; 'file-save-as': SendChannelInfo; diff --git a/src/main/backend/process.ts b/src/main/backend/process.ts index 636fc966a..a8224566a 100644 --- a/src/main/backend/process.ts +++ b/src/main/backend/process.ts @@ -6,7 +6,7 @@ import { getBackend } from '../../common/Backend'; import { PythonInfo } from '../../common/common-types'; import { sanitizedEnv } from '../../common/env'; import { log } from '../../common/log'; -import { lazy } from '../../common/util'; +import { delay, lazy } from '../../common/util'; const getBackendPath = lazy((): string => { const candidates: string[] = [ @@ -186,7 +186,24 @@ export class BorrowedBackendProcess implements BaseBackendProcess { static async fromPort(port: number): Promise { const backend = getBackend(port); - const python = await backend.pythonInfo(); + let python: PythonInfo | undefined; + // try a few times to get python info, in case backend is still starting up + const maxTries = 50; + const startSleep = 1; + const maxSleep = 250; + + for (let i = 0; i < maxTries; i += 1) { + try { + // eslint-disable-next-line no-await-in-loop + python = await backend.pythonInfo(); + } catch { + // eslint-disable-next-line no-await-in-loop + await delay(Math.max(maxSleep, startSleep * 2 ** i)); + } + } + if (!python) { + throw new Error('Unable to get python info from backend'); + } return new BorrowedBackendProcess(port, python); } } diff --git a/src/main/backend/setup.ts b/src/main/backend/setup.ts index 3eeb2dac8..5d98976a0 100644 --- a/src/main/backend/setup.ts +++ b/src/main/backend/setup.ts @@ -2,16 +2,9 @@ import { t } from 'i18next'; import path from 'path'; import portfinder from 'portfinder'; import { FfmpegInfo, PythonInfo } from '../../common/common-types'; -import { - Dependency, - getOptionalDependencies, - requiredDependencies, -} from '../../common/dependencies'; import { log } from '../../common/log'; -import { runPipInstall, runPipList } from '../../common/pip'; import { CriticalError } from '../../common/ui/error'; import { ProgressToken } from '../../common/ui/progress'; -import { versionGt } from '../../common/version'; import { getIntegratedFfmpeg, hasSystemFfmpeg } from '../ffmpeg/ffmpeg'; import { checkPythonPaths } from '../python/checkPythonPaths'; import { getIntegratedPython } from '../python/integratedPython'; @@ -170,78 +163,6 @@ const getFfmpegInfo = async (token: ProgressToken, rootDir: string) => { return ffmpegInfo; }; -const ensurePythonDeps = async ( - token: ProgressToken, - pythonInfo: PythonInfo, - hasNvidia: boolean -) => { - log.info('Attempting to check Python deps...'); - - try { - const pipList = await runPipList(pythonInfo); - const installedPackages = new Set(Object.keys(pipList)); - - const requiredPackages = requiredDependencies.flatMap((dep) => dep.packages); - const optionalPackages = getOptionalDependencies(hasNvidia, pythonInfo.version).flatMap( - (dep) => dep.packages - ); - - // CASE 1: A package isn't installed - const missingRequiredPackages = requiredPackages.filter( - (packageInfo) => !installedPackages.has(packageInfo.packageName) - ); - - // CASE 2: A required package is installed but not the latest version - const outOfDateRequiredPackages = requiredPackages.filter((packageInfo) => { - const installedVersion = pipList[packageInfo.packageName]; - if (!installedVersion) { - return false; - } - return versionGt(packageInfo.version, installedVersion); - }); - - // CASE 3: An optional package is installed, set to auto update, and is not the latest version - const outOfDateOptionalPackages = optionalPackages.filter((packageInfo) => { - const installedVersion = pipList[packageInfo.packageName]; - if (!installedVersion) { - return false; - } - return packageInfo.autoUpdate && versionGt(packageInfo.version, installedVersion); - }); - - const allPackagesThatNeedToBeInstalled = [ - ...missingRequiredPackages, - ...outOfDateRequiredPackages, - ...outOfDateOptionalPackages, - ]; - - if (allPackagesThatNeedToBeInstalled.length > 0) { - const isInstallingRequired = missingRequiredPackages.length > 0; - const isUpdating = - outOfDateRequiredPackages.length > 0 || outOfDateOptionalPackages.length > 0; - - const onlyUpdating = isUpdating && !isInstallingRequired; - token.submitProgress({ - status: onlyUpdating - ? t('splash.updatingDeps', 'Updating dependencies...') - : t('splash.installingDeps', 'Installing required dependencies...'), - totalProgress: 0.7, - }); - - // Try to update/install deps - log.info('Installing/Updating dependencies...'); - await runPipInstall(pythonInfo, [ - { - name: 'All Packages That Need To Be Installed', - packages: allPackagesThatNeedToBeInstalled, - }, - ] as Dependency[]); - } - } catch (error) { - log.error(error); - } -}; - const spawnBackend = (port: number, pythonInfo: PythonInfo, ffmpegInfo: FfmpegInfo) => { try { const backend = OwnedBackendProcess.spawn(port, pythonInfo, { @@ -281,12 +202,6 @@ const setupOwnedBackend = async ( }); const ffmpegInfo = await getFfmpegInfo(token, rootDir); - token.submitProgress({ - status: t('splash.checkingDeps', 'Checking dependencies...'), - totalProgress: 0.6, - }); - await ensurePythonDeps(token, pythonInfo, await hasNvidia()); - token.submitProgress({ status: t('splash.startingBackend', 'Starting up backend process...'), totalProgress: 0.8, diff --git a/src/main/gui/main-window.ts b/src/main/gui/main-window.ts index d391ac109..f284d89ec 100644 --- a/src/main/gui/main-window.ts +++ b/src/main/gui/main-window.ts @@ -1,5 +1,6 @@ import { ChildProcessWithoutNullStreams } from 'child_process'; import { BrowserWindow, app, dialog, nativeTheme, powerSaveBlocker, shell } from 'electron'; +import EventSource from 'eventsource'; import { t } from 'i18next'; import { Version, WindowSize } from '../../common/common-types'; import { log } from '../../common/log'; @@ -342,19 +343,37 @@ export const createMainWindow = async (args: OpenArguments) => { try { registerEventHandlerPreSetup(mainWindow, args); - const backend = await createBackend(SubProgress.slice(progressController, 0, 0.9), args); + const backend = await createBackend(SubProgress.slice(progressController, 0, 0.5), args); registerEventHandlerPostSetup(mainWindow, backend); - progressController.submitProgress({ - status: t('splash.loadingApp', 'Loading main application...'), + const sse = new EventSource(`http://127.0.0.1:${backend.port}/setup-sse`, { + withCredentials: true, }); - if (mainWindow.isDestroyed()) { - return; - } + sse.addEventListener('backend-started', () => { + mainWindow.webContents.send('backend-started'); + }); - ipcMain.once('backend-ready', () => { - progressController.submitProgress({ totalProgress: 1 }); + sse.onerror = (e) => { + log.error(e); + }; + + const backendStatusProgressSlice = SubProgress.slice(progressController, 0.5, 0.95); + sse.addEventListener('backend-status', (e: MessageEvent) => { + if (e.data) { + const data = JSON.parse(e.data) as { message: string; percent: number }; + backendStatusProgressSlice.submitProgress({ + status: data.message, + totalProgress: data.percent, + }); + } + }); + + sse.addEventListener('backend-ready', () => { + progressController.submitProgress({ + totalProgress: 1, + status: t('splash.loadingApp', 'Loading main application...'), + }); if (mainWindow.isDestroyed()) { dialog.showMessageBoxSync({ @@ -376,6 +395,10 @@ export const createMainWindow = async (args: OpenArguments) => { } }); + if (mainWindow.isDestroyed()) { + return; + } + // and load the index.html of the app. mainWindow.loadURL(MAIN_WINDOW_WEBPACK_ENTRY).catch(log.error); } catch (error) { diff --git a/src/renderer/contexts/DependencyContext.tsx b/src/renderer/contexts/DependencyContext.tsx index 60b51e708..13f27aa1e 100644 --- a/src/renderer/contexts/DependencyContext.tsx +++ b/src/renderer/contexts/DependencyContext.tsx @@ -34,7 +34,7 @@ import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { BsQuestionCircle, BsTerminalFill } from 'react-icons/bs'; import { createContext, useContext } from 'use-context-selector'; import { Version } from '../../common/common-types'; -import { Dependency, PyPiPackage, getOptionalDependencies } from '../../common/dependencies'; +import { Package, PyPiPackage } from '../../common/dependencies'; import { Integration, externalIntegrations } from '../../common/externalIntegrations'; import { log } from '../../common/log'; import { OnStdio, PipList, runPipInstall, runPipList, runPipUninstall } from '../../common/pip'; @@ -95,7 +95,7 @@ const FeaturePackage = memo( return ( - {pkg.packageName} + {pkg.pypiName} {!!tagText && {tagText}} {versionString} @@ -123,7 +123,7 @@ const Feature = memo( onUninstall, onUpdate, }: { - dep: Dependency; + dep: Package; pipList: PipList; isRunningShell: boolean; progress?: number; @@ -131,9 +131,9 @@ const Feature = memo( onUninstall: () => void; onUpdate: () => void; }) => { - const missingPackages = dep.packages.filter((p) => !pipList[p.packageName]); - const outdatedPackages = dep.packages.filter((p) => { - const installedVersion = pipList[p.packageName]; + const missingPackages = dep.dependencies.filter((p) => !pipList[p.pypiName]); + const outdatedPackages = dep.dependencies.filter((p) => { + const installedVersion = pipList[p.pypiName]; return installedVersion && versionGt(p.version, installedVersion); }); @@ -157,8 +157,8 @@ const Feature = memo( textAlign="left" w="full" > - {dep.name} ({dep.packages.length} package - {dep.packages.length === 1 ? '' : 's'}) + {dep.name} ({dep.dependencies.length} package + {dep.dependencies.length === 1 ? '' : 's'}) - {dep.packages.map((p) => ( + {dep.dependencies.map((p) => ( ))} @@ -275,7 +275,7 @@ export const DependencyProvider = memo(({ children }: React.PropsWithChildren([]); + const [availableDeps, setAvailableDeps] = useState([]); useAsyncEffect( () => ({ supplier: async () => { - const nvidiaGpu = await ipcRenderer.invoke('get-nvidia-gpu-name'); - const isNvidiaAvailable = nvidiaGpu !== null; - return getOptionalDependencies(isNvidiaAvailable, pythonInfo.version); + const res = await backend.dependencies(); + return res.filter((d) => d.dependencies.length > 0); }, successEffect: setAvailableDeps, }), - [pythonInfo.version] + [backend] ); - const [installingPackage, setInstallingPackage] = useState(null); - const [uninstallingPackage, setUninstallingPackage] = useState(null); + const [installingPackage, setInstallingPackage] = useState(null); + const [uninstallingPackage, setUninstallingPackage] = useState(null); const consoleRef = useRef(null); const [shellOutput, setShellOutput] = useState(''); @@ -352,14 +351,14 @@ export const DependencyProvider = memo(({ children }: React.PropsWithChildren { + const installPackage = (dep: Package) => { setInstallingPackage(dep); changePackages(() => runPipInstall(pythonInfo, [dep], usePipDirectly ? undefined : setProgress, onStdio) ); }; - const uninstallPackage = (dep: Dependency) => { + const uninstallPackage = (dep: Package) => { setUninstallingPackage(dep); changePackages(() => runPipUninstall(pythonInfo, [dep], usePipDirectly ? undefined : setProgress, onStdio) @@ -377,12 +376,12 @@ export const DependencyProvider = memo(({ children }: React.PropsWithChildren { - return availableDeps.filter(({ packages }) => - packages.some(({ packageName, version }) => { + return availableDeps.filter(({ dependencies }) => + dependencies.some(({ pypiName, version }) => { if (!pipList) { return false; } - const installedVersion = pipList[packageName]; + const installedVersion = pipList[pypiName]; if (!installedVersion) { return true; } diff --git a/src/renderer/main.tsx b/src/renderer/main.tsx index eeee755cd..5b77405ba 100644 --- a/src/renderer/main.tsx +++ b/src/renderer/main.tsx @@ -143,12 +143,16 @@ export const Main = memo(({ port }: MainProps) => { return prev; }); } + }, [response, data, loading, error, backendReady, sendAlert, t]); + useIpcRendererListener('backend-ready', () => { + // Refresh the nodes once the backend is ready + setNodesRefreshCounter((prev) => prev + 1); if (!backendReady) { setBackendReady(true); ipcRenderer.send('backend-ready'); } - }, [response, data, loading, error, backendReady, sendAlert, t]); + }); useLastWindowSize();