Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better NaN/inf loss handling for O0 (skip step across workers) #637

Merged
merged 3 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ To release a new version, please update the changelog as followed:
- Online audio augmentation notebook in ASR examples ([PR #605](https://github.com/NVIDIA/NeMo/pull/605)) - @titu1994

### Changed
- Syncs across workers at each step to check for NaN or inf loss. Terminates all workers if stop\_on\_nan\_loss is set (as before), lets Apex deal with it if apex.amp optimization level is O1 or higher, and skips the step across workers otherwise. ([PR #637](https://github.com/NVIDIA/NeMo/pull/637)) - @redoctopus

### Dependencies Update

Expand Down
59 changes: 15 additions & 44 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,22 +1078,6 @@ def deployment_export(module, output: str, d_format: DeploymentFormat, input_exa
output_example=output_example,
)

def _check_nan_or_inf(self, placement_gpu, nan_or_inf, steps_per_nan_check=None):
# Note that nan_or_inf only gets set if stop_on_nan loss is True, or if using O0/not using apex.amp.
if not placement_gpu:
return
if steps_per_nan_check is None or self.step % steps_per_nan_check == 0:
world_size = dist.get_world_size()
# We use dtype=int because nccl backend doesn't support torch.bool
nan_inf_tensor = torch.tensor(nan_or_inf, dtype=int).cuda()
nan_inf_results = []
for _ in range(world_size):
nan_inf_results.append(torch.empty_like(nan_inf_tensor))
dist.all_gather(nan_inf_results, nan_inf_tensor)
for nan_inf in nan_inf_results:
if nan_inf:
raise ValueError('Terminating due to previous NaN or inf.')

def train(
self,
tensors_to_optimize=None,
Expand All @@ -1104,7 +1088,6 @@ def train(
lr_policy=None,
batches_per_step=None,
stop_on_nan_loss=False,
steps_per_nan_check=100,
synced_batchnorm=False,
synced_batchnorm_groupsize=0,
gradient_predivide=False,
Expand Down Expand Up @@ -1353,8 +1336,6 @@ def train(
# Do action start callbacks
self._perform_on_action_start(callbacks=callbacks)

nan_or_inf = False

# MAIN TRAINING LOOP
# iteration over epochs
while num_epochs is None or self.epoch_num < num_epochs:
Expand Down Expand Up @@ -1418,26 +1399,22 @@ def train(
curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1]
final_loss = 0
for tensor in curr_tensors_to_optimize:
if (
torch.isnan(registered_tensors[tensor.unique_name]).any()
or torch.isinf(registered_tensors[tensor.unique_name]).any()
):
if (
(stop_on_nan_loss)
or (self._optim_level not in AmpOptimizations)
or (self._optim_level == Optimization.mxprO0)
):
# Set flag here and terminate at next all_gather check.
nan_or_inf = True
logging.warning(
'Loss is NaN or inf at step %d, will terminate within the'
' next steps_per_nan_check steps',
self.step,
)
else:
logging.warning('Loss is NaN or inf, continuing training')
final_loss += registered_tensors[tensor.unique_name]

# Check for NaN/inf loss (across workers if applicable)
loss_nan_inf_checker = final_loss.clone()
if placement_gpu:
dist.all_reduce(loss_nan_inf_checker)
blisc marked this conversation as resolved.
Show resolved Hide resolved
if torch.isnan(loss_nan_inf_checker).any() or torch.isinf(loss_nan_inf_checker).any():
if stop_on_nan_loss:
raise ValueError('Loss is NaN or inf - exiting')
if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0:
logging.warning('Loss is NaN or inf.')
else:
# Skip this step across workers if loss is NaN/inf and using fp32
logging.warning('Loss is NaN or inf. Skipping update.')
continue

if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0:
with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss:
if disable_allreduce:
Expand All @@ -1460,15 +1437,12 @@ def train(
final_loss.backward(bps_scale.to(final_loss.get_device()))
# single device (CPU or GPU)
else:
# Fix (workaround?) enabling to backpropagate gradiens on CPUs.
# Fix (workaround?) enabling to backpropagate gradients on CPUs.
if final_loss.get_device() < 0:
final_loss.backward(bps_scale)
else:
final_loss.backward(bps_scale.to(final_loss.get_device()))

# Check if we should terminate due to NaN/inf on any workers.
self._check_nan_or_inf(placement_gpu, nan_or_inf, steps_per_nan_check=steps_per_nan_check)

batch_counter += 1

if batch_counter == batches_per_step:
Expand All @@ -1488,9 +1462,6 @@ def train(
self._perform_on_epoch_end(callbacks=callbacks)
self.epoch_num += 1

# Check again if we should stop on NaN/inf
self._check_nan_or_inf(placement_gpu, nan_or_inf)

self._perform_on_action_end(callbacks=callbacks)

def infer(
Expand Down
5 changes: 1 addition & 4 deletions nemo/core/neural_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def train(
batch_size
stop_on_nan_loss: (default: False) If set to True, the training
will stop if loss=nan or inf. If set to False, the training
will continue. Note that if apex.amp is not used, or if
optimization level is O0, training will stop regardless.
will continue.

Returns:
None
Expand Down Expand Up @@ -573,7 +572,6 @@ def train(
lr_policy=None,
batches_per_step=None,
stop_on_nan_loss=False,
steps_per_nan_check=100,
synced_batchnorm=False,
synced_batchnorm_groupsize=0,
gradient_predivide=False,
Expand All @@ -591,7 +589,6 @@ def train(
lr_policy=lr_policy,
batches_per_step=batches_per_step,
stop_on_nan_loss=stop_on_nan_loss,
steps_per_nan_check=steps_per_nan_check,
synced_batchnorm=synced_batchnorm,
synced_batchnorm_groupsize=synced_batchnorm_groupsize,
gradient_predivide=gradient_predivide,
Expand Down