You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
4-bit quantization cannot load weights to meta device for bias terms of the linear layer: NotImplementedError: Cannot copy out of meta tensor; no data!
#2742
Closed
2 of 4 tasks
MuhammedHasan opened this issue
May 5, 2024
· 0 comments
· Fixed by #2805
One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
My own task or dataset (give details below)
Reproduction
The accelerate.utils.load_and_quantize_model cannot load bias to meta device. Run following python main.py:
importtorchfromsafetensors.torchimportsave_modelfromaccelerateimportinit_empty_weightsfromaccelerate.utilsimportBnbQuantizationConfig, load_and_quantize_modelclassModel(torch.nn.Module):
def__init__(self, bias=False):
super().__init__()
self.q=torch.nn.Linear(10, 10, bias=bias)
self.k=torch.nn.Linear(10, 10, bias=bias)
self.v=torch.nn.Linear(10, 10, bias=bias)
defforward(self, x):
returnself.q(x) +self.k(x) +self.v# Without biasprint('Save and load model without bias')
save_model(Model(bias=False), 'model.safetensors', metadata={'format': 'pt'})
withinit_empty_weights():
qmodel=Model(bias=False)
qmodel=load_and_quantize_model(
qmodel,
weights_location='model.safetensors',
bnb_quantization_config=BnbQuantizationConfig(load_in_4bit=True))
# Works fine!# With biasprint('Save and load model with bias')
save_model(Model(bias=True), 'model.safetensors', metadata={'format': 'pt'})
withinit_empty_weights():
qmodel=Model(bias=True)
qmodel=load_and_quantize_model(
qmodel,
weights_location='model.safetensors',
bnb_quantization_config=BnbQuantizationConfig(load_in_4bit=True))
This throws the following error:
# Traceback (most recent call last):# File "main.py", line 37, in <module># qmodel = load_and_quantize_model(# File "/home/mcelik/anaconda3/envs/esm-efficient/lib/python3.8/site-packages/accelerate/utils/bnb.py", line 183, in load_and_quantize_model# load_checkpoint_in_model(# File "/home/mcelik/anaconda3/envs/esm-efficient/lib/python3.8/site-packages/accelerate/utils/modeling.py", line 1736, in load_checkpoint_in_model# set_module_tensor_to_device(# File "/home/mcelik/anaconda3/envs/esm-efficient/lib/python3.8/site-packages/accelerate/utils/modeling.py", line 449, in set_module_tensor_to_device# module.weight = module.weight.cuda(device_index)# File "/home/mcelik/anaconda3/envs/esm-efficient/lib/python3.8/site-packages/bitsandbytes/nn/modules.py", line 304, in cuda# return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)# File "/home/mcelik/anaconda3/envs/esm-efficient/lib/python3.8/site-packages/bitsandbytes/nn/modules.py", line 324, in to# return self._quantize(device)# File "/home/mcelik/anaconda3/envs/esm-efficient/lib/python3.8/site-packages/bitsandbytes/nn/modules.py", line 288, in _quantize# w = self.data.contiguous().cuda(device)# NotImplementedError: Cannot copy out of meta tensor; no data!
When I initialized the model without with init_empty_weights():, it worked, so the problem seems related to the meta device. Also, load_in_8bit=True works okay. The issue only appears in 4-bit quantization.
Expected behavior
The models should be initialized with 4-bit weights, including bias terms.
The text was updated successfully, but these errors were encountered:
MuhammedHasan
pushed a commit
to MuhammedHasan/accelerate
that referenced
this issue
May 5, 2024
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
The
accelerate.utils.load_and_quantize_model
cannot load bias to meta device. Run followingpython main.py
:This throws the following error:
When I initialized the model without
with init_empty_weights():,
it worked, so the problem seems related to the meta device. Also,load_in_8bit=True
works okay. The issue only appears in 4-bit quantization.Expected behavior
The models should be initialized with 4-bit weights, including bias terms.
The text was updated successfully, but these errors were encountered: