Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with one-cycle logic #728

Merged
merged 2 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/accelerate/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def step(self, *args, **kwargs):
num_processes = AcceleratorState().num_processes
for _ in range(num_processes):
# Special case when using OneCycle and `drop_last` was not used
if getattr(self.scheduler, "total_steps", 0) <= self.scheduler.last_epoch:
if hasattr(self.scheduler, "total_steps") and self.scheduler._step_count <= self.scheduler.total_steps:
self.scheduler.step(*args, **kwargs)
else:
self.scheduler.step(*args, **kwargs)

# Passthroughs
Expand Down
56 changes: 44 additions & 12 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,30 @@
from accelerate.test_utils import require_cpu


def scheduler_test(num_processes=2, step_scheduler_with_optimizer=True, split_batches=False):
def one_cycle_test(num_processes=2, step_scheduler_with_optimizer=True, split_batches=False):
accelerator = Accelerator(step_scheduler_with_optimizer=step_scheduler_with_optimizer, split_batches=split_batches)
model = torch.nn.Linear(2, 4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1.0)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda n: 1 - n / 10)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1)
model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)

# Optimizer has stepped
scheduler.step()
if step_scheduler_with_optimizer or (num_processes == 1):
assert (
scheduler.scheduler.last_epoch == num_processes
), f"Last Epoch ({scheduler.scheduler.last_epoch}) != Num Processes ({num_processes})"
else:
assert (
scheduler.scheduler.last_epoch != num_processes
), f"Last Epoch ({scheduler.scheduler.last_epoch}) == Num Processes ({num_processes})"


def lambda_test(num_processes=2, step_scheduler_with_optimizer=True, split_batches=False):
accelerator = Accelerator(step_scheduler_with_optimizer=step_scheduler_with_optimizer, split_batches=split_batches)
model = torch.nn.Linear(2, 4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1.0)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda n: 1 - n / 10)
model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)

# Optimizer has stepped
Expand All @@ -49,16 +67,30 @@ def scheduler_test(num_processes=2, step_scheduler_with_optimizer=True, split_ba

@require_cpu
class SchedulerTester(unittest.TestCase):
def test_scheduler_steps_with_optimizer_single_process(self):
debug_launcher(partial(scheduler_test, num_processes=1), num_processes=1)
debug_launcher(partial(scheduler_test, num_processes=1, split_batches=True), num_processes=1)
def test_lambda_scheduler_steps_with_optimizer_single_process(self):
debug_launcher(partial(lambda_test, num_processes=1), num_processes=1)
debug_launcher(partial(lambda_test, num_processes=1, split_batches=True), num_processes=1)

def test_one_cycle_scheduler_steps_with_optimizer_single_process(self):
debug_launcher(partial(one_cycle_test, num_processes=1), num_processes=1)
debug_launcher(partial(one_cycle_test, num_processes=1, split_batches=True), num_processes=1)

def test_lambda_scheduler_not_step_with_optimizer_single_process(self):
debug_launcher(partial(lambda_test, num_processes=1, step_scheduler_with_optimizer=False), num_processes=1)

def test_one_cycle_scheduler_not_step_with_optimizer_single_process(self):
debug_launcher(partial(one_cycle_test, num_processes=1, step_scheduler_with_optimizer=False), num_processes=1)

def test_lambda_scheduler_steps_with_optimizer_multiprocess(self):
debug_launcher(lambda_test)
debug_launcher(partial(lambda_test, num_processes=1, split_batches=True), num_processes=1)

def test_scheduler_not_step_with_optimizer_single_process(self):
debug_launcher(partial(scheduler_test, num_processes=1, step_scheduler_with_optimizer=False), num_processes=1)
def test_one_cycle_scheduler_steps_with_optimizer_multiprocess(self):
debug_launcher(one_cycle_test)
debug_launcher(partial(one_cycle_test, num_processes=1, split_batches=True), num_processes=1)

def test_scheduler_steps_with_optimizer_multiprocess(self):
debug_launcher(scheduler_test)
debug_launcher(partial(scheduler_test, num_processes=1, split_batches=True), num_processes=1)
def test_lambda_scheduler_not_step_with_optimizer_multiprocess(self):
debug_launcher(partial(lambda_test, step_scheduler_with_optimizer=False))

def test_scheduler_not_step_with_optimizer_multiprocess(self):
debug_launcher(partial(scheduler_test, step_scheduler_with_optimizer=False))
def test_one_cycle_scheduler_not_step_with_optimizer_multiprocess(self):
debug_launcher(partial(one_cycle_test, step_scheduler_with_optimizer=False))