From 0989eca746ada5a2439010ffb60b17efdc378270 Mon Sep 17 00:00:00 2001 From: Alex Xiao Date: Fri, 28 Aug 2020 17:49:48 -0700 Subject: [PATCH] add more detailed logging for fp16 diverging Summary: We often get a generic "minimum loss scale reached" when fp16 training diverges. Would be useful to have a breakdown on where exactly the gradient norm becomes too big. Reviewed By: myleott Differential Revision: D23297774 fbshipit-source-id: 69da1cca1be22f15af633f8efe4e7b491cf4f6f9 --- fairseq/nan_detector.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/fairseq/nan_detector.py b/fairseq/nan_detector.py index 789169d2b0..df4e28ec89 100644 --- a/fairseq/nan_detector.py +++ b/fairseq/nan_detector.py @@ -19,6 +19,7 @@ def __init__(self, model, forward=True, backward=True): self.fhooks = [] self.forward = forward self.backward = backward + self.model = model self.reset() for name, mod in model.named_modules(): @@ -29,6 +30,19 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_traceback): + # Dump out all model gnorms to enable better debugging + norm = {} + gradients = {} + for name, param in self.model.named_parameters(): + grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32) + norm[name] = grad_norm.item() + if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): + gradients[name] = param.grad.data + if len(gradients) > 0: + logger.info("Detected nan/inf grad norm, dumping norms...") + logger.info(f"norms: {norm}") + logger.info(f"gradients: {gradients}") + self.close() def add_hooks(self, module):