From 3215a4c043335bb127db776edd710d5ae160fc49 Mon Sep 17 00:00:00 2001 From: Jocelyn Huang Date: Thu, 14 May 2020 10:13:55 -0700 Subject: [PATCH 1/3] Better NaN/inf loss handling for O0 (skip step across workers) Signed-off-by: Jocelyn Huang --- nemo/backends/pytorch/actions.py | 59 ++++++++------------------------ nemo/core/neural_factory.py | 5 +-- 2 files changed, 16 insertions(+), 48 deletions(-) diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index 172b2131990c..be6e8436ae29 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -1078,22 +1078,6 @@ def deployment_export(module, output: str, d_format: DeploymentFormat, input_exa output_example=output_example, ) - def _check_nan_or_inf(self, placement_gpu, nan_or_inf, steps_per_nan_check=None): - # Note that nan_or_inf only gets set if stop_on_nan loss is True, or if using O0/not using apex.amp. - if not placement_gpu: - return - if steps_per_nan_check is None or self.step % steps_per_nan_check == 0: - world_size = dist.get_world_size() - # We use dtype=int because nccl backend doesn't support torch.bool - nan_inf_tensor = torch.tensor(nan_or_inf, dtype=int).cuda() - nan_inf_results = [] - for _ in range(world_size): - nan_inf_results.append(torch.empty_like(nan_inf_tensor)) - dist.all_gather(nan_inf_results, nan_inf_tensor) - for nan_inf in nan_inf_results: - if nan_inf: - raise ValueError('Terminating due to previous NaN or inf.') - def train( self, tensors_to_optimize=None, @@ -1104,7 +1088,6 @@ def train( lr_policy=None, batches_per_step=None, stop_on_nan_loss=False, - steps_per_nan_check=100, synced_batchnorm=False, synced_batchnorm_groupsize=0, gradient_predivide=False, @@ -1353,8 +1336,6 @@ def train( # Do action start callbacks self._perform_on_action_start(callbacks=callbacks) - nan_or_inf = False - # MAIN TRAINING LOOP # iteration over epochs while num_epochs is None or self.epoch_num < num_epochs: @@ -1418,26 +1399,22 @@ def train( curr_tensors_to_optimize = training_loop[self.step % len(training_loop)][1] final_loss = 0 for tensor in curr_tensors_to_optimize: - if ( - torch.isnan(registered_tensors[tensor.unique_name]).any() - or torch.isinf(registered_tensors[tensor.unique_name]).any() - ): - if ( - (stop_on_nan_loss) - or (self._optim_level not in AmpOptimizations) - or (self._optim_level == Optimization.mxprO0) - ): - # Set flag here and terminate at next all_gather check. - nan_or_inf = True - logging.warning( - 'Loss is NaN or inf at step %d, will terminate within the' - ' next steps_per_nan_check steps', - self.step, - ) - else: - logging.warning('Loss is NaN or inf, continuing training') final_loss += registered_tensors[tensor.unique_name] + # Check for NaN/inf loss (across workers if applicable) + loss_nan_inf_checker = final_loss.clone() + if placement_gpu: + dist.all_reduce(loss_nan_inf_checker) + if torch.isnan(loss_nan_inf_checker).any() or torch.isinf(loss_nan_inf_checker).any(): + if stop_on_nan_loss: + raise ValueError('Loss is NaN or inf - exiting') + if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0: + logging.warning('Loss is NaN or inf.') + else: + # Skip this step across workers if loss is NaN/inf and using fp32 + logging.warning('Loss is NaN or inf. Skipping update.') + continue + if self._optim_level in AmpOptimizations and self._optim_level != Optimization.mxprO0: with amp.scale_loss(final_loss, curr_optimizer, delay_unscale=disable_allreduce) as scaled_loss: if disable_allreduce: @@ -1460,15 +1437,12 @@ def train( final_loss.backward(bps_scale.to(final_loss.get_device())) # single device (CPU or GPU) else: - # Fix (workaround?) enabling to backpropagate gradiens on CPUs. + # Fix (workaround?) enabling to backpropagate gradients on CPUs. if final_loss.get_device() < 0: final_loss.backward(bps_scale) else: final_loss.backward(bps_scale.to(final_loss.get_device())) - # Check if we should terminate due to NaN/inf on any workers. - self._check_nan_or_inf(placement_gpu, nan_or_inf, steps_per_nan_check=steps_per_nan_check) - batch_counter += 1 if batch_counter == batches_per_step: @@ -1488,9 +1462,6 @@ def train( self._perform_on_epoch_end(callbacks=callbacks) self.epoch_num += 1 - # Check again if we should stop on NaN/inf - self._check_nan_or_inf(placement_gpu, nan_or_inf) - self._perform_on_action_end(callbacks=callbacks) def infer( diff --git a/nemo/core/neural_factory.py b/nemo/core/neural_factory.py index 4402ded7b927..b5c1930d0e06 100644 --- a/nemo/core/neural_factory.py +++ b/nemo/core/neural_factory.py @@ -137,8 +137,7 @@ def train( batch_size stop_on_nan_loss: (default: False) If set to True, the training will stop if loss=nan or inf. If set to False, the training - will continue. Note that if apex.amp is not used, or if - optimization level is O0, training will stop regardless. + will continue. Returns: None @@ -573,7 +572,6 @@ def train( lr_policy=None, batches_per_step=None, stop_on_nan_loss=False, - steps_per_nan_check=100, synced_batchnorm=False, synced_batchnorm_groupsize=0, gradient_predivide=False, @@ -591,7 +589,6 @@ def train( lr_policy=lr_policy, batches_per_step=batches_per_step, stop_on_nan_loss=stop_on_nan_loss, - steps_per_nan_check=steps_per_nan_check, synced_batchnorm=synced_batchnorm, synced_batchnorm_groupsize=synced_batchnorm_groupsize, gradient_predivide=gradient_predivide, From ecfbcde94da68847ee03ffbc60150da11d188550 Mon Sep 17 00:00:00 2001 From: Jocelyn Huang Date: Thu, 14 May 2020 10:24:00 -0700 Subject: [PATCH 2/3] Add entry to changelog Signed-off-by: Jocelyn Huang --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bb024220a20..93f63a2bc154 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,7 @@ To release a new version, please update the changelog as followed: - Online audio augmentation notebook in ASR examples ([PR #605](https://github.com/NVIDIA/NeMo/pull/605)) - @titu1994 ### Changed +- Syncs across workers at each step to check for NaN or inf loss. Terminates all workers if stop\_on\_nan\_loss is set (as before), lets Apex deal with it if apex.amp optimization level is O1 or higher, and skips the step across workers otherwise. ([PR #637](https://github.com/NVIDIA/NeMo/pull/637)) - @redoctopus ### Dependencies Update From 0fe97daa7a7180f2988f58278ca729e7e05a3627 Mon Sep 17 00:00:00 2001 From: Jocelyn Huang Date: Fri, 15 May 2020 11:57:12 -0700 Subject: [PATCH 3/3] Change NaN/inf all_reduce check to use MAX instead of default SUM Signed-off-by: Jocelyn Huang --- nemo/backends/pytorch/actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index be6e8436ae29..d22fa8bc2c13 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -1404,7 +1404,7 @@ def train( # Check for NaN/inf loss (across workers if applicable) loss_nan_inf_checker = final_loss.clone() if placement_gpu: - dist.all_reduce(loss_nan_inf_checker) + dist.all_reduce(loss_nan_inf_checker, torch.distributed.ReduceOp.MAX) if torch.isnan(loss_nan_inf_checker).any() or torch.isinf(loss_nan_inf_checker).any(): if stop_on_nan_loss: raise ValueError('Loss is NaN or inf - exiting')