Skip to content

Commit

Permalink
Add dataloader, no grad scope and auto gpu detection (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
hadaev8 authored Oct 8, 2020
1 parent 355e745 commit 4d7695b
Showing 1 changed file with 49 additions and 39 deletions.
88 changes: 49 additions & 39 deletions pytorch_fid/fid_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
import os
import pathlib
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from multiprocessing import cpu_count

import numpy as np
import torch
import torchvision.transforms as TF
from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d

from PIL import Image

try:
Expand All @@ -53,25 +54,34 @@ def tqdm(x): return x
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--batch-size', type=int, default=50,
help='Batch size to use')
parser.add_argument('--device', type=str, default=None,
help='Device to use. Like cuda, cuda:0 or cpu')
parser.add_argument('--dims', type=int, default=2048,
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
help=('Dimensionality of Inception features to use. '
'By default, uses pool3 features'))
parser.add_argument('-c', '--gpu', default='', type=str,
help='GPU to use (leave blank for CPU only)')
parser.add_argument('path', type=str, nargs=2,
help=('Paths to the generated images or '
'to .npz statistic files'))


def imread(filename):
"""
Loads an image file into a (height, width, 3) uint8 ndarray.
"""
return np.asarray(Image.open(filename), dtype=np.uint8)[..., :3]
class ImagesPathDataset(torch.utils.data.Dataset):
def __init__(self, files, transforms=None):
self.files = files
self.transforms = transforms

def __len__(self):
return len(self.files)

def get_activations(files, model, batch_size=50, dims=2048, cuda=False):
def __getitem__(self, i):
path = self.files[i]
img = Image.open(path).convert('RGB')
if self.transforms is not None:
img = self.transforms(img)
return img


def get_activations(files, model, batch_size=50, dims=2048, device='cpu'):
"""Calculates the activations of the pool_3 layer for all images.
Params:
Expand All @@ -83,7 +93,7 @@ def get_activations(files, model, batch_size=50, dims=2048, cuda=False):
behavior is retained to match the original FID score
implementation.
-- dims : Dimensionality of features returned by Inception
-- cuda : If set to True, use GPU
-- device : Device to run calculations
Returns:
-- A numpy array of dimension (num images, dims) that contains the
Expand All @@ -97,31 +107,30 @@ def get_activations(files, model, batch_size=50, dims=2048, cuda=False):
'Setting batch size to data size'))
batch_size = len(files)

pred_arr = np.empty((len(files), dims))

for i in tqdm(range(0, len(files), batch_size)):
start = i
end = i + batch_size
ds = ImagesPathDataset(files, transforms=TF.ToTensor())
dl = torch.utils.data.DataLoader(ds, batch_size=batch_size,
drop_last=False, num_workers=cpu_count())

images = np.array([imread(str(f)).astype(np.float32)
for f in files[start:end]])
pred_arr = np.empty((len(files), dims))

# Reshape to (n_images, 3, height, width)
images = images.transpose((0, 3, 1, 2))
images /= 255
start_idx = 0

batch = torch.from_numpy(images).type(torch.FloatTensor)
if cuda:
batch = batch.cuda()
for batch in tqdm(dl):
batch = batch.to(device)

pred = model(batch)[0]
with torch.no_grad():
pred = model(batch)[0]

# If model output is not scalar, apply global spatial average pooling.
# This happens if you choose a dimensionality not equal 2048.
if pred.size(2) != 1 or pred.size(3) != 1:
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

pred_arr[start:end] = pred.cpu().data.numpy().reshape(pred.size(0), -1)
pred = pred.squeeze(3).squeeze(2).cpu().numpy()

pred_arr[start_idx:start_idx + pred.shape[0]] = pred

start_idx = start_idx + pred.shape[0]

return pred_arr

Expand Down Expand Up @@ -183,8 +192,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
np.trace(sigma2) - 2 * tr_covmean)


def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
cuda=False):
def calculate_activation_statistics(files, model, batch_size=50, dims=2048, device='cpu'):
"""Calculation of the statistics used by the FID.
Params:
-- files : List of image files paths
Expand All @@ -193,21 +201,21 @@ def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
batch size batch_size. A reasonable batch size
depends on the hardware.
-- dims : Dimensionality of features returned by Inception
-- cuda : If set to True, use GPU
-- device : Device to run calculations
Returns:
-- mu : The mean over samples of the activations of the pool_3 layer of
the inception model.
-- sigma : The covariance matrix of the activations of the pool_3 layer of
the inception model.
"""
act = get_activations(files, model, batch_size, dims, cuda)
act = get_activations(files, model, batch_size, dims, device)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma


def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
def _compute_statistics_of_path(path, model, batch_size, dims, device):
if path.endswith('.npz'):
f = np.load(path)
m, s = f['mu'][:], f['sigma'][:]
Expand All @@ -216,39 +224,41 @@ def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
path = pathlib.Path(path)
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
m, s = calculate_activation_statistics(files, model, batch_size,
dims, cuda)
dims, device)

return m, s


def calculate_fid_given_paths(paths, batch_size, cuda, dims):
def calculate_fid_given_paths(paths, batch_size, device, dims):
"""Calculates the FID of two paths"""
for p in paths:
if not os.path.exists(p):
raise RuntimeError('Invalid path: %s' % p)

block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

model = InceptionV3([block_idx])
if cuda:
model.cuda()
model = InceptionV3([block_idx]).to(device)

m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
dims, cuda)
dims, device)
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
dims, cuda)
dims, device)
fid_value = calculate_frechet_distance(m1, s1, m2, s2)

return fid_value


def main():
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

if args.device is None:
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
else:
device = torch.device(args.device)

fid_value = calculate_fid_given_paths(args.path,
args.batch_size,
args.gpu != '',
device,
args.dims)
print('FID: ', fid_value)

Expand Down

0 comments on commit 4d7695b

Please sign in to comment.