Skip to content

Commit

Permalink
add flux
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 22, 2024
1 parent be4624d commit f1f9e92
Show file tree
Hide file tree
Showing 13 changed files with 398 additions and 66 deletions.
8 changes: 8 additions & 0 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,14 @@ def run(self):
from optimum.intel import OVStableDiffusionPipeline

model_cls = OVStableDiffusionPipeline
elif class_name == "StableDiffusion3Pipeline":
from optimum.intel import OVStableDiffusion3Pipeline

model_cls = OVStableDiffusion3Pipeline
elif class_name == "FluxPipeline":
from optimum.intel import OVFluxPipeline

model_cls = OVFluxPipeline
else:
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")

Expand Down
97 changes: 95 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,9 +917,19 @@ def get_diffusion_models_for_export_ext(
except ImportError:
is_sd3 = False

if not is_sd3:
try:
from diffusers import FluxPipeline

is_flux = isinstance(pipeline, FluxPipeline)
except ImportError:
is_flux = False

if not is_sd3 and not is_flux:
return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
if is_sd3:
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
else:
models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)

return None, models_for_export

Expand Down Expand Up @@ -1021,3 +1031,86 @@ def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
models_for_export["text_encoder_3"] = (text_encoder_3, export_config)

return models_for_export


def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
models_for_export = {}

# Text encoder
text_encoder = getattr(pipeline, "text_encoder", None)
if text_encoder is not None:
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=text_encoder,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
model_type="clip-text-model",
)
text_encoder_export_config = text_encoder_config_constructor(
pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["text_encoder"] = (text_encoder, text_encoder_export_config)

transformer = pipeline.transformer
transformer.config.text_encoder_projection_dim = transformer.config.joint_attention_dim
transformer.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
transformer.config.time_cond_proj_dim = None
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=transformer,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="flux-transformer",
)
transformer_export_config = export_config_constructor(
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["transformer"] = (transformer, transformer_export_config)

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
vae_encoder = copy.deepcopy(pipeline.vae)
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_encoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-encoder",
)
vae_encoder_export_config = vae_config_constructor(
vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
vae_decoder = copy.deepcopy(pipeline.vae)
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_decoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-decoder",
)
vae_decoder_export_config = vae_config_constructor(
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)

text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=text_encoder_2,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
model_type="t5-encoder-model",
)
export_config = export_config_constructor(
text_encoder_2.config,
int_dtype=int_dtype,
float_dtype=float_dtype,
)
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)

return models_for_export
125 changes: 120 additions & 5 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
from optimum.exporters.tasks import TasksManager
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import (
DTYPE_MAPPER,
DummyInputGenerator,
DummyPastKeyValuesGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionInputGenerator,
Expand All @@ -63,6 +65,7 @@
DBRXModelPatcher,
DeciLMModelPatcher,
FalconModelPatcher,
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
GptNeoxJapaneseModelPatcher,
GptNeoxModelPatcher,
Expand Down Expand Up @@ -96,9 +99,9 @@ def init_model_configs():
"transformers",
"LlavaNextForConditionalGeneration",
)
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS[
"image-text-to-text"
] = TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"]
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["image-text-to-text"] = (
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"]
)

supported_model_types = [
"_SUPPORTED_MODEL_TYPE",
Expand Down Expand Up @@ -1576,7 +1579,7 @@ def patch_model_for_export(


class PooledProjectionsDummyInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = "pooled_projections"
SUPPORTED_INPUT_NAMES = ["pooled_projections"]

def __init__(
self,
Expand All @@ -1600,8 +1603,10 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int


class DummyTransformerTimestpsInputGenerator(DummyTimestepInputGenerator):
SUPPORTED_INPUT_NAMES = ("timestep", "text_embeds", "time_ids", "timestep_cond", "guidance")

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "timestep":
if input_name in ["timestep", "guidance"]:
shape = [self.batch_size]
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)
return super().generate(input_name, framework, int_dtype, float_dtype)
Expand Down Expand Up @@ -1642,3 +1647,113 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
return ModelPatcher(self, model, model_kwargs=model_kwargs)


class DummyFluxTransformerInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = (
"pixel_values",
"pixel_mask",
"sample",
"latent_sample",
"hidden_states",
"img_ids",
)

def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"],
height: int = DEFAULT_DUMMY_SHAPES["height"],
**kwargs,
):

super().__init__(task, normalized_config, batch_size, num_channels, width, height, **kwargs)
if getattr(normalized_config, "in_channels", None):
self.num_channels = normalized_config.in_channels // 4

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name in ["hidden_states", "sample"]:
shape = [self.batch_size, (self.height // 2) * (self.width // 2), self.num_channels * 4]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
if input_name == "img_ids":
return self.prepare_image_ids(framework, int_dtype, float_dtype)

return super().generate(input_name, framework, int_dtype, float_dtype)

def prepare_image_ids(self, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
img_ids_height = self.height // 2
img_ids_width = self.width // 2
if framework == "pt":
import torch

latent_image_ids = torch.zeros(img_ids_height, img_ids_width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(img_ids_height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(img_ids_width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

latent_image_ids = latent_image_ids[None, :].repeat(self.batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
self.batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids.to(DTYPE_MAPPER.pt(float_dtype))
return latent_image_ids
if framework == "np":
import numpy as np

latent_image_ids = np.zeros(img_ids_height, img_ids_width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + np.arange(img_ids_height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + np.arange(img_ids_width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

latent_image_ids = np.tile(latent_image_ids[None, :], (self.batch_size, 1, 1, 1))
latent_image_ids = latent_image_ids.reshape(
self.batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
latent_image_ids.astype(DTYPE_MAPPER.np[float_dtype])
return latent_image_ids


class DummyFluxTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"decoder_input_ids",
"decoder_attention_mask",
"encoder_outputs",
"encoder_hidden_states",
"txt_ids",
)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "txt_ids":
return self.constant_tensor([self.batch_size, self.sequence_length, 3], 0, DTYPE_MAPPER.pt(float_dtype))
return super().generate(input_name, framework, int_dtype, float_dtype)


@register_in_tasks_manager("flux-transformer", *["semantic-segmentation"], library_name="diffusers")
class FluxTransformerOpenVINOConfig(SD3TransformerOpenVINOConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestpsInputGenerator,
DummyFluxTransformerInputGenerator,
DummyFluxTextInputGenerator,
PooledProjectionsDummyInputGenerator,
)

@property
def inputs(self):
common_inputs = super().inputs
common_inputs.pop("sample", None)
common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"}
common_inputs["txt_ids"] = {0: "batch_size", 1: "sequence_length"}
common_inputs["img_ids"] = {0: "batch_size", 1: "packed_height_width"}
if getattr(self._normalized_config, "guidance_embeds", False):
common_inputs["guidance"] = {0: "batch_size"}
return common_inputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
return FluxTransfromerModelPatcher(self, model, model_kwargs=model_kwargs)
48 changes: 42 additions & 6 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down Expand Up @@ -1979,9 +1979,9 @@ def _dbrx_update_causal_mask_legacy(
offset = 0
mask_shape = attention_mask.shape
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
causal_mask[
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
] = mask_slice
causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = (
mask_slice
)

if (
self.config._attn_implementation == "sdpa"
Expand Down Expand Up @@ -2705,3 +2705,39 @@ def __init__(
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


def _embednb_forward(self, ids: torch.Tensor) -> torch.Tensor:
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."

scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
omega = 1.0 / (theta**scale)

batch_size, seq_length = pos.shape
out = pos.unsqueeze(-1) * omega.unsqueeze(0).unsqueeze(0)
cos_out = torch.cos(out)
sin_out = torch.sin(out)

stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()

n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)


class FluxTransfromerModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
self._model.pos_embed._orig_forward = self._model.pos_embed.forward
self._model.pos_embed.forward = types.MethodType(_embednb_forward, self._model.pos_embed)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)

self._model.pos_embed.forward = self._model.pos_embed._orig_forward
2 changes: 2 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
"OVStableDiffusion3InpaintPipeline",
"OVLatentConsistencyModelPipeline",
"OVLatentConsistencyModelImg2ImgPipeline",
"OVFluxPipeline",
"OVPipelineForImage2Image",
"OVPipelineForText2Image",
"OVPipelineForInpainting",
Expand All @@ -124,6 +125,7 @@
"OVStableDiffusion3InpaintPipeline",
"OVLatentConsistencyModelPipeline",
"OVLatentConsistencyModelImg2ImgPipeline",
"OVFluxPipeline",
"OVPipelineForImage2Image",
"OVPipelineForText2Image",
"OVPipelineForInpainting",
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
if is_diffusers_available():
from .modeling_diffusion import (
OVDiffusionPipeline,
OVFluxPipeline,
OVLatentConsistencyModelImg2ImgPipeline,
OVLatentConsistencyModelPipeline,
OVPipelineForImage2Image,
Expand Down
Loading

0 comments on commit f1f9e92

Please sign in to comment.