diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index e0fecf6d353f..d67a7f894886 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -33,6 +33,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 6bb463a6a65f..4c4f3998cb91 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -21,6 +21,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor @@ -34,6 +35,11 @@ def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -125,6 +131,8 @@ def __init__( watermarker=watermarker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") self.register_to_config(max_noise_level=max_noise_level) def enable_sequential_cpu_offload(self, gpu_id=0): @@ -432,14 +440,15 @@ def check_inputs( if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) + and not isinstance(image, np.ndarray) and not isinstance(image, list) ): raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" ) - # verify batch size of prompt and image are same if image is a list or tensor - if isinstance(image, list) or isinstance(image, torch.Tensor): + # verify batch size of prompt and image are same if image is a list or tensor or numpy array + if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): if isinstance(prompt, str): batch_size = 1 else: @@ -483,7 +492,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, @@ -506,7 +522,7 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be upscaled. * num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -627,7 +643,7 @@ def __call__( ) # 4. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) image = image.to(dtype=prompt_embeds.dtype, device=device) # 5. set timesteps @@ -723,25 +739,25 @@ def __call__( else: latents = latents.float() - # 11. Convert to PIL - if output_type == "pil": - image = self.decode_latents(latents) - + # post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) - - image = self.numpy_to_pil(image) - - # 11. Apply watermark - if self.watermarker is not None: - image = self.watermarker.apply_watermark(image) - elif output_type == "pt": - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - has_nsfw_concept = None else: - image = self.decode_latents(latents) + image = latents has_nsfw_concept = None + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # 11. Apply watermark + if output_type == "pil" and self.watermarker is not None: + image = self.watermarker.apply_watermark(image) + # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload()