Skip to content

Commit

Permalink
Fixed requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Oct 11, 2024
1 parent aa0237a commit 1567b23
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion modules/token_merging/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
author="Alexander Kozlov",
url="https://github.com/openvinotoolkit/openvino_contrib/tree/master/modules/token_merging",
description="Token Merging for OpenVINO",
install_requires=["torch~=1.13.1", "torchvision~=0.14.1"],
install_requires=["torch~=2.4", "torchvision~=0.19.1"],
dependency_links=["https://download.pytorch.org/whl/cpu"],
extras_require=EXTRAS_REQUIRE,
packages=find_packages(exclude=("examples", "build")),
Expand Down
8 changes: 2 additions & 6 deletions modules/token_merging/tomeov/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def forward(
class_labels=None,
) -> torch.Tensor:
# (1) ToMe
#print(self._tome_info)
m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(hidden_states, self._tome_info)

if self.use_ada_layer_norm:
Expand Down Expand Up @@ -237,7 +236,6 @@ def patch_stable_diffusion(
for _, module in diffusion_model.named_modules():
# If for some reason this has a different name, create an issue and I'll fix it
if isinstance_str(module, "BasicTransformerBlock"):
print("Patch module")
make_tome_block_fn = make_diffusers_tome_block if is_diffusers else make_tome_block
module.__class__ = make_tome_block_fn(module.__class__)
module._tome_info = diffusion_model._tome_info
Expand All @@ -246,7 +244,7 @@ def patch_stable_diffusion(
if not hasattr(module, "disable_self_attn") and not is_diffusers:
module.disable_self_attn = False

if optimize_image_encoder and hasattr(model, "image_encoder") and hasattr(model, "vae_encoder"):
if optimize_image_encoder and hasattr(model, "vae_encoder"):
image_encoder = model.vae_encoder

image_encoder._tome_info = {
Expand All @@ -265,11 +263,9 @@ def patch_stable_diffusion(
}
hook_tome_model(image_encoder)

for name, module in image_encoder.named_modules():
print(name, module.__class__)
for _, module in image_encoder.named_modules():
# If for some reason this has a different name, create an issue and I'll fix it
if isinstance_str(module, "BasicTransformerBlock"):
print("Patch module")
make_tome_block_fn = make_diffusers_tome_block if is_diffusers else make_tome_block
module.__class__ = make_tome_block_fn(module.__class__)
module._tome_info = image_encoder._tome_info
Expand Down

0 comments on commit 1567b23

Please sign in to comment.