diff --git a/js/share_load_controlnet.js b/js/share_load_controlnet.js new file mode 100644 index 00000000..4f6ec5b7 --- /dev/null +++ b/js/share_load_controlnet.js @@ -0,0 +1,43 @@ +import { api } from "../../../scripts/api.js"; +import { app } from "../../scripts/app.js"; +app.registerExtension({ + name: "bizyair.siliconcloud.share.controlnet.loader", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name === "BizyAir_SharedControlNetLoader") { + async function onTextChange(share_id, canvas, comfynode) { + console.log("share_id:", share_id); + const response = await api.fetchApi(`/bizyair/modelhost/${share_id}/models/files?type=bizyair/controlnet`, { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }); + + const { data: controlnets } = await response.json(); + const controlnet_name_widget = comfynode.widgets.find(widget => widget.name === "control_net_name"); + if (controlnets.length > 0) { + controlnet_name_widget.value = controlnets[0]; + controlnet_name_widget.options.values = controlnets; + } else { + console.log("No controlnets found in the response"); + controlnet_name_widget.value = ""; + controlnet_name_widget.options.values = []; + } + } + + function setWigetCallback(){ + const shareid_widget = this.widgets.find(widget => widget.name === "share_id"); + if (shareid_widget) { + shareid_widget.callback = onTextChange; + } else { + console.log("share_id widget not found"); + } + } + const onNodeCreated = nodeType.prototype.onNodeCreated + nodeType.prototype.onNodeCreated = function () { + onNodeCreated?.apply(this, arguments); + setWigetCallback.call(this, arguments); + }; + } + }, +}) diff --git a/nodes.py b/nodes.py index 971d5ec6..a489cc91 100644 --- a/nodes.py +++ b/nodes.py @@ -990,3 +990,28 @@ def INPUT_TYPES(s): # FUNCTION = "set_range" CATEGORY = "advanced/conditioning" + + +class SharedControlNetLoader(BizyAir_ControlNetLoader): + @classmethod + def INPUT_TYPES(s): + ret = super().INPUT_TYPES() + ret["required"]["share_id"] = ("STRING", {"default": "share_id"}) + return ret + + NODE_DISPLAY_NAME = "Shared Load ControlNet Model" + + @classmethod + def VALIDATE_INPUTS(cls, share_id: str, control_net_name: str): + if control_net_name in folder_paths.filename_path_mapping.get("controlnet", {}): + return True + + outs = folder_paths.get_share_filename_list("controlnet", share_id=share_id) + if control_net_name not in outs: + raise ValueError( + f"ControlNet {control_net_name} not found in share {share_id} with {outs}" + ) + return True + + def load_controlnet(self, control_net_name, share_id, **kwargs): + return super().load_controlnet(control_net_name=control_net_name, **kwargs) diff --git a/src/bizyair/path_utils/path_manager.py b/src/bizyair/path_utils/path_manager.py index a636e322..b0b78ce8 100644 --- a/src/bizyair/path_utils/path_manager.py +++ b/src/bizyair/path_utils/path_manager.py @@ -28,6 +28,7 @@ @dataclass class RefreshSettings: loras: bool = True + controlnet: bool = True def get(self, folder_name: str, default: bool = True): return getattr(self, folder_name, default)