Skip to content

Commit

Permalink
[Image Generation] Inpainting for FLUX
Browse files Browse the repository at this point in the history
  • Loading branch information
likholat committed Feb 6, 2025
1 parent 523ea7c commit f0f09ad
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 9 deletions.
2 changes: 1 addition & 1 deletion SUPPORTED_MODELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ The pipeline can work with other similar topologies produced by `optimum-intel`
<td><code>Flux</code></td>
<td>Supported</td>
<td>Supported</td>
<td>Not supported</td>
<td>Supported</td>
<td>Not supported</td>
<td>
<ul>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ class OPENVINO_GENAI_EXPORTS InpaintingPipeline {
const UNet2DConditionModel& unet,
const AutoencoderKL& vae);

// creates Flux pipeline from building blocks
static InpaintingPipeline flux(
const std::shared_ptr<Scheduler>& scheduler,
const CLIPTextModel& clip_text_model,
const T5EncoderModel& t5_text_encoder,
const FluxTransformer2DModel& transformer,
const AutoencoderKL& vae);

ImageGenerationConfig get_generation_config() const;
void set_generation_config(const ImageGenerationConfig& generation_config);

Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/image_generation/diffusion_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class DiffusionPipeline {
virtual void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const = 0;

void blend_latents(ov::Tensor image_latent, ov::Tensor noise, ov::Tensor mask, ov::Tensor latent, size_t inference_step) {
OPENVINO_ASSERT(m_pipeline_type == PipelineType::INPAINTING, "'prepare_mask_latents' can be called for inpainting pipeline only");
OPENVINO_ASSERT(m_pipeline_type == PipelineType::INPAINTING, "'blend_latents' can be called for inpainting pipeline only");
OPENVINO_ASSERT(image_latent.get_shape() == latent.get_shape(), "Shapes for current", latent.get_shape(), "and initial image latents ", image_latent.get_shape(), " must match");

ov::Tensor noised_image_latent(image_latent.get_element_type(), {});
Expand Down
134 changes: 129 additions & 5 deletions src/cpp/src/image_generation/flux_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ class FluxPipeline : public DiffusionPipeline {
m_image_processor = std::make_shared<ImageProcessor>(device, do_normalize, do_binarize, gray_scale_source);
m_image_resizer = std::make_shared<ImageResizer>(device, ov::element::u8, "NHWC", ov::op::v11::Interpolate::InterpolateMode::BICUBIC_PILLOW);
}

if (m_pipeline_type == PipelineType::INPAINTING) {
bool do_normalize = false, do_binarize = true;
m_mask_processor_rgb = std::make_shared<ImageProcessor>(device, do_normalize, do_binarize, false);
m_mask_processor_gray = std::make_shared<ImageProcessor>(device, do_normalize, do_binarize, true);
m_mask_resizer = std::make_shared<ImageResizer>(device, ov::element::f32, "NCHW", ov::op::v11::Interpolate::InterpolateMode::NEAREST);
}
}

FluxPipeline(PipelineType pipeline_type, const std::filesystem::path& root_dir) : FluxPipeline(pipeline_type) {
Expand Down Expand Up @@ -322,16 +329,25 @@ class FluxPipeline : public DiffusionPipeline {
num_channels_latents,
height,
width};
ov::Tensor latent(ov::element::f32, {}), proccesed_image, image_latents, noise;
ov::Tensor latent, noise, proccesed_image, image_latents;

if (initial_image) {
proccesed_image = m_image_resizer->execute(initial_image, generation_config.height, generation_config.width);
proccesed_image = m_image_processor->execute(proccesed_image);

image_latents = m_vae->encode(proccesed_image, generation_config.generator);
noise = generation_config.generator->randn_tensor(latent_shape);
m_scheduler->scale_noise(image_latents, m_latent_timestep, noise);
latent = pack_latents(image_latents, generation_config.num_images_per_prompt, num_channels_latents, height, width);

latent = ov::Tensor(image_latents.get_element_type(), image_latents.get_shape());
image_latents.copy_to(latent);

m_scheduler->scale_noise(latent, m_latent_timestep, noise);
latent = pack_latents(latent, generation_config.num_images_per_prompt, num_channels_latents, height, width);

if (m_pipeline_type == PipelineType::INPAINTING) {
noise = pack_latents(noise, generation_config.num_images_per_prompt, num_channels_latents, height, width);
image_latents = pack_latents(image_latents, generation_config.num_images_per_prompt, num_channels_latents, height, width);
}
} else {
noise = generation_config.generator->randn_tensor(latent_shape);
latent = pack_latents(noise, generation_config.num_images_per_prompt, num_channels_latents, height, width);
Expand All @@ -344,6 +360,79 @@ class FluxPipeline : public DiffusionPipeline {
OPENVINO_THROW("LORA adapters are not implemented for FLUX pipeline yet");
}

std::tuple<ov::Tensor, ov::Tensor> prepare_mask_latents(ov::Tensor mask_image, ov::Tensor processed_image, const ImageGenerationConfig& generation_config) {
OPENVINO_ASSERT(m_pipeline_type == PipelineType::INPAINTING, "'prepare_mask_latents' can be called for inpainting pipeline only");

const size_t vae_scale_factor = m_vae->get_vae_scale_factor();
ov::Shape target_shape = processed_image.get_shape();

// Prepare mask latent variables
ov::Tensor mask_condition = m_image_resizer->execute(mask_image, generation_config.height, generation_config.width);
std::shared_ptr<IImageProcessor> mask_processor = mask_condition.get_shape()[3] == 1 ? m_mask_processor_gray : m_mask_processor_rgb;
mask_condition = mask_processor->execute(mask_condition);

size_t num_channels_latents = m_transformer->get_config().in_channels / 4;
size_t height = generation_config.height / vae_scale_factor;
size_t width = generation_config.width / vae_scale_factor;

// resize mask to shape of latent space
ov::Tensor mask = m_mask_resizer->execute(mask_condition, height, width);
mask = numpy_utils::repeat(mask, generation_config.num_images_per_prompt);

// Create masked image:
// masked_image = init_image * (mask_condition < 0.5)
ov::Tensor masked_image(ov::element::f32, processed_image.get_shape());
const float * mask_condition_data = mask_condition.data<const float>();
const float * processed_image_data = processed_image.data<const float>();
float * masked_image_data = masked_image.data<float>();
for (size_t i = 0, plane_size = mask_condition.get_shape()[2] * mask_condition.get_shape()[3]; i < mask_condition.get_size(); ++i) {
masked_image_data[i + 0 * plane_size] = mask_condition_data[i] < 0.5f ? processed_image_data[i + 0 * plane_size] : 0.0f;
masked_image_data[i + 1 * plane_size] = mask_condition_data[i] < 0.5f ? processed_image_data[i + 1 * plane_size] : 0.0f;
masked_image_data[i + 2 * plane_size] = mask_condition_data[i] < 0.5f ? processed_image_data[i + 2 * plane_size] : 0.0f;
}

ov::Tensor masked_image_latent;
// TODO: support is_inpainting_model() == true
// masked_image_latent = m_vae->encode(masked_image, generation_config.generator);
// // masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
// float * masked_image_latent_data = masked_image_latent.data<float>();
// for (size_t i = 0; i < masked_image_latent.get_size(); ++i) {
// masked_image_latent_data[i] = (masked_image_latent_data[i] - m_vae->get_config().shift_factor) * m_vae->get_config().scaling_factor;
// }
// masked_image_latent = pack_latents(masked_image_latent, generation_config.num_images_per_prompt, num_channels_latents, height, width);

// mask.repeat(1, num_channels_latents, 1, 1)
auto repeat_mask = [](const ov::Tensor& mask, size_t num_channels_latents) -> ov::Tensor {
const ov::Shape& mask_shape = mask.get_shape();
OPENVINO_ASSERT(mask_shape.size() == 4 && mask_shape[1] == 1, "Mask must have shape (batch_size, 1, height, width)");

size_t batch_size = mask_shape[0], height = mask_shape[2], width = mask_shape[3];
size_t spatial_size = height * width;

ov::Shape target_shape = {batch_size, num_channels_latents, height, width};
ov::Tensor repeated_mask(mask.get_element_type(), target_shape);

const float* src_data = mask.data<float>();
float* dst_data = repeated_mask.data<float>();

for (size_t b = 0; b < batch_size; ++b) {
const float* src_batch = src_data + b * spatial_size; // Pointer to batch start
float* dst_batch = dst_data + b * num_channels_latents * spatial_size;

for (size_t c = 0; c < num_channels_latents; ++c) {
std::memcpy(dst_batch + c * spatial_size, src_batch, spatial_size * sizeof(float));
}
}

return repeated_mask;
};

ov::Tensor repeated_mask = repeat_mask(mask, num_channels_latents);
ov::Tensor mask_packed = pack_latents(repeated_mask, generation_config.num_images_per_prompt, num_channels_latents, height, width);

return std::make_tuple(mask_packed, masked_image_latent);
}

ov::Tensor generate(const std::string& positive_prompt,
ov::Tensor initial_image,
ov::Tensor mask_image,
Expand Down Expand Up @@ -380,14 +469,20 @@ class FluxPipeline : public DiffusionPipeline {
std::vector<float> timesteps;
if (m_pipeline_type == PipelineType::TEXT_2_IMAGE) {
timesteps = m_scheduler->get_float_timesteps();
m_latent_timestep = timesteps[0];
} else {
timesteps = get_timesteps(m_custom_generation_config.num_inference_steps, m_custom_generation_config.strength);
}
m_latent_timestep = timesteps[0];

ov::Tensor latents, processed_image, image_latent, noise;
std::tie(latents, processed_image, image_latent, noise) = prepare_latents(initial_image, m_custom_generation_config);

// prepare mask latents
ov::Tensor mask, masked_image_latent;
if (m_pipeline_type == PipelineType::INPAINTING) {
std::tie(mask, masked_image_latent) = prepare_mask_latents(mask_image, processed_image, m_custom_generation_config);
}

// 6. Denoising loop
ov::Tensor timestep(ov::element::f32, {1});
float* timestep_data = timestep.data<float>();
Expand All @@ -400,13 +495,16 @@ class FluxPipeline : public DiffusionPipeline {
auto scheduler_step_result = m_scheduler->step(noise_pred_tensor, latents, inference_step, m_custom_generation_config.generator);
latents = scheduler_step_result["latent"];

if (m_pipeline_type == PipelineType::INPAINTING) {
blend_latents(latents, image_latent, mask, noise, inference_step);
}

if (callback && callback(inference_step, timesteps.size(), latents)) {
return ov::Tensor(ov::element::u8, {});
}
}

latents = unpack_latents(latents, m_custom_generation_config.height, m_custom_generation_config.width, vae_scale_factor);

return m_vae->decode(latents);
}

Expand Down Expand Up @@ -488,6 +586,32 @@ class FluxPipeline : public DiffusionPipeline {
}
}

void blend_latents(ov::Tensor latents,
const ov::Tensor image_latent,
const ov::Tensor mask,
const ov::Tensor noise,
size_t inference_step) {
OPENVINO_ASSERT(m_pipeline_type == PipelineType::INPAINTING, "'blend_latents' can be called for inpainting pipeline only");
OPENVINO_ASSERT(image_latent.get_shape() == latents.get_shape(),
"Shapes for current", latents.get_shape(), "and initial image latents ", image_latent.get_shape(), " must match");

ov::Tensor init_latents_proper(image_latent.get_element_type(), image_latent.get_shape());
image_latent.copy_to(init_latents_proper);

std::vector<float> timesteps = m_scheduler->get_float_timesteps();
if (inference_step < timesteps.size() - 1) {
float noise_timestep = timesteps[inference_step + 1];
m_scheduler->scale_noise(init_latents_proper, noise_timestep, noise);
}

float * latents_data = latents.data<float>();
const float * mask_data = mask.data<float>();
const float * init_latents_proper_data = init_latents_proper.data<float>();
for (size_t i = 0; i < latents.get_size(); ++i) {
latents_data[i] = (1.0f - mask_data[i]) * init_latents_proper_data[i] + mask_data[i] * latents_data[i];
}
}

std::shared_ptr<FluxTransformer2DModel> m_transformer = nullptr;
std::shared_ptr<CLIPTextModel> m_clip_text_encoder = nullptr;
std::shared_ptr<T5EncoderModel> m_t5_text_encoder = nullptr;
Expand Down
19 changes: 19 additions & 0 deletions src/cpp/src/image_generation/inpainting_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "image_generation/stable_diffusion_pipeline.hpp"
#include "image_generation/stable_diffusion_xl_pipeline.hpp"
#include "image_generation/flux_pipeline.hpp"

#include "utils.hpp"

Expand All @@ -25,6 +26,8 @@ InpaintingPipeline::InpaintingPipeline(const std::filesystem::path& root_dir) {
m_impl = std::make_shared<StableDiffusionPipeline>(PipelineType::INPAINTING, root_dir);
} else if (class_name == "StableDiffusionXLPipeline" || class_name == "StableDiffusionXLInpaintPipeline") {
m_impl = std::make_shared<StableDiffusionXLPipeline>(PipelineType::INPAINTING, root_dir);
} else if (class_name == "FluxPipeline" || class_name == "FluxInpaintPipeline") {
m_impl = std::make_shared<FluxPipeline>(PipelineType::INPAINTING, root_dir);
} else {
OPENVINO_THROW("Unsupported inpainting pipeline '", class_name, "'");
}
Expand All @@ -39,6 +42,8 @@ InpaintingPipeline::InpaintingPipeline(const std::filesystem::path& root_dir, co
m_impl = std::make_shared<StableDiffusionPipeline>(PipelineType::INPAINTING, root_dir, device, properties);
} else if (class_name == "StableDiffusionXLPipeline" || class_name == "StableDiffusionXLInpaintPipeline") {
m_impl = std::make_shared<StableDiffusionXLPipeline>(PipelineType::INPAINTING, root_dir, device, properties);
} else if (class_name == "FluxPipeline" || class_name == "FluxInpaintPipeline") {
m_impl = std::make_shared<FluxPipeline>(PipelineType::INPAINTING, root_dir, device, properties);
} else {
OPENVINO_THROW("Unsupported inpainting pipeline '", class_name, "'");
}
Expand Down Expand Up @@ -99,6 +104,20 @@ InpaintingPipeline InpaintingPipeline::stable_diffusion_xl(
return InpaintingPipeline(impl);
}

InpaintingPipeline InpaintingPipeline::flux(
const std::shared_ptr<Scheduler>& scheduler,
const CLIPTextModel& clip_text_model,
const T5EncoderModel& t5_text_encoder,
const FluxTransformer2DModel& transformer,
const AutoencoderKL& vae) {
auto impl = std::make_shared<FluxPipeline>(PipelineType::INPAINTING, clip_text_model, t5_text_encoder, transformer, vae);

assert(scheduler != nullptr);
impl->set_scheduler(scheduler);

return InpaintingPipeline(impl);
}

ImageGenerationConfig InpaintingPipeline::get_generation_config() const {
return m_impl->get_generation_config();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ size_t FlowMatchEulerDiscreteScheduler::_index_for_timestep(float timestep) {
}

void FlowMatchEulerDiscreteScheduler::scale_noise(ov::Tensor sample, float timestep, ov::Tensor noise) {
OPENVINO_ASSERT(timestep == -1, "Timestep is not computed yet");
OPENVINO_ASSERT(timestep > 0, "Timestep is not computed yet");

size_t index_for_timestep;
if (m_begin_index == -1) {
Expand All @@ -180,7 +180,6 @@ void FlowMatchEulerDiscreteScheduler::scale_noise(ov::Tensor sample, float times
for (size_t i = 0; i < sample.get_size(); ++i) {
sample_data[i] = sigma * noise_data[i] + (1.0f - sigma) * sample_data[i];
}

}

void FlowMatchEulerDiscreteScheduler::set_timesteps_with_sigma(std::vector<float> sigma, float mu) {
Expand Down
3 changes: 3 additions & 0 deletions src/python/openvino_genai/py_openvino_genai.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,9 @@ class InpaintingPipeline:
This class is used for generation with inpainting models.
"""
@staticmethod
def flux(scheduler: Scheduler, clip_text_model: CLIPTextModel, t5_encoder_model: T5EncoderModel, transformer: FluxTransformer2DModel, vae: AutoencoderKL) -> InpaintingPipeline:
...
@staticmethod
def latent_consistency_model(scheduler: Scheduler, clip_text_model: CLIPTextModel, unet: UNet2DConditionModel, vae: AutoencoderKL) -> InpaintingPipeline:
...
@staticmethod
Expand Down
1 change: 1 addition & 0 deletions src/python/py_image_generation_pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ void init_image_generation_pipelines(py::module_& m) {
.def_static("stable_diffusion", &ov::genai::InpaintingPipeline::stable_diffusion, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("unet"), py::arg("vae"))
.def_static("latent_consistency_model", &ov::genai::InpaintingPipeline::latent_consistency_model, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("unet"), py::arg("vae"))
.def_static("stable_diffusion_xl", &ov::genai::InpaintingPipeline::stable_diffusion_xl, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("clip_text_model_with_projection"), py::arg("unet"), py::arg("vae"))
.def_static("flux", &ov::genai::InpaintingPipeline::flux, py::arg("scheduler"), py::arg("clip_text_model"), py::arg("t5_encoder_model"), py::arg("transformer"), py::arg("vae"))
.def(
"compile",
[](ov::genai::InpaintingPipeline& pipe,
Expand Down

0 comments on commit f0f09ad

Please sign in to comment.