Skip to content

Commit

Permalink
fix transformers issue for text_model.embeddings.position_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Feb 27, 2024
1 parent dca49d4 commit 71d82b0
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions library/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
init_ipex()

import diffusers
import importlib.metadata
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
from safetensors.torch import load_file, save_file
Expand Down Expand Up @@ -572,6 +573,17 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]

try:
vers = importlib.metadata.version("transformers").split(".")
except Exception:
vers = None

if vers is not None and tuple(vers) <= ('4', '30', '2'):
# support checkpoint without position_ids (invalid checkpoint)
if "text_model.embeddings.position_ids" not in text_model_dict:
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
return text_model_dict

# remove position_ids for newer transformer, which causes error :(
if "text_model.embeddings.position_ids" in text_model_dict:
text_model_dict.pop("text_model.embeddings.position_ids")
Expand Down

0 comments on commit 71d82b0

Please sign in to comment.