diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index 456be43977..4f49abcfdd 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -253,6 +253,8 @@ except OptionalDependencyNotAvailable: from .utils.dummy_openvino_and_diffusers_objects import ( OVLatentConsistencyModelPipeline, + OVStableDiffusion3Img2ImgPipeline, + OVStableDiffusion3Pipeline, OVStableDiffusionImg2ImgPipeline, OVStableDiffusionInpaintPipeline, OVStableDiffusionPipeline, @@ -262,6 +264,8 @@ else: from .openvino import ( OVLatentConsistencyModelPipeline, + OVStableDiffusion3Img2ImgPipeline, + OVStableDiffusion3Pipeline, OVStableDiffusionImg2ImgPipeline, OVStableDiffusionInpaintPipeline, OVStableDiffusionPipeline, diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index c4d19b441b..2aefe9620e 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -82,6 +82,7 @@ if is_diffusers_available(): from .modeling_diffusion import ( OVLatentConsistencyModelPipeline, + OVStableDiffusion3Img2ImgPipeline, OVStableDiffusion3Pipeline, OVStableDiffusionImg2ImgPipeline, OVStableDiffusionInpaintPipeline, diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 9099751353..6c38e7679b 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -71,10 +71,23 @@ if is_diffusers_version(">=", "0.29.0"): - from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline, StableDiffusion3Pipeline + from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline else: - StableDiffusion3Pipeline, StableDiffusion3InpaintPipeline, StableDiffusion3Img2ImgPipeline = None, None, None + StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline = None, None +if is_diffusers_version(">=", "0.30.0"): + from diffusers import StableDiffusion3InpaintPipeline +else: + StableDiffusion3InpaintPipeline = None + +PipelineImageInput = Union[ + PIL.Image.Image, + np.ndarray, + torch.Tensor, + List[PIL.Image.Image], + List[np.ndarray], + List[torch.Tensor], +] core = Core() @@ -131,7 +144,7 @@ def __init__( self.vae_decoder = OVModelVaeDecoder(vae_decoder, self) if unet is not None: self.unet = OVModelUnet(unet, self) - self.trasnformer = None + self.transformer = None elif transformer is not None: self.unet = None self.transformer = OVModelTransformer(transformer, self) @@ -659,7 +672,7 @@ def reshape( self.unet.model = self._reshape_unet( self.unet.model, batch_size, height, width, num_images_per_prompt, tokenizer_max_len ) - if self.transformer is not None: + if getattr(self, "transformer", None) is not None: self.transformer.model = self._reshape_transformer( self.transformer.model, batch_size, height, width, num_images_per_prompt, tokenizer_max_len ) @@ -1767,6 +1780,285 @@ def __call__( return StableDiffusionPipelineOutput(image, None) +class OVStableDiffusion3Img2ImgPipelineMixin(StableDiffusion3PipelineMixin): + def prepare_latents(self, image, timesteps, batch_size, num_images_per_prompt, dtype, generator=None): + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + else: + init_latents = self.vae_encoder(sample=image)[0] * self.vae_decoder.config.get("scaling_factor", 0.18215) + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = np.concatenate([init_latents] * additional_image_per_prompt, axis=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = np.concatenate([init_latents], axis=0) + + # add noise to latents using the timesteps + if isinstance(generator, np.random.RandomState): + noise = generator.randn(*init_latents.shape).astype(dtype) + elif isinstance(generator, torch.Generator): + noise = torch.randn(*init_latents.shape, generator=generator).numpy().astype(dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) + + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + + return init_latents + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + strength, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + strength: float = 0.6, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + ): + self.check_inputs( + prompt, + prompt_2, + prompt_3, + strength, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + generator, + ) + + if not isinstance(latents, torch.Tensor): + latents = torch.from_numpy(latents) + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + )[0] + + noise_pred = torch.from_numpy(noise_pred) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae_decoder.config.get("scaling_factor", 1.5305)) + self.vae_decoder.config.get( + "shift_factor", 0.0609 + ) + image = self.vae_decoder(latents)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return StableDiffusionPipelineOutput(image, None) + + class OVStableDiffusion3PipelineBase(OVStableDiffusionPipelineBase): def __init__(self, *args, **kwargs): if kwargs.get("transformer") is None: @@ -1794,7 +2086,7 @@ class OVStableDiffusion3Pipeline(OVStableDiffusion3PipelineBase, StableDiffusion def __call__( self, - prompt: Optional[Union[str, List[str]]] = None, + prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -1848,3 +2140,42 @@ def __call__( generator=generator, **kwargs, ) + + +class OVStableDiffusion3Img2ImgPipeline(OVStableDiffusion3PipelineBase, OVStableDiffusion3Img2ImgPipelineMixin): + auto_model_class = StableDiffusion3Img2ImgPipeline + export_feature = "image-to-image" + + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + **kwargs, + ): + _height = self.height + _width = self.width + expected_batch_size = self._batch_size + + if _height != -1 and _width != -1: + image = self.image_processor.preprocess(image, height=_height, width=_width).transpose(0, 2, 3, 1) + + if expected_batch_size != -1: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = kwargs.get("prompt_embeds").shape[0] + + _raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale) + + return OVStableDiffusion3Img2ImgPipelineMixin.__call__( + self, + prompt=prompt, + image=image, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + **kwargs, + ) diff --git a/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py b/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py index b8bb2d81b1..30d2e49d9a 100644 --- a/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py +++ b/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py @@ -90,3 +90,14 @@ def __init__(self, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["openvino", "diffusers"]) + + +class OVStableDiffusion3Img2ImgPipeline(metaclass=DummyObject): + _backends = ["openvino", "diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["openvino", "diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["openvino", "diffusers"])