From 9e62dcec099c64d864eb49834f47e7a07f2ac3b9 Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Tue, 4 Feb 2025 16:22:25 +0000 Subject: [PATCH 1/2] Added IP-Adapter to --- src/diffusers/pipelines/flux/pipeline_flux.py | 27 +-- .../flux/pipeline_flux_control_img2img.py | 1 - .../flux/pipeline_flux_control_inpaint.py | 1 - .../pipelines/flux/pipeline_flux_img2img.py | 213 +++++++++++++++++- tests/pipelines/flux/test_pipeline_flux.py | 80 +++---- .../flux/test_pipeline_flux_img2img.py | 58 ++++- 6 files changed, 318 insertions(+), 62 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index aa02dc1de5da..e6f9de3ac302 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -645,13 +645,13 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -707,28 +707,29 @@ def __call__( prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. - negative_ip_adapter_image: - (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index c386f41c8827..72f2298fa643 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -438,7 +438,6 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start - # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.check_inputs def check_inputs( self, prompt, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index 192b690f69e5..cd029b8efbc9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -477,7 +477,6 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start - # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.check_inputs def check_inputs( self, prompt, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index a63ecdadbd0c..94737187a3ef 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -17,10 +17,17 @@ import numpy as np import torch -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -159,7 +166,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin): r""" The Flux pipeline for image inpainting. @@ -184,10 +191,14 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile tokenizer_2 (`T5TokenizerFast`): Second Tokenizer of class [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + image_encoder (`CLIPImageProcessor`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`CLIPVisionModelWithProjection`, *optional*): + Image processor for IP Adapter. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = [] + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( @@ -199,6 +210,8 @@ def __init__( text_encoder_2: T5EncoderModel, tokenizer_2: T5TokenizerFast, transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, ): super().__init__() @@ -210,6 +223,8 @@ def __init__( tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible @@ -422,6 +437,50 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers + ): + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + + image_embeds.append(single_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + def check_inputs( self, prompt, @@ -429,8 +488,12 @@ def check_inputs( strength, height, width, + negative_prompt=None, + negative_prompt_2=None, prompt_embeds=None, + negative_prompt_embeds=None, pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -468,10 +531,33 @@ def check_inputs( elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @@ -586,6 +672,9 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, @@ -597,7 +686,13 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -614,7 +709,16 @@ def __call__( instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list @@ -656,9 +760,29 @@ def __call__( prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -697,14 +821,18 @@ def __call__( strength, height, width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale - self._joint_attention_kwargs = joint_attention_kwargs + self._joint_attention_kwargs = joint_attention_kwargs or {} self._interrupt = False # 2. Preprocess image @@ -724,6 +852,10 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( prompt_embeds, pooled_prompt_embeds, @@ -738,6 +870,21 @@ def __call__( max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) # 4.Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas @@ -791,12 +938,48 @@ def __call__( else: guidance = None + # handle image prompt + any_ip_adapter_input = ip_adapter_image is not None or ip_adapter_image_embeds is not None + any_negative_ip_adapter_input = ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ) + + if any_ip_adapter_input and not any_negative_ip_adapter_input: + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif any_negative_ip_adapter_input and not any_ip_adapter_input: + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + image_embeds = ( + self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if any_ip_adapter_input + else None + ) + + negative_image_embeds = ( + self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if any_negative_ip_adapter_input + else None + ) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( @@ -811,6 +994,22 @@ def __call__( return_dict=False, )[0] + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index bab343a5954c..cd10ecf563c6 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -333,6 +333,43 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() + def get_expected_slice(self): + return np.array( + [ + 0.1855, + 0.1680, + 0.1406, + 0.1953, + 0.1699, + 0.1465, + 0.2012, + 0.1738, + 0.1484, + 0.2051, + 0.1797, + 0.1523, + 0.2012, + 0.1719, + 0.1445, + 0.2070, + 0.1777, + 0.1465, + 0.2090, + 0.1836, + 0.1484, + 0.2129, + 0.1875, + 0.1523, + 0.2090, + 0.1816, + 0.1484, + 0.2110, + 0.1836, + 0.1543, + ], + dtype=np.float32, + ) + def get_inputs(self, device, seed=0): if str(device).startswith("mps"): generator = torch.manual_seed(seed) @@ -340,12 +377,14 @@ def get_inputs(self, device, seed=0): generator = torch.Generator(device="cpu").manual_seed(seed) prompt_embeds = torch.load( - hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt") + hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt"), + weights_only=True ) pooled_prompt_embeds = torch.load( hf_hub_download( repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt" - ) + ), + weights_only=True ) negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) @@ -381,42 +420,7 @@ def test_flux_ip_adapter_inference(self): image = pipe(**inputs).images[0] image_slice = image[0, :10, :10] - expected_slice = np.array( - [ - 0.1855, - 0.1680, - 0.1406, - 0.1953, - 0.1699, - 0.1465, - 0.2012, - 0.1738, - 0.1484, - 0.2051, - 0.1797, - 0.1523, - 0.2012, - 0.1719, - 0.1445, - 0.2070, - 0.1777, - 0.1465, - 0.2090, - 0.1836, - 0.1484, - 0.2129, - 0.1875, - 0.1523, - 0.2090, - 0.1816, - 0.1484, - 0.2110, - 0.1836, - 0.1543, - ], - dtype=np.float32, - ) - + expected_slice = self.get_expected_slice() max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) assert max_diff < 1e-4, f"{image_slice} != {expected_slice}" diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py index a1336fabdb89..1ea4ac46bbb9 100644 --- a/tests/pipelines/flux/test_pipeline_flux_img2img.py +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -12,13 +12,14 @@ torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin +from .test_pipeline_flux import FluxIPAdapterPipelineSlowTests enable_full_determinism() -class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): +class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin): pipeline_class = FluxImg2ImgPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -85,6 +86,8 @@ def get_dummy_components(self): "tokenizer_2": tokenizer_2, "transformer": transformer, "vae": vae, + "image_encoder": None, + "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0): @@ -161,3 +164,54 @@ def test_flux_image_output_shape(self): image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + + +class FluxIPAdapterImg2ImgPipelineSlowTests(FluxIPAdapterPipelineSlowTests): + """ + Same test as in FluxPipeline, only with inital `image` and `strength` parameters. + """ + pipeline_class = FluxImg2ImgPipeline + + def get_expected_slice(self): + return np.array( + [ + 0.3125, + 0.2812, + 0.2285, + 0.3125, + 0.2734, + 0.2207, + 0.3125, + 0.2734, + 0.2187, + 0.3125, + 0.2734, + 0.2226, + 0.3203, + 0.2832, + 0.2285, + 0.3164, + 0.2812, + 0.2207, + 0.3125, + 0.2792, + 0.2187, + 0.3085, + 0.2734, + 0.2128, + 0.3027, + 0.2675, + 0.2089, + 0.3144, + 0.2773, + 0.2226, + ], + dtype=np.float32, + ) + + def get_inputs(self, device, seed=0): + return { + "image": floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device), + "strength": 0.8, + **super().get_inputs(device, seed), + } \ No newline at end of file From 3d662a254eb81b64b7c00b2f1b5a9c09a4d4097f Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Tue, 4 Feb 2025 16:29:28 +0000 Subject: [PATCH 2/2] make style && make quality --- tests/pipelines/flux/test_pipeline_flux.py | 4 ++-- tests/pipelines/flux/test_pipeline_flux_img2img.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index cd10ecf563c6..a85c1daf0d88 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -378,13 +378,13 @@ def get_inputs(self, device, seed=0): prompt_embeds = torch.load( hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt"), - weights_only=True + weights_only=True, ) pooled_prompt_embeds = torch.load( hf_hub_download( repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt" ), - weights_only=True + weights_only=True, ) negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py index 1ea4ac46bbb9..2d141a318f0f 100644 --- a/tests/pipelines/flux/test_pipeline_flux_img2img.py +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -167,9 +167,8 @@ def test_flux_image_output_shape(self): class FluxIPAdapterImg2ImgPipelineSlowTests(FluxIPAdapterPipelineSlowTests): - """ - Same test as in FluxPipeline, only with inital `image` and `strength` parameters. - """ + """Same test as in FluxIPAdapterPipelineSlowTests, only with inital `image` and `strength` parameters.""" + pipeline_class = FluxImg2ImgPipeline def get_expected_slice(self): @@ -214,4 +213,4 @@ def get_inputs(self, device, seed=0): "image": floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device), "strength": 0.8, **super().get_inputs(device, seed), - } \ No newline at end of file + }