Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@
>>> image = pipe(prompt).images[0]
```
"""
class NullContext:
def __enter__(self):
pass

def __exit__(self, exc_type, exc_value, traceback):
pass


class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
Expand Down Expand Up @@ -540,6 +546,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
use_progress_bar: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -666,7 +673,8 @@ def __call__(

# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
state = self.progress_bar(total=num_inference_steps) if use_progress_bar else NullContext()
with state as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
Expand All @@ -690,7 +698,8 @@ def __call__(

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if use_progress_bar:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

Expand Down