Skip to content

Commit beb59ab

Browse files
remove ddpm test_full_inference (#2291)
* remove ddpm test_full_inference * style
1 parent 96c2279 commit beb59ab

File tree

1 file changed

+0
-26
lines changed

1 file changed

+0
-26
lines changed

tests/pipelines/ddpm/test_ddpm.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -66,32 +66,6 @@ def test_fast_inference(self):
6666
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
6767
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
6868

69-
def test_full_inference(self):
70-
device = "cpu"
71-
unet = self.dummy_uncond_unet
72-
scheduler = DDPMScheduler()
73-
74-
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
75-
ddpm.to(device)
76-
ddpm.set_progress_bar_config(disable=None)
77-
78-
generator = torch.Generator(device=device).manual_seed(0)
79-
image = ddpm(generator=generator, output_type="numpy").images
80-
81-
generator = torch.Generator(device=device).manual_seed(0)
82-
image_from_tuple = ddpm(generator=generator, output_type="numpy", return_dict=False)[0]
83-
84-
image_slice = image[0, -3:, -3:, -1]
85-
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
86-
87-
assert image.shape == (1, 32, 32, 3)
88-
expected_slice = np.array(
89-
[1.0, 3.495e-02, 2.939e-01, 9.821e-01, 9.448e-01, 6.261e-03, 7.998e-01, 8.9e-01, 1.122e-02]
90-
)
91-
92-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
93-
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
94-
9569
def test_inference_predict_sample(self):
9670
unet = self.dummy_uncond_unet
9771
scheduler = DDPMScheduler(prediction_type="sample")

0 commit comments

Comments
 (0)