Skip to content

Commit

Permalink
Rewrite all code to make more user friendly
Browse files Browse the repository at this point in the history
  • Loading branch information
JStehouwer committed Oct 7, 2020
1 parent 848e1a1 commit 4205b6b
Show file tree
Hide file tree
Showing 32 changed files with 663 additions and 3,328 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
__pycache__/
network/xception-b5690688.pth
models/
xception-b5690688.pth
Binary file added MCT/template0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MCT/template1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MCT/template2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MCT/template3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MCT/template4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MCT/template5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MCT/template6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MCT/template7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MCT/template8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added MCT/template9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
95 changes: 95 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from imageio import imread
import os
import random
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

DATABASE = '/home/jstay/FFD/data/'

DATASETS = {
'Real': 0,
'Fake': 1
}

class DatasetInstance:
def __init__(self, label_name, label, datatype, img_size, map_size, norm, seed, bs, drop_last):
self.img_size = img_size
self.map_size = map_size


self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize(*norm)
])

self.transform_mask = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(map_size),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor()
])

self.label_name = label_name
self.label = label
self.datatype = datatype
self.data_dir = '{0}{1}/{2}/'.format(DATABASE, self.datatype, self.label_name)
files = os.listdir(self.data_dir)
if self.datatype != 'test':
random.Random(seed).shuffle(files)
self.images = ['{0}/{1}'.format(self.data_dir, _) for _ in files]
self.loader = DataLoader(self, num_workers=8, batch_size=bs, shuffle=(self.datatype != 'test'), drop_last=drop_last, pin_memory=True)
self.generator = self.get_batch()

print('Constructed Dataset `{0}` of size `{1}`'.format(self.data_dir, self.__len__()))

def load_image(self, path):
return self.transform(imread(path))

def load_mask(self, path):
return self.transform_mask(imread(path))

def __getitem__(self, index):
im_name = self.images[index]
img = self.load_image(im_name)
if self.label_name == 'Real':
msk = torch.zeros(1,19,19)
else:
msk = self.load_mask(im_name.replace('Fake/', 'Mask/'))
return {'img': img, 'msk': msk, 'lab': self.label, 'im_name': im_name}

def __len__(self):
return len(self.images)

def get_batch(self):
if self.datatype == 'test':
for batch in self.loader:
yield batch
else:
while True:
for batch in self.loader:
yield batch

class Dataset:
def __init__(self, datatype, bs, img_size, map_size, norm, seed):
drop_last = datatype == 'train'
datasets = [DatasetInstance(_, DATASETS[_], datatype, img_size, map_size, norm, seed, bs, drop_last) for _ in DATASETS]
drop_last = datatype == 'train' or datatype == 'eval'
self.datasets = datasets

def get_batch(self, index = -1):
batch = None
if index == -1:
batch = [next(_.generator, None) for _ in self.datasets]
else:
batch = [next(self.datasets[index].generator, None)]
if any([_ is None for _ in batch]):
return None
img = torch.cat([_['img'] for _ in batch], dim=0).cuda()
msk = torch.cat([_['msk'] for _ in batch], dim=0).cuda()
lab = torch.cat([_['lab'] for _ in batch], dim=0).cuda()
#im_name = torch.cat([_['im_name'] for _ in batch], dim=0)
return { 'img': img, 'msk': msk, 'lab': lab }

71 changes: 71 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import glob
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat
from sklearn import metrics
from sklearn.metrics import auc

EPOCH = '80'
RESDIR = './models/xcp_reg/results/' + EPOCH + '/'
RESFILENAMES = glob.glob(RESDIR + '*.mat')
MASK_THRESHOLD = 0.5

print('{0} result files'.format(len(RESFILENAMES)))

def compute_result_file(rfn):
rf = loadmat(rfn)
res = {}
for r in ['lab', 'msk', 'score', 'pred', 'mask']:
res[r] = rf[r].squeeze()
return res

# Compile the results into a single variable for processing
TOTAL_RESULTS = {}
for rfn in RESFILENAMES:
rf = compute_result_file(rfn)
for r in rf:
if r not in TOTAL_RESULTS:
TOTAL_RESULTS[r] = rf[r]
else:
TOTAL_RESULTS[r] = np.concatenate([TOTAL_RESULTS[r], rf[r]], axis=0)

print('Found {0} total images with scores.'.format(TOTAL_RESULTS['lab'].shape[0]))
print(' {0} results are real images'.format((TOTAL_RESULTS['lab'] == 0).sum()))
print(' {0} results are fake images'.format((TOTAL_RESULTS['lab'] == 1).sum()))
#for r in TOTAL_RESULTS:
# print('{0} has shape {1}'.format(r, TOTAL_RESULTS[r].shape))

# Compute the performance numbers
PRED_ACC = (TOTAL_RESULTS['lab'] == TOTAL_RESULTS['pred']).astype(np.float32).mean()
MASK_ACC = ((TOTAL_RESULTS['mask'] >= MASK_THRESHOLD) == (TOTAL_RESULTS['msk'] >= MASK_THRESHOLD)).astype(np.float32).mean()

FPR, TPR, THRESH = metrics.roc_curve(TOTAL_RESULTS['lab'], TOTAL_RESULTS['score'][:,1], drop_intermediate=False)
AUC = auc(FPR, TPR)
FNR = 1 - TPR
EER = FNR[np.argmin(np.absolute(FNR - FPR))]
TPR_AT_FPR_NOT_0 = TPR[FPR != 0].min()
TPR_AT_FPR_THRESHOLDS = {}
for t in range(-1, -7, -1):
thresh = 10**t
TPR_AT_FPR_THRESHOLDS[thresh] = TPR[FPR <= thresh].max()

# Print out the performance numbers
print('Prediction Accuracy: {0:.4f}'.format(PRED_ACC))
print('Mask Accuracy: {0:.4f}'.format(MASK_ACC))
print('AUC: {0:.4f}'.format(AUC))
print('EER: {0:.4f}'.format(EER))
print('Minimum TPR at FPR != 0: {0:.4f}'.format(TPR_AT_FPR_NOT_0))

print('TPR at FPR Thresholds:')
for t in TPR_AT_FPR_THRESHOLDS:
print(' {0:.10f} TPR at {1:.10f} FPR'.format(TPR_AT_FPR_THRESHOLDS[t], t))

fig = plt.figure()
plt.plot(FPR, TPR)
plt.xlabel('FPR (%)')
plt.ylabel('TPR (%)')
plt.xscale('log')
plt.xlim([10e-8,1])
plt.ylim([0, 1])
plt.grid()
plt.show()
4 changes: 0 additions & 4 deletions network/utils.py

This file was deleted.

183 changes: 0 additions & 183 deletions network/vgg.py

This file was deleted.

Loading

0 comments on commit 4205b6b

Please sign in to comment.