@@ -303,6 +303,25 @@ def test_inpaint_compile(self):
303
303
assert np .abs (expected_slice - image_slice ).max () < 1e-4
304
304
assert np .abs (expected_slice - image_slice ).max () < 1e-3
305
305
306
+ def test_stable_diffusion_inpaint_pil_input_resolution_test (self ):
307
+ pipe = StableDiffusionInpaintPipeline .from_pretrained (
308
+ "runwayml/stable-diffusion-inpainting" , safety_checker = None
309
+ )
310
+ pipe .scheduler = LMSDiscreteScheduler .from_config (pipe .scheduler .config )
311
+ pipe .to (torch_device )
312
+ pipe .set_progress_bar_config (disable = None )
313
+ pipe .enable_attention_slicing ()
314
+
315
+ inputs = self .get_inputs (torch_device )
316
+ # change input image to a random size (one that would cause a tensor mismatch error)
317
+ inputs ['image' ] = inputs ['image' ].resize ((127 ,127 ))
318
+ inputs ['mask_image' ] = inputs ['mask_image' ].resize ((127 ,127 ))
319
+ inputs ['height' ] = 128
320
+ inputs ['width' ] = 128
321
+ image = pipe (** inputs ).images
322
+ # verify that the returned image has the same height and width as the input height and width
323
+ assert image .shape == (1 , inputs ['height' ], inputs ['width' ], 3 )
324
+
306
325
307
326
@nightly
308
327
@require_torch_gpu
@@ -400,21 +419,22 @@ def test_inpaint_dpm(self):
400
419
401
420
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests (unittest .TestCase ):
402
421
def test_pil_inputs (self ):
403
- im = np .random .randint (0 , 255 , (32 , 32 , 3 ), dtype = np .uint8 )
422
+ height , width = 32 , 32
423
+ im = np .random .randint (0 , 255 , (height , width , 3 ), dtype = np .uint8 )
404
424
im = Image .fromarray (im )
405
- mask = np .random .randint (0 , 255 , (32 , 32 ), dtype = np .uint8 ) > 127.5
425
+ mask = np .random .randint (0 , 255 , (height , width ), dtype = np .uint8 ) > 127.5
406
426
mask = Image .fromarray ((mask * 255 ).astype (np .uint8 ))
407
427
408
- t_mask , t_masked = prepare_mask_and_masked_image (im , mask )
428
+ t_mask , t_masked = prepare_mask_and_masked_image (im , mask , height , width )
409
429
410
430
self .assertTrue (isinstance (t_mask , torch .Tensor ))
411
431
self .assertTrue (isinstance (t_masked , torch .Tensor ))
412
432
413
433
self .assertEqual (t_mask .ndim , 4 )
414
434
self .assertEqual (t_masked .ndim , 4 )
415
435
416
- self .assertEqual (t_mask .shape , (1 , 1 , 32 , 32 ))
417
- self .assertEqual (t_masked .shape , (1 , 3 , 32 , 32 ))
436
+ self .assertEqual (t_mask .shape , (1 , 1 , height , width ))
437
+ self .assertEqual (t_masked .shape , (1 , 3 , height , width ))
418
438
419
439
self .assertTrue (t_mask .dtype == torch .float32 )
420
440
self .assertTrue (t_masked .dtype == torch .float32 )
@@ -427,141 +447,165 @@ def test_pil_inputs(self):
427
447
self .assertTrue (t_mask .sum () > 0.0 )
428
448
429
449
def test_np_inputs (self ):
430
- im_np = np .random .randint (0 , 255 , (32 , 32 , 3 ), dtype = np .uint8 )
450
+ height , width = 32 , 32
451
+
452
+ im_np = np .random .randint (0 , 255 , (height , width , 3 ), dtype = np .uint8 )
431
453
im_pil = Image .fromarray (im_np )
432
- mask_np = np .random .randint (0 , 255 , (32 , 32 ), dtype = np .uint8 ) > 127.5
454
+ mask_np = np .random .randint (0 , 255 , (height , width , ), dtype = np .uint8 ) > 127.5
433
455
mask_pil = Image .fromarray ((mask_np * 255 ).astype (np .uint8 ))
434
456
435
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
436
- t_mask_pil , t_masked_pil = prepare_mask_and_masked_image (im_pil , mask_pil )
457
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
458
+ t_mask_pil , t_masked_pil = prepare_mask_and_masked_image (im_pil , mask_pil , height , width )
437
459
438
460
self .assertTrue ((t_mask_np == t_mask_pil ).all ())
439
461
self .assertTrue ((t_masked_np == t_masked_pil ).all ())
440
462
441
463
def test_torch_3D_2D_inputs (self ):
442
- im_tensor = torch .randint (0 , 255 , (3 , 32 , 32 ), dtype = torch .uint8 )
443
- mask_tensor = torch .randint (0 , 255 , (32 , 32 ), dtype = torch .uint8 ) > 127.5
464
+ height , width = 32 , 32
465
+
466
+ im_tensor = torch .randint (0 , 255 , (3 , height , width ,), dtype = torch .uint8 )
467
+ mask_tensor = torch .randint (0 , 255 , (height , width ,), dtype = torch .uint8 ) > 127.5
444
468
im_np = im_tensor .numpy ().transpose (1 , 2 , 0 )
445
469
mask_np = mask_tensor .numpy ()
446
470
447
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
448
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
471
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
472
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
449
473
450
474
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
451
475
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
452
476
453
477
def test_torch_3D_3D_inputs (self ):
454
- im_tensor = torch .randint (0 , 255 , (3 , 32 , 32 ), dtype = torch .uint8 )
455
- mask_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
478
+ height , width = 32 , 32
479
+
480
+ im_tensor = torch .randint (0 , 255 , (3 , height , width ,), dtype = torch .uint8 )
481
+ mask_tensor = torch .randint (0 , 255 , (1 , height , width ,), dtype = torch .uint8 ) > 127.5
456
482
im_np = im_tensor .numpy ().transpose (1 , 2 , 0 )
457
483
mask_np = mask_tensor .numpy ()[0 ]
458
484
459
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
460
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
485
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
486
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
461
487
462
488
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
463
489
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
464
490
465
491
def test_torch_4D_2D_inputs (self ):
466
- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
467
- mask_tensor = torch .randint (0 , 255 , (32 , 32 ), dtype = torch .uint8 ) > 127.5
492
+ height , width = 32 , 32
493
+
494
+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
495
+ mask_tensor = torch .randint (0 , 255 , (height , width ,), dtype = torch .uint8 ) > 127.5
468
496
im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
469
497
mask_np = mask_tensor .numpy ()
470
498
471
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
472
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
499
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
500
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
473
501
474
502
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
475
503
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
476
504
477
505
def test_torch_4D_3D_inputs (self ):
478
- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
479
- mask_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
506
+ height , width = 32 , 32
507
+
508
+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
509
+ mask_tensor = torch .randint (0 , 255 , (1 , height , width ,), dtype = torch .uint8 ) > 127.5
480
510
im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
481
511
mask_np = mask_tensor .numpy ()[0 ]
482
512
483
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
484
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
513
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
514
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
485
515
486
516
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
487
517
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
488
518
489
519
def test_torch_4D_4D_inputs (self ):
490
- im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
491
- mask_tensor = torch .randint (0 , 255 , (1 , 1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
520
+ height , width = 32 , 32
521
+
522
+ im_tensor = torch .randint (0 , 255 , (1 , 3 , height , width ,), dtype = torch .uint8 )
523
+ mask_tensor = torch .randint (0 , 255 , (1 , 1 , height , width ,), dtype = torch .uint8 ) > 127.5
492
524
im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
493
525
mask_np = mask_tensor .numpy ()[0 ][0 ]
494
526
495
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
496
- t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
527
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
528
+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np , height , width )
497
529
498
530
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
499
531
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
500
532
501
533
def test_torch_batch_4D_3D (self ):
502
- im_tensor = torch .randint (0 , 255 , (2 , 3 , 32 , 32 ), dtype = torch .uint8 )
503
- mask_tensor = torch .randint (0 , 255 , (2 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
534
+ height , width = 32 , 32
535
+
536
+ im_tensor = torch .randint (0 , 255 , (2 , 3 , height , width ,), dtype = torch .uint8 )
537
+ mask_tensor = torch .randint (0 , 255 , (2 , height , width ,), dtype = torch .uint8 ) > 127.5
504
538
505
539
im_nps = [im .numpy ().transpose (1 , 2 , 0 ) for im in im_tensor ]
506
540
mask_nps = [mask .numpy () for mask in mask_tensor ]
507
541
508
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
509
- nps = [prepare_mask_and_masked_image (i , m ) for i , m in zip (im_nps , mask_nps )]
542
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
543
+ nps = [prepare_mask_and_masked_image (i , m , height , width ) for i , m in zip (im_nps , mask_nps )]
510
544
t_mask_np = torch .cat ([n [0 ] for n in nps ])
511
545
t_masked_np = torch .cat ([n [1 ] for n in nps ])
512
546
513
547
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
514
548
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
515
549
516
550
def test_torch_batch_4D_4D (self ):
517
- im_tensor = torch .randint (0 , 255 , (2 , 3 , 32 , 32 ), dtype = torch .uint8 )
518
- mask_tensor = torch .randint (0 , 255 , (2 , 1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
551
+ height , width = 32 , 32
552
+
553
+ im_tensor = torch .randint (0 , 255 , (2 , 3 , height , width ,), dtype = torch .uint8 )
554
+ mask_tensor = torch .randint (0 , 255 , (2 , 1 , height , width ,), dtype = torch .uint8 ) > 127.5
519
555
520
556
im_nps = [im .numpy ().transpose (1 , 2 , 0 ) for im in im_tensor ]
521
557
mask_nps = [mask .numpy ()[0 ] for mask in mask_tensor ]
522
558
523
- t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
524
- nps = [prepare_mask_and_masked_image (i , m ) for i , m in zip (im_nps , mask_nps )]
559
+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor , height , width )
560
+ nps = [prepare_mask_and_masked_image (i , m , height , width ) for i , m in zip (im_nps , mask_nps )]
525
561
t_mask_np = torch .cat ([n [0 ] for n in nps ])
526
562
t_masked_np = torch .cat ([n [1 ] for n in nps ])
527
563
528
564
self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
529
565
self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
530
566
531
567
def test_shape_mismatch (self ):
568
+ height , width = 32 , 32
569
+
532
570
# test height and width
533
571
with self .assertRaises (AssertionError ):
534
- prepare_mask_and_masked_image (torch .randn (3 , 32 , 32 ), torch .randn (64 , 64 ))
572
+ prepare_mask_and_masked_image (torch .randn (3 , height , width , ), torch .randn (64 , 64 ), height , width )
535
573
# test batch dim
536
574
with self .assertRaises (AssertionError ):
537
- prepare_mask_and_masked_image (torch .randn (2 , 3 , 32 , 32 ), torch .randn (4 , 64 , 64 ))
575
+ prepare_mask_and_masked_image (torch .randn (2 , 3 , height , width , ), torch .randn (4 , 64 , 64 ), height , width )
538
576
# test batch dim
539
577
with self .assertRaises (AssertionError ):
540
- prepare_mask_and_masked_image (torch .randn (2 , 3 , 32 , 32 ), torch .randn (4 , 1 , 64 , 64 ))
578
+ prepare_mask_and_masked_image (torch .randn (2 , 3 , height , width , ), torch .randn (4 , 1 , 64 , 64 ), height , width )
541
579
542
580
def test_type_mismatch (self ):
581
+ height , width = 32 , 32
582
+
543
583
# test tensors-only
544
584
with self .assertRaises (TypeError ):
545
- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .rand (3 , 32 , 32 ).numpy ())
585
+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .rand (3 , height , width , ).numpy (), height , width )
546
586
# test tensors-only
547
587
with self .assertRaises (TypeError ):
548
- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ).numpy (), torch .rand (3 , 32 , 32 ) )
588
+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ).numpy (), torch .rand (3 , height , width ,), height , width )
549
589
550
590
def test_channels_first (self ):
591
+ height , width = 32 , 32
592
+
551
593
# test channels first for 3D tensors
552
594
with self .assertRaises (AssertionError ):
553
- prepare_mask_and_masked_image (torch .rand (32 , 32 , 3 ), torch .rand (3 , 32 , 32 ) )
595
+ prepare_mask_and_masked_image (torch .rand (height , width , 3 ), torch .rand (3 , height , width ,), height , width )
554
596
555
597
def test_tensor_range (self ):
598
+ height , width = 32 , 32
599
+
556
600
# test im <= 1
557
601
with self .assertRaises (ValueError ):
558
- prepare_mask_and_masked_image (torch .ones (3 , 32 , 32 ) * 2 , torch .rand (32 , 32 ) )
602
+ prepare_mask_and_masked_image (torch .ones (3 , height , width , ) * 2 , torch .rand (height , width ,), height , width )
559
603
# test im >= -1
560
604
with self .assertRaises (ValueError ):
561
- prepare_mask_and_masked_image (torch .ones (3 , 32 , 32 ) * (- 2 ), torch .rand (32 , 32 ) )
605
+ prepare_mask_and_masked_image (torch .ones (3 , height , width , ) * (- 2 ), torch .rand (height , width ,), height , width )
562
606
# test mask <= 1
563
607
with self .assertRaises (ValueError ):
564
- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .ones (32 , 32 ) * 2 )
608
+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .ones (height , width , ) * 2 , height , width )
565
609
# test mask >= 0
566
610
with self .assertRaises (ValueError ):
567
- prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .ones (32 , 32 ) * - 1 )
611
+ prepare_mask_and_masked_image (torch .rand (3 , height , width , ), torch .ones (height , width , ) * - 1 , height , width )
0 commit comments