Skip to content

Commit

Permalink
fix bug huggingface#172
Browse files Browse the repository at this point in the history
  • Loading branch information
reppy4620 committed Sep 26, 2021
1 parent 1b1463f commit 6f2e3eb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/accelerate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _convert_to_fp32(tensor):
def _is_fp16_tensor(tensor):
return hasattr(tensor, "dtype") and tensor.dtype == torch.float16

return recursively_apply(_is_fp16_tensor, tensor, test_type=_is_fp16_tensor)
return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_tensor)


def convert_outputs_to_fp32(model_forward):
Expand Down

0 comments on commit 6f2e3eb

Please sign in to comment.