Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use image_shape for SD3 #1979

Merged
merged 1 commit into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Hongyu Chiu](https://github.com/james77777778), [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml), [divamgupta](https://github.com/divamgupta)<br>\n",
"**Date created:** 2024/10/09<br>\n",
"**Last modified:** 2024/10/09<br>\n",
"**Last modified:** 2024/10/24<br>\n",
"**Description:** Image generation using KerasHub's Stable Diffusion 3 model."
]
},
Expand Down Expand Up @@ -96,7 +96,7 @@
"That will automatically load and configure trained `backbone` and `preprocessor`\n",
"for you.\n",
"\n",
"Note that in this guide, we'll use `height=512` and `width=512` for faster\n",
"Note that in this guide, we'll use `image_shape=(512, 512, 3)` for faster\n",
"image generation. For higher-quality output, it's recommended to use the default\n",
"size of `1024`. Since the entire backbone has about 3 billion parameters, which\n",
"can be challenging to fit into a consumer-level GPU, we set `dtype=\"float16\"` to\n",
Expand Down Expand Up @@ -148,7 +148,7 @@
"\n",
"\n",
"backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(\n",
" \"stable_diffusion_3_medium\", height=512, width=512, dtype=\"float16\"\n",
" \"stable_diffusion_3_medium\", image_shape=(512, 512, 3), dtype=\"float16\"\n",
")\n",
"preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(\n",
" \"stable_diffusion_3_medium\"\n",
Expand Down
6 changes: 3 additions & 3 deletions guides/keras_hub/stable_diffusion_3_in_keras_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Stable Diffusion 3 in KerasHub!
Author: [Hongyu Chiu](https://github.com/james77777778), [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml), [divamgupta](https://github.com/divamgupta)
Date created: 2024/10/09
Last modified: 2024/10/09
Last modified: 2024/10/24
Description: Image generation using KerasHub's Stable Diffusion 3 model.
Accelerator: GPU
"""
Expand Down Expand Up @@ -63,7 +63,7 @@
That will automatically load and configure trained `backbone` and `preprocessor`
for you.

Note that in this guide, we'll use `height=512` and `width=512` for faster
Note that in this guide, we'll use `image_shape=(512, 512, 3)` for faster
image generation. For higher-quality output, it's recommended to use the default
size of `1024`. Since the entire backbone has about 3 billion parameters, which
can be challenging to fit into a consumer-level GPU, we set `dtype="float16"` to
Expand Down Expand Up @@ -107,7 +107,7 @@ def display_generated_images(images):


backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(
"stable_diffusion_3_medium", height=512, width=512, dtype="float16"
"stable_diffusion_3_medium", image_shape=(512, 512, 3), dtype="float16"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the channel axis need to be specified? Other APIs that an image_shape arg only do height/width

Copy link
Contributor Author

@james77777778 james77777778 Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true? I think most of the backbones in kerashub expect (h, w, c) format using image_shape.
https://github.com/search?q=repo%3Akeras-team%2Fkeras-hub%20image_shape&type=code

This change was requested by @divyashreepathihalli and I agree that it is more consistent with other backbone APIs

Additionally, even though most users won’t do this, it is still valid to train a diffusion model with the non-standard RGB images.

EDITED:
We need to specify channel axis to correctly instantiate VAE image encoder.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sounds good

)
preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(
"stable_diffusion_3_medium"
Expand Down
6 changes: 3 additions & 3 deletions guides/md/keras_hub/stable_diffusion_3_in_keras_hub.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Author:** [Hongyu Chiu](https://github.com/james77777778), [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml), [divamgupta](https://github.com/divamgupta)<br>
**Date created:** 2024/10/09<br>
**Last modified:** 2024/10/09<br>
**Last modified:** 2024/10/24<br>
**Description:** Image generation using KerasHub's Stable Diffusion 3 model.


Expand Down Expand Up @@ -71,7 +71,7 @@ text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
That will automatically load and configure trained `backbone` and `preprocessor`
for you.

Note that in this guide, we'll use `height=512` and `width=512` for faster
Note that in this guide, we'll use `image_shape=(512, 512, 3)` for faster
image generation. For higher-quality output, it's recommended to use the default
size of `1024`. Since the entire backbone has about 3 billion parameters, which
can be challenging to fit into a consumer-level GPU, we set `dtype="float16"` to
Expand Down Expand Up @@ -116,7 +116,7 @@ def display_generated_images(images):


backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(
"stable_diffusion_3_medium", height=512, width=512, dtype="float16"
"stable_diffusion_3_medium", image_shape=(512, 512, 3), dtype="float16"
)
preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(
"stable_diffusion_3_medium"
Expand Down
Loading