Skip to content

Commit

Permalink
[checkpoint] adapted inference to hf pretrained weights (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee authored Jun 17, 2024
1 parent a86a7c6 commit adba114
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
2 changes: 1 addition & 1 deletion configs/opensora-v1-2/inference/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

model = dict(
type="STDiT3-XL/2",
from_pretrained="/mnt/jfs/sora_checkpoints/042-STDiT3-XL-2/epoch0-global_step7200/ema.pt",
from_pretrained="hpcai-tech/OpenSora-STDiT-v3",
qk_norm=True,
enable_flash_attn=True,
enable_layernorm_kernel=True,
Expand Down
12 changes: 3 additions & 9 deletions gradio/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
"v1.2-stage3": "configs/opensora-v1-2/inference/sample.py",
}
HF_STDIT_MAP = {
"v1.2-stage3": {
"ema": "/mnt/jfs-hdd/sora/checkpoints/outputs/042-STDiT3-XL-2/epoch1-global_step16200/ema.pt",
"model": "/mnt/jfs-hdd/sora/checkpoints/outputs/042-STDiT3-XL-2/epoch1-global_step16200/model"
}
"v1.2-stage3": "hpcai-tech/OpenSora-STDiT-v3"
}

# ============================
Expand Down Expand Up @@ -104,11 +101,8 @@ def build_models(model_type, config, enable_optimization=False):
# build stdit
# we load model from HuggingFace directly so that we don't need to
# handle model download logic in HuggingFace Space
from opensora.models.stdit.stdit3 import STDiT3, STDiT3Config
stdit3_config = STDiT3Config.from_pretrained(HF_STDIT_MAP[model_type]['model'])
stdit = STDiT3(stdit3_config)
ckpt = torch.load(HF_STDIT_MAP[model_type]['ema'])
stdit.load_state_dict(ckpt)
from opensora.models.stdit.stdit3 import STDiT3
stdit = STDiT3.from_pretrained(HF_STDIT_MAP[model_type])
stdit = stdit.cuda()

# build scheduler
Expand Down
24 changes: 16 additions & 8 deletions opensora/models/stdit/stdit3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -444,17 +445,24 @@ def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w):

@MODELS.register_module("STDiT3-XL/2")
def STDiT3_XL_2(from_pretrained=None, **kwargs):
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
model = STDiT3(config)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
if from_pretrained is not None and not os.path.isdir(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
model = STDiT3(config)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model


@MODELS.register_module("STDiT3-3B/2")
def STDiT3_3B_2(from_pretrained=None, **kwargs):
config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)
model = STDiT3(config)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
# check if from_pretrained is a path
if from_pretrained is not None and not os.path.isdir(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)
model = STDiT3(config)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model

0 comments on commit adba114

Please sign in to comment.