Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scheduler #85

Merged
merged 11 commits into from
Jan 22, 2022
Merged

Add scheduler #85

merged 11 commits into from
Jan 22, 2022

Conversation

rentainhe
Copy link
Contributor

@rentainhe rentainhe commented Jan 6, 2022

TODO

Add more scheduler

  • PlateauLRScheduler ( ReduceLROnPlateau 传参 need to fix )
  • PolyLRScheduler
  • TanhLRScheduler

OneFlow

  • ReduceLROnPlateau didn't have epoch args

param_group['lr'] = self.restore_lr[i]
self.restore_lr = None

self.lr_scheduler.step(metric, epoch) # step the base scheduler
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flow.optim.lr_scheduler.ReduceLROnPlateau的step函数超参上缺少了一个epoch参数,可以对齐一下torch里的这个传参

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oeneflow 仓库的 ReduceLROnPlateau 也要改一下,还是先改了 这个再加吧,不然会报错

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一个功能加了要保证能运行

@Ldpe2G Ldpe2G requested a review from oneflow-ci-bot January 7, 2022 01:38
@rentainhe
Copy link
Contributor Author

Draw LRScheduler Picture

# from timm.scheduler.cosine_lr import CosineLRScheduler
from oneflow import optim
from flowvision.scheduler import MultiStepLRScheduler
from flowvision.models import ModelCreator
import oneflow as flow

model = ModelCreator.create_model("alexnet")
optimizer = flow.optim.SGD(model.parameters())

from matplotlib import pyplot as plt

from typing import List
a = [30, 60, 90]
print([_ * 100 for _ in a])
print(isinstance(a, list))


def get_lr_per_epoch(scheduler, num_epoch):
    lr_per_epoch = []
    for epoch in range(num_epoch):
        lr_per_epoch.append(scheduler.get_epoch_values(epoch))
    return lr_per_epoch

num_epoch = 100
scheduler = MultiStepLRScheduler(optimizer, decay_t=[30, 60, 90], decay_rate=0.1, warmup_lr_init=0.00001, warmup_t=10)

lr_per_epoch = []
for i in range(num_epoch):
    scheduler.step(i)
    lr_per_epoch.append(optimizer.param_groups[0]["lr"])


plt.plot([i for i in range(num_epoch)], lr_per_epoch)
plt.savefig("./test_multi_step.jpg")

@Ldpe2G Ldpe2G merged commit 356207d into main Jan 22, 2022
@Ldpe2G Ldpe2G deleted the add_scheduler branch January 22, 2022 04:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants