@@ -102,8 +102,12 @@ def test_download_no_safety_checker(self):
102102
103103 pipe_2 = StableDiffusionPipeline .from_pretrained ("hf-internal-testing/tiny-stable-diffusion-torch" )
104104 pipe_2 = pipe_2 .to (torch_device )
105- generator_2 = generator .manual_seed (0 )
106- out_2 = pipe_2 (prompt , num_inference_steps = 2 , generator = generator_2 , output_type = "numpy" ).images
105+ if torch_device == "mps" :
106+ # device type MPS is not supported for torch.Generator() api.
107+ generator = torch .manual_seed (0 )
108+ else :
109+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
110+ out_2 = pipe_2 (prompt , num_inference_steps = 2 , generator = generator , output_type = "numpy" ).images
107111
108112 assert np .max (np .abs (out - out_2 )) < 1e-3
109113
@@ -124,8 +128,14 @@ def test_load_no_safety_checker_explicit_locally(self):
124128 pipe .save_pretrained (tmpdirname )
125129 pipe_2 = StableDiffusionPipeline .from_pretrained (tmpdirname , safety_checker = None )
126130 pipe_2 = pipe_2 .to (torch_device )
127- generator_2 = generator .manual_seed (0 )
128- out_2 = pipe_2 (prompt , num_inference_steps = 2 , generator = generator_2 , output_type = "numpy" ).images
131+
132+ if torch_device == "mps" :
133+ # device type MPS is not supported for torch.Generator() api.
134+ generator = torch .manual_seed (0 )
135+ else :
136+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
137+
138+ out_2 = pipe_2 (prompt , num_inference_steps = 2 , generator = generator , output_type = "numpy" ).images
129139
130140 assert np .max (np .abs (out - out_2 )) < 1e-3
131141
@@ -144,8 +154,14 @@ def test_load_no_safety_checker_default_locally(self):
144154 pipe .save_pretrained (tmpdirname )
145155 pipe_2 = StableDiffusionPipeline .from_pretrained (tmpdirname )
146156 pipe_2 = pipe_2 .to (torch_device )
147- generator_2 = generator .manual_seed (0 )
148- out_2 = pipe_2 (prompt , num_inference_steps = 2 , generator = generator_2 , output_type = "numpy" ).images
157+
158+ if torch_device == "mps" :
159+ # device type MPS is not supported for torch.Generator() api.
160+ generator = torch .manual_seed (0 )
161+ else :
162+ generator = torch .Generator (device = torch_device ).manual_seed (0 )
163+
164+ out_2 = pipe_2 (prompt , num_inference_steps = 2 , generator = generator , output_type = "numpy" ).images
149165
150166 assert np .max (np .abs (out - out_2 )) < 1e-3
151167
0 commit comments