Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add color correction toggle for img2img #936

Merged
merged 1 commit into from
Sep 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 49 additions & 24 deletions scripts/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,15 +784,35 @@ def classToArrays( items, seed, n_iter ):

return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows


def perform_color_correction(img_rgb, correction_target_lab, do_color_correction):
try:
from skimage import exposure
except:
print("Install scikit-image to perform color correction")
return img_rgb

if not do_color_correction: return img_rgb
if correction_target_lab is None: return img_rgb

return (
Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(img_rgb),
cv2.COLOR_RGB2LAB
),
correction_target_lab,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8")
)
)

def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False,
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None):
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None, do_color_correction=False, correction_target=None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
prompt = prompt or ''
torch_gc()
Expand Down Expand Up @@ -991,6 +1011,7 @@ def process_images(
cfg_scale=cfg_scale, normalize_prompt_weights=normalize_prompt_weights, denoising_strength=denoising_strength,
GFPGAN=use_GFPGAN )
image = Image.fromarray(x_sample)
image = perform_color_correction(image, correction_target, do_color_correction)
ImageMetadata.set_on_image(image, metadata)

original_sample = x_sample
Expand All @@ -1001,6 +1022,7 @@ def process_images(
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(original_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample)
gfpgan_image = perform_color_correction(gfpgan_image, correction_target, do_color_correction)
gfpgan_metadata = copy.copy(metadata)
gfpgan_metadata.GFPGAN = True
ImageMetadata.set_on_image( gfpgan_image, gfpgan_metadata )
Expand All @@ -1018,6 +1040,7 @@ def process_images(
esrgan_filename = original_filename + '-esrgan4x'
esrgan_sample = output[:,:,::-1]
esrgan_image = Image.fromarray(esrgan_sample)
esrgan_image = perform_color_correction(esrgan_image, correction_target, do_color_correction)
ImageMetadata.set_on_image( esrgan_image, metadata )
save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False)
Expand All @@ -1034,6 +1057,7 @@ def process_images(
gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
gfpgan_esrgan_sample = output[:,:,::-1]
gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
gfpgan_esrgan_image = perform_color_correction(gfpgan_esrgan_image, correction_target, do_color_correction)
ImageMetadata.set_on_image(gfpgan_esrgan_image, metadata)
save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
skip_save, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False)
Expand Down Expand Up @@ -1129,6 +1153,10 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
jpg_sample = 7 in toggles
use_GFPGAN = 8 in toggles
use_RealESRGAN = 9 in toggles

do_color_correction = False
correction_target = None

ModelLoader(['model'],True,False)
if use_GFPGAN and not use_RealESRGAN:
ModelLoader(['GFPGAN'],True,False)
Expand Down Expand Up @@ -1194,6 +1222,8 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
variant_amount=variant_amount,
variant_seed=variant_seed,
job_info=job_info,
do_color_correction=do_color_correction,
correction_target=correction_target
)

del sampler
Expand Down Expand Up @@ -1283,8 +1313,9 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
write_info_files = 7 in toggles
write_sample_info_to_log_file = 8 in toggles
jpg_sample = 9 in toggles
use_GFPGAN = 10 in toggles
use_RealESRGAN = 11 in toggles
do_color_correction = 10 in toggles
use_GFPGAN = 11 in toggles
use_RealESRGAN = 12 in toggles
ModelLoader(['model'],True,False)
if use_GFPGAN and not use_RealESRGAN:
ModelLoader(['GFPGAN'],True,False)
Expand Down Expand Up @@ -1461,19 +1492,15 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
return samples_ddim



correction_target = None
if loopback:
output_images, info = None, None
history = []
initial_seed = None

do_color_correction = False
try:
from skimage import exposure
do_color_correction = True
except:
print("Install scikit-image to perform color correction on loopback")

# turn on color correction for loopback to prevent known issue of color drift
do_color_correction = True

for i in range(n_iter):
if do_color_correction and i == 0:
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
Expand Down Expand Up @@ -1512,24 +1539,16 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
write_info_files=write_info_files,
write_sample_info_to_log_file=write_sample_info_to_log_file,
jpg_sample=jpg_sample,
job_info=job_info
job_info=job_info,
do_color_correction=do_color_correction,
correction_target=correction_target
)

if initial_seed is None:
initial_seed = seed

init_img = output_images[0]

if do_color_correction and correction_target is not None:
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(init_img),
cv2.COLOR_RGB2LAB
),
correction_target,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))

if not random_seed_loopback:
seed = seed + 1
else:
Expand All @@ -1548,6 +1567,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
seed = initial_seed

else:
if do_color_correction:
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)

output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
Expand Down Expand Up @@ -1580,7 +1602,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
write_info_files=write_info_files,
write_sample_info_to_log_file=write_sample_info_to_log_file,
jpg_sample=jpg_sample,
job_info=job_info
job_info=job_info,
do_color_correction=do_color_correction,
correction_target=correction_target
)

del sampler
Expand Down Expand Up @@ -2157,6 +2181,7 @@ def run_RealESRGAN(image, model_name: str):
'Write sample info files',
'Write sample info to one file',
'jpg samples',
'Color correction (always enabled on loopback mode)'
]
# removed for now becuase of Image Lab implementation
if GFPGAN is not None:
Expand Down