From 065e74d11afdad8938e6276d0d6bc12d2d67f807 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Fri, 31 May 2024 15:26:17 +0200 Subject: [PATCH] 4-bit quantization meta device bias loading bug (#2805) * 4-bit quantization meta device bias loading bug: fixes #2742 * move condition --------- Co-authored-by: mh --- src/accelerate/utils/modeling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 650651019e7..d692316653e 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -442,7 +442,11 @@ def set_module_tensor_to_device( elif module.bias is None: # if no bias exists, we can quantize right away module = module.cuda(device_index) - elif module.__class__.__name__ == "Linear4bit" and getattr(module.weight, "quant_state", None) is None: + elif ( + module.__class__.__name__ == "Linear4bit" + and getattr(module.weight, "quant_state", None) is None + and str(module.weight.device) != "meta" + ): # quantize only if necessary device_index = torch.device(device).index if torch.device(device).type == "cuda" else None if not getattr(module.weight, "quant_state", None) and device_index is not None: