Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for dadapaption V3 #530

Merged
merged 13 commits into from
May 25, 2023
5 changes: 4 additions & 1 deletion docs/train_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- Lion8bit : 引数は同上
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
- SGDNesterov8bit : 引数は同上
- DAdaptation(DAdaptAdam) : https://github.com/facebookresearch/dadaptation
- DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation
- DAdaptAdam : 引数は同上
- DAdaptAdaGrad : 引数は同上
- DAdaptAdan : 引数は同上
- DAdaptAdanIP : 引数は同上
- DAdaptLion : 引数は同上
- DAdaptSGD : 引数は同上
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
- 任意のオプティマイザ
Expand Down
10 changes: 8 additions & 2 deletions docs/train_README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,14 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- Lion : https://github.com/lucidrains/lion-pytorch
- 与过去版本中指定的 --use_lion_optimizer 相同
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
- SGDNesterov8bit : 引数同上
- DAdaptation : https://github.com/facebookresearch/dadaptation
- SGDNesterov8bit : 参数同上
- DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation
- DAdaptAdam : 参数同上
- DAdaptAdaGrad : 参数同上
- DAdaptAdan : 参数同上
- DAdaptAdanIP : 引数は同上
- DAdaptLion : 参数同上
- DAdaptSGD : 参数同上
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
- 任何优化器

Expand Down
20 changes: 15 additions & 5 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,7 +1940,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
"--optimizer_type",
type=str,
default="",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptSGD, AdaFactor",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
)

# backward compatibility
Expand Down Expand Up @@ -2545,7 +2545,7 @@ def task():


def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, Lion8bit, DAdaptation, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptSGD, Adafactor"
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, Lion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"

optimizer_type = args.optimizer_type
if args.use_8bit_adam:
Expand Down Expand Up @@ -2653,6 +2653,7 @@ def get_optimizer(args, trainable_params):
# check dadaptation is installed
try:
import dadaptation
import dadaptation.experimental as experimental
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")

Expand All @@ -2677,15 +2678,24 @@ def get_optimizer(args, trainable_params):
)

# set optimizer
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdam".lower():
optimizer_class = dadaptation.DAdaptAdam
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
optimizer_class = experimental.DAdaptAdamPreprint
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdaGrad".lower():
optimizer_class = dadaptation.DAdaptAdaGrad
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdam".lower():
optimizer_class = dadaptation.DAdaptAdam
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdan".lower():
optimizer_class = dadaptation.DAdaptAdan
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdanIP".lower():
optimizer_class = experimental.DAdaptAdanIP
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptLion".lower():
optimizer_class = dadaptation.DAdaptLion
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptSGD".lower():
optimizer_class = dadaptation.DAdaptSGD
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
Expand Down