2020
2121from diffusers import DDPMPipeline , DDPMScheduler , UNet2DModel
2222from diffusers .utils import deprecate
23- from diffusers .utils .testing_utils import require_torch_gpu , slow , torch_device
23+ from diffusers .utils .testing_utils import require_torch , slow , torch_device
2424
2525from ...test_pipelines_common import PipelineTesterMixin
2626
@@ -44,18 +44,21 @@ def dummy_uncond_unet(self):
4444 return model
4545
4646 def test_inference (self ):
47- device = "cpu"
4847 unet = self .dummy_uncond_unet
4948 scheduler = DDPMScheduler ()
5049
5150 ddpm = DDPMPipeline (unet = unet , scheduler = scheduler )
52- ddpm .to (device )
51+ ddpm .to (torch_device )
5352 ddpm .set_progress_bar_config (disable = None )
5453
55- generator = torch .Generator (device = device ).manual_seed (0 )
54+ # Warmup pass when using mps (see #372)
55+ if torch_device == "mps" :
56+ _ = ddpm (num_inference_steps = 1 )
57+
58+ generator = torch .manual_seed (0 )
5659 image = ddpm (generator = generator , num_inference_steps = 2 , output_type = "numpy" ).images
5760
58- generator = torch .Generator ( device = device ). manual_seed (0 )
61+ generator = torch .manual_seed (0 )
5962 image_from_tuple = ddpm (generator = generator , num_inference_steps = 2 , output_type = "numpy" , return_dict = False )[0 ]
6063
6164 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
@@ -65,8 +68,9 @@ def test_inference(self):
6568 expected_slice = np .array (
6669 [5.589e-01 , 7.089e-01 , 2.632e-01 , 6.841e-01 , 1.000e-04 , 9.999e-01 , 1.973e-01 , 1.000e-04 , 8.010e-02 ]
6770 )
68- assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
69- assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
71+ tolerance = 1e-2 if torch_device != "mps" else 3e-2
72+ assert np .abs (image_slice .flatten () - expected_slice ).max () < tolerance
73+ assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < tolerance
7074
7175 def test_inference_predict_epsilon (self ):
7276 deprecate ("remove this test" , "0.10.0" , "remove" )
@@ -80,16 +84,14 @@ def test_inference_predict_epsilon(self):
8084 # Warmup pass when using mps (see #372)
8185 if torch_device == "mps" :
8286 _ = ddpm (num_inference_steps = 1 )
87+
8388 if torch_device == "mps" :
8489 # device type MPS is not supported for torch.Generator() api.
8590 generator = torch .manual_seed (0 )
8691 else :
8792 generator = torch .Generator (device = torch_device ).manual_seed (0 )
8893 image = ddpm (generator = generator , num_inference_steps = 2 , output_type = "numpy" ).images
8994
90- # Warmup pass when using mps (see #372)
91- if torch_device == "mps" :
92- _ = ddpm (num_inference_steps = 1 )
9395 generator = generator .manual_seed (0 )
9496 image_eps = ddpm (generator = generator , num_inference_steps = 2 , output_type = "numpy" , predict_epsilon = False )[0 ]
9597
@@ -102,7 +104,7 @@ def test_inference_predict_epsilon(self):
102104
103105
104106@slow
105- @require_torch_gpu
107+ @require_torch
106108class DDPMPipelineIntegrationTests (unittest .TestCase ):
107109 def test_inference_cifar10 (self ):
108110 model_id = "google/ddpm-cifar10-32"
@@ -114,11 +116,11 @@ def test_inference_cifar10(self):
114116 ddpm .to (torch_device )
115117 ddpm .set_progress_bar_config (disable = None )
116118
117- generator = torch .Generator ( device = torch_device ). manual_seed (0 )
119+ generator = torch .manual_seed (0 )
118120 image = ddpm (generator = generator , output_type = "numpy" ).images
119121
120122 image_slice = image [0 , - 3 :, - 3 :, - 1 ]
121123
122124 assert image .shape == (1 , 32 , 32 , 3 )
123- expected_slice = np .array ([0.4454 , 0.2025 , 0.0315 , 0.3023 , 0.2575 , 0.1031 , 0.0953 , 0.1604 , 0.2020 ])
125+ expected_slice = np .array ([0.41995 , 0.35885 , 0.19385 , 0.38475 , 0.3382 , 0.2647 , 0.41545 , 0.3582 , 0.33845 ])
124126 assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
0 commit comments