@@ -30,14 +30,17 @@ class VaeImageProcessor(ConfigMixin):
30
30
31
31
Args:
32
32
do_resize (`bool`, *optional*, defaults to `True`):
33
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
33
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
34
+ `height` and `width` arguments from `preprocess` method
34
35
vae_scale_factor (`int`, *optional*, defaults to `8`):
35
36
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
36
37
factor.
37
38
resample (`str`, *optional*, defaults to `lanczos`):
38
39
Resampling filter to use when resizing the image.
39
40
do_normalize (`bool`, *optional*, defaults to `True`):
40
41
Whether to normalize the image to [-1,1]
42
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
43
+ Whether to convert the images to RGB format.
41
44
"""
42
45
43
46
config_name = CONFIG_NAME
@@ -49,11 +52,12 @@ def __init__(
49
52
vae_scale_factor : int = 8 ,
50
53
resample : str = "lanczos" ,
51
54
do_normalize : bool = True ,
55
+ do_convert_rgb : bool = False ,
52
56
):
53
57
super ().__init__ ()
54
58
55
59
@staticmethod
56
- def numpy_to_pil (images ) :
60
+ def numpy_to_pil (images : np . ndarray ) -> PIL . Image . Image :
57
61
"""
58
62
Convert a numpy image or a batch of images to a PIL image.
59
63
"""
@@ -69,7 +73,19 @@ def numpy_to_pil(images):
69
73
return pil_images
70
74
71
75
@staticmethod
72
- def numpy_to_pt (images ):
76
+ def pil_to_numpy (images : Union [List [PIL .Image .Image ], PIL .Image .Image ]) -> np .ndarray :
77
+ """
78
+ Convert a PIL image or a list of PIL images to numpy arrays.
79
+ """
80
+ if not isinstance (images , list ):
81
+ images = [images ]
82
+ images = [np .array (image ).astype (np .float32 ) / 255.0 for image in images ]
83
+ images = np .stack (images , axis = 0 )
84
+
85
+ return images
86
+
87
+ @staticmethod
88
+ def numpy_to_pt (images : np .ndarray ) -> torch .FloatTensor :
73
89
"""
74
90
Convert a numpy image to a pytorch tensor
75
91
"""
@@ -80,7 +96,7 @@ def numpy_to_pt(images):
80
96
return images
81
97
82
98
@staticmethod
83
- def pt_to_numpy (images ) :
99
+ def pt_to_numpy (images : torch . FloatTensor ) -> np . ndarray :
84
100
"""
85
101
Convert a pytorch tensor to a numpy image
86
102
"""
@@ -101,18 +117,39 @@ def denormalize(images):
101
117
"""
102
118
return (images / 2 + 0.5 ).clamp (0 , 1 )
103
119
104
- def resize (self , images : PIL .Image .Image ) -> PIL .Image .Image :
120
+ @staticmethod
121
+ def convert_to_rgb (image : PIL .Image .Image ) -> PIL .Image .Image :
122
+ """
123
+ Converts an image to RGB format.
124
+ """
125
+ image = image .convert ("RGB" )
126
+ return image
127
+
128
+ def resize (
129
+ self ,
130
+ image : PIL .Image .Image ,
131
+ height : Optional [int ] = None ,
132
+ width : Optional [int ] = None ,
133
+ ) -> PIL .Image .Image :
105
134
"""
106
135
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
107
136
"""
108
- w , h = images .size
109
- w , h = (x - x % self .config .vae_scale_factor for x in (w , h )) # resize to integer multiple of vae_scale_factor
110
- images = images .resize ((w , h ), resample = PIL_INTERPOLATION [self .config .resample ])
111
- return images
137
+ if height is None :
138
+ height = image .height
139
+ if width is None :
140
+ width = image .width
141
+
142
+ width , height = (
143
+ x - x % self .config .vae_scale_factor for x in (width , height )
144
+ ) # resize to integer multiple of vae_scale_factor
145
+ image = image .resize ((width , height ), resample = PIL_INTERPOLATION [self .config .resample ])
146
+ return image
112
147
113
148
def preprocess (
114
149
self ,
115
150
image : Union [torch .FloatTensor , PIL .Image .Image , np .ndarray ],
151
+ height : Optional [int ] = None ,
152
+ width : Optional [int ] = None ,
116
153
) -> torch .Tensor :
117
154
"""
118
155
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
@@ -126,10 +163,11 @@ def preprocess(
126
163
)
127
164
128
165
if isinstance (image [0 ], PIL .Image .Image ):
166
+ if self .config .do_convert_rgb :
167
+ image = [self .convert_to_rgb (i ) for i in image ]
129
168
if self .config .do_resize :
130
- image = [self .resize (i ) for i in image ]
131
- image = [np .array (i ).astype (np .float32 ) / 255.0 for i in image ]
132
- image = np .stack (image , axis = 0 ) # to np
169
+ image = [self .resize (i , height , width ) for i in image ]
170
+ image = self .pil_to_numpy (image ) # to np
133
171
image = self .numpy_to_pt (image ) # to pt
134
172
135
173
elif isinstance (image [0 ], np .ndarray ):
@@ -146,7 +184,12 @@ def preprocess(
146
184
147
185
elif isinstance (image [0 ], torch .Tensor ):
148
186
image = torch .cat (image , axis = 0 ) if image [0 ].ndim == 4 else torch .stack (image , axis = 0 )
149
- _ , _ , height , width = image .shape
187
+ _ , channel , height , width = image .shape
188
+
189
+ # don't need any preprocess if the image is latents
190
+ if channel == 4 :
191
+ return image
192
+
150
193
if self .config .do_resize and (
151
194
height % self .config .vae_scale_factor != 0 or width % self .config .vae_scale_factor != 0
152
195
):
0 commit comments