-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rewrite all code to make more user friendly
- Loading branch information
1 parent
848e1a1
commit 4205b6b
Showing
32 changed files
with
663 additions
and
3,328 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
__pycache__/ | ||
network/xception-b5690688.pth | ||
models/ | ||
xception-b5690688.pth |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from imageio import imread | ||
import os | ||
import random | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
|
||
DATABASE = '/home/jstay/FFD/data/' | ||
|
||
DATASETS = { | ||
'Real': 0, | ||
'Fake': 1 | ||
} | ||
|
||
class DatasetInstance: | ||
def __init__(self, label_name, label, datatype, img_size, map_size, norm, seed, bs, drop_last): | ||
self.img_size = img_size | ||
self.map_size = map_size | ||
|
||
|
||
self.transform = transforms.Compose([ | ||
transforms.ToPILImage(), | ||
transforms.Resize(img_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize(*norm) | ||
]) | ||
|
||
self.transform_mask = transforms.Compose([ | ||
transforms.ToPILImage(), | ||
transforms.Resize(map_size), | ||
transforms.Grayscale(num_output_channels=1), | ||
transforms.ToTensor() | ||
]) | ||
|
||
self.label_name = label_name | ||
self.label = label | ||
self.datatype = datatype | ||
self.data_dir = '{0}{1}/{2}/'.format(DATABASE, self.datatype, self.label_name) | ||
files = os.listdir(self.data_dir) | ||
if self.datatype != 'test': | ||
random.Random(seed).shuffle(files) | ||
self.images = ['{0}/{1}'.format(self.data_dir, _) for _ in files] | ||
self.loader = DataLoader(self, num_workers=8, batch_size=bs, shuffle=(self.datatype != 'test'), drop_last=drop_last, pin_memory=True) | ||
self.generator = self.get_batch() | ||
|
||
print('Constructed Dataset `{0}` of size `{1}`'.format(self.data_dir, self.__len__())) | ||
|
||
def load_image(self, path): | ||
return self.transform(imread(path)) | ||
|
||
def load_mask(self, path): | ||
return self.transform_mask(imread(path)) | ||
|
||
def __getitem__(self, index): | ||
im_name = self.images[index] | ||
img = self.load_image(im_name) | ||
if self.label_name == 'Real': | ||
msk = torch.zeros(1,19,19) | ||
else: | ||
msk = self.load_mask(im_name.replace('Fake/', 'Mask/')) | ||
return {'img': img, 'msk': msk, 'lab': self.label, 'im_name': im_name} | ||
|
||
def __len__(self): | ||
return len(self.images) | ||
|
||
def get_batch(self): | ||
if self.datatype == 'test': | ||
for batch in self.loader: | ||
yield batch | ||
else: | ||
while True: | ||
for batch in self.loader: | ||
yield batch | ||
|
||
class Dataset: | ||
def __init__(self, datatype, bs, img_size, map_size, norm, seed): | ||
drop_last = datatype == 'train' | ||
datasets = [DatasetInstance(_, DATASETS[_], datatype, img_size, map_size, norm, seed, bs, drop_last) for _ in DATASETS] | ||
drop_last = datatype == 'train' or datatype == 'eval' | ||
self.datasets = datasets | ||
|
||
def get_batch(self, index = -1): | ||
batch = None | ||
if index == -1: | ||
batch = [next(_.generator, None) for _ in self.datasets] | ||
else: | ||
batch = [next(self.datasets[index].generator, None)] | ||
if any([_ is None for _ in batch]): | ||
return None | ||
img = torch.cat([_['img'] for _ in batch], dim=0).cuda() | ||
msk = torch.cat([_['msk'] for _ in batch], dim=0).cuda() | ||
lab = torch.cat([_['lab'] for _ in batch], dim=0).cuda() | ||
#im_name = torch.cat([_['im_name'] for _ in batch], dim=0) | ||
return { 'img': img, 'msk': msk, 'lab': lab } | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import glob | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from scipy.io import loadmat | ||
from sklearn import metrics | ||
from sklearn.metrics import auc | ||
|
||
EPOCH = '80' | ||
RESDIR = './models/xcp_reg/results/' + EPOCH + '/' | ||
RESFILENAMES = glob.glob(RESDIR + '*.mat') | ||
MASK_THRESHOLD = 0.5 | ||
|
||
print('{0} result files'.format(len(RESFILENAMES))) | ||
|
||
def compute_result_file(rfn): | ||
rf = loadmat(rfn) | ||
res = {} | ||
for r in ['lab', 'msk', 'score', 'pred', 'mask']: | ||
res[r] = rf[r].squeeze() | ||
return res | ||
|
||
# Compile the results into a single variable for processing | ||
TOTAL_RESULTS = {} | ||
for rfn in RESFILENAMES: | ||
rf = compute_result_file(rfn) | ||
for r in rf: | ||
if r not in TOTAL_RESULTS: | ||
TOTAL_RESULTS[r] = rf[r] | ||
else: | ||
TOTAL_RESULTS[r] = np.concatenate([TOTAL_RESULTS[r], rf[r]], axis=0) | ||
|
||
print('Found {0} total images with scores.'.format(TOTAL_RESULTS['lab'].shape[0])) | ||
print(' {0} results are real images'.format((TOTAL_RESULTS['lab'] == 0).sum())) | ||
print(' {0} results are fake images'.format((TOTAL_RESULTS['lab'] == 1).sum())) | ||
#for r in TOTAL_RESULTS: | ||
# print('{0} has shape {1}'.format(r, TOTAL_RESULTS[r].shape)) | ||
|
||
# Compute the performance numbers | ||
PRED_ACC = (TOTAL_RESULTS['lab'] == TOTAL_RESULTS['pred']).astype(np.float32).mean() | ||
MASK_ACC = ((TOTAL_RESULTS['mask'] >= MASK_THRESHOLD) == (TOTAL_RESULTS['msk'] >= MASK_THRESHOLD)).astype(np.float32).mean() | ||
|
||
FPR, TPR, THRESH = metrics.roc_curve(TOTAL_RESULTS['lab'], TOTAL_RESULTS['score'][:,1], drop_intermediate=False) | ||
AUC = auc(FPR, TPR) | ||
FNR = 1 - TPR | ||
EER = FNR[np.argmin(np.absolute(FNR - FPR))] | ||
TPR_AT_FPR_NOT_0 = TPR[FPR != 0].min() | ||
TPR_AT_FPR_THRESHOLDS = {} | ||
for t in range(-1, -7, -1): | ||
thresh = 10**t | ||
TPR_AT_FPR_THRESHOLDS[thresh] = TPR[FPR <= thresh].max() | ||
|
||
# Print out the performance numbers | ||
print('Prediction Accuracy: {0:.4f}'.format(PRED_ACC)) | ||
print('Mask Accuracy: {0:.4f}'.format(MASK_ACC)) | ||
print('AUC: {0:.4f}'.format(AUC)) | ||
print('EER: {0:.4f}'.format(EER)) | ||
print('Minimum TPR at FPR != 0: {0:.4f}'.format(TPR_AT_FPR_NOT_0)) | ||
|
||
print('TPR at FPR Thresholds:') | ||
for t in TPR_AT_FPR_THRESHOLDS: | ||
print(' {0:.10f} TPR at {1:.10f} FPR'.format(TPR_AT_FPR_THRESHOLDS[t], t)) | ||
|
||
fig = plt.figure() | ||
plt.plot(FPR, TPR) | ||
plt.xlabel('FPR (%)') | ||
plt.ylabel('TPR (%)') | ||
plt.xscale('log') | ||
plt.xlim([10e-8,1]) | ||
plt.ylim([0, 1]) | ||
plt.grid() | ||
plt.show() |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.