Skip to content

Commit 7fa90fb

Browse files
rupertmenneerpatrickvonplaten
authored andcommitted
StableDiffusionInpaintingPipeline - resize image w.r.t height and width (huggingface#3322)
* StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy. * Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution * Added a resolution test to StableDiffusionInpaintPipelineSlowTests this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent af6e35a commit 7fa90fb

File tree

2 files changed

+104
-57
lines changed

2 files changed

+104
-57
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3737

3838

39-
def prepare_mask_and_masked_image(image, mask):
39+
def prepare_mask_and_masked_image(image, mask, height, width):
4040
"""
4141
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
4242
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
@@ -64,6 +64,13 @@ def prepare_mask_and_masked_image(image, mask):
6464
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
6565
dimensions: ``batch x channels x height x width``.
6666
"""
67+
68+
if image is None:
69+
raise ValueError("`image` input cannot be undefined.")
70+
71+
if mask is None:
72+
raise ValueError("`mask_image` input cannot be undefined.")
73+
6774
if isinstance(image, torch.Tensor):
6875
if not isinstance(mask, torch.Tensor):
6976
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
@@ -111,8 +118,9 @@ def prepare_mask_and_masked_image(image, mask):
111118
# preprocess image
112119
if isinstance(image, (PIL.Image.Image, np.ndarray)):
113120
image = [image]
114-
115121
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
122+
# resize all images w.r.t passed height an width
123+
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
116124
image = [np.array(i.convert("RGB"))[None, :] for i in image]
117125
image = np.concatenate(image, axis=0)
118126
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
@@ -126,6 +134,7 @@ def prepare_mask_and_masked_image(image, mask):
126134
mask = [mask]
127135

128136
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
137+
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
129138
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
130139
mask = mask.astype(np.float32) / 255.0
131140
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
@@ -799,12 +808,6 @@ def __call__(
799808
negative_prompt_embeds,
800809
)
801810

802-
if image is None:
803-
raise ValueError("`image` input cannot be undefined.")
804-
805-
if mask_image is None:
806-
raise ValueError("`mask_image` input cannot be undefined.")
807-
808811
# 2. Define call parameters
809812
if prompt is not None and isinstance(prompt, str):
810813
batch_size = 1
@@ -830,8 +833,8 @@ def __call__(
830833
negative_prompt_embeds=negative_prompt_embeds,
831834
)
832835

833-
# 4. Preprocess mask and image
834-
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
836+
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
837+
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)
835838

836839
# 5. set timesteps
837840
self.scheduler.set_timesteps(num_inference_steps, device=device)

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 91 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,25 @@ def test_inpaint_compile(self):
303303
assert np.abs(expected_slice - image_slice).max() < 1e-4
304304
assert np.abs(expected_slice - image_slice).max() < 1e-3
305305

306+
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
307+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
308+
"runwayml/stable-diffusion-inpainting", safety_checker=None
309+
)
310+
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
311+
pipe.to(torch_device)
312+
pipe.set_progress_bar_config(disable=None)
313+
pipe.enable_attention_slicing()
314+
315+
inputs = self.get_inputs(torch_device)
316+
# change input image to a random size (one that would cause a tensor mismatch error)
317+
inputs['image'] = inputs['image'].resize((127,127))
318+
inputs['mask_image'] = inputs['mask_image'].resize((127,127))
319+
inputs['height'] = 128
320+
inputs['width'] = 128
321+
image = pipe(**inputs).images
322+
# verify that the returned image has the same height and width as the input height and width
323+
assert image.shape == (1, inputs['height'], inputs['width'], 3)
324+
306325

307326
@nightly
308327
@require_torch_gpu
@@ -400,21 +419,22 @@ def test_inpaint_dpm(self):
400419

401420
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
402421
def test_pil_inputs(self):
403-
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
422+
height, width = 32, 32
423+
im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
404424
im = Image.fromarray(im)
405-
mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
425+
mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5
406426
mask = Image.fromarray((mask * 255).astype(np.uint8))
407427

408-
t_mask, t_masked = prepare_mask_and_masked_image(im, mask)
428+
t_mask, t_masked = prepare_mask_and_masked_image(im, mask, height, width)
409429

410430
self.assertTrue(isinstance(t_mask, torch.Tensor))
411431
self.assertTrue(isinstance(t_masked, torch.Tensor))
412432

413433
self.assertEqual(t_mask.ndim, 4)
414434
self.assertEqual(t_masked.ndim, 4)
415435

416-
self.assertEqual(t_mask.shape, (1, 1, 32, 32))
417-
self.assertEqual(t_masked.shape, (1, 3, 32, 32))
436+
self.assertEqual(t_mask.shape, (1, 1, height, width))
437+
self.assertEqual(t_masked.shape, (1, 3, height, width))
418438

419439
self.assertTrue(t_mask.dtype == torch.float32)
420440
self.assertTrue(t_masked.dtype == torch.float32)
@@ -427,141 +447,165 @@ def test_pil_inputs(self):
427447
self.assertTrue(t_mask.sum() > 0.0)
428448

429449
def test_np_inputs(self):
430-
im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
450+
height, width = 32, 32
451+
452+
im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
431453
im_pil = Image.fromarray(im_np)
432-
mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
454+
mask_np = np.random.randint(0, 255, (height, width,), dtype=np.uint8) > 127.5
433455
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
434456

435-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
436-
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil)
457+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
458+
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width)
437459

438460
self.assertTrue((t_mask_np == t_mask_pil).all())
439461
self.assertTrue((t_masked_np == t_masked_pil).all())
440462

441463
def test_torch_3D_2D_inputs(self):
442-
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
443-
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
464+
height, width = 32, 32
465+
466+
im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8)
467+
mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5
444468
im_np = im_tensor.numpy().transpose(1, 2, 0)
445469
mask_np = mask_tensor.numpy()
446470

447-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
448-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
471+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
472+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
449473

450474
self.assertTrue((t_mask_tensor == t_mask_np).all())
451475
self.assertTrue((t_masked_tensor == t_masked_np).all())
452476

453477
def test_torch_3D_3D_inputs(self):
454-
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
455-
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
478+
height, width = 32, 32
479+
480+
im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8)
481+
mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5
456482
im_np = im_tensor.numpy().transpose(1, 2, 0)
457483
mask_np = mask_tensor.numpy()[0]
458484

459-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
460-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
485+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
486+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
461487

462488
self.assertTrue((t_mask_tensor == t_mask_np).all())
463489
self.assertTrue((t_masked_tensor == t_masked_np).all())
464490

465491
def test_torch_4D_2D_inputs(self):
466-
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
467-
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
492+
height, width = 32, 32
493+
494+
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
495+
mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5
468496
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
469497
mask_np = mask_tensor.numpy()
470498

471-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
472-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
499+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
500+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
473501

474502
self.assertTrue((t_mask_tensor == t_mask_np).all())
475503
self.assertTrue((t_masked_tensor == t_masked_np).all())
476504

477505
def test_torch_4D_3D_inputs(self):
478-
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
479-
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
506+
height, width = 32, 32
507+
508+
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
509+
mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5
480510
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
481511
mask_np = mask_tensor.numpy()[0]
482512

483-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
484-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
513+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
514+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
485515

486516
self.assertTrue((t_mask_tensor == t_mask_np).all())
487517
self.assertTrue((t_masked_tensor == t_masked_np).all())
488518

489519
def test_torch_4D_4D_inputs(self):
490-
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
491-
mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5
520+
height, width = 32, 32
521+
522+
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
523+
mask_tensor = torch.randint(0, 255, (1, 1, height, width,), dtype=torch.uint8) > 127.5
492524
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
493525
mask_np = mask_tensor.numpy()[0][0]
494526

495-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
496-
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
527+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
528+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
497529

498530
self.assertTrue((t_mask_tensor == t_mask_np).all())
499531
self.assertTrue((t_masked_tensor == t_masked_np).all())
500532

501533
def test_torch_batch_4D_3D(self):
502-
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
503-
mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5
534+
height, width = 32, 32
535+
536+
im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8)
537+
mask_tensor = torch.randint(0, 255, (2, height, width,), dtype=torch.uint8) > 127.5
504538

505539
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
506540
mask_nps = [mask.numpy() for mask in mask_tensor]
507541

508-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
509-
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
542+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
543+
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
510544
t_mask_np = torch.cat([n[0] for n in nps])
511545
t_masked_np = torch.cat([n[1] for n in nps])
512546

513547
self.assertTrue((t_mask_tensor == t_mask_np).all())
514548
self.assertTrue((t_masked_tensor == t_masked_np).all())
515549

516550
def test_torch_batch_4D_4D(self):
517-
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
518-
mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5
551+
height, width = 32, 32
552+
553+
im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8)
554+
mask_tensor = torch.randint(0, 255, (2, 1, height, width,), dtype=torch.uint8) > 127.5
519555

520556
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
521557
mask_nps = [mask.numpy()[0] for mask in mask_tensor]
522558

523-
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
524-
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
559+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
560+
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
525561
t_mask_np = torch.cat([n[0] for n in nps])
526562
t_masked_np = torch.cat([n[1] for n in nps])
527563

528564
self.assertTrue((t_mask_tensor == t_mask_np).all())
529565
self.assertTrue((t_masked_tensor == t_masked_np).all())
530566

531567
def test_shape_mismatch(self):
568+
height, width = 32, 32
569+
532570
# test height and width
533571
with self.assertRaises(AssertionError):
534-
prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64))
572+
prepare_mask_and_masked_image(torch.randn(3, height, width,), torch.randn(64, 64), height, width)
535573
# test batch dim
536574
with self.assertRaises(AssertionError):
537-
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64))
575+
prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 64, 64), height, width)
538576
# test batch dim
539577
with self.assertRaises(AssertionError):
540-
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64))
578+
prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 1, 64, 64), height, width)
541579

542580
def test_type_mismatch(self):
581+
height, width = 32, 32
582+
543583
# test tensors-only
544584
with self.assertRaises(TypeError):
545-
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy())
585+
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.rand(3, height, width,).numpy(), height, width)
546586
# test tensors-only
547587
with self.assertRaises(TypeError):
548-
prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32))
588+
prepare_mask_and_masked_image(torch.rand(3, height, width,).numpy(), torch.rand(3, height, width,), height, width)
549589

550590
def test_channels_first(self):
591+
height, width = 32, 32
592+
551593
# test channels first for 3D tensors
552594
with self.assertRaises(AssertionError):
553-
prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32))
595+
prepare_mask_and_masked_image(torch.rand(height, width, 3), torch.rand(3, height, width,), height, width)
554596

555597
def test_tensor_range(self):
598+
height, width = 32, 32
599+
556600
# test im <= 1
557601
with self.assertRaises(ValueError):
558-
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32))
602+
prepare_mask_and_masked_image(torch.ones(3, height, width,) * 2, torch.rand(height, width,), height, width)
559603
# test im >= -1
560604
with self.assertRaises(ValueError):
561-
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32))
605+
prepare_mask_and_masked_image(torch.ones(3, height, width,) * (-2), torch.rand(height, width,), height, width)
562606
# test mask <= 1
563607
with self.assertRaises(ValueError):
564-
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2)
608+
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * 2, height, width)
565609
# test mask >= 0
566610
with self.assertRaises(ValueError):
567-
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1)
611+
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * -1, height, width)

0 commit comments

Comments
 (0)