Skip to content

Commit 13994b2

Browse files
authored
RePaint fast tests and API conforming (#1701)
* add fast tests * better tests and fp16 * batch fixes * Reuse preprocessing * quickfix
1 parent ea90bf2 commit 13994b2

File tree

3 files changed

+126
-28
lines changed

3 files changed

+126
-28
lines changed

src/diffusers/pipelines/repaint/pipeline_repaint.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,61 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Optional, Tuple, Union
16+
from typing import List, Optional, Tuple, Union
1717

1818
import numpy as np
1919
import torch
2020

2121
import PIL
22-
from tqdm.auto import tqdm
2322

2423
from ...models import UNet2DModel
2524
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2625
from ...schedulers import RePaintScheduler
26+
from ...utils import PIL_INTERPOLATION, deprecate, logging
2727

2828

29-
def _preprocess_image(image: PIL.Image.Image):
30-
image = np.array(image.convert("RGB"))
31-
image = image[None].transpose(0, 3, 1, 2)
32-
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
29+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30+
31+
32+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
33+
def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
34+
if isinstance(image, torch.Tensor):
35+
return image
36+
elif isinstance(image, PIL.Image.Image):
37+
image = [image]
38+
39+
if isinstance(image[0], PIL.Image.Image):
40+
w, h = image[0].size
41+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
42+
43+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
44+
image = np.concatenate(image, axis=0)
45+
image = np.array(image).astype(np.float32) / 255.0
46+
image = image.transpose(0, 3, 1, 2)
47+
image = 2.0 * image - 1.0
48+
image = torch.from_numpy(image)
49+
elif isinstance(image[0], torch.Tensor):
50+
image = torch.cat(image, dim=0)
3351
return image
3452

3553

36-
def _preprocess_mask(mask: PIL.Image.Image):
37-
mask = np.array(mask.convert("L"))
38-
mask = mask.astype(np.float32) / 255.0
39-
mask = mask[None, None]
40-
mask[mask < 0.5] = 0
41-
mask[mask >= 0.5] = 1
42-
mask = torch.from_numpy(mask)
54+
def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]):
55+
if isinstance(mask, torch.Tensor):
56+
return mask
57+
elif isinstance(mask, PIL.Image.Image):
58+
mask = [mask]
59+
60+
if isinstance(mask[0], PIL.Image.Image):
61+
w, h = mask[0].size
62+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
63+
mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask]
64+
mask = np.concatenate(mask, axis=0)
65+
mask = mask.astype(np.float32) / 255.0
66+
mask[mask < 0.5] = 0
67+
mask[mask >= 0.5] = 1
68+
mask = torch.from_numpy(mask)
69+
elif isinstance(mask[0], torch.Tensor):
70+
mask = torch.cat(mask, dim=0)
4371
return mask
4472

4573

@@ -54,19 +82,20 @@ def __init__(self, unet, scheduler):
5482
@torch.no_grad()
5583
def __call__(
5684
self,
57-
original_image: Union[torch.FloatTensor, PIL.Image.Image],
58-
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
85+
image: Union[torch.Tensor, PIL.Image.Image],
86+
mask_image: Union[torch.Tensor, PIL.Image.Image],
5987
num_inference_steps: int = 250,
6088
eta: float = 0.0,
6189
jump_length: int = 10,
6290
jump_n_sample: int = 10,
6391
generator: Optional[torch.Generator] = None,
6492
output_type: Optional[str] = "pil",
6593
return_dict: bool = True,
94+
**kwargs,
6695
) -> Union[ImagePipelineOutput, Tuple]:
6796
r"""
6897
Args:
69-
original_image (`torch.FloatTensor` or `PIL.Image.Image`):
98+
image (`torch.FloatTensor` or `PIL.Image.Image`):
7099
The original image to inpaint on.
71100
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
72101
The mask_image where 0.0 values define which part of the original image to inpaint (change).
@@ -97,27 +126,29 @@ def __call__(
97126
generated images.
98127
"""
99128

100-
if not isinstance(original_image, torch.FloatTensor):
101-
original_image = _preprocess_image(original_image)
102-
original_image = original_image.to(self.device)
103-
if not isinstance(mask_image, torch.FloatTensor):
104-
mask_image = _preprocess_mask(mask_image)
105-
mask_image = mask_image.to(self.device)
129+
message = "Please use `image` instead of `original_image`."
130+
original_image = deprecate("original_image", "0.15.0", message, take_from=kwargs)
131+
original_image = original_image or image
132+
133+
original_image = _preprocess_image(original_image)
134+
original_image = original_image.to(device=self.device, dtype=self.unet.dtype)
135+
mask_image = _preprocess_mask(mask_image)
136+
mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype)
106137

107138
# sample gaussian noise to begin the loop
108139
image = torch.randn(
109140
original_image.shape,
110141
generator=generator,
111142
device=self.device,
112143
)
113-
image = image.to(self.device)
144+
image = image.to(device=self.device, dtype=self.unet.dtype)
114145

115146
# set step values
116147
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
117148
self.scheduler.eta = eta
118149

119150
t_last = self.scheduler.timesteps[0] + 1
120-
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
151+
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
121152
if t < t_last:
122153
# predict the noise residual
123154
model_output = self.unet(image, t).sample

src/diffusers/schedulers/scheduling_repaint.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,13 @@ def step(
270270
# been observed.
271271

272272
# 5. Add noise
273-
noise = torch.randn(
274-
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
275-
)
273+
device = model_output.device
274+
if device.type == "mps":
275+
# randn does not work reproducibly on mps
276+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
277+
noise = noise.to(device)
278+
else:
279+
noise = torch.randn(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
276280
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
277281

278282
variance = 0
@@ -305,7 +309,12 @@ def undo_step(self, sample, timestep, generator=None):
305309

306310
for i in range(n):
307311
beta = self.betas[timestep + i]
308-
noise = torch.randn(sample.shape, generator=generator, device=sample.device)
312+
if sample.device.type == "mps":
313+
# randn does not work reproducibly on mps
314+
noise = torch.randn(sample.shape, dtype=sample.dtype, generator=generator)
315+
noise = noise.to(sample.device)
316+
else:
317+
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
309318

310319
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
311320
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise

tests/pipelines/repaint/test_repaint.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,68 @@
2121
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
2222
from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device
2323

24+
from ...test_pipelines_common import PipelineTesterMixin
25+
2426

2527
torch.backends.cuda.matmul.allow_tf32 = False
2628

2729

30+
class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
31+
pipeline_class = RePaintPipeline
32+
test_cpu_offload = False
33+
34+
def get_dummy_components(self):
35+
torch.manual_seed(0)
36+
torch.manual_seed(0)
37+
unet = UNet2DModel(
38+
block_out_channels=(32, 64),
39+
layers_per_block=2,
40+
sample_size=32,
41+
in_channels=3,
42+
out_channels=3,
43+
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
44+
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
45+
)
46+
scheduler = RePaintScheduler()
47+
components = {"unet": unet, "scheduler": scheduler}
48+
return components
49+
50+
def get_dummy_inputs(self, device, seed=0):
51+
if str(device).startswith("mps"):
52+
generator = torch.manual_seed(seed)
53+
else:
54+
generator = torch.Generator(device=device).manual_seed(seed)
55+
image = np.random.RandomState(seed).standard_normal((1, 3, 32, 32))
56+
image = torch.from_numpy(image).to(device=device, dtype=torch.float32)
57+
mask = (image > 0).to(device=device, dtype=torch.float32)
58+
inputs = {
59+
"image": image,
60+
"mask_image": mask,
61+
"generator": generator,
62+
"num_inference_steps": 5,
63+
"eta": 0.0,
64+
"jump_length": 2,
65+
"jump_n_sample": 2,
66+
"output_type": "numpy",
67+
}
68+
return inputs
69+
70+
def test_repaint(self):
71+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
72+
components = self.get_dummy_components()
73+
sd_pipe = RePaintPipeline(**components)
74+
sd_pipe = sd_pipe.to(device)
75+
sd_pipe.set_progress_bar_config(disable=None)
76+
77+
inputs = self.get_dummy_inputs(device)
78+
image = sd_pipe(**inputs).images
79+
image_slice = image[0, -3:, -3:, -1]
80+
81+
assert image.shape == (1, 32, 32, 3)
82+
expected_slice = np.array([1.0000, 0.5426, 0.5497, 0.2200, 1.0000, 1.0000, 0.5623, 1.0000, 0.6274])
83+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
84+
85+
2886
@slow
2987
@require_torch_gpu
3088
class RepaintPipelineIntegrationTests(unittest.TestCase):

0 commit comments

Comments
 (0)