Skip to content

Add image_processor #2617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
50615d3
add image_processor
Mar 8, 2023
d82730d
Apply suggestions from code review
yiyixuxu Mar 9, 2023
d0d1437
add more tests
Mar 9, 2023
da62e8d
make style
Mar 9, 2023
98146d0
fix
Mar 9, 2023
d223e8e
update img2mg
Mar 9, 2023
5eb7592
style
Mar 9, 2023
af21a0d
fix
Mar 9, 2023
803c93e
apply feedbacks
Mar 12, 2023
5c6de08
fix style
Mar 13, 2023
e07a9be
remove fixed copies on img2img preprocess
Mar 13, 2023
cd2721f
fix
Mar 13, 2023
2c702f1
Update src/diffusers/image_processor.py
yiyixuxu Mar 14, 2023
dc508d6
Update src/diffusers/image_processor.py
yiyixuxu Mar 14, 2023
3475dec
Update src/diffusers/image_processor.py
yiyixuxu Mar 14, 2023
e3a0b13
Update src/diffusers/image_processor.py
yiyixuxu Mar 14, 2023
63b2418
Update src/diffusers/image_processor.py
yiyixuxu Mar 14, 2023
2847d4b
Update src/diffusers/image_processor.py
yiyixuxu Mar 14, 2023
f6e5af0
Update src/diffusers/image_processor.py
yiyixuxu Mar 14, 2023
771f6c0
Update src/diffusers/image_processor.py
yiyixuxu Mar 14, 2023
26e9514
Update src/diffusers/image_processor.py
yiyixuxu Mar 14, 2023
7c9b9f7
fix typos
Mar 14, 2023
f009e97
add back preprocess function
Mar 14, 2023
e2f7cf4
Revert "remove fixed copies on img2img preprocess"
Mar 14, 2023
c1569be
Revert "fix"
Mar 14, 2023
9cf2c0b
revert change in expected slice
Mar 14, 2023
1fe112c
fix img2img tests
Mar 14, 2023
2f4cade
make style
Mar 14, 2023
90e0539
remov #fixed copy on img2img init method
Mar 14, 2023
983f4e9
remove #copy on img2img decode_latents
Mar 14, 2023
8ab5015
update alt_img2img
Mar 14, 2023
4cc2d0e
style
Mar 14, 2023
d919e69
deprecate preprocess
Mar 14, 2023
daa3d32
style + copy
Mar 14, 2023
cd83878
style again
Mar 14, 2023
1c893ab
Merge branch 'main' into image-processor
patrickvonplaten Mar 14, 2023
3dbb862
update error message for using resize with torch tensor or numpy array
Mar 14, 2023
ef8582f
fix
Mar 14, 2023
0cec737
remove deprecation warning for preprocess function + fix copies
Mar 14, 2023
be5fcdc
remove comment
Mar 14, 2023
f3a2676
Apply suggestions from code review
yiyixuxu Mar 14, 2023
419cabb
update error message
Mar 14, 2023
c844d2c
Update src/diffusers/__init__.py
patrickvonplaten Mar 15, 2023
bf513f1
Apply suggestions from code review
patrickvonplaten Mar 15, 2023
3054135
fix import
Mar 15, 2023
89921a9
fix copies
Mar 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Union

import numpy as np
import PIL
import torch
from PIL import Image

from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION


class VaeImageProcessor(ConfigMixin):
"""
Image Processor for VAE

Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1]
"""

config_name = CONFIG_NAME

@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = True,
):
super().__init__()

@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]

return pil_images

@staticmethod
def numpy_to_pt(images):
"""
Convert a numpy image to a pytorch tensor
"""
if images.ndim == 3:
images = images[..., None]

images = torch.from_numpy(images.transpose(0, 3, 1, 2))
return images

@staticmethod
def pt_to_numpy(images):
"""
Convert a numpy image to a pytorch tensor
"""
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images

@staticmethod
def normalize(images):
"""
Normalize an image array to [-1,1]
"""
return 2.0 * images - 1.0

def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
"""
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
"""
w, h = images.size
w, h = map(lambda x: x - x % self.vae_scale_factor, (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample])
return images

def preprocess(
self,
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
) -> torch.Tensor:
"""
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
if isinstance(image, supported_formats):
image = [image]
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
)

if isinstance(image[0], PIL.Image.Image):
if self.do_resize:
image = [self.resize(i) for i in image]
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
image = np.stack(image, axis=0) # to np
image = self.numpy_to_pt(image) # to pt

elif isinstance(image[0], np.ndarray):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
image = self.numpy_to_pt(image)
_, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)

elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, _, height, width = image.shape
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)

# expected range [0,1], normalize to [-1,1]
do_normalize = self.do_normalize
if image.min() < 0:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
FutureWarning,
)
do_normalize = False

if do_normalize:
image = self.normalize(image)

return image

def postprocess(
self,
image,
output_type: str = "pil",
):
if isinstance(image, torch.Tensor) and output_type == "pt":
return image

image = self.pt_to_numpy(image)

if output_type == "np":
return image
elif output_type == "pil":
return self.numpy_to_pil(image)
else:
raise ValueError(f"Unsupported output_type {output_type}.")
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from diffusers.utils import is_accelerate_available, is_accelerate_version

from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
Expand Down Expand Up @@ -192,7 +193,6 @@ def __init__(
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand All @@ -203,7 +203,11 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
requires_safety_checker=requires_safety_checker,
)

def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Expand Down Expand Up @@ -415,21 +419,17 @@ def _encode_prompt(
return prompt_embeds

def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept

def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image

def prepare_extra_step_kwargs(self, generator, eta):
Expand Down Expand Up @@ -663,7 +663,7 @@ def __call__(
)

# 4. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)

# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -703,15 +703,26 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 9. Post-processing
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if output_type == "latent":
image = latents
has_nsfw_concept = None

image = self.decode_latents(latents)

# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False

# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
Expand Down Expand Up @@ -119,7 +120,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
"""
_optional_components = ["safety_checker", "feature_extractor"]

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
self,
vae: AutoencoderKL,
Expand Down Expand Up @@ -196,7 +196,6 @@ def __init__(
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand All @@ -207,7 +206,11 @@ def __init__(
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
requires_safety_checker=requires_safety_checker,
)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
Expand Down Expand Up @@ -422,24 +425,18 @@ def _encode_prompt(

return prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
Expand Down Expand Up @@ -674,7 +671,7 @@ def __call__(
)

# 4. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)

# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -714,15 +711,26 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# 9. Post-processing
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if output_type == "latent":
image = latents
has_nsfw_concept = None

image = self.decode_latents(latents)

# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False

# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type=output_type)

# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
Expand Down
Loading