Skip to content

Commit

Permalink
Cleanup debug code
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Nov 18, 2023
1 parent 2fdacf7 commit 131a953
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 106 deletions.
30 changes: 0 additions & 30 deletions daam/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,10 @@ def plot_overlay_heat_map(
plt_ = ax

with auto_autocast(dtype=torch.float32):
print("im size", im.size)
im.save("testing-img-size.png")
tensor2img(heat_map).save("heat_map.png")
im = np.array(im)
print("nmpy im", im.shape)

print("heatmap to plot", heat_map.size())
heat_map = heat_map.permute(1, 0) # swap width/height to match numpy
# shape height, width
tensor2img(heat_map).save("heat_map-after.png")
tensor2img(torch.from_numpy(im).float() / 255).save("img-after-numpy.png")

if crop is not None:
heat_map = heat_map[crop:-crop, crop:-crop]
Expand All @@ -133,32 +126,10 @@ def plot_overlay_heat_map(
plt_.imshow(heat_map.cpu().numpy(), cmap="jet", vmin=0.0, vmax=1.0)

im = torch.from_numpy(im).float() / 255
print(
f"im {im.size()} heat_map {heat_map.size()} {heat_map.unsqueeze(-1).size()}"
)

# print(im.permute(1, 0, 2).size(), heat_map.permute(1, 0, 2).size())
im = torch.cat((im, (1 - heat_map.unsqueeze(-1))), dim=-1)
print("final image size", im.size(), im.numpy().shape)

# tensor2img(im).save("catted-img.png")

plt_.imshow(im)

# fig = plt_.figure()
# size = fig.get_size_inches()*fig.dpi
#
# print("figure-size", size)

# disable_fig_axis(plt_=plt_)

img = fig2img(fig=plt_.gcf())

# print("fig2img result", img.size)
#
print("fig to pil img", img)
img.save("x-o.png")

if word is not None:
if ax is None:
plt.title(word)
Expand All @@ -183,7 +154,6 @@ def plot_overlay(
self, image, out_file=None, color_normalize=True, ax=None, **expand_kwargs
):
# type: (PIL.Image.Image | np.ndarray, Path, bool, plt.Axes, Dict[str, Any]) -> None
print("generated pil image to overlay", image.size)
plot_overlay_heat_map(
image,
self.expand_as(image, **expand_kwargs),
Expand Down
71 changes: 2 additions & 69 deletions daam/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ObjectHooker,
UNetCrossAttentionLocator,
)
from .utils import auto_autocast, cache_dir, tensor2img
from .utils import auto_autocast, cache_dir, tensor2img, get_max_tensor_width_height

__all__ = ["trace", "DiffusionHeatMapHooker", "GlobalHeatMap"]

Expand Down Expand Up @@ -163,75 +163,20 @@ def compute_global_heat_map(
if (head_idx is None or head_idx == head) and (
layer_idx is None or layer_idx == layer
):
# h = self.img_height // self.sample_size
# w = self.img_width // self.sample_size
# w = self.img_width // 8
# h = self.img_height // 8
# # h = self.img_height
# # w = self.img_width
# # shape 77, 1, 48, 80
# print("compute_global_heat_map")
# print("heatmap size", heat_map.size())
# print("permuting the image", heat_map.size())
# heat_map = heat_map.unsqueeze(1).permute(0, 1, 3, 2).clone()
# # heat_map = heat_map.unsqueeze(1).permute(0, 1, 2, 3).clone()
# print("permuting post", heat_map.size())
#
# # The clamping fixes undershoot.
# heat_map = F.interpolate(
# heat_map.unsqueeze(0).permute(1, 0, 3, 2), size=(w, h), mode="bicubic"
# ).clamp_(min=0)
# #
# print("post interpolation", heat_map.size())

# for i, map in enumerate(heat_map):
# plt.imshow(map.squeeze().cpu().numpy(), cmap="jet")
# plt.title(f"layer {layer:02d} head {head:02d} blk {i:02d}")
# plt.savefig(f"./tmp/heamp-{layer:02d}-{head:02d}-{i:02d}.png")
# plt.clf()

all_merges.append(heat_map)

def get_max_tensor_width_height(tensors):
maxes = {(sum(m.size()), m.size(1), m.size(2)) for m in tensors}

max_max = -1
max_w = -1
max_h = -1
for max, w, h in maxes:
if max > max_max:
max_max = max

max_w = w
max_h = h

assert max_w is not -1
assert max_h is not -1

return max_w, max_h

w, h = get_max_tensor_width_height(all_merges)

# we want to interpolate the dimensions so they are all the same size
for i, merge in enumerate(all_merges):
# The clamping fixes undershoot.
print('merge', merge.unsqueeze(0).permute(1, 0, 2, 3).size())
heat_map = F.interpolate(
merge.unsqueeze(0).permute(1, 0, 3, 2),
size=(w, h),
mode="bicubic",
).clamp_(min=0)
all_merges[i] = heat_map
# if layer == 1 and head == 1:
# [
# tensor2img(hm.squeeze()).save(
# f"./heatmaps/{i}-{ihm}-{w}-{h}-merged.png"
# )
# for ihm, hm in enumerate(heat_map)
# ]

# x = set()
print({m.size() for m in all_merges})

try:
maps = torch.stack(all_merges, dim=0)
except RuntimeError:
Expand Down Expand Up @@ -265,19 +210,13 @@ def _hooked_decode(
):
output = hk_self.monkey_super("decode", z, *args, **kwargs)

print("vae decoded", [img.size() for img in output])

# print("permuted", [img.squeeze().permute(2, 1, 0).size() for img in output])
images = [
to_pil_image(img.permute(1, 2, 0).cpu(), do_rescale=True)
if len(img.size()) == 2
else to_pil_image(img.squeeze().cpu(), do_rescale=True)
for img in output
]

print("post vae to pil image", [img.size for img in images])
[img.save(f"post-vae-img-{i:02d}.png") for i, img in enumerate(images)]

hk_self.parent_trace.last_image = images[len(images) - 1]
return output

Expand Down Expand Up @@ -375,14 +314,8 @@ def _unravel_attn(self, x):

with auto_autocast(dtype=torch.float32):
for i, map_ in enumerate(x):
print("pre view size", map_.size())
map_ = map_.view(map_.size(0), w, h)
map_ = map_[map_.size(0) // 2 :] # Filter out unconditional
print(f"view w, h {map_.size()}")

# print(map_.unsqueeze(1).size())
# to_pil_image(map_.unsqueeze(1).cpu(), do_rescale=True).save("heatmap_{i}.png")

maps.append(map_)

maps = torch.stack(maps, 0) # shape: (tokens, heads, height, width)
Expand Down
23 changes: 16 additions & 7 deletions daam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import sys
from functools import lru_cache
from pathlib import Path
from typing import TypeVar
from typing import TypeVar, Tuple

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
from PIL import Image
import spacy
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -141,10 +140,7 @@ def expand_image(
h = image.size[1]

# shape 77, 1, 48, 80
# print(heatmap.shape)
heatmap = heatmap.unsqueeze(0).unsqueeze(0)
# heatmap = heatmap
# print(heatmap.shape)

# The clamping fixes undershoot.
im = F.interpolate(heatmap, size=(w, h), mode="bicubic").clamp_(min=0)
Expand All @@ -162,8 +158,6 @@ def expand_image(

im = im.cpu().detach().squeeze()

print(f"expanded as {im.size()}")

return im


Expand All @@ -182,3 +176,18 @@ def tensor2img(tensor):
print(f"invalid size to tensor2img {tensor.size()}")

return to_pil_image(tensor, do_rescale=True)


# Measure the dimensions of the tensor and get the largest one by dimension size
# tensor.Size([77, 29, 40]) tensor.Size([77, 22, 38]) tensor.Size([77, 29, 18])
def get_max_tensor_width_height(tensors) -> Tuple[int, int]:
maxes = {(sum(m.size()), m.size(1), m.size(2)) for m in tensors}

max_sum = -1
max = (-1, -1)
for m, w, h in maxes:
if m > max_sum:
max = (w, h)
max_sum = m

return max

0 comments on commit 131a953

Please sign in to comment.