Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mask_restore to restore images based on mask, fixing #665 #898

Merged
merged 8 commits into from
Sep 9, 2022
22 changes: 16 additions & 6 deletions frontend/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,13 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
value=img2img_mask_modes[img2img_defaults['mask_mode']],
visible=True)

img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1,
img2img_mask_restore = gr.Checkbox(label="Restore image by mask",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img2img_mask_restore = gr.Checkbox(label="Restore image by mask",
img2img_mask_restore = gr.Checkbox(label="Restore masked image",

This probably needs a more user friendly name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I changed it to "Only modify regenerated parts of image".

value=img2img_defaults['mask_restore'],
visible=True)

img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=100, step=1,
label="How much blurry should the mask be? (to avoid hard edges)",
value=3, visible=False)
value=3, visible=True)

img2img_resize = gr.Radio(label="Resize mode",
choices=["Just resize", "Crop and resize",
Expand Down Expand Up @@ -290,8 +294,14 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
img2img_width,
img2img_height
],
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask,
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength]
[img2img_image_editor,
img2img_image_mask,
img2img_btn_editor,
img2img_btn_mask,
img2img_painterro_btn,
img2img_mask,
img2img_mask_blur_strength,
img2img_mask_restore]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break ui_functions.change_image_editor_mode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There I appended the new img2img_mask_restore component to the outputs list of the change_image_editor_mode function, where the gr.update is also added the end of the list to also update the new img2img_mask_restore component.

)

# img2img_image_editor_mode.change(
Expand Down Expand Up @@ -332,8 +342,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
)

img2img_func = img2img
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask,
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength,
img2img_mask_restore, img2img_steps, img2img_sampling, img2img_toggles,
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
img2img_image_editor, img2img_image_mask, img2img_embeddings]
Expand Down
64 changes: 59 additions & 5 deletions scripts/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False,
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
Expand Down Expand Up @@ -1018,6 +1018,26 @@ def process_images(
if imgProcessorTask == True:
output_images.append(image)

if mask_restore and init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')

if use_RealESRGAN and RealESRGAN is not None:
if RealESRGAN.model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name)
output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')

output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')

image = Image.composite(init_img, image, init_mask)

if not skip_save:
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
Expand Down Expand Up @@ -1225,7 +1245,7 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
print("Logged:", filenames[0])


def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str,
def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, mask_restore: bool, ddim_steps: int, sampler_name: str,
toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float,
seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None):
# print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode,
Expand Down Expand Up @@ -1428,6 +1448,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
init_mask=init_mask,
keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength,
mask_restore=mask_restore,
denoising_strength=denoising_strength,
resize_mode=resize_mode,
uses_loopback=loopback,
Expand Down Expand Up @@ -1498,6 +1519,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode,
uses_loopback=loopback,
sort_samples=sort_samples,
Expand Down Expand Up @@ -1638,6 +1660,7 @@ def processGoBig(image):
init_img = result
init_mask = None
keep_mask = False
mask_restore = False
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'

def init():
Expand Down Expand Up @@ -1784,6 +1807,7 @@ def make_mask_image(r):
keep_mask=False,
mask_blur_strength=None,
denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode,
uses_loopback=False,
sort_samples=True,
Expand Down Expand Up @@ -2086,6 +2110,7 @@ def run_RealESRGAN(image, model_name: str):
'cfg_scale': 5.0,
'denoising_strength': 0.75,
'mask_mode': 0,
'mask_restore': False,
'resize_mode': 0,
'seed': '',
'height': 512,
Expand All @@ -2099,10 +2124,39 @@ def run_RealESRGAN(image, model_name: str):
img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']]
img2img_image_mode = 'sketch'

def change_image_editor_mode(choice, cropped_image, resize_mode, width, height):
def change_image_editor_mode(choice, cropped_image, mask, resize_mode, width, height):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: here was an argument missing, I inserted it. It was no issue though, as the arguments are not used anyway.

if choice == "Mask":
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
update_image_editor = gr.update(visible=False)
update_image_mask = gr.update(visible=True)
update_btn_editor = gr.update(visible=False)
update_btn_mask = gr.update(visible=True)
update_painterro_btn = gr.update(visible=False)
update_mask = gr.update(visible=False)
update_mask_blur_strength = gr.update(visible=True)
update_mask_restore = gr.update(visible=True)
# unknown = gr.update(visible=True)
else:
update_image_editor = gr.update(visible=True)
update_image_mask = gr.update(visible=False)
update_btn_editor = gr.update(visible=True)
update_btn_mask = gr.update(visible=False)
update_painterro_btn = gr.update(visible=True)
update_mask = gr.update(visible=True)
update_mask_blur_strength = gr.update(visible=False)
update_mask_restore = gr.update(visible=False)
# unknown = gr.update(visible=False)

return [
update_image_editor,
update_image_mask,
update_btn_editor,
update_btn_mask,
update_painterro_btn,
update_mask,
update_mask_blur_strength,
update_mask_restore,
# unknown,
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this? Has no usages. Please explain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appended the gr.update(..) calls to the end of the returned lists. From the code it was totally unclear which entry in the list was for which component, so I gave each gr.update the name indicating the corresponding updated component and then return the list of named updates.

But I do understand that this is maybe a bit too much restructure for this PR. I will revert and just add gr.update at the end of the lists like you said in the earlier comment.


def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
Expand Down
32 changes: 28 additions & 4 deletions scripts/webui_streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False,
variant_amount=0.0, variant_seed=None, save_individual_images: bool = True):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
Expand Down Expand Up @@ -1156,7 +1156,29 @@ def process_images(

if simple_templating:
grid_captions.append( captions[i] + "\ngfpgan_esrgan" )


if mask_restore and init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')

if use_RealESRGAN and st.session_state["RealESRGAN"] is not None:
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)

output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')

output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')

image = Image.composite(init_img, image, init_mask)

if save_individual_images:
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
Expand Down Expand Up @@ -1257,7 +1279,7 @@ def resize_image(resize_mode, im, width, height):
return res

def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
ddim_steps: int = 50, sampler_name: str = 'DDIM',
mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
Expand Down Expand Up @@ -1426,6 +1448,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
init_mask=init_mask,
keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength,
mask_restore=mask_restore,
denoising_strength=denoising_strength,
resize_mode=resize_mode,
uses_loopback=loopback,
Expand Down Expand Up @@ -1486,8 +1509,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
init_img=init_img,
init_mask=init_mask,
keep_mask=keep_mask,
mask_blur_strength=2,
mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode,
uses_loopback=loopback,
sort_samples=group_by_prompt,
Expand Down