Skip to content

Commit

Permalink
introduce state for task skipping/stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
mashb1t committed Nov 18, 2023
1 parent 83efbbd commit f5c17fb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
7 changes: 6 additions & 1 deletion modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def __init__(self, args):
self.yields = []
self.results = []
self.last_stop = False
self.processing = False


async_tasks = []
Expand Down Expand Up @@ -115,7 +116,9 @@ def build_image_wall(async_task):
@torch.inference_mode()
def handler(async_task):
execution_start_time = time.perf_counter()
async_task.processing = True

execution_start_time = time.perf_counter()
args = async_task.args
args.reverse()

Expand Down Expand Up @@ -660,6 +663,8 @@ def callback(step, x0, x, total_steps, y):
execution_start_time = time.perf_counter()

try:
if async_task.last_stop is not False:
fcbh.model_management.interrupt_current_processing()
positive_cond, negative_cond = task['c'], task['uc']

if 'cn' in goals:
Expand Down Expand Up @@ -732,7 +737,7 @@ def callback(step, x0, x, total_steps, y):

execution_time = time.perf_counter() - execution_start_time
print(f'Generating and saving time: {execution_time:.2f} seconds')

async_task.processing = False
return

while True:
Expand Down
39 changes: 23 additions & 16 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from modules.ui_gradio_extensions import reload_javascript
from modules.auth import auth_enabled, check_auth

currentTask = gr.State()
def get_task(*args):
args = list(args)
currentTask = args.pop(0)
currentTask = worker.AsyncTask(args=args)
return currentTask

def generate_clicked(*args):
def generate_clicked(task):
# outputs=[progress_html, progress_window, progress_gallery, gallery]

execution_start_time = time.perf_counter()
currentTask.value = task = worker.AsyncTask(args=list(args))
finished = False

yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \
Expand Down Expand Up @@ -82,6 +84,7 @@ def generate_clicked(*args):
css=modules.html.css).queue()

with shared.gradio_root:
currentTask = gr.State(worker.AsyncTask(args=[]))
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
Expand All @@ -108,21 +111,24 @@ def generate_clicked(*args):
skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False)
stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False)

def stop_clicked():
def stop_clicked(currentTask):
import fcbh.model_management as model_management
currentTask.value.last_stop = 'stop'
model_management.interrupt_current_processing()
return [gr.update(interactive=False)] * 2
currentTask.last_stop = 'stop'
if (currentTask.processing):
model_management.interrupt_current_processing()
return [gr.update(interactive=False)] * 2, currentTask

def skip_clicked():
def skip_clicked(currentTask):
import fcbh.model_management as model_management
currentTask.value.last_stop = 'skip'
model_management.interrupt_current_processing()
return
currentTask.last_stop = 'skip'
if (currentTask.processing):
model_management.interrupt_current_processing()
return currentTask

stop_button.click(stop_clicked, outputs=[skip_button, stop_button],
stop_button.click(stop_clicked, inputs=currentTask, outputs=[skip_button, stop_button, currentTask],
queue=False, show_progress=False, _js='cancelGenerateForever')
skip_button.click(skip_clicked, queue=False, show_progress=False)
skip_button.click(skip_clicked, inputs=currentTask, outputs=currentTask,
queue=False, show_progress=False)
with gr.Row(elem_classes='advanced_check_row'):
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
Expand Down Expand Up @@ -428,7 +434,7 @@ def model_refresh_clicked():
.then(fn=lambda: None, _js='refresh_grid_delayed', queue=False, show_progress=False)

ctrls = [
prompt, negative_prompt, style_selections,
currentTask, prompt, negative_prompt, style_selections,
performance_selection, aspect_ratios_selection, image_number, image_seed, sharpness, guidance_scale
]

Expand All @@ -441,7 +447,8 @@ def model_refresh_clicked():
generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False), []), outputs=[stop_button, skip_button, generate_button, gallery]) \
.then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \
.then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
.then(fn=get_task, inputs=ctrls, outputs=currentTask) \
.then(fn=generate_clicked, inputs=currentTask, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
.then(lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), outputs=[generate_button, stop_button, skip_button]) \
.then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed')

Expand Down

0 comments on commit f5c17fb

Please sign in to comment.