Skip to content

Commit

Permalink
DDUF parser v0.1 (#2692)
Browse files Browse the repository at this point in the history
* First draft for a DDUF parser

* write before read

* comments and lint

* forbid nested directoroes

* gguf typo

* export_from_entries

* some docs

* Update docs/source/en/package_reference/serialization.md

Co-authored-by: Marc Sun <[email protected]>

* Update src/huggingface_hub/serialization/_dduf.py

Co-authored-by: Marc Sun <[email protected]>

* compute data offset without private arg

* type annotations

* enforce 1 level of directory only

* raise correct error DDUFInvalidEntryNameError

* add tests

* note

* test uncompress

* required model_index.json

* Apply suggestions from code review

Co-authored-by: Célina <[email protected]>

* use f-string in logs

* Update docs/source/en/package_reference/serialization.md

Co-authored-by: Pedro Cuenca <[email protected]>

* remove add_entry_to_dduf

* new rules: folders in model_index.json + config files in folders

* add arg

* add arg

* Update docs/source/en/package_reference/serialization.md

Co-authored-by: Sayak Paul <[email protected]>

* Update docs/source/en/package_reference/serialization.md

Co-authored-by: Célina <[email protected]>

* add scheduler config

* scheduler_config

* style

* preprocessor_config.json

---------

Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Célina <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
  • Loading branch information
6 people authored Dec 13, 2024
1 parent b75f8d9 commit 4b0b179
Show file tree
Hide file tree
Showing 5 changed files with 769 additions and 3 deletions.
116 changes: 113 additions & 3 deletions docs/source/en/package_reference/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,117 @@ rendered properly in your Markdown viewer.

`huggingface_hub` provides helpers to save and load ML model weights in a standardized way. This part of the library is still under development and will be improved in future releases. The goal is to harmonize how weights are saved and loaded across the Hub, both to remove code duplication across libraries and to establish consistent conventions.

## Saving
## DDUF file format

DDUF is a file format designed for diffusion models. It allows saving all the information to run a model in a single file. This work is inspired by the [GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) format. `huggingface_hub` provides helpers to save and load DDUF files, ensuring the file format is respected.

<Tip warning={true}>

This is a very early version of the parser. The API and implementation can evolve in the near future.

The parser currently does very little validation. For more details about the file format, check out https://github.com/huggingface/huggingface.js/tree/main/packages/dduf.

</Tip>

### How to write a DDUF file?

Here is how to export a folder containing different parts of a diffusion model using [`export_folder_as_dduf`]:

```python
# Export a folder as a DDUF file
>>> from huggingface_hub import export_folder_as_dduf
>>> export_folder_as_dduf("FLUX.1-dev.dduf", folder_path="path/to/FLUX.1-dev")
```

For more flexibility, you can use [`export_entries_as_dduf`] and pass a list of files to include in the final DDUF file:

```python
# Export specific files from the local disk.
>>> from huggingface_hub import export_entries_as_dduf
>>> export_entries_as_dduf(
... dduf_path="stable-diffusion-v1-4-FP16.dduf",
... entries=[ # List entries to add to the DDUF file (here, only FP16 weights)
... ("model_index.json", "path/to/model_index.json"),
... ("vae/config.json", "path/to/vae/config.json"),
... ("vae/diffusion_pytorch_model.fp16.safetensors", "path/to/vae/diffusion_pytorch_model.fp16.safetensors"),
... ("text_encoder/config.json", "path/to/text_encoder/config.json"),
... ("text_encoder/model.fp16.safetensors", "path/to/text_encoder/model.fp16.safetensors"),
... # ... add more entries here
... ]
... )
```

The `entries` parameter also supports passing an iterable of paths or bytes. This can prove useful if you have a loaded model and want to serialize it directly into a DDUF file instead of having to serialize each component to disk first and then as a DDUF file. Here is an example of how a `StableDiffusionPipeline` can be serialized as DDUF:


```python
# Export state_dicts one by one from a loaded pipeline
>>> from diffusers import DiffusionPipeline
>>> from typing import Generator, Tuple
>>> import safetensors.torch
>>> from huggingface_hub import export_entries_as_dduf
>>> pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
... # ... do some work with the pipeline

>>> def as_entries(pipe: DiffusionPipeline) -> Generator[Tuple[str, bytes], None, None]:
... # Build a generator that yields the entries to add to the DDUF file.
... # The first element of the tuple is the filename in the DDUF archive (must use UNIX separator!). The second element is the content of the file.
... # Entries will be evaluated lazily when the DDUF file is created (only 1 entry is loaded in memory at a time)
... yield "vae/config.json", pipe.vae.to_json_string().encode()
... yield "vae/diffusion_pytorch_model.safetensors", safetensors.torch.save(pipe.vae.state_dict())
... yield "text_encoder/config.json", pipe.text_encoder.config.to_json_string().encode()
... yield "text_encoder/model.safetensors", safetensors.torch.save(pipe.text_encoder.state_dict())
... # ... add more entries here

>>> export_entries_as_dduf(dduf_path="stable-diffusion-v1-4.dduf", entries=as_entries(pipe))
```

**Note:** in practice, `diffusers` provides a method to directly serialize a pipeline in a DDUF file. The snippet above is only meant as an example.

### How to read a DDUF file?

```python
>>> import json
>>> import safetensors.torch
>>> from huggingface_hub import read_dduf_file

# Read DDUF metadata
>>> dduf_entries = read_dduf_file("FLUX.1-dev.dduf")

# Returns a mapping filename <> DDUFEntry
>>> dduf_entries["model_index.json"]
DDUFEntry(filename='model_index.json', offset=66, length=587)

# Load model index as JSON
>>> json.loads(dduf_entries["model_index.json"].read_text())
{'_class_name': 'FluxPipeline', '_diffusers_version': '0.32.0.dev0', '_name_or_path': 'black-forest-labs/FLUX.1-dev', 'scheduler': ['diffusers', 'FlowMatchEulerDiscreteScheduler'], 'text_encoder': ['transformers', 'CLIPTextModel'], 'text_encoder_2': ['transformers', 'T5EncoderModel'], 'tokenizer': ['transformers', 'CLIPTokenizer'], 'tokenizer_2': ['transformers', 'T5TokenizerFast'], 'transformer': ['diffusers', 'FluxTransformer2DModel'], 'vae': ['diffusers', 'AutoencoderKL']}

# Load VAE weights using safetensors
>>> with dduf_entries["vae/diffusion_pytorch_model.safetensors"].as_mmap() as mm:
... state_dict = safetensors.torch.load(mm)
```

### Helpers

[[autodoc]] huggingface_hub.export_entries_as_dduf

[[autodoc]] huggingface_hub.export_folder_as_dduf

[[autodoc]] huggingface_hub.read_dduf_file

[[autodoc]] huggingface_hub.DDUFEntry

### Errors

[[autodoc]] huggingface_hub.errors.DDUFError

[[autodoc]] huggingface_hub.errors.DDUFCorruptedFileError

[[autodoc]] huggingface_hub.errors.DDUFExportError

[[autodoc]] huggingface_hub.errors.DDUFInvalidEntryNameError

## Saving tensors

The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported.

Expand Down Expand Up @@ -37,7 +147,7 @@ This is the underlying factory from which each framework-specific helper is deri

[[autodoc]] huggingface_hub.split_state_dict_into_shards_factory

## Loading
## Loading tensors

The loading helpers support both single-file and sharded checkpoints in either safetensors or pickle format. [`load_torch_model`] takes a `nn.Module` and a checkpoint path (either a single file or a directory) as input and load the weights into the model.

Expand All @@ -50,7 +160,7 @@ The loading helpers support both single-file and sharded checkpoints in either s
[[autodoc]] huggingface_hub.load_state_dict_from_file


## Helpers
## Tensors helpers

### get_torch_storage_id

Expand Down
12 changes: 12 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,12 @@
"split_tf_state_dict_into_shards",
"split_torch_state_dict_into_shards",
],
"serialization._dduf": [
"DDUFEntry",
"export_entries_as_dduf",
"export_folder_as_dduf",
"read_dduf_file",
],
"utils": [
"CacheNotFound",
"CachedFileInfo",
Expand Down Expand Up @@ -997,6 +1003,12 @@ def __dir__():
split_tf_state_dict_into_shards, # noqa: F401
split_torch_state_dict_into_shards, # noqa: F401
)
from .serialization._dduf import (
DDUFEntry, # noqa: F401
export_entries_as_dduf, # noqa: F401
export_folder_as_dduf, # noqa: F401
read_dduf_file, # noqa: F401
)
from .utils import (
CachedFileInfo, # noqa: F401
CachedRepoInfo, # noqa: F401
Expand Down
19 changes: 19 additions & 0 deletions src/huggingface_hub/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,22 @@ class BadRequestError(HfHubHTTPError, ValueError):
huggingface_hub.utils._errors.BadRequestError: Bad request for check endpoint: {details} (Request ID: XXX)
```
"""


# DDUF file format ERROR


class DDUFError(Exception):
"""Base exception for errors related to the DDUF format."""


class DDUFCorruptedFileError(DDUFError):
"""Exception thrown when the DDUF file is corrupted."""


class DDUFExportError(DDUFError):
"""Base exception for errors during DDUF export."""


class DDUFInvalidEntryNameError(DDUFExportError):
"""Exception thrown when the entry name is invalid."""
Loading

0 comments on commit 4b0b179

Please sign in to comment.