Skip to content

Commit

Permalink
fix: update prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith15 committed Feb 4, 2025
1 parent aa209f0 commit b85d01d
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 87 deletions.
2 changes: 1 addition & 1 deletion server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def on_message(message):
if "prompts" not in params:
logger.warning("[Control] Missing prompt in update_prompt message")
return
pipeline.set_prompts(params["prompts"])
await pipeline.update_prompts(params["prompts"])
response = {
"type": "prompts_updated",
"success": True
Expand Down
180 changes: 94 additions & 86 deletions src/comfystream/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,27 @@ def __init__(self, max_workers: int = 1, **kwargs):
config = Configuration(**kwargs)
# TODO: Need to handle cleanup for EmbeddedComfyClient if not using async context manager?
self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers)
self.running_prompts = []
self.running_prompts = {} # To be used for cancelling tasks
self.current_prompts = []

async def set_prompts(self, prompts: List[PromptDictInput]):
await self.cancel_running_tasks()
for prompt in [convert_prompt(prompt) for prompt in prompts]:
task = asyncio.create_task(self.run_prompt(prompt))
self.running_prompts.append({"task": task, "prompt": prompt})
self.current_prompts = [convert_prompt(prompt) for prompt in prompts]
for idx in range(self.current_prompts):
task = asyncio.create_task(self.run_prompt(idx))
self.running_prompts[idx] = task

async def cancel_running_tasks(self):
while self.running_prompts:
task = self.running_prompts.pop()
task["task"].cancel()
await task["task"]
async def update_prompts(self, prompts: List[PromptDictInput]):
# TODO: currently under the assumption that only already running prompts are updated
if len(prompts) != len(self.current_prompts):
raise ValueError(
"Number of updated prompts must match the number of currently running prompts."
)
self.current_prompts = [convert_prompt(prompt) for prompt in prompts]

async def run_prompt(self, prompt: PromptDictInput):
async def run_prompt(self, prompt_index: int):
while True:
try:
await self.comfy_client.queue_prompt(prompt)
await self.comfy_client.queue_prompt(self.current_prompts[prompt_index])
except Exception as e:
logger.error(f"Error running prompt: {str(e)}")
logger.error(f"Error type: {type(e)}")
Expand All @@ -61,87 +64,92 @@ async def get_available_nodes(self):
try:
from comfy.nodes.package import import_all_nodes_in_workspace
nodes = import_all_nodes_in_workspace()

all_prompts_nodes_info = {}

# Get set of class types we need metadata for, excluding LoadTensor and SaveTensor
needed_class_types = {
node.get('class_type')
for node in self.prompt.values()
if node.get('class_type') not in ('LoadTensor', 'SaveTensor')
}
remaining_nodes = {
node_id
for node_id, node in self.prompt.items()
if node.get('class_type') not in ('LoadTensor', 'SaveTensor')
}
nodes_info = {}

# Only process nodes until we've found all the ones we need
for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items():
if not remaining_nodes: # Exit early if we've found all needed nodes
break

if class_type not in needed_class_types:
continue

# Get metadata for this node type (same as original get_node_metadata)
input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {}
input_info = {}

# Process required inputs
if 'required' in input_data:
for name, value in input_data['required'].items():
if isinstance(value, tuple) and len(value) == 2:
input_type, config = value
input_info[name] = {
'type': input_type,
'required': True,
'min': config.get('min', None),
'max': config.get('max', None),
'widget': config.get('widget', None)
}
else:
logger.error(f"Unexpected structure for required input {name}: {value}")

# Process optional inputs
if 'optional' in input_data:
for name, value in input_data['optional'].items():
if isinstance(value, tuple) and len(value) == 2:
input_type, config = value
input_info[name] = {
'type': input_type,
'required': False,
'min': config.get('min', None),
'max': config.get('max', None),
'widget': config.get('widget', None)
}
else:
logger.error(f"Unexpected structure for optional input {name}: {value}")
for prompt_index, prompt in enumerate(self.current_prompts):
# Get set of class types we need metadata for, excluding LoadTensor and SaveTensor
needed_class_types = {
node.get('class_type')
for node in prompt.values()
if node.get('class_type') not in ('LoadTensor', 'SaveTensor')
}
remaining_nodes = {
node_id
for node_id, node in prompt.items()
if node.get('class_type') not in ('LoadTensor', 'SaveTensor')
}
nodes_info = {}

# Now process any nodes in our prompt that use this class_type
for node_id in list(remaining_nodes):
node = self.prompt[node_id]
if node.get('class_type') != class_type:
# Only process nodes until we've found all the ones we need
for class_type, node_class in nodes.NODE_CLASS_MAPPINGS.items():
if not remaining_nodes: # Exit early if we've found all needed nodes
break

if class_type not in needed_class_types:
continue

node_info = {
'class_type': class_type,
'inputs': {}
}
# Get metadata for this node type (same as original get_node_metadata)
input_data = node_class.INPUT_TYPES() if hasattr(node_class, 'INPUT_TYPES') else {}
input_info = {}

# Process required inputs
if 'required' in input_data:
for name, value in input_data['required'].items():
if isinstance(value, tuple) and len(value) == 2:
input_type, config = value
input_info[name] = {
'type': input_type,
'required': True,
'min': config.get('min', None),
'max': config.get('max', None),
'widget': config.get('widget', None)
}
else:
logger.error(f"Unexpected structure for required input {name}: {value}")

if 'inputs' in node:
for input_name, input_value in node['inputs'].items():
node_info['inputs'][input_name] = {
'value': input_value,
'type': input_info.get(input_name, {}).get('type', 'unknown'),
'min': input_info.get(input_name, {}).get('min', None),
'max': input_info.get(input_name, {}).get('max', None),
'widget': input_info.get(input_name, {}).get('widget', None)
}
# Process optional inputs
if 'optional' in input_data:
for name, value in input_data['optional'].items():
if isinstance(value, tuple) and len(value) == 2:
input_type, config = value
input_info[name] = {
'type': input_type,
'required': False,
'min': config.get('min', None),
'max': config.get('max', None),
'widget': config.get('widget', None)
}
else:
logger.error(f"Unexpected structure for optional input {name}: {value}")

nodes_info[node_id] = node_info
remaining_nodes.remove(node_id)
# Now process any nodes in our prompt that use this class_type
for node_id in list(remaining_nodes):
node = self.prompt[node_id]
if node.get('class_type') != class_type:
continue

node_info = {
'class_type': class_type,
'inputs': {}
}

if 'inputs' in node:
for input_name, input_value in node['inputs'].items():
node_info['inputs'][input_name] = {
'value': input_value,
'type': input_info.get(input_name, {}).get('type', 'unknown'),
'min': input_info.get(input_name, {}).get('min', None),
'max': input_info.get(input_name, {}).get('max', None),
'widget': input_info.get(input_name, {}).get('widget', None)
}

nodes_info[node_id] = node_info
remaining_nodes.remove(node_id)

all_prompts_nodes_info[prompt_index] = nodes_info

return nodes_info
return all_prompts_nodes_info

except Exception as e:
logger.error(f"Error getting node info: {str(e)}")
Expand Down

0 comments on commit b85d01d

Please sign in to comment.