-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[WIP]Vae preprocessor refactor (PR1) #3557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 54 commits
4b15eb4
7cd26ef
d58276c
4c75754
db25d47
f779977
c9ef024
bc66a2a
bf0331d
5c597c5
91423d7
19747b7
96a2a05
f7be498
df73ada
eb23e59
1655b1b
076afc8
706a53e
d63c444
cc7fffc
007deda
188de72
eae044f
0f01f9c
3a95cc6
d410a1b
70904ac
48896b3
a412fac
6c44410
1ad9c35
7bfeee8
0d10b3b
e0b6d96
b1353fb
4a35707
53fb437
958b527
02688a5
7394a66
5a7cae0
13671d3
b0f9e4c
a24e41e
90110ec
1c3457e
76bc363
f8a0705
7af8a1f
0ca3473
c822a27
ee38476
2f13a14
b02ca59
9a5dd1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,14 +30,17 @@ class VaeImageProcessor(ConfigMixin): | |
|
||
Args: | ||
do_resize (`bool`, *optional*, defaults to `True`): | ||
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. | ||
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept | ||
`height` and `width` arguments from `preprocess` method | ||
vae_scale_factor (`int`, *optional*, defaults to `8`): | ||
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this | ||
factor. | ||
resample (`str`, *optional*, defaults to `lanczos`): | ||
Resampling filter to use when resizing the image. | ||
do_normalize (`bool`, *optional*, defaults to `True`): | ||
Whether to normalize the image to [-1,1] | ||
do_convert_rgb (`bool`, *optional*, defaults to be `False`): | ||
Whether to convert the images to RGB format. | ||
""" | ||
|
||
config_name = CONFIG_NAME | ||
|
@@ -49,11 +52,12 @@ def __init__( | |
vae_scale_factor: int = 8, | ||
resample: str = "lanczos", | ||
do_normalize: bool = True, | ||
do_convert_rgb: bool = False, | ||
): | ||
super().__init__() | ||
|
||
@staticmethod | ||
def numpy_to_pil(images): | ||
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: | ||
""" | ||
Convert a numpy image or a batch of images to a PIL image. | ||
""" | ||
|
@@ -69,7 +73,19 @@ def numpy_to_pil(images): | |
return pil_images | ||
|
||
@staticmethod | ||
def numpy_to_pt(images): | ||
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: | ||
""" | ||
Convert a PIL image or a list of PIL images to numpy arrays. | ||
""" | ||
if not isinstance(images, list): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a need to also check if |
||
images = [images] | ||
images = [np.array(image).astype(np.float32) / 255.0 for image in images] | ||
images = np.stack(images, axis=0) | ||
|
||
return images | ||
|
||
@staticmethod | ||
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: | ||
""" | ||
Convert a numpy image to a pytorch tensor | ||
""" | ||
|
@@ -80,7 +96,7 @@ def numpy_to_pt(images): | |
return images | ||
|
||
@staticmethod | ||
def pt_to_numpy(images): | ||
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: | ||
""" | ||
Convert a pytorch tensor to a numpy image | ||
""" | ||
|
@@ -101,18 +117,39 @@ def denormalize(images): | |
""" | ||
return (images / 2 + 0.5).clamp(0, 1) | ||
|
||
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: | ||
@staticmethod | ||
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: | ||
""" | ||
Converts an image to RGB format. | ||
""" | ||
image = image.convert("RGB") | ||
return image | ||
|
||
def resize( | ||
self, | ||
image: PIL.Image.Image, | ||
height: Optional[int] = None, | ||
width: Optional[int] = None, | ||
) -> PIL.Image.Image: | ||
""" | ||
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` | ||
""" | ||
w, h = images.size | ||
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor | ||
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample]) | ||
return images | ||
if height is None: | ||
height = image.height | ||
if width is None: | ||
width = image.width | ||
|
||
width, height = ( | ||
x - x % self.config.vae_scale_factor for x in (width, height) | ||
) # resize to integer multiple of vae_scale_factor | ||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) | ||
Comment on lines
+142
to
+145
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👌 |
||
return image | ||
|
||
def preprocess( | ||
self, | ||
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], | ||
height: Optional[int] = None, | ||
width: Optional[int] = None, | ||
) -> torch.Tensor: | ||
""" | ||
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors" | ||
|
@@ -126,10 +163,11 @@ def preprocess( | |
) | ||
|
||
if isinstance(image[0], PIL.Image.Image): | ||
if self.config.do_convert_rgb: | ||
image = [self.convert_to_rgb(i) for i in image] | ||
if self.config.do_resize: | ||
image = [self.resize(i) for i in image] | ||
image = [np.array(i).astype(np.float32) / 255.0 for i in image] | ||
image = np.stack(image, axis=0) # to np | ||
image = [self.resize(i, height, width) for i in image] | ||
image = self.pil_to_numpy(image) # to np | ||
image = self.numpy_to_pt(image) # to pt | ||
|
||
elif isinstance(image[0], np.ndarray): | ||
|
@@ -146,7 +184,12 @@ def preprocess( | |
|
||
elif isinstance(image[0], torch.Tensor): | ||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) | ||
_, _, height, width = image.shape | ||
_, channel, height, width = image.shape | ||
|
||
# don't need any preprocess if the image is latents | ||
if channel == 4: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! Good for me! |
||
return image | ||
|
||
if self.config.do_resize and ( | ||
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 | ||
): | ||
|
Uh oh!
There was an error while loading. Please reload this page.