From d738fa8b78a74a2d9c2e5f817f2ee0d4295c42d0 Mon Sep 17 00:00:00 2001 From: Dmitry Baranchuk Date: Thu, 28 Jul 2022 04:16:03 +0300 Subject: [PATCH] Support bfloat16 for autograd (#499) (cherry picked from commit 28261470e44f2ae4157d08b563b4d2771f3a9549) --- hivemind/moe/server/module_backend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py index f6260371a..5688b8ecc 100644 --- a/hivemind/moe/server/module_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -118,9 +118,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: with torch.enable_grad(): args = [ - tensor.detach().requires_grad_(True) - if tensor.dtype in (torch.half, torch.float, torch.double) - else tensor.detach() + tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach() for tensor in args ] kwargs = {