Skip to content

Commit

Permalink
Add cogvideox-2b-img2vid CogVideoXModelLoader support
Browse files Browse the repository at this point in the history
fix for transformer model patch_embed.pos_embedding dtype
or at add line ComfyUI-CogVideoXWrapper/embeddings.py:129 code
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
  • Loading branch information
glide-the committed Dec 6, 2024
1 parent 729a648 commit d9d30f2
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,8 @@ def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_of
model_type = "5b_I2V_1_5"
elif sd["patch_embed.proj.weight"].shape == (1920, 33, 2, 2):
model_type = "fun_2b"
elif sd["patch_embed.proj.weight"].shape == (1920, 32, 2, 2):
model_type = "cogvideox-2b-img2vid"
elif sd["patch_embed.proj.weight"].shape == (1920, 16, 2, 2):
model_type = "2b"
elif sd["patch_embed.proj.weight"].shape == (3072, 32, 2, 2):
Expand All @@ -748,7 +750,7 @@ def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_of
with open(transformer_config_path) as f:
transformer_config = json.load(f)

if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5"]:
if model_type in ["I2V", "I2V_5b", "fun_5b_pose", "5b_I2V_1_5", "cogvideox-2b-img2vid"]:
transformer_config["in_channels"] = 32
if "1_5" in model_type:
transformer_config["ofs_embed_dim"] = 512
Expand All @@ -774,6 +776,10 @@ def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_of
#dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=base_dtype, value=sd[name])
del sd
# TODO fix for transformer model patch_embed.pos_embedding dtype
# or at add line ComfyUI-CogVideoXWrapper/embeddings.py:129 code
# pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
transformer = transformer.to(base_dtype).to(transformer_load_device)

#scheduler
with open(scheduler_config_path) as f:
Expand All @@ -797,7 +803,8 @@ def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_of
dtype=base_dtype,
is_fun_inpaint="fun" in model.lower() and not ("pose" in model.lower() or "control" in model.lower())
)

if "cogvideox-2b-img2vid" == model_type:
pipe.input_with_padding = False
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()

Expand Down

0 comments on commit d9d30f2

Please sign in to comment.