From cad9d778793b7b0720e2fba0fb92d31af8bbaec0 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 12 Jul 2022 16:54:09 +0200 Subject: [PATCH 1/2] Make grad accum work with dispatch dl --- src/accelerate/accelerator.py | 6 +- src/accelerate/data_loader.py | 113 ++++++++++-------- .../test_utils/scripts/test_sync.py | 63 +++++++--- 3 files changed, 108 insertions(+), 74 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 7e99c76efc7..354089c3876 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -240,10 +240,6 @@ def __init__( raise NotImplementedError( "Gradient accumulation on TPU is not supported. Pass in `gradient_accumulation_steps=1`" ) - if dispatch_batches: - raise NotImplementedError( - "Gradient accumulation with dispatched dataloaders is not supported. Pass in `gradient_accumulation_steps=1` or `dispatch_batches=False`" - ) self.gradient_accumulation_steps = gradient_accumulation_steps self.device_placement = device_placement @@ -397,7 +393,7 @@ def _do_sync(self): self.gradient_state._set_sync_gradients(True) else: self.step += 1 - self.gradient_state._set_sync_gradients((self.step % self.gradient_accumulation_steps) == 0) + self.gradient_state._set_sync_gradients((self.step % self.gradient_accumulation_steps) == 0) @property def sync_gradients(self): diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index a2c7a4a9bcc..fddfe895533 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -392,74 +392,87 @@ def __init__(self, dataset, split_batches: bool = False, **kwargs): self.gradient_state = GradientState() self.state = AcceleratorState() - def __iter__(self): - state = AcceleratorState() - if state.process_index == 0: - # We only iterate through the DataLoader on process 0. - main_iterator = super().__iter__() - stop_iteration = False - first_batch = None - while not stop_iteration: - # On process 0, we gather the batch to dispatch. - if state.process_index == 0: - try: - if self.split_batches: - # One batch of the main iterator is dispatched and split. - batch = next(main_iterator) - else: - # num_processes batches of the main iterator are concatenated then dispatched and split. - # We add the batches one by one so we have the remainder available when drop_last=False. - batches = [] - for _ in range(state.num_processes): - batches.append(next(main_iterator)) - batch = concatenate(batches, dim=0) - # In both cases, we need to get the structure of the batch that we will broadcast on other - # processes to initialize the tensors with the right shape. - # data_structure, stop_iteration + def _fetch_batches(self, iterator): + batches, batch = None, None + # On process 0, we gather the batch to dispatch. + if self.state.process_index == 0: + try: + if self.split_batches: + # One batch of the main iterator is dispatched and split. + batch = next(iterator) + else: + # num_processes batches of the main iterator are concatenated then dispatched and split. + # We add the batches one by one so we have the remainder available when drop_last=False. + batches = [] + for _ in range(self.state.num_processes): + batches.append(next(iterator)) + batch = concatenate(batches, dim=0) + # In both cases, we need to get the structure of the batch that we will broadcast on other + # processes to initialize the tensors with the right shape. + # data_structure, stop_iteration + batch_info = [get_data_structure(batch), False] + except StopIteration: + batch_info = [None, True] + else: + batch_info = [None, self._stop_iteration] + # This is inplace, so after this instruction, every process has the same `batch_info` as process 0. + broadcast_object_list(batch_info) + self._stop_iteration = batch_info[1] + if self._stop_iteration: + # If drop_last is False and split_batches is False, we may have a remainder to take care of. + if not self.split_batches and not self.drop_last: + if self.state.process_index == 0 and len(batches) > 0: + batch = concatenate(batches, dim=0) batch_info = [get_data_structure(batch), False] - except StopIteration: + else: batch_info = [None, True] + broadcast_object_list(batch_info) + if batch_info[1]: + return batch, batch_info, True else: - batch_info = [None, stop_iteration] - - # This is inplace, so after this instruction, every process has the same `batch_info` as process 0. - broadcast_object_list(batch_info) - stop_iteration = batch_info[1] - if stop_iteration: - # If drop_last is False and split_batches is False, we may have a remainder to take care of. - if not self.split_batches and not self.drop_last: - if state.process_index == 0 and len(batches) > 0: - batch = concatenate(batches, dim=0) - batch_info = [get_data_structure(batch), False] - else: - batch_info = [None, True] - broadcast_object_list(batch_info) - if batch_info[1]: - continue - else: - continue + return batch, batch_info, True + return batch, batch_info, False - if state.process_index != 0: + def __iter__(self): + self.gradient_state._set_end_of_dataloader(False) + main_iterator = None + if self.state.process_index == 0: + # We only iterate through the DataLoader on process 0. + main_iterator = super().__iter__() + self._stop_iteration = False + first_batch = None + batch, batch_info, skip = self._fetch_batches(main_iterator) + while True: + if skip: + continue + if self.state.process_index != 0: # Initialize tensors on other processes than process 0. batch = initialize_tensors(batch_info[0]) - batch = send_to_device(batch, state.device) + batch = send_to_device(batch, self.state.device) # Broadcast the batch before splitting it. batch = broadcast(batch, from_process=0) if not self.drop_last and first_batch is None: # We keep at least num processes elements of the first batch to be able to complete the last batch - first_batch = slice_tensors(batch, slice(0, state.num_processes)) + first_batch = slice_tensors(batch, slice(0, self.state.num_processes)) observed_batch_size = find_batch_size(batch) - batch_size = observed_batch_size // state.num_processes + batch_size = observed_batch_size // self.state.num_processes - if not self.drop_last and stop_iteration and observed_batch_size % state.num_processes != 0: + if not self.drop_last and self._stop_iteration and observed_batch_size % self.state.num_processes != 0: # If the last batch is not complete, let's add the first batch to it. batch = concatenate([batch, first_batch], dim=0) batch_size += 1 - data_slice = slice(state.process_index * batch_size, (state.process_index + 1) * batch_size) - yield slice_tensors(batch, data_slice) + data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size) + next_batch, next_batch_info, next_skip = self._fetch_batches(main_iterator) + if not self._stop_iteration: + yield slice_tensors(batch, data_slice) + batch, batch_info, skip = next_batch, next_batch_info, next_skip + else: + self.gradient_state._set_end_of_dataloader(True) + yield slice_tensors(batch, data_slice) + break def __len__(self): whole_length = super().__len__() diff --git a/src/accelerate/test_utils/scripts/test_sync.py b/src/accelerate/test_utils/scripts/test_sync.py index 0b0dc7abdd1..4754e484fb4 100644 --- a/src/accelerate/test_utils/scripts/test_sync.py +++ b/src/accelerate/test_utils/scripts/test_sync.py @@ -25,7 +25,7 @@ from accelerate.utils import DistributedType, set_seed -def check_model_parameters(model_a, model_b, did_step): +def check_model_parameters(model_a, model_b, did_step, iteration): for param, grad_param in zip(model_a.parameters(), model_b.parameters()): if not param.requires_grad: continue @@ -33,12 +33,12 @@ def check_model_parameters(model_a, model_b, did_step): # Grads should not be in sync assert ( torch.allclose(param.grad, grad_param.grad) is False - ), f"Gradients in sync when they should not be:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})" + ), f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})" else: # Grads should be in sync assert ( torch.allclose(param.grad, grad_param.grad) is True - ), f"Gradients not in sync when they should be:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})" + ), f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})" def step_model(model, input, target, accelerator, do_backward=True): @@ -96,7 +96,7 @@ def test_noop_sync(accelerator): step_model(ddp_model, ddp_input, ddp_target, accelerator) # Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync - check_model_parameters(model, ddp_model, True) + check_model_parameters(model, ddp_model, True, iteration) for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): if not param.requires_grad: continue @@ -106,7 +106,7 @@ def test_noop_sync(accelerator): # Shuffle ddp_input on each iteration torch.manual_seed(1337 + iteration) - ddp_input = ddp_input[torch.randperm(16)] + ddp_input = ddp_input[torch.randperm(len(ddp_input))] def test_distributed_sync(accelerator): @@ -146,11 +146,13 @@ def test_distributed_sync(accelerator): # Shuffle ddp_input on each iteration torch.manual_seed(1337 + iteration) - ddp_input = ddp_input[torch.randperm(16)] + ddp_input = ddp_input[torch.randperm(len(ddp_input))] -def test_gradient_accumulation(): - accelerator = Accelerator(gradient_accumulation_steps=2) +def test_gradient_accumulation(split_batches=False, dispatch_batches=False): + accelerator = Accelerator( + gradient_accumulation_steps=2, split_batches=split_batches, dispatch_batches=dispatch_batches + ) # Test that context manager behaves properly model, ddp_model, dataloader = get_training_setup(accelerator) for iteration, batch in enumerate(dataloader): @@ -181,11 +183,13 @@ def test_gradient_accumulation(): # Shuffle ddp_input on each iteration torch.manual_seed(1337 + iteration) - ddp_input = ddp_input[torch.randperm(16)] + ddp_input = ddp_input[torch.randperm(len(ddp_input))] -def test_gradient_accumulation_with_opt_and_scheduler(): - accelerator = Accelerator(gradient_accumulation_steps=2) +def test_gradient_accumulation_with_opt_and_scheduler(split_batches=False, dispatch_batches=False): + accelerator = Accelerator( + gradient_accumulation_steps=2, split_batches=split_batches, dispatch_batches=dispatch_batches + ) # Test that context manager behaves properly model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True) for iteration, batch in enumerate(dataloader): @@ -198,8 +202,11 @@ def test_gradient_accumulation_with_opt_and_scheduler(): ddp_model.train() step_model(model, input, target, accelerator, False) opt.step() - for _ in range(accelerator.num_processes): + if split_batches: sched.step() + else: + for _ in range(accelerator.num_processes): + sched.step() opt.zero_grad() # Perform gradient accumulation under wrapper with accelerator.accumulate(ddp_model): @@ -209,10 +216,12 @@ def test_gradient_accumulation_with_opt_and_scheduler(): ddp_opt.zero_grad() # Learning rates should be the same - assert opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"] - did_step = (((iteration + 1) % 2) == 0) or (iteration == (len(dataloader) - 1)) + assert ( + opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"] + ), f'Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]["lr"]}\nDDP opt: {ddp_opt.param_groups[0]["lr"]}\n' + did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader)) if accelerator.num_processes > 1: - check_model_parameters(model, ddp_model, did_step) + check_model_parameters(model, ddp_model, did_step, iteration) # Shuffle ddp_input on each iteration torch.manual_seed(1337 + iteration) @@ -229,12 +238,28 @@ def main(): print("**Test Distributed `no_sync` context manager**") test_distributed_sync(accelerator) if state.distributed_type == DistributedType.MULTI_GPU: - if state.local_process_index == 0: - print("**Test `accumulate` gradient accumulation**") - test_gradient_accumulation() + for split_batch in [True, False]: + for dispatch_batches in [True, False]: + if state.local_process_index == 0: + print( + f"**Test `accumulate` gradient accumulation, `split_batches={split_batch}` and `dispatch_batches={dispatch_batches}`**" + ) + test_gradient_accumulation(split_batch) if state.local_process_index == 0: - print("**Test `accumulate` gradient accumulation with optimizer and scheduler**") + print( + "**Test `accumulate` gradient accumulation with optimizer and scheduler, `split_batches=False`, `dispatch_batches=False`**" + ) test_gradient_accumulation_with_opt_and_scheduler() + if state.distributed_type == DistributedType.MULTI_GPU: + for split_batch in [True, False]: + for dispatch_batches in [True, False]: + if not split_batch and not dispatch_batches: + continue + if state.local_process_index == 0: + print( + f"**Test `accumulate` gradient accumulation with optimizer and scheduler, `split_batches={split_batch}` and `dispatch_batches={dispatch_batches}`**" + ) + test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches) def _mp_fn(index): From eb7e33e19c360534d8f741df68e82d121ce33426 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 12 Jul 2022 17:09:34 +0200 Subject: [PATCH 2/2] Split print over multiple lines --- src/accelerate/test_utils/scripts/test_sync.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_sync.py b/src/accelerate/test_utils/scripts/test_sync.py index 4754e484fb4..ae5a2c65b94 100644 --- a/src/accelerate/test_utils/scripts/test_sync.py +++ b/src/accelerate/test_utils/scripts/test_sync.py @@ -242,12 +242,14 @@ def main(): for dispatch_batches in [True, False]: if state.local_process_index == 0: print( - f"**Test `accumulate` gradient accumulation, `split_batches={split_batch}` and `dispatch_batches={dispatch_batches}`**" + "**Test `accumulate` gradient accumulation, ", + f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}`**", ) test_gradient_accumulation(split_batch) if state.local_process_index == 0: print( - "**Test `accumulate` gradient accumulation with optimizer and scheduler, `split_batches=False`, `dispatch_batches=False`**" + "**Test `accumulate` gradient accumulation with optimizer and scheduler, ", + "`split_batches=False`, `dispatch_batches=False`**", ) test_gradient_accumulation_with_opt_and_scheduler() if state.distributed_type == DistributedType.MULTI_GPU: @@ -257,7 +259,8 @@ def main(): continue if state.local_process_index == 0: print( - f"**Test `accumulate` gradient accumulation with optimizer and scheduler, `split_batches={split_batch}` and `dispatch_batches={dispatch_batches}`**" + "**Test `accumulate` gradient accumulation with optimizer and scheduler, ", + f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}`**", ) test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches)