Skip to content

refactor Image processor for x4 upscaler #3692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
Expand All @@ -34,6 +35,11 @@


def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
Expand Down Expand Up @@ -125,6 +131,8 @@ def __init__(
watermarker=watermarker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
self.register_to_config(max_noise_level=max_noise_level)

def enable_sequential_cpu_offload(self, gpu_id=0):
Expand Down Expand Up @@ -432,14 +440,15 @@ def check_inputs(
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, np.ndarray)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
)

# verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, list) or isinstance(image, torch.Tensor):
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
if isinstance(prompt, str):
batch_size = 1
else:
Expand Down Expand Up @@ -483,7 +492,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
Expand All @@ -506,7 +522,7 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -627,7 +643,7 @@ def __call__(
)

# 4. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)
image = image.to(dtype=prompt_embeds.dtype, device=device)

# 5. set timesteps
Expand Down Expand Up @@ -723,25 +739,25 @@ def __call__(
else:
latents = latents.float()

# 11. Convert to PIL
if output_type == "pil":
image = self.decode_latents(latents)

# post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)

image = self.numpy_to_pil(image)

# 11. Apply watermark
if self.watermarker is not None:
image = self.watermarker.apply_watermark(image)
elif output_type == "pt":
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
has_nsfw_concept = None
else:
image = self.decode_latents(latents)
image = latents
has_nsfw_concept = None

if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

# 11. Apply watermark
if output_type == "pil" and self.watermarker is not None:
image = self.watermarker.apply_watermark(image)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
Expand Down