Skip to content

Commit

Permalink
[bitsandbytes] Simplify bnb int8 dequant (#10401)
Browse files Browse the repository at this point in the history
* fix dequantization for latest bnb.

* smol fixes.

* fix type annotation

* update peft link

* updates
  • Loading branch information
sayakpaul authored Feb 4, 2025
1 parent 3e35f56 commit 5e8e6cb
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions src/diffusers/quantizers/bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
return model


# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
# Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604/src/peft/utils/integrations.py#L81
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None):
"""
Helper function to dequantize 4bit or 8bit bnb weights.
Expand All @@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
if state.SCB is None:
state.SCB = weight.SCB

im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
# Use bitsandbytes API if available (requires v0.45.0+)
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
else:
# Multiply by (scale/127) to dequantize.
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3

if dtype:
dequantized = dequantized.to(dtype)
return dequantized


def _create_accelerate_new_hook(old_hook):
Expand All @@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook):

def _dequantize_and_replace(
model,
dtype,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
Expand Down Expand Up @@ -244,7 +248,7 @@ def _dequantize_and_replace(
else:
state = None

new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype))

if bias is not None:
new_module.bias = bias
Expand All @@ -263,9 +267,10 @@ def _dequantize_and_replace(
if len(list(module.children())) > 0:
_, has_been_replaced = _dequantize_and_replace(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
dtype=dtype,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
quantization_config=quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
Expand All @@ -280,6 +285,7 @@ def dequantize_and_replace(
):
model, has_been_replaced = _dequantize_and_replace(
model,
dtype=model.dtype,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
Expand Down

0 comments on commit 5e8e6cb

Please sign in to comment.