Skip to content

Commit

Permalink
Update train_cycleGAN_loss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Tandon-A authored Nov 22, 2018
1 parent ceff625 commit e915f41
Showing 1 changed file with 35 additions and 34 deletions.
69 changes: 35 additions & 34 deletions train_cycleGAN_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import random
from PIL import Image
import os
import scipy.misc

"""
Import CycleGAN class definition.
Expand All @@ -23,6 +22,25 @@ def get_image_new(image_path,width,height):
image = np.subtract(image,0.5)
image = np.multiply(image,2)
return image

"""
Function to load training images.
"""
def get_data(trainA,trainB,width,height):
tr_A = []
tr_B = []
for i in range(len(trainA)):
tr_A.append(get_image_new(trainA[i],width,height))
if i % 200 == 0:
print ("getting trainA = %r" %(i))
for i in range(len(trainB)):
tr_B.append(get_image_new(trainB[i],width,height))
if i % 200 == 0:
print ("getting trainB = %r" %(i))
tr_A = np.array(tr_A)
tr_B = np.array(tr_B)
print ("Completed loading training data. DomainA = %r , DomainB = %r" %(tr_A.shape,tr_B.shape))
return tr_A,tr_B


"""
Expand All @@ -33,6 +51,7 @@ def save_to_pool(poolA,poolB,gen_A,gen_B,pool_size,num_im):
if num_im < pool_size:
poolA[num_im] = gen_A
poolB[num_im] = gen_B
num_im = num_im + 1

else:
p = random.random()
Expand All @@ -43,14 +62,13 @@ def save_to_pool(poolA,poolB,gen_A,gen_B,pool_size,num_im):
if p > 0.5:
indB = random.randint(0,pool_size-1)
poolB[indB] = gen_B

num_im = num_im + 1

return poolA,poolB,num_im

"""
Function to train the network
"""
def train(cgan_net,max_img,batch_size,trainA,trainB,lr_rate,shape,pool_size,model_dir,image_dir):
def train(cgan_net,max_img,batch_size,trainA,trainB,lr_rate,shape,pool_size,model_dir):
saver = tf.train.Saver(max_to_keep=None)
lenA = len(trainA)
lenB = len(trainB)
Expand All @@ -70,16 +88,16 @@ def train(cgan_net,max_img,batch_size,trainA,trainB,lr_rate,shape,pool_size,mode

if countA >= lenA:
countA = 0
random.shuffle(trainA)
np.random.shuffle(trainA)

if countB >= lenB:
countB = 0
random.shuffle(trainB)
np.random.shuffle(trainB)


imgA = get_image_new(trainA[countA],shape[0],shape[1])
imgA = trainA[countA]
countA = countA + 1
imgB = get_image_new(trainB[countB],shape[0],shape[1])
imgB = trainB[countB]
countB = countB + 1

imgA = np.reshape(imgA,(1,shape[0],shape[1],shape[2]))
Expand All @@ -102,30 +120,15 @@ def train(cgan_net,max_img,batch_size,trainA,trainB,lr_rate,shape,pool_size,mode
feed_dict={cgan_net.input_A:imgA,cgan_net.input_B:imgB,cgan_net.lr_rate:lr_rate,cgan_net.fake_pool_Aimg:fakeA_img,cgan_net.fake_pool_Bimg:fakeB_img})



if step % 50 == 0:
#tlogging training loss details
if step % 50 == 0 and epoch % 5 == 0:
print ("epoch = %r step = %r discA_loss = %r genA_loss = %r discB_loss = %r genB_loss = %r"
%(epoch,step,discA_loss,genA_loss,discB_loss,genB_loss))

if step % 150 == 0:
images = [genA,cyclicB,genB,cyclicA]
img_ind = 0
for img in images:
img = np.reshape(img,(shape[0],shape[1],shape[2]))
if np.array_equal(img.max(),img.min()) == False:
img = (((img - img.min())*255)/(img.max()-img.min())).astype(np.uint8)
else:
img = ((img - img.min())*255).astype(np.uint8)
scipy.misc.toimage(img, cmin=0.0, cmax=...).save(images_dir+"\\img_"+str(img_ind)+"_"+str(epoch)+"_"+str(step)+".jpg")
img_ind = img_ind + 1

if epoch % 10 == 0:
#change the second argument of save function to the path where you want to save model weights
saver.save(sess,model_dir+"try_"+str(epoch)+"\\",write_meta_graph=True)
print ("### Model weights Saved epoch = %r ###" %(epoch))


epoch = epoch + 1

saver.save(sess,model_dir,write_meta_graph=True)
print ("### Model weights Saved epoch = %r ###" %(epoch))



Expand All @@ -137,11 +140,10 @@ def main(_):

if not os.path.exists(FLAGS.model_dir):
os.makedirs(FLAGS.model_dir)
if not os.path.exists(FLAGS.sampled_images_dir):
os.makedirs(FLAGS.sampled_images_dir)

trainA = glob(FLAGS.data_path+"\\trainA\\"+FLAGS.input_fname_pattern)
trainB = glob(FLAGS.data_path+"\\trainB\\"+FLAGS.input_fname_pattern)
trainA = glob(FLAGS.data_path+"//trainA//"+FLAGS.input_fname_pattern)
trainB = glob(FLAGS.data_path+"//trainB//"+FLAGS.input_fname_pattern)
tr_imgA, tr_imgB = get_data(trainA,trainB,128,128)
input_shape = 128,128,3
batch_size = 1
pool_size = 50
Expand All @@ -154,15 +156,14 @@ def main(_):

cgan_net = CycleGAN(batch_size,input_shape,pool_size,beta1,loss_type)

train(cgan_net,max_img,batch_size,trainA,trainB,lr_rate,input_shape,pool_size,FLAGS.model_dir,FLAGS.sampled_images_dir)
train(cgan_net,max_img,batch_size,tr_imgA,tr_imgB,lr_rate,input_shape,pool_size,FLAGS.model_dir+"//")



flags = tf.app.flags
flags.DEFINE_string("data_path",None,"Path to parent directory of trainA and trainB folder")
flags.DEFINE_string("input_fname_pattern","*.jpg","Glob pattern of training images")
flags.DEFINE_string("model_dir","CycleGAN_model","Directory name to save checkpoints")
flags.DEFINE_string("sampled_images_dir","sampled_images","Directory where images sampled from the generator (while training the model) are stored")
flags.DEFINE_string("loss_type","l1","Loss type with which cycleGAN is to be trained")
FLAGS = flags.FLAGS

Expand Down

0 comments on commit e915f41

Please sign in to comment.