Skip to content

Commit

Permalink
4-bit quantization meta device bias loading bug (#2805)
Browse files Browse the repository at this point in the history
* 4-bit quantization meta device bias loading bug: fixes #2742

* move condition

---------

Co-authored-by: mh <[email protected]>
  • Loading branch information
SunMarc and mh authored May 31, 2024
1 parent 86b6dea commit 065e74d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,11 @@ def set_module_tensor_to_device(
elif module.bias is None:
# if no bias exists, we can quantize right away
module = module.cuda(device_index)
elif module.__class__.__name__ == "Linear4bit" and getattr(module.weight, "quant_state", None) is None:
elif (
module.__class__.__name__ == "Linear4bit"
and getattr(module.weight, "quant_state", None) is None
and str(module.weight.device) != "meta"
):
# quantize only if necessary
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
if not getattr(module.weight, "quant_state", None) and device_index is not None:
Expand Down

0 comments on commit 065e74d

Please sign in to comment.