From 67ba203e743da89508877697d2cfc7853c02f84c Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 28 Sep 2022 19:55:53 +0000 Subject: [PATCH 1/2] Fixed! --- src/accelerate/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/scheduler.py b/src/accelerate/scheduler.py index 4ae91a657a5..b917c799a1d 100644 --- a/src/accelerate/scheduler.py +++ b/src/accelerate/scheduler.py @@ -69,7 +69,7 @@ def step(self, *args, **kwargs): num_processes = AcceleratorState().num_processes for _ in range(num_processes): # Special case when using OneCycle and `drop_last` was not used - if getattr(self.scheduler, "total_steps", 0) <= self.scheduler.last_epoch: + if self.scheduler.last_epoch <= getattr(self.scheduler, "total_steps", 0): self.scheduler.step(*args, **kwargs) # Passthroughs From d432377ea119af2db35d5da3165529905f801a41 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 28 Sep 2022 20:25:35 +0000 Subject: [PATCH 2/2] Fix and write tests --- src/accelerate/scheduler.py | 4 ++- tests/test_scheduler.py | 56 +++++++++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/src/accelerate/scheduler.py b/src/accelerate/scheduler.py index b917c799a1d..835d4e0d961 100644 --- a/src/accelerate/scheduler.py +++ b/src/accelerate/scheduler.py @@ -69,7 +69,9 @@ def step(self, *args, **kwargs): num_processes = AcceleratorState().num_processes for _ in range(num_processes): # Special case when using OneCycle and `drop_last` was not used - if self.scheduler.last_epoch <= getattr(self.scheduler, "total_steps", 0): + if hasattr(self.scheduler, "total_steps") and self.scheduler._step_count <= self.scheduler.total_steps: + self.scheduler.step(*args, **kwargs) + else: self.scheduler.step(*args, **kwargs) # Passthroughs diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index be4f975fb35..c1ef18f1e66 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -21,12 +21,30 @@ from accelerate.test_utils import require_cpu -def scheduler_test(num_processes=2, step_scheduler_with_optimizer=True, split_batches=False): +def one_cycle_test(num_processes=2, step_scheduler_with_optimizer=True, split_batches=False): accelerator = Accelerator(step_scheduler_with_optimizer=step_scheduler_with_optimizer, split_batches=split_batches) model = torch.nn.Linear(2, 4) optimizer = torch.optim.AdamW(model.parameters(), lr=1.0) - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda n: 1 - n / 10) + scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1) + model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler) + # Optimizer has stepped + scheduler.step() + if step_scheduler_with_optimizer or (num_processes == 1): + assert ( + scheduler.scheduler.last_epoch == num_processes + ), f"Last Epoch ({scheduler.scheduler.last_epoch}) != Num Processes ({num_processes})" + else: + assert ( + scheduler.scheduler.last_epoch != num_processes + ), f"Last Epoch ({scheduler.scheduler.last_epoch}) == Num Processes ({num_processes})" + + +def lambda_test(num_processes=2, step_scheduler_with_optimizer=True, split_batches=False): + accelerator = Accelerator(step_scheduler_with_optimizer=step_scheduler_with_optimizer, split_batches=split_batches) + model = torch.nn.Linear(2, 4) + optimizer = torch.optim.AdamW(model.parameters(), lr=1.0) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda n: 1 - n / 10) model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler) # Optimizer has stepped @@ -49,16 +67,30 @@ def scheduler_test(num_processes=2, step_scheduler_with_optimizer=True, split_ba @require_cpu class SchedulerTester(unittest.TestCase): - def test_scheduler_steps_with_optimizer_single_process(self): - debug_launcher(partial(scheduler_test, num_processes=1), num_processes=1) - debug_launcher(partial(scheduler_test, num_processes=1, split_batches=True), num_processes=1) + def test_lambda_scheduler_steps_with_optimizer_single_process(self): + debug_launcher(partial(lambda_test, num_processes=1), num_processes=1) + debug_launcher(partial(lambda_test, num_processes=1, split_batches=True), num_processes=1) + + def test_one_cycle_scheduler_steps_with_optimizer_single_process(self): + debug_launcher(partial(one_cycle_test, num_processes=1), num_processes=1) + debug_launcher(partial(one_cycle_test, num_processes=1, split_batches=True), num_processes=1) + + def test_lambda_scheduler_not_step_with_optimizer_single_process(self): + debug_launcher(partial(lambda_test, num_processes=1, step_scheduler_with_optimizer=False), num_processes=1) + + def test_one_cycle_scheduler_not_step_with_optimizer_single_process(self): + debug_launcher(partial(one_cycle_test, num_processes=1, step_scheduler_with_optimizer=False), num_processes=1) + + def test_lambda_scheduler_steps_with_optimizer_multiprocess(self): + debug_launcher(lambda_test) + debug_launcher(partial(lambda_test, num_processes=1, split_batches=True), num_processes=1) - def test_scheduler_not_step_with_optimizer_single_process(self): - debug_launcher(partial(scheduler_test, num_processes=1, step_scheduler_with_optimizer=False), num_processes=1) + def test_one_cycle_scheduler_steps_with_optimizer_multiprocess(self): + debug_launcher(one_cycle_test) + debug_launcher(partial(one_cycle_test, num_processes=1, split_batches=True), num_processes=1) - def test_scheduler_steps_with_optimizer_multiprocess(self): - debug_launcher(scheduler_test) - debug_launcher(partial(scheduler_test, num_processes=1, split_batches=True), num_processes=1) + def test_lambda_scheduler_not_step_with_optimizer_multiprocess(self): + debug_launcher(partial(lambda_test, step_scheduler_with_optimizer=False)) - def test_scheduler_not_step_with_optimizer_multiprocess(self): - debug_launcher(partial(scheduler_test, step_scheduler_with_optimizer=False)) + def test_one_cycle_scheduler_not_step_with_optimizer_multiprocess(self): + debug_launcher(partial(one_cycle_test, step_scheduler_with_optimizer=False))