Skip to content

Commit

Permalink
Merge pull request caogang#18 from robotcator/wgan-cifar10
Browse files Browse the repository at this point in the history
Wgan cifar10
  • Loading branch information
caogang authored Aug 25, 2017
2 parents f6b8552 + c66c66b commit 9d5d407
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 0 deletions.
277 changes: 277 additions & 0 deletions gan_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
import os, sys
sys.path.append(os.getcwd())

import time
import tflib as lib
import tflib.save_images
import tflib.mnist
import tflib.cifar10
import tflib.plot
import tflib.inception_score

import numpy as np


import torch
import torchvision
from torch import nn
from torch import autograd
from torch import optim

# Download CIFAR-10 (Python version) at
# https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the
# extracted files here!
DATA_DIR = 'cifar-10-batches-py/'
if len(DATA_DIR) == 0:
raise Exception('Please specify path to data directory in gan_cifar.py!')

MODE = 'wgan-gp' # Valid options are dcgan, wgan, or wgan-gp
DIM = 128 # This overfits substantially; you're probably better off with 64
LAMBDA = 10 # Gradient penalty lambda hyperparameter
CRITIC_ITERS = 5 # How many critic iterations per generator iteration
BATCH_SIZE = 64 # Batch size
ITERS = 200000 # How many generator iterations to train for
OUTPUT_DIM = 3072 # Number of pixels in CIFAR10 (3*32*32)


class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
preprocess = nn.Sequential(
nn.Linear(128, 4 * 4 * 4 * DIM),
nn.BatchNorm2d(4 * 4 * 4 * DIM),
nn.ReLU(True),
)

block1 = nn.Sequential(
nn.ConvTranspose2d(4 * DIM, 2 * DIM, 2, stride=2),
nn.BatchNorm2d(2 * DIM),
nn.ReLU(True),
)
block2 = nn.Sequential(
nn.ConvTranspose2d(2 * DIM, DIM, 2, stride=2),
nn.BatchNorm2d(DIM),
nn.ReLU(True),
)
deconv_out = nn.ConvTranspose2d(DIM, 3, 2, stride=2)

self.preprocess = preprocess
self.block1 = block1
self.block2 = block2
self.deconv_out = deconv_out
self.tanh = nn.Tanh()

def forward(self, input):
output = self.preprocess(input)
output = output.view(-1, 4 * DIM, 4, 4)
output = self.block1(output)
output = self.block2(output)
output = self.deconv_out(output)
output = self.tanh(output)
return output.view(-1, 3, 32, 32)


class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
main = nn.Sequential(
nn.Conv2d(3, DIM, 3, 2, padding=1),
nn.LeakyReLU(),
nn.Conv2d(DIM, 2 * DIM, 3, 2, padding=1),
nn.LeakyReLU(),
nn.Conv2d(2 * DIM, 4 * DIM, 3, 2, padding=1),
nn.LeakyReLU(),
)

self.main = main
self.linear = nn.Linear(4*4*4*DIM, 1)

def forward(self, input):
output = self.main(input)
output = output.view(-1, 4*4*4*DIM)
output = self.linear(output)
return output

netG = Generator()
netD = Discriminator()
print netG
print netD

use_cuda = torch.cuda.is_available()
if use_cuda:
gpu = 0
if use_cuda:
netD = netD.cuda(gpu)
netG = netG.cuda(gpu)

one = torch.FloatTensor([1])
mone = one * -1
if use_cuda:
one = one.cuda(gpu)
mone = mone.cuda(gpu)

optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))

def calc_gradient_penalty(netD, real_data, fake_data):
# print "real_data: ", real_data.size(), fake_data.size()
alpha = torch.rand(BATCH_SIZE, 1)
alpha = alpha.expand(BATCH_SIZE, real_data.nelement()/BATCH_SIZE).contiguous().view(BATCH_SIZE, 3, 32, 32)
alpha = alpha.cuda(gpu) if use_cuda else alpha

interpolates = alpha * real_data + ((1 - alpha) * fake_data)

if use_cuda:
interpolates = interpolates.cuda(gpu)
interpolates = autograd.Variable(interpolates, requires_grad=True)

disc_interpolates = netD(interpolates)

gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(gpu) if use_cuda else torch.ones(
disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]

gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty

# For generating samples
def generate_image(frame, netG):
fixed_noise_128 = torch.randn(128, 128)
if use_cuda:
fixed_noise_128 = fixed_noise_128.cuda(gpu)
noisev = autograd.Variable(fixed_noise_128, volatile=True)
samples = netG(noisev)
samples = samples.view(-1, 3, 32, 32)
samples = samples.mul(0.5).add(0.5)
samples = samples.cpu().data.numpy()

lib.save_images.save_images(samples, './tmp/cifar10/samples_{}.jpg'.format(frame))

# For calculating inception score
def get_inception_score(G, ):
all_samples = []
for i in xrange(10):
samples_100 = torch.randn(100, 128)
if use_cuda:
samples_100 = samples_100.cuda(gpu)
samples_100 = autograd.Variable(samples_100, volatile=True)
all_samples.append(G(samples_100).cpu().data.numpy())

all_samples = np.concatenate(all_samples, axis=0)
all_samples = np.multiply(np.add(np.multiply(all_samples, 0.5), 0.5), 255).astype('int32')
all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0, 2, 3, 1)
return lib.inception_score.get_inception_score(list(all_samples))

# Dataset iterator
train_gen, dev_gen = lib.cifar10.load(BATCH_SIZE, data_dir=DATA_DIR)
def inf_train_gen():
while True:
for images, target in train_gen():
# yield images.astype('float32').reshape(BATCH_SIZE, 3, 32, 32).transpose(0, 2, 3, 1)
yield images
gen = inf_train_gen()
preprocess = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

for iteration in xrange(ITERS):
start_time = time.time()
############################
# (1) Update D network
###########################
for p in netD.parameters(): # reset requires_grad
p.requires_grad = True # they are set to False below in netG update
for i in xrange(CRITIC_ITERS):
_data = gen.next()
netD.zero_grad()

# train with real
_data = _data.reshape(BATCH_SIZE, 3, 32, 32).transpose(0, 2, 3, 1)
real_data = torch.stack([preprocess(item) for item in _data])

if use_cuda:
real_data = real_data.cuda(gpu)
real_data_v = autograd.Variable(real_data)

# import torchvision
# filename = os.path.join("test_train_data", str(iteration) + str(i) + ".jpg")
# torchvision.utils.save_image(real_data, filename)

D_real = netD(real_data_v)
D_real = D_real.mean()
D_real.backward(mone)

# train with fake
noise = torch.randn(BATCH_SIZE, 128)
if use_cuda:
noise = noise.cuda(gpu)
noisev = autograd.Variable(noise, volatile=True) # totally freeze netG
fake = autograd.Variable(netG(noisev).data)
inputv = fake
D_fake = netD(inputv)
D_fake = D_fake.mean()
D_fake.backward(one)

# train with gradient penalty
gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
gradient_penalty.backward()

# print "gradien_penalty: ", gradient_penalty

D_cost = D_fake - D_real + gradient_penalty
Wasserstein_D = D_real - D_fake
optimizerD.step()
############################
# (2) Update G network
###########################
for p in netD.parameters():
p.requires_grad = False # to avoid computation
netG.zero_grad()

noise = torch.randn(BATCH_SIZE, 128)
if use_cuda:
noise = noise.cuda(gpu)
noisev = autograd.Variable(noise)
fake = netG(noisev)
G = netD(fake)
G = G.mean()
G.backward(mone)
G_cost = -G
optimizerG.step()

# Write logs and save samples
lib.plot.plot('./tmp/cifar10/train disc cost', D_cost.cpu().data.numpy())
lib.plot.plot('./tmp/cifar10/time', time.time() - start_time)
lib.plot.plot('./tmp/cifar10/train gen cost', G_cost.cpu().data.numpy())
lib.plot.plot('./tmp/cifar10/wasserstein distance', Wasserstein_D.cpu().data.numpy())

# Calculate inception score every 1K iters
if False and iteration % 1000 == 999:
inception_score = get_inception_score(netG)
lib.plot.plot('./tmp/cifar10/inception score', inception_score[0])

# Calculate dev loss and generate samples every 100 iters
if iteration % 100 == 99:
dev_disc_costs = []
for images, _ in dev_gen():
images = images.reshape(BATCH_SIZE, 3, 32, 32).transpose(0, 2, 3, 1)
imgs = torch.stack([preprocess(item) for item in images])

# imgs = preprocess(images)
if use_cuda:
imgs = imgs.cuda(gpu)
imgs_v = autograd.Variable(imgs, volatile=True)

D = netD(imgs_v)
_dev_disc_cost = -D.mean().cpu().data.numpy()
dev_disc_costs.append(_dev_disc_cost)
lib.plot.plot('./tmp/cifar10/dev disc cost', np.mean(dev_disc_costs))

generate_image(iteration, netG)

# Save logs every 100 iters
if (iteration < 5) or (iteration % 100 == 99):
lib.plot.flush()
lib.plot.tick()
Binary file added imgs/cifar10_dev_disc_cost.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/cifar10_samples_79699.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/cifar10_samples_79799.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/cifar10_samples_79899.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/cifar10_samples_79999.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/cifar10_samples_80099.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/cifar10_train_disc_cost.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/cifar10_train_gen_cost.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/cifar_10wasserstein_distance.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 9d5d407

Please sign in to comment.