From daea68903eb54b4970d0b04e11abc949767ea7e1 Mon Sep 17 00:00:00 2001 From: Matthieu Courbariaux Date: Fri, 21 Aug 2015 20:48:48 -0400 Subject: [PATCH] first working version of BinaryConnect --- binary_connect.py | 59 +++++++++++++++++ mnist_mlp.py | 164 ++++++++++++++++++++++++++++------------------ 2 files changed, 161 insertions(+), 62 deletions(-) create mode 100644 binary_connect.py diff --git a/binary_connect.py b/binary_connect.py new file mode 100644 index 0000000..e54cc1c --- /dev/null +++ b/binary_connect.py @@ -0,0 +1,59 @@ + +from collections import OrderedDict + +import numpy as np + +import theano +import theano.tensor as T + +def weights_clipping(updates): + + params = updates.keys() + updates = OrderedDict(updates) + + for param in params: + if param.name is not None: + if "W" in param.name: + # print("ok") + updates[param] = T.clip(updates[param], -1, 1) + + return updates + +from theano.scalar.basic import UnaryScalarOp, same_out_nocomplex +from theano.tensor.elemwise import Elemwise + +class Binarize(UnaryScalarOp): + + def c_code(self, node, name, (x,), (z,), sub): + return "%(z)s = 2*(%(x)s >= 0)-1;" % locals() + + def grad(self, (x, ), (gz, )): + return [gz] + +binarize = Elemwise(Binarize(same_out_nocomplex, name='binarize')) + +import lasagne + +class BinaryDenseLayer(lasagne.layers.DenseLayer): + + def __init__(self, incoming, num_units, W=lasagne.init.Uniform((-1,1)), **kwargs): + + super(BinaryDenseLayer, self).__init__(incoming, num_units, W, **kwargs) + # self._srng = RandomStreams(lasagne.random.get_rng().randint(1, 2147462579)) + + # def get_output_for(self, input, deterministic=False, **kwargs): + def get_output_for(self, input, **kwargs): + + if input.ndim > 2: + # if the input has more than two dimensions, flatten it into a + # batch of feature vectors. + input = input.flatten(2) + + # deterministic BinaryConnect + # Wb = T.cast(T.switch(T.ge(self.W,0),1,-1), theano.config.floatX) + Wb = binarize(self.W) + + activation = T.dot(input,Wb) + if self.b is not None: + activation = activation + self.b.dimshuffle('x', 0) + return self.nonlinearity(activation) \ No newline at end of file diff --git a/mnist_mlp.py b/mnist_mlp.py index 8d75e13..deeaaeb 100644 --- a/mnist_mlp.py +++ b/mnist_mlp.py @@ -17,6 +17,8 @@ from batch_norm import BatchNormLayer +from binary_connect import weights_clipping, BinaryDenseLayer + if __name__ == "__main__": # BN parameters @@ -52,14 +54,16 @@ X_val = X_val.reshape((-1, 1, 28, 28)) X_test = X_test.reshape((-1, 1, 28, 28)) + # without standardization .97%, 1.11% + # with standardization 1.06%, 1.34% # standardize the dataset - def standardize(X): - X -= X.mean(axis=0) - X /= (X.std(axis=0)+epsilon) - return X - X_train = standardize(X_train) - X_val = standardize(X_val) - X_test = standardize(X_test) + # def standardize(X): + # X -= X.mean(axis=0) + # X /= (X.std(axis=0)+epsilon) + # return X + # X_train = standardize(X_train) + # X_val = standardize(X_val) + # X_test = standardize(X_test) # flatten the targets y_train = np.hstack(y_train) @@ -86,16 +90,21 @@ def standardize(X): shape=(None, 1, 28, 28), input_var=input) - mlp = lasagne.layers.DropoutLayer( - mlp, - p=0.2) + # mlp = lasagne.layers.DropoutLayer( + # mlp, + # p=0.2) for k in range(n_hidden_layers): - mlp = lasagne.layers.DenseLayer( + # mlp = lasagne.layers.DenseLayer( + # mlp, + # nonlinearity=lasagne.nonlinearities.identity, + # num_units=num_units) + + mlp = BinaryDenseLayer( mlp, nonlinearity=lasagne.nonlinearities.identity, - num_units=num_units) + num_units=num_units) mlp = BatchNormLayer( mlp, @@ -103,15 +112,20 @@ def standardize(X): alpha=alpha, nonlinearity=lasagne.nonlinearities.rectify) - mlp = lasagne.layers.DropoutLayer( + # mlp = lasagne.layers.DropoutLayer( + # mlp, + # p=0.5) + + mlp = BinaryDenseLayer( mlp, - p=0.5) - - mlp = lasagne.layers.DenseLayer( - mlp, - nonlinearity=lasagne.nonlinearities.identity, - num_units=10) - + nonlinearity=lasagne.nonlinearities.identity, + num_units=10) + + # mlp = lasagne.layers.DenseLayer( + # mlp, + # nonlinearity=lasagne.nonlinearities.identity, + # num_units=10) + mlp = BatchNormLayer( mlp, epsilon=epsilon, @@ -124,6 +138,7 @@ def standardize(X): params = lasagne.layers.get_all_params(mlp, trainable=True) # updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=0.1, momentum=0.9) updates = lasagne.updates.adam(loss, params) + updates = weights_clipping(updates) test_output = lasagne.layers.get_output(mlp, deterministic=True) test_loss = T.mean(T.sqr(T.maximum(0.,1.-target*test_output))) @@ -138,60 +153,85 @@ def standardize(X): print('Training...') - def iterate_minibatches(inputs, targets, batchsize, shuffle=False): - assert len(inputs) == len(targets) - if shuffle: - indices = np.arange(len(inputs)) - np.random.shuffle(indices) - for start_idx in range(0, len(inputs) - batchsize + 1, batchsize): - if shuffle: - excerpt = indices[start_idx:start_idx + batchsize] - else: - excerpt = slice(start_idx, start_idx + batchsize) - yield inputs[excerpt], targets[excerpt] - + def shuffle(X,y): + + shuffled_range = range(len(X)) + np.random.shuffle(shuffled_range) + # print(shuffled_range[0:10]) + + new_X = np.copy(X) + new_y = np.copy(y) + + for i in range(len(X)): + + new_X[i] = X[shuffled_range[i]] + new_y[i] = y[shuffled_range[i]] + + return new_X,new_y + + # shuffle the train set + X_train,y_train = shuffle(X_train,y_train) + best_val_err = 100 + best_epoch = 1 + # We iterate over epochs: for epoch in range(num_epochs): # In each epoch, we do a full pass over the training data: train_loss = 0 - train_batches = 0 + train_batches = len(X_train)/batch_size start_time = time.time() - for batch in iterate_minibatches(X_train, y_train, batch_size, shuffle=True): - inputs, targets = batch - train_loss += train_fn(inputs, targets) - train_batches += 1 - + + for i in range(train_batches): + train_loss += train_fn(X_train[i*batch_size:(i+1)*batch_size],y_train[i*batch_size:(i+1)*batch_size]) + + train_loss/=train_batches + # And a full pass over the validation data: val_err = 0 val_loss = 0 - val_batches = 0 - for batch in iterate_minibatches(X_val, y_val, batch_size, shuffle=False): - inputs, targets = batch - loss, err = val_fn(inputs, targets) + val_batches = len(X_val)/batch_size + + for i in range(val_batches): + loss, err = val_fn(X_val[i*batch_size:(i+1)*batch_size], y_val[i*batch_size:(i+1)*batch_size]) val_err += err val_loss += loss - val_batches += 1 - + + val_err = val_err / val_batches * 100 + val_loss /= val_batches + + # test if validation error went down + if val_err <= best_val_err: + + best_val_err = val_err + best_epoch = epoch+1 + + test_err = 0 + test_loss = 0 + test_batches = len(X_test)/batch_size + + for i in range(test_batches): + loss, err = val_fn(X_test[i*batch_size:(i+1)*batch_size], y_test[i*batch_size:(i+1)*batch_size]) + test_err += err + test_loss += loss + + test_err = test_err / test_batches * 100 + test_loss /= test_batches + + # shuffle the train set + X_train,y_train = shuffle(X_train,y_train) + + epoch_duration = time.time() - start_time + # Then we print the results for this epoch: - print("Epoch {} of {} took {:.3f}s".format(epoch + 1, num_epochs, time.time() - start_time)) - print(" training loss:\t\t{:.6f}".format(train_loss / train_batches)) - print(" validation loss:\t\t{:.6f}".format(val_loss / val_batches)) - print(" validation error rate:\t{:.2f} %".format(val_err / val_batches * 100)) - - # After training, we compute and print the test error: - test_err = 0 - test_loss = 0 - test_batches = 0 - for batch in iterate_minibatches(X_test, y_test, batch_size, shuffle=False): - inputs, targets = batch - loss, err = val_fn(inputs, targets) - test_err += err - test_loss += loss - test_batches += 1 - print("Final results:") - print(" test loss:\t\t\t{:.6f}".format(test_loss / test_batches)) - print(" test error rate:\t\t{:.2f} %".format(test_err / test_batches * 100)) + print("Epoch "+str(epoch + 1)+" of "+str(num_epochs)+" took "+str(epoch_duration)+"s") + print(" training loss: "+str(train_loss)) + print(" validation loss: "+str(val_loss)) + print(" validation error rate: "+str(val_err)+"%") + print(" best epoch: "+str(best_epoch)) + print(" best validation error rate: "+str(best_val_err)+"%") + print(" test loss: "+str(test_loss)) + print(" test error rate: "+str(test_err)+"%") # Optionally, you could now dump the network weights to a file like this: # np.savez('model.npz', lasagne.layers.get_all_param_values(network)) \ No newline at end of file