From b7a6bf5d94a609c0093b4f57fe0fa79194c503e4 Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 20 Mar 2023 16:26:37 -0700 Subject: [PATCH] stable diffusion depth batching fix --- .../stable_diffusion/pipeline_stable_diffusion_depth2img.py | 3 ++- .../stable_diffusion_2/test_stable_diffusion_depth.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 9087064ae0b8..b66cfe9b437e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -474,7 +474,8 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if depth_map.shape[0] < batch_size: - depth_map = depth_map.repeat(batch_size, 1, 1, 1) + repeat_by = batch_size // depth_map.shape[0] + depth_map = depth_map.repeat(repeat_by, 1, 1, 1) depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map return depth_map diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 110dbbd7f80c..12e7113399a8 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te test_save_load_optional_components = False params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"image"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS def get_dummy_components(self): torch.manual_seed(0)