# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is licensed under a Creative Commons # Attribution-NonCommercial-ShareAlike 4.0 International License. # You should have received a copy of the license along with this # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ """Script for calculating Frechet Inception Distance (FID).""" import os import click import tqdm import pickle import numpy as np import scipy.linalg import torch import dnnlib from torch_utils import distributed as dist from training import dataset #---------------------------------------------------------------------------- def calculate_inception_stats( image_path, num_expected=None, seed=0, max_batch_size=64, num_workers=11, prefetch_factor=2, device=torch.device('cuda'), ): # Rank 0 goes first. if dist.get_rank() != 0: torch.distributed.barrier() # Load Inception-v3 model. # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz dist.print0('Loading Inception-v3 model...') detector_url = 'fid-refs/inception-2015-12-05.pkl' if not os.path.exists(detector_url): detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' detector_kwargs = dict(return_features=True) feature_dim = 2048 with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f: detector_net = pickle.load(f).to(device) # List images. dist.print0(f'Loading images from "{image_path}"...') dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed) if num_expected is not None and len(dataset_obj) < num_expected: raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}') if len(dataset_obj) < 2: raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics') # Other ranks follow. if dist.get_rank() == 0: torch.distributed.barrier() # Divide images into batches. num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches) rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor) # Accumulate statistics. dist.print0(f'Calculating statistics for {len(dataset_obj)} images...') mu = torch.zeros([feature_dim], dtype=torch.float64, device=device) sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device) for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)): torch.distributed.barrier() if images.shape[0] == 0: continue if images.shape[1] == 1: images = images.repeat([1, 3, 1, 1]) features = detector_net(images.to(device), **detector_kwargs).to(torch.float64) mu += features.sum(0) sigma += features.T @ features # Calculate grand totals. torch.distributed.all_reduce(mu) torch.distributed.all_reduce(sigma) mu /= len(dataset_obj) sigma -= mu.ger(mu) * len(dataset_obj) sigma /= len(dataset_obj) - 1 return mu.cpu().numpy(), sigma.cpu().numpy() #---------------------------------------------------------------------------- def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): print('mu', mu.mean()) print('sigma', sigma.mean()) print('mu_ref', mu_ref.mean()) print('sigma_ref', sigma_ref.mean()) m = np.square(mu - mu_ref).sum() s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) fid = m + np.trace(sigma + sigma_ref - s * 2) if np.isnan(fid) or (fid < 0.) or (fid > 1e5): print(f'Got {fid}. Trying again with inflated sigmas.') eps = np.eye(len(sigma)) * 1e-8 return calculate_fid_from_inception_stats(mu, sigma+eps, mu_ref, sigma_ref+eps) return float(np.real(fid)) #---------------------------------------------------------------------------- @click.group() def main(): """Calculate Frechet Inception Distance (FID). Examples: \b # Generate 50000 images and save them as fid-tmp/*/*.png torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl \b # Calculate FID torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz \b # Compute dataset reference statistics python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz """ #---------------------------------------------------------------------------- @main.command() @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True) @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True) @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True) @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True) @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) def calc(image_path, ref_path, num_expected, seed, batch): """Calculate FID for a given set of images.""" torch.multiprocessing.set_start_method('spawn') dist.init() dist.print0(f'Loading dataset reference statistics from "{ref_path}"...') ref = None if dist.get_rank() == 0: with dnnlib.util.open_url(ref_path) as f: ref = dict(np.load(f)) mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch) dist.print0('Calculating FID...') if dist.get_rank() == 0: fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma']) print(f'{fid:g}') torch.distributed.barrier() #---------------------------------------------------------------------------- @main.command() @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True) @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True) @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) def ref(dataset_path, dest_path, batch): """Calculate dataset reference statistics needed by 'calc'.""" torch.multiprocessing.set_start_method('spawn') dist.init() mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch) dist.print0(f'Saving dataset reference statistics to "{dest_path}"...') if dist.get_rank() == 0: if os.path.dirname(dest_path): os.makedirs(os.path.dirname(dest_path), exist_ok=True) np.savez(dest_path, mu=mu, sigma=sigma) torch.distributed.barrier() dist.print0('Done.') #---------------------------------------------------------------------------- if __name__ == "__main__": main() #----------------------------------------------------------------------------