-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Add image_processor #2617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+447
−45
Merged
Add image_processor #2617
Changes from all commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
50615d3
add image_processor
d82730d
Apply suggestions from code review
yiyixuxu d0d1437
add more tests
da62e8d
make style
98146d0
fix
d223e8e
update img2mg
5eb7592
style
af21a0d
fix
803c93e
apply feedbacks
5c6de08
fix style
e07a9be
remove fixed copies on img2img preprocess
cd2721f
fix
2c702f1
Update src/diffusers/image_processor.py
yiyixuxu dc508d6
Update src/diffusers/image_processor.py
yiyixuxu 3475dec
Update src/diffusers/image_processor.py
yiyixuxu e3a0b13
Update src/diffusers/image_processor.py
yiyixuxu 63b2418
Update src/diffusers/image_processor.py
yiyixuxu 2847d4b
Update src/diffusers/image_processor.py
yiyixuxu f6e5af0
Update src/diffusers/image_processor.py
yiyixuxu 771f6c0
Update src/diffusers/image_processor.py
yiyixuxu 26e9514
Update src/diffusers/image_processor.py
yiyixuxu 7c9b9f7
fix typos
f009e97
add back preprocess function
e2f7cf4
Revert "remove fixed copies on img2img preprocess"
c1569be
Revert "fix"
9cf2c0b
revert change in expected slice
1fe112c
fix img2img tests
2f4cade
make style
90e0539
remov #fixed copy on img2img init method
983f4e9
remove #copy on img2img decode_latents
8ab5015
update alt_img2img
4cc2d0e
style
d919e69
deprecate preprocess
daa3d32
style + copy
cd83878
style again
1c893ab
Merge branch 'main' into image-processor
patrickvonplaten 3dbb862
update error message for using resize with torch tensor or numpy array
ef8582f
fix
0cec737
remove deprecation warning for preprocess function + fix copies
be5fcdc
remove comment
f3a2676
Apply suggestions from code review
yiyixuxu 419cabb
update error message
c844d2c
Update src/diffusers/__init__.py
patrickvonplaten bf513f1
Apply suggestions from code review
patrickvonplaten 3054135
fix import
89921a9
fix copies
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# Copyright 2023 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import warnings | ||
from typing import Union | ||
|
||
import numpy as np | ||
import PIL | ||
import torch | ||
from PIL import Image | ||
|
||
from .configuration_utils import ConfigMixin, register_to_config | ||
from .utils import CONFIG_NAME, PIL_INTERPOLATION | ||
|
||
|
||
class VaeImageProcessor(ConfigMixin): | ||
""" | ||
Image Processor for VAE | ||
|
||
Args: | ||
do_resize (`bool`, *optional*, defaults to `True`): | ||
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. | ||
vae_scale_factor (`int`, *optional*, defaults to `8`): | ||
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this | ||
factor. | ||
resample (`str`, *optional*, defaults to `lanczos`): | ||
Resampling filter to use when resizing the image. | ||
do_normalize (`bool`, *optional*, defaults to `True`): | ||
Whether to normalize the image to [-1,1] | ||
""" | ||
|
||
config_name = CONFIG_NAME | ||
|
||
@register_to_config | ||
def __init__( | ||
self, | ||
do_resize: bool = True, | ||
vae_scale_factor: int = 8, | ||
resample: str = "lanczos", | ||
do_normalize: bool = True, | ||
): | ||
super().__init__() | ||
|
||
@staticmethod | ||
def numpy_to_pil(images): | ||
""" | ||
Convert a numpy image or a batch of images to a PIL image. | ||
""" | ||
if images.ndim == 3: | ||
images = images[None, ...] | ||
images = (images * 255).round().astype("uint8") | ||
if images.shape[-1] == 1: | ||
# special case for grayscale (single channel) images | ||
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | ||
else: | ||
pil_images = [Image.fromarray(image) for image in images] | ||
|
||
return pil_images | ||
|
||
@staticmethod | ||
def numpy_to_pt(images): | ||
""" | ||
Convert a numpy image to a pytorch tensor | ||
""" | ||
if images.ndim == 3: | ||
images = images[..., None] | ||
|
||
images = torch.from_numpy(images.transpose(0, 3, 1, 2)) | ||
return images | ||
|
||
@staticmethod | ||
def pt_to_numpy(images): | ||
""" | ||
Convert a numpy image to a pytorch tensor | ||
""" | ||
images = images.cpu().permute(0, 2, 3, 1).float().numpy() | ||
return images | ||
|
||
@staticmethod | ||
def normalize(images): | ||
""" | ||
Normalize an image array to [-1,1] | ||
""" | ||
return 2.0 * images - 1.0 | ||
|
||
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: | ||
""" | ||
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` | ||
""" | ||
w, h = images.size | ||
w, h = map(lambda x: x - x % self.vae_scale_factor, (w, h)) # resize to integer multiple of vae_scale_factor | ||
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.resample]) | ||
return images | ||
|
||
def preprocess( | ||
self, | ||
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], | ||
) -> torch.Tensor: | ||
""" | ||
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors" | ||
""" | ||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) | ||
if isinstance(image, supported_formats): | ||
image = [image] | ||
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): | ||
raise ValueError( | ||
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" | ||
) | ||
|
||
if isinstance(image[0], PIL.Image.Image): | ||
if self.do_resize: | ||
image = [self.resize(i) for i in image] | ||
image = [np.array(i).astype(np.float32) / 255.0 for i in image] | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
image = np.stack(image, axis=0) # to np | ||
image = self.numpy_to_pt(image) # to pt | ||
|
||
elif isinstance(image[0], np.ndarray): | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) | ||
image = self.numpy_to_pt(image) | ||
_, _, height, width = image.shape | ||
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): | ||
raise ValueError( | ||
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.vae_scale_factor}" | ||
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" | ||
) | ||
|
||
elif isinstance(image[0], torch.Tensor): | ||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) | ||
_, _, height, width = image.shape | ||
if self.do_resize and (height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0): | ||
raise ValueError( | ||
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.vae_scale_factor}" | ||
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" | ||
) | ||
|
||
# expected range [0,1], normalize to [-1,1] | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
do_normalize = self.do_normalize | ||
if image.min() < 0: | ||
warnings.warn( | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " | ||
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", | ||
FutureWarning, | ||
) | ||
do_normalize = False | ||
|
||
if do_normalize: | ||
image = self.normalize(image) | ||
|
||
return image | ||
|
||
def postprocess( | ||
self, | ||
image, | ||
output_type: str = "pil", | ||
): | ||
if isinstance(image, torch.Tensor) and output_type == "pt": | ||
return image | ||
|
||
image = self.pt_to_numpy(image) | ||
|
||
if output_type == "np": | ||
return image | ||
elif output_type == "pil": | ||
return self.numpy_to_pil(image) | ||
else: | ||
raise ValueError(f"Unsupported output_type {output_type}.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.