Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make gradient accumulation work with dispatched dataloaders #510

Merged
merged 2 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,6 @@ def __init__(
raise NotImplementedError(
"Gradient accumulation on TPU is not supported. Pass in `gradient_accumulation_steps=1`"
)
if dispatch_batches:
raise NotImplementedError(
"Gradient accumulation with dispatched dataloaders is not supported. Pass in `gradient_accumulation_steps=1` or `dispatch_batches=False`"
)

self.gradient_accumulation_steps = gradient_accumulation_steps
self.device_placement = device_placement
Expand Down Expand Up @@ -397,7 +393,7 @@ def _do_sync(self):
self.gradient_state._set_sync_gradients(True)
else:
self.step += 1
self.gradient_state._set_sync_gradients((self.step % self.gradient_accumulation_steps) == 0)
self.gradient_state._set_sync_gradients((self.step % self.gradient_accumulation_steps) == 0)

@property
def sync_gradients(self):
Expand Down
113 changes: 63 additions & 50 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,74 +392,87 @@ def __init__(self, dataset, split_batches: bool = False, **kwargs):
self.gradient_state = GradientState()
self.state = AcceleratorState()

def __iter__(self):
state = AcceleratorState()
if state.process_index == 0:
# We only iterate through the DataLoader on process 0.
main_iterator = super().__iter__()
stop_iteration = False
first_batch = None
while not stop_iteration:
# On process 0, we gather the batch to dispatch.
if state.process_index == 0:
try:
if self.split_batches:
# One batch of the main iterator is dispatched and split.
batch = next(main_iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(state.num_processes):
batches.append(next(main_iterator))
batch = concatenate(batches, dim=0)
# In both cases, we need to get the structure of the batch that we will broadcast on other
# processes to initialize the tensors with the right shape.
# data_structure, stop_iteration
def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
if self.state.process_index == 0:
try:
if self.split_batches:
# One batch of the main iterator is dispatched and split.
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
batches.append(next(iterator))
batch = concatenate(batches, dim=0)
# In both cases, we need to get the structure of the batch that we will broadcast on other
# processes to initialize the tensors with the right shape.
# data_structure, stop_iteration
batch_info = [get_data_structure(batch), False]
except StopIteration:
batch_info = [None, True]
else:
batch_info = [None, self._stop_iteration]
# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
broadcast_object_list(batch_info)
self._stop_iteration = batch_info[1]
if self._stop_iteration:
# If drop_last is False and split_batches is False, we may have a remainder to take care of.
if not self.split_batches and not self.drop_last:
if self.state.process_index == 0 and len(batches) > 0:
batch = concatenate(batches, dim=0)
batch_info = [get_data_structure(batch), False]
except StopIteration:
else:
batch_info = [None, True]
broadcast_object_list(batch_info)
if batch_info[1]:
return batch, batch_info, True
else:
batch_info = [None, stop_iteration]

# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
broadcast_object_list(batch_info)
stop_iteration = batch_info[1]
if stop_iteration:
# If drop_last is False and split_batches is False, we may have a remainder to take care of.
if not self.split_batches and not self.drop_last:
if state.process_index == 0 and len(batches) > 0:
batch = concatenate(batches, dim=0)
batch_info = [get_data_structure(batch), False]
else:
batch_info = [None, True]
broadcast_object_list(batch_info)
if batch_info[1]:
continue
else:
continue
return batch, batch_info, True
return batch, batch_info, False

if state.process_index != 0:
def __iter__(self):
self.gradient_state._set_end_of_dataloader(False)
main_iterator = None
if self.state.process_index == 0:
# We only iterate through the DataLoader on process 0.
main_iterator = super().__iter__()
self._stop_iteration = False
first_batch = None
batch, batch_info, skip = self._fetch_batches(main_iterator)
while True:
if skip:
continue
if self.state.process_index != 0:
# Initialize tensors on other processes than process 0.
batch = initialize_tensors(batch_info[0])
batch = send_to_device(batch, state.device)
batch = send_to_device(batch, self.state.device)
# Broadcast the batch before splitting it.
batch = broadcast(batch, from_process=0)

if not self.drop_last and first_batch is None:
# We keep at least num processes elements of the first batch to be able to complete the last batch
first_batch = slice_tensors(batch, slice(0, state.num_processes))
first_batch = slice_tensors(batch, slice(0, self.state.num_processes))

observed_batch_size = find_batch_size(batch)
batch_size = observed_batch_size // state.num_processes
batch_size = observed_batch_size // self.state.num_processes

if not self.drop_last and stop_iteration and observed_batch_size % state.num_processes != 0:
if not self.drop_last and self._stop_iteration and observed_batch_size % self.state.num_processes != 0:
# If the last batch is not complete, let's add the first batch to it.
batch = concatenate([batch, first_batch], dim=0)
batch_size += 1

data_slice = slice(state.process_index * batch_size, (state.process_index + 1) * batch_size)
yield slice_tensors(batch, data_slice)
data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size)
next_batch, next_batch_info, next_skip = self._fetch_batches(main_iterator)
if not self._stop_iteration:
yield slice_tensors(batch, data_slice)
batch, batch_info, skip = next_batch, next_batch_info, next_skip
else:
self.gradient_state._set_end_of_dataloader(True)
yield slice_tensors(batch, data_slice)
break

def __len__(self):
whole_length = super().__len__()
Expand Down
66 changes: 47 additions & 19 deletions src/accelerate/test_utils/scripts/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@
from accelerate.utils import DistributedType, set_seed


def check_model_parameters(model_a, model_b, did_step):
def check_model_parameters(model_a, model_b, did_step, iteration):
for param, grad_param in zip(model_a.parameters(), model_b.parameters()):
if not param.requires_grad:
continue
if not did_step:
# Grads should not be in sync
assert (
torch.allclose(param.grad, grad_param.grad) is False
), f"Gradients in sync when they should not be:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})"
), f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})"
else:
# Grads should be in sync
assert (
torch.allclose(param.grad, grad_param.grad) is True
), f"Gradients not in sync when they should be:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})"
), f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})"


def step_model(model, input, target, accelerator, do_backward=True):
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_noop_sync(accelerator):
step_model(ddp_model, ddp_input, ddp_target, accelerator)

# Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync
check_model_parameters(model, ddp_model, True)
check_model_parameters(model, ddp_model, True, iteration)
for param, ddp_param in zip(model.parameters(), ddp_model.parameters()):
if not param.requires_grad:
continue
Expand All @@ -106,7 +106,7 @@ def test_noop_sync(accelerator):

# Shuffle ddp_input on each iteration
torch.manual_seed(1337 + iteration)
ddp_input = ddp_input[torch.randperm(16)]
ddp_input = ddp_input[torch.randperm(len(ddp_input))]


def test_distributed_sync(accelerator):
Expand Down Expand Up @@ -146,11 +146,13 @@ def test_distributed_sync(accelerator):

# Shuffle ddp_input on each iteration
torch.manual_seed(1337 + iteration)
ddp_input = ddp_input[torch.randperm(16)]
ddp_input = ddp_input[torch.randperm(len(ddp_input))]


def test_gradient_accumulation():
accelerator = Accelerator(gradient_accumulation_steps=2)
def test_gradient_accumulation(split_batches=False, dispatch_batches=False):
accelerator = Accelerator(
gradient_accumulation_steps=2, split_batches=split_batches, dispatch_batches=dispatch_batches
)
# Test that context manager behaves properly
model, ddp_model, dataloader = get_training_setup(accelerator)
for iteration, batch in enumerate(dataloader):
Expand Down Expand Up @@ -181,11 +183,13 @@ def test_gradient_accumulation():

# Shuffle ddp_input on each iteration
torch.manual_seed(1337 + iteration)
ddp_input = ddp_input[torch.randperm(16)]
ddp_input = ddp_input[torch.randperm(len(ddp_input))]


def test_gradient_accumulation_with_opt_and_scheduler():
accelerator = Accelerator(gradient_accumulation_steps=2)
def test_gradient_accumulation_with_opt_and_scheduler(split_batches=False, dispatch_batches=False):
accelerator = Accelerator(
gradient_accumulation_steps=2, split_batches=split_batches, dispatch_batches=dispatch_batches
)
# Test that context manager behaves properly
model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True)
for iteration, batch in enumerate(dataloader):
Expand All @@ -198,8 +202,11 @@ def test_gradient_accumulation_with_opt_and_scheduler():
ddp_model.train()
step_model(model, input, target, accelerator, False)
opt.step()
for _ in range(accelerator.num_processes):
if split_batches:
sched.step()
else:
for _ in range(accelerator.num_processes):
sched.step()
opt.zero_grad()
# Perform gradient accumulation under wrapper
with accelerator.accumulate(ddp_model):
Expand All @@ -209,10 +216,12 @@ def test_gradient_accumulation_with_opt_and_scheduler():
ddp_opt.zero_grad()

# Learning rates should be the same
assert opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"]
did_step = (((iteration + 1) % 2) == 0) or (iteration == (len(dataloader) - 1))
assert (
opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"]
), f'Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]["lr"]}\nDDP opt: {ddp_opt.param_groups[0]["lr"]}\n'
did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader))
if accelerator.num_processes > 1:
check_model_parameters(model, ddp_model, did_step)
check_model_parameters(model, ddp_model, did_step, iteration)
# Shuffle ddp_input on each iteration
torch.manual_seed(1337 + iteration)

Expand All @@ -229,12 +238,31 @@ def main():
print("**Test Distributed `no_sync` context manager**")
test_distributed_sync(accelerator)
if state.distributed_type == DistributedType.MULTI_GPU:
if state.local_process_index == 0:
print("**Test `accumulate` gradient accumulation**")
test_gradient_accumulation()
for split_batch in [True, False]:
for dispatch_batches in [True, False]:
if state.local_process_index == 0:
print(
"**Test `accumulate` gradient accumulation, ",
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}`**",
)
test_gradient_accumulation(split_batch)
if state.local_process_index == 0:
print("**Test `accumulate` gradient accumulation with optimizer and scheduler**")
print(
"**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
"`split_batches=False`, `dispatch_batches=False`**",
)
test_gradient_accumulation_with_opt_and_scheduler()
if state.distributed_type == DistributedType.MULTI_GPU:
for split_batch in [True, False]:
for dispatch_batches in [True, False]:
if not split_batch and not dispatch_batches:
continue
if state.local_process_index == 0:
print(
"**Test `accumulate` gradient accumulation with optimizer and scheduler, ",
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}`**",
)
test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches)


def _mp_fn(index):
Expand Down