Skip to content

Commit 187de44

Browse files
Fix device on save/load tests
1 parent 7d0c272 commit 187de44

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

tests/test_pipelines.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,12 @@ def test_download_no_safety_checker(self):
102102

103103
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
104104
pipe_2 = pipe_2.to(torch_device)
105-
generator_2 = generator.manual_seed(0)
106-
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
105+
if torch_device == "mps":
106+
# device type MPS is not supported for torch.Generator() api.
107+
generator = torch.manual_seed(0)
108+
else:
109+
generator = torch.Generator(device=torch_device).manual_seed(0)
110+
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
107111

108112
assert np.max(np.abs(out - out_2)) < 1e-3
109113

@@ -124,8 +128,14 @@ def test_load_no_safety_checker_explicit_locally(self):
124128
pipe.save_pretrained(tmpdirname)
125129
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
126130
pipe_2 = pipe_2.to(torch_device)
127-
generator_2 = generator.manual_seed(0)
128-
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
131+
132+
if torch_device == "mps":
133+
# device type MPS is not supported for torch.Generator() api.
134+
generator = torch.manual_seed(0)
135+
else:
136+
generator = torch.Generator(device=torch_device).manual_seed(0)
137+
138+
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
129139

130140
assert np.max(np.abs(out - out_2)) < 1e-3
131141

@@ -144,8 +154,14 @@ def test_load_no_safety_checker_default_locally(self):
144154
pipe.save_pretrained(tmpdirname)
145155
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
146156
pipe_2 = pipe_2.to(torch_device)
147-
generator_2 = generator.manual_seed(0)
148-
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
157+
158+
if torch_device == "mps":
159+
# device type MPS is not supported for torch.Generator() api.
160+
generator = torch.manual_seed(0)
161+
else:
162+
generator = torch.Generator(device=torch_device).manual_seed(0)
163+
164+
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
149165

150166
assert np.max(np.abs(out - out_2)) < 1e-3
151167

0 commit comments

Comments
 (0)