Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: saving model weights #556

Merged
merged 2 commits into from
Jul 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,16 +446,13 @@ def _prepare_one(self, obj, first_pass=False):
if isinstance(obj, torch.utils.data.DataLoader):
return self.prepare_data_loader(obj)
elif isinstance(obj, torch.nn.Module):
self._models.append(obj)
return self.prepare_model(obj)
elif isinstance(obj, torch.optim.Optimizer):
optimizer = self.prepare_optimizer(obj)
self._optimizers.append(optimizer)
return optimizer
# Second pass of preparation: LR scheduler (which need the full list of optimizers)
elif isinstance(obj, torch.optim.lr_scheduler._LRScheduler):
scheduler = self.prepare_scheduler(obj)
self._schedulers.append(scheduler)
return scheduler
# Return the unprocessed object if previous criteria was not met
return obj
Expand Down Expand Up @@ -570,6 +567,7 @@ def prepare(self, *args):
return result if len(result) > 1 else result[0]

def prepare_model(self, model):
self._models.append(model)
if self.device_placement and self.distributed_type != DistributedType.FSDP:
model = model.to(self.device)
if self.distributed_type == DistributedType.MULTI_GPU:
Expand Down Expand Up @@ -782,7 +780,9 @@ def prepare_data_loader(self, data_loader):
)

def prepare_optimizer(self, optimizer):
return AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler)
optimizer = AcceleratedOptimizer(optimizer, device_placement=self.device_placement, scaler=self.scaler)
self._optimizers.append(optimizer)
return optimizer

def prepare_scheduler(self, scheduler):
# We try to find the optimizer associated with `scheduler`, the default is the full list.
Expand All @@ -791,13 +791,14 @@ def prepare_scheduler(self, scheduler):
if getattr(scheduler, "optimizer", None) == opt.optimizer:
optimizer = opt
break

return AcceleratedScheduler(
scheduler = AcceleratedScheduler(
scheduler,
optimizer,
step_with_optimizer=self.step_scheduler_with_optimizer,
split_batches=self.split_batches,
)
self._schedulers.append(scheduler)
return scheduler

def backward(self, loss, **kwargs):
"""
Expand Down