diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1be09fdda0d6..2c0d94fcc16b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -106,6 +106,8 @@ title: DDIM - local: api/pipelines/ddpm title: DDPM + - local: api/pipelines/dit + title: DiT - local: api/pipelines/latent_diffusion title: Latent Diffusion - local: api/pipelines/paint_by_example diff --git a/docs/source/en/api/pipelines/dit.mdx b/docs/source/en/api/pipelines/dit.mdx new file mode 100644 index 000000000000..d7ab18e2ea76 --- /dev/null +++ b/docs/source/en/api/pipelines/dit.mdx @@ -0,0 +1,59 @@ + + +# [Scalable Diffusion Models with Transformers](https://www.wpeebles.com/DiT) (DiT) + +## Overview + +[Scalable Diffusion Models with Transformers](https://arxiv.org/abs/2212.09748) (DiT) by William Peebles and Saining Xie. + +The abstract of the paper is the following: + +*We explore a new class of diffusion models based on the transformer architecture. We train latent diffusion models of images, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops -- through increased transformer depth/width or increased number of input tokens -- consistently have lower FID. In addition to possessing good scalability properties, our largest DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512x512 and 256x256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.* + +The original codebase of this paper can be found here: [facebookresearch/dit](https://github.com/facebookresearch/dit). + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_dit.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dit/pipeline_dit.py) | *Conditional Image Generation* | - | + + +## Usage example + +```python +from diffusers import DiTPipeline, DPMSolverMultistepScheduler +import torch + +pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16) +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +# pick words from Imagenet class labels +pipe.labels # to print all available words + +# pick words that exist in ImageNet +words = ["white shark", "umbrella"] + +class_ids = pipe.get_label_ids(words) + +generator = torch.manual_seed(33) +output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator) + +image = output.images[0] # label 'white shark' +``` + +## DiTPipeline +[[autodoc]] DiTPipeline + - all + - __call__ diff --git a/docs/source/en/api/schedulers/overview.mdx b/docs/source/en/api/schedulers/overview.mdx index 7e139d152b4b..d27fbe10c528 100644 --- a/docs/source/en/api/schedulers/overview.mdx +++ b/docs/source/en/api/schedulers/overview.mdx @@ -37,6 +37,7 @@ To this end, the design of schedulers is such that: - Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality. - Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists). +- Many diffusion pipelines, such as [`StableDiffusionPipeline`] and [`DiTPipeline`] can use any of [`KarrasDiffusionSchedulers`] ## Schedulers Summary @@ -80,4 +81,6 @@ The class [`SchedulerOutput`] contains the outputs from any schedulers `step(... [[autodoc]] schedulers.scheduling_utils.SchedulerOutput +### KarrasDiffusionSchedulers +[[autodoc]] schedulers.scheduling_utils.KarrasDiffusionSchedulers diff --git a/scripts/convert_dit_to_diffusers.py b/scripts/convert_dit_to_diffusers.py new file mode 100644 index 000000000000..e14b4ad2a7bc --- /dev/null +++ b/scripts/convert_dit_to_diffusers.py @@ -0,0 +1,162 @@ +import argparse +import os + +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel +from torchvision.datasets.utils import download_url + + +pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"} + + +def download_model(model_name): + """ + Downloads a pre-trained DiT model from the web. + """ + local_path = f"pretrained_models/{model_name}" + if not os.path.isfile(local_path): + os.makedirs("pretrained_models", exist_ok=True) + web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}" + download_url(web_path, "pretrained_models") + model = torch.load(local_path, map_location=lambda storage, loc: storage) + return model + + +def main(args): + state_dict = download_model(pretrained_models[args.image_size]) + + state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"] + state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"] + state_dict.pop("x_embedder.proj.weight") + state_dict.pop("x_embedder.proj.bias") + + for depth in range(28): + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[ + "t_embedder.mlp.0.weight" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[ + "t_embedder.mlp.0.bias" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[ + "t_embedder.mlp.2.weight" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[ + "t_embedder.mlp.2.bias" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[ + "y_embedder.embedding_table.weight" + ] + + state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[ + f"blocks.{depth}.adaLN_modulation.1.weight" + ] + state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[ + f"blocks.{depth}.adaLN_modulation.1.bias" + ] + + q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0) + + state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q + state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias + state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k + state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias + state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v + state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias + + state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[ + f"blocks.{depth}.attn.proj.weight" + ] + state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"] + + state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"] + state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"] + state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"] + state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"] + + state_dict.pop(f"blocks.{depth}.attn.qkv.weight") + state_dict.pop(f"blocks.{depth}.attn.qkv.bias") + state_dict.pop(f"blocks.{depth}.attn.proj.weight") + state_dict.pop(f"blocks.{depth}.attn.proj.bias") + state_dict.pop(f"blocks.{depth}.mlp.fc1.weight") + state_dict.pop(f"blocks.{depth}.mlp.fc1.bias") + state_dict.pop(f"blocks.{depth}.mlp.fc2.weight") + state_dict.pop(f"blocks.{depth}.mlp.fc2.bias") + state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight") + state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias") + + state_dict.pop("t_embedder.mlp.0.weight") + state_dict.pop("t_embedder.mlp.0.bias") + state_dict.pop("t_embedder.mlp.2.weight") + state_dict.pop("t_embedder.mlp.2.bias") + state_dict.pop("y_embedder.embedding_table.weight") + + state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"] + state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"] + state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"] + state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"] + + state_dict.pop("final_layer.linear.weight") + state_dict.pop("final_layer.linear.bias") + state_dict.pop("final_layer.adaLN_modulation.1.weight") + state_dict.pop("final_layer.adaLN_modulation.1.bias") + + # DiT XL/2 + transformer = Transformer2DModel( + sample_size=args.image_size // 8, + num_layers=28, + attention_head_dim=72, + in_channels=4, + out_channels=8, + patch_size=2, + attention_bias=True, + num_attention_heads=16, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_zero", + norm_elementwise_affine=False, + ) + transformer.load_state_dict(state_dict, strict=True) + + scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_schedule="linear", + prediction_type="epsilon", + clip_sample=False, + ) + + vae = AutoencoderKL.from_pretrained(args.vae_model) + + pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler) + + if args.save: + pipeline.save_pretrained(args.checkpoint_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--image_size", + default=256, + type=int, + required=False, + help="Image size of pretrained model, either 256 or 512.", + ) + parser.add_argument( + "--vae_model", + default="stabilityai/sd-vae-ft-ema", + type=str, + required=False, + help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.", + ) + parser.add_argument( + "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline." + ) + + args = parser.parse_args() + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8988cdf14e1f..12ba3d270167 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -56,6 +56,7 @@ DDIMPipeline, DDPMPipeline, DiffusionPipeline, + DiTPipeline, ImagePipelineOutput, KarrasVePipeline, LDMPipeline, diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 1ef1edc14629..7fc779fc543e 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -4,7 +4,7 @@ deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", - "black": "black==22.8", + "black": "black==22.12", "datasets": "datasets", "filelock": "filelock", "flake8": "flake8>=3.8.3", diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 85dcc800fd1e..08263875d0c2 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -20,6 +20,7 @@ from ..utils.import_utils import is_xformers_available from .cross_attention import CrossAttention +from .embeddings import CombinedTimestepLabelEmbeddings if is_xformers_available(): @@ -196,10 +197,21 @@ def __init__( attention_bias: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention - self.use_ada_layer_norm = num_embeds_ada_norm is not None + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) # 1. Self-Attn self.attn1 = CrossAttention( @@ -212,7 +224,7 @@ def __init__( upcast_attention=upcast_attention, ) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) # 2. Cross-Attn if cross_attention_dim is not None: @@ -228,15 +240,27 @@ def __init__( else: self.attn2 = None - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) if cross_attention_dim is not None: - self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) else: self.norm2 = None # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) def forward( self, @@ -245,11 +269,18 @@ def forward( timestep=None, attention_mask=None, cross_attention_kwargs=None, + class_labels=None, ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + # 1. Self-Attention - norm_hidden_states = ( - self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) - ) cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, @@ -257,13 +288,16 @@ def forward( attention_mask=attention_mask, **cross_attention_kwargs, ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states if self.attn2 is not None: - # 2. Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) + + # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -273,7 +307,17 @@ def forward( hidden_states = attn_output + hidden_states # 3. Feed-forward - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states return hidden_states @@ -288,6 +332,7 @@ class FeedForward(nn.Module): mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. """ def __init__( @@ -297,6 +342,7 @@ def __init__( mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", + final_dropout: bool = False, ): super().__init__() inner_dim = int(dim * mult) @@ -304,6 +350,8 @@ def __init__( if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": @@ -316,6 +364,9 @@ def __init__( self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states): for module in self.net: @@ -325,18 +376,19 @@ def forward(self, hidden_states): class GELU(nn.Module): r""" - GELU activation function + GELU activation function with tanh approximation support with `approximate="tanh"`. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): super().__init__() self.proj = nn.Linear(dim_in, dim_out) + self.approximate = approximate def gelu(self, gate): if gate.device.type != "mps": - return F.gelu(gate) + return F.gelu(gate, approximate=self.approximate) # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) @@ -344,7 +396,6 @@ def forward(self, hidden_states): return hidden_states -# feedforward class GEGLU(nn.Module): r""" A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. @@ -402,3 +453,24 @@ def forward(self, x, timestep): scale, shift = torch.chunk(emb, 2) x = self.norm(x) * (1 + scale) + shift return x + + +class AdaLayerNormZero(nn.Module): + """ + Norm layer adaptive layer norm zero (adaLN-Zero). + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, timestep, class_labels, hidden_dtype=None): + emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0221d891f171..fc6cae43c16b 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -61,6 +61,96 @@ def get_timestep_embedding( return emb +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + return latent + self.pos_embed + + class TimestepEmbedding(nn.Module): def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): super().__init__() @@ -198,3 +288,58 @@ def forward(self, index): emb = emb + pos_emb[:, : emb.shape[1], :] return emb + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + class_labels = self.class_embedder(class_labels) # (N, D) + + conditioning = timesteps_emb + class_labels # (N, D) + + return conditioning diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index def4486932b6..57dd424aa4c6 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -20,8 +20,9 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..models.embeddings import ImagePositionalEmbeddings -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock +from .embeddings import PatchEmbed from .modeling_utils import ModelMixin @@ -81,6 +82,7 @@ def __init__( num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, + out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, @@ -88,11 +90,14 @@ def __init__( attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, ): super().__init__() self.use_linear_projection = use_linear_projection @@ -102,18 +107,35 @@ def __init__( # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = in_channels is not None + self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is None." ) - elif not self.is_input_continuous and not self.is_input_vectorized: + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: raise ValueError( - f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is not None." + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." ) # 2. Define input layers @@ -137,6 +159,20 @@ def __init__( self.latent_image_embedding = ImagePositionalEmbeddings( num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( @@ -152,13 +188,17 @@ def __init__( attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, ) for d in range(num_layers) ] ) # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels if self.is_input_continuous: + # TODO: should use out_channels for continous projections if use_linear_projection: self.proj_out = nn.Linear(in_channels, inner_dim) else: @@ -166,12 +206,17 @@ def __init__( elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) def forward( self, hidden_states, encoder_hidden_states=None, timestep=None, + class_labels=None, cross_attention_kwargs=None, return_dict: bool = True, ): @@ -185,6 +230,9 @@ def forward( self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -195,7 +243,7 @@ def forward( """ # 1. Input if self.is_input_continuous: - batch, channel, height, width = hidden_states.shape + batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) @@ -209,6 +257,8 @@ def forward( hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) # 2. Blocks for block in self.transformer_blocks: @@ -217,6 +267,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, ) # 3. Output @@ -237,6 +288,24 @@ def forward( # log(p(x_0)) output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5b461ba879c5..f0a6db712345 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -18,6 +18,7 @@ from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline + from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 5166cbb294c6..6978ab8e28b2 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -23,14 +23,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -91,14 +84,7 @@ def __init__( text_encoder: RobertaSeriesModelWithTransformation, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 3bd7b3e75be6..67c1d693ef5d 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -25,14 +25,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -129,14 +122,7 @@ def __init__( text_encoder: RobertaSeriesModelWithTransformation, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/dit/__init__.py b/src/diffusers/pipelines/dit/__init__.py new file mode 100644 index 000000000000..4ef0729cb490 --- /dev/null +++ b/src/diffusers/pipelines/dit/__init__.py @@ -0,0 +1 @@ +from .pipeline_dit import DiTPipeline diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py new file mode 100644 index 000000000000..ea372036f907 --- /dev/null +++ b/src/diffusers/pipelines/dit/pipeline_dit.py @@ -0,0 +1,199 @@ +# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) +# William Peebles and Saining Xie +# +# Copyright (c) 2021 OpenAI +# MIT License +# +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from ...models import AutoencoderKL, Transformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DiTPipeline(DiffusionPipeline): + r""" + This pipeline inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + transformer ([`Transformer2DModel`]): + Class conditioned Transformer in Diffusion model to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `dit` to denoise the encoded image latents. + """ + + def __init__( + self, + transformer: Transformer2DModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + id2label: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) + + # create a imagenet -> id dictionary for easier use + self.labels = {} + if id2label is not None: + for key, value in id2label.items(): + for label in value.split(","): + self.labels[label.lstrip().rstrip()] = int(key) + self.labels = dict(sorted(self.labels.items())) + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + + Map label strings, *e.g.* from ImageNet, to corresponding class ids. + + Parameters: + label (`str` or `dict` of `str`): label strings to be mapped to class ids. + + Returns: + `list` of `int`: Class ids to be processed by pipeline. + """ + + if not isinstance(label, list): + label = list(label) + + for l in label: + if l not in self.labels: + raise ValueError( + f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." + ) + + return [self.labels[l] for l in label] + + @torch.no_grad() + def __call__( + self, + class_labels: List[int], + guidance_scale: float = 4.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Function invoked when calling the pipeline for generation. + + Args: + class_labels (List[int]): + List of imagenet class labels for the images to be generated. + guidance_scale (`float`, *optional*, defaults to 4.0): + Scale of the guidance signal. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + num_inference_steps (`int`, *optional*, defaults to 250): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + """ + + batch_size = len(class_labels) + latent_size = self.transformer.config.sample_size + latent_channels = self.transformer.config.in_channels + + latents = randn_tensor( + shape=(batch_size, latent_channels, latent_size, latent_size), + generator=generator, + device=self.device, + dtype=self.transformer.dtype, + ) + latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents + + class_labels = torch.tensor(class_labels, device=self.device).reshape(-1) + class_null = torch.tensor([1000] * batch_size, device=self.device) + class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale > 1: + half = latent_model_input[: len(latent_model_input) // 2] + latent_model_input = torch.cat([half, half], dim=0) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + timesteps = t + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(latent_model_input.shape[0]) + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, timestep=timesteps, class_labels=class_labels_input + ).sample + + # perform guidance + if guidance_scale > 1: + eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + + half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + + noise_pred = torch.cat([eps, rest], dim=1) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + model_output, _ = torch.split(noise_pred, latent_channels, dim=1) + else: + model_output = noise_pred + + # compute previous image: x_t -> x_t-1 + latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample + + if guidance_scale > 1: + latents, _ = latent_model_input.chunk(2, dim=0) + else: + latents = latent_model_input + + latents = 1 / 0.18215 * latents + samples = self.vae.decode(latents).sample + + samples = (samples / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + samples = self.numpy_to_pil(samples) + + if not return_dict: + return (samples,) + + return ImagePipelineOutput(images=samples) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index c3b4b905e0d2..b38ca866d58d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,14 +22,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -88,14 +81,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 7e876f49c68f..fca9cb9e3732 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -25,14 +25,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -91,14 +84,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, depth_estimator: DPTForDepthEstimation, feature_extractor: DPTFeatureExtractor, ): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index e9ca167707bc..37d4a50efc45 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -23,14 +23,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -73,14 +66,7 @@ def __init__( vae: AutoencoderKL, image_encoder: CLIPVisionModelWithProjection, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 2ec26748408f..fceb45e75727 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -24,14 +24,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, deprecate, @@ -133,14 +126,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 1eb710937578..140aa8da2a77 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -173,7 +173,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 1f0be3ac0b28..588682b4ce33 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -24,14 +24,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -100,14 +93,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index d5eb63ca5db8..af4caa320278 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -22,7 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -84,7 +84,7 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, low_res_scheduler: DDPMScheduler, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, max_noise_level: int = 350, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index b2bed9d20892..ff4b41a9dc63 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -10,14 +10,7 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionSafePipelineOutput @@ -65,14 +58,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: SafeStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index 88e7e4b6a49f..ec8be907bb7a 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -7,7 +7,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging from ..pipeline_utils import DiffusionPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline @@ -53,7 +53,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): image_unet: UNet2DConditionModel text_unet: UNet2DConditionModel vae: AutoencoderKL - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + scheduler: KarrasDiffusionSchedulers def __init__( self, @@ -64,7 +64,7 @@ def __init__( image_unet: UNet2DConditionModel, text_unet: UNet2DConditionModel, vae: AutoencoderKL, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, ): super().__init__() diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 71bfe56b034d..460244854222 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -28,7 +28,7 @@ ) from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel @@ -62,7 +62,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + scheduler: KarrasDiffusionSchedulers _optional_components = ["text_unet"] @@ -75,7 +75,7 @@ def __init__( image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 56b532010c3a..b08d9bb143a0 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -23,7 +23,7 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -53,7 +53,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): image_encoder: CLIPVisionModelWithProjection image_unet: UNet2DConditionModel vae: AutoencoderKL - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + scheduler: KarrasDiffusionSchedulers def __init__( self, @@ -61,7 +61,7 @@ def __init__( image_encoder: CLIPVisionModelWithProjection, image_unet: UNet2DConditionModel, vae: AutoencoderKL, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index ac0adf5e7abf..06d8773eaf4a 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -21,7 +21,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel @@ -54,7 +54,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + scheduler: KarrasDiffusionSchedulers _optional_components = ["text_unet"] @@ -65,7 +65,7 @@ def __init__( image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 298bbb9ef4d1..3746acd5b576 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -39,7 +39,7 @@ from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_unclip import UnCLIPScheduler - from .scheduling_utils import SchedulerMixin + from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler try: @@ -55,7 +55,12 @@ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler - from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left + from .scheduling_utils_flax import ( + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, + ) try: diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 32a16071c66d..6a9fe29c6299 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -23,8 +23,8 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, deprecate, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass @@ -112,7 +112,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] order = 1 diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 9c675f17540f..52a997fa989e 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -24,8 +24,8 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from .scheduling_utils_flax import ( - _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, @@ -102,7 +102,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 12c17fd16948..b58ed8338280 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -22,8 +22,8 @@ import torch from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, deprecate, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass @@ -105,7 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] order = 1 diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index ed83ae8df231..8223b340cb7b 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -24,8 +24,8 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from .scheduling_utils_flax import ( - _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, @@ -85,7 +85,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 528bef9e09b3..1ad5480b7878 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -22,8 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -106,7 +105,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 2a003920ecf4..8acb87d78a4c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -21,8 +21,8 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from ..utils import deprecate +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -117,7 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] order = 1 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index 0aa121b59dec..ed2ed5f5e5a4 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -24,8 +24,8 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from .scheduling_utils_flax import ( - _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, @@ -140,7 +140,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index d016711b59df..0225d8027bc3 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,8 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -116,7 +115,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 9976235b75f6..45f939aafe70 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -19,8 +19,8 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, logging, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -71,7 +71,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 10f277f7e090..02e5c2cd99fd 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -19,8 +19,8 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, logging, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -72,7 +72,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 4f40a24050b4..0dea944b6fef 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -18,8 +18,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -48,7 +47,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 370a078704d8..175f338b929e 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -18,8 +18,8 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, randn_tensor -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from ..utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -49,7 +49,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 8aee346c574c..18dd97671636 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -18,8 +18,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -49,7 +48,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 28bc9bd0c608..f2c474ffe11c 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -21,8 +21,8 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass @@ -70,7 +70,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index fde18f2653d6..e105ded997d2 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -21,8 +21,8 @@ from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( - _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, @@ -82,7 +82,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index c3ac5fdf75fc..065a07e955f8 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -21,8 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -92,7 +91,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 25c0db934617..572da534643b 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -23,8 +23,8 @@ from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( - _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, @@ -110,7 +110,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype pndm_order: int diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 90ab674e38a4..f4103d4d62cc 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -14,6 +14,7 @@ import importlib import os from dataclasses import dataclass +from enum import Enum from typing import Any, Dict, Optional, Union import torch @@ -24,6 +25,21 @@ SCHEDULER_CONFIG_NAME = "scheduler_config.json" +class KarrasDiffusionSchedulers(Enum): + DDIMScheduler = 1 + DDPMScheduler = 2 + PNDMScheduler = 3 + LMSDiscreteScheduler = 4 + EulerDiscreteScheduler = 5 + HeunDiscreteScheduler = 6 + EulerAncestralDiscreteScheduler = 7 + DPMSolverMultistepScheduler = 8 + DPMSolverSinglestepScheduler = 9 + KDPM2DiscreteScheduler = 10 + KDPM2AncestralDiscreteScheduler = 11 + DEISMultistepScheduler = 12 + + @dataclass class SchedulerOutput(BaseOutput): """ diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index 889c0f25bc2b..9708c0883760 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -15,16 +15,24 @@ import math import os from dataclasses import dataclass +from enum import Enum from typing import Any, Dict, Optional, Tuple, Union import flax import jax.numpy as jnp -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput +from ..utils import BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" -_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS] + + +class FlaxKarrasDiffusionSchedulers(Enum): + FlaxDDIMScheduler = 1 + FlaxDDPMScheduler = 2 + FlaxPNDMScheduler = 3 + FlaxLMSDiscreteScheduler = 4 + FlaxDPMSolverMultistepScheduler = 5 @dataclass diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3d059f3f944e..61b1f2ca8dda 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -19,7 +19,6 @@ from .. import __version__ from .constants import ( - _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CONFIG_NAME, DIFFUSERS_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 35efff392cbd..0edb4c57f076 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,18 +30,3 @@ DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) - -_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ - "DDIMScheduler", - "DDPMScheduler", - "PNDMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "KDPM2DiscreteScheduler", - "KDPM2AncestralDiscreteScheduler", - "DEISMultistepScheduler", -] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 62c2bbc2732d..1e7c0a46a2b2 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -227,6 +227,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DiTPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ImagePipelineOutput(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/dit/__init__.py b/tests/pipelines/dit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py new file mode 100644 index 000000000000..ab41f9751c22 --- /dev/null +++ b/tests/pipelines/dit/test_dit.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel +from diffusers.utils import load_numpy, slow +from diffusers.utils.testing_utils import require_torch_gpu + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = DiTPipeline + test_cpu_offload = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = Transformer2DModel( + sample_size=4, + num_layers=2, + patch_size=2, + attention_head_dim=2, + num_attention_heads=2, + in_channels=4, + out_channels=8, + attention_bias=True, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_zero", + norm_elementwise_affine=False, + ) + vae = AutoencoderKL() + scheduler = DDIMScheduler() + components = {"transformer": transformer.eval(), "vae": vae.eval(), "scheduler": scheduler} + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "class_labels": [1], + "generator": generator, + "num_inference_steps": 2, + "output_type": "numpy", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (1, 4, 4, 3)) + expected_slice = np.array( + [0.44405967, 0.33592293, 0.6093237, 0.48981372, 0.79098296, 0.7504172, 0.59413105, 0.49462673, 0.35190058] + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(relax_max_difference=True) + + +@require_torch_gpu +@slow +class DiTPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_dit_256(self): + generator = torch.manual_seed(0) + + pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256") + pipe.to("cuda") + + words = ["vase", "umbrella", "white shark", "white wolf"] + ids = pipe.get_label_ids(words) + + images = pipe(ids, generator=generator, num_inference_steps=40, output_type="np").images + + for word, image in zip(words, images): + expected_image = load_numpy( + f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}.npy" + ) + assert np.abs((expected_image - image).sum()) < 1e-3 + + def test_dit_512_fp16(self): + generator = torch.manual_seed(0) + + pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512", torch_dtype=torch.float16) + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.to("cuda") + + words = ["vase", "umbrella", "white shark", "white wolf"] + ids = pipe.get_label_ids(words) + + images = pipe(ids, generator=generator, num_inference_steps=25, output_type="np").images + + for word, image in zip(words, images): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + f"/dit/{word}_fp16.npy" + ) + assert np.abs((expected_image - image).sum()) < 1e-3 diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 7babec588805..08f13b89607c 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -36,7 +36,7 @@ class PipelineTesterMixin: equivalence of dict and tuple outputs, etc. """ - allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image"] + allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image", "class_labels"] required_optional_params = ["generator", "num_inference_steps", "return_dict"] num_inference_steps_args = ["num_inference_steps"] @@ -194,8 +194,8 @@ def _test_inference_batch_single_identical( ): if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]: # RePaint can hardly be made deterministic since the scheduler is currently always - # indeterministic - # CycleDiffusion is also slighly undeterministic + # nondeterministic + # CycleDiffusion is also slightly nondeterministic return if test_max_difference is None: @@ -515,7 +515,7 @@ def test_cpu_offload_forward_pass(self): torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", ) - def test_xformers_attention_forward_pass(self): + def test_xformers_attention_forwardGenerator_pass(self): if not self.test_xformers_attention: return