-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flux] Port Flux Core Model #1864
Merged
divyashreepathihalli
merged 74 commits into
keras-team:master
from
DavidLandup0:feature/flux
Nov 13, 2024
Merged
Changes from all commits
Commits
Show all changes
74 commits
Select commit
Hold shift + click to select a range
286f4b2
starter commit - ported time embeddings to keras ops
DavidLandup0 244f013
add mlpembedder
DavidLandup0 480ad24
add RMS Norm re-implementation
DavidLandup0 2782242
add qknorm reimplementation
DavidLandup0 48c82e6
add rope, scaled dot product attention and self attention
DavidLandup0 513e370
modulation layer
DavidLandup0 8ccbb26
fix typing
DavidLandup0 c88c949
add double stream block
DavidLandup0 2bc150e
adjustments to doublestreamblock
DavidLandup0 969d508
add signle stream layer@
DavidLandup0 77c9297
update layers and add flux core model
DavidLandup0 35769ab
functions to layers
DavidLandup0 13d46c4
refactor layer usage
DavidLandup0 c00c6a5
refactor layer usage
DavidLandup0 05a1e3f
position math args in call()
DavidLandup0 f076006
name arguments
DavidLandup0 f9fc4a4
fix arg name
DavidLandup0 f2f2c96
start adding conversion script utils
DavidLandup0 311d342
change reshape into rearrange
DavidLandup0 db14c01
add rest of weight conversion and remove redundant shape extraction
DavidLandup0 c5b37c6
fix mlpembedder arg
DavidLandup0 8d3a385
remove redundant args
DavidLandup0 fa5379e
fix params. to self.
DavidLandup0 34e2477
add license
DavidLandup0 cdd397a
add einops
DavidLandup0 8169aa4
fix default arg
DavidLandup0 b1caa7f
expand docstrings
DavidLandup0 76eae83
tanh to gelu
DavidLandup0 c0236ac
refactor weight conversion into tools
DavidLandup0 b418659
update weight conversion
DavidLandup0 99839af
add stand-in presets until weights are uploaded
DavidLandup0 ac5c4b1
set float32 to t.dtype in timestep embedding
DavidLandup0 89dc08c
update more float32s into dynamic types
DavidLandup0 d3de26b
dtype
DavidLandup0 9d4aa22
dtype
DavidLandup0 dbddde7
enable float16 mode
DavidLandup0 b3c75a9
update conversion script to not require flux repo
DavidLandup0 4333bab
add build() methods to avoid running dummy input through model
DavidLandup0 199ba1c
update build call
DavidLandup0 a8de665
fix build calls
DavidLandup0 efe993a
style
DavidLandup0 ff118bb
change dummy call into build() call
DavidLandup0 da78707
Merge branch 'master' into feature/flux
DavidLandup0 a3ccf6d
reference einops issue
DavidLandup0 f88e1e9
address docstring comments in flux layers
DavidLandup0 6e2c320
address docstring comments in flux maths
DavidLandup0 b407ffc
remove numpy
DavidLandup0 ac43081
add docstrings for flux model
DavidLandup0 4b585a0
qkv bias -> use_bias
DavidLandup0 a2facb2
docstring updates
DavidLandup0 bd2ebe2
remove type hints
DavidLandup0 f48bbd2
all img->image, txt->text
DavidLandup0 cbad326
functional subclassing model
DavidLandup0 eeb8e0d
shape fixes
DavidLandup0 330ed70
format
DavidLandup0 9233411
self.hidden_size -> self.dim
DavidLandup0 ed2badc
einops rearrange
DavidLandup0 a65424b
remove build method
DavidLandup0 cb11e28
ops to rearrange
DavidLandup0 f478f39
remove build
DavidLandup0 3b5cb4d
rearrange -> symbolic_rearrange
DavidLandup0 40178e1
turn timesteps and guidance into inputs
DavidLandup0 078459d
basic preprocessor flow
DavidLandup0 0003b08
refactor layer names in conversion script
DavidLandup0 71b564f
add backbone tests
DavidLandup0 7aa93a2
raise not implemented on encode, encode_text, etc. methods
DavidLandup0 b05c94b
styling
DavidLandup0 94f9ffb
fix shape hack with a cleaner alternative
DavidLandup0 adeb842
remove unused attributes, fix tests
DavidLandup0 e97909c
change list into tuple for the expected shape
DavidLandup0 bc1879a
Merge branch 'master' into feature/flux
DavidLandup0 c6e20f6
address comments
DavidLandup0 dda8ec3
save mdel on conversion
DavidLandup0 446ed90
Merge branch 'master' into feature/flux
DavidLandup0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.RMSNormalization") | ||
class RMSNormalization(keras.layers.Layer): | ||
""" | ||
Root Mean Square (RMS) Normalization layer. | ||
This layer normalizes the input tensor based on its RMS value and applies | ||
a learned scaling factor. | ||
Args: | ||
input_dim: int. The dimensionality of the input tensor. | ||
""" | ||
|
||
def __init__(self, input_dim): | ||
super().__init__() | ||
self.scale = self.add_weight( | ||
name="scale", shape=(input_dim,), initializer="ones" | ||
) | ||
|
||
def call(self, x): | ||
""" | ||
Applies RMS normalization to the input tensor. | ||
Args: | ||
x: KerasTensor. Input tensor of shape (batch_size, input_dim). | ||
Returns: | ||
KerasTensor: The RMS-normalized tensor of the same shape (batch_size, input_dim), | ||
scaled by the learned `scale` parameter. | ||
""" | ||
x = ops.cast(x, float) | ||
rrms = ops.rsqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + 1e-6) | ||
return (x * rrms) * self.scale |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from keras_hub.src.models.flux.flux_model import FluxBackbone | ||
from keras_hub.src.models.flux.flux_presets import presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(presets, FluxBackbone) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import pytest | ||
from keras import ops | ||
|
||
from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder | ||
from keras_hub.src.models.flux.flux_model import FluxBackbone | ||
from keras_hub.src.models.vae.vae_backbone import VAEBackbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class FluxBackboneTest(TestCase): | ||
def setUp(self): | ||
vae = VAEBackbone( | ||
[32, 32, 32, 32], | ||
[1, 1, 1, 1], | ||
[32, 32, 32, 32], | ||
[1, 1, 1, 1], | ||
# Use `mode` generate a deterministic output. | ||
sampler_method="mode", | ||
name="vae", | ||
) | ||
clip_l = CLIPTextEncoder( | ||
20, 32, 32, 2, 2, 64, "quick_gelu", -2, name="clip_l" | ||
) | ||
self.init_kwargs = { | ||
"input_channels": 256, | ||
"hidden_size": 1024, | ||
"mlp_ratio": 2.0, | ||
"num_heads": 8, | ||
"depth": 4, | ||
"depth_single_blocks": 8, | ||
"axes_dim": [16, 56, 56], | ||
"theta": 10_000, | ||
"use_bias": True, | ||
"guidance_embed": True, | ||
"image_shape": (32, 256), | ||
"text_shape": (32, 256), | ||
"image_ids_shape": (32, 3), | ||
"text_ids_shape": (32, 3), | ||
"y_shape": (256,), | ||
} | ||
|
||
self.pipeline_models = { | ||
"vae": vae, | ||
"clip_l": clip_l, | ||
} | ||
|
||
self.input_data = { | ||
"image": ops.ones((1, 32, 256)), | ||
"image_ids": ops.ones((1, 32, 3)), | ||
"text": ops.ones((1, 32, 256)), | ||
"text_ids": ops.ones((1, 32, 3)), | ||
"y": ops.ones((1, 256)), | ||
"timesteps": ops.ones((1)), | ||
"guidance": ops.ones((1)), | ||
} | ||
|
||
def test_backbone_basics(self): | ||
self.run_backbone_test( | ||
cls=FluxBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(1, 32, 256), | ||
run_mixed_precision_check=False, | ||
run_quantization_check=False, | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=FluxBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will be part of the generation pipeline so these are added preemptively and unused for now