@@ -42,9 +42,10 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4242 def get_dummy_components (self ):
4343 torch .manual_seed (0 )
4444 unet = UNet2DModel (
45- block_out_channels = (32 , 64 ),
46- layers_per_block = 2 ,
47- sample_size = 32 ,
45+ block_out_channels = (4 , 8 ),
46+ layers_per_block = 1 ,
47+ norm_num_groups = 4 ,
48+ sample_size = 8 ,
4849 in_channels = 3 ,
4950 out_channels = 3 ,
5051 down_block_types = ("DownBlock2D" , "AttnDownBlock2D" ),
@@ -79,10 +80,8 @@ def test_inference(self):
7980 image = pipe (** inputs ).images
8081 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
8182
82- self .assertEqual (image .shape , (1 , 32 , 32 , 3 ))
83- expected_slice = np .array (
84- [1.000e00 , 5.717e-01 , 4.717e-01 , 1.000e00 , 0.000e00 , 1.000e00 , 3.000e-04 , 0.000e00 , 9.000e-04 ]
85- )
83+ self .assertEqual (image .shape , (1 , 8 , 8 , 3 ))
84+ expected_slice = np .array ([0.0 , 9.979e-01 , 0.0 , 9.999e-01 , 9.986e-01 , 9.991e-01 , 7.106e-04 , 0.0 , 0.0 ])
8685 max_diff = np .abs (image_slice .flatten () - expected_slice ).max ()
8786 self .assertLessEqual (max_diff , 1e-3 )
8887
0 commit comments