From 96697c4bc529bd4c2db586ce12d90a86dbed073f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20D=C3=B3cs?= Date: Sat, 28 Dec 2024 11:30:04 +0100 Subject: [PATCH] serve workflow templates from custom_nodes (#6193) * add GET /workflow_templates * serve workflow templates from custom_nodes * refactor into custom_node_manager, add test * remove unused import * revert changes in folder_paths * Remove trailing whitespace. * account for multiple custom_nodes paths --- app/custom_node_manager.py | 34 ++++++++++++++++ nodes.py | 5 +++ server.py | 4 ++ .../app_test/custom_node_manager_test.py | 40 +++++++++++++++++++ 4 files changed, 83 insertions(+) create mode 100644 app/custom_node_manager.py create mode 100644 tests-unit/app_test/custom_node_manager_test.py diff --git a/app/custom_node_manager.py b/app/custom_node_manager.py new file mode 100644 index 00000000000..32f15aa99e1 --- /dev/null +++ b/app/custom_node_manager.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import os +import folder_paths +import glob +from aiohttp import web + +class CustomNodeManager: + """ + Placeholder to refactor the custom node management features from ComfyUI-Manager. + Currently it only contains the custom workflow templates feature. + """ + def add_routes(self, routes, webapp, loadedModules): + + @routes.get("/workflow_templates") + async def get_workflow_templates(request): + """Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted.""" + files = [ + file + for folder in folder_paths.get_folder_paths("custom_nodes") + for file in glob.glob(os.path.join(folder, '*/example_workflows/*.json')) + ] + workflow_templates_dict = {} # custom_nodes folder name -> example workflow names + for file in files: + custom_nodes_name = os.path.basename(os.path.dirname(os.path.dirname(file))) + workflow_name = os.path.splitext(os.path.basename(file))[0] + workflow_templates_dict.setdefault(custom_nodes_name, []).append(workflow_name) + return web.json_response(workflow_templates_dict) + + # Serve workflow templates from custom nodes. + for module_name, module_dir in loadedModules: + workflows_dir = os.path.join(module_dir, 'example_workflows') + if os.path.exists(workflows_dir): + webapp.add_routes([web.static('/api/workflow_templates/' + module_name, workflows_dir)]) \ No newline at end of file diff --git a/nodes.py b/nodes.py index 89cecc48061..513d8a25693 100644 --- a/nodes.py +++ b/nodes.py @@ -2047,6 +2047,9 @@ def expand_image(self, image, left, top, right, bottom, feathering): EXTENSION_WEB_DIRS = {} +# Dictionary of successfully loaded module names and associated directories. +LOADED_MODULE_DIRS = {} + def get_module_name(module_path: str) -> str: """ @@ -2088,6 +2091,8 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes sys.modules[module_name] = module module_spec.loader.exec_module(module) + LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir) + if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None: web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY"))) if os.path.isdir(web_dir): diff --git a/server.py b/server.py index 8fbfaa89eb5..2dc53b9d430 100644 --- a/server.py +++ b/server.py @@ -30,6 +30,7 @@ from app.frontend_management import FrontendManager from app.user_manager import UserManager from app.model_manager import ModelFileManager +from app.custom_node_manager import CustomNodeManager from typing import Optional from api_server.routes.internal.internal_routes import InternalRoutes @@ -153,6 +154,7 @@ def __init__(self, loop): self.user_manager = UserManager() self.model_file_manager = ModelFileManager() + self.custom_node_manager = CustomNodeManager() self.internal_routes = InternalRoutes(self) self.supports = ["custom_nodes_from_web"] self.prompt_queue = None @@ -697,6 +699,7 @@ async def setup(self): def add_routes(self): self.user_manager.add_routes(self.routes) self.model_file_manager.add_routes(self.routes) + self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) self.app.add_subapp('/internal', self.internal_routes.get_app()) # Prefix every route with /api for easier matching for delegation. @@ -713,6 +716,7 @@ def add_routes(self): self.app.add_routes(api_routes) self.app.add_routes(self.routes) + # Add routes from web extensions. for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir)]) diff --git a/tests-unit/app_test/custom_node_manager_test.py b/tests-unit/app_test/custom_node_manager_test.py new file mode 100644 index 00000000000..89598de8492 --- /dev/null +++ b/tests-unit/app_test/custom_node_manager_test.py @@ -0,0 +1,40 @@ +import pytest +from aiohttp import web +from unittest.mock import patch +from app.custom_node_manager import CustomNodeManager + +pytestmark = ( + pytest.mark.asyncio +) # This applies the asyncio mark to all test functions in the module + +@pytest.fixture +def custom_node_manager(): + return CustomNodeManager() + +@pytest.fixture +def app(custom_node_manager): + app = web.Application() + routes = web.RouteTableDef() + custom_node_manager.add_routes(routes, app, [("ComfyUI-TestExtension1", "ComfyUI-TestExtension1")]) + app.add_routes(routes) + return app + +async def test_get_workflow_templates(aiohttp_client, app, tmp_path): + client = await aiohttp_client(app) + # Setup temporary custom nodes file structure with 1 workflow file + custom_nodes_dir = tmp_path / "custom_nodes" + example_workflows_dir = custom_nodes_dir / "ComfyUI-TestExtension1" / "example_workflows" + example_workflows_dir.mkdir(parents=True) + template_file = example_workflows_dir / "workflow1.json" + template_file.write_text('') + + with patch('folder_paths.folder_names_and_paths', { + 'custom_nodes': ([str(custom_nodes_dir)], None) + }): + response = await client.get('/workflow_templates') + assert response.status == 200 + workflows_dict = await response.json() + assert isinstance(workflows_dict, dict) + assert "ComfyUI-TestExtension1" in workflows_dict + assert isinstance(workflows_dict["ComfyUI-TestExtension1"], list) + assert workflows_dict["ComfyUI-TestExtension1"][0] == "workflow1"