Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] TorchAO Quantizer #10009

Merged
merged 39 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
64cbf11
torchao quantizer
a-r-r-o-w Nov 24, 2024
b78a36c
make style
a-r-r-o-w Nov 24, 2024
355509e
update
a-r-r-o-w Nov 24, 2024
cbb0da4
update
a-r-r-o-w Nov 24, 2024
ee084a5
cuda capability check
a-r-r-o-w Nov 24, 2024
748a002
update
a-r-r-o-w Nov 24, 2024
bc006f2
fix
a-r-r-o-w Nov 24, 2024
956f3bf
fix
a-r-r-o-w Nov 25, 2024
2c6beef
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Nov 25, 2024
cfdb94f
update
a-r-r-o-w Nov 25, 2024
8e214e2
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Nov 25, 2024
1d9f832
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Nov 28, 2024
01b2b42
update tests
a-r-r-o-w Nov 28, 2024
b17cf35
device map changes
a-r-r-o-w Nov 28, 2024
250ccf4
update; apply suggestions from review
a-r-r-o-w Dec 1, 2024
50946a9
fix
a-r-r-o-w Dec 1, 2024
edae34b
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 1, 2024
8f09bdf
remove slow marker
a-r-r-o-w Dec 1, 2024
7c79b8e
remove pytest deprecation warnings
a-r-r-o-w Dec 1, 2024
820ac88
Merge branch 'main' into torchao-quantizer
sayakpaul Dec 5, 2024
f9f1535
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 5, 2024
747bd7d
apply review suggestions
a-r-r-o-w Dec 5, 2024
25d3cf8
add torch compile test
a-r-r-o-w Dec 5, 2024
10deb16
add more tests; add expected slices
a-r-r-o-w Dec 5, 2024
f3771a8
fix
a-r-r-o-w Dec 5, 2024
55d6155
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 5, 2024
de97a51
improve test check
a-r-r-o-w Dec 5, 2024
101d10c
update docs
a-r-r-o-w Dec 5, 2024
edd98db
bnb device map check
a-r-r-o-w Dec 5, 2024
2677e0c
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 5, 2024
cc70887
update docs
a-r-r-o-w Dec 5, 2024
5f75db2
Apply suggestions from code review
a-r-r-o-w Dec 6, 2024
9704daa
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 6, 2024
b227189
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 9, 2024
7d9d1dc
address review comments
a-r-r-o-w Dec 9, 2024
e9fccb6
update
a-r-r-o-w Dec 9, 2024
bc874fc
add nightly marker for torch.compile test
a-r-r-o-w Dec 9, 2024
29ec905
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 12, 2024
7ca64fd
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
- local: quantization/torchao
title: torchao
title: Quantization Methods
- sections:
- local: optimization/fp16
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/api/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui

[[autodoc]] BitsAndBytesConfig

## TorchAoConfig

[[autodoc]] TorchAoConfig

## DiffusersQuantizer

[[autodoc]] quantizers.base.DiffusersQuantizer
2 changes: 1 addition & 1 deletion docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be

## When to use what?

This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use.
92 changes: 92 additions & 0 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# torchao

[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.

Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it seems Pytorch 2.5+ is required because in

there is an import of torch.uint1 (and others) which are not available in earlier torch versions. However, diffusers seem to require torch>=1.4 (ref), so this seem inconsistent. Am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchAO will not be imported or usable unless the pytorch version of 2.5 or above is available. Some Diffusers models can run with the 1.4 version as well, which is why that's the minimum required version.

Copy link

@fjeremic fjeremic Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm running into the same issue with the torch.unit1 import. It seems the TorchAO import is not guarded according to the backtrace. The following backtrace stems from this import line:

from diffusers import StableDiffusionXLPipeline

Here is the trace, and the pip list:

    Traceback (most recent call last):
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/utils/import_utils.py", line 920, in _get_module
        return importlib.import_module("." + module_name, self.__name__)
      File "/usr/local/lib/python3.10/importlib/__init__.py", line 126, in import_module
        return _bootstrap._gcd_import(name[level:], package, level)
      File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
      File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
      File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 883, in exec_module
      File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/loaders/single_file.py", line 24, in <module>
        from .single_file_utils import (
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/loaders/single_file_utils.py", line 28, in <module>
        from ..models.modeling_utils import load_state_dict
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 35, in <module>
        from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/__init__.py", line 15, in <module>
        from .auto import DiffusersAutoQuantizer
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/auto.py", line 31, in <module>
        from .torchao import TorchAoHfQuantizer
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/torchao/__init__.py", line 15, in <module>
        from .torchao_quantizer import TorchAoHfQuantizer
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/torchao/torchao_quantizer.py", line 45, in <module>
        torch.uint1,
      File "/github/home/.local/lib/python3.10/site-packages/torch/__init__.py", line 1938, in __getattr__
        raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
    AttributeError: module 'torch' has no attribute 'uint1'

And the pip list:

pip list -v
Package                  Version     Editable project location    Location                                         Installer
------------------------ ----------- ---------------------------- ------------------------------------------------ ---------
certifi                  2024.12.14                               /github/home/.local/lib/python3.10/site-packages pip
charset-normalizer       3.4.0                                    /github/home/.local/lib/python3.10/site-packages pip
colorama                 0.4.6                                    /github/home/.local/lib/python3.10/site-packages pip
coloredlogs              15.0.1                                   /github/home/.local/lib/python3.10/site-packages pip
colorlog                 6.9.0                                    /github/home/.local/lib/python3.10/site-packages pip
coverage                 7.6.9                                    /github/home/.local/lib/python3.10/site-packages pip
diffusers                0.32.0                                   /github/home/.local/lib/python3.10/site-packages pip
exceptiongroup           1.2.2                                    /github/home/.local/lib/python3.10/site-packages pip
execnet                  2.1.1                                    /github/home/.local/lib/python3.10/site-packages pip
filelock                 3.16.1                                   /github/home/.local/lib/python3.10/site-packages pip
flatbuffers              24.12.23                                 /github/home/.local/lib/python3.10/site-packages pip
fsspec                   2024.12.0                                /github/home/.local/lib/python3.10/site-packages pip
huggingface-hub          0.27.0                                   /github/home/.local/lib/python3.10/site-packages pip
humanfriendly            10.0                                     /github/home/.local/lib/python3.10/site-packages pip
idna                     3.10                                     /github/home/.local/lib/python3.10/site-packages pip
importlib_metadata       [8](/runs/952286/job/1969259#step:10:9).5.0                                    /github/home/.local/lib/python3.10/site-packages pip
iniconfig                2.0.0                                    /github/home/.local/lib/python3.10/site-packages pip
Jinja2                   3.1.5                                    /github/home/.local/lib/python3.10/site-packages pip
markdown-it-py           3.0.0                                    /github/home/.local/lib/python3.10/site-packages pip
MarkupSafe               3.0.2                                    /github/home/.local/lib/python3.10/site-packages pip
mdurl                    0.1.2                                    /github/home/.local/lib/python3.10/site-packages pip
mpmath                   1.3.0                                    /github/home/.local/lib/python3.10/site-packages pip
networkx                 3.4.2                                    /github/home/.local/lib/python3.10/site-packages pip
numpy                    1.26.4                                   /github/home/.local/lib/python3.10/site-packages pip
nvidia-cublas-cu12       12.1.3.1                                 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cuda-cupti-cu12   12.1.105                                 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cuda-nvrtc-cu12   12.1.105                                 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cuda-runtime-cu12 12.1.105                                 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cudnn-cu12        8.[9](/runs/952286/job/1969259#step:10:10).2.26                                 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cufft-cu12        11.0.2.54                                /github/home/.local/lib/python3.[10](/runs/952286/job/1969259#step:10:11)/site-packages pip
nvidia-curand-cu12       10.3.2.106                               /github/home/.local/lib/python3.10/site-packages pip
nvidia-cusolver-cu12     [11](/runs/952286/job/1969259#step:10:12).4.5.107                               /github/home/.local/lib/python3.10/site-packages pip
nvidia-cusparse-cu[12](/runs/952286/job/1969259#step:10:13)     12.1.0.106                               /github/home/.local/lib/python3.10/site-packages pip
nvidia-nccl-cu12         2.19.3                                   /github/home/.local/lib/python3.10/site-packages pip
nvidia-nvjitlink-cu12    12.6.85                                  /github/home/.local/lib/python3.10/site-packages pip
nvidia-nvtx-cu12         12.1.105                                 /github/home/.local/lib/python3.10/site-packages pip
onnx                     1.17.0                                   /github/home/.local/lib/python3.10/site-packages pip
onnx2torch               1.5.15                                   /github/home/.local/lib/python3.10/site-packages pip
onnxruntime              1.20.1                                   /github/home/.local/lib/python3.10/site-packages pip
onnxsim                  0.4.36                                   /github/home/.local/lib/python3.10/site-packages pip
packaging                24.2                                     /github/home/.local/lib/python3.10/site-packages pip
pandas                   2.2.3                                    /github/home/.local/lib/python3.10/site-packages pip
pillow                   11.0.0                                   /github/home/.local/lib/python3.10/site-packages pip
pip                      22.0.4                                   /usr/local/lib/python3.10/site-packages          pip
pluggy                   1.5.0                                    /github/home/.local/lib/python3.10/site-packages pip
protobuf                 5.29.2                                   /github/home/.local/lib/python3.10/site-packages pip
Pygments                 2.18.0                                   /github/home/.local/lib/python3.10/site-packages pip
pytest                   8.3.4                                    /github/home/.local/lib/python3.10/site-packages pip
pytest-xdist             3.6.1                                    /github/home/.local/lib/python3.10/site-packages pip
python-dateutil          2.9.0.post0                              /github/home/.local/lib/python3.10/site-packages pip
python-dotenv            1.0.1                                    /github/home/.local/lib/python3.10/site-packages pip
python-json-logger       3.2.1                                    /github/home/.local/lib/python3.10/site-packages pip
pytz                     2024.2                                   /github/home/.local/lib/python3.10/site-packages pip
PyYAML                   6.0.2                                    /github/home/.local/lib/python3.10/site-packages pip
regex                    2024.11.6                                /github/home/.local/lib/python3.10/site-packages pip
requests                 2.32.3                                   /github/home/.local/lib/python3.10/site-packages pip
rich                     13.9.4                                   /github/home/.local/lib/python3.10/site-packages pip
safetensors              0.4.5                                    /github/home/.local/lib/python3.10/site-packages pip
scipy                    1.[14](/runs/952286/job/1969259#step:10:15).1                                   /github/home/.local/lib/python3.10/site-packages pip
sentencepiece            0.2.0                                    /github/home/.local/lib/python3.10/site-packages pip
setuptools               58.1.0                                   /usr/local/lib/python3.10/site-packages          pip
six                      1.17.0                                   /github/home/.local/lib/python3.10/site-packages pip
sympy                    1.13.3                                   /github/home/.local/lib/python3.10/site-packages pip
tabulate                 0.9.0                                    /github/home/.local/lib/python3.10/site-packages pip
tokenizers               0.[15](/runs/952286/job/1969259#step:10:16).2                                   /github/home/.local/lib/python3.10/site-packages pip
tomli                    2.2.1                                    /github/home/.local/lib/python3.10/site-packages pip
torch                    2.2.2                                    /github/home/.local/lib/python3.10/site-packages pip
torchvision              0.[17](/runs/952286/job/1969259#step:10:18).2                                   /github/home/.local/lib/python3.10/site-packages pip
tqdm                     4.67.1                                   /github/home/.local/lib/python3.10/site-packages pip
transformers             4.38.2                                   /github/home/.local/lib/python3.10/site-packages pip
triton                   2.2.0                                    /github/home/.local/lib/python3.10/site-packages pip
typing_extensions        4.12.2                                   /github/home/.local/lib/python3.10/site-packages pip
tzdata                   [20](/runs/952286/job/1969259#step:10:21)24.2                                   /github/home/.local/lib/python3.10/site-packages pip
urllib3                  2.3.0                                    /github/home/.local/lib/python3.10/site-packages pip
wheel                    0.37.1                                   /usr/local/lib/python3.10/site-packages          pip
zipp                     3.[21](/runs/952286/job/1969259#step:10:22).0                                   /github/home/.local/lib/python3.10/site-packages pip

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for reporting @fjeremic! We were able to replicate for torch <= 2.2. It seems to not cause the import errors for >= 2.3. We will be doing a patch release soon to fix this behaviour. Sorry for the inconvenience!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for providing a quick fix!
For completeness, I was running into the import error with torch 2.2.2 when importing AutoencoderKL

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BeckerFelix @fjeremic The patch release is out! Hope it fixes any problems you were facing in torch < 2.3


```bash
pip install -U torch torchao
```


Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.

The example below only quantizes the weights to int8.

```python
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16

quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=dtype,
)
pipe = FluxPipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=dtype,
)
pipe.to("cuda")

prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
image.save("output.png")
```

sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.

```python
# In the above code, add the following after initializing the transformer
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
```

For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.

torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.

The `TorchAoConfig` class accepts three parameters:
- `quant_type`: A string value mentioning one of the quantization types below.
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
Comment on lines +63 to +66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not necessary to have this since it's already in the API docs


## Supported quantization types
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.

Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.

Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.

The quantization methods supported are as follows:
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

| **Category** | **Full Function Names** | **Shorthands** |
|--------------|-------------------------|----------------|
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |

Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.

Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.

## Resources

- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)
4 changes: 2 additions & 2 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig"],
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
Expand Down Expand Up @@ -562,7 +562,7 @@

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig
from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig

try:
if not is_onnx_available():
Expand Down
6 changes: 2 additions & 4 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch
from huggingface_hub.utils import EntryNotFoundError

from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
Expand Down Expand Up @@ -182,7 +181,6 @@ def load_model_dict_into_meta(
device = device or torch.device("cpu")
dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES

accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
empty_state_dict = model.state_dict()
Expand Down Expand Up @@ -215,12 +213,12 @@ def load_model_dict_into_meta(
# bnb params are flattened.
if empty_state_dict[param_name].shape != param.shape:
if (
is_quant_method_bnb
is_quantized
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
elif not is_quant_method_bnb:
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
Expand Down
11 changes: 5 additions & 6 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
hf_quantizer = None

if hf_quantizer is not None:
if device_map is not None:
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
if is_bnb_quantization_method and device_map is not None:
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
)

hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)

Expand Down Expand Up @@ -858,13 +860,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if device_map is None and not is_sharded:
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
# It would error out during the `validate_environment()` call above in the absence of cuda.
is_quant_method_bnb = (
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
)
if hf_quantizer is None:
param_device = "cpu"
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb:
else:
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
from typing import Dict, Optional, Union

from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig
from .torchao import TorchAoHfQuantizer


AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
"torchao": TorchAoHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig,
"torchao": TorchAoConfig,
}


Expand Down
Loading
Loading