Skip to content

Commit a607772

Browse files
yiyixuxuyiyixuxupatrickvonplatenpcuencasayakpaul
authored
[WIP]Vae preprocessor refactor (PR1) (huggingface#3557)
VaeImageProcessor.preprocess refactor * refactored VaeImageProcessor - allow passing optional height and width argument to resize() - add convert_to_rgb * refactored prepare_latents method for img2img pipelines so that if we pass latents directly as image input, it will not encode it again * added a test in test_pipelines_common.py to test latents as image inputs * refactored img2img pipelines that accept latents as image: - controlnet img2img, stable diffusion img2img , instruct_pix2pix --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 6f7bc4e commit a607772

15 files changed

+427
-312
lines changed

image_processor.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,17 @@ class VaeImageProcessor(ConfigMixin):
3030
3131
Args:
3232
do_resize (`bool`, *optional*, defaults to `True`):
33-
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
33+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
34+
`height` and `width` arguments from `preprocess` method
3435
vae_scale_factor (`int`, *optional*, defaults to `8`):
3536
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
3637
factor.
3738
resample (`str`, *optional*, defaults to `lanczos`):
3839
Resampling filter to use when resizing the image.
3940
do_normalize (`bool`, *optional*, defaults to `True`):
4041
Whether to normalize the image to [-1,1]
42+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
43+
Whether to convert the images to RGB format.
4144
"""
4245

4346
config_name = CONFIG_NAME
@@ -49,11 +52,12 @@ def __init__(
4952
vae_scale_factor: int = 8,
5053
resample: str = "lanczos",
5154
do_normalize: bool = True,
55+
do_convert_rgb: bool = False,
5256
):
5357
super().__init__()
5458

5559
@staticmethod
56-
def numpy_to_pil(images):
60+
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
5761
"""
5862
Convert a numpy image or a batch of images to a PIL image.
5963
"""
@@ -69,7 +73,19 @@ def numpy_to_pil(images):
6973
return pil_images
7074

7175
@staticmethod
72-
def numpy_to_pt(images):
76+
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
77+
"""
78+
Convert a PIL image or a list of PIL images to numpy arrays.
79+
"""
80+
if not isinstance(images, list):
81+
images = [images]
82+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
83+
images = np.stack(images, axis=0)
84+
85+
return images
86+
87+
@staticmethod
88+
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
7389
"""
7490
Convert a numpy image to a pytorch tensor
7591
"""
@@ -80,7 +96,7 @@ def numpy_to_pt(images):
8096
return images
8197

8298
@staticmethod
83-
def pt_to_numpy(images):
99+
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
84100
"""
85101
Convert a pytorch tensor to a numpy image
86102
"""
@@ -101,18 +117,39 @@ def denormalize(images):
101117
"""
102118
return (images / 2 + 0.5).clamp(0, 1)
103119

104-
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
120+
@staticmethod
121+
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
122+
"""
123+
Converts an image to RGB format.
124+
"""
125+
image = image.convert("RGB")
126+
return image
127+
128+
def resize(
129+
self,
130+
image: PIL.Image.Image,
131+
height: Optional[int] = None,
132+
width: Optional[int] = None,
133+
) -> PIL.Image.Image:
105134
"""
106135
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
107136
"""
108-
w, h = images.size
109-
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
110-
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample])
111-
return images
137+
if height is None:
138+
height = image.height
139+
if width is None:
140+
width = image.width
141+
142+
width, height = (
143+
x - x % self.config.vae_scale_factor for x in (width, height)
144+
) # resize to integer multiple of vae_scale_factor
145+
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
146+
return image
112147

113148
def preprocess(
114149
self,
115150
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
151+
height: Optional[int] = None,
152+
width: Optional[int] = None,
116153
) -> torch.Tensor:
117154
"""
118155
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
@@ -126,10 +163,11 @@ def preprocess(
126163
)
127164

128165
if isinstance(image[0], PIL.Image.Image):
166+
if self.config.do_convert_rgb:
167+
image = [self.convert_to_rgb(i) for i in image]
129168
if self.config.do_resize:
130-
image = [self.resize(i) for i in image]
131-
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
132-
image = np.stack(image, axis=0) # to np
169+
image = [self.resize(i, height, width) for i in image]
170+
image = self.pil_to_numpy(image) # to np
133171
image = self.numpy_to_pt(image) # to pt
134172

135173
elif isinstance(image[0], np.ndarray):
@@ -146,7 +184,12 @@ def preprocess(
146184

147185
elif isinstance(image[0], torch.Tensor):
148186
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
149-
_, _, height, width = image.shape
187+
_, channel, height, width = image.shape
188+
189+
# don't need any preprocess if the image is latents
190+
if channel == 4:
191+
return image
192+
150193
if self.config.do_resize and (
151194
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
152195
):

pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@
6969

7070
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
7171
def preprocess(image):
72+
warnings.warn(
73+
"The preprocess method is deprecated and will be removed in a future version. Please"
74+
" use VaeImageProcessor.preprocess instead",
75+
FutureWarning,
76+
)
7277
if isinstance(image, torch.Tensor):
7378
return image
7479
elif isinstance(image, PIL.Image.Image):
@@ -538,21 +543,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
538543
image = image.to(device=device, dtype=dtype)
539544

540545
batch_size = batch_size * num_images_per_prompt
541-
if isinstance(generator, list) and len(generator) != batch_size:
542-
raise ValueError(
543-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
544-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
545-
)
546546

547-
if isinstance(generator, list):
548-
init_latents = [
549-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
550-
]
551-
init_latents = torch.cat(init_latents, dim=0)
547+
if image.shape[1] == 4:
548+
init_latents = image
549+
552550
else:
553-
init_latents = self.vae.encode(image).latent_dist.sample(generator)
551+
if isinstance(generator, list) and len(generator) != batch_size:
552+
raise ValueError(
553+
f"You have passed a list of generators of length {len(generator)}, but requested an effective"
554+
f" batch size of {batch_size}. Make sure the batch size matches the length of the generators."
555+
)
556+
557+
elif isinstance(generator, list):
558+
init_latents = [
559+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
560+
]
561+
init_latents = torch.cat(init_latents, dim=0)
562+
else:
563+
init_latents = self.vae.encode(image).latent_dist.sample(generator)
554564

555-
init_latents = self.vae.config.scaling_factor * init_latents
565+
init_latents = self.vae.config.scaling_factor * init_latents
556566

557567
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
558568
# expand init_latents for batch_size
@@ -586,7 +596,14 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
586596
def __call__(
587597
self,
588598
prompt: Union[str, List[str]] = None,
589-
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
599+
image: Union[
600+
torch.FloatTensor,
601+
PIL.Image.Image,
602+
np.ndarray,
603+
List[torch.FloatTensor],
604+
List[PIL.Image.Image],
605+
List[np.ndarray],
606+
] = None,
590607
strength: float = 0.8,
591608
num_inference_steps: Optional[int] = 50,
592609
guidance_scale: Optional[float] = 7.5,
@@ -609,9 +626,10 @@ def __call__(
609626
prompt (`str` or `List[str]`, *optional*):
610627
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
611628
instead.
612-
image (`torch.FloatTensor` or `PIL.Image.Image`):
629+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
613630
`Image`, or tensor representing an image batch, that will be used as the starting point for the
614-
process.
631+
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
632+
again.
615633
strength (`float`, *optional*, defaults to 0.8):
616634
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
617635
will be used as a starting point, adding more noise to it the larger the `strength`. The number of

0 commit comments

Comments
 (0)