Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosmos #10660

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft

Cosmos #10660

wants to merge 19 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jan 27, 2025

The cosmos is within us. We are made of star-stuff. We are a way for the universe to know itself.

WIP.

test attention
from typing import Optional
from einops import rearrange

import torch
import torch.nn as nn


class RMSNorm(torch.nn.Module):
    def __init__(
        self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
    ):
        super().__init__()
        self.eps = eps
        self.learnable_scale = elementwise_affine
        if self.learnable_scale:
            self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
        else:
            self.register_parameter("weight", None)

    def forward(self, x):
        r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        if self.weight is None:
            return r
        else:
            return r * self.weight.to(dtype=x.dtype, device=x.device)


def get_normalization(name: str, channels: int):
    if name == "I":
        return nn.Identity()
    elif name == "R":
    #     return te.pytorch.RMSNorm(channels, eps=1e-6)
        return RMSNorm(channels, eps=1e-6)
    else:
        raise ValueError(f"Normalization {name} not found")


class Attention(nn.Module):
    def __init__(
        self,
        query_dim: int,
        context_dim=None,
        heads=8,
        dim_head=64,
        dropout=0.0,
        qkv_bias: bool = False,
        out_bias: bool = False,
        qkv_norm: str = "SSI",
        qkv_norm_mode: str = "per_head",
        backend: str = "transformer_engine",
        qkv_format: str = "bshd",
    ) -> None:
        super().__init__()

        self.is_selfattn = context_dim is None  # self attention

        inner_dim = dim_head * heads
        context_dim = query_dim if context_dim is None else context_dim

        self.heads = heads
        self.dim_head = dim_head
        self.qkv_norm_mode = qkv_norm_mode
        self.qkv_format = qkv_format

        if self.qkv_norm_mode == "per_head":
            norm_dim = dim_head
        else:
            raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'")

        self.backend = backend

        self.to_q = nn.Sequential(
            nn.Linear(query_dim, inner_dim, bias=qkv_bias),
            get_normalization(qkv_norm[0], norm_dim),
        )
        self.to_k = nn.Sequential(
            nn.Linear(context_dim, inner_dim, bias=qkv_bias),
            get_normalization(qkv_norm[1], norm_dim),
        )
        self.to_v = nn.Sequential(
            nn.Linear(context_dim, inner_dim, bias=qkv_bias),
            get_normalization(qkv_norm[2], norm_dim),
        )

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim, bias=out_bias),
            nn.Dropout(dropout),
        )

    def cal_qkv(
        self, x, context=None, mask=None, rope_emb=None, **kwargs
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q = self.to_q[0](x)
        context = x if context is None else context
        k = self.to_k[0](context)
        v = self.to_v[0](context)
        q, k, v = map(
            # lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads, c=self.dim_head),
            lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head),
            (q, k, v),
        )

        q = self.to_q[1](q)
        k = self.to_k[1](k)
        v = self.to_v[1](v)
        if self.is_selfattn and rope_emb is not None:  # only apply to self-attention!
            print("here")
            # q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True)
            # k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True)
            # apply_rotary_pos_emb inlined
            q_shape = q.shape
            q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
            q = torch.cat([rope_emb[..., 0] * q[..., 0], rope_emb[..., 1] * q[..., 1]], dim=-1)
            # q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
            q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)

            # apply_rotary_pos_emb inlined
            k_shape = k.shape
            k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
            k = torch.cat([rope_emb[..., 0] * k[..., 0], rope_emb[..., 1] * k[..., 1]], dim=-1)
            # k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
            k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
        return q, k, v

    def cal_attn(self, q, k, v, mask=None):
        out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        out = rearrange(out, "b n s c -> s b (n c)")
        out = self.to_out(out)
        return out

    def forward(
        self,
        x,
        context=None,
        mask=None,
        rope_emb=None,
        **kwargs,
    ):
        """
        Args:
            x (Tensor): The query tensor of shape [B, Mq, K]
            context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
        """
        q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
        return self.cal_attn(q, k, v, mask)


@torch.no_grad()
def match_rms_norm():
    from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm

    theirs_rmsnorm = RMSNorm(128, elementwise_affine=True, eps=1e-6)
    ours_rmsnorm = DiffusersRMSNorm(128, eps=1e-6, elementwise_affine=True)
    ours_rmsnorm.weight.data.copy_(theirs_rmsnorm.weight.data)

    input = torch.randn(1, 128)
    theirs_output = theirs_rmsnorm(input)
    ours_output = ours_rmsnorm(input)

    print(sum(p.numel() for p in theirs_rmsnorm.parameters()))
    print(sum(p.numel() for p in ours_rmsnorm.parameters()))
    print(torch.allclose(theirs_output, ours_output))


@torch.no_grad()
def match_attention():
    from diffusers.models.attention import Attention as DiffusersAttention

    theirs_attention = Attention(128, 128, heads=8, dim_head=16, qkv_bias=False, out_bias=False, qkv_norm="RRI")
    ours_attention = DiffusersAttention(128, 128, heads=8, dim_head=16, qk_norm="rms_norm", out_bias=False, elementwise_affine=False)
    ours_attention.to_q.weight.data.copy_(theirs_attention.to_q[0].weight.data)
    ours_attention.to_k.weight.data.copy_(theirs_attention.to_k[0].weight.data)
    ours_attention.to_v.weight.data.copy_(theirs_attention.to_v[0].weight.data)
    ours_attention.to_out[0].weight.data.copy_(theirs_attention.to_out[0].weight.data)

    input = torch.randn(1, 42, 128)
    theirs_output = rearrange(theirs_attention(rearrange(input, "b s c -> s b c")), "s b c -> b s c")
    ours_output = ours_attention(input)

    print(sum(p.numel() for p in theirs_attention.parameters()))
    print(sum(p.numel() for p in ours_attention.parameters()))
    print(torch.allclose(theirs_output, ours_output, atol=1e-3))


match_rms_norm()
match_attention()
test ff
from typing import Optional

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class FeedForward(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        dropout: float = 0.1,
        activation=nn.ReLU(),
        is_gated: bool = False,
        bias: bool = False,
    ) -> None:
        super().__init__()

        self.layer1 = nn.Linear(d_model, d_ff, bias=bias)
        self.layer2 = nn.Linear(d_ff, d_model, bias=bias)

        self.dropout = nn.Dropout(dropout)
        self.activation = activation
        self.is_gated = is_gated
        if is_gated:
            self.linear_gate = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x: torch.Tensor):
        g = self.activation(self.layer1(x))
        if self.is_gated:
            x = g * self.linear_gate(x)
        else:
            x = g
        assert self.dropout.p == 0.0, "we skip dropout"
        return self.layer2(x)


class GPT2FeedForward(FeedForward):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False):
        super().__init__(
            d_model=d_model,
            d_ff=d_ff,
            dropout=dropout,
            activation=nn.GELU(),
            is_gated=False,
            bias=bias,
        )

    def forward(self, x: torch.Tensor):
        assert self.dropout.p == 0.0, "we skip dropout"

        x = self.layer1(x)

        def activation_layer2_forward(x):
            x = self.activation(x)
            x = self.layer2(x)
            return x

        x = checkpoint(activation_layer2_forward, x, use_reentrant=False)
        return x


@torch.no_grad()
def match_ff():
    from diffusers.models.attention import FeedForward as DiffusersFeedForward

    theirs_ff = FeedForward(128, 512, 0.0, activation=nn.GELU(), is_gated=True, bias=False)
    ours_ff = DiffusersFeedForward(128, mult=4, dropout=0.0, activation_fn="geglu", bias=False)
    ours_ff.net[0].proj.weight.data[:512, :].copy_(theirs_ff.linear_gate.weight.data)
    ours_ff.net[0].proj.weight.data[512:, :].copy_(theirs_ff.layer1.weight.data)
    ours_ff.net[2].weight.data.copy_(theirs_ff.layer2.weight.data)

    input = torch.randn(1, 128)
    theirs_output = theirs_ff(input)
    ours_output = ours_ff(input)

    print(sum(p.numel() for p in theirs_ff.parameters()))
    print(sum(p.numel() for p in ours_ff.parameters()))
    print(torch.allclose(theirs_output, ours_output))


match_ff()
test timesteps
import itertools
import math

import torch
import torch.nn as nn

class Timesteps(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.num_channels = num_channels

    def forward(self, timesteps):
        in_dype = timesteps.dtype
        half_dim = self.num_channels // 2
        exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
        exponent = exponent / (half_dim - 0.0)

        emb = torch.exp(exponent)
        emb = timesteps[:, None].float() * emb[None, :]

        sin_emb = torch.sin(emb)
        cos_emb = torch.cos(emb)
        emb = torch.cat([cos_emb, sin_emb], dim=-1)

        # return emb.to(in_dype)
        return emb


class TimestepEmbedding(nn.Module):
    def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False):
        super().__init__()
        self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora)
        self.activation = nn.SiLU()
        self.use_adaln_lora = use_adaln_lora
        if use_adaln_lora:
            self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
        else:
            self.linear_2 = nn.Linear(out_features, out_features, bias=True)

    def forward(self, sample: torch.Tensor) -> torch.Tensor:
        sample = sample.to(self.linear_1.weight.dtype)
        emb = self.linear_1(sample)
        emb = self.activation(emb)
        emb = self.linear_2(emb)

        if self.use_adaln_lora:
            emb_B_D = sample
            adaln_lora_B_3D = emb
        else:
            emb_B_D = emb
            adaln_lora_B_3D = None

        return emb_B_D, adaln_lora_B_3D


class CosmosTimestepEmbedding(nn.Module):
    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(in_features, out_features, bias=False)
        self.activation = nn.SiLU()
        self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
    
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        emb = self.linear_1(hidden_states)
        emb = self.activation(emb)
        emb = self.linear_2(emb)
        return hidden_states, emb


@torch.no_grad()
def match_timestep():
    from diffusers.models.embeddings import Timesteps as DiffusersTimesteps

    theirs_timesteps = Timesteps(256)
    ours_timesteps = DiffusersTimesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0.0)

    input = torch.tensor([1000.0], dtype=torch.float32)
    theirs_output = theirs_timesteps(input)
    ours_output = ours_timesteps(input)

    print(torch.allclose(theirs_output, ours_output))


@torch.no_grad()
def match_timestep_embedding():
    theirs_temb = TimestepEmbedding(256, 256, use_adaln_lora=True)
    ours_temb = CosmosTimestepEmbedding(256, 256)
    ours_temb.linear_1.weight.data.copy_(theirs_temb.linear_1.weight.data)
    ours_temb.linear_2.weight.data.copy_(theirs_temb.linear_2.weight.data)

    input = torch.randn(1, 256)
    theirs_output = theirs_temb(input)
    ours_output = ours_temb(input)

    print(sum(p.numel() for p in theirs_temb.parameters()))
    print(sum(p.numel() for p in ours_temb.parameters()))
    print(torch.allclose(theirs_output[0], ours_output[0]))
    print(torch.allclose(theirs_output[1], ours_output[1]))


@torch.no_grad()
def match_timestep_embedding_2():
    from diffusers.models.transformers.transformer_cosmos import CosmosTimestepEmbedding
    theirs_temb = TimestepEmbedding(256, 256, use_adaln_lora=True)
    ours_temb = CosmosTimestepEmbedding(256, 256)
    ours_temb.linear_1.weight.data.copy_(theirs_temb.linear_1.weight.data)
    ours_temb.linear_2.weight.data.copy_(theirs_temb.linear_2.weight.data)

    input = torch.randn(1, 256)
    theirs_output = theirs_temb(input)
    ours_output = ours_temb(input)

    print(sum(p.numel() for p in theirs_temb.parameters()))
    print(sum(p.numel() for p in ours_temb.parameters()))
    print(torch.allclose(theirs_output[1], ours_output))


@torch.no_grad()
def match_timestep_prepare_embedding():
    from diffusers.models.transformers.transformer_cosmos import CosmosEmbedding
    from diffusers.models.normalization import RMSNorm
    theirs_t_embedder = nn.Sequential(
        Timesteps(4096),
        TimestepEmbedding(4096, 4096, use_adaln_lora=True),
    )
    theirs_norm = RMSNorm(4096, 1e-6, True)
    ours_t_embedder = CosmosEmbedding(4096, 4096)
    ours_t_embedder.t_embedder.linear_1.weight.data.copy_(theirs_t_embedder[1].linear_1.weight.data)
    ours_t_embedder.t_embedder.linear_2.weight.data.copy_(theirs_t_embedder[1].linear_2.weight.data)
    ours_t_embedder.norm.weight.data.copy_(theirs_norm.weight.data)

    hidden_states = torch.randn(1, 1, 4096)
    input = torch.randint(0, 1000, (1,)).long()
    theirs_output = theirs_t_embedder(input)
    ours_output = ours_t_embedder(hidden_states, input)

    print(sum(p.numel() for p in itertools.chain(theirs_t_embedder.parameters(), theirs_norm.parameters())))
    print(sum(p.numel() for p in ours_t_embedder.parameters()))
    print(torch.allclose(theirs_output[1], ours_output[0]))
    print(torch.allclose(theirs_norm(theirs_output[0]), ours_output[1]))


match_timestep()
print()

match_timestep_embedding()
print()

match_timestep_embedding_2()
print()

match_timestep_prepare_embedding()
print()
test patch embed
import torch
import torch.nn as nn

from einops.layers.torch import Rearrange

class PatchEmbed(nn.Module):
    def __init__(
        self,
        spatial_patch_size,
        temporal_patch_size,
        in_channels=3,
        out_channels=768,
        bias=True,
    ):
        super().__init__()
        self.spatial_patch_size = spatial_patch_size
        self.temporal_patch_size = temporal_patch_size

        self.proj = nn.Sequential(
            Rearrange(
                "b c (t r) (h m) (w n) -> b t h w (c r m n)",
                r=temporal_patch_size,
                m=spatial_patch_size,
                n=spatial_patch_size,
            ),
            nn.Linear(
                in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias
            ),
        )
        self.out = nn.Identity()

    def forward(self, x):
        assert x.dim() == 5
        _, _, T, H, W = x.shape
        assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0
        assert T % self.temporal_patch_size == 0
        x = self.proj(x)
        return self.out(x)


@torch.no_grad()
def match_patch_embed():
    from diffusers.models.transformers.transformer_cosmos import CosmosPatchEmbed

    theirs_patch_embed = PatchEmbed(2, 1, 16, 4096, bias=False)
    ours_patch_embed = CosmosPatchEmbed(16, 4096, (1, 2, 2), bias=False)

    ours_patch_embed.proj.weight.data.copy_(theirs_patch_embed.proj[1].weight.data)

    input = torch.randn(1, 16, 128, 240, 240)

    theirs_output = theirs_patch_embed(input)
    ours_output = ours_patch_embed(input)

    print(torch.allclose(theirs_output, ours_output))

match_patch_embed()
test positional embed
import math
from typing import Optional, List

import numpy as np
import torch
from einops import rearrange, repeat


def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor:
    if dim is None:
        dim = list(range(1, x.ndim))
    norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
    norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
    return x / norm.to(x.dtype)


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.0))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)



class VideoPositionEmb(torch.nn.Module):
    def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor:
        """
        It delegates the embedding generation to generate_embeddings function.
        """
        B_T_H_W_C = x_B_T_H_W_C.shape
        embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps)

        return embeddings

    def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]):
        raise NotImplementedError


class VideoRopePosition3DEmb(VideoPositionEmb):
    def __init__(
        self,
        *,  # enforce keyword arguments
        head_dim: int,
        len_h: int,
        len_w: int,
        len_t: int,
        base_fps: int = 24,
        h_extrapolation_ratio: float = 1.0,
        w_extrapolation_ratio: float = 1.0,
        t_extrapolation_ratio: float = 1.0,
        **kwargs,  # used for compatibility with other positional embeddings; unused in this class
    ):
        del kwargs
        super().__init__()
        self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float))
        self.base_fps = base_fps
        self.max_h = len_h
        self.max_w = len_w

        dim = head_dim
        dim_h = dim // 6 * 2
        dim_w = dim_h
        dim_t = dim - 2 * dim_h
        assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
        self.register_buffer(
            "dim_spatial_range",
            # torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h,
            torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h,
            persistent=False,
        )
        self.register_buffer(
            "dim_temporal_range",
            # torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t,
            torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t,
            persistent=False,
        )

        self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2))
        self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2))
        self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2))

    def generate_embeddings(
        self,
        B_T_H_W_C: torch.Size,
        fps: Optional[torch.Tensor] = None,
        h_ntk_factor: Optional[float] = None,
        w_ntk_factor: Optional[float] = None,
        t_ntk_factor: Optional[float] = None,
    ):
        """
        Generate embeddings for the given input size.

        Args:
            B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels).
            fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None.
            h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor.
            w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor.
            t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor.

        Returns:
            Not specified in the original code snippet.
        """
        h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor
        w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor
        t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor

        h_theta = 10000.0 * h_ntk_factor
        w_theta = 10000.0 * w_ntk_factor
        t_theta = 10000.0 * t_ntk_factor

        h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range)
        w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range)
        temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range)

        B, T, H, W, _ = B_T_H_W_C
        uniform_fps = (fps is None) or (fps.min() == fps.max())
        assert (
            uniform_fps or B == 1 or T == 1
        ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1"
        assert (
            H <= self.max_h and W <= self.max_w
        ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})"
        half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs)
        half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs)

        # apply sequence scaling in temporal dimension
        if fps is None:  # image case
            assert T == 1, "T should be 1 for image batch."
            half_emb_t = torch.outer(self.seq[:T], temporal_freqs)
        else:
            half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs)

        em_T_H_W_D = torch.cat(
            [
                repeat(half_emb_t, "t d -> t h w d", h=H, w=W),
                repeat(half_emb_h, "h d -> t h w d", t=T, w=W),
                repeat(half_emb_w, "w d -> t h w d", t=T, h=H),
            ]
            * 2,
            dim=-1,
        )

        return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float()


class LearnablePosEmbAxis(VideoPositionEmb):
    def __init__(
        self,
        *,  # enforce keyword arguments
        interpolation: str,
        model_channels: int,
        len_h: int,
        len_w: int,
        len_t: int,
        **kwargs,
    ):
        """
        Args:
            interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet.
        """
        del kwargs  # unused
        super().__init__()
        self.interpolation = interpolation
        assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"

        self.pos_emb_h = torch.nn.Parameter(torch.zeros(len_h, model_channels))
        self.pos_emb_w = torch.nn.Parameter(torch.zeros(len_w, model_channels))
        self.pos_emb_t = torch.nn.Parameter(torch.zeros(len_t, model_channels))

        trunc_normal_(self.pos_emb_h, std=0.02)
        trunc_normal_(self.pos_emb_w, std=0.02)
        trunc_normal_(self.pos_emb_t, std=0.02)

    def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor:
        B, T, H, W, _ = B_T_H_W_C
        if self.interpolation == "crop":
            emb_h_H = self.pos_emb_h[:H]
            emb_w_W = self.pos_emb_w[:W]
            emb_t_T = self.pos_emb_t[:T]
            emb = (
                repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
                + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
                + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H)
            )
            assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}"
        else:
            raise ValueError(f"Unknown interpolation method {self.interpolation}")

        return normalize(emb, dim=-1, eps=1e-6)


@torch.no_grad()
def match_rope():
    from diffusers.models.transformers.transformer_cosmos import CosmosRotaryPosEmbed

    theirs_rope = VideoRopePosition3DEmb(head_dim=128, len_h=240 // 2, len_w=240 // 2, len_t=128 // 1, base_fps=24, h_extrapolation_ratio=1.0, w_extrapolation_ratio=1.0, t_extrapolation_ratio=2.0)
    ours_rope = CosmosRotaryPosEmbed(hidden_size=128, max_size=(128, 240, 240), patch_size=(1, 2, 2), base_fps=24, rope_scale=(2.0, 1.0, 1.0))

    hidden_states = torch.randn(2, 2, 32, 32, 16)
    fps = 30

    theirs_output = theirs_rope(hidden_states[:, :, :16, :16, :], fps=torch.tensor([fps]))  # the input slicing is to replicate patchification operation
    ours_output = ours_rope(hidden_states.permute(0, 4, 1, 2, 3), fps=fps)

    theirs_cos, theirs_sin = torch.cos(theirs_output), torch.sin(theirs_output)
    print(torch.allclose(ours_output[0][:, None, None, :], theirs_cos))
    print(torch.allclose(ours_output[1][:, None, None, :], theirs_sin))


@torch.no_grad()
def match_learnable_pe():
    from diffusers.models.transformers.transformer_cosmos import CosmosLearnablePositionalEmbed

    theirs_pe = LearnablePosEmbAxis(interpolation="crop", model_channels=4096, len_h=240 // 2, len_w=240 // 2, len_t=128 // 1)
    ours_pe = CosmosLearnablePositionalEmbed(4096, max_size=(128, 240, 240), patch_size=(1, 2, 2), eps=1e-6)

    ours_pe.pos_emb_t.data.copy_(theirs_pe.pos_emb_t.data)
    ours_pe.pos_emb_h.data.copy_(theirs_pe.pos_emb_h.data)
    ours_pe.pos_emb_w.data.copy_(theirs_pe.pos_emb_w.data)

    hidden_states = torch.randn(2, 2, 32, 32, 16)

    theirs_output = theirs_pe(hidden_states[:, :, :16, :16, :])
    ours_output = ours_pe(hidden_states.permute(0, 4, 1, 2, 3))

    theirs_output = theirs_output.flatten(1, 3)
    print(torch.allclose(ours_output, theirs_output))


# match_rope()
match_learnable_pe()
test transformer block
import sys
sys.path.append("/raid/aryan/cosmos-code/")

import torch
from cosmos1.models.diffusion.module.blocks import GeneralDITTransformerBlock


@torch.no_grad()
def match_transformer_block():
    from diffusers.models.transformers.transformer_cosmos import CosmosTransformerBlock

    theirs_transformer_block = GeneralDITTransformerBlock(
        x_dim=4096,
        context_dim=1024,
        num_heads=32,
        block_config="FA-CA-MLP",
        mlp_ratio=4.0,
        x_format="BTHWD",
        use_adaln_lora=True,
        adaln_lora_dim=256,
    )

    ours_transformer_block = CosmosTransformerBlock(
        num_attention_heads=32,
        attention_head_dim=128,
        cross_attention_dim=1024,
        mlp_ratio=4,
        adaln_lora_dim=256,
        qk_norm="rms_norm",
        out_bias=False,
    )

    ours_transformer_block.norm1.linear_1.weight.data.copy_(theirs_transformer_block.blocks[0].adaLN_modulation[1].weight.data)
    ours_transformer_block.norm1.linear_2.weight.data.copy_(theirs_transformer_block.blocks[0].adaLN_modulation[2].weight.data)
    
    ours_transformer_block.attn1.to_q.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_q[0].weight.data)
    ours_transformer_block.attn1.to_k.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_k[0].weight.data)
    ours_transformer_block.attn1.to_v.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_v[0].weight.data)
    ours_transformer_block.attn1.to_out[0].weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_out[0].weight.data)
    ours_transformer_block.attn1.norm_q.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_q[1].weight.data)
    ours_transformer_block.attn1.norm_k.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_k[1].weight.data)

    ours_transformer_block.norm2.linear_1.weight.data.copy_(theirs_transformer_block.blocks[1].adaLN_modulation[1].weight.data)
    ours_transformer_block.norm2.linear_2.weight.data.copy_(theirs_transformer_block.blocks[1].adaLN_modulation[2].weight.data)

    ours_transformer_block.attn2.to_q.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_q[0].weight.data)
    ours_transformer_block.attn2.to_k.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_k[0].weight.data)
    ours_transformer_block.attn2.to_v.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_v[0].weight.data)
    ours_transformer_block.attn2.to_out[0].weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_out[0].weight.data)
    ours_transformer_block.attn2.norm_q.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_q[1].weight.data)
    ours_transformer_block.attn2.norm_k.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_k[1].weight.data)

    ours_transformer_block.norm3.linear_1.weight.data.copy_(theirs_transformer_block.blocks[2].adaLN_modulation[1].weight.data)
    ours_transformer_block.norm3.linear_2.weight.data.copy_(theirs_transformer_block.blocks[2].adaLN_modulation[2].weight.data)

    ours_transformer_block.ff.net[0].proj.weight.data.copy_(theirs_transformer_block.blocks[2].block.layer1.weight.data)
    ours_transformer_block.ff.net[2].weight.data.copy_(theirs_transformer_block.blocks[2].block.layer2.weight.data)

    # ============
    batch_size = 1
    latent_num_frames = 2
    latent_height = 16
    latent_width = 16
    embedding_dim = 4096
    encoder_seq_length = 64
    encoder_dim = 1024
    

    hidden_states = torch.randn(batch_size, latent_num_frames, latent_height, latent_width, embedding_dim)
    temb = torch.randn(batch_size, embedding_dim)
    encoder_hidden_states = torch.randn(batch_size, encoder_seq_length, encoder_dim)
    attention_mask = None
    freqs = torch.randn(1, 1, latent_num_frames * latent_height * latent_width, 128)
    embedded_timestep = torch.randn(batch_size, 3 * embedding_dim)
    extra_per_block_emb = torch.randn(batch_size, latent_num_frames, latent_height, latent_width, embedding_dim)

    theirs_output = theirs_transformer_block(
        x=hidden_states.flatten(1, 3).permute(1, 0, 2),
        emb_B_D=temb,
        crossattn_emb=encoder_hidden_states.permute(1, 0, 2),
        crossattn_mask=attention_mask,
        rope_emb_L_1_1_D=freqs.permute(2, 0, 1, 3),
        adaln_lora_B_3D=embedded_timestep,
        extra_per_block_pos_emb=extra_per_block_emb.flatten(1, 3).permute(1, 0, 2),
    )
    ours_output = ours_transformer_block(
        hidden_states=hidden_states.flatten(1, 3),
        encoder_hidden_states=encoder_hidden_states,
        temb=temb,
        embedded_timestep=embedded_timestep,
        image_rotary_emb=(torch.cos(freqs.flatten(0, 2)), torch.sin(freqs.flatten(0, 2))),
        extra_pos_emb=extra_per_block_emb.flatten(1, 3),
        attention_mask=attention_mask,
    )

    theirs_output = theirs_output.flatten(0, 2).permute(1, 0, 2)
    print(sum(p.numel() for p in theirs_transformer_block.parameters()))
    print(sum(p.numel() for p in ours_transformer_block.parameters()))
    print(torch.allclose(theirs_output.flatten(), ours_output.flatten(), atol=1e-4))


match_transformer_block()

# GeneralDITTransformerBlock(
#   (blocks): ModuleList(
#     (0): DITBuildingBlock(
#       (block): VideoAttn(
#         (attn): Attention(
#           (to_q): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): RMSNorm()
#           )
#           (to_k): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): RMSNorm()
#           )
#           (to_v): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): Identity()
#           )
#           (to_out): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): Dropout(p=0.0, inplace=False)
#           )
#         )
#       )
#       (norm_state): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#       (adaLN_modulation): Sequential(
#         (0): SiLU()
#         (1): Linear(in_features=4096, out_features=256, bias=False)
#         (2): Linear(in_features=256, out_features=12288, bias=False)
#       )
#     )
#     (1): DITBuildingBlock(
#       (block): VideoAttn(
#         (attn): Attention(
#           (to_q): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): RMSNorm()
#           )
#           (to_k): Sequential(
#             (0): Linear(in_features=1204, out_features=4096, bias=False)
#             (1): RMSNorm()
#           )
#           (to_v): Sequential(
#             (0): Linear(in_features=1204, out_features=4096, bias=False)
#             (1): Identity()
#           )
#           (to_out): Sequential(
#             (0): Linear(in_features=4096, out_features=4096, bias=False)
#             (1): Dropout(p=0.0, inplace=False)
#           )
#         )
#       )
#       (norm_state): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#       (adaLN_modulation): Sequential(
#         (0): SiLU()
#         (1): Linear(in_features=4096, out_features=256, bias=False)
#         (2): Linear(in_features=256, out_features=12288, bias=False)
#       )
#     )
#     (2): DITBuildingBlock(
#       (block): GPT2FeedForward(
#         (layer1): Linear(in_features=4096, out_features=16384, bias=False)
#         (layer2): Linear(in_features=16384, out_features=4096, bias=False)
#         (dropout): Dropout(p=0.0, inplace=False)
#         (activation): GELU(approximate='none')
#       )
#       (norm_state): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#       (adaLN_modulation): Sequential(
#         (0): SiLU()
#         (1): Linear(in_features=4096, out_features=256, bias=False)
#         (2): Linear(in_features=256, out_features=12288, bias=False)
#       )
#     )
#   )
# )
# CosmosTransformerBlock(
#   (norm1): CosmosAdaLayerNormZero(
#     (norm): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#     (activation): SiLU()
#     (linear_1): Linear(in_features=4096, out_features=256, bias=False)
#     (linear_2): Linear(in_features=256, out_features=12288, bias=False)
#   )
#   (attn1): Attention(
#     (norm_q): RMSNorm()
#     (norm_k): RMSNorm()
#     (to_q): Linear(in_features=4096, out_features=4096, bias=False)
#     (to_k): Linear(in_features=4096, out_features=4096, bias=False)
#     (to_v): Linear(in_features=4096, out_features=4096, bias=False)
#     (to_out): ModuleList(
#       (0): Linear(in_features=4096, out_features=4096, bias=False)
#       (1): Dropout(p=0.0, inplace=False)
#     )
#   )
#   (norm2): CosmosAdaLayerNormZero(
#     (norm): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#     (activation): SiLU()
#     (linear_1): Linear(in_features=4096, out_features=256, bias=False)
#     (linear_2): Linear(in_features=256, out_features=12288, bias=False)
#   )
#   (attn2): Attention(
#     (norm_q): RMSNorm()
#     (norm_k): RMSNorm()
#     (to_q): Linear(in_features=4096, out_features=4096, bias=False)
#     (to_k): Linear(in_features=1024, out_features=4096, bias=False)
#     (to_v): Linear(in_features=1024, out_features=4096, bias=False)
#     (to_out): ModuleList(
#       (0): Linear(in_features=4096, out_features=4096, bias=False)
#       (1): Dropout(p=0.0, inplace=False)
#     )
#   )
#   (norm3): CosmosAdaLayerNormZero(
#     (norm): LayerNorm((4096,), eps=1e-06, elementwise_affine=False)
#     (activation): SiLU()
#     (linear_1): Linear(in_features=4096, out_features=256, bias=False)
#     (linear_2): Linear(in_features=256, out_features=12288, bias=False)
#   )
#   (ff): FeedForward(
#     (net): ModuleList(
#       (0): GELU(
#         (proj): Linear(in_features=4096, out_features=16384, bias=False)
#       )
#       (1): Dropout(p=0.0, inplace=False)
#       (2): Linear(in_features=16384, out_features=4096, bias=False)
#     )
#   )
# )
test transformer
import sys
sys.path.append("/raid/aryan/cosmos-code/")

import torch
from cosmos1.models.diffusion.networks.general_dit import GeneralDIT


@torch.no_grad()
def match_transformer():
    from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel

    theirs_transformer = GeneralDIT(
        max_img_h=240,
        max_img_w=240,
        max_frames=128,
        in_channels=16,
        out_channels=16,
        patch_spatial=2,
        patch_temporal=1,
        concat_padding_mask=True,
        block_config="FA-CA-MLP",
        model_channels=4096,
        num_blocks=2,
        num_heads=32,
        mlp_ratio=4,
        block_x_format="THWBD",
        crossattn_emb_channels=1024,
        use_cross_attn_mask=False,
        pos_emb_cls="rope3d",
        pos_emb_learnable=True,
        pos_emb_interpolation="crop",
        affline_emb_norm=True,
        use_adaln_lora=True,
        adaln_lora_dim=256,
        rope_h_extrapolation_ratio=1.0,
        rope_w_extrapolation_ratio=1.0,
        rope_t_extrapolation_ratio=2.0,
        extra_per_block_abs_pos_emb=True,
        extra_per_block_abs_pos_emb_type="learnable",
    )

    ours_transformer = CosmosTransformer3DModel(
        in_channels=16,
        out_channels=16,
        num_attention_heads=32,
        attention_head_dim=128,
        num_layers=2,
        mlp_ratio=4,
        text_embed_dim=1024,
        adaln_lora_dim=256,
        max_size=(128, 240, 240),
        patch_size=(1, 2, 2),
        rope_scale=(2.0, 1.0, 1.0),
        concat_padding_mask=True,
        extra_pos_embed_type="learnable",
    )

    # Patch embedding
    ours_transformer.patch_embed.proj.weight.data.copy_(theirs_transformer.x_embedder.proj[1].weight.data)

    # Timestep embedding
    ours_t_embedder = ours_transformer.time_embed
    theirs_t_embedder = theirs_transformer.t_embedder
    theirs_norm = theirs_transformer.affline_norm
    ours_t_embedder.t_embedder.linear_1.weight.data.copy_(theirs_t_embedder[1].linear_1.weight.data)
    ours_t_embedder.t_embedder.linear_2.weight.data.copy_(theirs_t_embedder[1].linear_2.weight.data)
    ours_t_embedder.norm.weight.data.copy_(theirs_norm.weight.data)

    # Learnable position embedding
    ours_pe = ours_transformer.learnable_pos_embed
    theirs_pe = theirs_transformer.extra_pos_embedder
    ours_pe.pos_emb_t.data.copy_(theirs_pe.pos_emb_t.data)
    ours_pe.pos_emb_h.data.copy_(theirs_pe.pos_emb_h.data)
    ours_pe.pos_emb_w.data.copy_(theirs_pe.pos_emb_w.data)

    # Transformer blocks
    for i in range(2):
        ours_transformer_block = ours_transformer.transformer_blocks[i]
        theirs_transformer_block = theirs_transformer.blocks[f"block{i}"]
        
        ours_transformer_block.norm1.linear_1.weight.data.copy_(theirs_transformer_block.blocks[0].adaLN_modulation[1].weight.data)
        ours_transformer_block.norm1.linear_2.weight.data.copy_(theirs_transformer_block.blocks[0].adaLN_modulation[2].weight.data)
            
        ours_transformer_block.attn1.to_q.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_q[0].weight.data)
        ours_transformer_block.attn1.to_k.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_k[0].weight.data)
        ours_transformer_block.attn1.to_v.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_v[0].weight.data)
        ours_transformer_block.attn1.to_out[0].weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_out[0].weight.data)
        ours_transformer_block.attn1.norm_q.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_q[1].weight.data)
        ours_transformer_block.attn1.norm_k.weight.data.copy_(theirs_transformer_block.blocks[0].block.attn.to_k[1].weight.data)

        ours_transformer_block.norm2.linear_1.weight.data.copy_(theirs_transformer_block.blocks[1].adaLN_modulation[1].weight.data)
        ours_transformer_block.norm2.linear_2.weight.data.copy_(theirs_transformer_block.blocks[1].adaLN_modulation[2].weight.data)

        ours_transformer_block.attn2.to_q.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_q[0].weight.data)
        ours_transformer_block.attn2.to_k.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_k[0].weight.data)
        ours_transformer_block.attn2.to_v.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_v[0].weight.data)
        ours_transformer_block.attn2.to_out[0].weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_out[0].weight.data)
        ours_transformer_block.attn2.norm_q.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_q[1].weight.data)
        ours_transformer_block.attn2.norm_k.weight.data.copy_(theirs_transformer_block.blocks[1].block.attn.to_k[1].weight.data)

        ours_transformer_block.norm3.linear_1.weight.data.copy_(theirs_transformer_block.blocks[2].adaLN_modulation[1].weight.data)
        ours_transformer_block.norm3.linear_2.weight.data.copy_(theirs_transformer_block.blocks[2].adaLN_modulation[2].weight.data)

        ours_transformer_block.ff.net[0].proj.weight.data.copy_(theirs_transformer_block.blocks[2].block.layer1.weight.data)
        ours_transformer_block.ff.net[2].weight.data.copy_(theirs_transformer_block.blocks[2].block.layer2.weight.data)
    
    # Output layers
    ours_transformer.norm_out.linear_1.weight.data.copy_(theirs_transformer.final_layer.adaLN_modulation[1].weight.data)
    ours_transformer.norm_out.linear_2.weight.data.copy_(theirs_transformer.final_layer.adaLN_modulation[2].weight.data)
    ours_transformer.proj_out.weight.data.copy_(theirs_transformer.final_layer.linear.weight.data)

    for name, param in theirs_transformer.named_parameters():
        if "bias" in name:
            print(name, param.shape)
    for name, param in ours_transformer.named_parameters():
        if "bias" in name:
            print(name, param.shape)


    # ============
    batch_size = 1
    latent_num_frames = 2
    latent_height = 16
    latent_width = 16
    encoder_seq_length = 64
    encoder_dim = 1024
    fps = 30.0
    
    hidden_states = torch.randn(batch_size, latent_num_frames, latent_height, latent_width, 16)
    timestep = torch.randint(0, 1000, (batch_size,)).float()
    encoder_hidden_states = torch.randn(batch_size, encoder_seq_length, encoder_dim)
    attention_mask = None
    padding_mask = torch.zeros((1, 1, latent_height * 8, latent_width * 8))

    theirs_output = theirs_transformer(
        x=hidden_states.permute(0, 4, 1, 2, 3),
        timesteps=timestep,
        crossattn_emb=encoder_hidden_states,
        crossattn_mask=attention_mask,
        fps=torch.tensor([fps]),
        padding_mask=padding_mask,
    )
    print()
    ours_output = ours_transformer(
        hidden_states=hidden_states.permute(0, 4, 1, 2, 3),
        timestep=timestep.long(),
        encoder_hidden_states=encoder_hidden_states,
        attention_mask=attention_mask,
        fps=fps,
        padding_mask=padding_mask,
    )[0]

    print(torch.allclose(theirs_output, ours_output, atol=1e-4))

match_transformer()

Inference code (at the moment) (does not work as expected due to possible differences in scheduler):

import os
from typing import Any, Dict

import torch
from diffusers import CosmosTransformer3DModel, CosmosPipeline, EDMEulerScheduler
from diffusers.utils import export_to_video
from transformers import T5EncoderModel, T5TokenizerFast


def remove_keys_(key: str, state_dict: Dict[str, Any]):
    state_dict.pop(key)


def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
    state_dict[new_key] = state_dict.pop(old_key)


def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
    block_index = int(key.split(".")[1].removeprefix("block"))
    new_key = key

    old_prefix = f"blocks.block{block_index}"
    new_prefix = f"transformer_blocks.{block_index}"
    new_key = new_prefix + new_key.removeprefix(old_prefix)
    
    state_dict[new_key] = state_dict.pop(key)


TRANSFORMER_KEYS_RENAME_DICT = {
    "t_embedder.1": "time_embed.t_embedder",
    "affline_norm": "time_embed.norm",
    ".blocks.0.block.attn": ".attn1",
    ".blocks.1.block.attn": ".attn2",
    ".blocks.2.block": ".ff",
    ".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
    ".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
    ".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
    ".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
    ".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
    ".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
    "to_q.0": "to_q",
    "to_q.1": "norm_q",
    "to_k.0": "to_k",
    "to_k.1": "norm_k",
    "to_v.0": "to_v",
    "layer1": "net.0.proj",
    "layer2": "net.2",
    "proj.1": "proj",
    "x_embedder": "patch_embed",
    "extra_pos_embedder": "learnable_pos_embed",
    "final_layer.adaLN_modulation.1": "norm_out.linear_1",
    "final_layer.adaLN_modulation.2": "norm_out.linear_2",
    "final_layer.linear": "proj_out",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
    "blocks.block": rename_transformer_blocks_,
    "logvar.0.freqs": remove_keys_,
    "logvar.0.phases": remove_keys_,
    "logvar.1.weight": remove_keys_,
    "pos_embedder.seq": remove_keys_,
}


def convert_transformer(state_dict):
    PREFIX_KEY = "net."
    for key in list(state_dict.keys()):
        new_key = key[:]
        if new_key.startswith(PREFIX_KEY):
            new_key = key[len(PREFIX_KEY) :]
        for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
            new_key = new_key.replace(replace_key, rename_key)
        update_state_dict_inplace(state_dict, key, new_key)

    for key in list(state_dict.keys()):
        for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
            if special_key not in key:
                continue
            handler_fn_inplace(key, state_dict)
    
    return state_dict


torch.manual_seed(0)
device = "cuda"
dtype = torch.bfloat16

with torch.no_grad():
    with torch.device("meta"):
        transformer = CosmosTransformer3DModel()
    num_parameters = sum(p.numel() for p in transformer.parameters())
    print(f"{num_parameters=}")

    checkpoint_file = "/raid/aryan/cosmos-code/checkpoints/Cosmos-1.0-Diffusion-7B-Text2World/model.pt"
    checkpoint = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
    checkpoint = convert_transformer(checkpoint)
    transformer.load_state_dict(checkpoint, strict=True, assign=True)

    text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-11b", torch_dtype=dtype, cache_dir="/raid/aryan/cosmos-code/checkpoints")
    tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-11b", cache_dir="/raid/aryan/cosmos-code/checkpoints")

    vae_dir = "/raid/aryan/cosmos-code/checkpoints/Cosmos-1.0-Tokenizer-CV8x8x8"
    decoder = torch.jit.load(os.path.join(vae_dir, "decoder.jit")).to(device=device, dtype=dtype)

    scheduler = EDMEulerScheduler()
    
    pipe = CosmosPipeline(text_encoder, tokenizer, transformer, vae=None, scheduler=scheduler)
    pipe.to(device)

    prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
    latents = pipe(
        prompt=prompt,
        height=704,
        # width=1280,
        width=960,
        num_frames=121,
        num_inference_steps=35,
        output_type="latent",
    ).frames
    output = decoder(latents)

    video = pipe.video_processor.postprocess_video(output, output_type="pil")[0]

    export_to_video(video, "output.mp4", fps=30)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w added the roadmap Add to current release roadmap label Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

2 participants