diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 68782d1f5f79..17c083914753 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -30,7 +30,8 @@ class VaeImageProcessor(ConfigMixin): Args: do_resize (`bool`, *optional*, defaults to `True`): - Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from `preprocess` method vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this factor. @@ -38,6 +39,8 @@ class VaeImageProcessor(ConfigMixin): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1] + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. """ config_name = CONFIG_NAME @@ -49,11 +52,12 @@ def __init__( vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = True, + do_convert_rgb: bool = False, ): super().__init__() @staticmethod - def numpy_to_pil(images): + def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: """ Convert a numpy image or a batch of images to a PIL image. """ @@ -69,7 +73,19 @@ def numpy_to_pil(images): return pil_images @staticmethod - def numpy_to_pt(images): + def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: + """ + Convert a PIL image or a list of PIL images to numpy arrays. + """ + if not isinstance(images, list): + images = [images] + images = [np.array(image).astype(np.float32) / 255.0 for image in images] + images = np.stack(images, axis=0) + + return images + + @staticmethod + def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: """ Convert a numpy image to a pytorch tensor """ @@ -80,7 +96,7 @@ def numpy_to_pt(images): return images @staticmethod - def pt_to_numpy(images): + def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: """ Convert a pytorch tensor to a numpy image """ @@ -101,18 +117,39 @@ def denormalize(images): """ return (images / 2 + 0.5).clamp(0, 1) - def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: + @staticmethod + def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: + """ + Converts an image to RGB format. + """ + image = image.convert("RGB") + return image + + def resize( + self, + image: PIL.Image.Image, + height: Optional[int] = None, + width: Optional[int] = None, + ) -> PIL.Image.Image: """ Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` """ - w, h = images.size - w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor - images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample]) - return images + if height is None: + height = image.height + if width is None: + width = image.width + + width, height = ( + x - x % self.config.vae_scale_factor for x in (width, height) + ) # resize to integer multiple of vae_scale_factor + image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) + return image def preprocess( self, image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + height: Optional[int] = None, + width: Optional[int] = None, ) -> torch.Tensor: """ Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors" @@ -126,10 +163,11 @@ def preprocess( ) if isinstance(image[0], PIL.Image.Image): + if self.config.do_convert_rgb: + image = [self.convert_to_rgb(i) for i in image] if self.config.do_resize: - image = [self.resize(i) for i in image] - image = [np.array(i).astype(np.float32) / 255.0 for i in image] - image = np.stack(image, axis=0) # to np + image = [self.resize(i, height, width) for i in image] + image = self.pil_to_numpy(image) # to np image = self.numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): @@ -146,7 +184,12 @@ def preprocess( elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) - _, _, height, width = image.shape + _, channel, height, width = image.shape + + # don't need any preprocess if the image is latents + if channel == 4: + return image + if self.config.do_resize and ( height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 ): diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index b10d85f722eb..f0d4d91ce966 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -69,6 +69,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -538,21 +543,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + if image.shape[1] == 4: + init_latents = image + else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective" + f" batch size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) - init_latents = self.vae.config.scaling_factor * init_latents + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -586,7 +596,14 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, @@ -609,9 +626,10 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 632cd546ed0a..4ac43377c82a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -29,7 +29,6 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( - PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, is_compiled_module, @@ -172,7 +171,10 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -477,17 +479,12 @@ def check_inputs( self, prompt, image, - height, - width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -592,21 +589,26 @@ def check_inputs( def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): raise TypeError( - "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + "image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors" ) if image_is_pil: image_batch_size = 1 - elif image_is_tensor: - image_batch_size = image.shape[0] - elif image_is_pil_list: - image_batch_size = len(image) - elif image_is_tensor_list: + else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): @@ -633,29 +635,7 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - if not isinstance(image, torch.Tensor): - if isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - images = [] - - for image_ in image: - image_ = image_.convert("RGB") - image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - image_ = np.array(image_) - image_ = image_[None, :] - images.append(image_) - - image = images - - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -691,31 +671,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - def _default_height_width(self, height, width, image): - # NOTE: It is possible that a list of images have different - # dimensions for each image, so just checking the first image - # is not _exactly_ correct, but it is simple. - while isinstance(image, list): - image = image[0] - - if height is None: - if isinstance(image, PIL.Image.Image): - height = image.height - elif isinstance(image, torch.Tensor): - height = image.shape[2] - - height = (height // 8) * 8 # round down to nearest multiple of 8 - - if width is None: - if isinstance(image, PIL.Image.Image): - width = image.width - elif isinstance(image, torch.Tensor): - width = image.shape[3] - - width = (width // 8) * 8 # round down to nearest multiple of 8 - - return height, width - # override DiffusionPipeline def save_pretrained( self, @@ -733,7 +688,14 @@ def save_pretrained( def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -760,8 +722,8 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, - `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If @@ -837,15 +799,11 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - # 0. Default height and width to unet - height, width = self._default_height_width(height, width, image) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, - height, - width, callback_steps, negative_prompt, prompt_embeds, @@ -903,6 +861,7 @@ def __call__( do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) + height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] @@ -922,6 +881,7 @@ def __call__( images.append(image_) image = images + height, width = image[0].shape[-2:] else: assert False diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 72b90f334725..6667cf43ce46 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -29,7 +29,6 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( - PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, @@ -198,7 +197,10 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -503,17 +505,12 @@ def check_inputs( self, prompt, image, - height, - width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -615,24 +612,30 @@ def check_inputs( else: assert False + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): raise TypeError( - "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + "image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors" ) if image_is_pil: image_batch_size = 1 - elif image_is_tensor: - image_batch_size = image.shape[0] - elif image_is_pil_list: - image_batch_size = len(image) - elif image_is_tensor_list: + else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): @@ -660,29 +663,7 @@ def prepare_control_image( do_classifier_free_guidance=False, guess_mode=False, ): - if not isinstance(image, torch.Tensor): - if isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - images = [] - - for image_ in image: - image_ = image_.convert("RGB") - image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - image_ = np.array(image_) - image_ = image_[None, :] - images.append(image_) - - image = images - - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -720,21 +701,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + if image.shape[1] == 4: + init_latents = image + else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - init_latents = self.vae.config.scaling_factor * init_latents + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -763,31 +749,6 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt return latents - def _default_height_width(self, height, width, image): - # NOTE: It is possible that a list of images have different - # dimensions for each image, so just checking the first image - # is not _exactly_ correct, but it is simple. - while isinstance(image, list): - image = image[0] - - if height is None: - if isinstance(image, PIL.Image.Image): - height = image.height - elif isinstance(image, torch.Tensor): - height = image.shape[2] - - height = (height // 8) * 8 # round down to nearest multiple of 8 - - if width is None: - if isinstance(image, PIL.Image.Image): - width = image.width - elif isinstance(image, torch.Tensor): - width = image.shape[3] - - width = (width // 8) * 8 # round down to nearest multiple of 8 - - return height, width - # override DiffusionPipeline def save_pretrained( self, @@ -805,9 +766,21 @@ def save_pretrained( def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, control_image: Union[ - torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image] + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], ] = None, height: Optional[int] = None, width: Optional[int] = None, @@ -836,8 +809,12 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, - `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The initial image will be used as the starting point for the image generation process. Can also accpet + image latents as `image`, if passing latents directly, it will not be encoded again. + control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If @@ -914,15 +891,10 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - # 0. Default height and width to unet - height, width = self._default_height_width(height, width, image) - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, control_image, - height, - width, callback_steps, negative_prompt, prompt_embeds, @@ -966,10 +938,10 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) - # 4. Prepare image, and controlnet_conditioning_image - image = prepare_image(image) + # 4. Prepare image + image = self.image_processor.preprocess(image).to(dtype=torch.float32) - # 5. Prepare image + # 5. Prepare controlnet_conditioning_image if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index f57d88bd8d8a..e43ee3a1e5ec 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -30,7 +30,6 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( - PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, is_compiled_module, @@ -316,6 +315,9 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -742,24 +744,30 @@ def check_inputs( else: assert False + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): raise TypeError( - "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + "image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors" ) if image_is_pil: image_batch_size = 1 - elif image_is_tensor: - image_batch_size = image.shape[0] - elif image_is_pil_list: - image_batch_size = len(image) - elif image_is_tensor_list: + else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): @@ -787,29 +795,7 @@ def prepare_control_image( do_classifier_free_guidance=False, guess_mode=False, ): - if not isinstance(image, torch.Tensor): - if isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - images = [] - - for image_ in image: - image_ = image_.convert("RGB") - image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - image_ = np.array(image_) - image_ = image_[None, :] - images.append(image_) - - image = images - - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -982,7 +968,12 @@ def __call__( image: Union[torch.Tensor, PIL.Image.Image] = None, mask_image: Union[torch.Tensor, PIL.Image.Image] = None, control_image: Union[ - torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image] + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], ] = None, height: Optional[int] = None, width: Optional[int] = None, diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index f4914c46db51..d2aa1d4f1f77 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -13,6 +13,7 @@ # limitations under the License. +import warnings from typing import List, Optional, Tuple, Union import numpy as np @@ -30,6 +31,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 8babc6ab0d11..6b6df0945943 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -40,6 +40,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -549,21 +554,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt image = image.to(device=device, dtype=dtype) batch_size = image.shape[0] - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + if image.shape[1] == 4: + init_latents = image + else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) - init_latents = self.vae.config.scaling_factor * init_latents + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -599,7 +609,14 @@ def __call__( self, prompt: Union[str, List[str]], source_prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, @@ -619,9 +636,10 @@ def __call__( Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - image (`torch.FloatTensor` or `PIL.Image.Image`): + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of @@ -699,7 +717,7 @@ def __call__( ) # 4. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 67d3f44e6d4b..293ed7d981b8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -33,6 +34,13 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 def preprocess(image): + warnings.warn( + ( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead" + ), + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index a5b2a9987fa1..2fd4503a94ce 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -37,6 +37,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -423,21 +428,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + if image.shape[1] == 4: + init_latents = image + else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) - init_latents = self.vae.config.scaling_factor * init_latents + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -474,6 +484,8 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui if isinstance(image[0], PIL.Image.Image): width, height = image[0].size + elif isinstance(image[0], np.ndarray): + width, height = image[0].shape[:-1] else: height, width = image[0].shape[-2:] @@ -512,7 +524,14 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, depth_map: Optional[torch.FloatTensor] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, @@ -535,9 +554,12 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. + process. Can accept image latents as `image` only if `depth_map` is not `None`. + depth_map (`torch.FloatTensor`, *optional*): + depth prediction that will be used as additional conditioning for the image generation process. If not + defined, it will automatically predicts the depth via `self.depth_estimator`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of @@ -664,7 +686,7 @@ def __call__( ) # 5. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) # 6. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index e4fc08b79cfd..3c1ac58bcee4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -159,6 +159,11 @@ def kl_divergence(hidden_states): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -799,19 +804,25 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None image = image.to(device=device, dtype=dtype) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) + if image.shape[1] == 4: + latents = image - if isinstance(generator, list): - latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] - latents = torch.cat(latents, dim=0) else: - latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) - latents = self.vae.config.scaling_factor * latents + latents = self.vae.config.scaling_factor * latents if batch_size != latents.shape[0]: if batch_size % latents.shape[0] == 0: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 258c8000ba63..106b6528a982 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -73,6 +73,11 @@ def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -441,6 +446,7 @@ def _encode_prompt( return prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -455,6 +461,7 @@ def run_safety_checker(self, image, device, dtype): ) return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" @@ -544,21 +551,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + if image.shape[1] == 4: + init_latents = image + else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) - init_latents = self.vae.config.scaling_factor * init_latents + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -592,7 +604,14 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, @@ -615,9 +634,10 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 65ef5617fc68..25102ae7cf4a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -43,6 +43,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -145,7 +150,14 @@ def __init__( def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, num_inference_steps: int = 100, guidance_scale: float = 7.5, image_guidance_scale: float = 1.5, @@ -168,8 +180,9 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch which will be repainted according to `prompt`. + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also + accpet image latents as `image`, if passing latents directly, it will not be encoded again. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -290,8 +303,7 @@ def __call__( ) # 3. Preprocess image - image = preprocess(image) - height, width = image.shape[-2:] + image = self.image_processor.preprocess(image) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -308,6 +320,10 @@ def __call__( generator, ) + height, width = image_latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( @@ -746,17 +762,21 @@ def prepare_image_latents( image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) + if image.shape[1] == 4: + image_latents = image else: - image_latents = self.vae.encode(image).latent_dist.mode() + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.mode() if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 664d58dc812f..e0fecf6d353f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -94,7 +94,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -291,7 +291,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def __call__( self, prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -308,7 +315,7 @@ def __call__( Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image upscaling. - image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an @@ -413,7 +420,7 @@ def __call__( ) # 4. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) image = image.to(dtype=text_embeddings.dtype, device=device) if image.shape[1] == 3: # encode image if not in latent-space yet diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 3b7c6dc6b513..3332cc89d96c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -177,6 +177,11 @@ class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -629,7 +634,6 @@ def prepare_extra_step_kwargs(self, generator, eta): def check_inputs( self, prompt, - image, source_embeds, target_embeds, callback_steps, @@ -727,19 +731,25 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None image = image.to(device=device, dtype=dtype) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) + if image.shape[1] == 4: + latents = image - if isinstance(generator, list): - latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] - latents = torch.cat(latents, dim=0) else: - latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - latents = self.vae.config.scaling_factor * latents + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents if batch_size != latents.shape[0]: if batch_size % latents.shape[0] == 0: @@ -804,7 +814,6 @@ def kl_divergence(self, hidden_states): def __call__( self, prompt: Optional[Union[str, List[str]]] = None, - image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, source_embeds: torch.Tensor = None, target_embeds: torch.Tensor = None, height: Optional[int] = None, @@ -905,7 +914,6 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, - image, source_embeds, target_embeds, callback_steps, @@ -1085,7 +1093,14 @@ def __call__( def invert( self, prompt: Optional[str] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, num_inference_steps: int = 50, guidance_scale: float = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -1109,8 +1124,9 @@ def invert( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`PIL.Image.Image`, *optional*): - `Image`, or tensor representing an image batch which will be used for conditioning. + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be used for conditioning. Can also accpet + image latents as `image`, if passing latents directly, it will not be encoded again. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -1179,7 +1195,7 @@ def invert( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) # 4. Prepare latent variables latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) @@ -1267,16 +1283,13 @@ def invert( inverted_latents = latents.detach().clone() # 8. Post-processing - image = self.decode_latents(latents.detach()) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() - # 9. Convert to PIL. - if output_type == "pil": - image = self.image_processor.numpy_to_pil(image) - if not return_dict: return (inverted_latents, image) diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index 36e5411b4215..ecc457b4cb94 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -1,4 +1,5 @@ import inspect +import warnings from dataclasses import dataclass from typing import Callable, List, Optional, Union @@ -34,6 +35,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py index 6842d29dc6c0..1344d33a2552 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -40,6 +40,7 @@ class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMix params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index b2312a4e94d0..9915998be24e 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -41,7 +41,9 @@ ) from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS, ) from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin @@ -99,7 +101,8 @@ class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 9d3b10aa8283..de8f578a3cce 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -38,6 +38,7 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) @@ -51,7 +52,8 @@ class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTest pipeline_class = StableDiffusionControlNetImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"control_image"}) + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index f8cc881e8650..0f8808bcb728 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -40,6 +40,7 @@ from ..pipeline_params import ( TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, ) from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin @@ -51,7 +52,8 @@ class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTest pipeline_class = StableDiffusionControlNetInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - image_params = frozenset([]) + image_params = frozenset({"control_image"}) # skip `image` and `mask` for now, only test for control_image + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py index a1ae3d2d0e7c..9a54c21c0a21 100644 --- a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py @@ -25,7 +25,11 @@ from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin @@ -42,7 +46,8 @@ class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterM } required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"}) - image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -101,6 +106,7 @@ def get_dummy_components(self): def get_dummy_inputs(self, device, seed=0): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 6140bf771e65..a5c36046013f 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -92,6 +92,7 @@ class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTester params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index c35d84de9802..e16478f06112 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -47,6 +47,7 @@ class StableDiffusionImageVariationPipelineFastTests( batch_params = IMAGE_VARIATION_BATCH_PARAMS image_params = frozenset([]) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 33305d5980be..eefbc83ce9d7 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -32,7 +32,6 @@ StableDiffusionImg2ImgPipeline, UNet2DConditionModel, ) -from diffusers.image_processor import VaeImageProcessor from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -91,6 +90,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipelin required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -142,6 +142,7 @@ def get_dummy_components(self): def get_dummy_inputs(self, device, seed=0): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: @@ -160,12 +161,10 @@ def test_stable_diffusion_img2img_default_case(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - inputs["image"] = inputs["image"] / 2 + 0.5 image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] @@ -178,12 +177,10 @@ def test_stable_diffusion_img2img_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - inputs["image"] = inputs["image"] / 2 + 0.5 negative_prompt = "french fries" output = sd_pipe(**inputs, negative_prompt=negative_prompt) image = output.images @@ -198,14 +195,12 @@ def test_stable_diffusion_img2img_multiple_init_images(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) inputs["prompt"] = [inputs["prompt"]] * 2 inputs["image"] = inputs["image"].repeat(2, 1, 1, 1) - inputs["image"] = inputs["image"] / 2 + 0.5 image = sd_pipe(**inputs).images image_slice = image[-1, -3:, -3:, -1] @@ -221,12 +216,10 @@ def test_stable_diffusion_img2img_k_lms(self): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - inputs["image"] = inputs["image"] / 2 + 0.5 image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index e355e82e5b35..cbe9aa063915 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -88,6 +88,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipelin batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS image_params = frozenset([]) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py index fbff6c554967..691427b1c6eb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py @@ -31,10 +31,15 @@ StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel, ) +from diffusers.image_processor import VaeImageProcessor from diffusers.utils import floats_tensor, load_image, slow, torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu -from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin @@ -47,9 +52,8 @@ class StableDiffusionInstructPix2PixPipelineFastTests( pipeline_class = StableDiffusionInstructPix2PixPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"} batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - image_params = frozenset( - [] - ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -163,6 +167,7 @@ def test_stable_diffusion_pix2pix_multiple_init_images(self): image = np.array(inputs["image"]).astype(np.float32) / 255.0 image = torch.from_numpy(image).unsqueeze(0).to(device) + image = image / 2 + 0.5 image = image.permute(0, 3, 1, 2) inputs["image"] = image.repeat(2, 1, 1, 1) @@ -199,6 +204,28 @@ def test_stable_diffusion_pix2pix_euler(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + # Overwrite the default test_latents_inputs because pix2pix encode the image differently + def test_latents_input(self): + components = self.get_dummy_components() + pipe = StableDiffusionInstructPix2PixPipeline(**components) + pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0] + + vae = components["vae"] + inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt") + + for image_param in self.image_latents_params: + if image_param in inputs.keys(): + inputs[image_param] = vae.encode(inputs[image_param]).latent_dist.mode() + + out_latents_inputs = pipe(**inputs)[0] + + max_diff = np.abs(out - out_latents_inputs).max() + self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image") + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py index cba20417bca0..f47a70c4ece8 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py @@ -44,6 +44,7 @@ class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, Pi params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 021065416838..c8d2bfa8c59d 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -45,6 +45,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 98f5910ab313..6f41d2c43c8e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -32,11 +32,16 @@ StableDiffusionPix2PixZeroPipeline, UNet2DConditionModel, ) +from diffusers.image_processor import VaeImageProcessor from diffusers.utils import floats_tensor, load_numpy, slow, torch_device from diffusers.utils.testing_utils import enable_full_determinism, load_image, load_pt, require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference enable_full_determinism() @@ -45,11 +50,10 @@ @skip_mps class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPix2PixZeroPipeline - params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"image"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - image_params = frozenset( - [] - ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS @classmethod def setUpClass(cls): @@ -130,6 +134,7 @@ def get_dummy_inputs(self, device, seed=0): def get_dummy_inversion_inputs(self, device, seed=0): dummy_image = floats_tensor((2, 3, 32, 32), rng=random.Random(seed)).to(torch_device) + dummy_image = dummy_image / 2 + 0.5 generator = torch.manual_seed(seed) inputs = { @@ -145,6 +150,24 @@ def get_dummy_inversion_inputs(self, device, seed=0): } return inputs + def get_dummy_inversion_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"): + inputs = self.get_dummy_inversion_inputs(device, seed) + + if input_image_type == "pt": + image = inputs["image"] + elif input_image_type == "np": + image = VaeImageProcessor.pt_to_numpy(inputs["image"]) + elif input_image_type == "pil": + image = VaeImageProcessor.pt_to_numpy(inputs["image"]) + image = VaeImageProcessor.numpy_to_pil(image) + else: + raise ValueError(f"unsupported input_image_type {input_image_type}") + + inputs["image"] = image + inputs["output_type"] = output_type + + return inputs + def test_save_load_optional_components(self): if not hasattr(self.pipeline_class, "_optional_components"): return @@ -281,6 +304,41 @@ def test_stable_diffusion_pix2pix_zero_ddpm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_outputs_equivalent(self): + device = torch_device + components = self.get_dummy_components() + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + output_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pt")).images + output_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="np")).images + output_pil = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pil")).images + + max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() + self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`") + + max_diff = np.abs(np.array(output_pil[0]) - (output_np[0] * 255).round()).max() + self.assertLess(max_diff, 2.0, "`output_type=='pil'` generate different results from `output_type=='np'`") + + def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_inputs_equivalent(self): + device = torch_device + components = self.get_dummy_components() + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + out_input_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pt")).images + out_input_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="np")).images + out_input_pil = sd_pipe.invert( + **self.get_dummy_inversion_inputs_by_type(device, input_image_type="pil") + ).images + + max_diff = np.abs(out_input_pt - out_input_np).max() + self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`") + + assert_mean_pixel_difference(out_input_pil, out_input_np, expected_max_diff=1) + # Non-determinism caused by the scheduler optimizing the latent inputs during inference @unittest.skip("non-deterministic pipeline") def test_inference_batch_single_identical(self): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py index 2b0f0bfc11a6..91719ce7676f 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py @@ -41,6 +41,7 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_cpu_offload = False def get_dummy_components(self): diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 3f9867783b33..87a960c7d1a4 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -47,6 +47,7 @@ class StableDiffusion2PipelineFastTests(PipelineLatentTesterMixin, PipelineTeste params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py index 6cec2cce752d..304ddacd2c36 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py @@ -45,6 +45,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests( params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"}) image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS # Attend and excite requires being able to run a backward pass at # inference time. There's no deterministic backward operator for pad diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 08ac29868971..f393967c7de4 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -51,7 +51,12 @@ ) from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, +) from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin @@ -65,9 +70,8 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - image_params = frozenset( - [] - ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py index 8df5b6da846c..1de80d60d8e8 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -49,6 +49,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli image_params = frozenset( [] ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 10d8561f0126..37c254f367f3 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -40,6 +40,7 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineLatentTesterMixin, Pipeli image_params = frozenset( [] ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py index 561536a44ea0..b94aaca4258a 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py @@ -52,6 +52,7 @@ class StableDiffusionLatentUpscalePipelineFastTests(PipelineLatentTesterMixin, P image_params = frozenset( [] ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) test_cpu_offload = True diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index 8b4a065cd4bf..4bbbad757edf 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -27,6 +27,7 @@ class StableUnCLIPPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMix params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS # TODO(will) Expected attn_bias.stride(1) == 0 to be true, but got false test_xformers_attention = False diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index 35cae61242c4..741343066133 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -46,6 +46,7 @@ class StableUnCLIPImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTe image_params = frozenset( [] ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) def get_dummy_components(self): embedder_hidden_size = 32 diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 3ddfd35defb7..976e3f5cbb12 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -8,6 +8,7 @@ from typing import Callable, Union import numpy as np +import PIL import torch import diffusers @@ -39,9 +40,28 @@ def image_params(self) -> frozenset: "`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results" ) + @property + def image_latents_params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `image_latents_params` in the child test class. " + "`image_latents_params` are tested for if passing latents directly are producing same results" + ) + def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"): inputs = self.get_dummy_inputs(device, seed) + def convert_to_pt(image): + if isinstance(image, torch.Tensor): + input_image = image + elif isinstance(image, np.ndarray): + input_image = VaeImageProcessor.numpy_to_pt(image) + elif isinstance(image, PIL.Image.Image): + input_image = VaeImageProcessor.pil_to_numpy(image) + input_image = VaeImageProcessor.numpy_to_pt(input_image) + else: + raise ValueError(f"unsupported input_image_type {type(image)}") + return input_image + def convert_pt_to_type(image, input_image_type): if input_image_type == "pt": input_image = image @@ -56,21 +76,32 @@ def convert_pt_to_type(image, input_image_type): for image_param in self.image_params: if image_param in inputs.keys(): - inputs[image_param] = convert_pt_to_type(inputs[image_param], input_image_type) + inputs[image_param] = convert_pt_to_type( + convert_to_pt(inputs[image_param]).to(device), input_image_type + ) inputs["output_type"] = output_type return inputs def test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4): + self._test_pt_np_pil_outputs_equivalent(expected_max_diff=expected_max_diff) + + def _test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4, input_image_type="pt"): components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - output_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pt"))[0] - output_np = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="np"))[0] - output_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pil"))[0] + output_pt = pipe( + **self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pt") + )[0] + output_np = pipe( + **self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="np") + )[0] + output_pil = pipe( + **self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pil") + )[0] max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() self.assertLess( @@ -98,6 +129,31 @@ def test_pt_np_pil_inputs_equivalent(self): max_diff = np.abs(out_input_pil - out_input_np).max() self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`") + def test_latents_input(self): + if len(self.image_latents_params) == 0: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0] + + vae = components["vae"] + inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt") + generator = inputs["generator"] + for image_param in self.image_latents_params: + if image_param in inputs.keys(): + inputs[image_param] = ( + vae.encode(inputs[image_param]).latent_dist.sample(generator) * vae.config.scaling_factor + ) + out_latents_inputs = pipe(**inputs)[0] + + max_diff = np.abs(out - out_latents_inputs).max() + self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image") + @require_torch class PipelineTesterMixin: