From 131a953e43bcccc60ce5bc180cb76bd17ade2d76 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 17 Nov 2023 21:44:45 -0500 Subject: [PATCH] Cleanup debug code --- daam/heatmap.py | 30 --------------------- daam/trace.py | 71 ++----------------------------------------------- daam/utils.py | 23 +++++++++++----- 3 files changed, 18 insertions(+), 106 deletions(-) diff --git a/daam/heatmap.py b/daam/heatmap.py index 4949130..b4c7577 100644 --- a/daam/heatmap.py +++ b/daam/heatmap.py @@ -110,17 +110,10 @@ def plot_overlay_heat_map( plt_ = ax with auto_autocast(dtype=torch.float32): - print("im size", im.size) - im.save("testing-img-size.png") - tensor2img(heat_map).save("heat_map.png") im = np.array(im) - print("nmpy im", im.shape) - print("heatmap to plot", heat_map.size()) heat_map = heat_map.permute(1, 0) # swap width/height to match numpy # shape height, width - tensor2img(heat_map).save("heat_map-after.png") - tensor2img(torch.from_numpy(im).float() / 255).save("img-after-numpy.png") if crop is not None: heat_map = heat_map[crop:-crop, crop:-crop] @@ -133,32 +126,10 @@ def plot_overlay_heat_map( plt_.imshow(heat_map.cpu().numpy(), cmap="jet", vmin=0.0, vmax=1.0) im = torch.from_numpy(im).float() / 255 - print( - f"im {im.size()} heat_map {heat_map.size()} {heat_map.unsqueeze(-1).size()}" - ) - - # print(im.permute(1, 0, 2).size(), heat_map.permute(1, 0, 2).size()) im = torch.cat((im, (1 - heat_map.unsqueeze(-1))), dim=-1) - print("final image size", im.size(), im.numpy().shape) - - # tensor2img(im).save("catted-img.png") plt_.imshow(im) - # fig = plt_.figure() - # size = fig.get_size_inches()*fig.dpi - # - # print("figure-size", size) - - # disable_fig_axis(plt_=plt_) - - img = fig2img(fig=plt_.gcf()) - - # print("fig2img result", img.size) - # - print("fig to pil img", img) - img.save("x-o.png") - if word is not None: if ax is None: plt.title(word) @@ -183,7 +154,6 @@ def plot_overlay( self, image, out_file=None, color_normalize=True, ax=None, **expand_kwargs ): # type: (PIL.Image.Image | np.ndarray, Path, bool, plt.Axes, Dict[str, Any]) -> None - print("generated pil image to overlay", image.size) plot_overlay_heat_map( image, self.expand_as(image, **expand_kwargs), diff --git a/daam/trace.py b/daam/trace.py index b82de1e..c8b0baa 100644 --- a/daam/trace.py +++ b/daam/trace.py @@ -17,7 +17,7 @@ ObjectHooker, UNetCrossAttentionLocator, ) -from .utils import auto_autocast, cache_dir, tensor2img +from .utils import auto_autocast, cache_dir, tensor2img, get_max_tensor_width_height __all__ = ["trace", "DiffusionHeatMapHooker", "GlobalHeatMap"] @@ -163,75 +163,20 @@ def compute_global_heat_map( if (head_idx is None or head_idx == head) and ( layer_idx is None or layer_idx == layer ): - # h = self.img_height // self.sample_size - # w = self.img_width // self.sample_size - # w = self.img_width // 8 - # h = self.img_height // 8 - # # h = self.img_height - # # w = self.img_width - # # shape 77, 1, 48, 80 - # print("compute_global_heat_map") - # print("heatmap size", heat_map.size()) - # print("permuting the image", heat_map.size()) - # heat_map = heat_map.unsqueeze(1).permute(0, 1, 3, 2).clone() - # # heat_map = heat_map.unsqueeze(1).permute(0, 1, 2, 3).clone() - # print("permuting post", heat_map.size()) - # - # # The clamping fixes undershoot. - # heat_map = F.interpolate( - # heat_map.unsqueeze(0).permute(1, 0, 3, 2), size=(w, h), mode="bicubic" - # ).clamp_(min=0) - # # - # print("post interpolation", heat_map.size()) - - # for i, map in enumerate(heat_map): - # plt.imshow(map.squeeze().cpu().numpy(), cmap="jet") - # plt.title(f"layer {layer:02d} head {head:02d} blk {i:02d}") - # plt.savefig(f"./tmp/heamp-{layer:02d}-{head:02d}-{i:02d}.png") - # plt.clf() - all_merges.append(heat_map) - def get_max_tensor_width_height(tensors): - maxes = {(sum(m.size()), m.size(1), m.size(2)) for m in tensors} - - max_max = -1 - max_w = -1 - max_h = -1 - for max, w, h in maxes: - if max > max_max: - max_max = max - - max_w = w - max_h = h - - assert max_w is not -1 - assert max_h is not -1 - - return max_w, max_h - w, h = get_max_tensor_width_height(all_merges) # we want to interpolate the dimensions so they are all the same size for i, merge in enumerate(all_merges): # The clamping fixes undershoot. - print('merge', merge.unsqueeze(0).permute(1, 0, 2, 3).size()) heat_map = F.interpolate( merge.unsqueeze(0).permute(1, 0, 3, 2), size=(w, h), mode="bicubic", ).clamp_(min=0) all_merges[i] = heat_map - # if layer == 1 and head == 1: - # [ - # tensor2img(hm.squeeze()).save( - # f"./heatmaps/{i}-{ihm}-{w}-{h}-merged.png" - # ) - # for ihm, hm in enumerate(heat_map) - # ] - - # x = set() - print({m.size() for m in all_merges}) + try: maps = torch.stack(all_merges, dim=0) except RuntimeError: @@ -265,9 +210,6 @@ def _hooked_decode( ): output = hk_self.monkey_super("decode", z, *args, **kwargs) - print("vae decoded", [img.size() for img in output]) - - # print("permuted", [img.squeeze().permute(2, 1, 0).size() for img in output]) images = [ to_pil_image(img.permute(1, 2, 0).cpu(), do_rescale=True) if len(img.size()) == 2 @@ -275,9 +217,6 @@ def _hooked_decode( for img in output ] - print("post vae to pil image", [img.size for img in images]) - [img.save(f"post-vae-img-{i:02d}.png") for i, img in enumerate(images)] - hk_self.parent_trace.last_image = images[len(images) - 1] return output @@ -375,14 +314,8 @@ def _unravel_attn(self, x): with auto_autocast(dtype=torch.float32): for i, map_ in enumerate(x): - print("pre view size", map_.size()) map_ = map_.view(map_.size(0), w, h) map_ = map_[map_.size(0) // 2 :] # Filter out unconditional - print(f"view w, h {map_.size()}") - - # print(map_.unsqueeze(1).size()) - # to_pil_image(map_.unsqueeze(1).cpu(), do_rescale=True).save("heatmap_{i}.png") - maps.append(map_) maps = torch.stack(maps, 0) # shape: (tokens, heads, height, width) diff --git a/daam/utils.py b/daam/utils.py index 4b3d8ff..7c90467 100644 --- a/daam/utils.py +++ b/daam/utils.py @@ -3,12 +3,11 @@ import sys from functools import lru_cache from pathlib import Path -from typing import TypeVar +from typing import TypeVar, Tuple import matplotlib.pyplot as plt import numpy as np import PIL.Image -from PIL import Image import spacy import torch import torch.nn.functional as F @@ -141,10 +140,7 @@ def expand_image( h = image.size[1] # shape 77, 1, 48, 80 - # print(heatmap.shape) heatmap = heatmap.unsqueeze(0).unsqueeze(0) - # heatmap = heatmap - # print(heatmap.shape) # The clamping fixes undershoot. im = F.interpolate(heatmap, size=(w, h), mode="bicubic").clamp_(min=0) @@ -162,8 +158,6 @@ def expand_image( im = im.cpu().detach().squeeze() - print(f"expanded as {im.size()}") - return im @@ -182,3 +176,18 @@ def tensor2img(tensor): print(f"invalid size to tensor2img {tensor.size()}") return to_pil_image(tensor, do_rescale=True) + + +# Measure the dimensions of the tensor and get the largest one by dimension size +# tensor.Size([77, 29, 40]) tensor.Size([77, 22, 38]) tensor.Size([77, 29, 18]) +def get_max_tensor_width_height(tensors) -> Tuple[int, int]: + maxes = {(sum(m.size()), m.size(1), m.size(2)) for m in tensors} + + max_sum = -1 + max = (-1, -1) + for m, w, h in maxes: + if m > max_sum: + max = (w, h) + max_sum = m + + return max