from diffusers import LDMPipeline, DDPMPipeline, DDIMPipeline, DDIMScheduler, DDPMScheduler, VQModel
from diffusers.models import UNet2DModel
import torch_pruning as tp
import torch
import torchvision
from torchvision import transforms
import torchvision
from tqdm import tqdm
import os
from glob import glob
from PIL import Image
import accelerate
import utils

import argparse
parser = argparse.ArgumentParser()
#parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--save_path", type=str, required=True)
parser.add_argument("--pruning_ratio", type=float, default=0.3)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--device", type=str, default='cpu')
#parser.add_argument("--pruner", type=str, default='taylor', choices=['taylor', 'random', 'magnitude', 'reinit', 'diff-pruning'])
parser.add_argument("--pruner", type=str, default='random', choices=['random', 'magnitude', 'reinit'])

#parser.add_argument("--thr", type=float, default=0.05, help="threshold for diff-pruning")

args = parser.parse_args()

batch_size = args.batch_size

if __name__=='__main__':
    #dataset = utils.get_dataset(args.dataset)
    #print(f"Dataset size: {len(dataset)}")
    #train_dataloader = torch.utils.data.DataLoader(
    #    dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True
    #)
    #import torch_pruning as tp

    # loading images for gradient-based pruning
    #clean_images = iter(train_dataloader).next()
    #if isinstance(clean_images, (list, tuple)):
    #    clean_images = clean_images[0]
    #clean_images = clean_images.to(args.device)
    #noise = torch.randn(clean_images.shape).to(clean_images.device)

    # Loading pretrained model
    print("Loading pretrained model from {}".format(args.model_path))
    # load all models
    unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
    vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
    scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")

    # set to cuda
    torch_device = torch.device(args.device) if torch.cuda.is_available() else "cpu"

    unet.to(torch_device)
    vqvae.to(torch_device)
    example_inputs = {'sample': torch.randn(1, unet.in_channels, unet.sample_size, unet.sample_size).to(args.device), 'timestep': torch.ones((1,)).long().to(args.device)}

    if args.pruning_ratio>0:
        if args.pruner == 'taylor':
            imp = tp.importance.TaylorImportance() 
        elif args.pruner == 'random' or args.pruner=='reinit':
            imp = tp.importance.RandomImportance()
        elif args.pruner == 'magnitude':
            imp = tp.importance.MagnitudeImportance()
        elif args.pruner == 'diff-pruning':
            imp = tp.importance.TaylorImportance(multivariable=False) 
        else:
            raise NotImplementedError

        ignored_layers = [unet.conv_out]
        ignored_layers = [unet.conv_out]
        from diffusers.models.attention import Attention
        channel_groups = {}
        for m in unet.modules():
            if isinstance(m, Attention):
                channel_groups[m.to_q] = m.heads
                channel_groups[m.to_k] = m.heads
                channel_groups[m.to_v] = m.heads
        
        pruner = tp.pruner.MagnitudePruner(
            unet,
            example_inputs,
            importance=imp,
            iterative_steps=1,
            channel_groups=channel_groups,
            ch_sparsity=args.pruning_ratio,
            ignored_layers=ignored_layers,
        )

        base_macs, base_params = tp.utils.count_ops_and_params(unet, example_inputs)
        unet.zero_grad()
        unet.eval()
        import random

        for g in pruner.step(interactive=True):
            g.prune()

        # Update static attributes
        from diffusers.models.resnet import Upsample2D, Downsample2D
        for m in unet.modules():
            if isinstance(m, (Upsample2D, Downsample2D)):
                m.channels = m.conv.in_channels
                m.out_channels == m.conv.out_channels

        macs, params = tp.utils.count_ops_and_params(unet, example_inputs)
        print(unet)
        print("#Params: {:.4f} M => {:.4f} M".format(base_params/1e6, params/1e6))
        print("#MACS: {:.4f} G => {:.4f} G".format(base_macs/1e9, macs/1e9))
        unet.zero_grad()
        del pruner

        if args.pruner=='reinit':
            def reset_parameters(model):
                for m in model.modules():
                    if hasattr(m, 'reset_parameters'):
                        m.reset_parameters()
            reset_parameters(unet)

    
    pipeline = LDMPipeline(
        unet=unet,
        vqvae=vqvae,
        scheduler=scheduler,
    ).to(torch_device)
    pipeline.save_pretrained(args.save_path)
    if args.pruning_ratio>0:
        os.makedirs(os.path.join(args.save_path, "pruned"), exist_ok=True)
        torch.save(unet, os.path.join(args.save_path, "pruned", "unet_pruned.pth"))
    with torch.no_grad():
        generator = torch.Generator(device=torch_device).manual_seed(0)
        images = pipeline(num_inference_steps=100, batch_size=args.batch_size, output_type="numpy").images
        os.makedirs(os.path.join(args.save_path, 'vis'), exist_ok=True)
        torchvision.utils.save_image(torch.from_numpy(images).permute([0, 3, 1, 2]), "{}/vis/after_pruning.png".format(args.save_path))