Skip to content

Commit a7be43b

Browse files
committed
Add mask_restore option to give users the option to restore images based on mask, fixing #665.
Before commit c73fdd7 (Implement masking during sampling to improve blending, #308) image mask was applied after sampling, resulting in masked parts that are not regenerated to actually stay the same. Since c73fdd7 the masked img2img will change the whole image, even in masked areas. It gives better looking results at first glance, but will result in image degredation when applied a few times. See issue #665. In the workflow of using repeated masked img2img, users may want to use this options to keep the parts of image they actually want to keep without image degradation. A final masked img2img or whole image img2img with mask_restore disabled will give the better blending of "Implement masking during sampling".
1 parent 90a922c commit a7be43b

File tree

3 files changed

+103
-15
lines changed

3 files changed

+103
-15
lines changed

frontend/frontend.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,13 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
197197
value=img2img_mask_modes[img2img_defaults['mask_mode']],
198198
visible=True)
199199

200-
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1,
200+
img2img_mask_restore = gr.Checkbox(label="Restore image by mask",
201+
value=img2img_defaults['mask_restore'],
202+
visible=True)
203+
204+
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=100, step=1,
201205
label="How much blurry should the mask be? (to avoid hard edges)",
202-
value=3, visible=False)
206+
value=3, visible=True)
203207

204208
img2img_resize = gr.Radio(label="Resize mode",
205209
choices=["Just resize", "Crop and resize",
@@ -290,8 +294,14 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
290294
img2img_width,
291295
img2img_height
292296
],
293-
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask,
294-
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength]
297+
[img2img_image_editor,
298+
img2img_image_mask,
299+
img2img_btn_editor,
300+
img2img_btn_mask,
301+
img2img_painterro_btn,
302+
img2img_mask,
303+
img2img_mask_blur_strength,
304+
img2img_mask_restore]
295305
)
296306

297307
# img2img_image_editor_mode.change(
@@ -332,8 +342,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
332342
)
333343

334344
img2img_func = img2img
335-
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask,
336-
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
345+
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength,
346+
img2img_mask_restore, img2img_steps, img2img_sampling, img2img_toggles,
337347
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
338348
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
339349
img2img_image_editor, img2img_image_mask, img2img_embeddings]

scripts/webui.py

+59-5
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ def process_images(
781781
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
782782
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
783783
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
784-
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
784+
keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
785785
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False,
786786
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None):
787787
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@@ -1018,6 +1018,26 @@ def process_images(
10181018
if imgProcessorTask == True:
10191019
output_images.append(image)
10201020

1021+
if mask_restore and init_mask:
1022+
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
1023+
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
1024+
init_mask = init_mask.convert('L')
1025+
init_img = init_img.convert('RGB')
1026+
image = image.convert('RGB')
1027+
1028+
if use_RealESRGAN and RealESRGAN is not None:
1029+
if RealESRGAN.model.name != realesrgan_model_name:
1030+
try_loading_RealESRGAN(realesrgan_model_name)
1031+
output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
1032+
init_img = Image.fromarray(output)
1033+
init_img = init_img.convert('RGB')
1034+
1035+
output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
1036+
init_mask = Image.fromarray(output)
1037+
init_mask = init_mask.convert('L')
1038+
1039+
image = Image.composite(init_img, image, init_mask)
1040+
10211041
if not skip_save:
10221042
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
10231043
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
@@ -1225,7 +1245,7 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
12251245
print("Logged:", filenames[0])
12261246

12271247

1228-
def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str,
1248+
def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, mask_restore: bool, ddim_steps: int, sampler_name: str,
12291249
toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float,
12301250
seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None):
12311251
# print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode,
@@ -1428,6 +1448,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
14281448
init_mask=init_mask,
14291449
keep_mask=keep_mask,
14301450
mask_blur_strength=mask_blur_strength,
1451+
mask_restore=mask_restore,
14311452
denoising_strength=denoising_strength,
14321453
resize_mode=resize_mode,
14331454
uses_loopback=loopback,
@@ -1498,6 +1519,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
14981519
keep_mask=keep_mask,
14991520
mask_blur_strength=mask_blur_strength,
15001521
denoising_strength=denoising_strength,
1522+
mask_restore=mask_restore,
15011523
resize_mode=resize_mode,
15021524
uses_loopback=loopback,
15031525
sort_samples=sort_samples,
@@ -1638,6 +1660,7 @@ def processGoBig(image):
16381660
init_img = result
16391661
init_mask = None
16401662
keep_mask = False
1663+
mask_restore = False
16411664
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
16421665

16431666
def init():
@@ -1784,6 +1807,7 @@ def make_mask_image(r):
17841807
keep_mask=False,
17851808
mask_blur_strength=None,
17861809
denoising_strength=denoising_strength,
1810+
mask_restore=mask_restore,
17871811
resize_mode=resize_mode,
17881812
uses_loopback=False,
17891813
sort_samples=True,
@@ -2086,6 +2110,7 @@ def run_RealESRGAN(image, model_name: str):
20862110
'cfg_scale': 5.0,
20872111
'denoising_strength': 0.75,
20882112
'mask_mode': 0,
2113+
'mask_restore': False,
20892114
'resize_mode': 0,
20902115
'seed': '',
20912116
'height': 512,
@@ -2099,10 +2124,39 @@ def run_RealESRGAN(image, model_name: str):
20992124
img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']]
21002125
img2img_image_mode = 'sketch'
21012126

2102-
def change_image_editor_mode(choice, cropped_image, resize_mode, width, height):
2127+
def change_image_editor_mode(choice, cropped_image, mask, resize_mode, width, height):
21032128
if choice == "Mask":
2104-
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
2105-
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
2129+
update_image_editor = gr.update(visible=False)
2130+
update_image_mask = gr.update(visible=True)
2131+
update_btn_editor = gr.update(visible=False)
2132+
update_btn_mask = gr.update(visible=True)
2133+
update_painterro_btn = gr.update(visible=False)
2134+
update_mask = gr.update(visible=False)
2135+
update_mask_blur_strength = gr.update(visible=True)
2136+
update_mask_restore = gr.update(visible=True)
2137+
# unknown = gr.update(visible=True)
2138+
else:
2139+
update_image_editor = gr.update(visible=True)
2140+
update_image_mask = gr.update(visible=False)
2141+
update_btn_editor = gr.update(visible=True)
2142+
update_btn_mask = gr.update(visible=False)
2143+
update_painterro_btn = gr.update(visible=True)
2144+
update_mask = gr.update(visible=True)
2145+
update_mask_blur_strength = gr.update(visible=False)
2146+
update_mask_restore = gr.update(visible=False)
2147+
# unknown = gr.update(visible=False)
2148+
2149+
return [
2150+
update_image_editor,
2151+
update_image_mask,
2152+
update_btn_editor,
2153+
update_btn_mask,
2154+
update_painterro_btn,
2155+
update_mask,
2156+
update_mask_blur_strength,
2157+
update_mask_restore,
2158+
# unknown,
2159+
]
21062160

21072161
def update_image_mask(cropped_image, resize_mode, width, height):
21082162
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None

scripts/webui_streamlit.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,7 @@ def process_images(
913913
outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size,
914914
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
915915
fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None,
916-
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
916+
keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
917917
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False,
918918
variant_amount=0.0, variant_seed=None, save_individual_images: bool = True):
919919
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@@ -1156,7 +1156,29 @@ def process_images(
11561156

11571157
if simple_templating:
11581158
grid_captions.append( captions[i] + "\ngfpgan_esrgan" )
1159-
1159+
1160+
if mask_restore and init_mask:
1161+
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
1162+
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
1163+
init_mask = init_mask.convert('L')
1164+
init_img = init_img.convert('RGB')
1165+
image = image.convert('RGB')
1166+
1167+
if use_RealESRGAN and st.session_state["RealESRGAN"] is not None:
1168+
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
1169+
#try_loading_RealESRGAN(realesrgan_model_name)
1170+
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
1171+
1172+
output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8))
1173+
init_img = Image.fromarray(output)
1174+
init_img = init_img.convert('RGB')
1175+
1176+
output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8))
1177+
init_mask = Image.fromarray(output)
1178+
init_mask = init_mask.convert('L')
1179+
1180+
image = Image.composite(init_img, image, init_mask)
1181+
11601182
if save_individual_images:
11611183
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
11621184
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
@@ -1257,7 +1279,7 @@ def resize_image(resize_mode, im, width, height):
12571279
return res
12581280

12591281
def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
1260-
ddim_steps: int = 50, sampler_name: str = 'DDIM',
1282+
mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
12611283
n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
12621284
seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
12631285
variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
@@ -1426,6 +1448,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
14261448
init_mask=init_mask,
14271449
keep_mask=keep_mask,
14281450
mask_blur_strength=mask_blur_strength,
1451+
mask_restore=mask_restore,
14291452
denoising_strength=denoising_strength,
14301453
resize_mode=resize_mode,
14311454
uses_loopback=loopback,
@@ -1486,8 +1509,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
14861509
init_img=init_img,
14871510
init_mask=init_mask,
14881511
keep_mask=keep_mask,
1489-
mask_blur_strength=2,
1512+
mask_blur_strength=mask_blur_strength,
14901513
denoising_strength=denoising_strength,
1514+
mask_restore=mask_restore,
14911515
resize_mode=resize_mode,
14921516
uses_loopback=loopback,
14931517
sort_samples=group_by_prompt,

0 commit comments

Comments
 (0)