From 47ca35835e9a4e22b270ef51d7d1452a10b4e8c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 8 May 2023 02:19:34 +0800 Subject: [PATCH 01/13] Update train_util.py for DAdaptLion --- library/train_util.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 9a4218082..e5f23710c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1885,7 +1885,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -2478,7 +2478,7 @@ def task(): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, Lion8bit, DAdaptation, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, Lion8bit, DAdaptation, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -2619,6 +2619,9 @@ def get_optimizer(args, trainable_params): elif optimizer_type == "DAdaptAdan".lower(): optimizer_class = dadaptation.DAdaptAdan print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptLion".lower(): + optimizer_class = dadaptation.DAdaptLion + print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptSGD".lower(): optimizer_class = dadaptation.DAdaptSGD print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") From 7ea8da5b0579041d5af1f0e2caf657fb3c30a9d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 8 May 2023 02:23:52 +0800 Subject: [PATCH 02/13] Update train_README-zh.md for dadaptlion --- docs/train_README-zh.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index dbd266060..220a9dd16 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -550,8 +550,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - Lion : https://github.com/lucidrains/lion-pytorch - 与过去版本中指定的 --use_lion_optimizer 相同 - SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True - - SGDNesterov8bit : 引数同上 + - SGDNesterov8bit : 参数同上 - DAdaptation : https://github.com/facebookresearch/dadaptation + - DAdaptAdaGrad : 参数同上 + - DAdaptAdan : 参数同上 + - DAdaptLion : 参数同上 + - DAdaptSGD : 参数同上 - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任何优化器 From 9040b11815cf93cca0e2fd66f210ab1be23a4687 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 8 May 2023 02:24:42 +0800 Subject: [PATCH 03/13] Update train_README-ja.md for DAdaptLion --- docs/train_README-ja.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index f27c5c654..5b3cf7723 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -618,6 +618,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - DAdaptation(DAdaptAdam) : https://github.com/facebookresearch/dadaptation - DAdaptAdaGrad : 引数は同上 - DAdaptAdan : 引数は同上 + - DAdaptLion : 引数は同上 - DAdaptSGD : 引数は同上 - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任意のオプティマイザ From 8e0ebc7a74131eedea8b5a882ce1f1528edfdff0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 8 May 2023 11:20:04 +0800 Subject: [PATCH 04/13] add DAdatpt V3 --- library/train_util.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e5f23710c..7ee4c73c3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1885,7 +1885,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -2478,7 +2478,7 @@ def task(): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, Lion8bit, DAdaptation, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, Lion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -2610,15 +2610,21 @@ def get_optimizer(args, trainable_params): ) # set optimizer - if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdam".lower(): - optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): + optimizer_class = dadaptation.DAdaptAdamPreprint + print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdaGrad".lower(): optimizer_class = dadaptation.DAdaptAdaGrad print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdam".lower(): + optimizer_class = dadaptation.DAdaptAdam + print(f"use D-Adaptation DAdaptAdam optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdan".lower(): optimizer_class = dadaptation.DAdaptAdan print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdanIP".lower(): + optimizer_class = dadaptation.DAdaptAdanIP + print(f"use D-Adaptation DAdaptAdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") From 557fd11bd0db97c3e74bb38b0ee1619dc8addcf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 8 May 2023 11:21:30 +0800 Subject: [PATCH 05/13] Alignment --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 7ee4c73c3..2146cd0ba 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2627,7 +2627,7 @@ def get_optimizer(args, trainable_params): print(f"use D-Adaptation DAdaptAdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptSGD".lower(): optimizer_class = dadaptation.DAdaptSGD print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") From a47bfc0a3b811e4355517b13e7371cf58fe69e57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 8 May 2023 17:24:09 +0800 Subject: [PATCH 06/13] Update train_util.py for experimental --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 2146cd0ba..fc96705c1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2611,7 +2611,7 @@ def get_optimizer(args, trainable_params): # set optimizer if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): - optimizer_class = dadaptation.DAdaptAdamPreprint + optimizer_class = dadaptation.experimental.DAdaptAdamPreprint print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdaGrad".lower(): optimizer_class = dadaptation.DAdaptAdaGrad @@ -2623,7 +2623,7 @@ def get_optimizer(args, trainable_params): optimizer_class = dadaptation.DAdaptAdan print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdanIP".lower(): - optimizer_class = dadaptation.DAdaptAdanIP + optimizer_class = dadaptation.experimental.DAdaptAdanIP print(f"use D-Adaptation DAdaptAdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion From 3a984e5be3ad1c5ccfd419f8c439204dfa76b8a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Tue, 9 May 2023 00:43:12 +0800 Subject: [PATCH 07/13] Update train_util.py V3 --- library/train_util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index fc96705c1..ef5e2e8e4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2586,6 +2586,7 @@ def get_optimizer(args, trainable_params): # check dadaptation is installed try: import dadaptation + import dadaptation.experimental as experimental except ImportError: raise ImportError("No dadaptation / dadaptation がインストールされていないようです") @@ -2611,7 +2612,7 @@ def get_optimizer(args, trainable_params): # set optimizer if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): - optimizer_class = dadaptation.experimental.DAdaptAdamPreprint + optimizer_class = experimental.DAdaptAdamPreprint print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdaGrad".lower(): optimizer_class = dadaptation.DAdaptAdaGrad @@ -2623,7 +2624,7 @@ def get_optimizer(args, trainable_params): optimizer_class = dadaptation.DAdaptAdan print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdanIP".lower(): - optimizer_class = dadaptation.experimental.DAdaptAdanIP + optimizer_class = experimental.DAdaptAdanIP print(f"use D-Adaptation DAdaptAdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion From a2b3f8591b72d55c0b847d51deeab49197f56d02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 22 May 2023 23:25:39 +0800 Subject: [PATCH 08/13] Update train_README-zh.md --- docs/train_README-zh.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index 220a9dd16..678832d2b 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -551,9 +551,11 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - 与过去版本中指定的 --use_lion_optimizer 相同 - SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True - SGDNesterov8bit : 参数同上 - - DAdaptation : https://github.com/facebookresearch/dadaptation + - DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation + - DAdaptAdam : 参数同上 - DAdaptAdaGrad : 参数同上 - DAdaptAdan : 参数同上 + - DAdaptAdanIP : 引数は同上 - DAdaptLion : 参数同上 - DAdaptSGD : 参数同上 - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) From 8be737472a332262692a72f19871a1a198813c97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Mon, 22 May 2023 23:25:46 +0800 Subject: [PATCH 09/13] Update train_README-ja.md --- docs/train_README-ja.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index 5b3cf7723..b64b18082 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -615,9 +615,11 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - Lion8bit : 引数は同上 - SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True - SGDNesterov8bit : 引数は同上 - - DAdaptation(DAdaptAdam) : https://github.com/facebookresearch/dadaptation + - DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation + - DAdaptAdam : 引数は同上 - DAdaptAdaGrad : 引数は同上 - DAdaptAdan : 引数は同上 + - DAdaptAdanIP : 引数は同上 - DAdaptLion : 引数は同上 - DAdaptSGD : 引数は同上 - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) From 342e253b5c993c10e96050967e360073ddc83755 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Wed, 24 May 2023 02:01:28 +0800 Subject: [PATCH 10/13] Update train_util.py fix --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 03f5469f3..04a980272 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2695,7 +2695,7 @@ def get_optimizer(args, trainable_params): print(f"use D-Adaptation DAdaptAdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptSGD".lower(): optimizer_class = dadaptation.DAdaptSGD print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") From 522eeea5233b060484d690d28aea5ead614519e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Wed, 24 May 2023 02:02:52 +0800 Subject: [PATCH 11/13] Update train_util.py --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 04a980272..b3968c431 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2686,13 +2686,13 @@ def get_optimizer(args, trainable_params): print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdam".lower(): optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation DAdaptAdam optimizer | {optimizer_kwargs}") + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdan".lower(): optimizer_class = dadaptation.DAdaptAdan print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdanIP".lower(): optimizer_class = experimental.DAdaptAdanIP - print(f"use D-Adaptation DAdaptAdanIP optimizer | {optimizer_kwargs}") + print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") From 7e185f8d5e54cda502edc8fb87640f259ec77179 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Tue, 13 Jun 2023 03:21:45 +0800 Subject: [PATCH 12/13] support Prodigy --- docs/train_README-ja.md | 1 + docs/train_README-zh.md | 3 ++- fine_tune.py | 2 +- library/train_util.py | 32 ++++++++++++++++++++++++++++++++ train_db.py | 2 +- train_network.py | 4 ++-- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 8 files changed, 41 insertions(+), 7 deletions(-) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index b64b18082..158363b39 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -622,6 +622,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - DAdaptAdanIP : 引数は同上 - DAdaptLion : 引数は同上 - DAdaptSGD : 引数は同上 + - Prodigy : https://github.com/konstmish/prodigy - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任意のオプティマイザ diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index 678832d2b..454d54561 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -555,9 +555,10 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - DAdaptAdam : 参数同上 - DAdaptAdaGrad : 参数同上 - DAdaptAdan : 参数同上 - - DAdaptAdanIP : 引数は同上 + - DAdaptAdanIP : 参数同上 - DAdaptLion : 参数同上 - DAdaptSGD : 参数同上 + - Prodigy : https://github.com/konstmish/prodigy - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任何优化器 diff --git a/fine_tune.py b/fine_tune.py index 201d49525..b8695bc13 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -393,7 +393,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/library/train_util.py b/library/train_util.py index 844faca75..72f5973bb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2732,6 +2732,38 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "Prodigy".lower(): + # Prodigy + # check Prodigy is installed + try: + import prodigyopt + except ImportError: + raise ImportError("No Prodigy / Prodigy がインストールされていないようです") + + # check lr and lr_count, and print warning + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + ) + print("recommend option: lr=1.0 / 推奨は1.0です") + if lr_count > 1: + print( + f"when multiple learning rates are specified with Prodigy (e.g. for Text Encoder and U-Net), only the first one will take effect / Prodigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + ) + + print(f"use Prodigy optimizer | {optimizer_kwargs}") + optimizer_class = prodigyopt.Prodigy + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "Adafactor".lower(): # 引数を確認して適宜補正する if "relative_step" not in optimizer_kwargs: diff --git a/train_db.py b/train_db.py index c81a092de..03dbe8dd5 100644 --- a/train_db.py +++ b/train_db.py @@ -380,7 +380,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_network.py b/train_network.py index b62aef7ee..48e39faa1 100644 --- a/train_network.py +++ b/train_network.py @@ -57,7 +57,7 @@ def generate_step_logs( logs["lr/textencoder"] = float(lrs[0]) logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value of unet. + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value of unet. logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] else: idx = 0 @@ -67,7 +67,7 @@ def generate_step_logs( for i in range(idx, len(lrs)): logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()): + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 8be0703d6..cdaed648b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -473,7 +473,7 @@ def remove_model(old_ckpt_name): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7b734f283..d685dcc4a 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -506,7 +506,7 @@ def remove_model(old_ckpt_name): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) From acf8a1f2e032795ada1f6b76f83b451e9953023a Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Tue, 13 Jun 2023 04:10:21 +0800 Subject: [PATCH 13/13] add lower --- train_db.py | 2 +- train_network.py | 4 ++-- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/train_db.py b/train_db.py index 03dbe8dd5..dabef6247 100644 --- a/train_db.py +++ b/train_db.py @@ -380,7 +380,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_network.py b/train_network.py index 48e39faa1..ec580314d 100644 --- a/train_network.py +++ b/train_network.py @@ -57,7 +57,7 @@ def generate_step_logs( logs["lr/textencoder"] = float(lrs[0]) logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value of unet. + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value of unet. logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] else: idx = 0 @@ -67,7 +67,7 @@ def generate_step_logs( for i in range(idx, len(lrs)): logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index cdaed648b..5af8aa0f8 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -473,7 +473,7 @@ def remove_model(old_ckpt_name): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index d685dcc4a..ca0c6f9b6 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -506,7 +506,7 @@ def remove_model(old_ckpt_name): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] )