Skip to content

Commit cd8b750

Browse files
authored
speed up attend-and-excite fast tests (#3079)
1 parent 3b641ea commit cd8b750

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_dummy_components(self):
4444
torch.manual_seed(0)
4545
unet = UNet2DConditionModel(
4646
block_out_channels=(32, 64),
47-
layers_per_block=2,
47+
layers_per_block=1,
4848
sample_size=32,
4949
in_channels=4,
5050
out_channels=4,
@@ -111,7 +111,7 @@ def get_dummy_inputs(self, device, seed=0):
111111
"prompt": "a cat and a frog",
112112
"token_indices": [2, 5],
113113
"generator": generator,
114-
"num_inference_steps": 2,
114+
"num_inference_steps": 1,
115115
"guidance_scale": 6.0,
116116
"output_type": "numpy",
117117
"max_iter_to_alter": 2,
@@ -132,13 +132,18 @@ def test_inference(self):
132132
image_slice = image[0, -3:, -3:, -1]
133133

134134
self.assertEqual(image.shape, (1, 64, 64, 3))
135-
expected_slice = np.array([0.5743, 0.6081, 0.4975, 0.5021, 0.5441, 0.4699, 0.4988, 0.4841, 0.4851])
135+
expected_slice = np.array(
136+
[0.63905364, 0.62897307, 0.48599017, 0.5133624, 0.5550048, 0.45769516, 0.50326973, 0.5023139, 0.45384496]
137+
)
136138
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
137139
self.assertLessEqual(max_diff, 1e-3)
138140

139141
def test_inference_batch_consistent(self):
140142
# NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
141-
self._test_inference_batch_consistent(batch_sizes=[2, 4])
143+
self._test_inference_batch_consistent(batch_sizes=[1, 2])
144+
145+
def test_inference_batch_single_identical(self):
146+
self._test_inference_batch_single_identical(batch_size=2)
142147

143148

144149
@require_torch_gpu

0 commit comments

Comments
 (0)