import json import logging import os import torch from torch.cuda.amp import GradScaler, autocast import monai from monai.data import ThreadDataLoader, CacheDataset, create_test_image_3d from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi from generative.networks.schedulers import DDPMScheduler import nibabel as nib import numpy as np def main(): # Configuration definitions config = { "noise_scheduler": { "num_train_timesteps": 1000, "beta_start": 0.0015, "beta_end": 0.0195, "schedule": "scaled_linear_beta", "clip_sample": False, }, } # User input latent_dim = (128, 128, 128) # for 80GB GPU # latent_dim = (128, 64, 64) # for 16GB GPU # latent_dim = (128, 128, 64) # for 32GB GPU enable_flash_attention = True # Initialize device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("training") # Create fake data data_dir = "./data" os.makedirs(data_dir, exist_ok=True) train_data = [] for i in range(10): image, _ = create_test_image_3d(latent_dim[0], latent_dim[1], latent_dim[2]) image_fpath = os.path.join(data_dir, f"image_{i}.nii.gz") nib.save(nib.Nifti1Image(image, affine=np.eye(4)), image_fpath) train_data.append({"image": image_fpath}) with open(os.path.join(data_dir, "train.json"), 'w') as f: json.dump({"training": train_data}, f) # Load fake data train_transforms = Compose([LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"])]) train_ds = CacheDataset(data=train_data, transform=train_transforms, cache_rate=1.0, num_workers=2) train_loader = ThreadDataLoader(train_ds, num_workers=6, batch_size=1, shuffle=True) # Define UNet model unet = DiffusionModelUNetMaisi( spatial_dims=3, in_channels=4, out_channels=4, num_channels=[64, 128, 256, 512], attention_levels=[False, False, True, True], num_head_channels=[0, 0, 32, 32], num_res_blocks=2, use_flash_attention=enable_flash_attention, include_top_region_index_input=True, include_bottom_region_index_input=True, include_spacing_input=True, ).to(device) # Define noise scheduler noise_scheduler = DDPMScheduler( num_train_timesteps=config["noise_scheduler"]["num_train_timesteps"], beta_start=config["noise_scheduler"]["beta_start"], beta_end=config["noise_scheduler"]["beta_end"], schedule=config["noise_scheduler"]["schedule"], clip_sample=config["noise_scheduler"]["clip_sample"], ) # Calculate scale factor check_data = next(iter(train_loader)) z = check_data["image"].to(device) scale_factor = 1 / torch.std(z) logger.info(f"Scaling factor set to {scale_factor}.") # Create optimizer and learning rate scheduler optimizer = torch.optim.Adam(params=unet.parameters(), lr=1e-3) total_steps = 2 * len(train_loader) lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 - step / total_steps) # Loss function and scaler for mixed precision training loss_pt = torch.nn.L1Loss() scaler = GradScaler() top_region_index_tensor = torch.tensor([0,1,0,0]).unsqueeze(0).half().to(device) bottom_region_index_tensor = torch.tensor([0,1,0,0]).unsqueeze(0).half().to(device) spacing_tensor = torch.tensor([1,1,1]).unsqueeze(0).half().to(device) # Training loop for epoch in range(2): unet.train() for train_data in train_loader: images = train_data["image"].to(device) images = images.repeat(1, 4, 1, 1, 1) images = images * scale_factor optimizer.zero_grad(set_to_none=True) with autocast(enabled=True): noise = torch.randn((1, 4, latent_dim[0], latent_dim[1], latent_dim[2]), device=device) timesteps = torch.randint(0, 1000, (images.shape[0],), device=images.device).long() noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps) noise_pred = unet( x=noisy_latent, timesteps=timesteps, top_region_index_tensor=top_region_index_tensor, bottom_region_index_tensor=bottom_region_index_tensor, spacing_tensor=spacing_tensor, ) loss = loss_pt(noise_pred.float(), noise.float()) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() lr_scheduler.step() logger.info(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}") if __name__ == "__main__": main()