'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import transforms as transforms
from dataloader import lunanod
import os
import argparse
import time
from models.cnn_res import *
# from utils import progress_bar
from torch.autograd import Variable
import logging
import numpy as np
import ast

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.0002, type=float, help='learning rate')
parser.add_argument('--batch_size', default=1, type=int, help='batch size ')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--savemodel', type=str, default='', help='resume from checkpoint model')
parser.add_argument("--gpuids", type=str, default='all', help='use which gpu')

parser.add_argument('--num_epochs', type=int, default=700)
parser.add_argument('--num_epochs_decay', type=int, default=70)

parser.add_argument('--num_workers', type=int, default=24)

parser.add_argument('--beta1', type=float, default=0.5)  # momentum1 in Adam
parser.add_argument('--beta2', type=float, default=0.999)  # momentum2 in Adam
parser.add_argument('--lamb', type=float, default=1, help="lambda for loss2")
parser.add_argument('--fold', type=int, default=5, help="fold")

args = parser.parse_args()

CROPSIZE = 32
gbtdepth = 1
fold = args.fold
blklst = []  # ['1.3.6.1.4.1.14519.5.2.1.6279.6001.121993590721161347818774929286-388', \
# '1.3.6.1.4.1.14519.5.2.1.6279.6001.121993590721161347818774929286-389', \
# '1.3.6.1.4.1.14519.5.2.1.6279.6001.132817748896065918417924920957-660']
logging.basicConfig(filename='log-' + str(fold), level=logging.INFO)

use_cuda = torch.cuda.is_available()
best_acc = 0  # best test accuracy
best_acc_gbt = 0
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
# Cal mean std
# preprocesspath = '/media/data1/wentao/tianchi/luna16/cls/crop_v3/'
preprocesspath = '/data/xxx/LUNA/cls/crop_v3/'
# preprocesspath = '/media/jehovah/Work/data/LUNA/cls/crop_v3/'
pixvlu, npix = 0, 0
for fname in os.listdir(preprocesspath):
    # print(fname)
    if fname.endswith('.npy'):
        if fname[:-4] in blklst: continue
        data = np.load(os.path.join(preprocesspath, fname))
        pixvlu += np.sum(data)
        # print("data.shape = " + str(data.shape))
        npix += np.prod(data.shape)
pixmean = pixvlu / float(npix)
pixvlu = 0
for fname in os.listdir(preprocesspath):
    if fname.endswith('.npy'):
        if fname[:-4] in blklst: continue
        data = np.load(os.path.join(preprocesspath, fname)) - pixmean
        pixvlu += np.sum(data * data)
pixstd = np.sqrt(pixvlu / float(npix))
# pixstd /= 255
print(pixmean, pixstd)
logging.info('mean ' + str(pixmean) + ' std ' + str(pixstd))
# Datatransforms
logging.info('==> Preparing data..')  # Random Crop, Zero out, x z flip, scale,
transform_train = transforms.Compose([
    # transforms.RandomScale(range(28, 38)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomYFlip(),
    transforms.RandomZFlip(),
    transforms.ZeroOut(4),
    transforms.ToTensor(),
    transforms.Normalize((pixmean), (pixstd)),  # need to cal mean and std, revise norm func
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((pixmean), (pixstd)),
])

# load data list
trfnamelst = []
trlabellst = []
trfeatlst = []
tefnamelst = []
telabellst = []
tefeatlst = []
import pandas as pd

dataframe = pd.read_csv('./data/annotationdetclsconvfnl_v3.csv',
                        names=['seriesuid', 'coordX', 'coordY', 'coordZ', 'diameter_mm', 'malignant'])

alllst = dataframe['seriesuid'].tolist()[1:]
labellst = dataframe['malignant'].tolist()[1:]
crdxlst = dataframe['coordX'].tolist()[1:]
crdylst = dataframe['coordY'].tolist()[1:]
crdzlst = dataframe['coordZ'].tolist()[1:]
dimlst = dataframe['diameter_mm'].tolist()[1:]
# test id
teidlst = []
for fname in os.listdir('/data/xxx/LUNA/rowfile/subset' + str(fold) + '/'):
    # for fname in os.listdir('/media/jehovah/Work/data/LUNA/rowfile/subset' + str(fold) + '/'):

    if fname.endswith('.mhd'):
        teidlst.append(fname[:-4])
mxx = mxy = mxz = mxd = 0
for srsid, label, x, y, z, d in zip(alllst, labellst, crdxlst, crdylst, crdzlst, dimlst):
    mxx = max(abs(float(x)), mxx)
    mxy = max(abs(float(y)), mxy)
    mxz = max(abs(float(z)), mxz)
    mxd = max(abs(float(d)), mxd)
    if srsid in blklst: continue
    # crop raw pixel as feature
    data = np.load(os.path.join(preprocesspath, srsid + '.npy'))
    bgx = int(data.shape[0] / 2 - CROPSIZE / 2)
    bgy = int(data.shape[1] / 2 - CROPSIZE / 2)
    bgz = int(data.shape[2] / 2 - CROPSIZE / 2)
    data = np.array(data[bgx:bgx + CROPSIZE, bgy:bgy + CROPSIZE, bgz:bgz + CROPSIZE])
    # feat = np.hstack((np.reshape(data, (-1,)) / 255, float(d)))
    y, x, z = np.ogrid[-CROPSIZE / 2:CROPSIZE / 2, -CROPSIZE / 2:CROPSIZE / 2, -CROPSIZE / 2:CROPSIZE / 2]
    mask = abs(y ** 3 + x ** 3 + z ** 3) <= abs(float(d)) ** 3
    feat = np.zeros((CROPSIZE, CROPSIZE, CROPSIZE), dtype=float)
    feat[mask] = 1
    # print(feat.shape)
    if srsid.split('-')[0] in teidlst:
        tefnamelst.append(srsid + '.npy')
        telabellst.append(int(label))
        tefeatlst.append(feat)
    else:
        trfnamelst.append(srsid + '.npy')
        trlabellst.append(int(label))
        trfeatlst.append(feat)
for idx in range(len(trfeatlst)):
    # trfeatlst[idx][0] /= mxx
    # trfeatlst[idx][1] /= mxy
    # trfeatlst[idx][2] /= mxz
    trfeatlst[idx][-1] /= mxd
for idx in range(len(tefeatlst)):
    # tefeatlst[idx][0] /= mxx
    # tefeatlst[idx][1] /= mxy
    # tefeatlst[idx][2] /= mxz
    tefeatlst[idx][-1] /= mxd
trainset = lunanod(preprocesspath, trfnamelst, trlabellst, trfeatlst, train=True, download=True,
                   transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=20)

testset = lunanod(preprocesspath, tefnamelst, telabellst, tefeatlst, train=False, download=True,
                  transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=20)
savemodelpath = './checkpoint-' + str(fold) + '/'
# Model
print(args.resume)
if args.resume:
    print('==> Resuming from checkpoint..')
    print(args.savemodel)
    if args.savemodel == '':
        logging.info('==> Resuming from checkpoint..')
        assert os.path.isdir(savemodelpath), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(savemodelpath + 'ckpt.t7')
        
    else:
        logging.info('==> Resuming from checkpoint..')
        assert os.path.isdir(savemodelpath), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.savemodel)
    net = checkpoint['net']
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
    print(savemodelpath + " load success")
    print(start_epoch)
else:
    logging.info('==> Building model..')
    logging.info('args.savemodel : ' + args.savemodel)
    net = ConvRes([[64, 64, 64], [128, 128, 256], [256, 256, 256, 512]])
    if args.savemodel != "":
        # args.savemodel = '/home/xxx/DeepLung-master/nodcls/checkpoint-5/ckpt.t7'
        checkpoint = torch.load(args.savemodel)
        finenet = checkpoint
        Low_rankmodel_dic = net.state_dict()
        finenet = {k: v for k, v in finenet.items() if k in Low_rankmodel_dic}
        Low_rankmodel_dic.update(finenet)
        net.load_state_dict(Low_rankmodel_dic)
        print("net_loaded")

lr = args.lr


def get_lr(epoch):
    global lr
    if (epoch + 1) > (args.num_epochs - args.num_epochs_decay):
        lr -= (lr / float(args.num_epochs_decay))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        print('Decay learning rate to lr: {}.'.format(lr))


if use_cuda:
    net.cuda()
    if args.gpuids == 'all':
        device_ids = range(torch.cuda.device_count())
    else:
        device_ids = map(int, list(filter(str.isdigit, args.gpuids)))

    print('gpu use' + str(device_ids))
    net = torch.nn.DataParallel(net, device_ids=device_ids)
    cudnn.benchmark = False  # True

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))


# L2Loss = torch.nn.MSELoss()

# Training
def train(epoch):
    logging.info('\nEpoch: ' + str(epoch))
    net.train()
    get_lr(epoch)
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets, feat) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()
        train_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
        # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'

    print('ep ' + str(epoch) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr))

    logging.info(
        'ep ' + str(epoch) + ' tracc ' + str(correct.data.item() / float(total)) + ' lr ' + str(lr))


def test(epoch):
    epoch_start_time = time.time()
    global best_acc
    global best_acc_gbt
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    TP = FP = FN = TN = 0
    for batch_idx, (inputs, targets, feat) in enumerate(testloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets = Variable(inputs, requires_grad=False), Variable(targets)
        outputs = net(inputs)

        loss = criterion(outputs, targets)
        test_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
        TP += ((predicted == 1) & (targets.data == 1)).cpu().sum()
        TN += ((predicted == 0) & (targets.data == 0)).cpu().sum()
        FN += ((predicted == 0) & (targets.data == 1)).cpu().sum()
        FP += ((predicted == 1) & (targets.data == 0)).cpu().sum()

    # Save checkpoint.
    acc = 100. * correct.data.item() / total
    if acc > best_acc:
        logging.info('Saving..')
        state = {
            'net': net.module if use_cuda else net,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir(savemodelpath):
            os.mkdir(savemodelpath)
        torch.save(state, savemodelpath + 'ckpt.t7')
        best_acc = acc
    logging.info('Saving..')
    state = {
        'net': net.module if use_cuda else net,
        'acc': acc,
        'epoch': epoch,
    }
    if not os.path.isdir(savemodelpath):
        os.mkdir(savemodelpath)
    if epoch % 50 == 0:
        torch.save(state, savemodelpath + 'ckpt' + str(epoch) + '.t7')
    # best_acc = acc
    tpr = 100. * TP.data.item() / (TP.data.item() + FN.data.item())
    fpr = 100. * FP.data.item() / (FP.data.item() + TN.data.item())

    print('teacc ' + str(acc) + ' bestacc ' + str(best_acc))
    print('tpr ' + str(tpr) + ' fpr ' + str(fpr))
    print('Time Taken: %d sec' % (time.time() - epoch_start_time))
    logging.info(
        'teacc ' + str(acc) + ' bestacc ' + str(best_acc))
    logging.info(
        'tpr ' + str(tpr) + ' fpr ' + str(fpr))


if __name__ == '__main__':
    for epoch in range(start_epoch + 1, start_epoch + args.num_epochs + 1):  # 200):
        train(epoch)
        test(epoch)