Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from huggingface_hub import snapshot_download
from PIL import Image
from tqdm.auto import tqdm

from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging
Expand Down Expand Up @@ -266,3 +267,16 @@ def numpy_to_pil(images):
pil_images = [Image.fromarray(image) for image in images]

return pil_images

def progress_bar(self, iterable):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)

return tqdm(iterable, **self._progress_bar_config)

def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
4 changes: 1 addition & 3 deletions src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

import torch

from tqdm.auto import tqdm

from ...pipeline_utils import DiffusionPipeline


Expand Down Expand Up @@ -56,7 +54,7 @@ def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50
# set step values
self.scheduler.set_timesteps(num_inference_steps)

for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t)["sample"]

Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

import torch

from tqdm.auto import tqdm

from ...pipeline_utils import DiffusionPipeline


Expand Down Expand Up @@ -53,7 +51,7 @@ def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
# set step values
self.scheduler.set_timesteps(1000)

for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t)["sample"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch.nn as nn
import torch.utils.checkpoint

from tqdm.auto import tqdm
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
Expand Down Expand Up @@ -83,7 +82,7 @@ def __call__(
if accepts_eta:
extra_kwargs["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
latents_input = latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import torch

from tqdm.auto import tqdm

from ...pipeline_utils import DiffusionPipeline


Expand Down Expand Up @@ -45,7 +43,7 @@ def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50
if accepts_eta:
extra_kwargs["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
# predict the noise residual
noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1
Expand Down
4 changes: 1 addition & 3 deletions src/diffusers/pipelines/pndm/pipeline_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

import torch

from tqdm.auto import tqdm

from ...pipeline_utils import DiffusionPipeline


Expand Down Expand Up @@ -54,7 +52,7 @@ def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_
image = image.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
model_output = self.unet(image, t)["sample"]

image = self.scheduler.step(model_output, t, image)["prev_sample"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch

from diffusers import DiffusionPipeline
from tqdm.auto import tqdm


class ScoreSdeVePipeline(DiffusionPipeline):
Expand Down Expand Up @@ -37,7 +36,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)

for i, t in tqdm(enumerate(self.scheduler.timesteps)):
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)

# correction step
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...models import AutoencoderKL, UNet2DConditionModel
Expand Down Expand Up @@ -133,7 +132,7 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

for i, t in tqdm(enumerate(self.scheduler.timesteps)):
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for enumerate and tqdm, if the order is enumerate(tqdm(...)), we don't have to pass total.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that's what it states in the tqdm docs: https://pypi.org/project/tqdm/#faq-and-known-issues

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import torch

from tqdm.auto import tqdm

from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import KarrasVeScheduler
Expand Down Expand Up @@ -53,7 +51,7 @@ def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_

self.scheduler.set_timesteps(num_inference_steps)

for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
# here sigma_t == t_i from the paper
sigma = self.scheduler.schedule[t]
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
Expand Down
23 changes: 23 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,29 @@
torch.backends.cuda.matmul.allow_tf32 = False


def test_progress_bar(capsys):
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
scheduler = DDPMScheduler(num_train_timesteps=10)

ddpm = DDPMPipeline(model, scheduler).to(torch_device)
ddpm(output_type="numpy")["sample"]
captured = capsys.readouterr()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test, gonna remember this ;)

assert "10/10" in captured.err, "Progress bar has to be displayed"

ddpm.set_progress_bar_config(disable=True)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Aug 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simple way to turn off the progress bar

ddpm(output_type="numpy")["sample"]
captured = capsys.readouterr()
assert captured.err == "", "Progress bar should be disabled"


class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models
Expand Down