Skip to content

Commit

Permalink
train function
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthieuCourbariaux committed Sep 2, 2015
1 parent 11f9e33 commit 0e27148
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 112 deletions.
94 changes: 93 additions & 1 deletion binary_connect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

import time

from collections import OrderedDict

import numpy as np
Expand Down Expand Up @@ -114,4 +116,94 @@ def get_output_for(self, input, deterministic=False, **kwargs):

if self.b is not None:
activation = activation + self.b.dimshuffle('x', 0)
return self.nonlinearity(activation)
return self.nonlinearity(activation)

def train(train_fn,val_fn,
batch_size,
LR_start,LR_decay,
num_epochs,
X_train,y_train,
X_val,y_val,
X_test,y_test):

def shuffle(X,y):

shuffled_range = range(len(X))
np.random.shuffle(shuffled_range)
# print(shuffled_range[0:10])

new_X = np.copy(X)
new_y = np.copy(y)

for i in range(len(X)):

new_X[i] = X[shuffled_range[i]]
new_y[i] = y[shuffled_range[i]]

return new_X,new_y

def train_epoch(X,y,LR):

loss = 0
batches = len(X)/batch_size

for i in range(batches):
loss += train_fn(X[i*batch_size:(i+1)*batch_size],y[i*batch_size:(i+1)*batch_size],LR)

loss/=batches

return loss

def val_epoch(X,y):

err = 0
loss = 0
batches = len(X)/batch_size

for i in range(batches):
new_loss, new_err = val_fn(X[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size])
err += new_err
loss += new_loss

err = err / batches * 100
loss /= batches

return err, loss

# shuffle the train set
X_train,y_train = shuffle(X_train,y_train)
best_val_err = 100
best_epoch = 1
LR = LR_start

# We iterate over epochs:
for epoch in range(num_epochs):

start_time = time.time()

train_loss = train_epoch(X_train,y_train,LR)
X_train,y_train = shuffle(X_train,y_train)
LR *= LR_decay

val_err, val_loss = val_epoch(X_val,y_val)

# test if validation error went down
if val_err <= best_val_err:

best_val_err = val_err
best_epoch = epoch+1

test_err, test_loss = val_epoch(X_test,y_test)

epoch_duration = time.time() - start_time

# Then we print the results for this epoch:
print("Epoch "+str(epoch + 1)+" of "+str(num_epochs)+" took "+str(epoch_duration)+"s")
print(" LR: "+str(LR))
print(" training loss: "+str(train_loss))
print(" validation loss: "+str(val_loss))
print(" validation error rate: "+str(val_err)+"%")
print(" best epoch: "+str(best_epoch))
print(" best validation error rate: "+str(best_val_err)+"%")
print(" test loss: "+str(test_loss))
print(" test error rate: "+str(test_err)+"%")
128 changes: 17 additions & 111 deletions mnist_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import binary_connect

if __name__ == "__main__":
# def MNIST_exp(dropout_in,dropout_hidden,binary,stochastic,H):

# BN parameters
batch_size = 200
# alpha is the exponential moving average factor
# alpha = .1 # for a minibatch of size 50
# alpha = .2 # for a minibatch of size 100
alpha = .33 # for a minibatch of size 200
Expand All @@ -36,7 +36,7 @@
n_hidden_layers = 3

# Training parameters
num_epochs = 100
num_epochs = 1000

# Dropout parameters
dropout_in = 0.
Expand All @@ -51,8 +51,8 @@
# H = 1.

# LR decay
LR_start = .001
LR_fin = .00001
LR_start = 3.
LR_fin = .1
LR_decay = (LR_fin/LR_start)**(1./num_epochs)
# BTW, LR decay is good for the moving average...

Expand Down Expand Up @@ -159,14 +159,15 @@

if binary:
grads = binary_connect.compute_grads(loss,mlp)
updates = lasagne.updates.adam(loss_or_grads=grads, params=params, learning_rate=LR)
# updates = lasagne.updates.sgd(grads, params, learning_rate=.3)
# updates = lasagne.updates.adam(loss_or_grads=grads, params=params, learning_rate=LR)
updates = lasagne.updates.sgd(loss_or_grads=grads, params=params, learning_rate=LR)
# updates = binary_connect.weights_clipping(updates,H)
updates = binary_connect.weights_clipping(updates,mlp)
# using 2H instead of H with stochastic yields about 20% relative worse results

else:
updates = lasagne.updates.adam(loss_or_grads=loss, params=params)
# updates = lasagne.updates.adam(loss_or_grads=loss, params=params, learning_rate=LR)
updates = lasagne.updates.sgd(loss_or_grads=loss, params=params, learning_rate=LR)
# updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=0.01, momentum=0.9)

test_output = lasagne.layers.get_output(mlp, deterministic=True)
Expand All @@ -183,87 +184,14 @@

print('Training...')

def shuffle(X,y):

shuffled_range = range(len(X))
np.random.shuffle(shuffled_range)
# print(shuffled_range[0:10])

new_X = np.copy(X)
new_y = np.copy(y)

for i in range(len(X)):

new_X[i] = X[shuffled_range[i]]
new_y[i] = y[shuffled_range[i]]

return new_X,new_y

def train_epoch(X,y,batch_size,LR):

loss = 0
batches = len(X)/batch_size

for i in range(batches):
loss += train_fn(X[i*batch_size:(i+1)*batch_size],y[i*batch_size:(i+1)*batch_size],LR)

loss/=batches

return loss

def val_epoch(X,y,batch_size):

err = 0
loss = 0
batches = len(X)/batch_size

for i in range(batches):
new_loss, new_err = val_fn(X[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size])
err += new_err
loss += new_loss

err = err / batches * 100
loss /= batches

return err, loss

# shuffle the train set
X_train,y_train = shuffle(X_train,y_train)
best_val_err = 100
best_epoch = 1
LR = LR_start

# We iterate over epochs:
for epoch in range(num_epochs):

start_time = time.time()

train_loss = train_epoch(X_train,y_train,batch_size,LR)
X_train,y_train = shuffle(X_train,y_train)
LR *= LR_decay

val_err, val_loss = val_epoch(X_val,y_val,batch_size)

# test if validation error went down
if val_err <= best_val_err:

best_val_err = val_err
best_epoch = epoch+1

test_err, test_loss = val_epoch(X_test,y_test,batch_size)

epoch_duration = time.time() - start_time

# Then we print the results for this epoch:
print("Epoch "+str(epoch + 1)+" of "+str(num_epochs)+" took "+str(epoch_duration)+"s")
print(" LR: "+str(LR))
print(" training loss: "+str(train_loss))
print(" validation loss: "+str(val_loss))
print(" validation error rate: "+str(val_err)+"%")
print(" best epoch: "+str(best_epoch))
print(" best validation error rate: "+str(best_val_err)+"%")
print(" test loss: "+str(test_loss))
print(" test error rate: "+str(test_err)+"%")
binary_connect.train(
train_fn,val_fn,
batch_size,
LR_start,LR_decay,
num_epochs,
X_train,y_train,
X_val,y_val,
X_test,y_test)

# print("display histogram")

Expand All @@ -275,26 +203,4 @@ def val_epoch(X,y,batch_size):
# np.savetxt(str(dropout_hidden)+str(binary)+str(stochastic)+str(H)+"_hist1.csv", histogram[1], delimiter=",")

# Optionally, you could now dump the network weights to a file like this:
# np.savez('model.npz', lasagne.layers.get_all_param_values(network))

# if __name__ == "__main__":
if False:

# baselines
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=False,stochastic=False,H=1.)
MNIST_exp(dropout_in=0.2,dropout_hidden=0.5,binary=False,stochastic=False,H=1.)

# stochastic BC
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=True,stochastic=True,H=1.)
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=True,stochastic=True,H=1./(1<<2))
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=True,stochastic=True,H=1./(1<<4))
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=True,stochastic=True,H=1./(1<<6))
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=True,stochastic=True,H=1./(1<<8))

# deterministic BC
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=True,stochastic=False,H=1.)
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=True,stochastic=False,H=1./(1<<2))
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=True,stochastic=False,H=1./(1<<4))
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=True,stochastic=False,H=1./(1<<6))
MNIST_exp(dropout_in=0.,dropout_hidden=0.,binary=True,stochastic=False,H=1./(1<<8))

# np.savez('model.npz', lasagne.layers.get_all_param_values(network))

0 comments on commit 0e27148

Please sign in to comment.