Skip to content

Commit

Permalink
feat: simplified input to just have "prompt", removed unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
TimPietrusky committed Oct 11, 2023
1 parent c04de4d commit 0c3ccda
Showing 1 changed file with 52 additions and 41 deletions.
93 changes: 52 additions & 41 deletions src/rp_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import urllib.request
import urllib.parse
import time
import random
import os
import requests

Expand All @@ -18,9 +17,10 @@
COMFY_POLLING_MAX_RETRIES = 100
# Host where ComfyUI is running
COMFY_HOST = "127.0.0.1:8188"
# The path where ComfyUI stores it generated images
# The path where ComfyUI stores the generated images
COMFY_OUTPUT_PATH = "/comfyui/output"


def check_server(url, retries=50, delay=500):
"""
Check if a server is reachable via HTTP GET request
Expand All @@ -39,18 +39,21 @@ def check_server(url, retries=50, delay=500):
response = requests.get(url)
# If the response status code is 200, the server is up and running
if response.status_code == 200:
print(f'runpod-worker-comfy - API is reachable')
print(f"runpod-worker-comfy - API is reachable")
return True
except requests.RequestException as e:
# If an exception occurs, the server may not be ready
pass

# Wait for the specified delay before retrying
time.sleep(delay / 1000)

print(f'runpod-worker-comfy - Failed to connect to server at {url} after {retries} attempts.')
print(
f"runpod-worker-comfy - Failed to connect to server at {url} after {retries} attempts."
)
return False


def queue_prompt(prompt):
"""
Queue a prompt to be processed by ComfyUI
Expand All @@ -61,10 +64,11 @@ def queue_prompt(prompt):
Returns:
dict: The JSON response from ComfyUI after processing the prompt
"""
data = json.dumps(prompt).encode('utf-8')
req = urllib.request.Request(f"http://{COMFY_HOST}/prompt", data=data)
data = json.dumps(prompt).encode("utf-8")
req = urllib.request.Request(f"http://{COMFY_HOST}/prompt", data=data)
return json.loads(urllib.request.urlopen(req).read())


def get_history(prompt_id):
"""
Retrieve the history of a given prompt using its ID
Expand All @@ -78,6 +82,7 @@ def get_history(prompt_id):
with urllib.request.urlopen(f"http://{COMFY_HOST}/history/{prompt_id}") as response:
return json.loads(response.read())


def handler(job):
"""
The main function that handles a job of generating an image.
Expand All @@ -92,53 +97,53 @@ def handler(job):
dict: A dictionary containing either an error message or a success status with generated images.
"""
job_input = job["input"]
prompt_text = job_input.get("prompt")

# Make sure that the ComfyUI API is available
check_server(f"http://{COMFY_HOST}", COMFY_API_AVAILABLE_MAX_RETRIES, COMFY_API_AVAILABLE_INTERVAL_MS)
check_server(
f"http://{COMFY_HOST}",
COMFY_API_AVAILABLE_MAX_RETRIES,
COMFY_API_AVAILABLE_INTERVAL_MS,
)

# Validate input
if prompt_text is None:
if job_input is None:
return {"error": "Please provide the 'prompt'"}

# Is JSON?
if isinstance(prompt_text, dict):
prompt = prompt_text
if isinstance(job_input, dict):
prompt = job_input
# Is String?
elif isinstance(prompt_text, str):
elif isinstance(job_input, str):
try:
prompt = json.loads(prompt_text)
prompt = json.loads(job_input)
except json.JSONDecodeError:
return {"error": "Invalid JSON format in 'prompt'"}
else:
return {"error": "'prompt' must be a JSON object or a JSON-encoded string"}

# TODO: REMOVE
# prompt["prompt"]["3"]["inputs"]["seed"] = random.randint(1, 10000000000)

# Queue the prompt
try:
queued_prompt = queue_prompt(prompt)
prompt_id = queued_prompt['prompt_id']
print(f'runpod-worker-comfy - queued prompt with ID {prompt_id}')
prompt_id = queued_prompt["prompt_id"]
print(f"runpod-worker-comfy - queued prompt with ID {prompt_id}")
except Exception as e:
return {"error": f"Error queuing prompt: {str(e)}"}

# Poll for completion
print(f'runpod-worker-comfy - wait until image generation is complete')
print(f"runpod-worker-comfy - wait until image generation is complete")
retries = 0
try:
while retries < COMFY_POLLING_MAX_RETRIES:
history = get_history(prompt_id)

# Exit the loop if we have found the history
if prompt_id in history and history[prompt_id].get('outputs'):
break
else:
# Wait before trying again
time.sleep(COMFY_POLLING_INTERVAL_MS / 1000)
retries += 1
else:
history = get_history(prompt_id)

# Exit the loop if we have found the history
if prompt_id in history and history[prompt_id].get("outputs"):
break
else:
# Wait before trying again
time.sleep(COMFY_POLLING_INTERVAL_MS / 1000)
retries += 1
else:
return {"error": "Max retries reached while waiting for image generation"}
except Exception as e:
return {"error": f"Error waiting for image generation: {str(e)}"}
Expand All @@ -149,20 +154,26 @@ def handler(job):
outputs = history[prompt_id].get("outputs")

for node_id, node_output in outputs.items():
if 'images' in node_output:
images_output = []
for image in node_output['images']:
output_images = image['filename']
if "images" in node_output:
for image in node_output["images"]:
output_images = image["filename"]

print(f'runpod-worker-comfy - image generation is done')
print(f"runpod-worker-comfy - image generation is done")

# The image is in the output folder
if os.path.exists(f"{COMFY_OUTPUT_PATH}/{output_images}"):
print("runpod-worker-comfy - the image exists in the output folder")
image_url = rp_upload.upload_image(job['id'], f"{COMFY_OUTPUT_PATH}/{output_images}")
return {"status": "success", "message": f'{image_url}'}
print("runpod-worker-comfy - the image exists in the output folder")
image_url = rp_upload.upload_image(
job["id"], f"{COMFY_OUTPUT_PATH}/{output_images}"
)
return {"status": "success", "message": f"{image_url}"}
else:
print("runpod-worker-comfy - the image does not exist in the output folder")
return {"status": "error", "message": f'the image does not exist in the specified output folder: {COMFY_OUTPUT_PATH}/{output_images}'}
print("runpod-worker-comfy - the image does not exist in the output folder")
return {
"status": "error",
"message": f"the image does not exist in the specified output folder: {COMFY_OUTPUT_PATH}/{output_images}",
}


# Start the serverless function with the defined handler.
# Start the handler
runpod.serverless.start({"handler": handler})

0 comments on commit 0c3ccda

Please sign in to comment.