diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index 99cbc591090b..eb02f6cb321c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -111,7 +111,15 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): r""" Encodes the prompt into text encoder hidden states. @@ -125,32 +133,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -170,7 +194,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -179,6 +203,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. @@ -188,9 +214,56 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return prompt_embeds - def __call__( + def check_inputs( self, prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: Union[str, List[str]] = None, height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, @@ -200,28 +273,86 @@ def __call__( eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): - if isinstance(prompt, str): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + One or a list of [numpy generator(s)](TODO) to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): batch_size = 1 - elif isinstance(prompt, list): + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) + batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random @@ -232,7 +363,12 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # get the initial random noise unless the user supplied it diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 80c4a8692a05..67d3f44e6d4b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -161,7 +161,15 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): r""" Encodes the prompt into text encoder hidden states. @@ -175,32 +183,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -220,7 +244,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -229,6 +253,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. @@ -238,6 +264,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return prompt_embeds + def check_inputs( + self, + prompt: Union[str, List[str]], + callback_steps: int, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def __call__( self, prompt: Union[str, List[str]], @@ -249,6 +317,8 @@ def __call__( num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, @@ -288,6 +358,13 @@ def __call__( [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -308,24 +385,21 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): batch_size = 1 - elif isinstance(prompt, list): + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + batch_size = prompt_embeds.shape[0] if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if generator is None: generator = np.random @@ -340,7 +414,12 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index df586d39f648..0bb39c4b1c61 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -162,7 +162,15 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): r""" Encodes the prompt into text encoder hidden states. @@ -176,32 +184,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -221,7 +245,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -230,6 +254,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. @@ -239,6 +265,54 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + @torch.no_grad() def __call__( self, @@ -254,6 +328,8 @@ def __call__( eta: float = 0.0, generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, @@ -300,6 +376,13 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -321,23 +404,18 @@ def __call__( (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): batch_size = 1 - elif isinstance(prompt, list): + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) + batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random @@ -351,7 +429,12 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) num_channels_latents = NUM_LATENT_CHANNELS diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py index 5cb3abb4f54e..8ef7a781451c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -147,7 +147,15 @@ def __init__( self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): r""" Encodes the prompt into text encoder hidden states. @@ -161,32 +169,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -206,7 +230,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -215,6 +239,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. @@ -224,6 +250,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida return prompt_embeds + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def __call__( self, prompt: Union[str, List[str]], @@ -236,6 +304,8 @@ def __call__( num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, @@ -280,6 +350,13 @@ def __call__( [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -300,24 +377,21 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): batch_size = 1 - elif isinstance(prompt, list): + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + batch_size = prompt_embeds.shape[0] if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if generator is None: generator = np.random @@ -333,7 +407,12 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py index b91262551b0f..8db19c2b9109 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -70,16 +70,85 @@ def __call__( eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + noise_level TODO + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs self.check_inputs(prompt, image, noise_level, callback_steps) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -88,7 +157,13 @@ def __call__( # 3. Encode input prompt text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)] @@ -199,45 +274,59 @@ def decode_latents(self, latents): image = image.transpose((0, 2, 3, 1)) return image - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device, + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - # if hasattr(text_inputs, "attention_mask"): - # attention_mask = text_inputs.attention_mask.to(device) - # else: - # attention_mask = None - - # no positional arguments to text_encoder - text_embeddings = self.text_encoder( - input_ids=text_input_ids.int().to(device), - # attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + # no positional arguments to text_encoder + prompt_embeds = self.text_encoder( + input_ids=text_input_ids.int().to(device), + # attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] - bs_embed, seq_len, _ = text_embeddings.shape + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) - text_embeddings = text_embeddings.reshape(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -277,6 +366,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr ) uncond_embeddings = uncond_embeddings[0] + if do_classifier_free_guidance: seq_len = uncond_embeddings.shape[1] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) @@ -285,6 +375,6 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds]) - return text_embeddings + return prompt_embeds diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py index 74783faae421..3a5f9379ae50 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py @@ -133,6 +133,76 @@ def test_pipeline_dpm_multistep(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_prompt_embeds(self): + pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider") + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs() + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + inputs = self.get_dummy_inputs() + prompt = 3 * [inputs.pop("prompt")] + + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_inputs = text_inputs["input_ids"] + + prompt_embeds = pipe.text_encoder(input_ids=text_inputs.astype(np.int32))[0] + + inputs["prompt_embeds"] = prompt_embeds + + # forward + output = pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + + def test_stable_diffusion_negative_prompt_embeds(self): + pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider") + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs() + negative_prompt = 3 * ["this is a negative prompt"] + inputs["negative_prompt"] = negative_prompt + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + inputs = self.get_dummy_inputs() + prompt = 3 * [inputs.pop("prompt")] + + embeds = [] + for p in [prompt, negative_prompt]: + text_inputs = pipe.tokenizer( + p, + padding="max_length", + max_length=pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_inputs = text_inputs["input_ids"] + + embeds.append(pipe.text_encoder(input_ids=text_inputs.astype(np.int32))[0]) + + inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds + + # forward + output = pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + @nightly @require_onnxruntime