diff --git a/hivemind/moe/server/module_backend.py b/hivemind/moe/server/module_backend.py index f6260371a..5688b8ecc 100644 --- a/hivemind/moe/server/module_backend.py +++ b/hivemind/moe/server/module_backend.py @@ -118,9 +118,7 @@ def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: with torch.enable_grad(): args = [ - tensor.detach().requires_grad_(True) - if tensor.dtype in (torch.half, torch.float, torch.double) - else tensor.detach() + tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach() for tensor in args ] kwargs = {