diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index 247d0e71bb26..a9771b368a86 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -153,8 +153,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name return model -# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 -def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): +# Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604/src/peft/utils/integrations.py#L81 +def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None): """ Helper function to dequantize 4bit or 8bit bnb weights. @@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): if state.SCB is None: state.SCB = weight.SCB - im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) - im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) - im, Sim = bnb.functional.transform(im, "col32") - if state.CxB is None: - state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) - out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) - return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() + if hasattr(bnb.functional, "int8_vectorwise_dequant"): + # Use bitsandbytes API if available (requires v0.45.0+) + dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB) + else: + # Multiply by (scale/127) to dequantize. + dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3 + + if dtype: + dequantized = dequantized.to(dtype) + return dequantized def _create_accelerate_new_hook(old_hook): @@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook): def _dequantize_and_replace( model, + dtype, modules_to_not_convert=None, current_key_name=None, quantization_config=None, @@ -244,7 +248,7 @@ def _dequantize_and_replace( else: state = None - new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state)) + new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype)) if bias is not None: new_module.bias = bias @@ -263,9 +267,10 @@ def _dequantize_and_replace( if len(list(module.children())) > 0: _, has_been_replaced = _dequantize_and_replace( module, - modules_to_not_convert, - current_key_name, - quantization_config, + dtype=dtype, + modules_to_not_convert=modules_to_not_convert, + current_key_name=current_key_name, + quantization_config=quantization_config, has_been_replaced=has_been_replaced, ) # Remove the last key for recursion @@ -280,6 +285,7 @@ def dequantize_and_replace( ): model, has_been_replaced = _dequantize_and_replace( model, + dtype=model.dtype, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config, )