GROUP ORTHOGONALIZATION REGULARIZATION (GOR) is a weight regularization technique that promotes orthonormality between groups of filters within the same layer. GOR complements existing normalization techniques, such as BN and GN, and can be applied to a wide range of deep-learning models.
def inter_gor(weight, num_groups):
c_out = weight[0]
if weight.ndimension() > 2:
w = weight.reshape(weight[0], -1)
# w is c_out x filter_dim
reg_loss = 0
group_size = c_out // num_groups
# Iterate over groups and calculate regularization loss for each
for ii in range(num_groups):
w_g = w[ii * group_size: (ii + 1) * group_size] # set_size x filter_dim
reg_loss += torch.dist(w_g.T @ w_g, torch.eye(w_g.shape[1]).cuda()) ** 2 # ||W^T * W - I||^2
return reg_loss
Parallelized version in weight_regularization.py
git clone https://github.com/YoavKurtz/GOR
cd GOR
python3 -m venv gor_venv
source gor_venv/bin/activate
pip install -r requirements.txt
We provide implementations of training with our regularization for the following tasks:
-
▶️ Train ResNet110 + GN on CIFAR10 with GOR:python train_cifar10.py --data-path /path/to/cifar --reg-type inter --norm GN
-
Fine-tuning diffusion models with LoRA
We provide our modified training script. Based on an example from the HF repo.
▶️ Fine-tune SD 1.5 with LoRA and GOR on Pokemon-BLIP on 2 GPUs by runningfine_tune_lora_gor_pokemon.sh
.
@article{kurtz2023group,
title={Group Orthogonalization Regularization For Vision Models Adaptation and Robustness},
author={Kurtz, Yoav and Bar, Noga and Giryes, Raja},
journal={arXiv preprint arXiv:2306.10001},
year={2023}
}
- Add requirement.txt
- Add cmds for running the training with GOR
- Enter paper citation and link to paper/project.
- Consider adding more examples.