From e97d67a68121df2ec57270d131c76ec8cb2e312d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Thu, 15 Jun 2023 20:12:53 +0800 Subject: [PATCH] Support for Prodigy(Dadapt variety for Dylora) (#585) * Update train_util.py for DAdaptLion * Update train_README-zh.md for dadaptlion * Update train_README-ja.md for DAdaptLion * add DAdatpt V3 * Alignment * Update train_util.py for experimental * Update train_util.py V3 * Update train_README-zh.md * Update train_README-ja.md * Update train_util.py fix * Update train_util.py * support Prodigy * add lower --- docs/train_README-ja.md | 1 + docs/train_README-zh.md | 3 ++- fine_tune.py | 2 +- library/train_util.py | 32 ++++++++++++++++++++++++++++++++ train_db.py | 2 +- train_network.py | 4 ++-- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 8 files changed, 41 insertions(+), 7 deletions(-) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index b64b18082..158363b39 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -622,6 +622,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - DAdaptAdanIP : 引数は同上 - DAdaptLion : 引数は同上 - DAdaptSGD : 引数は同上 + - Prodigy : https://github.com/konstmish/prodigy - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任意のオプティマイザ diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index 678832d2b..454d54561 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -555,9 +555,10 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - DAdaptAdam : 参数同上 - DAdaptAdaGrad : 参数同上 - DAdaptAdan : 参数同上 - - DAdaptAdanIP : 引数は同上 + - DAdaptAdanIP : 参数同上 - DAdaptLion : 参数同上 - DAdaptSGD : 参数同上 + - Prodigy : https://github.com/konstmish/prodigy - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任何优化器 diff --git a/fine_tune.py b/fine_tune.py index 308f90ef1..d0013d538 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -397,7 +397,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/library/train_util.py b/library/train_util.py index 4a25e00d8..5b5d99ac3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2808,6 +2808,38 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "Prodigy".lower(): + # Prodigy + # check Prodigy is installed + try: + import prodigyopt + except ImportError: + raise ImportError("No Prodigy / Prodigy がインストールされていないようです") + + # check lr and lr_count, and print warning + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + ) + print("recommend option: lr=1.0 / 推奨は1.0です") + if lr_count > 1: + print( + f"when multiple learning rates are specified with Prodigy (e.g. for Text Encoder and U-Net), only the first one will take effect / Prodigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + ) + + print(f"use Prodigy optimizer | {optimizer_kwargs}") + optimizer_class = prodigyopt.Prodigy + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "Adafactor".lower(): # 引数を確認して適宜補正する if "relative_step" not in optimizer_kwargs: diff --git a/train_db.py b/train_db.py index 115855c13..927e79dea 100644 --- a/train_db.py +++ b/train_db.py @@ -384,7 +384,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_network.py b/train_network.py index abec3d419..da0ca1c9c 100644 --- a/train_network.py +++ b/train_network.py @@ -57,7 +57,7 @@ def generate_step_logs( logs["lr/textencoder"] = float(lrs[0]) logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value of unet. + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value of unet. logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] else: idx = 0 @@ -67,7 +67,7 @@ def generate_step_logs( for i in range(idx, len(lrs)): logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()): + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 48713fc10..d08251e12 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -476,7 +476,7 @@ def remove_model(old_ckpt_name): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index bf7d5bb0f..f44d565cc 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -515,7 +515,7 @@ def remove_model(old_ckpt_name): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] )