diff --git a/src/accelerate/utils/offload.py b/src/accelerate/utils/offload.py index 6e8b34bbc48..750ff9d571c 100644 --- a/src/accelerate/utils/offload.py +++ b/src/accelerate/utils/offload.py @@ -25,8 +25,8 @@ def offload_weight(weight, weight_name, offload_folder, index=None): dtype = None # Check the string instead of the dtype to be compatible with versions of PyTorch that don't have bfloat16. if str(weight.dtype) == "torch.bfloat16": - # Need to convert to FP32 since NumPy does not handle bfloat16s. - weight = weight.float() + # Need to reinterpret the underlined data as int16 since NumPy does not handle bfloat16s. + weight = weight.view(torch.int16) dtype = "bfloat16" array = weight.numpy() tensor_file = os.path.join(offload_folder, f"{weight_name}.dat") @@ -50,8 +50,8 @@ def load_offloaded_weight(weight_file, weight_info): dtype = weight_info["dtype"] if dtype == "bfloat16": - # NumPy does not support bfloat16 so this was saved as a float32 - dtype = "float32" + # NumPy does not support bfloat16 so this was saved as a int16 + dtype = "int16" weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode="r") @@ -59,7 +59,7 @@ def load_offloaded_weight(weight_file, weight_info): weight = weight[0] weight = torch.tensor(weight) if weight_info["dtype"] == "bfloat16": - weight = weight.to(torch.bfloat16) + weight = weight.view(torch.bfloat16) return weight