Skip to content

Commit ec6eb48

Browse files
prep test file for changes
1 parent cc34d6b commit ec6eb48

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
4343
from diffusers.utils.import_utils import is_accelerate_available
44-
from diffusers.utils.testing_utils import require_torch_gpu
44+
from diffusers.utils.testing_utils import require_torch_gpu, print_tensor_test
4545

4646
from ...test_pipelines_common import PipelineTesterMixin
4747

@@ -284,11 +284,12 @@ def test_stable_diffusion_depth2img_default_case(self):
284284
image = pipe(**inputs).images
285285
image_slice = image[0, -3:, -3:, -1]
286286

287-
assert image.shape == (1, 32, 32, 3)
288-
if torch_device == "mps":
289-
expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546])
290-
else:
291-
expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711])
287+
# assert image.shape == (1, 32, 32, 3)
288+
# if torch_device == "mps":
289+
# expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546])
290+
# else:
291+
print_tensor_test(image_slice)
292+
expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711])
292293

293294
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
294295

@@ -306,10 +307,11 @@ def test_stable_diffusion_depth2img_negative_prompt(self):
306307
image_slice = image[0, -3:, -3:, -1]
307308

308309
assert image.shape == (1, 32, 32, 3)
309-
if torch_device == "mps":
310-
expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335])
311-
else:
312-
expected_slice = np.array([0.6332, 0.5167, 0.3911, 0.4446, 0.5971, 0.4619, 0.3821, 0.5323, 0.4621])
310+
# if torch_device == "mps":
311+
# expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335])
312+
# else:
313+
print_tensor_test(image_slice)
314+
expected_slice = np.array([0.6332, 0.5167, 0.3911, 0.4446, 0.5971, 0.4619, 0.3821, 0.5323, 0.4621])
313315

314316
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
315317

@@ -328,10 +330,11 @@ def test_stable_diffusion_depth2img_multiple_init_images(self):
328330

329331
assert image.shape == (2, 32, 32, 3)
330332

331-
if torch_device == "mps":
332-
expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551])
333-
else:
334-
expected_slice = np.array([0.6248, 0.5206, 0.6007, 0.6749, 0.5022, 0.6442, 0.5352, 0.4140, 0.4681])
333+
# if torch_device == "mps":
334+
# expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551])
335+
# else:
336+
print_tensor_test(image_slice)
337+
expected_slice = np.array([0.6248, 0.5206, 0.6007, 0.6749, 0.5022, 0.6442, 0.5352, 0.4140, 0.4681])
335338

336339
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
337340

@@ -383,10 +386,11 @@ def test_stable_diffusion_depth2img_pil(self):
383386
image = pipe(**inputs).images
384387
image_slice = image[0, -3:, -3:, -1]
385388

386-
if torch_device == "mps":
387-
expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439])
388-
else:
389-
expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711])
389+
# if torch_device == "mps":
390+
# expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439])
391+
# else:
392+
expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711])
393+
print_tensor_test(image_slice)
390394

391395
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
392396

0 commit comments

Comments
 (0)