@@ -30,9 +30,10 @@ class DDPMPipelineFastTests(unittest.TestCase):
3030 def dummy_uncond_unet (self ):
3131 torch .manual_seed (0 )
3232 model = UNet2DModel (
33- block_out_channels = (32 , 64 ),
34- layers_per_block = 2 ,
35- sample_size = 32 ,
33+ block_out_channels = (4 , 8 ),
34+ layers_per_block = 1 ,
35+ norm_num_groups = 4 ,
36+ sample_size = 8 ,
3637 in_channels = 3 ,
3738 out_channels = 3 ,
3839 down_block_types = ("DownBlock2D" , "AttnDownBlock2D" ),
@@ -58,10 +59,8 @@ def test_fast_inference(self):
5859 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
5960 image_from_tuple_slice = image_from_tuple [0 , - 3 :, - 3 :, - 1 ]
6061
61- assert image .shape == (1 , 32 , 32 , 3 )
62- expected_slice = np .array (
63- [9.956e-01 , 5.785e-01 , 4.675e-01 , 9.930e-01 , 0.0 , 1.000 , 1.199e-03 , 2.648e-04 , 5.101e-04 ]
64- )
62+ assert image .shape == (1 , 8 , 8 , 3 )
63+ expected_slice = np .array ([0.0 , 0.9996672 , 0.00329116 , 1.0 , 0.9995991 , 1.0 , 0.0060907 , 0.00115037 , 0.0 ])
6564
6665 assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
6766 assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
@@ -83,7 +82,7 @@ def test_inference_predict_sample(self):
8382 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
8483 image_eps_slice = image_eps [0 , - 3 :, - 3 :, - 1 ]
8584
86- assert image .shape == (1 , 32 , 32 , 3 )
85+ assert image .shape == (1 , 8 , 8 , 3 )
8786 tolerance = 1e-2 if torch_device != "mps" else 3e-2
8887 assert np .abs (image_slice .flatten () - image_eps_slice .flatten ()).max () < tolerance
8988
0 commit comments