Skip to content

Commit

Permalink
Formatted soft_inpainting.
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeHatchling authored and ruchej committed Sep 30, 2024
1 parent 26e4519 commit 28130bf
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions scripts/soft_inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_modified_nmask(settings, nmask, sigma):


def apply_adaptive_masks(
settings:SoftInpaintingSettings,
settings: SoftInpaintingSettings,
nmask,
latent_orig,
latent_processed,
Expand All @@ -137,10 +137,10 @@ def apply_adaptive_masks(
# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
latent_mask = nmask[0].float()
# convert the original mask into a form we use to scale distances for thresholding
mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
mask_scalar = (0.5 * (1-settings.composite_mask_influence)
mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
mask_scalar = (0.5 * (1 - settings.composite_mask_influence)
+ mask_scalar * settings.composite_mask_influence)
mask_scalar = mask_scalar / (1.00001-mask_scalar)
mask_scalar = mask_scalar / (1.00001 - mask_scalar)
mask_scalar = mask_scalar.cpu().numpy()

latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
Expand All @@ -152,9 +152,9 @@ def apply_adaptive_masks(
for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
converted_mask = distance_map.float().cpu().numpy()
converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
percentile_min=0.9, percentile_max=1, min_width=1)
percentile_min=0.9, percentile_max=1, min_width=1)
converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
percentile_min=0.25, percentile_max=0.75, min_width=1)
percentile_min=0.25, percentile_max=0.75, min_width=1)

# The distance at which opacity of original decreases to 50%
half_weighted_distance = settings.composite_difference_threshold * mask_scalar
Expand Down Expand Up @@ -276,6 +276,7 @@ class WeightedElement:
An element of the histogram, its weight
and bounds.
"""

def __init__(self, value, weight):
self.value: float = value
self.weight: float = weight
Expand Down Expand Up @@ -355,13 +356,15 @@ def sort_key(x: WeightedElement):

return img_out


def smoothstep(x):
"""
The smoothstep function, input should be clamped to 0-1 range.
Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
"""
return x * x * (3 - 2 * x)


def smootherstep(x):
"""
The smootherstep function, input should be clamped to 0-1 range.
Expand All @@ -385,6 +388,7 @@ def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
Returns:
(nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))
"""

# Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.
def gaussian(sqr_mag):
return math.exp(-sqr_mag / (stddev_radius * stddev_radius))
Expand Down Expand Up @@ -656,7 +660,8 @@ def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_t
# p.extra_generation_params["Mask rounding"] = False
settings.add_generation_params(p.extra_generation_params)

def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf,
dif_thresh, dif_contr):
if not enabled:
return

Expand All @@ -675,7 +680,8 @@ def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, de
mba.current_latent,
get_modified_nmask(settings, mba.nmask, mba.sigma[0]))

def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf,
dif_thresh, dif_contr):
if not enabled:
return

Expand Down Expand Up @@ -723,8 +729,8 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta
height=p.height,
paste_to=p.paste_to)


def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale,
detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled:
return

Expand Down

0 comments on commit 28130bf

Please sign in to comment.