@@ -66,32 +66,6 @@ def test_fast_inference(self):
66
66
assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
67
67
assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
68
68
69
- def test_full_inference (self ):
70
- device = "cpu"
71
- unet = self .dummy_uncond_unet
72
- scheduler = DDPMScheduler ()
73
-
74
- ddpm = DDPMPipeline (unet = unet , scheduler = scheduler )
75
- ddpm .to (device )
76
- ddpm .set_progress_bar_config (disable = None )
77
-
78
- generator = torch .Generator (device = device ).manual_seed (0 )
79
- image = ddpm (generator = generator , output_type = "numpy" ).images
80
-
81
- generator = torch .Generator (device = device ).manual_seed (0 )
82
- image_from_tuple = ddpm (generator = generator , output_type = "numpy" , return_dict = False )[0 ]
83
-
84
- image_slice = image [0 , - 3 :, - 3 :, - 1 ]
85
- image_from_tuple_slice = image_from_tuple [0 , - 3 :, - 3 :, - 1 ]
86
-
87
- assert image .shape == (1 , 32 , 32 , 3 )
88
- expected_slice = np .array (
89
- [1.0 , 3.495e-02 , 2.939e-01 , 9.821e-01 , 9.448e-01 , 6.261e-03 , 7.998e-01 , 8.9e-01 , 1.122e-02 ]
90
- )
91
-
92
- assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
93
- assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
94
-
95
69
def test_inference_predict_sample (self ):
96
70
unet = self .dummy_uncond_unet
97
71
scheduler = DDPMScheduler (prediction_type = "sample" )
0 commit comments