From 43111cf70785edd555cd3b30f2804f7ec1249a24 Mon Sep 17 00:00:00 2001 From: ethansmith2000 <98723285+ethansmith2000@users.noreply.github.com> Date: Sun, 16 Apr 2023 14:54:21 -0400 Subject: [PATCH] give option to turn off progress bar in jupyter notebooks, if running generations sequentially, the page can get filled with progress bars. would be nice to have an option to disable it. --- .../stable_diffusion/pipeline_stable_diffusion.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 689febe3e891..cdb86b58b5c8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -51,6 +51,12 @@ >>> image = pipe(prompt).images[0] ``` """ +class NullContext: + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + pass class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): @@ -540,6 +546,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + use_progress_bar: bool = True, ): r""" Function invoked when calling the pipeline for generation. @@ -666,7 +673,8 @@ def __call__( # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: + state = self.progress_bar(total=num_inference_steps) if use_progress_bar else NullContext() + with state as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -690,7 +698,8 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() + if use_progress_bar: + progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents)