Skip to content

Commit ca1e407

Browse files
stable diffusion depth batching fix (#2757)
1 parent b33bd91 commit ca1e407

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,8 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui
474474

475475
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
476476
if depth_map.shape[0] < batch_size:
477-
depth_map = depth_map.repeat(batch_size, 1, 1, 1)
477+
repeat_by = batch_size // depth_map.shape[0]
478+
depth_map = depth_map.repeat(repeat_by, 1, 1, 1)
478479

479480
depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map
480481
return depth_map

tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
6464
test_save_load_optional_components = False
6565
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
6666
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
67-
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"image"}
67+
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
6868

6969
def get_dummy_components(self):
7070
torch.manual_seed(0)

0 commit comments

Comments
 (0)