Skip to content

Commit 017ee16

Browse files
yiyixuxuyiyixuxu
andauthored
refactor Image processor for x4 upscaler (#3692)
* refactor x4 upscaler * style * copies --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent 8669e83 commit 017ee16

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333

3434
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
3535
def preprocess(image):
36+
warnings.warn(
37+
"The preprocess method is deprecated and will be removed in a future version. Please"
38+
" use VaeImageProcessor.preprocess instead",
39+
FutureWarning,
40+
)
3641
if isinstance(image, torch.Tensor):
3742
return image
3843
elif isinstance(image, PIL.Image.Image):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
2323

24+
from ...image_processor import VaeImageProcessor
2425
from ...loaders import TextualInversionLoaderMixin
2526
from ...models import AutoencoderKL, UNet2DConditionModel
2627
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
@@ -34,6 +35,11 @@
3435

3536

3637
def preprocess(image):
38+
warnings.warn(
39+
"The preprocess method is deprecated and will be removed in a future version. Please"
40+
" use VaeImageProcessor.preprocess instead",
41+
FutureWarning,
42+
)
3743
if isinstance(image, torch.Tensor):
3844
return image
3945
elif isinstance(image, PIL.Image.Image):
@@ -125,6 +131,8 @@ def __init__(
125131
watermarker=watermarker,
126132
feature_extractor=feature_extractor,
127133
)
134+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
135+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
128136
self.register_to_config(max_noise_level=max_noise_level)
129137

130138
def enable_sequential_cpu_offload(self, gpu_id=0):
@@ -432,14 +440,15 @@ def check_inputs(
432440
if (
433441
not isinstance(image, torch.Tensor)
434442
and not isinstance(image, PIL.Image.Image)
443+
and not isinstance(image, np.ndarray)
435444
and not isinstance(image, list)
436445
):
437446
raise ValueError(
438-
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
447+
f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
439448
)
440449

441-
# verify batch size of prompt and image are same if image is a list or tensor
442-
if isinstance(image, list) or isinstance(image, torch.Tensor):
450+
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
451+
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
443452
if isinstance(prompt, str):
444453
batch_size = 1
445454
else:
@@ -483,7 +492,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
483492
def __call__(
484493
self,
485494
prompt: Union[str, List[str]] = None,
486-
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
495+
image: Union[
496+
torch.FloatTensor,
497+
PIL.Image.Image,
498+
np.ndarray,
499+
List[torch.FloatTensor],
500+
List[PIL.Image.Image],
501+
List[np.ndarray],
502+
] = None,
487503
num_inference_steps: int = 75,
488504
guidance_scale: float = 9.0,
489505
noise_level: int = 20,
@@ -506,7 +522,7 @@ def __call__(
506522
prompt (`str` or `List[str]`, *optional*):
507523
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
508524
instead.
509-
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
525+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
510526
`Image`, or tensor representing an image batch which will be upscaled. *
511527
num_inference_steps (`int`, *optional*, defaults to 50):
512528
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -627,7 +643,7 @@ def __call__(
627643
)
628644

629645
# 4. Preprocess image
630-
image = preprocess(image)
646+
image = self.image_processor.preprocess(image)
631647
image = image.to(dtype=prompt_embeds.dtype, device=device)
632648

633649
# 5. set timesteps
@@ -723,25 +739,25 @@ def __call__(
723739
else:
724740
latents = latents.float()
725741

726-
# 11. Convert to PIL
727-
if output_type == "pil":
728-
image = self.decode_latents(latents)
729-
742+
# post-processing
743+
if not output_type == "latent":
744+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
730745
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
731-
732-
image = self.numpy_to_pil(image)
733-
734-
# 11. Apply watermark
735-
if self.watermarker is not None:
736-
image = self.watermarker.apply_watermark(image)
737-
elif output_type == "pt":
738-
latents = 1 / self.vae.config.scaling_factor * latents
739-
image = self.vae.decode(latents).sample
740-
has_nsfw_concept = None
741746
else:
742-
image = self.decode_latents(latents)
747+
image = latents
743748
has_nsfw_concept = None
744749

750+
if has_nsfw_concept is None:
751+
do_denormalize = [True] * image.shape[0]
752+
else:
753+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
754+
755+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
756+
757+
# 11. Apply watermark
758+
if output_type == "pil" and self.watermarker is not None:
759+
image = self.watermarker.apply_watermark(image)
760+
745761
# Offload last model to CPU
746762
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
747763
self.final_offload_hook.offload()

0 commit comments

Comments
 (0)