Skip to content

Commit dbf78f5

Browse files
clean up
1 parent 64d1727 commit dbf78f5

File tree

2 files changed

+18
-26
lines changed

2 files changed

+18
-26
lines changed

test_corrections.txt

Lines changed: 0 additions & 4 deletions
This file was deleted.

tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
4343
from diffusers.utils.import_utils import is_accelerate_available
44-
from diffusers.utils.testing_utils import require_torch_gpu, print_tensor_test
44+
from diffusers.utils.testing_utils import require_torch_gpu
4545

4646
from ...test_pipelines_common import PipelineTesterMixin
4747

@@ -284,12 +284,11 @@ def test_stable_diffusion_depth2img_default_case(self):
284284
image = pipe(**inputs).images
285285
image_slice = image[0, -3:, -3:, -1]
286286

287-
# assert image.shape == (1, 32, 32, 3)
288-
# if torch_device == "mps":
289-
# expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546])
290-
# else:
291-
print_tensor_test(image_slice)
292-
expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716])
287+
assert image.shape == (1, 32, 32, 3)
288+
if torch_device == "mps":
289+
expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546])
290+
else:
291+
expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716])
293292

294293
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
295294

@@ -307,11 +306,10 @@ def test_stable_diffusion_depth2img_negative_prompt(self):
307306
image_slice = image[0, -3:, -3:, -1]
308307

309308
assert image.shape == (1, 32, 32, 3)
310-
# if torch_device == "mps":
311-
# expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335])
312-
# else:
313-
print_tensor_test(image_slice)
314-
expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626])
309+
if torch_device == "mps":
310+
expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335])
311+
else:
312+
expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626])
315313

316314
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
317315

@@ -330,11 +328,10 @@ def test_stable_diffusion_depth2img_multiple_init_images(self):
330328

331329
assert image.shape == (2, 32, 32, 3)
332330

333-
# if torch_device == "mps":
334-
# expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551])
335-
# else:
336-
print_tensor_test(image_slice)
337-
expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674])
331+
if torch_device == "mps":
332+
expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551])
333+
else:
334+
expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674])
338335

339336
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
340337

@@ -386,11 +383,10 @@ def test_stable_diffusion_depth2img_pil(self):
386383
image = pipe(**inputs).images
387384
image_slice = image[0, -3:, -3:, -1]
388385

389-
# if torch_device == "mps":
390-
# expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439])
391-
# else:
392-
expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716])
393-
print_tensor_test(image_slice)
386+
if torch_device == "mps":
387+
expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439])
388+
else:
389+
expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716])
394390

395391
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
396392

0 commit comments

Comments
 (0)