diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 223f8a236efa..66706c806a81 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union @@ -21,7 +22,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, PNDMScheduler +from ...schedulers import DDIMScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -96,9 +97,6 @@ def __init__( ): super().__init__() - if isinstance(scheduler, PNDMScheduler): - logger.error("PNDMScheduler for this pipeline is currently not supported.") - if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" @@ -612,7 +610,7 @@ def __call__( # 6. Define panorama grid and initialize views for synthesis. views = self.get_views(height, width) - blocks_model_outputs = [None] * len(views) + views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views) count = torch.zeros_like(latents) value = torch.zeros_like(latents) @@ -637,6 +635,9 @@ def __call__( # get the latents corresponding to the current view coordinates latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] + # rematch block's scheduler status + self.scheduler.__dict__.update(views_scheduler_status[j]) + # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view @@ -657,21 +658,13 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if hasattr(self.scheduler, "model_outputs"): - # rematch model_outputs in each block - if i >= 1: - self.scheduler.model_outputs = blocks_model_outputs[j] - latents_view_denoised = self.scheduler.step( - noise_pred, t, latents_for_view, **extra_step_kwargs - ).prev_sample - # collect model_outputs - blocks_model_outputs[j] = [ - output if output is not None else None for output in self.scheduler.model_outputs - ] - else: - latents_view_denoised = self.scheduler.step( - noise_pred, t, latents_for_view, **extra_step_kwargs - ).prev_sample + latents_view_denoised = self.scheduler.step( + noise_pred, t, latents_for_view, **extra_step_kwargs + ).prev_sample + + # save views scheduler status after sample + views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 02a15b2a29dc..021065416838 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -174,15 +174,22 @@ def test_stable_diffusion_panorama_euler(self): def test_stable_diffusion_panorama_pndm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - components["scheduler"] = PNDMScheduler() + components["scheduler"] = PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) sd_pipe = StableDiffusionPanoramaPipeline(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - # the pipeline does not expect pndm so test if it raises error. - with self.assertRaises(ValueError): - _ = sd_pipe(**inputs).images + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow