From 9533ee156d5a5addf39a580a8f77e3c403100674 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Jan 2024 19:56:16 +0530 Subject: [PATCH 01/13] =?UTF-8?q?move=20unets=20to=20=20module=20?= =?UTF-8?q?=F0=9F=A6=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/en/api/models/unet-motion.md | 2 +- docs/source/en/api/models/unet.md | 2 +- docs/source/en/api/models/unet2d-cond.md | 6 ++-- docs/source/en/api/models/unet2d.md | 2 +- docs/source/en/api/models/unet3d-cond.md | 2 +- .../pipeline_animatediff_controlnet.py | 2 +- .../stable_diffusion_controlnet_reference.py | 2 +- .../community/stable_diffusion_reference.py | 2 +- .../stable_diffusion_xl_reference.py | 2 +- .../controlnetxs/controlnetxs.py | 4 +-- scripts/convert_amused.py | 2 +- scripts/convert_consistency_decoder.py | 2 +- src/diffusers/__init__.py | 2 +- .../experimental/rl/value_guided_sampling.py | 2 +- src/diffusers/models/__init__.py | 21 +++++++------ .../models/autoencoders/autoencoder_kl.py | 10 +++---- .../autoencoder_kl_temporal_decoder.py | 6 ++-- .../autoencoders/consistency_decoder_vae.py | 8 ++--- src/diffusers/models/autoencoders/vae.py | 2 +- src/diffusers/models/controlnet.py | 18 +++++++---- src/diffusers/models/controlnet_flax.py | 2 +- src/diffusers/models/prior_transformer.py | 6 ++-- src/diffusers/models/unets/__init__.py | 9 ++++++ src/diffusers/models/{ => unets}/unet_1d.py | 8 ++--- .../models/{ => unets}/unet_1d_blocks.py | 4 +-- src/diffusers/models/{ => unets}/unet_2d.py | 8 ++--- .../models/{ => unets}/unet_2d_blocks.py | 16 +++++----- .../models/{ => unets}/unet_2d_blocks_flax.py | 4 +-- .../models/{ => unets}/unet_2d_condition.py | 14 ++++----- .../{ => unets}/unet_2d_condition_flax.py | 8 ++--- .../models/{ => unets}/unet_3d_blocks.py | 14 ++++----- .../models/{ => unets}/unet_3d_condition.py | 30 +++++++++---------- .../models/{ => unets}/unet_kandinsky3.py | 10 +++---- .../models/{ => unets}/unet_motion_model.py | 28 ++++++++--------- .../unet_spatio_temporal_condition.py | 14 ++++----- src/diffusers/models/{ => unets}/uvit_2d.py | 22 +++++++------- .../animatediff/pipeline_animatediff.py | 2 +- .../pipelines/audioldm2/modeling_audioldm2.py | 14 ++++----- .../versatile_diffusion/modeling_text_unet.py | 14 ++++----- .../wuerstchen/modeling_wuerstchen_prior.py | 6 ++-- tests/models/test_unet_2d_blocks.py | 2 +- .../controlnet/test_controlnet_sdxl.py | 2 +- 42 files changed, 177 insertions(+), 159 deletions(-) create mode 100644 src/diffusers/models/unets/__init__.py rename src/diffusers/models/{ => unets}/unet_1d.py (97%) rename src/diffusers/models/{ => unets}/unet_1d_blocks.py (99%) rename src/diffusers/models/{ => unets}/unet_2d.py (98%) rename src/diffusers/models/{ => unets}/unet_2d_blocks.py (99%) rename src/diffusers/models/{ => unets}/unet_2d_blocks_flax.py (99%) rename src/diffusers/models/{ => unets}/unet_2d_condition.py (99%) rename src/diffusers/models/{ => unets}/unet_2d_condition_flax.py (98%) rename src/diffusers/models/{ => unets}/unet_3d_blocks.py (99%) rename src/diffusers/models/{ => unets}/unet_3d_condition.py (96%) rename src/diffusers/models/{ => unets}/unet_kandinsky3.py (98%) rename src/diffusers/models/{ => unets}/unet_motion_model.py (97%) rename src/diffusers/models/{ => unets}/unet_spatio_temporal_condition.py (97%) rename src/diffusers/models/{ => unets}/uvit_2d.py (95%) diff --git a/docs/source/en/api/models/unet-motion.md b/docs/source/en/api/models/unet-motion.md index cbc8c30ff64f..af967924dfb3 100644 --- a/docs/source/en/api/models/unet-motion.md +++ b/docs/source/en/api/models/unet-motion.md @@ -22,4 +22,4 @@ The abstract from the paper is: [[autodoc]] UNetMotionModel ## UNet3DConditionOutput -[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput +[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput diff --git a/docs/source/en/api/models/unet.md b/docs/source/en/api/models/unet.md index 66508b469a60..7e6324952b28 100644 --- a/docs/source/en/api/models/unet.md +++ b/docs/source/en/api/models/unet.md @@ -22,4 +22,4 @@ The abstract from the paper is: [[autodoc]] UNet1DModel ## UNet1DOutput -[[autodoc]] models.unet_1d.UNet1DOutput +[[autodoc]] models.unets.unet_1d.UNet1DOutput diff --git a/docs/source/en/api/models/unet2d-cond.md b/docs/source/en/api/models/unet2d-cond.md index ea385ff92426..ec9dbae8f25e 100644 --- a/docs/source/en/api/models/unet2d-cond.md +++ b/docs/source/en/api/models/unet2d-cond.md @@ -22,10 +22,10 @@ The abstract from the paper is: [[autodoc]] UNet2DConditionModel ## UNet2DConditionOutput -[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput +[[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput ## FlaxUNet2DConditionModel -[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel +[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionModel ## FlaxUNet2DConditionOutput -[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput +[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput diff --git a/docs/source/en/api/models/unet2d.md b/docs/source/en/api/models/unet2d.md index 7669d4a5d75a..d317d14ce744 100644 --- a/docs/source/en/api/models/unet2d.md +++ b/docs/source/en/api/models/unet2d.md @@ -22,4 +22,4 @@ The abstract from the paper is: [[autodoc]] UNet2DModel ## UNet2DOutput -[[autodoc]] models.unet_2d.UNet2DOutput +[[autodoc]] models.unets.unet_2d.UNet2DOutput diff --git a/docs/source/en/api/models/unet3d-cond.md b/docs/source/en/api/models/unet3d-cond.md index 4eea0a6d1cd2..1dc01234dabe 100644 --- a/docs/source/en/api/models/unet3d-cond.md +++ b/docs/source/en/api/models/unet3d-cond.md @@ -22,4 +22,4 @@ The abstract from the paper is: [[autodoc]] UNet3DConditionModel ## UNet3DConditionOutput -[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput +[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py index cf0c66bb50d0..b700a6c86b93 100644 --- a/examples/community/pipeline_animatediff_controlnet.py +++ b/examples/community/pipeline_animatediff_controlnet.py @@ -26,7 +26,7 @@ from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.models.unet_motion_model import MotionAdapter +from diffusers.models.unets.unet_motion_model import MotionAdapter from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import ( diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py index 358fc1c6dc67..16f7f589b70b 100644 --- a/examples/community/stable_diffusion_controlnet_reference.py +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -8,7 +8,7 @@ from diffusers import StableDiffusionControlNetPipeline from diffusers.models import ControlNetModel from diffusers.models.attention import BasicTransformerBlock -from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D +from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index 505470574a0b..88a7febae650 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -7,7 +7,7 @@ from diffusers import StableDiffusionPipeline from diffusers.models.attention import BasicTransformerBlock -from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D +from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg from diffusers.utils import PIL_INTERPOLATION, logging diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py index 5d2b1c771128..fbfb6bdd6160 100644 --- a/examples/community/stable_diffusion_xl_reference.py +++ b/examples/community/stable_diffusion_xl_reference.py @@ -8,7 +8,7 @@ from diffusers import StableDiffusionXLPipeline from diffusers.models.attention import BasicTransformerBlock -from diffusers.models.unet_2d_blocks import ( +from diffusers.models.unets.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, diff --git a/examples/research_projects/controlnetxs/controlnetxs.py b/examples/research_projects/controlnetxs/controlnetxs.py index 20c8d0fdf0f1..027a853764f8 100644 --- a/examples/research_projects/controlnetxs/controlnetxs.py +++ b/examples/research_projects/controlnetxs/controlnetxs.py @@ -26,7 +26,7 @@ from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.lora import LoRACompatibleConv from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unet_2d_blocks import ( +from diffusers.models.unets.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, @@ -36,7 +36,7 @@ UpBlock2D, Upsample2D, ) -from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.utils import BaseOutput, logging diff --git a/scripts/convert_amused.py b/scripts/convert_amused.py index fdddbef7cd65..21be29dfdb99 100644 --- a/scripts/convert_amused.py +++ b/scripts/convert_amused.py @@ -10,7 +10,7 @@ from diffusers import VQModel from diffusers.models.attention_processor import AttnProcessor -from diffusers.models.uvit_2d import UVit2DModel +from diffusers.models.unets.uvit_2d import UVit2DModel from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline from diffusers.schedulers import AmusedScheduler diff --git a/scripts/convert_consistency_decoder.py b/scripts/convert_consistency_decoder.py index 3319f4c4665e..0cb5fc50dd60 100644 --- a/scripts/convert_consistency_decoder.py +++ b/scripts/convert_consistency_decoder.py @@ -14,7 +14,7 @@ from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel from diffusers.models.autoencoders.vae import Encoder from diffusers.models.embeddings import TimestepEmbedding -from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D +from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D args = ArgumentParser() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 00b660a6d4ac..8637bbda9e1e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -709,7 +709,7 @@ else: from .models.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin - from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py index af5ee2102163..da78f3b55605 100644 --- a/src/diffusers/experimental/rl/value_guided_sampling.py +++ b/src/diffusers/experimental/rl/value_guided_sampling.py @@ -16,7 +16,7 @@ import torch import tqdm -from ...models.unet_1d import UNet1DModel +from ...models.unets.unet_1d import UNet1DModel from ...pipelines import DiffusionPipeline from ...utils.dummy_pt_objects import DDPMScheduler from ...utils.torch_utils import randn_tensor diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 36dbe14c5053..62b1f5f1f995 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -73,19 +73,22 @@ from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_temporal import TransformerTemporalModel - from .unet_1d import UNet1DModel - from .unet_2d import UNet2DModel - from .unet_2d_condition import UNet2DConditionModel - from .unet_3d_condition import UNet3DConditionModel - from .unet_kandinsky3 import Kandinsky3UNet - from .unet_motion_model import MotionAdapter, UNetMotionModel - from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel - from .uvit_2d import UVit2DModel + from .unets import ( + Kandinsky3UNet, + MotionAdapter, + UNet1DModel, + UNet2DConditionModel, + UNet2DModel, + UNet3DConditionModel, + UNetMotionModel, + UNetSpatioTemporalConditionModel, + UVit2DModel, + ) from .vq_model import VQModel if is_flax_available(): from .controlnet_flax import FlaxControlNetModel - from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .unets import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL else: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 10a3ae58de9f..a0b23b896d13 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -157,7 +157,7 @@ def disable_slicing(self): self.use_slicing = False @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -181,7 +181,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -216,7 +216,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. @@ -448,7 +448,7 @@ def forward( return DecoderOutput(sample=dec) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, @@ -472,7 +472,7 @@ def fuse_qkv_projections(self): if isinstance(module, Attention): module.fuse_projections(fuse=True) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index dbafb4571d4a..ab4b16a1931c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -23,7 +23,7 @@ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder +from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder @@ -242,7 +242,7 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -266,7 +266,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py index ca670fec4b28..0013521f4cbb 100644 --- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -31,7 +31,7 @@ AttnProcessor, ) from ..modeling_utils import ModelMixin -from ..unet_2d import UNet2DModel +from ..unets.unet_2d import UNet2DModel from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder @@ -187,7 +187,7 @@ def disable_slicing(self): self.use_slicing = False @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -211,7 +211,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -246,7 +246,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 3f1643bc50ef..3c56f15117ba 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -22,7 +22,7 @@ from ...utils.torch_utils import randn_tensor from ..activations import get_activation from ..attention_processor import SpatialNorm -from ..unet_2d_blocks import ( +from ..unets.unet_2d_blocks import ( AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 1102f4f9d36d..65dc0513907b 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -30,8 +30,14 @@ ) from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin -from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block -from .unet_2d_condition import UNet2DConditionModel +from .unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from .unets.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -509,7 +515,7 @@ def from_unet( return controlnet @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -533,7 +539,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -568,7 +574,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. @@ -584,7 +590,7 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: r""" Enable sliced attention computation. diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py index 34aaac549f8c..5ab2597a4e79 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnet_flax.py @@ -23,7 +23,7 @@ from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .modeling_flax_utils import FlaxModelMixin -from .unet_2d_blocks_flax import ( +from .unets.unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 6b52ea344d41..081d66991faf 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -167,7 +167,7 @@ def __init__( self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -191,7 +191,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -226,7 +226,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. diff --git a/src/diffusers/models/unets/__init__.py b/src/diffusers/models/unets/__init__.py new file mode 100644 index 000000000000..44c480178d12 --- /dev/null +++ b/src/diffusers/models/unets/__init__.py @@ -0,0 +1,9 @@ +from .unet_1d import UNet1DModel +from .unet_2d import UNet2DModel +from .unet_2d_condition import UNet2DConditionModel +from .unet_2d_condition_flax import FlaxUNet2DConditionModel +from .unet_3d_condition import UNet3DConditionModel +from .unet_kandinsky3 import Kandinsky3UNet +from .unet_motion_model import MotionAdapter, UNetMotionModel +from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel +from .uvit_2d import UVit2DModel diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unets/unet_1d.py similarity index 97% rename from src/diffusers/models/unet_1d.py rename to src/diffusers/models/unets/unet_1d.py index 5bb5b0818245..131f05f735cd 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unets/unet_1d.py @@ -18,10 +18,10 @@ import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unets/unet_1d_blocks.py similarity index 99% rename from src/diffusers/models/unet_1d_blocks.py rename to src/diffusers/models/unets/unet_1d_blocks.py index 74a2f1681ead..3e128bf727c0 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unets/unet_1d_blocks.py @@ -18,8 +18,8 @@ import torch.nn.functional as F from torch import nn -from .activations import get_activation -from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims +from ..activations import get_activation +from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims class DownResnetBlock1D(nn.Module): diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unets/unet_2d.py similarity index 98% rename from src/diffusers/models/unet_2d.py rename to src/diffusers/models/unets/unet_2d.py index 0531d8aae783..0a4ede51a7fd 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -17,10 +17,10 @@ import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py similarity index 99% rename from src/diffusers/models/unet_2d_blocks.py rename to src/diffusers/models/unets/unet_2d_blocks.py index 470a021165ac..d933691d89d3 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -18,13 +18,13 @@ import torch.nn.functional as F from torch import nn -from ..utils import is_torch_version, logging -from ..utils.torch_utils import apply_freeu -from .activations import get_activation -from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 -from .dual_transformer_2d import DualTransformer2DModel -from .normalization import AdaGroupNorm -from .resnet import ( +from ...utils import is_torch_version, logging +from ...utils.torch_utils import apply_freeu +from ..activations import get_activation +from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from ..dual_transformer_2d import DualTransformer2DModel +from ..normalization import AdaGroupNorm +from ..resnet import ( Downsample2D, FirDownsample2D, FirUpsample2D, @@ -34,7 +34,7 @@ ResnetBlockCondNorm2D, Upsample2D, ) -from .transformer_2d import Transformer2DModel +from ..transformer_2d import Transformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unets/unet_2d_blocks_flax.py similarity index 99% rename from src/diffusers/models/unet_2d_blocks_flax.py rename to src/diffusers/models/unets/unet_2d_blocks_flax.py index 8cf2f8eb24b4..447efcd8c138 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unets/unet_2d_blocks_flax.py @@ -15,8 +15,8 @@ import flax.linen as nn import jax.numpy as jnp -from .attention_flax import FlaxTransformer2DModel -from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D +from ..attention_flax import FlaxTransformer2DModel +from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D class FlaxCrossAttnDownBlock2D(nn.Module): diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py similarity index 99% rename from src/diffusers/models/unet_2d_condition.py rename to src/diffusers/models/unets/unet_2d_condition.py index 7b4f9f5594ea..d8926dbf2273 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -18,11 +18,11 @@ import torch.nn as nn import torch.utils.checkpoint -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin -from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers -from .activations import get_activation -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ..activations import get_activation +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, @@ -30,7 +30,7 @@ AttnAddedKVProcessor, AttnProcessor, ) -from .embeddings import ( +from ..embeddings import ( GaussianFourierProjection, GLIGENTextBoundingboxProjection, ImageHintTimeEmbedding, @@ -42,7 +42,7 @@ TimestepEmbedding, Timesteps, ) -from .modeling_utils import ModelMixin +from ..modeling_utils import ModelMixin from .unet_2d_blocks import ( UNetMidBlock2D, UNetMidBlock2DCrossAttn, diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py similarity index 98% rename from src/diffusers/models/unet_2d_condition_flax.py rename to src/diffusers/models/unets/unet_2d_condition_flax.py index 13f53e16e7ac..5997568361b0 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -19,10 +19,10 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -from ..configuration_utils import ConfigMixin, flax_register_to_config -from ..utils import BaseOutput -from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps -from .modeling_flax_utils import FlaxModelMixin +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ...utils import BaseOutput +from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from ..modeling_flax_utils import FlaxModelMixin from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxCrossAttnUpBlock2D, diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py similarity index 99% rename from src/diffusers/models/unet_3d_blocks.py rename to src/diffusers/models/unets/unet_3d_blocks.py index e9c505c347b0..6c20b1175349 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -17,19 +17,19 @@ import torch from torch import nn -from ..utils import is_torch_version -from ..utils.torch_utils import apply_freeu -from .attention import Attention -from .dual_transformer_2d import DualTransformer2DModel -from .resnet import ( +from ...utils import is_torch_version +from ...utils.torch_utils import apply_freeu +from ..attention import Attention +from ..dual_transformer_2d import DualTransformer2DModel +from ..resnet import ( Downsample2D, ResnetBlock2D, SpatioTemporalResBlock, TemporalConvLayer, Upsample2D, ) -from .transformer_2d import Transformer2DModel -from .transformer_temporal import ( +from ..transformer_2d import Transformer2DModel +from ..transformer_temporal import ( TransformerSpatioTemporalModel, TransformerTemporalModel, ) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py similarity index 96% rename from src/diffusers/models/unet_3d_condition.py rename to src/diffusers/models/unets/unet_3d_condition.py index fc8695e064b5..b29e2c270ba9 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -20,20 +20,20 @@ import torch.nn as nn import torch.utils.checkpoint -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, deprecate, logging -from .activations import get_activation -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import BaseOutput, deprecate, logging +from ..activations import get_activation +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .transformer_temporal import TransformerTemporalModel +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, @@ -284,7 +284,7 @@ def __init__( ) @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -308,7 +308,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: r""" Enable sliced attention computation. @@ -374,7 +374,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -449,7 +449,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. @@ -469,7 +469,7 @@ def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. @@ -494,7 +494,7 @@ def enable_freeu(self, s1, s2, b1, b2): setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} @@ -503,7 +503,7 @@ def disable_freeu(self): if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora def unload_lora(self): """Unloads LoRA weights.""" deprecate( diff --git a/src/diffusers/models/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py similarity index 98% rename from src/diffusers/models/unet_kandinsky3.py rename to src/diffusers/models/unets/unet_kandinsky3.py index eef3287e5d99..b52aace419f0 100644 --- a/src/diffusers/models/unet_kandinsky3.py +++ b/src/diffusers/models/unets/unet_kandinsky3.py @@ -19,11 +19,11 @@ import torch.utils.checkpoint from torch import nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging -from .attention_processor import Attention, AttentionProcessor, AttnProcessor -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, logging +from ..attention_processor import Attention, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py similarity index 97% rename from src/diffusers/models/unet_motion_model.py rename to src/diffusers/models/unets/unet_motion_model.py index b5f0302b4a43..9654ae508215 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -17,19 +17,19 @@ import torch.nn as nn import torch.utils.checkpoint -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin -from ..utils import logging -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import logging +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .transformer_temporal import TransformerTemporalModel +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..transformer_temporal import TransformerTemporalModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel from .unet_3d_blocks import ( @@ -524,7 +524,7 @@ def save_motion_modules( ) @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -548,7 +548,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -583,7 +583,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward @@ -613,7 +613,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) - # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self) -> None: def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): @@ -625,7 +625,7 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, None, 0) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self) -> None: """ Disables custom attention processors and sets the default attention implementation. @@ -645,7 +645,7 @@ def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): module.gradient_checkpointing = value - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. @@ -670,7 +670,7 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu def disable_freeu(self) -> None: """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} diff --git a/src/diffusers/models/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py similarity index 97% rename from src/diffusers/models/unet_spatio_temporal_condition.py rename to src/diffusers/models/unets/unet_spatio_temporal_condition.py index 8d0d3e61d879..39a8009d5af9 100644 --- a/src/diffusers/models/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -4,12 +4,12 @@ import torch import torch.nn as nn -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, logging -from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import BaseOutput, logging +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block @@ -323,7 +323,7 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward diff --git a/src/diffusers/models/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py similarity index 95% rename from src/diffusers/models/uvit_2d.py rename to src/diffusers/models/unets/uvit_2d.py index c0e224562cf2..492c41e4cad4 100644 --- a/src/diffusers/models/uvit_2d.py +++ b/src/diffusers/models/unets/uvit_2d.py @@ -20,20 +20,20 @@ from torch import nn from torch.utils.checkpoint import checkpoint -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import PeftAdapterMixin -from .attention import BasicTransformerBlock, SkipFFTransformerBlock -from .attention_processor import ( +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ..attention import BasicTransformerBlock, SkipFFTransformerBlock +from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) -from .embeddings import TimestepEmbedding, get_timestep_embedding -from .modeling_utils import ModelMixin -from .normalization import GlobalResponseNorm, RMSNorm -from .resnet import Downsample2D, Upsample2D +from ..embeddings import TimestepEmbedding, get_timestep_embedding +from ..modeling_utils import ModelMixin +from ..normalization import GlobalResponseNorm, RMSNorm +from ..resnet import Downsample2D, Upsample2D class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): @@ -213,7 +213,7 @@ def layer_(*args): return logits @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -237,7 +237,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -272,7 +272,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 0fb4637dab7f..89b3231be762 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -26,7 +26,7 @@ from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder -from ...models.unet_motion_model import MotionAdapter +from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index d39b2c99ddd0..f2d2d48f3520 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -36,8 +36,8 @@ from ...models.modeling_utils import ModelMixin from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from ...models.transformer_2d import Transformer2DModel -from ...models.unet_2d_blocks import DownBlock2D, UpBlock2D -from ...models.unet_2d_condition import UNet2DConditionOutput +from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D +from ...models.unets.unet_2d_condition import UNet2DConditionOutput from ...utils import BaseOutput, is_torch_version, logging @@ -513,7 +513,7 @@ def __init__( ) @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -537,7 +537,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -572,7 +572,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. @@ -588,7 +588,7 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. @@ -654,7 +654,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 504dfa40bac8..eb6ab761ed37 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -33,7 +33,7 @@ ) from ....models.resnet import ResnetBlockCondNorm2D from ....models.transformer_2d import Transformer2DModel -from ....models.unet_2d_condition import UNet2DConditionOutput +from ....models.unets.unet_2d_condition import UNet2DConditionOutput from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import apply_freeu @@ -268,7 +268,7 @@ def forward( return objs -# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat +# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat class UNetFlatConditionModel(ModelMixin, ConfigMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample @@ -1786,7 +1786,7 @@ def custom_forward(*inputs): return hidden_states, output_states -# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim class UpBlockFlat(nn.Module): def __init__( self, @@ -1897,7 +1897,7 @@ def custom_forward(*inputs): return hidden_states -# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim class CrossAttnUpBlockFlat(nn.Module): def __init__( self, @@ -2071,7 +2071,7 @@ def custom_forward(*inputs): return hidden_states -# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlat(nn.Module): """ A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks. @@ -2227,7 +2227,7 @@ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTe return hidden_states -# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatCrossAttn(nn.Module): def __init__( self, @@ -2374,7 +2374,7 @@ def custom_forward(*inputs): return hidden_states -# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatSimpleCrossAttn(nn.Module): def __init__( self, diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index 8b494fa32476..c752cba606a4 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -66,7 +66,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro self.set_default_attn_processor() @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -90,7 +90,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -125,7 +125,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. diff --git a/tests/models/test_unet_2d_blocks.py b/tests/models/test_unet_2d_blocks.py index d714b9384860..ef77df8abdfb 100644 --- a/tests/models/test_unet_2d_blocks.py +++ b/tests/models/test_unet_2d_blocks.py @@ -14,7 +14,7 @@ # limitations under the License. import unittest -from diffusers.models.unet_2d_blocks import * # noqa F403 +from diffusers.models.unets.unet_2d_blocks import * # noqa F403 from diffusers.utils.testing_utils import torch_device from .test_unet_blocks_common import UNetBlockTesterMixin diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index ba129e763c22..88cf254ff6e0 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -28,7 +28,7 @@ StableDiffusionXLControlNetPipeline, UNet2DConditionModel, ) -from diffusers.models.unet_2d_blocks import UNetMidBlock2D +from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device From 492b9b9681671add82a4c39093c4eccbc6054d92 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Jan 2024 20:00:21 +0530 Subject: [PATCH 02/13] parameterize unet-level import. --- src/diffusers/models/unets/__init__.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/unets/__init__.py b/src/diffusers/models/unets/__init__.py index 44c480178d12..5b1418a608f0 100644 --- a/src/diffusers/models/unets/__init__.py +++ b/src/diffusers/models/unets/__init__.py @@ -1,9 +1,16 @@ -from .unet_1d import UNet1DModel -from .unet_2d import UNet2DModel -from .unet_2d_condition import UNet2DConditionModel -from .unet_2d_condition_flax import FlaxUNet2DConditionModel -from .unet_3d_condition import UNet3DConditionModel -from .unet_kandinsky3 import Kandinsky3UNet -from .unet_motion_model import MotionAdapter, UNetMotionModel -from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel -from .uvit_2d import UVit2DModel +from ...utils import is_flax_available, is_torch_available + + +if is_torch_available(): + from .unet_1d import UNet1DModel + from .unet_2d import UNet2DModel + from .unet_2d_condition import UNet2DConditionModel + from .unet_3d_condition import UNet3DConditionModel + from .unet_kandinsky3 import Kandinsky3UNet + from .unet_motion_model import MotionAdapter, UNetMotionModel + from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + from .uvit_2d import UVit2DModel + + +if is_flax_available(): + from .unet_2d_condition_flax import FlaxUNet2DConditionModel From 9bec7753b99e92b9f5fc74dfe73dbad2429711f0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Jan 2024 20:04:19 +0530 Subject: [PATCH 03/13] fix flax unet2dcondition model import --- src/diffusers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8637bbda9e1e..af61d4252717 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -381,7 +381,7 @@ else: _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] - _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( From 9496d72aa034f2ddcb26382d59243f4619404a69 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 18 Jan 2024 20:13:07 +0530 Subject: [PATCH 04/13] models __init__ --- src/diffusers/models/__init__.py | 18 +++++++++--------- src/diffusers/models/controlnet_flax.py | 6 +++--- src/diffusers/models/dual_transformer_2d.py | 2 +- src/diffusers/models/transformer_2d.py | 2 +- src/diffusers/models/transformer_temporal.py | 2 +- .../models/unets/unet_2d_condition.py | 6 +++--- .../models/unets/unet_2d_condition_flax.py | 6 +++--- .../pipelines/audioldm2/modeling_audioldm2.py | 6 +++--- .../versatile_diffusion/modeling_text_unet.py | 6 +++--- .../pipelines/unidiffuser/modeling_uvit.py | 2 +- 10 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 62b1f5f1f995..02c94ddbf1de 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -39,19 +39,19 @@ _import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformer_2d"] = ["Transformer2DModel"] _import_structure["transformer_temporal"] = ["TransformerTemporalModel"] - _import_structure["unet_1d"] = ["UNet1DModel"] - _import_structure["unet_2d"] = ["UNet2DModel"] - _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] - _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] - _import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"] - _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] - _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] - _import_structure["uvit_2d"] = ["UVit2DModel"] + _import_structure["unets.unet_1d"] = ["UNet1DModel"] + _import_structure["unets.unet_2d"] = ["UNet2DModel"] + _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] + _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"] + _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"] + _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] + _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] + _import_structure["unets.uvit_2d"] = ["UVit2DModel"] _import_structure["vq_model"] = ["VQModel"] if is_flax_available(): _import_structure["controlnet_flax"] = ["FlaxControlNetModel"] - _import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["vae_flax"] = ["FlaxAutoencoderKL"] diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py index 5ab2597a4e79..1a140cfb94d3 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnet_flax.py @@ -329,14 +329,14 @@ def __call__( controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ channel_order = self.controlnet_conditioning_channel_order diff --git a/src/diffusers/models/dual_transformer_2d.py b/src/diffusers/models/dual_transformer_2d.py index 02568298409c..21b135c2eb86 100644 --- a/src/diffusers/models/dual_transformer_2d.py +++ b/src/diffusers/models/dual_transformer_2d.py @@ -120,7 +120,7 @@ def forward( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 128395cc161a..3b219b4f0b37 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -286,7 +286,7 @@ def forward( If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index 26e899a9b908..a18671776baf 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -149,7 +149,7 @@ def forward( `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index d8926dbf2273..87297b5b5d0b 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -890,7 +890,7 @@ def forward( `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. @@ -906,8 +906,8 @@ def forward( additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py index 5997568361b0..0c17777f1a51 100644 --- a/src/diffusers/models/unets/unet_2d_condition_flax.py +++ b/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -342,14 +342,14 @@ def __call__( mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 1. time diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index f2d2d48f3520..147dd7a58e7b 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -687,7 +687,7 @@ def forward( `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. @@ -700,8 +700,8 @@ def forward( which adds large negative values to the attention scores corresponding to "discard" tokens. Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index eb6ab761ed37..20884a15da4d 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1096,7 +1096,7 @@ def forward( `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. @@ -1112,8 +1112,8 @@ def forward( additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index 6e97e0279350..561d8344e746 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -752,7 +752,7 @@ def forward( cross_attention_kwargs (*optional*): Keyword arguments to supply to the cross attention layers, if used. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the From 85c20aa0448cb78c3e3e932aa3b410191238a9cc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jan 2024 09:19:45 +0530 Subject: [PATCH 05/13] mildly depcrecating models.unet_2d_blocks in favor of models.unets.unet_2d_blocks. --- src/diffusers/models/unet_2d_blocks.py | 183 +++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 src/diffusers/models/unet_2d_blocks.py diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py new file mode 100644 index 000000000000..99c11f087ef6 --- /dev/null +++ b/src/diffusers/models/unet_2d_blocks.py @@ -0,0 +1,183 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..utils import deprecate + + +def get_down_block(): + deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import get_down_block`, instead." + deprecate("get_down_block", "0.29", deprecation_message) + + +def get_up_block(): + deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import get_up_block`, instead." + deprecate("get_up_block", "0.29", deprecation_message) + + +class AutoencoderTinyBlock: + deprecation_message = "Importing `AutoencoderTinyBlock` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AutoencoderTinyBlock`, instead." + deprecate("AutoencoderTinyBlock", "0.29", deprecation_message) + from .unets.unet_2d_blocks import AutoencoderTinyBlock + + +class UNetMidBlock2D: + deprecation_message = "Importing `UNetMidBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D`, instead." + deprecate("UNetMidBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import UNetMidBlock2D + + +class UNetMidBlock2DCrossAttn: + deprecation_message = "Importing `UNetMidBlock2DCrossAttn` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn`, instead." + deprecate("UNetMidBlock2DCrossAttn", "0.29", deprecation_message) + from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn + + +class UNetMidBlock2DSimpleCrossAttn: + deprecation_message = "Importing `UNetMidBlock2DSimpleCrossAttn` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn`, instead." + deprecate("UNetMidBlock2DSimpleCrossAttn", "0.29", deprecation_message) + from .unets.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn + + +class AttnDownBlock2D: + deprecation_message = "Importing `AttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnDownBlock2D`, instead." + deprecate("AttnDownBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import AttnDownBlock2D + + +class CrossAttnDownBlock2D: + deprecation_message = "Importing `AttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D`, instead." + deprecate("CrossAttnDownBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import CrossAttnDownBlock2D + + +class DownBlock2D: + deprecation_message = "Importing `DownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import DownBlock2D`, instead." + deprecate("DownBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import DownBlock2D + + +class AttnDownEncoderBlock2D: + deprecation_message = "Importing `AttnDownEncoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnDownEncoderBlock2D`, instead." + deprecate("AttnDownEncoderBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import AttnDownEncoderBlock2D + + +class AttnSkipDownBlock2D: + deprecation_message = "Importing `AttnSkipDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnSkipDownBlock2D`, instead." + deprecate("AttnSkipDownBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import AttnSkipDownBlock2D + + +class SkipDownBlock2D: + deprecation_message = "Importing `SkipDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SkipDownBlock2D`, instead." + deprecate("SkipDownBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import SkipDownBlock2D + + +class ResnetDownsampleBlock2D: + deprecation_message = "Importing `ResnetDownsampleBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D`, instead." + deprecate("ResnetDownsampleBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import ResnetDownsampleBlock2D + + +class SimpleCrossAttnDownBlock2D: + deprecation_message = "Importing `SimpleCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SimpleCrossAttnDownBlock2D`, instead." + deprecate("SimpleCrossAttnDownBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import SimpleCrossAttnDownBlock2D + + +class KDownBlock2D: + deprecation_message = "Importing `KDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KDownBlock2D`, instead." + deprecate("KDownBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import KDownBlock2D + + +class KCrossAttnDownBlock2D: + deprecation_message = "Importing `KCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KCrossAttnDownBlock2D`, instead." + deprecate("KCrossAttnDownBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import KCrossAttnDownBlock2D + + +class AttnUpBlock2D: + deprecation_message = "Importing `AttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnUpBlock2D`, instead." + deprecate("AttnUpBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import AttnUpBlock2D + + +class CrossAttnUpBlock2D: + deprecation_message = "Importing `CrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import CrossAttnUpBlock2D`, instead." + deprecate("CrossAttnUpBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import CrossAttnUpBlock2D + + +class UpBlock2D: + deprecation_message = "Importing `UpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UpBlock2D`, instead." + deprecate("UpBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import UpBlock2D + + +class UpDecoderBlock2D: + deprecation_message = "Importing `UpDecoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UpDecoderBlock2D`, instead." + deprecate("UpDecoderBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import UpDecoderBlock2D + + +class AttnUpDecoderBlock2D: + deprecation_message = "Importing `AttnUpDecoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnUpDecoderBlock2D`, instead." + deprecate("AttnUpDecoderBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import AttnUpDecoderBlock2D + + +class AttnSkipUpBlock2D: + deprecation_message = "Importing `AttnSkipUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnSkipUpBlock2D`, instead." + deprecate("AttnSkipUpBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import AttnSkipUpBlock2D + + +class SkipUpBlock2D: + deprecation_message = "Importing `SkipUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SkipUpBlock2D`, instead." + deprecate("SkipUpBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import SkipUpBlock2D + + +class ResnetUpsampleBlock2D: + deprecation_message = "Importing `ResnetUpsampleBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import ResnetUpsampleBlock2D`, instead." + deprecate("ResnetUpsampleBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import ResnetUpsampleBlock2D + + +class SimpleCrossAttnUpBlock2D: + deprecation_message = "Importing `SimpleCrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SimpleCrossAttnUpBlock2D`, instead." + deprecate("SimpleCrossAttnUpBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import SimpleCrossAttnUpBlock2D + + +class KUpBlock2D: + deprecation_message = "Importing `KUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KUpBlock2D`, instead." + deprecate("KUpBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import KUpBlock2D + + +class KCrossAttnUpBlock2D: + deprecation_message = "Importing `KCrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KCrossAttnUpBlock2D`, instead." + deprecate("KCrossAttnUpBlock2D", "0.29", deprecation_message) + from .unets.unet_2d_blocks import KCrossAttnUpBlock2D + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock: + deprecation_message = "Importing `KAttentionBlock` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KAttentionBlock`, instead." + deprecate("KAttentionBlock", "0.29", deprecation_message) + from .unets.unet_2d_blocks import KAttentionBlock From 5dc5886f8d2f5eb30b3f18172cfee9c5a8258e20 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jan 2024 09:29:38 +0530 Subject: [PATCH 06/13] noqa --- src/diffusers/models/unet_2d_blocks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 99c11f087ef6..ebecc208364e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -20,11 +20,15 @@ def get_down_block(): deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import get_down_block`, instead." deprecate("get_down_block", "0.29", deprecation_message) + from .unets.unet_2d_blocks import get_down_block # noqa + def get_up_block(): deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import get_up_block`, instead." deprecate("get_up_block", "0.29", deprecation_message) + from .unets.unet_2d_blocks import get_up_block # noqa + class AutoencoderTinyBlock: deprecation_message = "Importing `AutoencoderTinyBlock` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AutoencoderTinyBlock`, instead." From 233ac9804401528a27ade5dca8c37c1cffdec6cd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jan 2024 10:00:37 +0530 Subject: [PATCH 07/13] correct depcrecation behaviour --- src/diffusers/models/unet_2d_blocks.py | 123 +++++++++++++++++++++++-- 1 file changed, 117 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index ebecc208364e..ba560372191f 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -12,22 +12,133 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from ..utils import deprecate -def get_down_block(): +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import get_down_block`, instead." deprecate("get_down_block", "0.29", deprecation_message) - from .unets.unet_2d_blocks import get_down_block # noqa - - -def get_up_block(): + from .unets.unet_2d_blocks import get_down_block + + return get_down_block( + down_block_type=down_block_type, + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim, + downsample_type=downsample_type, + dropout=dropout, + ) + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +): deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import get_up_block`, instead." deprecate("get_up_block", "0.29", deprecation_message) - from .unets.unet_2d_blocks import get_up_block # noqa + from .unets.unet_2d_blocks import get_up_block + + return get_up_block( + up_block_type=up_block_type, + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resolution_idx=resolution_idx, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim, + upsample_type=upsample_type, + dropout=dropout, + ) class AutoencoderTinyBlock: From 4e60591136d2e6e190d7c121cda82ac5959473fb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jan 2024 10:16:54 +0530 Subject: [PATCH 08/13] inherit from the actual classes. --- src/diffusers/models/unet_2d_blocks.py | 106 +++++++++++++------------ 1 file changed, 54 insertions(+), 52 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index ba560372191f..cbde60bc8ad1 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,6 +15,34 @@ from typing import Optional from ..utils import deprecate +from .unets.unet_2d_blocks import ( + AttnDownBlock2D, + AttnDownEncoderBlock2D, + AttnSkipDownBlock2D, + AttnSkipUpBlock2D, + AttnUpBlock2D, + AttnUpDecoderBlock2D, + AutoencoderTinyBlock, + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + KAttentionBlock, + KCrossAttnDownBlock2D, + KCrossAttnUpBlock2D, + KDownBlock2D, + KUpBlock2D, + ResnetDownsampleBlock2D, + ResnetUpsampleBlock2D, + SimpleCrossAttnDownBlock2D, + SimpleCrossAttnUpBlock2D, + SkipDownBlock2D, + SkipUpBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, + UpDecoderBlock2D, +) def get_down_block( @@ -141,158 +169,132 @@ def get_up_block( ) -class AutoencoderTinyBlock: +class AutoencoderTinyBlock(AutoencoderTinyBlock): deprecation_message = "Importing `AutoencoderTinyBlock` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AutoencoderTinyBlock`, instead." deprecate("AutoencoderTinyBlock", "0.29", deprecation_message) - from .unets.unet_2d_blocks import AutoencoderTinyBlock -class UNetMidBlock2D: +class UNetMidBlock2D(UNetMidBlock2D): deprecation_message = "Importing `UNetMidBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D`, instead." deprecate("UNetMidBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import UNetMidBlock2D -class UNetMidBlock2DCrossAttn: +class UNetMidBlock2DCrossAttn(UNetMidBlock2DCrossAttn): deprecation_message = "Importing `UNetMidBlock2DCrossAttn` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn`, instead." deprecate("UNetMidBlock2DCrossAttn", "0.29", deprecation_message) - from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn -class UNetMidBlock2DSimpleCrossAttn: +class UNetMidBlock2DSimpleCrossAttn(UNetMidBlock2DSimpleCrossAttn): deprecation_message = "Importing `UNetMidBlock2DSimpleCrossAttn` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn`, instead." deprecate("UNetMidBlock2DSimpleCrossAttn", "0.29", deprecation_message) - from .unets.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn -class AttnDownBlock2D: +class AttnDownBlock2D(AttnDownBlock2D): deprecation_message = "Importing `AttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnDownBlock2D`, instead." deprecate("AttnDownBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import AttnDownBlock2D -class CrossAttnDownBlock2D: +class CrossAttnDownBlock2D(CrossAttnDownBlock2D): deprecation_message = "Importing `AttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D`, instead." deprecate("CrossAttnDownBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import CrossAttnDownBlock2D -class DownBlock2D: +class DownBlock2D(DownBlock2D): deprecation_message = "Importing `DownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import DownBlock2D`, instead." deprecate("DownBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import DownBlock2D -class AttnDownEncoderBlock2D: +class AttnDownEncoderBlock2D(AttnDownEncoderBlock2D): deprecation_message = "Importing `AttnDownEncoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnDownEncoderBlock2D`, instead." deprecate("AttnDownEncoderBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import AttnDownEncoderBlock2D -class AttnSkipDownBlock2D: +class AttnSkipDownBlock2D(AttnSkipDownBlock2D): deprecation_message = "Importing `AttnSkipDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnSkipDownBlock2D`, instead." deprecate("AttnSkipDownBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import AttnSkipDownBlock2D -class SkipDownBlock2D: +class SkipDownBlock2D(SkipDownBlock2D): deprecation_message = "Importing `SkipDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SkipDownBlock2D`, instead." deprecate("SkipDownBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import SkipDownBlock2D -class ResnetDownsampleBlock2D: +class ResnetDownsampleBlock2D(ResnetDownsampleBlock2D): deprecation_message = "Importing `ResnetDownsampleBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D`, instead." deprecate("ResnetDownsampleBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import ResnetDownsampleBlock2D -class SimpleCrossAttnDownBlock2D: +class SimpleCrossAttnDownBlock2D(SimpleCrossAttnDownBlock2D): deprecation_message = "Importing `SimpleCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SimpleCrossAttnDownBlock2D`, instead." deprecate("SimpleCrossAttnDownBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import SimpleCrossAttnDownBlock2D -class KDownBlock2D: +class KDownBlock2D(KDownBlock2D): deprecation_message = "Importing `KDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KDownBlock2D`, instead." deprecate("KDownBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import KDownBlock2D -class KCrossAttnDownBlock2D: +class KCrossAttnDownBlock2D(KCrossAttnDownBlock2D): deprecation_message = "Importing `KCrossAttnDownBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KCrossAttnDownBlock2D`, instead." deprecate("KCrossAttnDownBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import KCrossAttnDownBlock2D -class AttnUpBlock2D: +class AttnUpBlock2D(AttnUpBlock2D): deprecation_message = "Importing `AttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnUpBlock2D`, instead." deprecate("AttnUpBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import AttnUpBlock2D -class CrossAttnUpBlock2D: +class CrossAttnUpBlock2D(CrossAttnUpBlock2D): deprecation_message = "Importing `CrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import CrossAttnUpBlock2D`, instead." deprecate("CrossAttnUpBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import CrossAttnUpBlock2D -class UpBlock2D: +class UpBlock2D(UpBlock2D): deprecation_message = "Importing `UpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UpBlock2D`, instead." deprecate("UpBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import UpBlock2D -class UpDecoderBlock2D: +class UpDecoderBlock2D(UpDecoderBlock2D): deprecation_message = "Importing `UpDecoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import UpDecoderBlock2D`, instead." deprecate("UpDecoderBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import UpDecoderBlock2D -class AttnUpDecoderBlock2D: +class AttnUpDecoderBlock2D(AttnUpDecoderBlock2D): deprecation_message = "Importing `AttnUpDecoderBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnUpDecoderBlock2D`, instead." deprecate("AttnUpDecoderBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import AttnUpDecoderBlock2D -class AttnSkipUpBlock2D: +class AttnSkipUpBlock2D(AttnSkipUpBlock2D): deprecation_message = "Importing `AttnSkipUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import AttnSkipUpBlock2D`, instead." deprecate("AttnSkipUpBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import AttnSkipUpBlock2D -class SkipUpBlock2D: +class SkipUpBlock2D(SkipUpBlock2D): deprecation_message = "Importing `SkipUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SkipUpBlock2D`, instead." deprecate("SkipUpBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import SkipUpBlock2D -class ResnetUpsampleBlock2D: +class ResnetUpsampleBlock2D(ResnetUpsampleBlock2D): deprecation_message = "Importing `ResnetUpsampleBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import ResnetUpsampleBlock2D`, instead." deprecate("ResnetUpsampleBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import ResnetUpsampleBlock2D -class SimpleCrossAttnUpBlock2D: +class SimpleCrossAttnUpBlock2D(SimpleCrossAttnUpBlock2D): deprecation_message = "Importing `SimpleCrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import SimpleCrossAttnUpBlock2D`, instead." deprecate("SimpleCrossAttnUpBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import SimpleCrossAttnUpBlock2D -class KUpBlock2D: +class KUpBlock2D(KUpBlock2D): deprecation_message = "Importing `KUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KUpBlock2D`, instead." deprecate("KUpBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import KUpBlock2D -class KCrossAttnUpBlock2D: +class KCrossAttnUpBlock2D(KCrossAttnUpBlock2D): deprecation_message = "Importing `KCrossAttnUpBlock2D` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KCrossAttnUpBlock2D`, instead." deprecate("KCrossAttnUpBlock2D", "0.29", deprecation_message) - from .unets.unet_2d_blocks import KCrossAttnUpBlock2D # can potentially later be renamed to `No-feed-forward` attention -class KAttentionBlock: +class KAttentionBlock(KAttentionBlock): deprecation_message = "Importing `KAttentionBlock` from `diffusers.models.unet_2d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_blocks import KAttentionBlock`, instead." deprecate("KAttentionBlock", "0.29", deprecation_message) - from .unets.unet_2d_blocks import KAttentionBlock From e2033611f8f0fe94790f75562139b1bb702eb9b5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jan 2024 10:50:41 +0530 Subject: [PATCH 09/13] Empty-Commit From a7adbeaa7b37f35aa825513fc600fbebaa2959be Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jan 2024 15:14:54 +0530 Subject: [PATCH 10/13] backwards compatibility for unet_2d.py --- src/diffusers/models/unet_2d.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 src/diffusers/models/unet_2d.py diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py new file mode 100644 index 000000000000..006bf4721856 --- /dev/null +++ b/src/diffusers/models/unet_2d.py @@ -0,0 +1,27 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..utils import deprecate +from .unets.unet_2d import UNet2DModel, UNet2DOutput + + +class UNet2DOutput(UNet2DOutput): + deprecation_message = "Importing `UNet2DOutput` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DOutput`, instead." + deprecate("UNet2DOutput", "0.29", deprecation_message) + + +class UNet2DModel(UNet2DModel): + deprecation_message = "Importing `UNet2DModel` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DModel`, instead." + deprecate("UNet2DModel", "0.29", deprecation_message) From c8b433f582213440fd214503d3550661bd7aa111 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jan 2024 15:19:49 +0530 Subject: [PATCH 11/13] backward compatibility for unet_2d_condition --- src/diffusers/models/unet_2d_condition.py | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/diffusers/models/unet_2d_condition.py diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py new file mode 100644 index 000000000000..cc619dd17c4c --- /dev/null +++ b/src/diffusers/models/unet_2d_condition.py @@ -0,0 +1,25 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..utils import deprecate +from .unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput + + +class UNet2DConditionOutput(UNet2DConditionOutput): + deprecation_message = "Importing `UNet2DConditionOutput` from `diffusers.models.unet_2d_condition` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput`, instead." + deprecate("UNet2DConditionOutput", "0.29", deprecation_message) + + +class UNet2DConditionModel(UNet2DConditionModel): + deprecation_message = "Importing `UNet2DConditionModel` from `diffusers.models.unet_2d_condition` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel`, instead." + deprecate("UNet2DConditionModel", "0.29", deprecation_message) From 5b0590f1a9fcf3c2f89ff24cb1174fd2cc7b2089 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jan 2024 15:22:47 +0530 Subject: [PATCH 12/13] bc for unet_1d --- src/diffusers/models/unet_1d.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 src/diffusers/models/unet_1d.py diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py new file mode 100644 index 000000000000..06ff51b17d0d --- /dev/null +++ b/src/diffusers/models/unet_1d.py @@ -0,0 +1,26 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import deprecate +from .unets.unet_1d import UNet1DModel, UNet1DOutput + + +class UNet1DOutput(UNet1DOutput): + deprecation_message = "Importing `UNet1DOutput` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DOutput`, instead." + deprecate("UNet1DOutput", "0.29", deprecation_message) + + +class UNet1DModel(UNet1DModel): + deprecation_message = "Importing `UNet1DModel` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DModel`, instead." + deprecate("UNet1DModel", "0.29", deprecation_message) From 37f965e7df723c42f0a5dade40be4751a2f5ec53 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 19 Jan 2024 15:40:28 +0530 Subject: [PATCH 13/13] bc for unet_1d_blocks --- src/diffusers/models/unet_1d_blocks.py | 203 +++++++++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 src/diffusers/models/unet_1d_blocks.py diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py new file mode 100644 index 000000000000..772d7f6cfbe4 --- /dev/null +++ b/src/diffusers/models/unet_1d_blocks.py @@ -0,0 +1,203 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import deprecate +from .unets.unet_1d_blocks import ( + AttnDownBlock1D, + AttnUpBlock1D, + DownBlock1D, + DownBlock1DNoSkip, + DownResnetBlock1D, + Downsample1d, + MidResTemporalBlock1D, + OutConv1DBlock, + OutValueFunctionBlock, + ResConvBlock, + SelfAttention1d, + UNetMidBlock1D, + UpBlock1D, + UpBlock1DNoSkip, + UpResnetBlock1D, + Upsample1d, + ValueFunctionMidBlock1D, +) + + +class DownResnetBlock1D(DownResnetBlock1D): + deprecation_message = "Importing `DownResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownResnetBlock1D`, instead." + deprecate("DownResnetBlock1D", "0.29", deprecation_message) + + +class UpResnetBlock1D(UpResnetBlock1D): + deprecation_message = "Importing `UpResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpResnetBlock1D`, instead." + deprecate("UpResnetBlock1D", "0.29", deprecation_message) + + +class ValueFunctionMidBlock1D(ValueFunctionMidBlock1D): + deprecation_message = "Importing `ValueFunctionMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ValueFunctionMidBlock1D`, instead." + deprecate("ValueFunctionMidBlock1D", "0.29", deprecation_message) + + +class OutConv1DBlock(OutConv1DBlock): + deprecation_message = "Importing `OutConv1DBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutConv1DBlock`, instead." + deprecate("OutConv1DBlock", "0.29", deprecation_message) + + +class OutValueFunctionBlock(OutValueFunctionBlock): + deprecation_message = "Importing `OutValueFunctionBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutValueFunctionBlock`, instead." + deprecate("OutValueFunctionBlock", "0.29", deprecation_message) + + +class Downsample1d(Downsample1d): + deprecation_message = "Importing `Downsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Downsample1d`, instead." + deprecate("Downsample1d", "0.29", deprecation_message) + + +class Upsample1d(Upsample1d): + deprecation_message = "Importing `Upsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Upsample1d`, instead." + deprecate("Upsample1d", "0.29", deprecation_message) + + +class SelfAttention1d(SelfAttention1d): + deprecation_message = "Importing `SelfAttention1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import SelfAttention1d`, instead." + deprecate("SelfAttention1d", "0.29", deprecation_message) + + +class ResConvBlock(ResConvBlock): + deprecation_message = "Importing `ResConvBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ResConvBlock`, instead." + deprecate("ResConvBlock", "0.29", deprecation_message) + + +class UNetMidBlock1D(UNetMidBlock1D): + deprecation_message = "Importing `UNetMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UNetMidBlock1D`, instead." + deprecate("UNetMidBlock1D", "0.29", deprecation_message) + + +class AttnDownBlock1D(AttnDownBlock1D): + deprecation_message = "Importing `AttnDownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnDownBlock1D`, instead." + deprecate("AttnDownBlock1D", "0.29", deprecation_message) + + +class DownBlock1D(DownBlock1D): + deprecation_message = "Importing `DownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1D`, instead." + deprecate("DownBlock1D", "0.29", deprecation_message) + + +class DownBlock1DNoSkip(DownBlock1DNoSkip): + deprecation_message = "Importing `DownBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1DNoSkip`, instead." + deprecate("DownBlock1DNoSkip", "0.29", deprecation_message) + + +class AttnUpBlock1D(AttnUpBlock1D): + deprecation_message = "Importing `AttnUpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnUpBlock1D`, instead." + deprecate("AttnUpBlock1D", "0.29", deprecation_message) + + +class UpBlock1D(UpBlock1D): + deprecation_message = "Importing `UpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1D`, instead." + deprecate("UpBlock1D", "0.29", deprecation_message) + + +class UpBlock1DNoSkip(UpBlock1DNoSkip): + deprecation_message = "Importing `UpBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1DNoSkip`, instead." + deprecate("UpBlock1DNoSkip", "0.29", deprecation_message) + + +class MidResTemporalBlock1D(MidResTemporalBlock1D): + deprecation_message = "Importing `MidResTemporalBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import MidResTemporalBlock1D`, instead." + deprecate("MidResTemporalBlock1D", "0.29", deprecation_message) + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, +): + deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_down_block`, instead." + deprecate("get_down_block", "0.29", deprecation_message) + + from .unets.unet_1d_blocks import get_down_block + + return get_down_block( + down_block_type=down_block_type, + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + + +def get_up_block( + up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool +): + deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_up_block`, instead." + deprecate("get_up_block", "0.29", deprecation_message) + + from .unets.unet_1d_blocks import get_up_block + + return get_up_block( + up_block_type=up_block_type, + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + ) + + +def get_mid_block( + mid_block_type: str, + num_layers: int, + in_channels: int, + mid_channels: int, + out_channels: int, + embed_dim: int, + add_downsample: bool, +): + deprecation_message = "Importing `get_mid_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_mid_block`, instead." + deprecate("get_mid_block", "0.29", deprecation_message) + + from .unets.unet_1d_blocks import get_mid_block + + return get_mid_block( + mid_block_type=mid_block_type, + num_layers=num_layers, + in_channels=in_channels, + mid_channels=mid_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + + +def get_out_block( + *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int +): + deprecation_message = "Importing `get_out_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_out_block`, instead." + deprecate("get_out_block", "0.29", deprecation_message) + + from .unets.unet_1d_blocks import get_out_block + + return get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=embed_dim, + out_channels=out_channels, + act_fn=act_fn, + fc_dim=fc_dim, + )