Skip to content

Commit

Permalink
Merge pull request #2 from CSMMLab/develop
Browse files Browse the repository at this point in the history
Version 1.1: Added low-rank transformers and refactored code
  • Loading branch information
ScSteffen authored Nov 2, 2022
2 parents 7ba103f + 1f10847 commit 72d9589
Show file tree
Hide file tree
Showing 58 changed files with 6,655 additions and 419 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
python/mat/
python/data/
*.pth
*.pyc

# Files generated by invoking Julia with --code-coverage
*.jl.cov
Expand Down
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# DLRANet
## [Low-rank lottery tickets: finding efficient low-rank neural networks via matrix differential equations](https://arxiv.org/abs/2205.13571)

## Codebase for the Tensorflow DLRT implementation with rank adaption.
The NeurIPS publication to this work can be found at [Low-rank lottery tickets: finding efficient low-rank neural networks via matrix differential equations](https://arxiv.org/abs/2205.13571)
Code supplement for all dense neural network experiments of the arXiv Preprint

### Usage

Expand All @@ -21,7 +20,11 @@ The NeurIPS publication to this work can be found at [Low-rank lottery tickets:
4. ``sh run_tests_fixed_rank_train_from_prune.sh`` loads the weights of a traditional network (provided in the
folder "dense_weights"), then factorizes the weight matrix and truncates all but 20 eigenvalues. Then, fixed
low-rank training is used to retrain the model. This method is used in Section 7.3

5. ``sh run_test_transformer_dlrt.sh`` and ``sh run_test_transformer_big_dlrt.sh`` trains a transformer on the
portuguese to english translation task with DLRT
6. ``sh run_test_transformer_fix_rank.sh`` and ``sh run_test_transformer_big_fix_rank.sh`` trains a transformer on
the portuguese to english translation task with fixed rank DLRT.

### Useful links

The pytorch version can be found [here](https://github.com/COMPiLELab/DLRT/tree/efficient_gradient)
1 change: 1 addition & 0 deletions docker/docker_run_interactive.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
docker run --gpus all -i --rm -t -v $(pwd)/..:/mnt scsteffen/neural_entropy:latest /bin/bash
1 change: 1 addition & 0 deletions docker/docker_start_training.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
docker run -t --rm -v $(pwd)/..:/mnt scsteffen/neural_entropy:latest /bin/bash -c "/mnt/docker/train_on_docker.sh"
9 changes: 9 additions & 0 deletions docker/train_on_docker.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#mkdir /home/neuralEntropy
cp -r /mnt /home/neuralEntropy

cd /home/neuralEntropy/
mkdir models
/usr/bin/python3 -m pip install --upgrade pip
pip install -r requirements.txt

sh rs_test.sh
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ numpy==1.22.2
Pillow==9.0.1
tensorflow==2.8.0
pandas

tensorflow_datasets
tensorflow-text==2.8.*
nltk
2 changes: 2 additions & 0 deletions run_scripts/run_test_transformer_big_dlrt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# DLRT for a 512-d transformer on the portuguese-english dataset with tau=0.01 for 100 epochs
python src/speech_transformer_big_DLRT.py -t 0.01 -e 100
2 changes: 2 additions & 0 deletions run_scripts/run_test_transformer_big_fix_rank.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# DLRT for a 512-d transformer on the portuguese-english dataset with tau=0.01 for 100 epochs
python src/speech_transformer_big_DLRT_fr.py -r 50 -e 100
2 changes: 2 additions & 0 deletions run_scripts/run_test_transformer_dlrt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# DLRT for a 128-d transformer on the portuguese-english dataset with tau=0.01 for 100 epochs
python src/speech_transformer_DLRT.py -t 0.01 -e 100
2 changes: 2 additions & 0 deletions run_scripts/run_test_transformer_fix_rank.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# DLRT for a 128-d transformer on the portuguese-english dataset with tau=0.01 for 100 epochs
python src/speech_transformer_DLRT_fr.py -r 50 -e 100
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Standard neural network training, for 5-layer network of widths [784,784,784,784,10]
python mnist_reference.py --load 0
python src/mnist_reference.py --load 0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Fixed Rank training, for 5-layer network of widths [500,500,500,500,10] wit low-ranks [20,20,20,20,10] for 100 epochs. Last layer has fixed rank 10 (since we classfy 10 classes)
python mnist_DLRA_fixed_rank.py -s 20 -t 1.0 -l 0 --train 1 -d 500
python src/mnist_DLRA_fixed_rank.py -s 20 -t 1.0 -l 0 --train 1 -d 500
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Fixed Rank training, for 5-layer network of widths [500,500,500,500,10] wit adpative low-ranks for 10 epochs. Last layer has fixed rank 10 (since we classfy 10 classes)
# Starting rank is set to 300, rank adaption tolerance is set to 0.17
python mnist_DLRA.py -s 300 -t 0.17 -l 0 -a 1 -d 500
python src/mnist_DLRA.py -s 300 -t 0.17 -l 0 -a 1 -d 500
# Fixed Rank finetuning for 100 epochs (flags -s and -t are set only to navigate into the right save-directory)
python mnist_DLRA_fixed_rank.py -s 300 -t 0.17 -l 1 --train 1 -d 500
python src/mnist_DLRA_fixed_rank.py -s 300 -t 0.17 -l 1 --train 1 -d 500
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Loads a dense 5-layer neural networks with width [784,784,784,784,10]
# Then, the dense, full-rank weight matrices are factorized using SVD, and we keep the top 20 eigenvalue-eigenvector pairs.
# The decomposed network is first evaluated (first line of the history file), and then retrained using our fixed-rank training algorithm.
python mnist_DLRA_fixed_rank_retrain_from_prune.py -s 20 -l 1 --train 1
python src/mnist_DLRA_fixed_rank_retrain_from_prune.py -s 20 -l 1 --train 1
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Fixed Rank training, for 5-layer network of widths [500,500,500,500,10] with adaptive low-ranks for 10 epochs. Last layer has fixed rank 10 (since we classfy 10 classes)
# Starting rank is set to 150, rank adaption tolerance is set to 0.17, and max rank to 300.
python mnist_DLRA.py -s 150 -t 0.17 -l 0 -a 1 -d 500 -m 300 -e 100
python src/mnist_DLRA.py -s 150 -t 0.17 -l 0 -a 1 -d 500 -m 300 -e 100
Binary file removed src/dense_weights/best_model/b_0.npy
Binary file not shown.
Binary file removed src/dense_weights/best_model/b_1.npy
Binary file not shown.
Binary file removed src/dense_weights/best_model/b_2.npy
Binary file not shown.
Binary file removed src/dense_weights/best_model/b_3.npy
Binary file not shown.
Binary file removed src/dense_weights/best_model/b_4.npy
Binary file not shown.
Binary file removed src/dense_weights/best_model/w_0.npy
Binary file not shown.
Binary file removed src/dense_weights/best_model/w_1.npy
Binary file not shown.
Binary file removed src/dense_weights/best_model/w_2.npy
Binary file not shown.
Binary file removed src/dense_weights/best_model/w_3.npy
Binary file not shown.
Binary file removed src/dense_weights/best_model/w_4.npy
Binary file not shown.
Binary file removed src/dense_weights/latest_model/b_0.npy
Binary file not shown.
Binary file removed src/dense_weights/latest_model/b_1.npy
Binary file not shown.
Binary file removed src/dense_weights/latest_model/b_2.npy
Binary file not shown.
Binary file removed src/dense_weights/latest_model/b_3.npy
Binary file not shown.
Binary file removed src/dense_weights/latest_model/b_4.npy
Binary file not shown.
Binary file removed src/dense_weights/latest_model/w_0.npy
Binary file not shown.
Binary file removed src/dense_weights/latest_model/w_1.npy
Binary file not shown.
Binary file removed src/dense_weights/latest_model/w_2.npy
Binary file not shown.
Binary file removed src/dense_weights/latest_model/w_3.npy
Binary file not shown.
Binary file removed src/dense_weights/latest_model/w_4.npy
Binary file not shown.
37 changes: 18 additions & 19 deletions src/mnist_DLRA.py → src/mnist_DLRT.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dlranet import DLRANetAdaptive, create_csv_logger_cb
from networks.dense_dlrt_nets import DLRTNetAdaptive
from networks.utils import create_csv_logger_cb

import tensorflow as tf
from tensorflow import keras
Expand All @@ -12,9 +13,10 @@ def train(start_rank, tolerance, load_model, dim_layer, rmax, epochs):
epochs = epochs
batch_size = 256

filename = "e2edense_sr" + str(start_rank) + "_v" + str(tolerance)
folder_name = "e2edense_sr" + str(start_rank) + "_v" + str(tolerance) + '/latest_model'
folder_name_best = "e2edense_sr" + str(start_rank) + "_v" + str(tolerance) + '/best_model'
name = "mnist_dense_sr"
filename = name + str(start_rank) + "_v" + str(tolerance)
folder_name = name + str(start_rank) + "_v" + str(tolerance) + '/latest_model'
folder_name_best = name + str(start_rank) + "_v" + str(tolerance) + '/best_model'

# check if dir exists
if not path.exists(folder_name):
Expand All @@ -35,8 +37,10 @@ def train(start_rank, tolerance, load_model, dim_layer, rmax, epochs):

dlra_layer_dim = dim_layer

model = DLRANetAdaptive(input_dim=input_dim, output_dim=output_dim, low_rank=starting_rank,
model = DLRTNetAdaptive(input_dim=input_dim, output_dim=output_dim, low_rank=starting_rank,
dlra_layer_dim=dlra_layer_dim, tol=tol, rmax_total=max_rank)
model.build_model()

# Build optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
# Choose loss
Expand Down Expand Up @@ -95,7 +99,6 @@ def train(start_rank, tolerance, load_model, dim_layer, rmax, epochs):
model.dlraBlock2.l_step_preprocessing()
model.dlraBlock3.k_step_preprocessing()
model.dlraBlock3.l_step_preprocessing()


# 1.b) Tape Gradients for K-Step
model.toggle_non_s_step_training()
Expand All @@ -106,8 +109,8 @@ def train(start_rank, tolerance, load_model, dim_layer, rmax, epochs):
# Compute reconstruction loss
loss = loss_fn(batch_train[1], out)
loss += sum(model.losses) # Add KLD regularization loss
if step == 0:

if step == 0:
# Network monotoring and verbosity
loss_metric.update_state(loss)
prediction = tf.math.argmax(out, 1)
Expand All @@ -119,7 +122,7 @@ def train(start_rank, tolerance, load_model, dim_layer, rmax, epochs):

print("step %d: mean loss S-Step = %.4f" % (step, loss_value))
print("Accuracy: " + str(acc_value))
print("Loss: "+ str(loss_value))
print("Loss: " + str(loss_value))
print("Current Rank: " + str(int(model.dlraBlockInput.low_rank)) + " | " + str(
int(model.dlraBlock1.low_rank)) + " | " + str(
int(model.dlraBlock2.low_rank)) + " | " + str(int(model.dlraBlock3.low_rank)) + " )")
Expand All @@ -143,7 +146,7 @@ def train(start_rank, tolerance, load_model, dim_layer, rmax, epochs):
acc_metric.update_state(prediction, y_val)
acc_val = acc_metric.result().numpy()
print("Accuracy: " + str(acc_val))
print("Loss: "+ str(loss_val))
print("Loss: " + str(loss_val))
# save current model if it's the best
if acc_val >= best_acc and loss_val <= best_loss:
best_acc = acc_val
Expand All @@ -155,9 +158,9 @@ def train(start_rank, tolerance, load_model, dim_layer, rmax, epochs):
# Reset metrics
loss_metric.reset_state()
acc_metric.reset_state()

print("----- Test Metrics (not used for early stopping) ----")

# Test model
out = model(x_test, step=0, training=True)
out = tf.keras.activations.softmax(out)
Expand All @@ -169,19 +172,17 @@ def train(start_rank, tolerance, load_model, dim_layer, rmax, epochs):
acc_metric.update_state(prediction, y_test)
acc_test = acc_metric.result().numpy()
print("Accuracy: " + str(acc_test))
print("Loss: "+ str(loss_test))
print("Loss: " + str(loss_test))
# Reset metrics
loss_metric.reset_state()
acc_metric.reset_state()
print("-------------------------------------\n\n")


# Gradient updates for k step
grads_k_step = tape.gradient(loss, model.trainable_weights)
model.set_none_grads_to_zero(grads_k_step, model.trainable_weights)
model.set_dlra_bias_grads_to_zero(grads_k_step)


# 1.b) Tape Gradients for L-Step
with tf.GradientTape() as tape:
out = model(batch_train[0], step=1, training=True)
Expand All @@ -198,7 +199,7 @@ def train(start_rank, tolerance, load_model, dim_layer, rmax, epochs):
optimizer.apply_gradients(zip(grads_k_step, model.trainable_weights))
optimizer.apply_gradients(zip(grads_l_step, model.trainable_weights))

# Postprocessing K and L
# Postprocessing K and L (excplicitly writing down for each layer)
model.dlraBlockInput.k_step_postprocessing_adapt()
model.dlraBlockInput.l_step_postprocessing_adapt()
model.dlraBlock1.k_step_postprocessing_adapt()
Expand Down Expand Up @@ -235,8 +236,6 @@ def train(start_rank, tolerance, load_model, dim_layer, rmax, epochs):
model.dlraBlock2.rank_adaption()
model.dlraBlock3.rank_adaption()



# Log Data of current epoch
log_string = str(loss_value) + ";" + str(acc_value) + ";" + str(
loss_val) + ";" + str(acc_val) + ";" + str(
Expand All @@ -262,7 +261,7 @@ def normalize_img(image, label):
# --- parse options ---
parser = OptionParser()
parser.add_option("-s", "--start_rank", dest="start_rank", default=10)
parser.add_option("-t", "--tolerance", dest="tolerance", default=10)
parser.add_option("-t", "--tolerance", dest="tolerance", default=0.05)
parser.add_option("-l", "--load_model", dest="load_model", default=1)
parser.add_option("-a", "--train", dest="train", default=1)
parser.add_option("-d", "--dim_layer", dest="dim_layer", default=200)
Expand Down
24 changes: 11 additions & 13 deletions src/mnist_DLRA_fixed_rank.py → src/mnist_DLRT_fr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dlranet import DLRANet, create_csv_logger_cb
from networks.dense_dlrt_nets import DLRTNet
from networks.utils import create_csv_logger_cb

import tensorflow as tf
from tensorflow import keras
Expand All @@ -8,14 +9,15 @@
from os import path, makedirs


def train(start_rank, tolerance, load_model, dim_layer):
def train(start_rank, load_model, dim_layer):
# specify training
epochs = 100
batch_size = 256

filename = "e2edense_sr" + str(start_rank) + "_v" + str(tolerance)
folder_name = "e2edense_sr" + str(start_rank) + "_v" + str(tolerance) + '/latest_model'
folder_name_best = "e2edense_sr" + str(start_rank) + "_v" + str(tolerance) + '/best_model'
name = "mnist_dense_fr_sr"
filename = name + str(start_rank) + "_v"
folder_name = name + str(start_rank) + "_v/latest_model"
folder_name_best = name + str(start_rank) + "_v/best_model"

# check if dir exists
if not path.exists(folder_name):
Expand All @@ -30,12 +32,11 @@ def train(start_rank, tolerance, load_model, dim_layer):
output_dim = 10 # one-hot vector of digits 0-9

starting_rank = start_rank # starting rank of S matrix
tol = tolerance # eigenvalue treshold
max_rank = 350 # maximum rank of S matrix

dlra_layer_dim = dim_layer
model = DLRANet(input_dim=input_dim, output_dim=output_dim, low_rank=starting_rank,
dlra_layer_dim=dlra_layer_dim, tol=tol, rmax_total=max_rank)
model = DLRTNet(input_dim=input_dim, output_dim=output_dim, low_rank=starting_rank, dlra_layer_dim=dlra_layer_dim)
model.build_model()

# Build optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
# Choose loss
Expand Down Expand Up @@ -257,18 +258,15 @@ def normalize_img(image, label):
# --- parse options ---
parser = OptionParser()
parser.add_option("-s", "--start_rank", dest="start_rank", default=10)
parser.add_option("-t", "--tolerance", dest="tolerance", default=10)
parser.add_option("-l", "--load_model", dest="load_model", default=1)
parser.add_option("-a", "--train", dest="train", default=0)
parser.add_option("-d", "--dim_layer", dest="dim_layer", default=200)

(options, args) = parser.parse_args()
options.start_rank = int(options.start_rank)
options.tolerance = float(options.tolerance)
options.load_model = int(options.load_model)
options.train = int(options.train)
options.dim_layer = int(options.dim_layer)

if options.train == 1:
train(start_rank=options.start_rank, tolerance=options.tolerance, load_model=options.load_model,
dim_layer=options.dim_layer)
train(start_rank=options.start_rank, load_model=options.load_model, dim_layer=options.dim_layer)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dlranet import DLRANet, create_csv_logger_cb
from networks.dense_dlrt_nets import DLRTNet
from networks.utils import create_csv_logger_cb

import tensorflow as tf
from tensorflow import keras
Expand All @@ -13,9 +14,10 @@ def train(start_rank, tolerance, load_model):
epochs = 200
batch_size = 256

filename = "e2edense_sr" + str(start_rank) + "_v" + str(tolerance)
folder_name = "e2edense_sr" + str(start_rank) + "_v" + str(tolerance) + '/latest_model'
folder_name_best = "e2edense_sr" + str(start_rank) + "_v" + str(tolerance) + '/best_model'
name = "mnist_dense_sr"
filename = name + str(start_rank) + "_v" + str(tolerance)
folder_name = name + str(start_rank) + "_v" + str(tolerance) + '/latest_model'
folder_name_best = name + str(start_rank) + "_v" + str(tolerance) + '/best_model'
folder_dense_weights = "dense_weights" + '/best_model'

# check if dir exists
Expand All @@ -35,8 +37,7 @@ def train(start_rank, tolerance, load_model):
max_rank = 350 # maximum rank of S matrix

dlra_layer_dim = 784
model = DLRANet(input_dim=input_dim, output_dim=output_dim, low_rank=starting_rank,
dlra_layer_dim=dlra_layer_dim, tol=tol, rmax_total=max_rank)
model = DLRTNet(input_dim=input_dim, output_dim=output_dim, low_rank=starting_rank, dlra_layer_dim=dlra_layer_dim)
# Build optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
# Choose loss
Expand Down Expand Up @@ -327,5 +328,4 @@ def normalize_img(image, label):
options.load_model = int(options.load_model)
options.train = int(options.train)

if options.train == 1:
train(start_rank=options.start_rank, tolerance=options.tolerance, load_model=options.load_model)
train(start_rank=options.start_rank, tolerance=options.tolerance, load_model=options.load_model)
7 changes: 4 additions & 3 deletions src/mnist_reference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dlranet import ReferenceNet, create_csv_logger_cb
from networks.dense_dlrt_nets import ReferenceNet
from networks.utils import create_csv_logger_cb

import tensorflow as tf
from tensorflow import keras
Expand All @@ -13,7 +14,7 @@ def train(load_model=1):
epochs = 250
batch_size = 256

filename = "dense_500x5"
filename = "dense_weights"
folder_name = filename + '/latest_model'
folder_name_best = filename + '/best_model'

Expand All @@ -30,7 +31,7 @@ def train(load_model=1):

dlra_layer_dim = 784
model = ReferenceNet(input_dim=input_dim, output_dim=output_dim, layer_dim=dlra_layer_dim)

model.build_model()
# Build optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
# Choose loss
Expand Down
Empty file added src/networks/__init__.py
Empty file.
Loading

0 comments on commit 72d9589

Please sign in to comment.