13
13
# limitations under the License.
14
14
15
15
16
- from typing import Optional , Tuple , Union
16
+ from typing import List , Optional , Tuple , Union
17
17
18
18
import numpy as np
19
19
import torch
20
20
21
21
import PIL
22
- from tqdm .auto import tqdm
23
22
24
23
from ...models import UNet2DModel
25
24
from ...pipeline_utils import DiffusionPipeline , ImagePipelineOutput
26
25
from ...schedulers import RePaintScheduler
26
+ from ...utils import PIL_INTERPOLATION , deprecate , logging
27
27
28
28
29
- def _preprocess_image (image : PIL .Image .Image ):
30
- image = np .array (image .convert ("RGB" ))
31
- image = image [None ].transpose (0 , 3 , 1 , 2 )
32
- image = torch .from_numpy (image ).to (dtype = torch .float32 ) / 127.5 - 1.0
29
+ logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
30
+
31
+
32
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
33
+ def _preprocess_image (image : Union [List , PIL .Image .Image , torch .Tensor ]):
34
+ if isinstance (image , torch .Tensor ):
35
+ return image
36
+ elif isinstance (image , PIL .Image .Image ):
37
+ image = [image ]
38
+
39
+ if isinstance (image [0 ], PIL .Image .Image ):
40
+ w , h = image [0 ].size
41
+ w , h = map (lambda x : x - x % 32 , (w , h )) # resize to integer multiple of 32
42
+
43
+ image = [np .array (i .resize ((w , h ), resample = PIL_INTERPOLATION ["lanczos" ]))[None , :] for i in image ]
44
+ image = np .concatenate (image , axis = 0 )
45
+ image = np .array (image ).astype (np .float32 ) / 255.0
46
+ image = image .transpose (0 , 3 , 1 , 2 )
47
+ image = 2.0 * image - 1.0
48
+ image = torch .from_numpy (image )
49
+ elif isinstance (image [0 ], torch .Tensor ):
50
+ image = torch .cat (image , dim = 0 )
33
51
return image
34
52
35
53
36
- def _preprocess_mask (mask : PIL .Image .Image ):
37
- mask = np .array (mask .convert ("L" ))
38
- mask = mask .astype (np .float32 ) / 255.0
39
- mask = mask [None , None ]
40
- mask [mask < 0.5 ] = 0
41
- mask [mask >= 0.5 ] = 1
42
- mask = torch .from_numpy (mask )
54
+ def _preprocess_mask (mask : Union [List , PIL .Image .Image , torch .Tensor ]):
55
+ if isinstance (mask , torch .Tensor ):
56
+ return mask
57
+ elif isinstance (mask , PIL .Image .Image ):
58
+ mask = [mask ]
59
+
60
+ if isinstance (mask [0 ], PIL .Image .Image ):
61
+ w , h = mask [0 ].size
62
+ w , h = map (lambda x : x - x % 32 , (w , h )) # resize to integer multiple of 32
63
+ mask = [np .array (m .convert ("L" ).resize ((w , h ), resample = PIL_INTERPOLATION ["nearest" ]))[None , :] for m in mask ]
64
+ mask = np .concatenate (mask , axis = 0 )
65
+ mask = mask .astype (np .float32 ) / 255.0
66
+ mask [mask < 0.5 ] = 0
67
+ mask [mask >= 0.5 ] = 1
68
+ mask = torch .from_numpy (mask )
69
+ elif isinstance (mask [0 ], torch .Tensor ):
70
+ mask = torch .cat (mask , dim = 0 )
43
71
return mask
44
72
45
73
@@ -54,19 +82,20 @@ def __init__(self, unet, scheduler):
54
82
@torch .no_grad ()
55
83
def __call__ (
56
84
self ,
57
- original_image : Union [torch .FloatTensor , PIL .Image .Image ],
58
- mask_image : Union [torch .FloatTensor , PIL .Image .Image ],
85
+ image : Union [torch .Tensor , PIL .Image .Image ],
86
+ mask_image : Union [torch .Tensor , PIL .Image .Image ],
59
87
num_inference_steps : int = 250 ,
60
88
eta : float = 0.0 ,
61
89
jump_length : int = 10 ,
62
90
jump_n_sample : int = 10 ,
63
91
generator : Optional [torch .Generator ] = None ,
64
92
output_type : Optional [str ] = "pil" ,
65
93
return_dict : bool = True ,
94
+ ** kwargs ,
66
95
) -> Union [ImagePipelineOutput , Tuple ]:
67
96
r"""
68
97
Args:
69
- original_image (`torch.FloatTensor` or `PIL.Image.Image`):
98
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
70
99
The original image to inpaint on.
71
100
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
72
101
The mask_image where 0.0 values define which part of the original image to inpaint (change).
@@ -97,27 +126,29 @@ def __call__(
97
126
generated images.
98
127
"""
99
128
100
- if not isinstance (original_image , torch .FloatTensor ):
101
- original_image = _preprocess_image (original_image )
102
- original_image = original_image .to (self .device )
103
- if not isinstance (mask_image , torch .FloatTensor ):
104
- mask_image = _preprocess_mask (mask_image )
105
- mask_image = mask_image .to (self .device )
129
+ message = "Please use `image` instead of `original_image`."
130
+ original_image = deprecate ("original_image" , "0.15.0" , message , take_from = kwargs )
131
+ original_image = original_image or image
132
+
133
+ original_image = _preprocess_image (original_image )
134
+ original_image = original_image .to (device = self .device , dtype = self .unet .dtype )
135
+ mask_image = _preprocess_mask (mask_image )
136
+ mask_image = mask_image .to (device = self .device , dtype = self .unet .dtype )
106
137
107
138
# sample gaussian noise to begin the loop
108
139
image = torch .randn (
109
140
original_image .shape ,
110
141
generator = generator ,
111
142
device = self .device ,
112
143
)
113
- image = image .to (self .device )
144
+ image = image .to (device = self .device , dtype = self . unet . dtype )
114
145
115
146
# set step values
116
147
self .scheduler .set_timesteps (num_inference_steps , jump_length , jump_n_sample , self .device )
117
148
self .scheduler .eta = eta
118
149
119
150
t_last = self .scheduler .timesteps [0 ] + 1
120
- for i , t in enumerate (tqdm (self .scheduler .timesteps )):
151
+ for i , t in enumerate (self . progress_bar (self .scheduler .timesteps )):
121
152
if t < t_last :
122
153
# predict the noise residual
123
154
model_output = self .unet (image , t ).sample
0 commit comments