From b85d01d5e39b9e9cbb65b0cc6eaaab391e373314 Mon Sep 17 00:00:00 2001 From: Varshith B Date: Tue, 4 Feb 2025 17:29:17 +0530 Subject: [PATCH] fix: update prompts --- server/app.py | 2 +- src/comfystream/client.py | 180 ++++++++++++++++++++------------------ 2 files changed, 95 insertions(+), 87 deletions(-) diff --git a/server/app.py b/server/app.py index 170acaa8..04ac04e7 100644 --- a/server/app.py +++ b/server/app.py @@ -155,7 +155,7 @@ async def on_message(message): if "prompts" not in params: logger.warning("[Control] Missing prompt in update_prompt message") return - pipeline.set_prompts(params["prompts"]) + await pipeline.update_prompts(params["prompts"]) response = { "type": "prompts_updated", "success": True diff --git a/src/comfystream/client.py b/src/comfystream/client.py index dff26bc1..40019cae 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -17,24 +17,27 @@ def __init__(self, max_workers: int = 1, **kwargs): config = Configuration(**kwargs) # TODO: Need to handle cleanup for EmbeddedComfyClient if not using async context manager? self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers) - self.running_prompts = [] + self.running_prompts = {} # To be used for cancelling tasks + self.current_prompts = [] async def set_prompts(self, prompts: List[PromptDictInput]): - await self.cancel_running_tasks() - for prompt in [convert_prompt(prompt) for prompt in prompts]: - task = asyncio.create_task(self.run_prompt(prompt)) - self.running_prompts.append({"task": task, "prompt": prompt}) + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + for idx in range(self.current_prompts): + task = asyncio.create_task(self.run_prompt(idx)) + self.running_prompts[idx] = task - async def cancel_running_tasks(self): - while self.running_prompts: - task = self.running_prompts.pop() - task["task"].cancel() - await task["task"] + async def update_prompts(self, prompts: List[PromptDictInput]): + # TODO: currently under the assumption that only already running prompts are updated + if len(prompts) != len(self.current_prompts): + raise ValueError( + "Number of updated prompts must match the number of currently running prompts." + ) + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] - async def run_prompt(self, prompt: PromptDictInput): + async def run_prompt(self, prompt_index: int): while True: try: - await self.comfy_client.queue_prompt(prompt) + await self.comfy_client.queue_prompt(self.current_prompts[prompt_index]) except Exception as e: logger.error(f"Error running prompt: {str(e)}") logger.error(f"Error type: {type(e)}") @@ -61,87 +64,92 @@ async def get_available_nodes(self): try: from comfy.nodes.package import import_all_nodes_in_workspace nodes = import_all_nodes_in_workspace() + + all_prompts_nodes_info = {} - # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor - needed_class_types = { - node.get('class_type') - for node in self.prompt.values() - if node.get('class_type') not in ('LoadTensor', 'SaveTensor') - } - remaining_nodes = { - node_id - for node_id, node in self.prompt.items() - if node.get('class_type') not in ('LoadTensor', 'SaveTensor') - } - nodes_info = {} - - # Only process nodes until we've found all the ones we need - for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items(): - if not remaining_nodes: # Exit early if we've found all needed nodes - break - - if class_type not in needed_class_types: - continue - - # Get metadata for this node type (same as original get_node_metadata) - input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {} - input_info = {} - - # Process required inputs - if 'required' in input_data: - for name, value in input_data['required'].items(): - if isinstance(value, tuple) and len(value) == 2: - input_type, config = value - input_info[name] = { - 'type': input_type, - 'required': True, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) - } - else: - logger.error(f"Unexpected structure for required input {name}: {value}") - - # Process optional inputs - if 'optional' in input_data: - for name, value in input_data['optional'].items(): - if isinstance(value, tuple) and len(value) == 2: - input_type, config = value - input_info[name] = { - 'type': input_type, - 'required': False, - 'min': config.get('min', None), - 'max': config.get('max', None), - 'widget': config.get('widget', None) - } - else: - logger.error(f"Unexpected structure for optional input {name}: {value}") + for prompt_index, prompt in enumerate(self.current_prompts): + # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor + needed_class_types = { + node.get('class_type') + for node in prompt.values() + if node.get('class_type') not in ('LoadTensor', 'SaveTensor') + } + remaining_nodes = { + node_id + for node_id, node in prompt.items() + if node.get('class_type') not in ('LoadTensor', 'SaveTensor') + } + nodes_info = {} - # Now process any nodes in our prompt that use this class_type - for node_id in list(remaining_nodes): - node = self.prompt[node_id] - if node.get('class_type') != class_type: + # Only process nodes until we've found all the ones we need + for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items(): + if not remaining_nodes: # Exit early if we've found all needed nodes + break + + if class_type not in needed_class_types: continue - node_info = { - 'class_type': class_type, - 'inputs': {} - } + # Get metadata for this node type (same as original get_node_metadata) + input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {} + input_info = {} + + # Process required inputs + if 'required' in input_data: + for name, value in input_data['required'].items(): + if isinstance(value, tuple) and len(value) == 2: + input_type, config = value + input_info[name] = { + 'type': input_type, + 'required': True, + 'min': config.get('min', None), + 'max': config.get('max', None), + 'widget': config.get('widget', None) + } + else: + logger.error(f"Unexpected structure for required input {name}: {value}") - if 'inputs' in node: - for input_name, input_value in node['inputs'].items(): - node_info['inputs'][input_name] = { - 'value': input_value, - 'type': input_info.get(input_name, {}).get('type', 'unknown'), - 'min': input_info.get(input_name, {}).get('min', None), - 'max': input_info.get(input_name, {}).get('max', None), - 'widget': input_info.get(input_name, {}).get('widget', None) - } + # Process optional inputs + if 'optional' in input_data: + for name, value in input_data['optional'].items(): + if isinstance(value, tuple) and len(value) == 2: + input_type, config = value + input_info[name] = { + 'type': input_type, + 'required': False, + 'min': config.get('min', None), + 'max': config.get('max', None), + 'widget': config.get('widget', None) + } + else: + logger.error(f"Unexpected structure for optional input {name}: {value}") - nodes_info[node_id] = node_info - remaining_nodes.remove(node_id) + # Now process any nodes in our prompt that use this class_type + for node_id in list(remaining_nodes): + node = self.prompt[node_id] + if node.get('class_type') != class_type: + continue + + node_info = { + 'class_type': class_type, + 'inputs': {} + } + + if 'inputs' in node: + for input_name, input_value in node['inputs'].items(): + node_info['inputs'][input_name] = { + 'value': input_value, + 'type': input_info.get(input_name, {}).get('type', 'unknown'), + 'min': input_info.get(input_name, {}).get('min', None), + 'max': input_info.get(input_name, {}).get('max', None), + 'widget': input_info.get(input_name, {}).get('widget', None) + } + + nodes_info[node_id] = node_info + remaining_nodes.remove(node_id) + + all_prompts_nodes_info[prompt_index] = nodes_info - return nodes_info + return all_prompts_nodes_info except Exception as e: logger.error(f"Error getting node info: {str(e)}")