From a5a3e57125d7c3f06b5a63d6079434428f7371ad Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Sat, 20 Jul 2024 03:07:07 +0200 Subject: [PATCH] Add `torch.float8_e4m3fn` format `dtype_byte_size` (#2945) * add new format * check torch version * style --- src/accelerate/utils/modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index d130a6c044d..8b4e04a1bc5 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -43,7 +43,7 @@ from .memory import clear_device_cache from .offload import load_offloaded_weight, offload_weight, save_offload_index from .tqdm import is_tqdm_available, tqdm -from .versions import compare_versions +from .versions import compare_versions, is_torch_version if is_npu_available(check_device=False): @@ -163,6 +163,8 @@ def dtype_byte_size(dtype: torch.dtype): return 1 / 2 elif dtype == CustomDtype.FP8: return 1 + elif is_torch_version(">=", "2.1.0") and dtype == torch.float8_e4m3fn: + return 1 bit_search = re.search(r"[^\d](\d+)$", str(dtype)) if bit_search is None: raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")